cloud-fan commented on code in PR #55535:
URL: https://github.com/apache/spark/pull/55535#discussion_r3146256205
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala:
##########
@@ -3897,3 +3899,178 @@ case class TimestampDiff(
copy(startTimestamp = newLeft, endTimestamp = newRight)
}
}
+
+/**
+ * Aligns a timestamp to the start of a fixed-size interval bucket.
+ *
+ * Returns the start of the half-open bucket [start, start + bucketSize)
containing ts.
+ * All computation is performed on UTC values.
+ */
+case class TimeBucket(
+ bucketSize: Expression,
+ ts: Expression,
+ originTs: Expression)
+ extends TernaryExpression with ExpectsInputTypes {
Review Comment:
Should this extend `TimeZoneAwareExpression`? Every other expression in this
file that touches `TimestampType` does — `TruncTimestamp` (2467),
`TimestampAddInterval` (1662), `TimestampAddYMInterval` (1979),
`ConvertTimezone`, `ParseToTimestamp`, etc. `TimeBucket` accepts `TIMESTAMP`
(LTZ) but always bucketizes in UTC, so e.g. in `America/Los_Angeles`:
```sql
SELECT date_trunc('MONTH', TIMESTAMP '2024-03-15 10:00:00');
-- 2024-03-01 00:00:00
SELECT time_bucket(INTERVAL '1' MONTH, TIMESTAMP '2024-03-15 10:00:00');
-- 2024-02-29 16:00:00 (March 1 UTC seen in PDT)
```
Proposal: extend `TimeZoneAwareExpression`, thread `zoneId` into
`timeBucketYMInterval`, and replace `ZoneOffset.UTC` with `zoneId` at the four
`microsToDays` / `daysToMicros` sites in that helper (mirrors how
`TruncTimestamp` calls `truncTimestamp(t, level, zoneId)`). For NTZ, keep
passing `ZoneOffset.UTC`. DT can stay as-is (fixed-length intervals are
inherently zone-free, like `TimeWindow`), but the docstring should call that
out explicitly rather than imply UTC for both flavors. Hard to walk this back
once 4.2.0 ships.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala:
##########
@@ -1059,4 +1059,59 @@ object DateTimeUtils extends SparkDateTimeUtils {
time, timePrecision, interval, intervalEndField)
}
}
+
+ /**
+ * DayTimeInterval bucketing: microsecond floor division against
`originMicros`.
+ * Returns `originMicros + floorDiv(tsMicros - originMicros, bucketMicros) *
bucketMicros`.
+ *
+ * `bucketMicros` must be positive; `TimeBucket.checkInputDataTypes` enforces
+ * this at analysis time.
+ *
+ * @param bucketMicros bucket size in microseconds.
+ * @param tsMicros timestamp to bucket, in microseconds since the epoch
(UTC).
+ * @param originMicros grid alignment anchor, in microseconds since the
epoch (UTC).
+ */
+ def timeBucketDTInterval(bucketMicros: Long, tsMicros: Long, originMicros:
Long): Long = {
+ val diff = Math.subtractExact(tsMicros, originMicros)
+ val bucketOffset = Math.multiplyExact(Math.floorDiv(diff, bucketMicros),
bucketMicros)
+ Math.addExact(originMicros, bucketOffset)
Review Comment:
Spark's convention for arithmetic with overflow checks is
`MathUtils.subtractExact` / `multiplyExact` / `addExact` (sql/api
`MathUtils.scala`), which wraps `java.lang.ArithmeticException` into
`SparkArithmeticException` with the `ARITHMETIC_OVERFLOW` error class. Using
`Math.*Exact` directly here means a user hitting an extreme `tsMicros` /
`originMicros` / `bucketMicros` combination sees a raw, non-i18n'd
`java.lang.ArithmeticException` — the unit test
`intercept[ArithmeticException]` confirms this is what propagates. Same comment
applies to `Math.subtractExact` / `multiplyExact` / `addExact` / `toIntExact`
in `timeBucketYMInterval` below.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala:
##########
@@ -3897,3 +3899,178 @@ case class TimestampDiff(
copy(startTimestamp = newLeft, endTimestamp = newRight)
}
}
+
+/**
+ * Aligns a timestamp to the start of a fixed-size interval bucket.
+ *
+ * Returns the start of the half-open bucket [start, start + bucketSize)
containing ts.
+ * All computation is performed on UTC values.
+ */
+case class TimeBucket(
+ bucketSize: Expression,
+ ts: Expression,
+ originTs: Expression)
+ extends TernaryExpression with ExpectsInputTypes {
+
+ override def nullIntolerant: Boolean = true
+
+ override def first: Expression = bucketSize
+ override def second: Expression = ts
+ override def third: Expression = originTs
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(
+ TypeCollection(DayTimeIntervalType, YearMonthIntervalType),
+ AnyTimestampType,
+ AnyTimestampType)
+
+ override def dataType: DataType = ts.dataType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) return defaultCheck
+
+ if (!bucketSize.foldable) {
+ return DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("bucketSize"),
+ "inputType" -> toSQLType(bucketSize.dataType),
+ "inputExpr" -> toSQLExpr(bucketSize)))
+ }
+
+ val bucketSizeValue = bucketSize.eval()
+ if (bucketSizeValue != null) {
+ val isNonPositive = bucketSize.dataType match {
+ case _: DayTimeIntervalType => bucketSizeValue.asInstanceOf[Long] <= 0
+ case _: YearMonthIntervalType => bucketSizeValue.asInstanceOf[Int] <= 0
+ case other => throw SparkException.internalError(
+ s"Unexpected bucketSize type: $other")
+ }
+ if (isNonPositive) {
+ return DataTypeMismatch(
+ errorSubClass = "VALUE_OUT_OF_RANGE",
+ messageParameters = Map(
+ "exprName" -> "time_bucket",
+ "valueRange" -> "(0, inf)",
Review Comment:
Two issues here:
1. `exprName` should be the parameter name (the thing being constrained),
not the function name. Convention (`TimeWindow.scala:138-142`,
`timeExpressions.scala:491-495`, `max_by`'s `k` arg → `"exprName" : "\`k\`"`)
is `toSQLId("<param>")`. As written, the rendered message is "The time_bucket
must be between (0, inf) ...", which reads as if `time_bucket` is itself the
value being checked.
2. `(0, inf)` reads better as a concrete bound — `TimeWindow.scala:141` uses
`s"(0, ${Long.MaxValue}]"`.
```suggestion
"exprName" -> toSQLId("bucketSize"),
"valueRange" -> s"(0, $${Long.MaxValue}]",
```
##########
python/pyspark/sql/functions/builtin.py:
##########
@@ -13125,6 +13125,74 @@ def timestamp_add(unit: str, quantity: "ColumnOrName",
ts: "ColumnOrName") -> Co
)
+@_try_remote_functions
+def time_bucket(
+ bucket_size: "ColumnOrName",
+ ts: "ColumnOrName",
+ origin: Optional["ColumnOrName"] = None,
+) -> Column:
+ """
+ Aligns a timestamp to the start of a fixed-size interval bucket.
+
+ Returns the start of the bucket that ``ts`` falls into, where buckets are
defined by
+ the given ``bucket_size`` interval aligned to ``origin``. All bucketing is
performed on
+ UTC micros, the session time zone does not affect bucket alignment. For
local wall-clock
+ alignment in a DST zone, cast the TIMESTAMP to TIMESTAMP_NTZ.
+
+ .. versionadded:: 4.2.0
+
+ Parameters
+ ----------
+ bucket_size : :class:`~pyspark.sql.Column` or column name
+ A day-time or year-month interval defining the bucket size. Must be
positive
+ and foldable.
+ ts : :class:`~pyspark.sql.Column` or column name
+ A TIMESTAMP or TIMESTAMP_NTZ value to bucket.
+ origin : :class:`~pyspark.sql.Column` or column name, optional
+ Alignment anchor. Defaults to 1970-01-01 00:00:00 (UTC for TIMESTAMP).
Must be
+ the same type as ``ts`` and must be foldable.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The start of the bucket containing ``ts``, as the same type as ``ts``.
+
+ Examples
+ --------
+ >>> spark.conf.set("spark.sql.session.timeZone", "UTC")
+ >>> import datetime
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame(
+ ... [(datetime.datetime(2024, 1, 1, 11, 27, 0),)], ['ts'])
+ >>> df.select(
+ ... sf.time_bucket(sf.expr("INTERVAL '15' MINUTE"),
'ts').alias("bucket")
+ ... ).collect()
+ [Row(bucket=datetime.datetime(2024, 1, 1, 11, 15))]
+
+ Shift the grid with an explicit origin: buckets run at :05, :20, :35, :50:
+
+ >>> df.select(
+ ... sf.time_bucket(
+ ... sf.expr("INTERVAL '15' MINUTE"),
+ ... 'ts',
+ ... sf.expr("TIMESTAMP '1970-01-01 00:05:00'")
+ ... ).alias("bucket")
+ ... ).collect()
+ [Row(bucket=datetime.datetime(2024, 1, 1, 11, 20))]
+ >>> spark.conf.unset("spark.sql.session.timeZone")
+ """
+ from pyspark.sql.classic.column import _to_java_column
+
+ if origin is None:
+ return _invoke_function("time_bucket", _to_java_column(bucket_size),
_to_java_column(ts))
+ return _invoke_function(
+ "time_bucket",
+ _to_java_column(bucket_size),
+ _to_java_column(ts),
+ _to_java_column(origin),
+ )
Review Comment:
Recent additions in this file (`time_diff`, `time_trunc`,
`to_timestamp_ltz`) use `_invoke_function_over_columns`, which already maps
`_to_java_column` for you internally — equivalent here, but easier to read and
consistent with the neighbors:
```suggestion
if origin is None:
return _invoke_function_over_columns("time_bucket", bucket_size, ts)
return _invoke_function_over_columns("time_bucket", bucket_size, ts,
origin)
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala:
##########
@@ -3897,3 +3899,178 @@ case class TimestampDiff(
copy(startTimestamp = newLeft, endTimestamp = newRight)
}
}
+
+/**
+ * Aligns a timestamp to the start of a fixed-size interval bucket.
+ *
+ * Returns the start of the half-open bucket [start, start + bucketSize)
containing ts.
+ * All computation is performed on UTC values.
+ */
+case class TimeBucket(
+ bucketSize: Expression,
+ ts: Expression,
+ originTs: Expression)
+ extends TernaryExpression with ExpectsInputTypes {
+
+ override def nullIntolerant: Boolean = true
+
+ override def first: Expression = bucketSize
+ override def second: Expression = ts
+ override def third: Expression = originTs
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(
+ TypeCollection(DayTimeIntervalType, YearMonthIntervalType),
+ AnyTimestampType,
+ AnyTimestampType)
+
+ override def dataType: DataType = ts.dataType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) return defaultCheck
+
+ if (!bucketSize.foldable) {
+ return DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("bucketSize"),
+ "inputType" -> toSQLType(bucketSize.dataType),
+ "inputExpr" -> toSQLExpr(bucketSize)))
+ }
+
+ val bucketSizeValue = bucketSize.eval()
+ if (bucketSizeValue != null) {
+ val isNonPositive = bucketSize.dataType match {
+ case _: DayTimeIntervalType => bucketSizeValue.asInstanceOf[Long] <= 0
+ case _: YearMonthIntervalType => bucketSizeValue.asInstanceOf[Int] <= 0
+ case other => throw SparkException.internalError(
+ s"Unexpected bucketSize type: $other")
+ }
+ if (isNonPositive) {
+ return DataTypeMismatch(
+ errorSubClass = "VALUE_OUT_OF_RANGE",
+ messageParameters = Map(
+ "exprName" -> "time_bucket",
+ "valueRange" -> "(0, inf)",
+ "currentValue" -> toSQLValue(bucketSizeValue,
bucketSize.dataType)))
+ }
+ }
+
+ if (!originTs.foldable) {
+ return DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("origin"),
+ "inputType" -> toSQLType(originTs.dataType),
+ "inputExpr" -> toSQLExpr(originTs)))
+ }
+
+ if (ts.dataType != originTs.dataType) {
+ return DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(2),
+ "requiredType" -> toSQLType(ts.dataType),
+ "inputSql" -> toSQLExpr(originTs),
+ "inputType" -> toSQLType(originTs.dataType)))
+ }
+
+ TypeCheckSuccess
+ }
+
+ override def nullSafeEval(bucketSizeVal: Any, tsVal: Any, originVal: Any):
Any = {
+ first.dataType match {
+ case _: DayTimeIntervalType =>
+ DateTimeUtils.timeBucketDTInterval(
+ bucketSizeVal.asInstanceOf[Long], tsVal.asInstanceOf[Long],
+ originVal.asInstanceOf[Long])
+ case _: YearMonthIntervalType =>
+ DateTimeUtils.timeBucketYMInterval(
+ bucketSizeVal.asInstanceOf[Int], tsVal.asInstanceOf[Long],
+ originVal.asInstanceOf[Long])
+ case other => throw SparkException.internalError(
+ s"Unexpected bucketSize type: $other")
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ first.dataType match {
+ case _: DayTimeIntervalType =>
+ defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) =>
+ s"$dtu.timeBucketDTInterval($bucketSizeCode, $tsCode, $originCode)")
+ case _: YearMonthIntervalType =>
+ defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) =>
+ s"$dtu.timeBucketYMInterval($bucketSizeCode, $tsCode, $originCode)")
+ case other => throw SparkException.internalError(
+ s"Unexpected bucketSize type: $other")
+ }
+ }
+
+ override def prettyName: String = "time_bucket"
+
+ override protected def withNewChildrenInternal(
+ newFirst: Expression, newSecond: Expression, newThird: Expression):
TimeBucket =
+ copy(bucketSize = newFirst, ts = newSecond, originTs = newThird)
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(bucketSize, ts[, origin]) - Returns the start of the bucket that
`ts` falls into,
+ where buckets are defined by the given `bucketSize` interval aligned to
`origin`. All
+ bucketing is performed on UTC micros, the session time zone does not
affect bucket
Review Comment:
Minor: comma splice — two independent clauses joined by a comma.
```suggestion
bucketing is performed on UTC micros; the session time zone does not
affect bucket
```
##########
python/pyspark/sql/functions/builtin.py:
##########
@@ -13125,6 +13125,74 @@ def timestamp_add(unit: str, quantity: "ColumnOrName",
ts: "ColumnOrName") -> Co
)
+@_try_remote_functions
+def time_bucket(
+ bucket_size: "ColumnOrName",
+ ts: "ColumnOrName",
+ origin: Optional["ColumnOrName"] = None,
+) -> Column:
+ """
+ Aligns a timestamp to the start of a fixed-size interval bucket.
+
+ Returns the start of the bucket that ``ts`` falls into, where buckets are
defined by
+ the given ``bucket_size`` interval aligned to ``origin``. All bucketing is
performed on
+ UTC micros, the session time zone does not affect bucket alignment. For
local wall-clock
Review Comment:
Same comma splice as in the Scala `@ExpressionDescription`.
```suggestion
the given ``bucket_size`` interval aligned to ``origin``. All bucketing
is performed on
UTC micros; the session time zone does not affect bucket alignment. For
local wall-clock
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]