This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new efa0269 [SPARK-31854][SQL] Invoke in MapElementsExec should not propagate null efa0269 is described below commit efa0269080cb7f6e2591caedcdac554beaf2661b Author: Takeshi Yamamuro <yamam...@apache.org> AuthorDate: Mon Jun 1 04:50:00 2020 +0000 [SPARK-31854][SQL] Invoke in MapElementsExec should not propagate null This PR intends to fix a bug of `Dataset.map` below when the whole-stage codegen enabled; ``` scala> val ds = Seq(1.asInstanceOf[Integer], null.asInstanceOf[Integer]).toDS() scala> sql("SET spark.sql.codegen.wholeStage=true") scala> ds.map(v=>(v,v)).explain == Physical Plan == *(1) SerializeFromObject [assertnotnull(input[0, scala.Tuple2, true])._1.intValue AS _1#69, assertnotnull(input[0, scala.Tuple2, true])._2.intValue AS _2#70] +- *(1) MapElements <function1>, obj#68: scala.Tuple2 +- *(1) DeserializeToObject staticinvoke(class java.lang.Integer, ObjectType(class java.lang.Integer), valueOf, value#1, true, false), obj#67: java.lang.Integer +- LocalTableScan [value#1] // `AssertNotNull` in `SerializeFromObject` will fail; scala> ds.map(v => (v, v)).show() java.lang.NullPointerException: Null value appeared in non-nullable fails: top level Product input object If the schema is inferred from a Scala tuple/case class, or a Java bean, please try to use scala.Option[_] or other nullable types (e.g. java.lang.Integer instead of int/scala.Int). // When the whole-stage codegen disabled, the query works well; scala> sql("SET spark.sql.codegen.wholeStage=false") scala> ds.map(v=>(v,v)).show() +----+----+ | _1| _2| +----+----+ | 1| 1| |null|null| +----+----+ ``` A root cause is that `Invoke` used in `MapElementsExec` propagates input null, and then [AssertNotNull](https://github.com/apache/spark/blob/1b780f364bfbb46944fe805a024bb6c32f5d2dde/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala#L253-L255) in `SerializeFromObject` fails because a top-level row becomes null. So, `MapElementsExec` should not return `null` but `(null, null)`. NOTE: the generated code of the query above in the current master; ``` /* 033 */ private void mapelements_doConsume_0(java.lang.Integer mapelements_expr_0_0, boolean mapelements_exprIsNull_0_0) throws java.io.IOException { /* 034 */ boolean mapelements_isNull_1 = true; /* 035 */ scala.Tuple2 mapelements_value_1 = null; /* 036 */ if (!false) { /* 037 */ mapelements_resultIsNull_0 = false; /* 038 */ /* 039 */ if (!mapelements_resultIsNull_0) { /* 040 */ mapelements_resultIsNull_0 = mapelements_exprIsNull_0_0; /* 041 */ mapelements_mutableStateArray_0[0] = mapelements_expr_0_0; /* 042 */ } /* 043 */ /* 044 */ mapelements_isNull_1 = mapelements_resultIsNull_0; /* 045 */ if (!mapelements_isNull_1) { /* 046 */ Object mapelements_funcResult_0 = null; /* 047 */ mapelements_funcResult_0 = ((scala.Function1) references[1] /* literal */).apply(mapelements_mutableStateArray_0[0]); /* 048 */ /* 049 */ if (mapelements_funcResult_0 != null) { /* 050 */ mapelements_value_1 = (scala.Tuple2) mapelements_funcResult_0; /* 051 */ } else { /* 052 */ mapelements_isNull_1 = true; /* 053 */ } /* 054 */ /* 055 */ } /* 056 */ } /* 057 */ /* 058 */ serializefromobject_doConsume_0(mapelements_value_1, mapelements_isNull_1); /* 059 */ /* 060 */ } ``` The generated code w/ this fix; ``` /* 032 */ private void mapelements_doConsume_0(java.lang.Integer mapelements_expr_0_0, boolean mapelements_exprIsNull_0_0) throws java.io.IOException { /* 033 */ boolean mapelements_isNull_1 = true; /* 034 */ scala.Tuple2 mapelements_value_1 = null; /* 035 */ if (!false) { /* 036 */ mapelements_mutableStateArray_0[0] = mapelements_expr_0_0; /* 037 */ /* 038 */ mapelements_isNull_1 = false; /* 039 */ if (!mapelements_isNull_1) { /* 040 */ Object mapelements_funcResult_0 = null; /* 041 */ mapelements_funcResult_0 = ((scala.Function1) references[1] /* literal */).apply(mapelements_mutableStateArray_0[0]); /* 042 */ /* 043 */ if (mapelements_funcResult_0 != null) { /* 044 */ mapelements_value_1 = (scala.Tuple2) mapelements_funcResult_0; /* 045 */ mapelements_isNull_1 = false; /* 046 */ } else { /* 047 */ mapelements_isNull_1 = true; /* 048 */ } /* 049 */ /* 050 */ } /* 051 */ } /* 052 */ /* 053 */ serializefromobject_doConsume_0(mapelements_value_1, mapelements_isNull_1); /* 054 */ /* 055 */ } ``` Bugfix. No. Added tests. Closes #28681 from maropu/SPARK-31854. Authored-by: Takeshi Yamamuro <yamam...@apache.org> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit b806fc458265578fddf544363b60fb5e122439b5) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../main/scala/org/apache/spark/sql/execution/objects.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d051134..4b2d419 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -276,12 +276,12 @@ case class MapElementsExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val (funcClass, methodName) = func match { + val (funcClass, funcName) = func match { case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" case _ => FunctionUtils.getFunctionOneName(outputObjectType, child.output(0).dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) - val callFunc = Invoke(funcObj, methodName, outputObjectType, child.output) + val callFunc = Invoke(funcObj, funcName, outputObjectType, child.output, propagateNull = false) val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 4e4558b..a1e8132 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1901,6 +1901,16 @@ class DatasetSuite extends QueryTest assert(active eq SparkSession.getActiveSession.get) } + + test("SPARK-31854: Invoke in MapElementsExec should not propagate null") { + Seq("true", "false").foreach { wholeStage => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStage) { + val ds = Seq(1.asInstanceOf[Integer], null.asInstanceOf[Integer]).toDS() + val expectedAnswer = Seq[(Integer, Integer)]((1, 1), (null, null)) + checkDataset(ds.map(v => (v, v)), expectedAnswer: _*) + } + } + } } object AssertExecutionId { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org