This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e858cd5  [SPARK-36724][SQL] Support timestamp_ntz as a type of time 
column for SessionWindow
e858cd5 is described below

commit e858cd568a74123f7fd8fe4c3d2917a7e5bbb685
Author: Kousuke Saruta <saru...@oss.nttdata.com>
AuthorDate: Mon Sep 13 21:47:43 2021 +0800

    [SPARK-36724][SQL] Support timestamp_ntz as a type of time column for 
SessionWindow
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to support `timestamp_ntz` as a type of time column for 
`SessionWIndow` like `TimeWindow` does.
    
    ### Why are the changes needed?
    
    For better usability.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New test.
    
    Closes #33965 from sarutak/session-window-ntz.
    
    Authored-by: Kousuke Saruta <saru...@oss.nttdata.com>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  9 +++---
 .../sql/catalyst/expressions/SessionWindow.scala   |  6 ++--
 .../spark/sql/DataFrameSessionWindowingSuite.scala | 33 ++++++++++++++++++++--
 3 files changed, 39 insertions(+), 9 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 340b859..0f90159 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3999,7 +3999,8 @@ object SessionWindowing extends Rule[LogicalPlan] {
         val sessionAttr = AttributeReference(
           SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
 
-        val sessionStart = PreciseTimestampConversion(session.timeColumn, 
TimestampType, LongType)
+        val sessionStart =
+          PreciseTimestampConversion(session.timeColumn, 
session.timeColumn.dataType, LongType)
         val gapDuration = session.gapDuration match {
           case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
             Cast(expr, CalendarIntervalType)
@@ -4007,13 +4008,13 @@ object SessionWindowing extends Rule[LogicalPlan] {
             throw 
QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
         }
         val sessionEnd = PreciseTimestampConversion(session.timeColumn + 
gapDuration,
-          TimestampType, LongType)
+          session.timeColumn.dataType, LongType)
 
         val literalSessionStruct = CreateNamedStruct(
           Literal(SESSION_START) ::
-            PreciseTimestampConversion(sessionStart, LongType, TimestampType) 
::
+            PreciseTimestampConversion(sessionStart, LongType, 
session.timeColumn.dataType) ::
             Literal(SESSION_END) ::
-            PreciseTimestampConversion(sessionEnd, LongType, TimestampType) ::
+            PreciseTimestampConversion(sessionEnd, LongType, 
session.timeColumn.dataType) ::
             Nil)
 
         val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
index 796ea27..77e8dfd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
@@ -69,10 +69,10 @@ case class SessionWindow(timeColumn: Expression, 
gapDuration: Expression) extend
   with NonSQLExpression {
 
   override def children: Seq[Expression] = Seq(timeColumn, gapDuration)
-  override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, 
AnyDataType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType, 
AnyDataType)
   override def dataType: DataType = new StructType()
-    .add(StructField("start", TimestampType))
-    .add(StructField("end", TimestampType))
+    .add(StructField("start", timeColumn.dataType))
+    .add(StructField("end", timeColumn.dataType))
 
   // This expression is replaced in the analyzer.
   override lazy val resolved = false
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
index 7a0cd42..b3d2127 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.sql
 
+import java.time.LocalDateTime
+
 import org.scalatest.BeforeAndAfterEach
 
-import org.apache.spark.sql.catalyst.plans.logical.Expand
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types._
 
 class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
   with BeforeAndAfterEach {
@@ -377,4 +380,30 @@ class DataFrameSessionWindowingSuite extends QueryTest 
with SharedSparkSession
       )
     }
   }
+
+  test("SPARK-36724: Support timestamp_ntz as a type of time column for 
SessionWindow") {
+    val df = Seq((LocalDateTime.parse("2016-03-27T19:39:30"), 1, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:25"), 2, "a")).toDF("time", 
"value", "id")
+    val aggDF =
+      df.groupBy(session_window($"time", "10 seconds"))
+        .agg(count("*").as("counts"))
+        .orderBy($"session_window.start".asc)
+        .select($"session_window.start".cast("string"),
+          $"session_window.end".cast("string"), $"counts")
+
+    val aggregate = aggDF.queryExecution.analyzed.children(0).children(0)
+    assert(aggregate.isInstanceOf[Aggregate])
+
+    val timeWindow = aggregate.asInstanceOf[Aggregate].groupingExpressions(0)
+    assert(timeWindow.isInstanceOf[AttributeReference])
+
+    val attributeReference = timeWindow.asInstanceOf[AttributeReference]
+    assert(attributeReference.name == "session_window")
+
+    val expectedSchema = StructType(
+      Seq(StructField("start", TimestampNTZType), StructField("end", 
TimestampNTZType)))
+    assert(attributeReference.dataType == expectedSchema)
+
+    checkAnswer(aggDF, Seq(Row("2016-03-27 19:39:25", "2016-03-27 19:39:40", 
2)))
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to