This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 14ee0d8f04f [SPARK-37544][SQL] Correct date arithmetic in sequences 14ee0d8f04f is described below commit 14ee0d8f04f218ad61688196a0b984f024151468 Author: Bruce Robbins <bersprock...@gmail.com> AuthorDate: Sun May 15 09:26:19 2022 +0900 [SPARK-37544][SQL] Correct date arithmetic in sequences ### What changes were proposed in this pull request? Change `InternalSequenceBase` to pass a time-zone aware value to `DateTimeUtils#timestampAddInterval`, rather than a time-zone agnostic value, when performing `Date` arithmetic. ### Why are the changes needed? The following query gets the wrong answer if run in the America/Los_Angeles time zone: ``` spark-sql> select sequence(date '2021-01-01', date '2022-01-01', interval '3' month) x; [2021-01-01,2021-03-31,2021-06-30,2021-09-30,2022-01-01] Time taken: 0.664 seconds, Fetched 1 row(s) spark-sql> ``` The answer should be ``` [2021-01-01,2021-04-01,2021-07-01,2021-10-01,2022-01-01] ``` `InternalSequenceBase` converts the date to micros by multiplying days by micros per day. This converts the date into a time-zone agnostic timestamp. However, `InternalSequenceBase` uses `DateTimeUtils#timestampAddInterval` to perform the arithmetic, and that function assumes a _time-zone aware_ timestamp. One simple fix would be to call `DateTimeUtils#timestampNTZAddInterval` instead for date arithmetic. However, Spark date arithmetic is typically time-zone aware (see the comment in the test added by this PR), so this PR converts the date to a time-zone aware value before calling `DateTimeUtils#timestampAddInterval`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit test. Closes #36546 from bersprockets/date_sequence_issue. Authored-by: Bruce Robbins <bersprock...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../expressions/collectionOperations.scala | 58 ++++++++++++++++++---- .../expressions/CollectionExpressionsSuite.scala | 46 ++++++++++++++++- 2 files changed, 93 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1bd934214f5..f38beb480e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3012,6 +3012,22 @@ object Sequence { case TimestampNTZType => timestampNTZAddInterval } + private def toMicros(value: Long, scale: Long): Long = { + if (scale == MICROS_PER_DAY) { + daysToMicros(value.toInt, zoneId) + } else { + value * scale + } + } + + private def fromMicros(value: Long, scale: Long): Long = { + if (scale == MICROS_PER_DAY) { + microsToDays(value, zoneId).toLong + } else { + value / scale + } + } + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { val start = input1.asInstanceOf[T] val stop = input2.asInstanceOf[T] @@ -3035,8 +3051,9 @@ object Sequence { // about a month length in days and a day length in microseconds val intervalStepInMicros = stepMicros + stepMonths * microsPerMonth + stepDays * MICROS_PER_DAY - val startMicros: Long = num.toLong(start) * scale - val stopMicros: Long = num.toLong(stop) * scale + + val startMicros: Long = toMicros(num.toLong(start), scale) + val stopMicros: Long = toMicros(num.toLong(stop), scale) val maxEstimatedArrayLength = getSequenceLength(startMicros, stopMicros, input3, intervalStepInMicros) @@ -3048,7 +3065,8 @@ object Sequence { var i = 0 while (t < exclusiveItem ^ stepSign < 0) { - arr(i) = fromLong(t / scale) + val result = fromMicros(t, scale) + arr(i) = fromLong(result) i += 1 t = addInterval(startMicros, i * stepMonths, i * stepDays, i * stepMicros, zoneId) } @@ -3061,13 +3079,16 @@ object Sequence { protected def stepSplitCode( stepMonths: String, stepDays: String, stepMicros: String, step: String): String + private val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + private val addIntervalCode = outerDataType match { - case TimestampType | DateType => - "org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval" - case TimestampNTZType => - "org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampNTZAddInterval" + case TimestampType | DateType => s"$dtu.timestampAddInterval" + case TimestampNTZType => s"$dtu.timestampNTZAddInterval" } + private val daysToMicrosCode = s"$dtu.daysToMicros" + private val microsToDaysCode = s"$dtu.microsToDays" + override def genCode( ctx: CodegenContext, start: String, @@ -3111,6 +3132,24 @@ object Sequence { val stepSplits = stepSplitCode(stepMonths, stepDays, stepMicros, step) + val toMicrosCode = if (scale == MICROS_PER_DAY) { + s""" + | final long $startMicros = $daysToMicrosCode((int) $start, $zid); + | final long $stopMicros = $daysToMicrosCode((int) $stop, $zid); + |""".stripMargin + } else { + s""" + | final long $startMicros = $start * ${scale}L; + | final long $stopMicros = $stop * ${scale}L; + |""".stripMargin + } + + val fromMicrosCode = if (scale == MICROS_PER_DAY) { + s"($elemType) $microsToDaysCode($t, $zid)" + } else { + s"($elemType) ($t / ${scale}L)" + } + s""" |$stepSplits | @@ -3122,8 +3161,7 @@ object Sequence { |} else if ($stepMonths == 0 && $stepDays == 0 && ${scale}L == 1) { | ${backedSequenceImpl.genCode(ctx, start, stop, stepMicros, arr, elemType)}; |} else { - | final long $startMicros = $start * ${scale}L; - | final long $stopMicros = $stop * ${scale}L; + | $toMicrosCode | | $sequenceLengthCode | @@ -3135,7 +3173,7 @@ object Sequence { | int $i = 0; | | while ($t < $exclusiveItem ^ $stepSign < 0) { - | $arr[$i] = ($elemType) ($t / ${scale}L); + | $arr[$i] = $fromMicrosCode; | $i += 1; | $t = $addIntervalCode( | $startMicros, $i * $stepMonths, $i * $stepDays, $i * $stepMicros, $zid); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index fb4bf43ba83..a8c4b16c7a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} -import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingZoneIds, LA, UTC} import org.apache.spark.sql.catalyst.util.IntervalUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -964,6 +964,50 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } } + test("SPARK-37544: Time zone should not affect date sequence with month interval") { + outstandingZoneIds.foreach { zid => + DateTimeTestUtils.withDefaultTimeZone(zid) { + checkEvaluation(new Sequence( + Literal(Date.valueOf("2021-01-01")), + Literal(Date.valueOf("2022-01-01")), + Literal(stringToInterval("interval 3 month"))), + Seq( + Date.valueOf("2021-01-01"), + Date.valueOf("2021-04-01"), + Date.valueOf("2021-07-01"), + Date.valueOf("2021-10-01"), + Date.valueOf("2022-01-01"))) + } + } + + // However, time zone should still affect sequences generated using hours interval, + // especially if the sequence's start-stop includes a "spring forward". + // Take, for example, the following Spark date arithmetic: + // select cast(date'2022-03-09' + interval '4' days '23' hour as date) as x; + // In the America/Los_Angeles time zone, it returns 2022-03-14. + // In the UTC time zone, it instead returns 2022-03-13. + // The sequence function should be consistent with the date arithmetic. + DateTimeTestUtils.withDefaultTimeZone(LA) { + checkEvaluation(new Sequence( + Literal(Date.valueOf("2022-03-09")), + Literal(Date.valueOf("2022-03-15")), + Literal(stringToInterval("interval 4 days 23 hours"))), + Seq( + Date.valueOf("2022-03-09"), + Date.valueOf("2022-03-14"))) + } + + DateTimeTestUtils.withDefaultTimeZone(UTC) { + checkEvaluation(new Sequence( + Literal(Date.valueOf("2022-03-09")), + Literal(Date.valueOf("2022-03-15")), + Literal(stringToInterval("interval 4 days 23 hours"))), + Seq( + Date.valueOf("2022-03-09"), + Date.valueOf("2022-03-13"))) // this is different from LA time zone above + } + } + test("SPARK-35088: Accept ANSI intervals by the Sequence expression") { checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-01 00:00:00")), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org