Repository: spark Updated Branches: refs/heads/master 39ad53f7f -> 01277d4b2
[SPARK-16097][SQL] Encoders.tuple should handle null object correctly ## What changes were proposed in this pull request? Although the top level input object can not be null, but when we use `Encoders.tuple` to combine 2 encoders, their input objects are not top level anymore and can be null. We should handle this case. ## How was this patch tested? new test in DatasetSuite Author: Wenchen Fan <wenc...@databricks.com> Closes #13807 from cloud-fan/bug. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/01277d4b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/01277d4b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/01277d4b Branch: refs/heads/master Commit: 01277d4b259dcf9cad25eece1377162b7a8c946d Parents: 39ad53f Author: Wenchen Fan <wenc...@databricks.com> Authored: Wed Jun 22 18:32:14 2016 +0800 Committer: Cheng Lian <l...@databricks.com> Committed: Wed Jun 22 18:32:14 2016 +0800 ---------------------------------------------------------------------- .../catalyst/encoders/ExpressionEncoder.scala | 48 ++++++++++++++------ .../org/apache/spark/sql/DatasetSuite.scala | 7 +++ 2 files changed, 42 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/01277d4b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0023ce6..1fac26c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} -import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType} import org.apache.spark.util.Utils /** @@ -110,16 +110,34 @@ object ExpressionEncoder { val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val serializer = encoders.map { - case e if e.flat => e.serializer.head - case other => CreateStruct(other.serializer) - }.zipWithIndex.map { case (expr, index) => - expr.transformUp { - case BoundReference(0, t, _) => - Invoke( - BoundReference(0, ObjectType(cls), nullable = true), - s"_${index + 1}", - t) + val serializer = encoders.zipWithIndex.map { case (enc, index) => + val originalInputObject = enc.serializer.head.collect { case b: BoundReference => b }.head + val newInputObject = Invoke( + BoundReference(0, ObjectType(cls), nullable = true), + s"_${index + 1}", + originalInputObject.dataType) + + val newSerializer = enc.serializer.map(_.transformUp { + case b: BoundReference if b == originalInputObject => newInputObject + }) + + if (enc.flat) { + newSerializer.head + } else { + // For non-flat encoder, the input object is not top level anymore after being combined to + // a tuple encoder, thus it can be null and we should wrap the `CreateStruct` with `If` and + // null check to handle null case correctly. + // e.g. for Encoder[(Int, String)], the serializer expressions will create 2 columns, and is + // not able to handle the case when the input tuple is null. This is not a problem as there + // is a check to make sure the input object won't be null. However, if this encoder is used + // to create a bigger tuple encoder, the original input object becomes a filed of the new + // input tuple and can be null. So instead of creating a struct directly here, we should add + // a null/None check and return a null struct if the null/None check fails. + val struct = CreateStruct(newSerializer) + val nullCheck = Or( + IsNull(newInputObject), + Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) + If(nullCheck, Literal.create(null, struct.dataType), struct) } } @@ -203,8 +221,12 @@ case class ExpressionEncoder[T]( // (intermediate value is not an attribute). We assume that all serializer expressions use a same // `BoundReference` to refer to the object, and throw exception if they don't. assert(serializer.forall(_.references.isEmpty), "serializer cannot reference to any attributes.") - assert(serializer.flatMap(_.collect { case b: BoundReference => b}).distinct.length <= 1, - "all serializer expressions must use the same BoundReference.") + assert(serializer.flatMap { ser => + val boundRefs = ser.collect { case b: BoundReference => b } + assert(boundRefs.nonEmpty, + "each serializer expression should contains at least one `BoundReference`") + boundRefs + }.distinct.length <= 1, "all serializer expressions must use the same BoundReference.") /** * Returns a new copy of this encoder, where the `deserializer` is resolved and bound to the http://git-wip-us.apache.org/repos/asf/spark/blob/01277d4b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index f02a314..bd8479b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -830,6 +830,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ds.dropDuplicates("_1", "_2"), ("a", 1), ("a", 2), ("b", 1)) } + + test("SPARK-16097: Encoders.tuple should handle null object correctly") { + val enc = Encoders.tuple(Encoders.tuple(Encoders.STRING, Encoders.STRING), Encoders.STRING) + val data = Seq((("a", "b"), "c"), (null, "d")) + val ds = spark.createDataset(data)(enc) + checkDataset(ds, (("a", "b"), "c"), (null, "d")) + } } case class Generic[T](id: T, value: Double) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org