Repository: spark
Updated Branches:
  refs/heads/master 8c1b867ce -> 42bfde298


[SPARK-12371][SQL] Runtime nullability check for NewInstance

This PR adds a new expression `AssertNotNull` to ensure non-nullable fields of 
products and case classes don't receive null values at runtime.

Author: Cheng Lian <l...@databricks.com>

Closes #10331 from liancheng/dataset-nullability-check.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/42bfde29
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/42bfde29
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/42bfde29

Branch: refs/heads/master
Commit: 42bfde29836529251a4337ea8cfc539c9c8b04b8
Parents: 8c1b867
Author: Cheng Lian <l...@databricks.com>
Authored: Tue Dec 22 19:41:44 2015 +0800
Committer: Cheng Lian <l...@databricks.com>
Committed: Tue Dec 22 19:41:44 2015 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  |   9 +-
 .../spark/sql/catalyst/ScalaReflection.scala    |  10 +-
 .../sql/catalyst/expressions/objects.scala      |  40 ++++++
 .../encoders/EncoderResolutionSuite.scala       |  21 +++-
 .../scala/org/apache/spark/sql/Dataset.scala    |   3 +-
 .../org/apache/spark/sql/JavaDatasetSuite.java  | 126 ++++++++++++++++++-
 .../org/apache/spark/sql/DatasetSuite.scala     |  33 +++++
 7 files changed, 232 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index f566d1b..a1500cb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -288,7 +288,14 @@ object JavaTypeInference {
         val setters = properties.map { p =>
           val fieldName = p.getName
           val fieldType = typeToken.method(p.getReadMethod).getReturnType
-          p.getWriteMethod.getName -> constructorFor(fieldType, 
Some(addToPath(fieldName)))
+          val (_, nullable) = inferDataType(fieldType)
+          val constructor = constructorFor(fieldType, 
Some(addToPath(fieldName)))
+          val setter = if (nullable) {
+            constructor
+          } else {
+            AssertNotNull(constructor, other.getName, fieldName, 
fieldType.toString)
+          }
+          p.getWriteMethod.getName -> setter
         }.toMap
 
         val newInstance = NewInstance(other, Nil, propagateNull = false, 
ObjectType(other))

http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index cc9e6af..becd019 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -326,7 +326,7 @@ object ScalaReflection extends ScalaReflection {
         val cls = getClassFromType(tpe)
 
         val arguments = params.zipWithIndex.map { case ((fieldName, 
fieldType), i) =>
-          val dataType = schemaFor(fieldType).dataType
+          val Schema(dataType, nullable) = schemaFor(fieldType)
           val clsName = getClassNameFromType(fieldType)
           val newTypePath = s"""- field (class: "$clsName", name: 
"$fieldName")""" +: walkedTypePath
           // For tuples, we based grab the inner fields by ordinal instead of 
name.
@@ -336,10 +336,16 @@ object ScalaReflection extends ScalaReflection {
               Some(addToPathOrdinal(i, dataType, newTypePath)),
               newTypePath)
           } else {
-            constructorFor(
+            val constructor = constructorFor(
               fieldType,
               Some(addToPath(fieldName, dataType, newTypePath)),
               newTypePath)
+
+            if (!nullable) {
+              AssertNotNull(constructor, t.toString, fieldName, 
fieldType.toString)
+            } else {
+              constructor
+            }
           }
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 492cc9b..d40cd96 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -624,3 +624,43 @@ case class InitializeJavaBean(beanInstance: Expression, 
setters: Map[String, Exp
      """
   }
 }
+
+/**
+ * Asserts that input values of a non-nullable child expression are not null.
+ *
+ * Note that there are cases where `child.nullable == true`, while we still 
needs to add this
+ * assertion.  Consider a nullable column `s` whose data type is a struct 
containing a non-nullable
+ * `Int` field named `i`.  Expression `s.i` is nullable because `s` can be 
null.  However, for all
+ * non-null `s`, `s.i` can't be null.
+ */
+case class AssertNotNull(
+    child: Expression, parentType: String, fieldName: String, fieldType: 
String)
+  extends UnaryExpression {
+
+  override def dataType: DataType = child.dataType
+
+  override def nullable: Boolean = false
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is 
supported.")
+
+  override protected def genCode(ctx: CodeGenContext, ev: 
GeneratedExpressionCode): String = {
+    val childGen = child.gen(ctx)
+
+    ev.isNull = "false"
+    ev.value = childGen.value
+
+    s"""
+      ${childGen.code}
+
+      if (${childGen.isNull}) {
+        throw new RuntimeException(
+          "Null value appeared in non-nullable field $parentType.$fieldName of 
type $fieldType. " +
+          "If the schema is inferred from a Scala tuple/case class, or a Java 
bean, " +
+          "please try to use scala.Option[_] or other nullable types " +
+          "(e.g. java.lang.Integer instead of int/scala.Int)."
+        );
+      }
+     """
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 815a03f..764ffdc 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -36,12 +36,16 @@ class EncoderResolutionSuite extends PlanTest {
     val encoder = ExpressionEncoder[StringLongClass]
     val cls = classOf[StringLongClass]
 
+
     {
       val attrs = Seq('a.string, 'b.int)
       val fromRowExpr: Expression = encoder.resolve(attrs, 
null).fromRowExpression
       val expected: Expression = NewInstance(
         cls,
-        toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil,
+        Seq(
+          toExternalString('a.string),
+          AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
+        ),
         false,
         ObjectType(cls))
       compareExpressions(fromRowExpr, expected)
@@ -52,7 +56,10 @@ class EncoderResolutionSuite extends PlanTest {
       val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
       val expected = NewInstance(
         cls,
-        toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil,
+        Seq(
+          toExternalString('a.int.cast(StringType)),
+          AssertNotNull('b.long, cls.getName, "b", "Long")
+        ),
         false,
         ObjectType(cls))
       compareExpressions(fromRowExpr, expected)
@@ -69,7 +76,7 @@ class EncoderResolutionSuite extends PlanTest {
     val expected: Expression = NewInstance(
       cls,
       Seq(
-        'a.int.cast(LongType),
+        AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
         If(
           'b.struct('a.int, 'b.long).isNull,
           Literal.create(null, ObjectType(innerCls)),
@@ -78,7 +85,9 @@ class EncoderResolutionSuite extends PlanTest {
             Seq(
               toExternalString(
                 GetStructField('b.struct('a.int, 'b.long), 0, 
Some("a")).cast(StringType)),
-              GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))),
+              AssertNotNull(
+                GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
+                innerCls.getName, "b", "Long")),
             false,
             ObjectType(innerCls))
         )),
@@ -102,7 +111,9 @@ class EncoderResolutionSuite extends PlanTest {
           cls,
           Seq(
             toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, 
Some("a"))),
-            GetStructField('a.struct('a.string, 'b.byte), 1, 
Some("b")).cast(LongType)),
+            AssertNotNull(
+              GetStructField('a.struct('a.string, 'b.byte), 1, 
Some("b")).cast(LongType),
+              cls.getName, "b", "Long")),
           false,
           ObjectType(cls)),
         'b.int.cast(LongType)),

http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index d201d65..a763a95 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import scala.collection.JavaConverters._
 
+import org.apache.spark.Logging
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.function._
 import org.apache.spark.rdd.RDD
@@ -64,7 +65,7 @@ import org.apache.spark.util.Utils
 class Dataset[T] private[sql](
     @transient override val sqlContext: SQLContext,
     @transient override val queryExecution: QueryExecution,
-    tEncoder: Encoder[T]) extends Queryable with Serializable {
+    tEncoder: Encoder[T]) extends Queryable with Serializable with Logging {
 
   /**
    * An unresolved version of the internal encoder for the type of this 
[[Dataset]].  This one is

http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0dbaeb8..9f8db39 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,6 +23,8 @@ import java.sql.Date;
 import java.sql.Timestamp;
 import java.util.*;
 
+import com.google.common.base.Objects;
+import org.junit.rules.ExpectedException;
 import scala.Tuple2;
 import scala.Tuple3;
 import scala.Tuple4;
@@ -39,7 +41,6 @@ import org.apache.spark.sql.expressions.Aggregator;
 import org.apache.spark.sql.test.TestSQLContext;
 import org.apache.spark.sql.catalyst.encoders.OuterScopes;
 import org.apache.spark.sql.catalyst.expressions.GenericRow;
-import org.apache.spark.sql.types.DecimalType;
 import org.apache.spark.sql.types.StructType;
 
 import static org.apache.spark.sql.functions.*;
@@ -741,4 +742,127 @@ public class JavaDatasetSuite implements Serializable {
       context.createDataset(Arrays.asList(obj), 
Encoders.bean(SimpleJavaBean2.class));
     ds.collect();
   }
+
+  public class SmallBean implements Serializable {
+    private String a;
+
+    private int b;
+
+    public int getB() {
+      return b;
+    }
+
+    public void setB(int b) {
+      this.b = b;
+    }
+
+    public String getA() {
+      return a;
+    }
+
+    public void setA(String a) {
+      this.a = a;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+      SmallBean smallBean = (SmallBean) o;
+      return b == smallBean.b && com.google.common.base.Objects.equal(a, 
smallBean.a);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hashCode(a, b);
+    }
+  }
+
+  public class NestedSmallBean implements Serializable {
+    private SmallBean f;
+
+    public SmallBean getF() {
+      return f;
+    }
+
+    public void setF(SmallBean f) {
+      this.f = f;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+      NestedSmallBean that = (NestedSmallBean) o;
+      return Objects.equal(f, that.f);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hashCode(f);
+    }
+  }
+
+  @Rule
+  public transient ExpectedException nullabilityCheck = 
ExpectedException.none();
+
+  @Test
+  public void testRuntimeNullabilityCheck() {
+    OuterScopes.addOuterScope(this);
+
+    StructType schema = new StructType()
+      .add("f", new StructType()
+        .add("a", StringType, true)
+        .add("b", IntegerType, true), true);
+
+    // Shouldn't throw runtime exception since it passes nullability check.
+    {
+      Row row = new GenericRow(new Object[] {
+          new GenericRow(new Object[] {
+              "hello", 1
+          })
+      });
+
+      DataFrame df = context.createDataFrame(Collections.singletonList(row), 
schema);
+      Dataset<NestedSmallBean> ds = 
df.as(Encoders.bean(NestedSmallBean.class));
+
+      SmallBean smallBean = new SmallBean();
+      smallBean.setA("hello");
+      smallBean.setB(1);
+
+      NestedSmallBean nestedSmallBean = new NestedSmallBean();
+      nestedSmallBean.setF(smallBean);
+
+      Assert.assertEquals(ds.collectAsList(), 
Collections.singletonList(nestedSmallBean));
+    }
+
+    // Shouldn't throw runtime exception when parent object (`ClassData`) is 
null
+    {
+      Row row = new GenericRow(new Object[] { null });
+
+      DataFrame df = context.createDataFrame(Collections.singletonList(row), 
schema);
+      Dataset<NestedSmallBean> ds = 
df.as(Encoders.bean(NestedSmallBean.class));
+
+      NestedSmallBean nestedSmallBean = new NestedSmallBean();
+      Assert.assertEquals(ds.collectAsList(), 
Collections.singletonList(nestedSmallBean));
+    }
+
+    nullabilityCheck.expect(RuntimeException.class);
+    nullabilityCheck.expectMessage(
+      "Null value appeared in non-nullable field " +
+        "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
+
+    {
+      Row row = new GenericRow(new Object[] {
+          new GenericRow(new Object[] {
+              "hello", null
+          })
+      });
+
+      DataFrame df = context.createDataFrame(Collections.singletonList(row), 
schema);
+      Dataset<NestedSmallBean> ds = 
df.as(Encoders.bean(NestedSmallBean.class));
+
+      ds.collect();
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index de012a9..3337996 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -24,6 +24,7 @@ import scala.language.postfixOps
 
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, 
StructType}
 
 
 class DatasetSuite extends QueryTest with SharedSQLContext {
@@ -515,12 +516,44 @@ class DatasetSuite extends QueryTest with 
SharedSQLContext {
     }
     assert(e.getMessage.contains("cannot resolve 'c' given input columns a, 
b"), e.getMessage)
   }
+
+  test("runtime nullability check") {
+    val schema = StructType(Seq(
+      StructField("f", StructType(Seq(
+        StructField("a", StringType, nullable = true),
+        StructField("b", IntegerType, nullable = false)
+      )), nullable = true)
+    ))
+
+    def buildDataset(rows: Row*): Dataset[NestedStruct] = {
+      val rowRDD = sqlContext.sparkContext.parallelize(rows)
+      sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct]
+    }
+
+    checkAnswer(
+      buildDataset(Row(Row("hello", 1))),
+      NestedStruct(ClassData("hello", 1))
+    )
+
+    // Shouldn't throw runtime exception when parent object (`ClassData`) is 
null
+    assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null)))
+
+    val message = intercept[RuntimeException] {
+      buildDataset(Row(Row("hello", null))).collect()
+    }.getMessage
+
+    assert(message.contains(
+      "Null value appeared in non-nullable field 
org.apache.spark.sql.ClassData.b of type Int."
+    ))
+  }
 }
 
 case class ClassData(a: String, b: Int)
 case class ClassData2(c: String, d: Int)
 case class ClassNullableData(a: String, b: Integer)
 
+case class NestedStruct(f: ClassData)
+
 /**
  * A class used to test serialization using encoders. This class throws 
exceptions when using
  * Java serialization -- so the only way it can be "serialized" is through our 
encoders.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to