This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5ae1a6b [SPARK-28052][SQL] Make `ArrayExists` follow the three-valued boolean logic. 5ae1a6b is described below commit 5ae1a6bf0dea9e74ced85686ef33a87cfa3e90c2 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Sat Jun 15 10:48:06 2019 -0700 [SPARK-28052][SQL] Make `ArrayExists` follow the three-valued boolean logic. ## What changes were proposed in this pull request? Currently `ArrayExists` always returns boolean values (if the arguments are not `null`), but it should follow the three-valued boolean logic: - `true` if the predicate holds at least one `true` - otherwise, `null` if the predicate holds `null` - otherwise, `false` This behavior change is made to match Postgres' equivalent function `ANY/SOME (array)`'s behavior: https://www.postgresql.org/docs/9.6/functions-comparisons.html#AEN21174 ## How was this patch tested? Modified tests and existing tests. Closes #24873 from ueshin/issues/SPARK-28052/fix_exists. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- docs/sql-migration-guide-upgrade.md | 2 + .../expressions/higherOrderFunctions.scala | 29 +++++++++++++- .../ReplaceNullWithFalseInPredicate.scala | 4 +- .../org/apache/spark/sql/internal/SQLConf.scala | 6 +++ .../expressions/HigherOrderFunctionsSuite.scala | 46 ++++++++++++++++------ .../ReplaceNullWithFalseInPredicateSuite.scala | 14 ++++++- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 2 + ...laceNullWithFalseInPredicateEndToEndSuite.scala | 9 +++-- 8 files changed, 94 insertions(+), 18 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 44772cc..37be86f 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -139,6 +139,8 @@ license: | - Since Spark 3.0, we use a new protocol for fetching shuffle blocks, for external shuffle service users, we need to upgrade the server correspondingly. Otherwise, we'll get the error message `UnsupportedOperationException: Unexpected message: FetchShuffleBlocks`. If it is hard to upgrade the shuffle service right now, you can still use the old protocol by setting `spark.shuffle.useOldFetchProtocol` to `true`. + - Since Spark 3.0, a higher-order function `exists` follows the three-valued boolean logic, i.e., if the `predicate` returns any `null`s and no `true` is obtained, then `exists` will return `null` instead of `false`. For example, `exists(array(1, null, 3), x -> x % 2 == 0)` will be `null`. The previous behaviour can be restored by setting `spark.sql.legacy.arrayExistsFollowsThreeValuedLogic` to `false`. + ## Upgrading from Spark SQL 2.4 to 2.4.1 - The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index e6cc11d..b326e1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods @@ -388,6 +389,10 @@ case class ArrayFilter( Examples: > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 0); true + > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 10); + false + > SELECT _FUNC_(array(1, null, 3), x -> x % 2 == 0); + NULL """, since = "2.4.0") case class ArrayExists( @@ -395,6 +400,16 @@ case class ArrayExists( function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + private val followThreeValuedLogic = + SQLConf.get.getConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC) + + override def nullable: Boolean = + if (followThreeValuedLogic) { + super.nullable || function.nullable + } else { + super.nullable + } + override def dataType: DataType = BooleanType override def functionType: AbstractDataType = BooleanType @@ -410,15 +425,25 @@ case class ArrayExists( val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval var exists = false + var foundNull = false var i = 0 while (i < arr.numElements && !exists) { elementVar.value.set(arr.get(i, elementVar.dataType)) - if (f.eval(inputRow).asInstanceOf[Boolean]) { + val ret = f.eval(inputRow) + if (ret == null) { + foundNull = true + } else if (ret.asInstanceOf[Boolean]) { exists = true } i += 1 } - exists + if (exists) { + true + } else if (followThreeValuedLogic && foundNull) { + null + } else { + false + } } override def prettyName: String = "exists" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 689915a..b8edf98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFi import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.BooleanType import org.apache.spark.util.Utils @@ -63,7 +64,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => val newLambda = lf.copy(function = replaceNullWithFalse(func)) af.copy(function = newLambda) - case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => + case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) + if !SQLConf.get.getConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC) => val newLambda = lf.copy(function = replaceNullWithFalse(func)) ae.copy(function = newLambda) case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b231f08..950d231 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1799,6 +1799,12 @@ object SQLConf { .doc("When true, the upcast will be loose and allows string to atomic types.") .booleanConf .createWithDefault(false) + + val LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC = + buildConf("spark.sql.legacy.arrayExistsFollowsThreeValuedLogic") + .doc("When true, the ArrayExists will follow the three-valued boolean logic.") + .booleanConf + .createWithDefault(true) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 03fb75e..1411be8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -255,13 +255,26 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isEven: Expression => Expression = x => x % 2 === 0 val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 - - checkEvaluation(exists(ai0, isEven), true) - checkEvaluation(exists(ai0, isNullOrOdd), true) - checkEvaluation(exists(ai1, isEven), false) - checkEvaluation(exists(ai1, isNullOrOdd), true) - checkEvaluation(exists(ain, isEven), null) - checkEvaluation(exists(ain, isNullOrOdd), null) + val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral + val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType) + + for (followThreeValuedLogic <- Seq(false, true)) { + withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key + -> followThreeValuedLogic.toString) { + checkEvaluation(exists(ai0, isEven), true) + checkEvaluation(exists(ai0, isNullOrOdd), true) + checkEvaluation(exists(ai0, alwaysFalse), false) + checkEvaluation(exists(ai0, alwaysNull), if (followThreeValuedLogic) null else false) + checkEvaluation(exists(ai1, isEven), if (followThreeValuedLogic) null else false) + checkEvaluation(exists(ai1, isNullOrOdd), true) + checkEvaluation(exists(ai1, alwaysFalse), false) + checkEvaluation(exists(ai1, alwaysNull), if (followThreeValuedLogic) null else false) + checkEvaluation(exists(ain, isEven), null) + checkEvaluation(exists(ain, isNullOrOdd), null) + checkEvaluation(exists(ain, alwaysFalse), null) + checkEvaluation(exists(ain, alwaysNull), null) + } + } val as0 = Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) @@ -270,9 +283,20 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val startsWithA: Expression => Expression = x => x.startsWith("a") - checkEvaluation(exists(as0, startsWithA), true) - checkEvaluation(exists(as1, startsWithA), false) - checkEvaluation(exists(asn, startsWithA), null) + for (followThreeValuedLogic <- Seq(false, true)) { + withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key + -> followThreeValuedLogic.toString) { + checkEvaluation(exists(as0, startsWithA), true) + checkEvaluation(exists(as0, alwaysFalse), false) + checkEvaluation(exists(as0, alwaysNull), if (followThreeValuedLogic) null else false) + checkEvaluation(exists(as1, startsWithA), if (followThreeValuedLogic) null else false) + checkEvaluation(exists(as1, alwaysFalse), false) + checkEvaluation(exists(as1, alwaysNull), if (followThreeValuedLogic) null else false) + checkEvaluation(exists(asn, startsWithA), null) + checkEvaluation(exists(asn, alwaysFalse), null) + checkEvaluation(exists(asn, alwaysNull), null) + } + } val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 748075b..b692c3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, IntegerType} class ReplaceNullWithFalseInPredicateSuite extends PlanTest { @@ -313,7 +314,18 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { } test("replace nulls in lambda function of ArrayExists") { - testHigherOrderFunc('a, ArrayExists, Seq(lv('e))) + withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> "true") { + val lambdaArgs = Seq(lv('e)) + val cond = GreaterThan(lambdaArgs.last, Literal(0)) + val lambda = LambdaFunction( + function = If(cond, Literal(null, BooleanType), TrueLiteral), + arguments = lambdaArgs) + val expr = ArrayExists('a, lambda) + testProjection(originalExpr = expr, expectedExpr = expr) + } + withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> "false") { + testHigherOrderFunc('a, ArrayExists, Seq(lv('e))) + } } test("replace nulls in lambda function of MapFilter") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e5c2de9..3f16f64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2246,6 +2246,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("exists function - array for primitive type containing null") { val df = Seq[Seq[Integer]]( Seq(1, 9, 8, null, 7), + Seq(1, 3, 5), Seq(5, null, null, 9, 7, null), Seq.empty, null @@ -2256,6 +2257,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq( Row(true), Row(false), + Row(null), Row(false), Row(null))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala index 0f84b0c..1729c3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, Literal} import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions.{lit, when} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.BooleanType @@ -94,9 +95,11 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared val df2 = spark.table("t2") // ArrayExists - val q1 = df1.selectExpr("EXISTS(a, e -> IF(e is null, null, true))") - checkAnswer(q1, Row(true) :: Nil) - assertNoLiteralNullInPlan(q1) + withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> "false") { + val q1 = df1.selectExpr("EXISTS(a, e -> IF(e is null, null, true))") + checkAnswer(q1, Row(true) :: Nil) + assertNoLiteralNullInPlan(q1) + } // ArrayFilter val q2 = df1.selectExpr("FILTER(a, e -> IF(e is null, null, true))") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org