Repository: spark
Updated Branches:
  refs/heads/branch-2.1 9dfdd2adf -> a04428fe2


[SPARK-19980][SQL][BACKPORT-2.1] Add NULL checks in Bean serializer

## What changes were proposed in this pull request?
A Bean serializer in `ExpressionEncoder`  could change values when Beans having 
NULL. A concrete example is as follows;
```
scala> :paste
class Outer extends Serializable {
  private var cls: Inner = _
  def setCls(c: Inner): Unit = cls = c
  def getCls(): Inner = cls
}

class Inner extends Serializable {
  private var str: String = _
  def setStr(s: String): Unit = str = str
  def getStr(): String = str
}

scala> Seq("""{"cls":null}""", """{"cls": 
{"str":null}}""").toDF().write.text("data")
scala> val encoder = Encoders.bean(classOf[Outer])
scala> val schema = encoder.schema
scala> val df = spark.read.schema(schema).json("data").as[Outer](encoder)
scala> df.show
+------+
|   cls|
+------+
|[null]|
|  null|
+------+

scala> df.map(x => x)(encoder).show()
+------+
|   cls|
+------+
|[null]|
|[null]|     // <-- Value changed
+------+
```

This is because the Bean serializer does not have the NULL-check expressions 
that the serializer of Scala's product types has. Actually, this value change 
does not happen in Scala's product types;

```
scala> :paste
case class Outer(cls: Inner)
case class Inner(str: String)

scala> val encoder = Encoders.product[Outer]
scala> val schema = encoder.schema
scala> val df = spark.read.schema(schema).json("data").as[Outer](encoder)
scala> df.show
+------+
|   cls|
+------+
|[null]|
|  null|
+------+

scala> df.map(x => x)(encoder).show()
+------+
|   cls|
+------+
|[null]|
|  null|
+------+
```

This pr added the NULL-check expressions in Bean serializer along with the 
serializer of Scala's product types.

## How was this patch tested?
Added tests in `JavaDatasetSuite`.

Author: Takeshi Yamamuro <yamam...@apache.org>

Closes #17372 from maropu/SPARK-19980-BACKPORT2.1.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a04428fe
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a04428fe
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a04428fe

Branch: refs/heads/branch-2.1
Commit: a04428fe26b5b3ad998a88c81c829050fe4a0256
Parents: 9dfdd2a
Author: Takeshi Yamamuro <yamam...@apache.org>
Authored: Wed Mar 22 08:37:54 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Wed Mar 22 08:37:54 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  | 11 +++++++--
 .../org/apache/spark/sql/JavaDatasetSuite.java  | 24 ++++++++++++++++++++
 2 files changed, 33 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a04428fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 61c153c..2de066f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -334,7 +334,11 @@ object JavaTypeInference {
    */
   def serializerFor(beanClass: Class[_]): CreateNamedStruct = {
     val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
-    serializerFor(inputObject, 
TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
+    val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean"))
+    serializerFor(nullSafeInput, TypeToken.of(beanClass)) match {
+      case expressions.If(_, _, s: CreateNamedStruct) => s
+      case other => CreateNamedStruct(expressions.Literal("value") :: other :: 
Nil)
+    }
   }
 
   private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): 
Expression = {
@@ -417,7 +421,7 @@ object JavaTypeInference {
         case other =>
           val properties = getJavaBeanProperties(other)
           if (properties.length > 0) {
-            CreateNamedStruct(properties.flatMap { p =>
+            val nonNullOutput = CreateNamedStruct(properties.flatMap { p =>
               val fieldName = p.getName
               val fieldType = typeToken.method(p.getReadMethod).getReturnType
               val fieldValue = Invoke(
@@ -426,6 +430,9 @@ object JavaTypeInference {
                 inferExternalType(fieldType.getRawType))
               expressions.Literal(fieldName) :: serializerFor(fieldValue, 
fieldType) :: Nil
             })
+
+            val nullOutput = expressions.Literal.create(null, 
nonNullOutput.dataType)
+            expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
           } else {
             throw new UnsupportedOperationException(
               s"Cannot infer type for class ${other.getName} because it is not 
bean-compliant")

http://git-wip-us.apache.org/repos/asf/spark/blob/a04428fe/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 8304b72..b25e349 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -1305,4 +1305,28 @@ public class JavaDatasetSuite implements Serializable {
       spark.createDataset(data, 
Encoders.bean(NestedComplicatedJavaBean.class));
     ds.collectAsList();
   }
+
+  @Test(expected = RuntimeException.class)
+  public void testNullInTopLevelBean() {
+    NestedSmallBean bean = new NestedSmallBean();
+    // We cannot set null in top-level bean
+    spark.createDataset(Arrays.asList(bean, null), 
Encoders.bean(NestedSmallBean.class));
+  }
+
+  @Test
+  public void testSerializeNull() {
+    NestedSmallBean bean = new NestedSmallBean();
+    Encoder<NestedSmallBean> encoder = Encoders.bean(NestedSmallBean.class);
+    List<NestedSmallBean> beans = Arrays.asList(bean);
+    Dataset<NestedSmallBean> ds1 = spark.createDataset(beans, encoder);
+    Assert.assertEquals(beans, ds1.collectAsList());
+    Dataset<NestedSmallBean> ds2 =
+      ds1.map(new MapFunction<NestedSmallBean, NestedSmallBean>() {
+        @Override
+        public NestedSmallBean call(NestedSmallBean b) throws Exception {
+          return b;
+        }
+      }, encoder);
+    Assert.assertEquals(beans, ds2.collectAsList());
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to