Repository: spark Updated Branches: refs/heads/branch-2.1 99891e56e -> c2876bfbf
[SPARK-17981][SPARK-17957][SQL] Fix Incorrect Nullability Setting to False in FilterExec ### What changes were proposed in this pull request? When `FilterExec` contains `isNotNull`, which could be inferred and pushed down or users specified, we convert the nullability of the involved columns if the top-layer expression is null-intolerant. However, this is not correct, if the top-layer expression is not a leaf expression, it could still tolerate the null when it has null-tolerant child expressions. For example, `cast(coalesce(a#5, a#15) as double)`. Although `cast` is a null-intolerant expression, but obviously`coalesce` is null-tolerant. Thus, it could eat null. When the nullability is wrong, we could generate incorrect results in different cases. For example, ``` Scala val df1 = Seq((1, 2), (2, 3)).toDF("a", "b") val df2 = Seq((2, 5), (3, 4)).toDF("a", "c") val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0) val df3 = Seq((3, 1)).toDF("a", "d") joinedDf.join(df3, "a").show ``` The optimized plan is like ``` Project [a#29, b#30, c#31, d#42] +- Join Inner, (a#29 = a#41) :- Project [cast(coalesce(cast(coalesce(a#5, a#15) as double), 0.0) as int) AS a#29, cast(coalesce(cast(b#6 as double), 0.0) as int) AS b#30, cast(coalesce(cast(c#16 as double), 0.0) as int) AS c#31] : +- Filter isnotnull(cast(coalesce(cast(coalesce(a#5, a#15) as double), 0.0) as int)) : +- Join FullOuter, (a#5 = a#15) : :- LocalRelation [a#5, b#6] : +- LocalRelation [a#15, c#16] +- LocalRelation [a#41, d#42] ``` Without the fix, it returns an empty result. With the fix, it can return a correct answer: ``` +---+---+---+---+ | a| b| c| d| +---+---+---+---+ | 3| 0| 4| 1| +---+---+---+---+ ``` ### How was this patch tested? Added test cases to verify the nullability changes in FilterExec. Also added a test case for verifying the reported incorrect result. Author: gatorsmile <gatorsm...@gmail.com> Closes #15523 from gatorsmile/nullabilityFilterExec. (cherry picked from commit 66a99f4a411ee7dc94ff1070a8fd6865fd004093) Signed-off-by: Herman van Hovell <hvanhov...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c2876bfb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c2876bfb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c2876bfb Branch: refs/heads/branch-2.1 Commit: c2876bfbf06fe1057c4236128d41782c61685c53 Parents: 99891e5 Author: gatorsmile <gatorsm...@gmail.com> Authored: Thu Nov 3 16:35:36 2016 +0100 Committer: Herman van Hovell <hvanhov...@databricks.com> Committed: Thu Nov 3 16:35:49 2016 +0100 ---------------------------------------------------------------------- .../sql/execution/basicPhysicalOperators.scala | 8 ++- .../org/apache/spark/sql/DataFrameSuite.scala | 74 +++++++++++++++++++- 2 files changed, 79 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c2876bfb/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 32133f5..e6f1de5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -90,7 +90,13 @@ case class FilterExec(condition: Expression, child: SparkPlan) // Split out all the IsNotNulls from condition. private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { - case IsNotNull(a: NullIntolerant) if a.references.subsetOf(child.outputSet) => true + case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet) + case _ => false + } + + // If one expression and its children are null intolerant, it is null intolerant. + private def isNullIntolerant(expr: Expression): Boolean = expr match { + case e: NullIntolerant => e.children.forall(isNullIntolerant) case _ => false } http://git-wip-us.apache.org/repos/asf/spark/blob/c2876bfb/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 33b3b78..f5bc878 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,8 +28,8 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, Union} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union} +import org.apache.spark.sql.execution.{FilterExec, QueryExecution} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} import org.apache.spark.sql.functions._ @@ -1635,6 +1635,76 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + private def verifyNullabilityInFilterExec( + df: DataFrame, + expr: String, + expectedNonNullableColumns: Seq[String]): Unit = { + val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) + // In the logical plan, all the output columns of input dataframe are nullable + dfWithFilter.queryExecution.optimizedPlan.collect { + case e: Filter => assert(e.output.forall(_.nullable)) + } + + dfWithFilter.queryExecution.executedPlan.collect { + // When the child expression in isnotnull is null-intolerant (i.e. any null input will + // result in null output), the involved columns are converted to not nullable; + // otherwise, no change should be made. + case e: FilterExec => + assert(e.output.forall { o => + if (expectedNonNullableColumns.contains(o.name)) !o.nullable else o.nullable + }) + } + } + + test("SPARK-17957: no change on nullability in FilterExec output") { + val df = sparkContext.parallelize(Seq( + null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), + new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], + new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + + verifyNullabilityInFilterExec(df, + expr = "Rand()", expectedNonNullableColumns = Seq.empty[String]) + verifyNullabilityInFilterExec(df, + expr = "coalesce(_1, _2)", expectedNonNullableColumns = Seq.empty[String]) + verifyNullabilityInFilterExec(df, + expr = "coalesce(_1, 0) + Rand()", expectedNonNullableColumns = Seq.empty[String]) + verifyNullabilityInFilterExec(df, + expr = "cast(coalesce(cast(coalesce(_1, _2) as double), 0.0) as int)", + expectedNonNullableColumns = Seq.empty[String]) + } + + test("SPARK-17957: set nullability to false in FilterExec output") { + val df = sparkContext.parallelize(Seq( + null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), + new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], + new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + + verifyNullabilityInFilterExec(df, + expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2")) + verifyNullabilityInFilterExec(df, + expr = "_1 + _2", expectedNonNullableColumns = Seq("_1", "_2")) + verifyNullabilityInFilterExec(df, + expr = "_1", expectedNonNullableColumns = Seq("_1")) + // `constructIsNotNullConstraints` infers the IsNotNull(_2) from IsNotNull(_2 + Rand()) + // Thus, we are able to set nullability of _2 to false. + // If IsNotNull(_2) is not given from `constructIsNotNullConstraints`, the impl of + // isNullIntolerant in `FilterExec` needs an update for more advanced inference. + verifyNullabilityInFilterExec(df, + expr = "_2 + Rand()", expectedNonNullableColumns = Seq("_2")) + verifyNullabilityInFilterExec(df, + expr = "_2 * 3 + coalesce(_1, 0)", expectedNonNullableColumns = Seq("_2")) + verifyNullabilityInFilterExec(df, + expr = "cast((_1 + _2) as boolean)", expectedNonNullableColumns = Seq("_1", "_2")) + } + + test("SPARK-17957: outer join + na.fill") { + val df1 = Seq((1, 2), (2, 3)).toDF("a", "b") + val df2 = Seq((2, 5), (3, 4)).toDF("a", "c") + val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0) + val df3 = Seq((3, 1)).toDF("a", "d") + checkAnswer(joinedDf.join(df3, "a"), Row(3, 0, 4, 1)) + } + test("SPARK-17123: Performing set operations that combine non-scala native types") { val dates = Seq( (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org