This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new 6f6dcc8 [SPARK-31854][SQL][2.4] Invoke in MapElementsExec should not propagate null 6f6dcc8 is described below commit 6f6dcc8eccacd3a567c08789c53ceae8fbc124de Author: Takeshi Yamamuro <yamam...@apache.org> AuthorDate: Mon Jun 1 20:09:09 2020 +0900 [SPARK-31854][SQL][2.4] Invoke in MapElementsExec should not propagate null ### What changes were proposed in this pull request? This PR intends to fix a bug of `Dataset.map` below when the whole-stage codegen enabled; ``` scala> val ds = Seq(1.asInstanceOf[Integer], null.asInstanceOf[Integer]).toDS() scala> sql("SET spark.sql.codegen.wholeStage=true") scala> ds.map(v=>(v,v)).explain == Physical Plan == *(1) SerializeFromObject [assertnotnull(input[0, scala.Tuple2, true])._1.intValue AS _1#69, assertnotnull(input[0, scala.Tuple2, true])._2.intValue AS _2#70] +- *(1) MapElements <function1>, obj#68: scala.Tuple2 +- *(1) DeserializeToObject staticinvoke(class java.lang.Integer, ObjectType(class java.lang.Integer), valueOf, value#1, true, false), obj#67: java.lang.Integer +- LocalTableScan [value#1] // `AssertNotNull` in `SerializeFromObject` will fail; scala> ds.map(v => (v, v)).show() java.lang.NullPointerException: Null value appeared in non-nullable fails: top level Product input object If the schema is inferred from a Scala tuple/case class, or a Java bean, please try to use scala.Option[_] or other nullable types (e.g. java.lang.Integer instead of int/scala.Int). // When the whole-stage codegen disabled, the query works well; scala> sql("SET spark.sql.codegen.wholeStage=false") scala> ds.map(v=>(v,v)).show() +----+----+ | _1| _2| +----+----+ | 1| 1| |null|null| +----+----+ ``` A root cause is that `Invoke` used in `MapElementsExec` propagates input null, and then [AssertNotNull](https://github.com/apache/spark/blob/1b780f364bfbb46944fe805a024bb6c32f5d2dde/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala#L253-L255) in `SerializeFromObject` fails because a top-level row becomes null. So, `MapElementsExec` should not return `null` but `(null, null)`. NOTE: the generated code of the query above in the current master; ``` /* 033 */ private void mapelements_doConsume_0(java.lang.Integer mapelements_expr_0_0, boolean mapelements_exprIsNull_0_0) throws java.io.IOException { /* 034 */ boolean mapelements_isNull_1 = true; /* 035 */ scala.Tuple2 mapelements_value_1 = null; /* 036 */ if (!false) { /* 037 */ mapelements_resultIsNull_0 = false; /* 038 */ /* 039 */ if (!mapelements_resultIsNull_0) { /* 040 */ mapelements_resultIsNull_0 = mapelements_exprIsNull_0_0; /* 041 */ mapelements_mutableStateArray_0[0] = mapelements_expr_0_0; /* 042 */ } /* 043 */ /* 044 */ mapelements_isNull_1 = mapelements_resultIsNull_0; /* 045 */ if (!mapelements_isNull_1) { /* 046 */ Object mapelements_funcResult_0 = null; /* 047 */ mapelements_funcResult_0 = ((scala.Function1) references[1] /* literal */).apply(mapelements_mutableStateArray_0[0]); /* 048 */ /* 049 */ if (mapelements_funcResult_0 != null) { /* 050 */ mapelements_value_1 = (scala.Tuple2) mapelements_funcResult_0; /* 051 */ } else { /* 052 */ mapelements_isNull_1 = true; /* 053 */ } /* 054 */ /* 055 */ } /* 056 */ } /* 057 */ /* 058 */ serializefromobject_doConsume_0(mapelements_value_1, mapelements_isNull_1); /* 059 */ /* 060 */ } ``` The generated code w/ this fix; ``` /* 032 */ private void mapelements_doConsume_0(java.lang.Integer mapelements_expr_0_0, boolean mapelements_exprIsNull_0_0) throws java.io.IOException { /* 033 */ boolean mapelements_isNull_1 = true; /* 034 */ scala.Tuple2 mapelements_value_1 = null; /* 035 */ if (!false) { /* 036 */ mapelements_mutableStateArray_0[0] = mapelements_expr_0_0; /* 037 */ /* 038 */ mapelements_isNull_1 = false; /* 039 */ if (!mapelements_isNull_1) { /* 040 */ Object mapelements_funcResult_0 = null; /* 041 */ mapelements_funcResult_0 = ((scala.Function1) references[1] /* literal */).apply(mapelements_mutableStateArray_0[0]); /* 042 */ /* 043 */ if (mapelements_funcResult_0 != null) { /* 044 */ mapelements_value_1 = (scala.Tuple2) mapelements_funcResult_0; /* 045 */ mapelements_isNull_1 = false; /* 046 */ } else { /* 047 */ mapelements_isNull_1 = true; /* 048 */ } /* 049 */ /* 050 */ } /* 051 */ } /* 052 */ /* 053 */ serializefromobject_doConsume_0(mapelements_value_1, mapelements_isNull_1); /* 054 */ /* 055 */ } ``` This comes from https://github.com/apache/spark/pull/28681 ### Why are the changes needed? Bugfix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added tests. Closes #28691 from maropu/SPARK-31854-BRANCH2.4. Authored-by: Takeshi Yamamuro <yamam...@apache.org> Signed-off-by: HyukjinKwon <gurwls...@apache.org> --- .../main/scala/org/apache/spark/sql/execution/objects.scala | 3 ++- .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 03d1bbf..27673fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -217,7 +217,8 @@ case class MapElementsExec( case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) - val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) + val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output, + propagateNull = false) val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 01d0877..08ebf8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1576,6 +1576,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } assert(thrownException.message.contains("Cannot up cast `id` from bigint to tinyint")) } + + test("SPARK-31854: Invoke in MapElementsExec should not propagate null") { + Seq("true", "false").foreach { wholeStage => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStage) { + val ds = Seq(1.asInstanceOf[Integer], null.asInstanceOf[Integer]).toDS() + val expectedAnswer = Seq[(Integer, Integer)]((1, 1), (null, null)) + checkDataset(ds.map(v => (v, v)), expectedAnswer: _*) + } + } + } } case class TestDataUnion(x: Int, y: Int, z: Int) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org