This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 54e058d  [SPARK-28416][SQL] Use java.time API in timestampAddInterval
54e058d is described below

commit 54e058dff2037066ddaf0ada1357701de04169a5
Author: Maxim Gekk <max.g...@gmail.com>
AuthorDate: Thu Jul 18 19:17:23 2019 -0400

    [SPARK-28416][SQL] Use java.time API in timestampAddInterval
    
    ## What changes were proposed in this pull request?
    
    The `DateTimeUtils.timestampAddInterval` method was rewritten by using Java 
8 time API. To add months and microseconds, I used the `plusMonths()` and 
`plus()` methods of `ZonedDateTime`. Also the signature of 
`timestampAddInterval()` was changed to accept an `ZoneId` instance instead of 
`TimeZone`. Using `ZoneId` allows to avoid the conversion `TimeZone` -> 
`ZoneId` on every invoke of `timestampAddInterval()`.
    
    ## How was this patch tested?
    
    By existing test suites `DateExpressionsSuite`, `TypeCoercionSuite` and 
`CollectionExpressionsSuite`.
    
    Closes #25173 from MaxGekk/timestamp-add-interval.
    
    Authored-by: Maxim Gekk <max.g...@gmail.com>
    Signed-off-by: Sean Owen <sean.o...@databricks.com>
---
 .../sql/catalyst/expressions/collectionOperations.scala   | 15 ++++++++-------
 .../sql/catalyst/expressions/datetimeExpressions.scala    | 12 ++++++------
 .../apache/spark/sql/catalyst/util/DateTimeUtils.scala    | 15 +++++++--------
 .../spark/sql/catalyst/util/DateTimeUtilsSuite.scala      |  7 +++----
 4 files changed, 24 insertions(+), 25 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index f671ede..5314821 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -16,7 +16,8 @@
  */
 package org.apache.spark.sql.catalyst.expressions
 
-import java.util.{Comparator, TimeZone}
+import java.time.ZoneId
+import java.util.Comparator
 
 import scala.collection.mutable
 import scala.reflect.ClassTag
@@ -2459,10 +2460,10 @@ case class Sequence(
       new IntegralSequenceImpl(iType)(ct, iType.integral)
 
     case TimestampType =>
-      new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone)
+      new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId)
 
     case DateType =>
-      new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, 
timeZone)
+      new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, 
zoneId)
   }
 
   override def eval(input: InternalRow): Any = {
@@ -2603,7 +2604,7 @@ object Sequence {
   }
 
   private class TemporalSequenceImpl[T: ClassTag]
-      (dt: IntegralType, scale: Long, fromLong: Long => T, timeZone: TimeZone)
+      (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId)
       (implicit num: Integral[T]) extends SequenceImpl {
 
     override val defaultStep: DefaultStep = new DefaultStep(
@@ -2642,7 +2643,7 @@ object Sequence {
         while (t < exclusiveItem ^ stepSign < 0) {
           arr(i) = fromLong(t / scale)
           i += 1
-          t = timestampAddInterval(startMicros, i * stepMonths, i * 
stepMicros, timeZone)
+          t = timestampAddInterval(startMicros, i * stepMonths, i * 
stepMicros, zoneId)
         }
 
         // truncate array to the correct length
@@ -2668,7 +2669,7 @@ object Sequence {
       val exclusiveItem = ctx.freshName("exclusiveItem")
       val t = ctx.freshName("t")
       val i = ctx.freshName("i")
-      val genTimeZone = ctx.addReferenceObj("timeZone", timeZone, 
classOf[TimeZone].getName)
+      val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
 
       val sequenceLengthCode =
         s"""
@@ -2701,7 +2702,7 @@ object Sequence {
          |    $arr[$i] = ($elemType) ($t / ${scale}L);
          |    $i += 1;
          |    $t = 
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval(
-         |       $startMicros, $i * $stepMonths, $i * $stepMicros, 
$genTimeZone);
+         |       $startMicros, $i * $stepMonths, $i * $stepMicros, $zid);
          |  }
          |
          |  if ($arr.length > $i) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index ccf6b36..53329fd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -996,14 +996,14 @@ case class TimeAdd(start: Expression, interval: 
Expression, timeZoneId: Option[S
   override def nullSafeEval(start: Any, interval: Any): Any = {
     val itvl = interval.asInstanceOf[CalendarInterval]
     DateTimeUtils.timestampAddInterval(
-      start.asInstanceOf[Long], itvl.months, itvl.microseconds, timeZone)
+      start.asInstanceOf[Long], itvl.months, itvl.microseconds, zoneId)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val tz = ctx.addReferenceObj("timeZone", timeZone)
+    val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
     val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
     defineCodeGen(ctx, ev, (sd, i) => {
-      s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $tz)"""
+      s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $zid)"""
     })
   }
 }
@@ -1111,14 +1111,14 @@ case class TimeSub(start: Expression, interval: 
Expression, timeZoneId: Option[S
   override def nullSafeEval(start: Any, interval: Any): Any = {
     val itvl = interval.asInstanceOf[CalendarInterval]
     DateTimeUtils.timestampAddInterval(
-      start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds, 
timeZone)
+      start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds, zoneId)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val tz = ctx.addReferenceObj("timeZone", timeZone)
+    val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
     val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
     defineCodeGen(ctx, ev, (sd, i) => {
-      s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, 
$tz)"""
+      s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, 
$zid)"""
     })
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 1daf65a..10a7f9b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.util
 
 import java.sql.{Date, Timestamp}
 import java.time._
-import java.time.Year.isLeap
-import java.time.temporal.IsoFields
+import java.time.temporal.{ChronoUnit, IsoFields}
 import java.util.{Locale, TimeZone}
 import java.util.concurrent.TimeUnit._
 
@@ -521,12 +520,12 @@ object DateTimeUtils {
       start: SQLTimestamp,
       months: Int,
       microseconds: Long,
-      timeZone: TimeZone): SQLTimestamp = {
-    val days = millisToDays(MICROSECONDS.toMillis(start), timeZone)
-    val newDays = dateAddMonths(days, months)
-    start +
-      MILLISECONDS.toMicros(daysToMillis(newDays, timeZone) - 
daysToMillis(days, timeZone)) +
-      microseconds
+      zoneId: ZoneId): SQLTimestamp = {
+    val resultTimestamp = microsToInstant(start)
+      .atZone(zoneId)
+      .plusMonths(months)
+      .plus(microseconds, ChronoUnit.MICROS)
+    instantToMicros(resultTimestamp.toInstant)
   }
 
   /**
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
index 4f83539..8ff691f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
@@ -31,7 +31,6 @@ import org.apache.spark.unsafe.types.UTF8String
 class DateTimeUtilsSuite extends SparkFunSuite {
 
   val TimeZonePST = TimeZone.getTimeZone("PST")
-  private def defaultTz = DateTimeUtils.defaultTimeZone()
   private def defaultZoneId = ZoneId.systemDefault()
 
   test("nanoseconds truncation") {
@@ -366,13 +365,13 @@ class DateTimeUtilsSuite extends SparkFunSuite {
   test("timestamp add months") {
     val ts1 = date(1997, 2, 28, 10, 30, 0)
     val ts2 = date(2000, 2, 28, 10, 30, 0, 123000)
-    assert(timestampAddInterval(ts1, 36, 123000, defaultTz) === ts2)
+    assert(timestampAddInterval(ts1, 36, 123000, defaultZoneId) === ts2)
 
     val ts3 = date(1997, 2, 27, 16, 0, 0, 0, TimeZonePST)
     val ts4 = date(2000, 2, 27, 16, 0, 0, 123000, TimeZonePST)
     val ts5 = date(2000, 2, 28, 0, 0, 0, 123000, TimeZoneGMT)
-    assert(timestampAddInterval(ts3, 36, 123000, TimeZonePST) === ts4)
-    assert(timestampAddInterval(ts3, 36, 123000, TimeZoneGMT) === ts5)
+    assert(timestampAddInterval(ts3, 36, 123000, TimeZonePST.toZoneId) === ts4)
+    assert(timestampAddInterval(ts3, 36, 123000, TimeZoneGMT.toZoneId) === ts5)
   }
 
   test("monthsBetween") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to