This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b2492628c6 [SPARK-39857][SQL] V2ExpressionBuilder uses the wrong 
LiteralValue data type for In predicate
6b2492628c6 is described below

commit 6b2492628c60fc1c4f70889c71cc3a9403a0adbc
Author: huaxingao <huaxin_...@apple.com>
AuthorDate: Mon Jul 25 08:11:19 2022 -0700

    [SPARK-39857][SQL] V2ExpressionBuilder uses the wrong LiteralValue data 
type for In predicate
    
    ### What changes were proposed in this pull request?
    When building V2 `In` Predicate in `V2ExpressionBuilder`, `InSet.dataType` 
(which is `BooleanType`) is used to build the `LiteralValue`, 
`InSet.child.dataType` should be used instead.
    
    ### Why are the changes needed?
    bug fix
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    new test
    
    Closes #37271 from huaxingao/inset.
    
    Authored-by: huaxingao <huaxin_...@apple.com>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../sql/catalyst/util/V2ExpressionBuilder.scala    |   4 +-
 .../datasources/v2/DataSourceV2StrategySuite.scala | 229 ++++++++++++++++++++-
 2 files changed, 228 insertions(+), 5 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 70c85def45d..07d681a6616 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -52,10 +52,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
       } else {
         Some(ref)
       }
-    case in @ InSet(child, hset) =>
+    case InSet(child, hset) =>
       generateExpression(child).map { v =>
         val children =
-          (v +: hset.toSeq.map(elem => LiteralValue(elem, 
in.dataType))).toArray[V2Expression]
+          (v +: hset.toSeq.map(elem => LiteralValue(elem, 
child.dataType))).toArray[V2Expression]
         new V2Predicate("IN", children)
       }
     // Because we only convert In to InSet in Optimizer when there are more 
than certain
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
index c3f51bed269..5fefcadca3e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
@@ -21,9 +21,10 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.connector.expressions.{FieldReference, 
LiteralValue}
-import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => 
V2Not, Or => V2Or, Predicate}
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, 
StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
 
 class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
   val attrInts = Seq(
@@ -55,8 +56,37 @@ class DataSourceV2StrategySuite extends PlanTest with 
SharedSparkSession {
     "a.b.cint" // three level nested field
   ))
 
-  test("SPARK-39784: translate binary expression") { attrInts
-    .foreach { case (attrInt, intColName) =>
+  val attrStrs = Seq(
+    $"cstr".string,
+    $"c.str".string,
+    GetStructField($"a".struct(StructType(
+      StructField("cint", IntegerType, nullable = true) ::
+        StructField("cstr", StringType, nullable = true) :: Nil)), 1, None),
+    GetStructField($"a".struct(StructType(
+      StructField("c.str", StringType, nullable = true) ::
+        StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None),
+    GetStructField($"a.b".struct(StructType(
+      StructField("cint1", IntegerType, nullable = true) ::
+        StructField("cint2", IntegerType, nullable = true) ::
+        StructField("cstr", StringType, nullable = true) :: Nil)), 2, None),
+    GetStructField($"a.b".struct(StructType(
+      StructField("c.str", StringType, nullable = true) :: Nil)), 0, None),
+    GetStructField(GetStructField($"a".struct(StructType(
+      StructField("cint1", IntegerType, nullable = true) ::
+        StructField("b", StructType(StructField("cstr", StringType, nullable = 
true) ::
+          StructField("cint2", IntegerType, nullable = true) :: Nil)) :: 
Nil)), 1, None), 0, None)
+  ).zip(Seq(
+    "cstr",
+    "`c.str`", // single level field that contains `dot` in name
+    "a.cstr", // two level nested field
+    "a.`c.str`", // two level nested field, and nested level contains `dot`
+    "`a.b`.cstr", // two level nested field, and top level contains `dot`
+    "`a.b`.`c.str`", // two level nested field, and both levels contain `dot`
+    "a.b.cstr" // three level nested field
+  ))
+
+  test("translate simple expression") { attrInts.zip(attrStrs)
+    .foreach { case ((attrInt, intColName), (attrStr, strColName)) =>
       testTranslateFilter(EqualTo(attrInt, 1),
         Some(new Predicate("=", Array(FieldReference(intColName), 
LiteralValue(1, IntegerType)))))
       testTranslateFilter(EqualTo(1, attrInt),
@@ -86,6 +116,199 @@ class DataSourceV2StrategySuite extends PlanTest with 
SharedSparkSession {
         Some(new Predicate("<=", Array(FieldReference(intColName), 
LiteralValue(1, IntegerType)))))
       testTranslateFilter(LessThanOrEqual(1, attrInt),
         Some(new Predicate(">=", Array(FieldReference(intColName), 
LiteralValue(1, IntegerType)))))
+
+      testTranslateFilter(IsNull(attrInt),
+        Some(new Predicate("IS_NULL", Array(FieldReference(intColName)))))
+      testTranslateFilter(IsNotNull(attrInt),
+        Some(new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))
+
+      testTranslateFilter(InSet(attrInt, Set(1, 2, 3)),
+        Some(new Predicate("IN", Array(FieldReference(intColName),
+          LiteralValue(1, IntegerType), LiteralValue(2, IntegerType),
+          LiteralValue(3, IntegerType)))))
+
+      testTranslateFilter(In(attrInt, Seq(1, 2, 3)),
+        Some(new Predicate("IN", Array(FieldReference(intColName),
+          LiteralValue(1, IntegerType), LiteralValue(2, IntegerType),
+          LiteralValue(3, IntegerType)))))
+
+      // cint > 1 AND cint < 10
+      testTranslateFilter(And(
+        GreaterThan(attrInt, 1),
+        LessThan(attrInt, 10)),
+        Some(new V2And(
+          new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, 
IntegerType))),
+          new Predicate("<", Array(FieldReference(intColName), 
LiteralValue(10, IntegerType))))))
+
+      // cint >= 8 OR cint <= 2
+      testTranslateFilter(Or(
+        GreaterThanOrEqual(attrInt, 8),
+        LessThanOrEqual(attrInt, 2)),
+        Some(new V2Or(
+          new Predicate(">=", Array(FieldReference(intColName), 
LiteralValue(8, IntegerType))),
+          new Predicate("<=", Array(FieldReference(intColName), 
LiteralValue(2, IntegerType))))))
+
+      testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)),
+        Some(new V2Not(new Predicate(">=", Array(FieldReference(intColName),
+          LiteralValue(8, IntegerType))))))
+
+      testTranslateFilter(StartsWith(attrStr, "a"),
+        Some(new Predicate("STARTS_WITH", Array(FieldReference(strColName),
+          LiteralValue(UTF8String.fromString("a"), StringType)))))
+
+      testTranslateFilter(EndsWith(attrStr, "a"),
+        Some(new Predicate("ENDS_WITH", Array(FieldReference(strColName),
+          LiteralValue(UTF8String.fromString("a"), StringType)))))
+
+      testTranslateFilter(Contains(attrStr, "a"),
+        Some(new Predicate("CONTAINS", Array(FieldReference(strColName),
+          LiteralValue(UTF8String.fromString("a"), StringType)))))
+    }
+  }
+
+  test("translate complex expression") {
+    attrInts.foreach { case (attrInt, intColName) =>
+
+      // ABS(cint) - 2 <= 1
+      testTranslateFilter(LessThanOrEqual(
+        // Expressions are not supported
+        // Functions such as 'Abs' are not supported
+        Subtract(Abs(attrInt), 2), 1), None)
+
+      // (cin1 > 1 AND cint < 10) OR (cint > 50 AND cint > 100)
+      testTranslateFilter(Or(
+        And(
+          GreaterThan(attrInt, 1),
+          LessThan(attrInt, 10)
+        ),
+        And(
+          GreaterThan(attrInt, 50),
+          LessThan(attrInt, 100))),
+        Some(new V2Or(
+          new V2And(
+            new Predicate(">", Array(FieldReference(intColName), 
LiteralValue(1, IntegerType))),
+            new Predicate("<", Array(FieldReference(intColName), 
LiteralValue(10, IntegerType)))),
+          new V2And(
+            new Predicate(">", Array(FieldReference(intColName), 
LiteralValue(50, IntegerType))),
+            new Predicate("<", Array(FieldReference(intColName),
+              LiteralValue(100, IntegerType)))))
+        )
+      )
+
+      // (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100)
+      testTranslateFilter(Or(
+        And(
+          GreaterThan(attrInt, 1),
+          // Functions such as 'Abs' are not supported
+          LessThan(Abs(attrInt), 10)
+        ),
+        And(
+          GreaterThan(attrInt, 50),
+          LessThan(attrInt, 100))), None)
+
+      // NOT ((cint <= 1 OR ABS(cint) >= 10) AND (cint <= 50 OR cint >= 100))
+      testTranslateFilter(Not(And(
+        Or(
+          LessThanOrEqual(attrInt, 1),
+          // Functions such as 'Abs' are not supported
+          GreaterThanOrEqual(Abs(attrInt), 10)
+        ),
+        Or(
+          LessThanOrEqual(attrInt, 50),
+          GreaterThanOrEqual(attrInt, 100)))), None)
+
+      // (cint = 1 OR cint = 10) OR (cint > 0 OR cint < -10)
+      testTranslateFilter(Or(
+        Or(
+          EqualTo(attrInt, 1),
+          EqualTo(attrInt, 10)
+        ),
+        Or(
+          GreaterThan(attrInt, 0),
+          LessThan(attrInt, -10))),
+        Some(new V2Or(
+          new V2Or(
+            new Predicate("=", Array(FieldReference(intColName), 
LiteralValue(1, IntegerType))),
+            new Predicate("=", Array(FieldReference(intColName), 
LiteralValue(10, IntegerType)))),
+          new V2Or(
+            new Predicate(">", Array(FieldReference(intColName), 
LiteralValue(0, IntegerType))),
+            new Predicate("<", Array(FieldReference(intColName), 
LiteralValue(-10, IntegerType)))))
+        )
+      )
+
+      // (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10)
+      testTranslateFilter(Or(
+        Or(
+          EqualTo(attrInt, 1),
+          // Functions such as 'Abs' are not supported
+          EqualTo(Abs(attrInt), 10)
+        ),
+        Or(
+          GreaterThan(attrInt, 0),
+          LessThan(attrInt, -10))), None)
+
+      // In end-to-end testing, conjunctive predicate should has been split
+      // before reaching DataSourceStrategy.translateFilter.
+      // This is for UT purpose to test each [[case]].
+      // (cint > 1 AND cint < 10) AND (cint = 6 AND cint IS NOT NULL)
+      testTranslateFilter(And(
+        And(
+          GreaterThan(attrInt, 1),
+          LessThan(attrInt, 10)
+        ),
+        And(
+          EqualTo(attrInt, 6),
+          IsNotNull(attrInt))),
+        Some(new V2And(
+          new V2And(
+            new Predicate(">", Array(FieldReference(intColName), 
LiteralValue(1, IntegerType))),
+            new Predicate("<", Array(FieldReference(intColName), 
LiteralValue(10, IntegerType)))),
+          new V2And(
+            new Predicate("=", Array(FieldReference(intColName), 
LiteralValue(6, IntegerType))),
+            new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))
+        )
+      )
+
+      // (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL)
+      testTranslateFilter(And(
+        And(
+          GreaterThan(attrInt, 1),
+          LessThan(attrInt, 10)
+        ),
+        And(
+          // Functions such as 'Abs' are not supported
+          EqualTo(Abs(attrInt), 6),
+          IsNotNull(attrInt))), None)
+
+      // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
+      testTranslateFilter(And(
+        Or(
+          GreaterThan(attrInt, 1),
+          LessThan(attrInt, 10)
+        ),
+        Or(
+          EqualTo(attrInt, 6),
+          IsNotNull(attrInt))),
+        Some(new V2And(
+          new V2Or(
+            new Predicate(">", Array(FieldReference(intColName), 
LiteralValue(1, IntegerType))),
+            new Predicate("<", Array(FieldReference(intColName), 
LiteralValue(10, IntegerType)))),
+          new V2Or(
+            new Predicate("=", Array(FieldReference(intColName), 
LiteralValue(6, IntegerType))),
+            new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))
+        )
+      )
+
+      // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
+      testTranslateFilter(And(
+        Or(
+          GreaterThan(attrInt, 1),
+          LessThan(attrInt, 10)
+        ),
+        Or(
+          // Functions such as 'Abs' are not supported
+          EqualTo(Abs(attrInt), 6),
+          IsNotNull(attrInt))), None)
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to