cloud-fan commented on code in PR #55535:
URL: https://github.com/apache/spark/pull/55535#discussion_r3146256205


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala:
##########
@@ -3897,3 +3899,178 @@ case class TimestampDiff(
     copy(startTimestamp = newLeft, endTimestamp = newRight)
   }
 }
+
+/**
+ * Aligns a timestamp to the start of a fixed-size interval bucket.
+ *
+ * Returns the start of the half-open bucket [start, start + bucketSize) 
containing ts.
+ * All computation is performed on UTC values.
+ */
+case class TimeBucket(
+    bucketSize: Expression,
+    ts: Expression,
+    originTs: Expression)
+  extends TernaryExpression with ExpectsInputTypes {

Review Comment:
   Should this extend `TimeZoneAwareExpression`? Every other expression in this 
file that touches `TimestampType` does — `TruncTimestamp` (2467), 
`TimestampAddInterval` (1662), `TimestampAddYMInterval` (1979), 
`ConvertTimezone`, `ParseToTimestamp`, etc. `TimeBucket` accepts `TIMESTAMP` 
(LTZ) but always bucketizes in UTC, so e.g. in `America/Los_Angeles`:
   
   ```sql
   SELECT date_trunc('MONTH', TIMESTAMP '2024-03-15 10:00:00');
   -- 2024-03-01 00:00:00
   SELECT time_bucket(INTERVAL '1' MONTH, TIMESTAMP '2024-03-15 10:00:00');
   -- 2024-02-29 16:00:00   (March 1 UTC seen in PDT)
   ```
   
   Proposal: extend `TimeZoneAwareExpression`, thread `zoneId` into 
`timeBucketYMInterval`, and replace `ZoneOffset.UTC` with `zoneId` at the four 
`microsToDays` / `daysToMicros` sites in that helper (mirrors how 
`TruncTimestamp` calls `truncTimestamp(t, level, zoneId)`). For NTZ, keep 
passing `ZoneOffset.UTC`. DT can stay as-is (fixed-length intervals are 
inherently zone-free, like `TimeWindow`), but the docstring should call that 
out explicitly rather than imply UTC for both flavors. Hard to walk this back 
once 4.2.0 ships.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala:
##########
@@ -1059,4 +1059,59 @@ object DateTimeUtils extends SparkDateTimeUtils {
         time, timePrecision, interval, intervalEndField)
     }
   }
+
+  /**
+   * DayTimeInterval bucketing: microsecond floor division against 
`originMicros`.
+   * Returns `originMicros + floorDiv(tsMicros - originMicros, bucketMicros) * 
bucketMicros`.
+   *
+   * `bucketMicros` must be positive; `TimeBucket.checkInputDataTypes` enforces
+   * this at analysis time.
+   *
+   * @param bucketMicros bucket size in microseconds.
+   * @param tsMicros     timestamp to bucket, in microseconds since the epoch 
(UTC).
+   * @param originMicros grid alignment anchor, in microseconds since the 
epoch (UTC).
+   */
+  def timeBucketDTInterval(bucketMicros: Long, tsMicros: Long, originMicros: 
Long): Long = {
+    val diff = Math.subtractExact(tsMicros, originMicros)
+    val bucketOffset = Math.multiplyExact(Math.floorDiv(diff, bucketMicros), 
bucketMicros)
+    Math.addExact(originMicros, bucketOffset)

Review Comment:
   Spark's convention for arithmetic with overflow checks is 
`MathUtils.subtractExact` / `multiplyExact` / `addExact` (sql/api 
`MathUtils.scala`), which wraps `java.lang.ArithmeticException` into 
`SparkArithmeticException` with the `ARITHMETIC_OVERFLOW` error class. Using 
`Math.*Exact` directly here means a user hitting an extreme `tsMicros` / 
`originMicros` / `bucketMicros` combination sees a raw, non-i18n'd 
`java.lang.ArithmeticException` — the unit test 
`intercept[ArithmeticException]` confirms this is what propagates. Same comment 
applies to `Math.subtractExact` / `multiplyExact` / `addExact` / `toIntExact` 
in `timeBucketYMInterval` below.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala:
##########
@@ -3897,3 +3899,178 @@ case class TimestampDiff(
     copy(startTimestamp = newLeft, endTimestamp = newRight)
   }
 }
+
+/**
+ * Aligns a timestamp to the start of a fixed-size interval bucket.
+ *
+ * Returns the start of the half-open bucket [start, start + bucketSize) 
containing ts.
+ * All computation is performed on UTC values.
+ */
+case class TimeBucket(
+    bucketSize: Expression,
+    ts: Expression,
+    originTs: Expression)
+  extends TernaryExpression with ExpectsInputTypes {
+
+  override def nullIntolerant: Boolean = true
+
+  override def first: Expression = bucketSize
+  override def second: Expression = ts
+  override def third: Expression = originTs
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(
+    TypeCollection(DayTimeIntervalType, YearMonthIntervalType),
+    AnyTimestampType,
+    AnyTimestampType)
+
+  override def dataType: DataType = ts.dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) return defaultCheck
+
+    if (!bucketSize.foldable) {
+      return DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("bucketSize"),
+          "inputType" -> toSQLType(bucketSize.dataType),
+          "inputExpr" -> toSQLExpr(bucketSize)))
+    }
+
+    val bucketSizeValue = bucketSize.eval()
+    if (bucketSizeValue != null) {
+      val isNonPositive = bucketSize.dataType match {
+        case _: DayTimeIntervalType => bucketSizeValue.asInstanceOf[Long] <= 0
+        case _: YearMonthIntervalType => bucketSizeValue.asInstanceOf[Int] <= 0
+        case other => throw SparkException.internalError(
+          s"Unexpected bucketSize type: $other")
+      }
+      if (isNonPositive) {
+        return DataTypeMismatch(
+          errorSubClass = "VALUE_OUT_OF_RANGE",
+          messageParameters = Map(
+            "exprName" -> "time_bucket",
+            "valueRange" -> "(0, inf)",

Review Comment:
   Two issues here:
   
   1. `exprName` should be the parameter name (the thing being constrained), 
not the function name. Convention (`TimeWindow.scala:138-142`, 
`timeExpressions.scala:491-495`, `max_by`'s `k` arg → `"exprName" : "\`k\`"`) 
is `toSQLId("<param>")`. As written, the rendered message is "The time_bucket 
must be between (0, inf) ...", which reads as if `time_bucket` is itself the 
value being checked.
   2. `(0, inf)` reads better as a concrete bound — `TimeWindow.scala:141` uses 
`s"(0, ${Long.MaxValue}]"`.
   
   ```suggestion
               "exprName" -> toSQLId("bucketSize"),
               "valueRange" -> s"(0, $${Long.MaxValue}]",
   ```



##########
python/pyspark/sql/functions/builtin.py:
##########
@@ -13125,6 +13125,74 @@ def timestamp_add(unit: str, quantity: "ColumnOrName", 
ts: "ColumnOrName") -> Co
     )
 
 
+@_try_remote_functions
+def time_bucket(
+    bucket_size: "ColumnOrName",
+    ts: "ColumnOrName",
+    origin: Optional["ColumnOrName"] = None,
+) -> Column:
+    """
+    Aligns a timestamp to the start of a fixed-size interval bucket.
+
+    Returns the start of the bucket that ``ts`` falls into, where buckets are 
defined by
+    the given ``bucket_size`` interval aligned to ``origin``. All bucketing is 
performed on
+    UTC micros, the session time zone does not affect bucket alignment. For 
local wall-clock
+    alignment in a DST zone, cast the TIMESTAMP to TIMESTAMP_NTZ.
+
+    .. versionadded:: 4.2.0
+
+    Parameters
+    ----------
+    bucket_size : :class:`~pyspark.sql.Column` or column name
+        A day-time or year-month interval defining the bucket size. Must be 
positive
+        and foldable.
+    ts : :class:`~pyspark.sql.Column` or column name
+        A TIMESTAMP or TIMESTAMP_NTZ value to bucket.
+    origin : :class:`~pyspark.sql.Column` or column name, optional
+        Alignment anchor. Defaults to 1970-01-01 00:00:00 (UTC for TIMESTAMP). 
Must be
+        the same type as ``ts`` and must be foldable.
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        The start of the bucket containing ``ts``, as the same type as ``ts``.
+
+    Examples
+    --------
+    >>> spark.conf.set("spark.sql.session.timeZone", "UTC")
+    >>> import datetime
+    >>> from pyspark.sql import functions as sf
+    >>> df = spark.createDataFrame(
+    ...     [(datetime.datetime(2024, 1, 1, 11, 27, 0),)], ['ts'])
+    >>> df.select(
+    ...     sf.time_bucket(sf.expr("INTERVAL '15' MINUTE"), 
'ts').alias("bucket")
+    ... ).collect()
+    [Row(bucket=datetime.datetime(2024, 1, 1, 11, 15))]
+
+    Shift the grid with an explicit origin: buckets run at :05, :20, :35, :50:
+
+    >>> df.select(
+    ...     sf.time_bucket(
+    ...         sf.expr("INTERVAL '15' MINUTE"),
+    ...         'ts',
+    ...         sf.expr("TIMESTAMP '1970-01-01 00:05:00'")
+    ...     ).alias("bucket")
+    ... ).collect()
+    [Row(bucket=datetime.datetime(2024, 1, 1, 11, 20))]
+    >>> spark.conf.unset("spark.sql.session.timeZone")
+    """
+    from pyspark.sql.classic.column import _to_java_column
+
+    if origin is None:
+        return _invoke_function("time_bucket", _to_java_column(bucket_size), 
_to_java_column(ts))
+    return _invoke_function(
+        "time_bucket",
+        _to_java_column(bucket_size),
+        _to_java_column(ts),
+        _to_java_column(origin),
+    )

Review Comment:
   Recent additions in this file (`time_diff`, `time_trunc`, 
`to_timestamp_ltz`) use `_invoke_function_over_columns`, which already maps 
`_to_java_column` for you internally — equivalent here, but easier to read and 
consistent with the neighbors:
   
   ```suggestion
       if origin is None:
           return _invoke_function_over_columns("time_bucket", bucket_size, ts)
       return _invoke_function_over_columns("time_bucket", bucket_size, ts, 
origin)
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala:
##########
@@ -3897,3 +3899,178 @@ case class TimestampDiff(
     copy(startTimestamp = newLeft, endTimestamp = newRight)
   }
 }
+
+/**
+ * Aligns a timestamp to the start of a fixed-size interval bucket.
+ *
+ * Returns the start of the half-open bucket [start, start + bucketSize) 
containing ts.
+ * All computation is performed on UTC values.
+ */
+case class TimeBucket(
+    bucketSize: Expression,
+    ts: Expression,
+    originTs: Expression)
+  extends TernaryExpression with ExpectsInputTypes {
+
+  override def nullIntolerant: Boolean = true
+
+  override def first: Expression = bucketSize
+  override def second: Expression = ts
+  override def third: Expression = originTs
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(
+    TypeCollection(DayTimeIntervalType, YearMonthIntervalType),
+    AnyTimestampType,
+    AnyTimestampType)
+
+  override def dataType: DataType = ts.dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) return defaultCheck
+
+    if (!bucketSize.foldable) {
+      return DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("bucketSize"),
+          "inputType" -> toSQLType(bucketSize.dataType),
+          "inputExpr" -> toSQLExpr(bucketSize)))
+    }
+
+    val bucketSizeValue = bucketSize.eval()
+    if (bucketSizeValue != null) {
+      val isNonPositive = bucketSize.dataType match {
+        case _: DayTimeIntervalType => bucketSizeValue.asInstanceOf[Long] <= 0
+        case _: YearMonthIntervalType => bucketSizeValue.asInstanceOf[Int] <= 0
+        case other => throw SparkException.internalError(
+          s"Unexpected bucketSize type: $other")
+      }
+      if (isNonPositive) {
+        return DataTypeMismatch(
+          errorSubClass = "VALUE_OUT_OF_RANGE",
+          messageParameters = Map(
+            "exprName" -> "time_bucket",
+            "valueRange" -> "(0, inf)",
+            "currentValue" -> toSQLValue(bucketSizeValue, 
bucketSize.dataType)))
+      }
+    }
+
+    if (!originTs.foldable) {
+      return DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("origin"),
+          "inputType" -> toSQLType(originTs.dataType),
+          "inputExpr" -> toSQLExpr(originTs)))
+    }
+
+    if (ts.dataType != originTs.dataType) {
+      return DataTypeMismatch(
+        errorSubClass = "UNEXPECTED_INPUT_TYPE",
+        messageParameters = Map(
+          "paramIndex" -> ordinalNumber(2),
+          "requiredType" -> toSQLType(ts.dataType),
+          "inputSql" -> toSQLExpr(originTs),
+          "inputType" -> toSQLType(originTs.dataType)))
+    }
+
+    TypeCheckSuccess
+  }
+
+  override def nullSafeEval(bucketSizeVal: Any, tsVal: Any, originVal: Any): 
Any = {
+    first.dataType match {
+      case _: DayTimeIntervalType =>
+        DateTimeUtils.timeBucketDTInterval(
+          bucketSizeVal.asInstanceOf[Long], tsVal.asInstanceOf[Long],
+          originVal.asInstanceOf[Long])
+      case _: YearMonthIntervalType =>
+        DateTimeUtils.timeBucketYMInterval(
+          bucketSizeVal.asInstanceOf[Int], tsVal.asInstanceOf[Long],
+          originVal.asInstanceOf[Long])
+      case other => throw SparkException.internalError(
+        s"Unexpected bucketSize type: $other")
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+    first.dataType match {
+      case _: DayTimeIntervalType =>
+        defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) =>
+          s"$dtu.timeBucketDTInterval($bucketSizeCode, $tsCode, $originCode)")
+      case _: YearMonthIntervalType =>
+        defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) =>
+          s"$dtu.timeBucketYMInterval($bucketSizeCode, $tsCode, $originCode)")
+      case other => throw SparkException.internalError(
+        s"Unexpected bucketSize type: $other")
+    }
+  }
+
+  override def prettyName: String = "time_bucket"
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): 
TimeBucket =
+    copy(bucketSize = newFirst, ts = newSecond, originTs = newThird)
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(bucketSize, ts[, origin]) - Returns the start of the bucket that 
`ts` falls into,
+      where buckets are defined by the given `bucketSize` interval aligned to 
`origin`. All
+      bucketing is performed on UTC micros, the session time zone does not 
affect bucket

Review Comment:
   Minor: comma splice — two independent clauses joined by a comma.
   
   ```suggestion
         bucketing is performed on UTC micros; the session time zone does not 
affect bucket
   ```



##########
python/pyspark/sql/functions/builtin.py:
##########
@@ -13125,6 +13125,74 @@ def timestamp_add(unit: str, quantity: "ColumnOrName", 
ts: "ColumnOrName") -> Co
     )
 
 
+@_try_remote_functions
+def time_bucket(
+    bucket_size: "ColumnOrName",
+    ts: "ColumnOrName",
+    origin: Optional["ColumnOrName"] = None,
+) -> Column:
+    """
+    Aligns a timestamp to the start of a fixed-size interval bucket.
+
+    Returns the start of the bucket that ``ts`` falls into, where buckets are 
defined by
+    the given ``bucket_size`` interval aligned to ``origin``. All bucketing is 
performed on
+    UTC micros, the session time zone does not affect bucket alignment. For 
local wall-clock

Review Comment:
   Same comma splice as in the Scala `@ExpressionDescription`.
   
   ```suggestion
       the given ``bucket_size`` interval aligned to ``origin``. All bucketing 
is performed on
       UTC micros; the session time zone does not affect bucket alignment. For 
local wall-clock
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to