vranes commented on code in PR #55535:
URL: https://github.com/apache/spark/pull/55535#discussion_r3163463520
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala:
##########
@@ -3897,3 +3899,197 @@ case class TimestampDiff(
copy(startTimestamp = newLeft, endTimestamp = newRight)
}
}
+
+/**
+ * Aligns a timestamp to the start of a fixed-size interval bucket.
+ *
+ * Returns the start of the half-open bucket [start, start + bucketSize)
containing ts.
+ * For TIMESTAMP_NTZ, bucketing is performed in UTC. For TIMESTAMP, buckets
align to
+ * the session time zone.
+ */
+case class TimeBucket(
+ bucketSize: Expression,
+ ts: Expression,
+ originTs: Expression,
+ timeZoneId: Option[String] = None)
+ extends TernaryExpression with ExpectsInputTypes with
TimeZoneAwareExpression {
+
+ override def nullIntolerant: Boolean = true
+
+ override def first: Expression = bucketSize
+ override def second: Expression = ts
+ override def third: Expression = originTs
+
+ override def withTimeZone(timeZoneId: String): TimeBucket =
+ copy(timeZoneId = Option(timeZoneId))
+
+ def this(bucketSize: Expression, ts: Expression, originTs: Expression) =
+ this(bucketSize, ts, originTs, None)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(
+ TypeCollection(DayTimeIntervalType, YearMonthIntervalType),
+ AnyTimestampType,
+ AnyTimestampType)
+
+ override def dataType: DataType = ts.dataType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) return defaultCheck
+
+ if (!bucketSize.foldable) {
+ return DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("bucketSize"),
+ "inputType" -> toSQLType(bucketSize.dataType),
+ "inputExpr" -> toSQLExpr(bucketSize)))
+ }
+
+ val bucketSizeValue = bucketSize.eval()
+ if (bucketSizeValue != null) {
+ val isNonPositive = bucketSize.dataType match {
+ case _: DayTimeIntervalType => bucketSizeValue.asInstanceOf[Long] <= 0
+ case _: YearMonthIntervalType => bucketSizeValue.asInstanceOf[Int] <= 0
+ case other => throw SparkException.internalError(
+ s"Unexpected bucketSize type: $other")
+ }
+ if (isNonPositive) {
+ return DataTypeMismatch(
+ errorSubClass = "VALUE_OUT_OF_RANGE",
+ messageParameters = Map(
+ "exprName" -> toSQLId("bucketSize"),
+ "valueRange" -> "(0, inf)",
+ "currentValue" -> toSQLValue(bucketSizeValue,
bucketSize.dataType)))
+ }
+ }
+
+ if (!originTs.foldable) {
+ return DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("origin"),
+ "inputType" -> toSQLType(originTs.dataType),
+ "inputExpr" -> toSQLExpr(originTs)))
+ }
+
+ if (ts.dataType != originTs.dataType) {
+ return DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(2),
+ "requiredType" -> toSQLType(ts.dataType),
+ "inputSql" -> toSQLExpr(originTs),
+ "inputType" -> toSQLType(originTs.dataType)))
+ }
+
+ TypeCheckSuccess
+ }
+
+ override def nullSafeEval(bucketSizeVal: Any, tsVal: Any, originVal: Any):
Any = {
+ first.dataType match {
+ case _: DayTimeIntervalType =>
+ DateTimeUtils.timeBucketDTInterval(
+ bucketSizeVal.asInstanceOf[Long], tsVal.asInstanceOf[Long],
+ originVal.asInstanceOf[Long], zoneIdForType(ts.dataType))
Review Comment:
Cached `zoneIdForType(ts.dataType)` as `@transient private lazy val
zoneIdInEval`. Both `nullSafeEval` and `doGenCode` now reference zoneIdInEval.
Thanks!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]