This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6b2492628c6 [SPARK-39857][SQL] V2ExpressionBuilder uses the wrong LiteralValue data type for In predicate 6b2492628c6 is described below commit 6b2492628c60fc1c4f70889c71cc3a9403a0adbc Author: huaxingao <huaxin_...@apple.com> AuthorDate: Mon Jul 25 08:11:19 2022 -0700 [SPARK-39857][SQL] V2ExpressionBuilder uses the wrong LiteralValue data type for In predicate ### What changes were proposed in this pull request? When building V2 `In` Predicate in `V2ExpressionBuilder`, `InSet.dataType` (which is `BooleanType`) is used to build the `LiteralValue`, `InSet.child.dataType` should be used instead. ### Why are the changes needed? bug fix ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new test Closes #37271 from huaxingao/inset. Authored-by: huaxingao <huaxin_...@apple.com> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- .../sql/catalyst/util/V2ExpressionBuilder.scala | 4 +- .../datasources/v2/DataSourceV2StrategySuite.scala | 229 ++++++++++++++++++++- 2 files changed, 228 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 70c85def45d..07d681a6616 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -52,10 +52,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } else { Some(ref) } - case in @ InSet(child, hset) => + case InSet(child, hset) => generateExpression(child).map { v => val children = - (v +: hset.toSeq.map(elem => LiteralValue(elem, in.dataType))).toArray[V2Expression] + (v +: hset.toSeq.map(elem => LiteralValue(elem, child.dataType))).toArray[V2Expression] new V2Predicate("IN", children) } // Because we only convert In to InSet in Optimizer when there are more than certain diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala index c3f51bed269..5fefcadca3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala @@ -21,9 +21,10 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} -import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { val attrInts = Seq( @@ -55,8 +56,37 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { "a.b.cint" // three level nested field )) - test("SPARK-39784: translate binary expression") { attrInts - .foreach { case (attrInt, intColName) => + val attrStrs = Seq( + $"cstr".string, + $"c.str".string, + GetStructField($"a".struct(StructType( + StructField("cint", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil)), 1, None), + GetStructField($"a".struct(StructType( + StructField("c.str", StringType, nullable = true) :: + StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None), + GetStructField($"a.b".struct(StructType( + StructField("cint1", IntegerType, nullable = true) :: + StructField("cint2", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil)), 2, None), + GetStructField($"a.b".struct(StructType( + StructField("c.str", StringType, nullable = true) :: Nil)), 0, None), + GetStructField(GetStructField($"a".struct(StructType( + StructField("cint1", IntegerType, nullable = true) :: + StructField("b", StructType(StructField("cstr", StringType, nullable = true) :: + StructField("cint2", IntegerType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None) + ).zip(Seq( + "cstr", + "`c.str`", // single level field that contains `dot` in name + "a.cstr", // two level nested field + "a.`c.str`", // two level nested field, and nested level contains `dot` + "`a.b`.cstr", // two level nested field, and top level contains `dot` + "`a.b`.`c.str`", // two level nested field, and both levels contain `dot` + "a.b.cstr" // three level nested field + )) + + test("translate simple expression") { attrInts.zip(attrStrs) + .foreach { case ((attrInt, intColName), (attrStr, strColName)) => testTranslateFilter(EqualTo(attrInt, 1), Some(new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) testTranslateFilter(EqualTo(1, attrInt), @@ -86,6 +116,199 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { Some(new Predicate("<=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) testTranslateFilter(LessThanOrEqual(1, attrInt), Some(new Predicate(">=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + + testTranslateFilter(IsNull(attrInt), + Some(new Predicate("IS_NULL", Array(FieldReference(intColName))))) + testTranslateFilter(IsNotNull(attrInt), + Some(new Predicate("IS_NOT_NULL", Array(FieldReference(intColName))))) + + testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), + Some(new Predicate("IN", Array(FieldReference(intColName), + LiteralValue(1, IntegerType), LiteralValue(2, IntegerType), + LiteralValue(3, IntegerType))))) + + testTranslateFilter(In(attrInt, Seq(1, 2, 3)), + Some(new Predicate("IN", Array(FieldReference(intColName), + LiteralValue(1, IntegerType), LiteralValue(2, IntegerType), + LiteralValue(3, IntegerType))))) + + // cint > 1 AND cint < 10 + testTranslateFilter(And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10)), + Some(new V2And( + new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))), + new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))))) + + // cint >= 8 OR cint <= 2 + testTranslateFilter(Or( + GreaterThanOrEqual(attrInt, 8), + LessThanOrEqual(attrInt, 2)), + Some(new V2Or( + new Predicate(">=", Array(FieldReference(intColName), LiteralValue(8, IntegerType))), + new Predicate("<=", Array(FieldReference(intColName), LiteralValue(2, IntegerType)))))) + + testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)), + Some(new V2Not(new Predicate(">=", Array(FieldReference(intColName), + LiteralValue(8, IntegerType)))))) + + testTranslateFilter(StartsWith(attrStr, "a"), + Some(new Predicate("STARTS_WITH", Array(FieldReference(strColName), + LiteralValue(UTF8String.fromString("a"), StringType))))) + + testTranslateFilter(EndsWith(attrStr, "a"), + Some(new Predicate("ENDS_WITH", Array(FieldReference(strColName), + LiteralValue(UTF8String.fromString("a"), StringType))))) + + testTranslateFilter(Contains(attrStr, "a"), + Some(new Predicate("CONTAINS", Array(FieldReference(strColName), + LiteralValue(UTF8String.fromString("a"), StringType))))) + } + } + + test("translate complex expression") { + attrInts.foreach { case (attrInt, intColName) => + + // ABS(cint) - 2 <= 1 + testTranslateFilter(LessThanOrEqual( + // Expressions are not supported + // Functions such as 'Abs' are not supported + Subtract(Abs(attrInt), 2), 1), None) + + // (cin1 > 1 AND cint < 10) OR (cint > 50 AND cint > 100) + testTranslateFilter(Or( + And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + And( + GreaterThan(attrInt, 50), + LessThan(attrInt, 100))), + Some(new V2Or( + new V2And( + new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))), + new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))), + new V2And( + new Predicate(">", Array(FieldReference(intColName), LiteralValue(50, IntegerType))), + new Predicate("<", Array(FieldReference(intColName), + LiteralValue(100, IntegerType))))) + ) + ) + + // (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100) + testTranslateFilter(Or( + And( + GreaterThan(attrInt, 1), + // Functions such as 'Abs' are not supported + LessThan(Abs(attrInt), 10) + ), + And( + GreaterThan(attrInt, 50), + LessThan(attrInt, 100))), None) + + // NOT ((cint <= 1 OR ABS(cint) >= 10) AND (cint <= 50 OR cint >= 100)) + testTranslateFilter(Not(And( + Or( + LessThanOrEqual(attrInt, 1), + // Functions such as 'Abs' are not supported + GreaterThanOrEqual(Abs(attrInt), 10) + ), + Or( + LessThanOrEqual(attrInt, 50), + GreaterThanOrEqual(attrInt, 100)))), None) + + // (cint = 1 OR cint = 10) OR (cint > 0 OR cint < -10) + testTranslateFilter(Or( + Or( + EqualTo(attrInt, 1), + EqualTo(attrInt, 10) + ), + Or( + GreaterThan(attrInt, 0), + LessThan(attrInt, -10))), + Some(new V2Or( + new V2Or( + new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))), + new Predicate("=", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))), + new V2Or( + new Predicate(">", Array(FieldReference(intColName), LiteralValue(0, IntegerType))), + new Predicate("<", Array(FieldReference(intColName), LiteralValue(-10, IntegerType))))) + ) + ) + + // (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10) + testTranslateFilter(Or( + Or( + EqualTo(attrInt, 1), + // Functions such as 'Abs' are not supported + EqualTo(Abs(attrInt), 10) + ), + Or( + GreaterThan(attrInt, 0), + LessThan(attrInt, -10))), None) + + // In end-to-end testing, conjunctive predicate should has been split + // before reaching DataSourceStrategy.translateFilter. + // This is for UT purpose to test each [[case]]. + // (cint > 1 AND cint < 10) AND (cint = 6 AND cint IS NOT NULL) + testTranslateFilter(And( + And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + And( + EqualTo(attrInt, 6), + IsNotNull(attrInt))), + Some(new V2And( + new V2And( + new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))), + new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))), + new V2And( + new Predicate("=", Array(FieldReference(intColName), LiteralValue(6, IntegerType))), + new Predicate("IS_NOT_NULL", Array(FieldReference(intColName))))) + ) + ) + + // (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL) + testTranslateFilter(And( + And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + And( + // Functions such as 'Abs' are not supported + EqualTo(Abs(attrInt), 6), + IsNotNull(attrInt))), None) + + // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL) + testTranslateFilter(And( + Or( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + Or( + EqualTo(attrInt, 6), + IsNotNull(attrInt))), + Some(new V2And( + new V2Or( + new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))), + new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))), + new V2Or( + new Predicate("=", Array(FieldReference(intColName), LiteralValue(6, IntegerType))), + new Predicate("IS_NOT_NULL", Array(FieldReference(intColName))))) + ) + ) + + // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL) + testTranslateFilter(And( + Or( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + Or( + // Functions such as 'Abs' are not supported + EqualTo(Abs(attrInt), 6), + IsNotNull(attrInt))), None) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org