This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new 3dbe065 [SPARK-29213][SQL] Generate extra IsNotNull predicate in FilterExec 3dbe065 is described below commit 3dbe06561c3d645182b1b512ae8d545056b7613b Author: Wang Shuo <wangshuo...@gmail.com> AuthorDate: Fri Sep 27 15:14:17 2019 +0800 [SPARK-29213][SQL] Generate extra IsNotNull predicate in FilterExec Currently the behavior of getting output and generating null checks in `FilterExec` is different. Thus some nullable attribute could be treated as not nullable by mistake. In `FilterExec.ouput`, an attribute is marked as nullable or not by finding its `exprId` in notNullAttributes: ``` a.nullable && notNullAttributes.contains(a.exprId) ``` But in `FilterExec.doConsume`, a `nullCheck` is generated or not for a predicate is decided by whether there is semantic equal not null predicate: ``` val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} if (idx != -1 && !generatedIsNotNullChecks(idx)) { generatedIsNotNullChecks(idx) = true // Use the child's output. The nullability is what the child produced. genPredicate(notNullPreds(idx), input, child.output) } else { "" } }.mkString("\n").trim ``` NPE will happen when run the SQL below: ``` sql("create table table1(x string)") sql("create table table2(x bigint)") sql("create table table3(x string)") sql("insert into table2 select null as x") sql( """ |select t1.x |from ( | select x from table1) t1 |left join ( | select x from ( | select x from table2 | union all | select substr(x,5) x from table3 | ) a | where length(x)>0 |) t3 |on t1.x=t3.x """.stripMargin).collect() ``` NPE Exception: ``` java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(generated.java:40) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:726) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:135) at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:94) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52) at org.apache.spark.scheduler.Task.run(Task.scala:127) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:449) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:452) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` the generated code: ``` == Subtree 4 / 5 == *(2) Project [cast(x#7L as string) AS x#9] +- *(2) Filter ((length(cast(x#7L as string)) > 0) AND isnotnull(cast(x#7L as string))) +- Scan hive default.table2 [x#7L], HiveTableRelation `default`.`table2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#7L] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage2(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=2 /* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private scala.collection.Iterator inputadapter_input_0; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] filter_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 011 */ /* 012 */ public GeneratedIteratorForCodegenStage2(Object[] references) { /* 013 */ this.references = references; /* 014 */ } /* 015 */ /* 016 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 017 */ partitionIndex = index; /* 018 */ this.inputs = inputs; /* 019 */ inputadapter_input_0 = inputs[0]; /* 020 */ filter_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 021 */ filter_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 022 */ /* 023 */ } /* 024 */ /* 025 */ protected void processNext() throws java.io.IOException { /* 026 */ while ( inputadapter_input_0.hasNext()) { /* 027 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next(); /* 028 */ /* 029 */ do { /* 030 */ boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0); /* 031 */ long inputadapter_value_0 = inputadapter_isNull_0 ? /* 032 */ -1L : (inputadapter_row_0.getLong(0)); /* 033 */ /* 034 */ boolean filter_isNull_2 = inputadapter_isNull_0; /* 035 */ UTF8String filter_value_2 = null; /* 036 */ if (!inputadapter_isNull_0) { /* 037 */ filter_value_2 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 038 */ } /* 039 */ int filter_value_1 = -1; /* 040 */ filter_value_1 = (filter_value_2).numChars(); /* 041 */ /* 042 */ boolean filter_value_0 = false; /* 043 */ filter_value_0 = filter_value_1 > 0; /* 044 */ if (!filter_value_0) continue; /* 045 */ /* 046 */ boolean filter_isNull_6 = inputadapter_isNull_0; /* 047 */ UTF8String filter_value_6 = null; /* 048 */ if (!inputadapter_isNull_0) { /* 049 */ filter_value_6 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 050 */ } /* 051 */ if (!(!filter_isNull_6)) continue; /* 052 */ /* 053 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 054 */ /* 055 */ boolean project_isNull_0 = false; /* 056 */ UTF8String project_value_0 = null; /* 057 */ if (!false) { /* 058 */ project_value_0 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 059 */ } /* 060 */ filter_mutableStateArray_0[1].reset(); /* 061 */ /* 062 */ filter_mutableStateArray_0[1].zeroOutNullBytes(); /* 063 */ /* 064 */ if (project_isNull_0) { /* 065 */ filter_mutableStateArray_0[1].setNullAt(0); /* 066 */ } else { /* 067 */ filter_mutableStateArray_0[1].write(0, project_value_0); /* 068 */ } /* 069 */ append((filter_mutableStateArray_0[1].getRow())); /* 070 */ /* 071 */ } while(false); /* 072 */ if (shouldStop()) return; /* 073 */ } /* 074 */ } /* 075 */ /* 076 */ } ``` This PR proposes to use semantic comparison both in `FilterExec.output` and `FilterExec.doConsume` for nullable attribute. With this PR, the generated code snippet is below: ``` == Subtree 2 / 5 == *(3) Project [substring(x#8, 5, 2147483647) AS x#5] +- *(3) Filter ((length(substring(x#8, 5, 2147483647)) > 0) AND isnotnull(substring(x#8, 5, 2147483647))) +- Scan hive default.table3 [x#8], HiveTableRelation `default`.`table3`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#8] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage3(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=3 /* 006 */ final class GeneratedIteratorForCodegenStage3 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private scala.collection.Iterator inputadapter_input_0; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] filter_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 011 */ /* 012 */ public GeneratedIteratorForCodegenStage3(Object[] references) { /* 013 */ this.references = references; /* 014 */ } /* 015 */ /* 016 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 017 */ partitionIndex = index; /* 018 */ this.inputs = inputs; /* 019 */ inputadapter_input_0 = inputs[0]; /* 020 */ filter_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 021 */ filter_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 022 */ /* 023 */ } /* 024 */ /* 025 */ protected void processNext() throws java.io.IOException { /* 026 */ while ( inputadapter_input_0.hasNext()) { /* 027 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next(); /* 028 */ /* 029 */ do { /* 030 */ boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0); /* 031 */ UTF8String inputadapter_value_0 = inputadapter_isNull_0 ? /* 032 */ null : (inputadapter_row_0.getUTF8String(0)); /* 033 */ /* 034 */ boolean filter_isNull_0 = true; /* 035 */ boolean filter_value_0 = false; /* 036 */ boolean filter_isNull_2 = true; /* 037 */ UTF8String filter_value_2 = null; /* 038 */ /* 039 */ if (!inputadapter_isNull_0) { /* 040 */ filter_isNull_2 = false; // resultCode could change nullability. /* 041 */ filter_value_2 = inputadapter_value_0.substringSQL(5, 2147483647); /* 042 */ /* 043 */ } /* 044 */ boolean filter_isNull_1 = filter_isNull_2; /* 045 */ int filter_value_1 = -1; /* 046 */ /* 047 */ if (!filter_isNull_2) { /* 048 */ filter_value_1 = (filter_value_2).numChars(); /* 049 */ } /* 050 */ if (!filter_isNull_1) { /* 051 */ filter_isNull_0 = false; // resultCode could change nullability. /* 052 */ filter_value_0 = filter_value_1 > 0; /* 053 */ /* 054 */ } /* 055 */ if (filter_isNull_0 || !filter_value_0) continue; /* 056 */ boolean filter_isNull_8 = true; /* 057 */ UTF8String filter_value_8 = null; /* 058 */ /* 059 */ if (!inputadapter_isNull_0) { /* 060 */ filter_isNull_8 = false; // resultCode could change nullability. /* 061 */ filter_value_8 = inputadapter_value_0.substringSQL(5, 2147483647); /* 062 */ /* 063 */ } /* 064 */ if (!(!filter_isNull_8)) continue; /* 065 */ /* 066 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 067 */ /* 068 */ boolean project_isNull_0 = true; /* 069 */ UTF8String project_value_0 = null; /* 070 */ /* 071 */ if (!inputadapter_isNull_0) { /* 072 */ project_isNull_0 = false; // resultCode could change nullability. /* 073 */ project_value_0 = inputadapter_value_0.substringSQL(5, 2147483647); /* 074 */ /* 075 */ } /* 076 */ filter_mutableStateArray_0[1].reset(); /* 077 */ /* 078 */ filter_mutableStateArray_0[1].zeroOutNullBytes(); /* 079 */ /* 080 */ if (project_isNull_0) { /* 081 */ filter_mutableStateArray_0[1].setNullAt(0); /* 082 */ } else { /* 083 */ filter_mutableStateArray_0[1].write(0, project_value_0); /* 084 */ } /* 085 */ append((filter_mutableStateArray_0[1].getRow())); /* 086 */ /* 087 */ } while(false); /* 088 */ if (shouldStop()) return; /* 089 */ } /* 090 */ } /* 091 */ /* 092 */ } ``` Fix NPE bug in FilterExec. no new UT Closes #25902 from wangshuo128/filter-codegen-npe. Authored-by: Wang Shuo <wangshuo...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit bd28e8e1794bb0ae833ffa8098b2236b93d34623) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/execution/basicPhysicalOperators.scala | 5 +++++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 26 ++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a0e7559..4f86f3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration @@ -160,6 +161,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // This is very perf sensitive. // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) + val extraIsNotNullAttrs = mutable.Set[Attribute]() val generated = otherPreds.map { c => val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} @@ -167,6 +169,9 @@ case class FilterExec(condition: Expression, child: SparkPlan) generatedIsNotNullChecks(idx) = true // Use the child's output. The nullability is what the child produced. genPredicate(notNullPreds(idx), input, child.output) + } else if (notNullAttributes.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) { + extraIsNotNullAttrs += r + genPredicate(IsNotNull(r), input, child.output) } else { "" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index eb15da5..38fb5d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3034,6 +3034,32 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + + test("SPARK-29213: FilterExec should not throw NPE") { + withTempView("t1", "t2", "t3") { + sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1") + sql("SELECT * FROM VALUES 0, CAST(NULL AS BIGINT)") + .as[java.lang.Long] + .map(identity) + .toDF("x") + .createOrReplaceTempView("t2") + sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t3") + sql( + """ + |SELECT t1.x + |FROM t1 + |LEFT JOIN ( + | SELECT x FROM ( + | SELECT x FROM t2 + | UNION ALL + | SELECT SUBSTR(x,5) x FROM t3 + | ) a + | WHERE LENGTH(x)>0 + |) t3 + |ON t1.x=t3.x + """.stripMargin).collect() + } + } } case class Foo(bar: Option[String]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org