This is an automated email from the ASF dual-hosted git repository.

yumwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0e8e303fb4b [SPARK-46069][SQL] Support unwrap timestamp type to date 
type
0e8e303fb4b is described below

commit 0e8e303fb4b25b2254791abbf900d115232eb966
Author: Kun Wan <wan...@apache.org>
AuthorDate: Sun Dec 3 12:26:41 2023 +0800

    [SPARK-46069][SQL] Support unwrap timestamp type to date type
    
    ### What changes were proposed in this pull request?
    
    Just like [[SPARK-42597][SQL] Support unwrap date type to timestamp 
type](https://github.com/apache/spark/pull/40190), this PR enhance 
`UnwrapCastInBinaryComparison` to support unwrap timestamp type to date type.
    
    Add two new expressions:
    1. floorDate: the largest date that is less than or equal to the input 
timestamp.
    2. dateAddOne: floorDate + one day.
    3. isStartOfDay: return true if ts == floorDate
    
    The way to unwrap timestamp type to date type are:
    
    1. CAST(date AS timestamp) > ts ===> date > floorDate
    2. CAST(date AS timestamp) >= ts ===> if(isStartOfDay) date >= floorDate 
else date >= dateAddOne
    3. CAST(date AS timestamp) === ts ===>
            if (isStartOfDay) {
              fromExp === floorDate
            } else if (!fromExp.nullable) {
              FalseLiteral
            } else {
              fromExp === floorDate AND fromExp === dateAddOne
            }
    4. CAST(date AS timestamp) <=> ts ===> if (isStartOfDay) date <=> floorDate 
else FalseLiteral
    5. CAST(date AS timestamp) < ts ===> if(isStartOfDay) date < floorDate else 
date < dateAddOne
    6. CAST(date AS timestamp) <= ts ===> date <= floorDate
    
    ### Why are the changes needed?
    
    Improve query performance.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Unit test.
    
    Closes #43982 from wankunde/date_ts.
    
    Lead-authored-by: Kun Wan <wan...@apache.org>
    Co-authored-by: wankun <wan...@apache.org>
    Signed-off-by: Yuming Wang <yumw...@ebay.com>
---
 .../optimizer/UnwrapCastInBinaryComparison.scala   | 47 ++++++++++++++++++++
 .../UnwrapCastInBinaryComparisonSuite.scala        | 51 +++++++++++++++++-----
 .../sql/UnwrapCastInComparisonEndToEndSuite.scala  | 25 +++++++++++
 3 files changed, 113 insertions(+), 10 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index 54b1dd419fb..34b98e43038 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -138,6 +138,11 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
         if AnyTimestampType.acceptsType(fromExp.dataType) && value != null =>
       Some(unwrapDateToTimestamp(be, fromExp, date, timeZoneId, evalMode))
 
+    case be @ BinaryComparison(
+      Cast(fromExp, _, timeZoneId, evalMode), ts @ Literal(value, _))
+        if AnyTimestampType.acceptsType(ts.dataType) && value != null =>
+      Some(unwrapTimeStampToDate(be, fromExp, ts, timeZoneId, evalMode))
+
     // As the analyzer makes sure that the list of In is already of the same 
data type, then the
     // rule can simply check the first literal in `in.list` can implicitly 
cast to `toType` or not,
     // and note that:
@@ -329,6 +334,48 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
     }
   }
 
+  private def unwrapTimeStampToDate(
+      exp: BinaryComparison,
+      fromExp: Expression,
+      ts: Literal,
+      tz: Option[String],
+      evalMode: EvalMode.Value): Expression = {
+    val floorDate = Cast(ts, fromExp.dataType, tz, evalMode)
+    val dateAddOne = DateAdd(floorDate, Literal(1, IntegerType))
+    val isStartOfDay =
+      EqualTo(ts, Cast(floorDate, ts.dataType, tz, 
evalMode)).eval(EmptyRow).asInstanceOf[Boolean]
+
+    exp match {
+      case _: GreaterThan =>
+        GreaterThan(fromExp, floorDate)
+      case _: GreaterThanOrEqual =>
+        if (isStartOfDay) {
+          GreaterThanOrEqual(fromExp, floorDate)
+        } else {
+          GreaterThanOrEqual(fromExp, dateAddOne)
+        }
+      case _: EqualTo =>
+        if (isStartOfDay) {
+          EqualTo(fromExp, floorDate)
+        } else if (!fromExp.nullable) {
+          FalseLiteral
+        } else {
+          And(EqualTo(fromExp, floorDate), EqualTo(fromExp, dateAddOne))
+        }
+      case _: EqualNullSafe =>
+        if (isStartOfDay) EqualNullSafe(fromExp, floorDate) else FalseLiteral
+      case _: LessThan =>
+        if (isStartOfDay) {
+          LessThan(fromExp, floorDate)
+        } else {
+          LessThan(fromExp, dateAddOne)
+        }
+      case _: LessThanOrEqual =>
+        LessThanOrEqual(fromExp, floorDate)
+      case _ => exp
+    }
+  }
+
   private def simplifyIn[IN <: Expression](
       fromExp: Expression,
       toType: NumericType,
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
index 0f0acb669c2..5646f3f4345 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils._
+import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
 import org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison._
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -41,13 +42,14 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
   }
 
   val testRelation: LocalRelation = LocalRelation($"a".short, $"b".float,
-    $"c".decimal(5, 2), $"d".boolean, $"e".timestamp, $"f".timestampNTZ)
+    $"c".decimal(5, 2), $"d".boolean, $"e".timestamp, $"f".timestampNTZ, 
$"g".date)
   val f: BoundReference = $"a".short.canBeNull.at(0)
   val f2: BoundReference = $"b".float.canBeNull.at(1)
   val f3: BoundReference = $"c".decimal(5, 2).canBeNull.at(2)
   val f4: BoundReference = $"d".boolean.canBeNull.at(3)
   val f5: BoundReference = $"e".timestamp.notNull.at(4)
   val f6: BoundReference = $"f".timestampNTZ.canBeNull.at(5)
+  val f7: BoundReference = $"g".date.canBeNull.at(6)
 
   test("unwrap casts when literal == max") {
     val v = Short.MaxValue
@@ -403,11 +405,40 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
       castDate(f5) > nullLit || castDate(f6) > nullLit)
   }
 
+  test("SPARK-46069: Support unwrap timestamp type to date type") {
+    val tsLit = Literal.create(ts1, TimestampType)
+    val tsNTZLit = Literal.create(ts1, TimestampNTZType)
+    val floorDate = Cast(tsLit, DateType, Some(conf.sessionLocalTimeZone))
+    val floorDateNTZ = Cast(tsNTZLit, DateType, 
Some(conf.sessionLocalTimeZone))
+    val dateAddOne = DateAdd(floorDate, Literal(1, IntegerType))
+    val dateAddOneNTZ = DateAdd(floorDateNTZ, Literal(1, IntegerType))
+    assertEquivalent(
+      castTimestamp(f7) > tsLit || castTimestampNTZ(f7) > tsNTZLit,
+      f7 > floorDate || f7 > floorDateNTZ)
+    assertEquivalent(
+      castTimestamp(f7) >= tsLit || castTimestampNTZ(f7) >= tsNTZLit,
+      f7 >= dateAddOne || f7 >= dateAddOneNTZ)
+    assertEquivalent(
+      castTimestamp(f7) === tsLit || castTimestampNTZ(f7) === tsNTZLit,
+      f7 === floorDate && f7 === dateAddOne || f7 === floorDateNTZ && f7 === 
dateAddOneNTZ)
+    assertEquivalent(
+      castTimestamp(f7) <=> tsLit || castTimestampNTZ(f7) <=> tsNTZLit,
+      FalseLiteral || FalseLiteral)
+    assertEquivalent(
+      castTimestamp(f7) < tsLit || castTimestampNTZ(f7) < tsNTZLit,
+      f7 < dateAddOne || f7 < dateAddOneNTZ)
+    assertEquivalent(
+      castTimestamp(f7) <= tsLit || castTimestampNTZ(f7) <= tsNTZLit,
+      f7 <= floorDate || f7 <= floorDateNTZ)
+  }
+
   private val ts1 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 99999000)
   private val ts2 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 999998000)
   private val ts3 = LocalDateTime.of(9999, 12, 31, 23, 59, 59, 999999999)
   private val ts4 = LocalDateTime.of(1, 1, 1, 0, 0, 0, 0)
 
+  private val dt1 = java.sql.Date.valueOf("2023-01-01")
+
   private def castInt(e: Expression): Expression = Cast(e, IntegerType)
   private def castDouble(e: Expression): Expression = Cast(e, DoubleType)
   private def castDecimal2(e: Expression): Expression = Cast(e, 
DecimalType(10, 4))
@@ -429,16 +460,16 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
 
     if (evaluate) {
       Seq(
-        (100.toShort, 3.14.toFloat, decimal2(100), true, ts1, ts1),
-        (-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false, ts2, ts2),
-        (null, Float.NaN, decimal2(12345.6789), null, null, null),
-        (null, null, null, null, null, null),
-        (Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue), 
true, ts3, ts3),
-        (Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue), 
false, ts4, ts4),
-        (0.toShort, Float.MaxValue, decimal2(0), null, null, null),
-        (0.toShort, Float.MinValue, decimal2(0.01), null, null, null)
+        (100.toShort, 3.14.toFloat, decimal2(100), true, ts1, ts1, dt1),
+        (-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false, ts2, ts2, 
dt1),
+        (null, Float.NaN, decimal2(12345.6789), null, null, null, null),
+        (null, null, null, null, null, null, null),
+        (Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue), 
true, ts3, ts3, dt1),
+        (Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue), 
false, ts4, ts4, dt1),
+        (0.toShort, Float.MaxValue, decimal2(0), null, null, null, null),
+        (0.toShort, Float.MinValue, decimal2(0.01), null, null, null, null)
       ).foreach(v => {
-        val row = create_row(v._1, v._2, v._3, v._4, v._5, v._6)
+        val row = create_row(v._1, v._2, v._3, v._4, v._5, v._6, v._7)
         checkEvaluation(e1, e2.eval(row), row)
       })
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
index 468915aa493..2657ab310cc 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
@@ -264,5 +264,30 @@ class UnwrapCastInComparisonEndToEndSuite extends 
QueryTest with SharedSparkSess
     }
   }
 
+  test("SPARK-46069: Support unwrap timestamp type to date type") {
+    val d1 = java.sql.Date.valueOf("2023-01-01")
+    val d2 = java.sql.Date.valueOf("2023-01-02")
+    val d3 = java.sql.Date.valueOf("2023-01-03")
+
+    withTable(t) {
+      Seq(d1, d2, d3).toDF("dt").write.saveAsTable(t)
+      val df = spark.table(t)
+
+      val ts1 = "timestamp'2023-01-02 10:00:00'"
+      checkAnswer(df.where(s"cast(dt as timestamp) > $ts1"), 
Seq(d3).map(Row(_)))
+      checkAnswer(df.where(s"cast(dt as timestamp) >= $ts1"), 
Seq(d3).map(Row(_)))
+      checkAnswer(df.where(s"cast(dt as timestamp) = $ts1"), Seq().map(Row(_)))
+      checkAnswer(df.where(s"cast(dt as timestamp) < $ts1"), Seq(d1, 
d2).map(Row(_)))
+      checkAnswer(df.where(s"cast(dt as timestamp) <= $ts1"), Seq(d1, 
d2).map(Row(_)))
+
+      val ts2 = "timestamp'2023-01-02 00:00:00'"
+      checkAnswer(df.where(s"cast(dt as timestamp) > $ts2"), 
Seq(d3).map(Row(_)))
+      checkAnswer(df.where(s"cast(dt as timestamp) >= $ts2"), Seq(d2, 
d3).map(Row(_)))
+      checkAnswer(df.where(s"cast(dt as timestamp) = $ts2"), 
Seq(d2).map(Row(_)))
+      checkAnswer(df.where(s"cast(dt as timestamp) < $ts2"), 
Seq(d1).map(Row(_)))
+      checkAnswer(df.where(s"cast(dt as timestamp) <= $ts2"), Seq(d1, 
d2).map(Row(_)))
+    }
+  }
+
   private def decimal(v: BigDecimal): Decimal = Decimal(v, 5, 2)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to