Repository: spark Updated Branches: refs/heads/master 8c1b867ce -> 42bfde298
[SPARK-12371][SQL] Runtime nullability check for NewInstance This PR adds a new expression `AssertNotNull` to ensure non-nullable fields of products and case classes don't receive null values at runtime. Author: Cheng Lian <l...@databricks.com> Closes #10331 from liancheng/dataset-nullability-check. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/42bfde29 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/42bfde29 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/42bfde29 Branch: refs/heads/master Commit: 42bfde29836529251a4337ea8cfc539c9c8b04b8 Parents: 8c1b867 Author: Cheng Lian <l...@databricks.com> Authored: Tue Dec 22 19:41:44 2015 +0800 Committer: Cheng Lian <l...@databricks.com> Committed: Tue Dec 22 19:41:44 2015 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/JavaTypeInference.scala | 9 +- .../spark/sql/catalyst/ScalaReflection.scala | 10 +- .../sql/catalyst/expressions/objects.scala | 40 ++++++ .../encoders/EncoderResolutionSuite.scala | 21 +++- .../scala/org/apache/spark/sql/Dataset.scala | 3 +- .../org/apache/spark/sql/JavaDatasetSuite.java | 126 ++++++++++++++++++- .../org/apache/spark/sql/DatasetSuite.scala | 33 +++++ 7 files changed, 232 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index f566d1b..a1500cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -288,7 +288,14 @@ object JavaTypeInference { val setters = properties.map { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType - p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName))) + val (_, nullable) = inferDataType(fieldType) + val constructor = constructorFor(fieldType, Some(addToPath(fieldName))) + val setter = if (nullable) { + constructor + } else { + AssertNotNull(constructor, other.getName, fieldName, fieldType.toString) + } + p.getWriteMethod.getName -> setter }.toMap val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index cc9e6af..becd019 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -326,7 +326,7 @@ object ScalaReflection extends ScalaReflection { val cls = getClassFromType(tpe) val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => - val dataType = schemaFor(fieldType).dataType + val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. @@ -336,10 +336,16 @@ object ScalaReflection extends ScalaReflection { Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - constructorFor( + val constructor = constructorFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) + + if (!nullable) { + AssertNotNull(constructor, t.toString, fieldName, fieldType.toString) + } else { + constructor + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 492cc9b..d40cd96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -624,3 +624,43 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp """ } } + +/** + * Asserts that input values of a non-nullable child expression are not null. + * + * Note that there are cases where `child.nullable == true`, while we still needs to add this + * assertion. Consider a nullable column `s` whose data type is a struct containing a non-nullable + * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all + * non-null `s`, `s.i` can't be null. + */ +case class AssertNotNull( + child: Expression, parentType: String, fieldName: String, fieldType: String) + extends UnaryExpression { + + override def dataType: DataType = child.dataType + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val childGen = child.gen(ctx) + + ev.isNull = "false" + ev.value = childGen.value + + s""" + ${childGen.code} + + if (${childGen.isNull}) { + throw new RuntimeException( + "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " + + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + + "please try to use scala.Option[_] or other nullable types " + + "(e.g. java.lang.Integer instead of int/scala.Int)." + ); + } + """ + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 815a03f..764ffdc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -36,12 +36,16 @@ class EncoderResolutionSuite extends PlanTest { val encoder = ExpressionEncoder[StringLongClass] val cls = classOf[StringLongClass] + { val attrs = Seq('a.string, 'b.int) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( cls, - toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, + Seq( + toExternalString('a.string), + AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long") + ), false, ObjectType(cls)) compareExpressions(fromRowExpr, expected) @@ -52,7 +56,10 @@ class EncoderResolutionSuite extends PlanTest { val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression val expected = NewInstance( cls, - toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, + Seq( + toExternalString('a.int.cast(StringType)), + AssertNotNull('b.long, cls.getName, "b", "Long") + ), false, ObjectType(cls)) compareExpressions(fromRowExpr, expected) @@ -69,7 +76,7 @@ class EncoderResolutionSuite extends PlanTest { val expected: Expression = NewInstance( cls, Seq( - 'a.int.cast(LongType), + AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"), If( 'b.struct('a.int, 'b.long).isNull, Literal.create(null, ObjectType(innerCls)), @@ -78,7 +85,9 @@ class EncoderResolutionSuite extends PlanTest { Seq( toExternalString( GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), - GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))), + AssertNotNull( + GetStructField('b.struct('a.int, 'b.long), 1, Some("b")), + innerCls.getName, "b", "Long")), false, ObjectType(innerCls)) )), @@ -102,7 +111,9 @@ class EncoderResolutionSuite extends PlanTest { cls, Seq( toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), - GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)), + AssertNotNull( + GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType), + cls.getName, "b", "Long")), false, ObjectType(cls)), 'b.int.cast(LongType)), http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d201d65..a763a95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.rdd.RDD @@ -64,7 +65,7 @@ import org.apache.spark.util.Utils class Dataset[T] private[sql]( @transient override val sqlContext: SQLContext, @transient override val queryExecution: QueryExecution, - tEncoder: Encoder[T]) extends Queryable with Serializable { + tEncoder: Encoder[T]) extends Queryable with Serializable with Logging { /** * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0dbaeb8..9f8db39 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,8 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.*; +import com.google.common.base.Objects; +import org.junit.rules.ExpectedException; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -39,7 +41,6 @@ import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.functions.*; @@ -741,4 +742,127 @@ public class JavaDatasetSuite implements Serializable { context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); ds.collect(); } + + public class SmallBean implements Serializable { + private String a; + + private int b; + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public String getA() { + return a; + } + + public void setA(String a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SmallBean smallBean = (SmallBean) o; + return b == smallBean.b && com.google.common.base.Objects.equal(a, smallBean.a); + } + + @Override + public int hashCode() { + return Objects.hashCode(a, b); + } + } + + public class NestedSmallBean implements Serializable { + private SmallBean f; + + public SmallBean getF() { + return f; + } + + public void setF(SmallBean f) { + this.f = f; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NestedSmallBean that = (NestedSmallBean) o; + return Objects.equal(f, that.f); + } + + @Override + public int hashCode() { + return Objects.hashCode(f); + } + } + + @Rule + public transient ExpectedException nullabilityCheck = ExpectedException.none(); + + @Test + public void testRuntimeNullabilityCheck() { + OuterScopes.addOuterScope(this); + + StructType schema = new StructType() + .add("f", new StructType() + .add("a", StringType, true) + .add("b", IntegerType, true), true); + + // Shouldn't throw runtime exception since it passes nullability check. + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", 1 + }) + }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); + + SmallBean smallBean = new SmallBean(); + smallBean.setA("hello"); + smallBean.setB(1); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + nestedSmallBean.setF(smallBean); + + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + { + Row row = new GenericRow(new Object[] { null }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + nullabilityCheck.expect(RuntimeException.class); + nullabilityCheck.expectMessage( + "Null value appeared in non-nullable field " + + "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int."); + + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", null + }) + }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); + + ds.collect(); + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/42bfde29/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index de012a9..3337996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,6 +24,7 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} class DatasetSuite extends QueryTest with SharedSQLContext { @@ -515,12 +516,44 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) } + + test("runtime nullability check") { + val schema = StructType(Seq( + StructField("f", StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )), nullable = true) + )) + + def buildDataset(rows: Row*): Dataset[NestedStruct] = { + val rowRDD = sqlContext.sparkContext.parallelize(rows) + sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] + } + + checkAnswer( + buildDataset(Row(Row("hello", 1))), + NestedStruct(ClassData("hello", 1)) + ) + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null))) + + val message = intercept[RuntimeException] { + buildDataset(Row(Row("hello", null))).collect() + }.getMessage + + assert(message.contains( + "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int." + )) + } } case class ClassData(a: String, b: Int) case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) +case class NestedStruct(f: ClassData) + /** * A class used to test serialization using encoders. This class throws exceptions when using * Java serialization -- so the only way it can be "serialized" is through our encoders. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org