Repository: spark Updated Branches: refs/heads/branch-2.1 9dfdd2adf -> a04428fe2
[SPARK-19980][SQL][BACKPORT-2.1] Add NULL checks in Bean serializer ## What changes were proposed in this pull request? A Bean serializer in `ExpressionEncoder` could change values when Beans having NULL. A concrete example is as follows; ``` scala> :paste class Outer extends Serializable { private var cls: Inner = _ def setCls(c: Inner): Unit = cls = c def getCls(): Inner = cls } class Inner extends Serializable { private var str: String = _ def setStr(s: String): Unit = str = str def getStr(): String = str } scala> Seq("""{"cls":null}""", """{"cls": {"str":null}}""").toDF().write.text("data") scala> val encoder = Encoders.bean(classOf[Outer]) scala> val schema = encoder.schema scala> val df = spark.read.schema(schema).json("data").as[Outer](encoder) scala> df.show +------+ | cls| +------+ |[null]| | null| +------+ scala> df.map(x => x)(encoder).show() +------+ | cls| +------+ |[null]| |[null]| // <-- Value changed +------+ ``` This is because the Bean serializer does not have the NULL-check expressions that the serializer of Scala's product types has. Actually, this value change does not happen in Scala's product types; ``` scala> :paste case class Outer(cls: Inner) case class Inner(str: String) scala> val encoder = Encoders.product[Outer] scala> val schema = encoder.schema scala> val df = spark.read.schema(schema).json("data").as[Outer](encoder) scala> df.show +------+ | cls| +------+ |[null]| | null| +------+ scala> df.map(x => x)(encoder).show() +------+ | cls| +------+ |[null]| | null| +------+ ``` This pr added the NULL-check expressions in Bean serializer along with the serializer of Scala's product types. ## How was this patch tested? Added tests in `JavaDatasetSuite`. Author: Takeshi Yamamuro <yamam...@apache.org> Closes #17372 from maropu/SPARK-19980-BACKPORT2.1. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a04428fe Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a04428fe Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a04428fe Branch: refs/heads/branch-2.1 Commit: a04428fe26b5b3ad998a88c81c829050fe4a0256 Parents: 9dfdd2a Author: Takeshi Yamamuro <yamam...@apache.org> Authored: Wed Mar 22 08:37:54 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Wed Mar 22 08:37:54 2017 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/JavaTypeInference.scala | 11 +++++++-- .../org/apache/spark/sql/JavaDatasetSuite.java | 24 ++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a04428fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 61c153c..2de066f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -334,7 +334,11 @@ object JavaTypeInference { */ def serializerFor(beanClass: Class[_]): CreateNamedStruct = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) + serializerFor(nullSafeInput, TypeToken.of(beanClass)) match { + case expressions.If(_, _, s: CreateNamedStruct) => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) + } } private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { @@ -417,7 +421,7 @@ object JavaTypeInference { case other => val properties = getJavaBeanProperties(other) if (properties.length > 0) { - CreateNamedStruct(properties.flatMap { p => + val nonNullOutput = CreateNamedStruct(properties.flatMap { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val fieldValue = Invoke( @@ -426,6 +430,9 @@ object JavaTypeInference { inferExternalType(fieldType.getRawType)) expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil }) + + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) } else { throw new UnsupportedOperationException( s"Cannot infer type for class ${other.getName} because it is not bean-compliant") http://git-wip-us.apache.org/repos/asf/spark/blob/a04428fe/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 8304b72..b25e349 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1305,4 +1305,28 @@ public class JavaDatasetSuite implements Serializable { spark.createDataset(data, Encoders.bean(NestedComplicatedJavaBean.class)); ds.collectAsList(); } + + @Test(expected = RuntimeException.class) + public void testNullInTopLevelBean() { + NestedSmallBean bean = new NestedSmallBean(); + // We cannot set null in top-level bean + spark.createDataset(Arrays.asList(bean, null), Encoders.bean(NestedSmallBean.class)); + } + + @Test + public void testSerializeNull() { + NestedSmallBean bean = new NestedSmallBean(); + Encoder<NestedSmallBean> encoder = Encoders.bean(NestedSmallBean.class); + List<NestedSmallBean> beans = Arrays.asList(bean); + Dataset<NestedSmallBean> ds1 = spark.createDataset(beans, encoder); + Assert.assertEquals(beans, ds1.collectAsList()); + Dataset<NestedSmallBean> ds2 = + ds1.map(new MapFunction<NestedSmallBean, NestedSmallBean>() { + @Override + public NestedSmallBean call(NestedSmallBean b) throws Exception { + return b; + } + }, encoder); + Assert.assertEquals(beans, ds2.collectAsList()); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org