vranes commented on code in PR #55535:
URL: https://github.com/apache/spark/pull/55535#discussion_r3161738884
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala:
##########
@@ -3897,3 +3899,178 @@ case class TimestampDiff(
copy(startTimestamp = newLeft, endTimestamp = newRight)
}
}
+
+/**
+ * Aligns a timestamp to the start of a fixed-size interval bucket.
+ *
+ * Returns the start of the half-open bucket [start, start + bucketSize)
containing ts.
+ * All computation is performed on UTC values.
+ */
+case class TimeBucket(
+ bucketSize: Expression,
+ ts: Expression,
+ originTs: Expression)
+ extends TernaryExpression with ExpectsInputTypes {
+
+ override def nullIntolerant: Boolean = true
+
+ override def first: Expression = bucketSize
+ override def second: Expression = ts
+ override def third: Expression = originTs
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(
+ TypeCollection(DayTimeIntervalType, YearMonthIntervalType),
+ AnyTimestampType,
+ AnyTimestampType)
+
+ override def dataType: DataType = ts.dataType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) return defaultCheck
+
+ if (!bucketSize.foldable) {
+ return DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("bucketSize"),
+ "inputType" -> toSQLType(bucketSize.dataType),
+ "inputExpr" -> toSQLExpr(bucketSize)))
+ }
+
+ val bucketSizeValue = bucketSize.eval()
+ if (bucketSizeValue != null) {
+ val isNonPositive = bucketSize.dataType match {
+ case _: DayTimeIntervalType => bucketSizeValue.asInstanceOf[Long] <= 0
+ case _: YearMonthIntervalType => bucketSizeValue.asInstanceOf[Int] <= 0
+ case other => throw SparkException.internalError(
+ s"Unexpected bucketSize type: $other")
+ }
+ if (isNonPositive) {
+ return DataTypeMismatch(
+ errorSubClass = "VALUE_OUT_OF_RANGE",
+ messageParameters = Map(
+ "exprName" -> "time_bucket",
+ "valueRange" -> "(0, inf)",
+ "currentValue" -> toSQLValue(bucketSizeValue,
bucketSize.dataType)))
+ }
+ }
+
+ if (!originTs.foldable) {
+ return DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("origin"),
+ "inputType" -> toSQLType(originTs.dataType),
+ "inputExpr" -> toSQLExpr(originTs)))
+ }
+
+ if (ts.dataType != originTs.dataType) {
+ return DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(2),
+ "requiredType" -> toSQLType(ts.dataType),
+ "inputSql" -> toSQLExpr(originTs),
+ "inputType" -> toSQLType(originTs.dataType)))
+ }
+
+ TypeCheckSuccess
+ }
+
+ override def nullSafeEval(bucketSizeVal: Any, tsVal: Any, originVal: Any):
Any = {
+ first.dataType match {
+ case _: DayTimeIntervalType =>
+ DateTimeUtils.timeBucketDTInterval(
+ bucketSizeVal.asInstanceOf[Long], tsVal.asInstanceOf[Long],
+ originVal.asInstanceOf[Long])
+ case _: YearMonthIntervalType =>
+ DateTimeUtils.timeBucketYMInterval(
+ bucketSizeVal.asInstanceOf[Int], tsVal.asInstanceOf[Long],
+ originVal.asInstanceOf[Long])
+ case other => throw SparkException.internalError(
+ s"Unexpected bucketSize type: $other")
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ first.dataType match {
+ case _: DayTimeIntervalType =>
+ defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) =>
+ s"$dtu.timeBucketDTInterval($bucketSizeCode, $tsCode, $originCode)")
+ case _: YearMonthIntervalType =>
+ defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) =>
+ s"$dtu.timeBucketYMInterval($bucketSizeCode, $tsCode, $originCode)")
+ case other => throw SparkException.internalError(
+ s"Unexpected bucketSize type: $other")
+ }
+ }
+
+ override def prettyName: String = "time_bucket"
+
+ override protected def withNewChildrenInternal(
+ newFirst: Expression, newSecond: Expression, newThird: Expression):
TimeBucket =
+ copy(bucketSize = newFirst, ts = newSecond, originTs = newThird)
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(bucketSize, ts[, origin]) - Returns the start of the bucket that
`ts` falls into,
+ where buckets are defined by the given `bucketSize` interval aligned to
`origin`. All
+ bucketing is performed on UTC micros, the session time zone does not
affect bucket
Review Comment:
Changed the docstring, resolving
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]