This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new efa0269  [SPARK-31854][SQL] Invoke in MapElementsExec should not 
propagate null
efa0269 is described below

commit efa0269080cb7f6e2591caedcdac554beaf2661b
Author: Takeshi Yamamuro <yamam...@apache.org>
AuthorDate: Mon Jun 1 04:50:00 2020 +0000

    [SPARK-31854][SQL] Invoke in MapElementsExec should not propagate null
    
    This PR intends to fix a bug of `Dataset.map` below when the whole-stage 
codegen enabled;
    ```
    scala> val ds = Seq(1.asInstanceOf[Integer], 
null.asInstanceOf[Integer]).toDS()
    
    scala> sql("SET spark.sql.codegen.wholeStage=true")
    
    scala> ds.map(v=>(v,v)).explain
    == Physical Plan ==
    *(1) SerializeFromObject [assertnotnull(input[0, scala.Tuple2, 
true])._1.intValue AS _1#69, assertnotnull(input[0, scala.Tuple2, 
true])._2.intValue AS _2#70]
    +- *(1) MapElements <function1>, obj#68: scala.Tuple2
       +- *(1) DeserializeToObject staticinvoke(class java.lang.Integer, 
ObjectType(class java.lang.Integer), valueOf, value#1, true, false), obj#67: 
java.lang.Integer
          +- LocalTableScan [value#1]
    
    // `AssertNotNull` in `SerializeFromObject` will fail;
    scala> ds.map(v => (v, v)).show()
    java.lang.NullPointerException: Null value appeared in non-nullable fails:
    top level Product input object
    If the schema is inferred from a Scala tuple/case class, or a Java bean, 
please try to use scala.Option[_] or other nullable types (e.g. 
java.lang.Integer instead of int/scala.Int).
    
    // When the whole-stage codegen disabled, the query works well;
    scala> sql("SET spark.sql.codegen.wholeStage=false")
    scala> ds.map(v=>(v,v)).show()
    +----+----+
    |  _1|  _2|
    +----+----+
    |   1|   1|
    |null|null|
    +----+----+
    ```
    A root cause is that `Invoke` used in `MapElementsExec` propagates input 
null, and then 
[AssertNotNull](https://github.com/apache/spark/blob/1b780f364bfbb46944fe805a024bb6c32f5d2dde/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala#L253-L255)
 in `SerializeFromObject` fails because a top-level row becomes null. So, 
`MapElementsExec` should not return `null` but `(null, null)`.
    
    NOTE: the generated code of the query above in the current master;
    ```
    /* 033 */   private void mapelements_doConsume_0(java.lang.Integer 
mapelements_expr_0_0, boolean mapelements_exprIsNull_0_0) throws 
java.io.IOException {
    /* 034 */     boolean mapelements_isNull_1 = true;
    /* 035 */     scala.Tuple2 mapelements_value_1 = null;
    /* 036 */     if (!false) {
    /* 037 */       mapelements_resultIsNull_0 = false;
    /* 038 */
    /* 039 */       if (!mapelements_resultIsNull_0) {
    /* 040 */         mapelements_resultIsNull_0 = mapelements_exprIsNull_0_0;
    /* 041 */         mapelements_mutableStateArray_0[0] = mapelements_expr_0_0;
    /* 042 */       }
    /* 043 */
    /* 044 */       mapelements_isNull_1 = mapelements_resultIsNull_0;
    /* 045 */       if (!mapelements_isNull_1) {
    /* 046 */         Object mapelements_funcResult_0 = null;
    /* 047 */         mapelements_funcResult_0 = ((scala.Function1) 
references[1] /* literal */).apply(mapelements_mutableStateArray_0[0]);
    /* 048 */
    /* 049 */         if (mapelements_funcResult_0 != null) {
    /* 050 */           mapelements_value_1 = (scala.Tuple2) 
mapelements_funcResult_0;
    /* 051 */         } else {
    /* 052 */           mapelements_isNull_1 = true;
    /* 053 */         }
    /* 054 */
    /* 055 */       }
    /* 056 */     }
    /* 057 */
    /* 058 */     serializefromobject_doConsume_0(mapelements_value_1, 
mapelements_isNull_1);
    /* 059 */
    /* 060 */   }
    ```
    
    The generated code w/ this fix;
    ```
    /* 032 */   private void mapelements_doConsume_0(java.lang.Integer 
mapelements_expr_0_0, boolean mapelements_exprIsNull_0_0) throws 
java.io.IOException {
    /* 033 */     boolean mapelements_isNull_1 = true;
    /* 034 */     scala.Tuple2 mapelements_value_1 = null;
    /* 035 */     if (!false) {
    /* 036 */       mapelements_mutableStateArray_0[0] = mapelements_expr_0_0;
    /* 037 */
    /* 038 */       mapelements_isNull_1 = false;
    /* 039 */       if (!mapelements_isNull_1) {
    /* 040 */         Object mapelements_funcResult_0 = null;
    /* 041 */         mapelements_funcResult_0 = ((scala.Function1) 
references[1] /* literal */).apply(mapelements_mutableStateArray_0[0]);
    /* 042 */
    /* 043 */         if (mapelements_funcResult_0 != null) {
    /* 044 */           mapelements_value_1 = (scala.Tuple2) 
mapelements_funcResult_0;
    /* 045 */           mapelements_isNull_1 = false;
    /* 046 */         } else {
    /* 047 */           mapelements_isNull_1 = true;
    /* 048 */         }
    /* 049 */
    /* 050 */       }
    /* 051 */     }
    /* 052 */
    /* 053 */     serializefromobject_doConsume_0(mapelements_value_1, 
mapelements_isNull_1);
    /* 054 */
    /* 055 */   }
    ```
    
    Bugfix.
    
    No.
    
    Added tests.
    
    Closes #28681 from maropu/SPARK-31854.
    
    Authored-by: Takeshi Yamamuro <yamam...@apache.org>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit b806fc458265578fddf544363b60fb5e122439b5)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/execution/objects.scala    |  4 ++--
 .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala     | 10 ++++++++++
 2 files changed, 12 insertions(+), 2 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index d051134..4b2d419 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -276,12 +276,12 @@ case class MapElementsExec(
   }
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
-    val (funcClass, methodName) = func match {
+    val (funcClass, funcName) = func match {
       case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
       case _ => FunctionUtils.getFunctionOneName(outputObjectType, 
child.output(0).dataType)
     }
     val funcObj = Literal.create(func, ObjectType(funcClass))
-    val callFunc = Invoke(funcObj, methodName, outputObjectType, child.output)
+    val callFunc = Invoke(funcObj, funcName, outputObjectType, child.output, 
propagateNull = false)
 
     val result = BindReferences.bindReference(callFunc, 
child.output).genCode(ctx)
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 4e4558b..a1e8132 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1901,6 +1901,16 @@ class DatasetSuite extends QueryTest
 
     assert(active eq SparkSession.getActiveSession.get)
   }
+
+  test("SPARK-31854: Invoke in MapElementsExec should not propagate null") {
+    Seq("true", "false").foreach { wholeStage =>
+      withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStage) {
+        val ds = Seq(1.asInstanceOf[Integer], 
null.asInstanceOf[Integer]).toDS()
+        val expectedAnswer = Seq[(Integer, Integer)]((1, 1), (null, null))
+        checkDataset(ds.map(v => (v, v)), expectedAnswer: _*)
+      }
+    }
+  }
 }
 
 object AssertExecutionId {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to