This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 3dbe065  [SPARK-29213][SQL] Generate extra IsNotNull predicate in 
FilterExec
3dbe065 is described below

commit 3dbe06561c3d645182b1b512ae8d545056b7613b
Author: Wang Shuo <wangshuo...@gmail.com>
AuthorDate: Fri Sep 27 15:14:17 2019 +0800

    [SPARK-29213][SQL] Generate extra IsNotNull predicate in FilterExec
    
    Currently the behavior of getting output and generating null checks in 
`FilterExec` is different. Thus some nullable attribute could be treated as not 
nullable by mistake.
    
    In `FilterExec.ouput`, an attribute is marked as nullable or not by finding 
its `exprId` in notNullAttributes:
    ```
    a.nullable && notNullAttributes.contains(a.exprId)
    ```
    But in `FilterExec.doConsume`,  a `nullCheck` is generated or not for a 
predicate is decided by whether there is semantic equal not null predicate:
    ```
          val nullChecks = c.references.map { r =>
            val idx = notNullPreds.indexWhere { n => 
n.asInstanceOf[IsNotNull].child.semanticEquals(r)}
            if (idx != -1 && !generatedIsNotNullChecks(idx)) {
              generatedIsNotNullChecks(idx) = true
              // Use the child's output. The nullability is what the child 
produced.
              genPredicate(notNullPreds(idx), input, child.output)
            } else {
              ""
            }
          }.mkString("\n").trim
    ```
    NPE will happen when run the SQL below:
    ```
    sql("create table table1(x string)")
    sql("create table table2(x bigint)")
    sql("create table table3(x string)")
    sql("insert into table2 select null as x")
    sql(
      """
        |select t1.x
        |from (
        |    select x from table1) t1
        |left join (
        |    select x from (
        |        select x from table2
        |        union all
        |        select substr(x,5) x from table3
        |    ) a
        |    where length(x)>0
        |) t3
        |on t1.x=t3.x
      """.stripMargin).collect()
    ```
    NPE Exception:
    ```
    java.lang.NullPointerException
        at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(generated.java:40)
        at 
org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
        at 
org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:726)
        at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
        at 
org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:135)
        at 
org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
        at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:94)
        at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
        at org.apache.spark.scheduler.Task.run(Task.scala:127)
        at 
org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:449)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:452)
        at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
        at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
        at java.lang.Thread.run(Thread.java:748)
    ```
    the generated code:
    ```
    == Subtree 4 / 5 ==
    *(2) Project [cast(x#7L as string) AS x#9]
    +- *(2) Filter ((length(cast(x#7L as string)) > 0) AND isnotnull(cast(x#7L 
as string)))
       +- Scan hive default.table2 [x#7L], HiveTableRelation 
`default`.`table2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#7L]
    
    Generated code:
    /* 001 */ public Object generate(Object[] references) {
    /* 002 */   return new GeneratedIteratorForCodegenStage2(references);
    /* 003 */ }
    /* 004 */
    /* 005 */ // codegenStageId=2
    /* 006 */ final class GeneratedIteratorForCodegenStage2 extends 
org.apache.spark.sql.execution.BufferedRowIterator {
    /* 007 */   private Object[] references;
    /* 008 */   private scala.collection.Iterator[] inputs;
    /* 009 */   private scala.collection.Iterator inputadapter_input_0;
    /* 010 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] 
filter_mutableStateArray_0 = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
    /* 011 */
    /* 012 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
    /* 013 */     this.references = references;
    /* 014 */   }
    /* 015 */
    /* 016 */   public void init(int index, scala.collection.Iterator[] inputs) 
{
    /* 017 */     partitionIndex = index;
    /* 018 */     this.inputs = inputs;
    /* 019 */     inputadapter_input_0 = inputs[0];
    /* 020 */     filter_mutableStateArray_0[0] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
    /* 021 */     filter_mutableStateArray_0[1] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
    /* 022 */
    /* 023 */   }
    /* 024 */
    /* 025 */   protected void processNext() throws java.io.IOException {
    /* 026 */     while ( inputadapter_input_0.hasNext()) {
    /* 027 */       InternalRow inputadapter_row_0 = (InternalRow) 
inputadapter_input_0.next();
    /* 028 */
    /* 029 */       do {
    /* 030 */         boolean inputadapter_isNull_0 = 
inputadapter_row_0.isNullAt(0);
    /* 031 */         long inputadapter_value_0 = inputadapter_isNull_0 ?
    /* 032 */         -1L : (inputadapter_row_0.getLong(0));
    /* 033 */
    /* 034 */         boolean filter_isNull_2 = inputadapter_isNull_0;
    /* 035 */         UTF8String filter_value_2 = null;
    /* 036 */         if (!inputadapter_isNull_0) {
    /* 037 */           filter_value_2 = 
UTF8String.fromString(String.valueOf(inputadapter_value_0));
    /* 038 */         }
    /* 039 */         int filter_value_1 = -1;
    /* 040 */         filter_value_1 = (filter_value_2).numChars();
    /* 041 */
    /* 042 */         boolean filter_value_0 = false;
    /* 043 */         filter_value_0 = filter_value_1 > 0;
    /* 044 */         if (!filter_value_0) continue;
    /* 045 */
    /* 046 */         boolean filter_isNull_6 = inputadapter_isNull_0;
    /* 047 */         UTF8String filter_value_6 = null;
    /* 048 */         if (!inputadapter_isNull_0) {
    /* 049 */           filter_value_6 = 
UTF8String.fromString(String.valueOf(inputadapter_value_0));
    /* 050 */         }
    /* 051 */         if (!(!filter_isNull_6)) continue;
    /* 052 */
    /* 053 */         ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[0] /* numOutputRows */).add(1);
    /* 054 */
    /* 055 */         boolean project_isNull_0 = false;
    /* 056 */         UTF8String project_value_0 = null;
    /* 057 */         if (!false) {
    /* 058 */           project_value_0 = 
UTF8String.fromString(String.valueOf(inputadapter_value_0));
    /* 059 */         }
    /* 060 */         filter_mutableStateArray_0[1].reset();
    /* 061 */
    /* 062 */         filter_mutableStateArray_0[1].zeroOutNullBytes();
    /* 063 */
    /* 064 */         if (project_isNull_0) {
    /* 065 */           filter_mutableStateArray_0[1].setNullAt(0);
    /* 066 */         } else {
    /* 067 */           filter_mutableStateArray_0[1].write(0, project_value_0);
    /* 068 */         }
    /* 069 */         append((filter_mutableStateArray_0[1].getRow()));
    /* 070 */
    /* 071 */       } while(false);
    /* 072 */       if (shouldStop()) return;
    /* 073 */     }
    /* 074 */   }
    /* 075 */
    /* 076 */ }
    
    ```
    
    This PR proposes to use semantic comparison both in `FilterExec.output` and 
`FilterExec.doConsume` for nullable attribute.
    
    With this PR, the generated code snippet is below:
    ```
    == Subtree 2 / 5 ==
    *(3) Project [substring(x#8, 5, 2147483647) AS x#5]
    +- *(3) Filter ((length(substring(x#8, 5, 2147483647)) > 0) AND 
isnotnull(substring(x#8, 5, 2147483647)))
       +- Scan hive default.table3 [x#8], HiveTableRelation `default`.`table3`, 
org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#8]
    
    Generated code:
    /* 001 */ public Object generate(Object[] references) {
    /* 002 */   return new GeneratedIteratorForCodegenStage3(references);
    /* 003 */ }
    /* 004 */
    /* 005 */ // codegenStageId=3
    /* 006 */ final class GeneratedIteratorForCodegenStage3 extends 
org.apache.spark.sql.execution.BufferedRowIterator {
    /* 007 */   private Object[] references;
    /* 008 */   private scala.collection.Iterator[] inputs;
    /* 009 */   private scala.collection.Iterator inputadapter_input_0;
    /* 010 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] 
filter_mutableStateArray_0 = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
    /* 011 */
    /* 012 */   public GeneratedIteratorForCodegenStage3(Object[] references) {
    /* 013 */     this.references = references;
    /* 014 */   }
    /* 015 */
    /* 016 */   public void init(int index, scala.collection.Iterator[] inputs) 
{
    /* 017 */     partitionIndex = index;
    /* 018 */     this.inputs = inputs;
    /* 019 */     inputadapter_input_0 = inputs[0];
    /* 020 */     filter_mutableStateArray_0[0] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
    /* 021 */     filter_mutableStateArray_0[1] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
    /* 022 */
    /* 023 */   }
    /* 024 */
    /* 025 */   protected void processNext() throws java.io.IOException {
    /* 026 */     while ( inputadapter_input_0.hasNext()) {
    /* 027 */       InternalRow inputadapter_row_0 = (InternalRow) 
inputadapter_input_0.next();
    /* 028 */
    /* 029 */       do {
    /* 030 */         boolean inputadapter_isNull_0 = 
inputadapter_row_0.isNullAt(0);
    /* 031 */         UTF8String inputadapter_value_0 = inputadapter_isNull_0 ?
    /* 032 */         null : (inputadapter_row_0.getUTF8String(0));
    /* 033 */
    /* 034 */         boolean filter_isNull_0 = true;
    /* 035 */         boolean filter_value_0 = false;
    /* 036 */         boolean filter_isNull_2 = true;
    /* 037 */         UTF8String filter_value_2 = null;
    /* 038 */
    /* 039 */         if (!inputadapter_isNull_0) {
    /* 040 */           filter_isNull_2 = false; // resultCode could change 
nullability.
    /* 041 */           filter_value_2 = inputadapter_value_0.substringSQL(5, 
2147483647);
    /* 042 */
    /* 043 */         }
    /* 044 */         boolean filter_isNull_1 = filter_isNull_2;
    /* 045 */         int filter_value_1 = -1;
    /* 046 */
    /* 047 */         if (!filter_isNull_2) {
    /* 048 */           filter_value_1 = (filter_value_2).numChars();
    /* 049 */         }
    /* 050 */         if (!filter_isNull_1) {
    /* 051 */           filter_isNull_0 = false; // resultCode could change 
nullability.
    /* 052 */           filter_value_0 = filter_value_1 > 0;
    /* 053 */
    /* 054 */         }
    /* 055 */         if (filter_isNull_0 || !filter_value_0) continue;
    /* 056 */         boolean filter_isNull_8 = true;
    /* 057 */         UTF8String filter_value_8 = null;
    /* 058 */
    /* 059 */         if (!inputadapter_isNull_0) {
    /* 060 */           filter_isNull_8 = false; // resultCode could change 
nullability.
    /* 061 */           filter_value_8 = inputadapter_value_0.substringSQL(5, 
2147483647);
    /* 062 */
    /* 063 */         }
    /* 064 */         if (!(!filter_isNull_8)) continue;
    /* 065 */
    /* 066 */         ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[0] /* numOutputRows */).add(1);
    /* 067 */
    /* 068 */         boolean project_isNull_0 = true;
    /* 069 */         UTF8String project_value_0 = null;
    /* 070 */
    /* 071 */         if (!inputadapter_isNull_0) {
    /* 072 */           project_isNull_0 = false; // resultCode could change 
nullability.
    /* 073 */           project_value_0 = inputadapter_value_0.substringSQL(5, 
2147483647);
    /* 074 */
    /* 075 */         }
    /* 076 */         filter_mutableStateArray_0[1].reset();
    /* 077 */
    /* 078 */         filter_mutableStateArray_0[1].zeroOutNullBytes();
    /* 079 */
    /* 080 */         if (project_isNull_0) {
    /* 081 */           filter_mutableStateArray_0[1].setNullAt(0);
    /* 082 */         } else {
    /* 083 */           filter_mutableStateArray_0[1].write(0, project_value_0);
    /* 084 */         }
    /* 085 */         append((filter_mutableStateArray_0[1].getRow()));
    /* 086 */
    /* 087 */       } while(false);
    /* 088 */       if (shouldStop()) return;
    /* 089 */     }
    /* 090 */   }
    /* 091 */
    /* 092 */ }
    ```
    Fix NPE bug in FilterExec.
    
    no
    
    new UT
    
    Closes #25902 from wangshuo128/filter-codegen-npe.
    
    Authored-by: Wang Shuo <wangshuo...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit bd28e8e1794bb0ae833ffa8098b2236b93d34623)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/execution/basicPhysicalOperators.scala     |  5 +++++
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 26 ++++++++++++++++++++++
 2 files changed, 31 insertions(+)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index a0e7559..4f86f3b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution
 
+import scala.collection.mutable
 import scala.concurrent.{ExecutionContext, Future}
 import scala.concurrent.duration.Duration
 
@@ -160,6 +161,7 @@ case class FilterExec(condition: Expression, child: 
SparkPlan)
     // This is very perf sensitive.
     // TODO: revisit this. We can consider reordering predicates as well.
     val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length)
+    val extraIsNotNullAttrs = mutable.Set[Attribute]()
     val generated = otherPreds.map { c =>
       val nullChecks = c.references.map { r =>
         val idx = notNullPreds.indexWhere { n => 
n.asInstanceOf[IsNotNull].child.semanticEquals(r)}
@@ -167,6 +169,9 @@ case class FilterExec(condition: Expression, child: 
SparkPlan)
           generatedIsNotNullChecks(idx) = true
           // Use the child's output. The nullability is what the child 
produced.
           genPredicate(notNullPreds(idx), input, child.output)
+        } else if (notNullAttributes.contains(r.exprId) && 
!extraIsNotNullAttrs.contains(r)) {
+          extraIsNotNullAttrs += r
+          genPredicate(IsNotNull(r), input, child.output)
         } else {
           ""
         }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index eb15da5..38fb5d8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -3034,6 +3034,32 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
     }
 
   }
+
+  test("SPARK-29213: FilterExec should not throw NPE") {
+    withTempView("t1", "t2", "t3") {
+      sql("SELECT 
''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1")
+      sql("SELECT * FROM VALUES 0, CAST(NULL AS BIGINT)")
+        .as[java.lang.Long]
+        .map(identity)
+        .toDF("x")
+        .createOrReplaceTempView("t2")
+      sql("SELECT 
''").as[String].map(identity).toDF("x").createOrReplaceTempView("t3")
+      sql(
+        """
+          |SELECT t1.x
+          |FROM t1
+          |LEFT JOIN (
+          |    SELECT x FROM (
+          |        SELECT x FROM t2
+          |        UNION ALL
+          |        SELECT SUBSTR(x,5) x FROM t3
+          |    ) a
+          |    WHERE LENGTH(x)>0
+          |) t3
+          |ON t1.x=t3.x
+        """.stripMargin).collect()
+    }
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to