Repository: spark
Updated Branches:
  refs/heads/master 77361a433 -> 55cc1c991


[SPARK-14139][SQL] RowEncoder should preserve schema nullability

## What changes were proposed in this pull request?

The problem is: In `RowEncoder`, we use `Invoke` to get the field of an 
external row, which lose the nullability information. This PR creates a 
`GetExternalRowField` expression, so that we can preserve the nullability info.

TODO: simplify the null handling logic in `RowEncoder`, to remove so many if 
branches, in follow-up PR.

## How was this patch tested?

new tests in `RowEncoderSuite`

Note that, This PR takes over https://github.com/apache/spark/pull/11980, with 
a little simplification, so all credits should go to koertkuipers

Author: Wenchen Fan <wenc...@databricks.com>
Author: Koert Kuipers <ko...@tresata.com>

Closes #12364 from cloud-fan/nullable.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/55cc1c99
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/55cc1c99
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/55cc1c99

Branch: refs/heads/master
Commit: 55cc1c991a9e39efb14177a948b09b7909e53e25
Parents: 77361a4
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Fri May 6 01:08:04 2016 +0800
Committer: Cheng Lian <l...@databricks.com>
Committed: Fri May 6 01:08:04 2016 +0800

----------------------------------------------------------------------
 .../sql/catalyst/encoders/RowEncoder.scala      | 36 ++++++++++-------
 .../sql/catalyst/expressions/objects.scala      | 42 ++++++++++++++++++++
 .../sql/catalyst/encoders/RowEncoderSuite.scala |  8 ++++
 .../org/apache/spark/sql/DatasetSuite.scala     | 18 ++++++++-
 4 files changed, 88 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/55cc1c99/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 44e135c..cfde3bf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -35,9 +35,8 @@ import org.apache.spark.unsafe.types.UTF8String
 object RowEncoder {
   def apply(schema: StructType): ExpressionEncoder[Row] = {
     val cls = classOf[Row]
-    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
-    // We use an If expression to wrap extractorsFor result of StructType
-    val serializer = serializerFor(inputObject, 
schema).asInstanceOf[If].falseValue
+    val inputObject = BoundReference(0, ObjectType(cls), nullable = false)
+    val serializer = serializerFor(inputObject, schema)
     val deserializer = deserializerFor(schema)
     new ExpressionEncoder[Row](
       schema,
@@ -130,21 +129,28 @@ object RowEncoder {
 
     case StructType(fields) =>
       val convertedFields = fields.zipWithIndex.map { case (f, i) =>
-        val method = if (f.dataType.isInstanceOf[StructType]) {
-          "getStruct"
+        val fieldValue = serializerFor(
+          GetExternalRowField(inputObject, i, 
externalDataTypeForInput(f.dataType)),
+          f.dataType
+        )
+        if (f.nullable) {
+          If(
+            Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
+            Literal.create(null, f.dataType),
+            fieldValue
+          )
         } else {
-          "get"
+          fieldValue
         }
-        If(
-          Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
-          Literal.create(null, f.dataType),
-          serializerFor(
-            Invoke(inputObject, method, externalDataTypeForInput(f.dataType), 
Literal(i) :: Nil),
-            f.dataType))
       }
-      If(IsNull(inputObject),
-        Literal.create(null, inputType),
-        CreateStruct(convertedFields))
+
+      if (inputObject.nullable) {
+        If(IsNull(inputObject),
+          Literal.create(null, inputType),
+          CreateStruct(convertedFields))
+      } else {
+        CreateStruct(convertedFields)
+      }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/55cc1c99/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 523eed8..dbaff16 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -688,3 +688,45 @@ case class AssertNotNull(child: Expression, 
walkedTypePath: Seq[String])
     ev.copy(code = code, isNull = "false", value = childGen.value)
   }
 }
+
+/**
+ * Returns the value of field at index `index` from the external row `child`.
+ * This class can be viewed as [[GetStructField]] for [[Row]]s instead of 
[[InternalRow]]s.
+ *
+ * Note that the input row and the field we try to get are both guaranteed to 
be not null, if they
+ * are null, a runtime exception will be thrown.
+ */
+case class GetExternalRowField(
+    child: Expression,
+    index: Int,
+    dataType: DataType) extends UnaryExpression with NonSQLExpression {
+
+  override def nullable: Boolean = false
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is 
supported")
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val row = child.genCode(ctx)
+
+    val getField = dataType match {
+      case ObjectType(x) if x == classOf[Row] => 
s"""${row.value}.getStruct($index)"""
+      case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)"""
+    }
+
+    val code = s"""
+      ${row.code}
+
+      if (${row.isNull}) {
+        throw new RuntimeException("The input external row cannot be null.");
+      }
+
+      if (${row.value}.isNullAt($index)) {
+        throw new RuntimeException("The ${index}th field of input row cannot 
be null.");
+      }
+
+      final ${ctx.javaType(dataType)} ${ev.value} = $getField;
+     """
+    ev.copy(code = code, isNull = "false")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/55cc1c99/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index a8fa372..98be3b0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -160,6 +160,14 @@ class RowEncoderSuite extends SparkFunSuite {
       .compareTo(convertedBack.getDecimal(3)) == 0)
   }
 
+  test("RowEncoder should preserve schema nullability") {
+    val schema = new StructType().add("int", IntegerType, nullable = false)
+    val encoder = RowEncoder(schema)
+    assert(encoder.serializer.length == 1)
+    assert(encoder.serializer.head.dataType == IntegerType)
+    assert(encoder.serializer.head.nullable == false)
+  }
+
   private def encodeDecodeTest(schema: StructType): Unit = {
     test(s"encode/decode: ${schema.simpleString}") {
       val encoder = RowEncoder(schema)

http://git-wip-us.apache.org/repos/asf/spark/blob/55cc1c99/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 68a12b0..3cb4e52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
 
 import scala.language.postfixOps
 
-import org.apache.spark.sql.catalyst.encoders.OuterScopes
+import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
@@ -658,6 +658,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
     val dataset = Seq(1, 2, 3).toDS()
     checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4)
   }
+
+  test("runtime null check for RowEncoder") {
+    val schema = new StructType().add("i", IntegerType, nullable = false)
+    val df = sqlContext.range(10).map(l => {
+      if (l % 5 == 0) {
+        Row(null)
+      } else {
+        Row(l)
+      }
+    })(RowEncoder(schema))
+
+    val message = intercept[Exception] {
+      df.collect()
+    }.getMessage
+    assert(message.contains("The 0th field of input row cannot be null"))
+  }
 }
 
 case class OtherTuple(_1: String, _2: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to