Repository: spark
Updated Branches:
  refs/heads/master 39ad53f7f -> 01277d4b2


[SPARK-16097][SQL] Encoders.tuple should handle null object correctly

## What changes were proposed in this pull request?

Although the top level input object can not be null, but when we use 
`Encoders.tuple` to combine 2 encoders, their input objects are not top level 
anymore and can be null. We should handle this case.

## How was this patch tested?

new test in DatasetSuite

Author: Wenchen Fan <wenc...@databricks.com>

Closes #13807 from cloud-fan/bug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/01277d4b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/01277d4b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/01277d4b

Branch: refs/heads/master
Commit: 01277d4b259dcf9cad25eece1377162b7a8c946d
Parents: 39ad53f
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Wed Jun 22 18:32:14 2016 +0800
Committer: Cheng Lian <l...@databricks.com>
Committed: Wed Jun 22 18:32:14 2016 +0800

----------------------------------------------------------------------
 .../catalyst/encoders/ExpressionEncoder.scala   | 48 ++++++++++++++------
 .../org/apache/spark/sql/DatasetSuite.scala     |  7 +++
 2 files changed, 42 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/01277d4b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 0023ce6..1fac26c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,7 +28,7 @@ import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
 import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, 
Invoke, NewInstance}
 import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
 import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, 
DeserializeToObject, LocalRelation}
-import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
+import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, 
StructType}
 import org.apache.spark.util.Utils
 
 /**
@@ -110,16 +110,34 @@ object ExpressionEncoder {
 
     val cls = 
Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
 
-    val serializer = encoders.map {
-      case e if e.flat => e.serializer.head
-      case other => CreateStruct(other.serializer)
-    }.zipWithIndex.map { case (expr, index) =>
-      expr.transformUp {
-        case BoundReference(0, t, _) =>
-          Invoke(
-            BoundReference(0, ObjectType(cls), nullable = true),
-            s"_${index + 1}",
-            t)
+    val serializer = encoders.zipWithIndex.map { case (enc, index) =>
+      val originalInputObject = enc.serializer.head.collect { case b: 
BoundReference => b }.head
+      val newInputObject = Invoke(
+        BoundReference(0, ObjectType(cls), nullable = true),
+        s"_${index + 1}",
+        originalInputObject.dataType)
+
+      val newSerializer = enc.serializer.map(_.transformUp {
+        case b: BoundReference if b == originalInputObject => newInputObject
+      })
+
+      if (enc.flat) {
+        newSerializer.head
+      } else {
+        // For non-flat encoder, the input object is not top level anymore 
after being combined to
+        // a tuple encoder, thus it can be null and we should wrap the 
`CreateStruct` with `If` and
+        // null check to handle null case correctly.
+        // e.g. for Encoder[(Int, String)], the serializer expressions will 
create 2 columns, and is
+        // not able to handle the case when the input tuple is null. This is 
not a problem as there
+        // is a check to make sure the input object won't be null. However, if 
this encoder is used
+        // to create a bigger tuple encoder, the original input object becomes 
a filed of the new
+        // input tuple and can be null. So instead of creating a struct 
directly here, we should add
+        // a null/None check and return a null struct if the null/None check 
fails.
+        val struct = CreateStruct(newSerializer)
+        val nullCheck = Or(
+          IsNull(newInputObject),
+          Invoke(Literal.fromObject(None), "equals", BooleanType, 
newInputObject :: Nil))
+        If(nullCheck, Literal.create(null, struct.dataType), struct)
       }
     }
 
@@ -203,8 +221,12 @@ case class ExpressionEncoder[T](
   // (intermediate value is not an attribute). We assume that all serializer 
expressions use a same
   // `BoundReference` to refer to the object, and throw exception if they 
don't.
   assert(serializer.forall(_.references.isEmpty), "serializer cannot reference 
to any attributes.")
-  assert(serializer.flatMap(_.collect { case b: BoundReference => 
b}).distinct.length <= 1,
-    "all serializer expressions must use the same BoundReference.")
+  assert(serializer.flatMap { ser =>
+    val boundRefs = ser.collect { case b: BoundReference => b }
+    assert(boundRefs.nonEmpty,
+      "each serializer expression should contains at least one 
`BoundReference`")
+    boundRefs
+  }.distinct.length <= 1, "all serializer expressions must use the same 
BoundReference.")
 
   /**
    * Returns a new copy of this encoder, where the `deserializer` is resolved 
and bound to the

http://git-wip-us.apache.org/repos/asf/spark/blob/01277d4b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index f02a314..bd8479b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -830,6 +830,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
       ds.dropDuplicates("_1", "_2"),
       ("a", 1), ("a", 2), ("b", 1))
   }
+
+  test("SPARK-16097: Encoders.tuple should handle null object correctly") {
+    val enc = Encoders.tuple(Encoders.tuple(Encoders.STRING, Encoders.STRING), 
Encoders.STRING)
+    val data = Seq((("a", "b"), "c"), (null, "d"))
+    val ds = spark.createDataset(data)(enc)
+    checkDataset(ds, (("a", "b"), "c"), (null, "d"))
+  }
 }
 
 case class Generic[T](id: T, value: Double)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to