Repository: spark Updated Branches: refs/heads/master 77361a433 -> 55cc1c991
[SPARK-14139][SQL] RowEncoder should preserve schema nullability ## What changes were proposed in this pull request? The problem is: In `RowEncoder`, we use `Invoke` to get the field of an external row, which lose the nullability information. This PR creates a `GetExternalRowField` expression, so that we can preserve the nullability info. TODO: simplify the null handling logic in `RowEncoder`, to remove so many if branches, in follow-up PR. ## How was this patch tested? new tests in `RowEncoderSuite` Note that, This PR takes over https://github.com/apache/spark/pull/11980, with a little simplification, so all credits should go to koertkuipers Author: Wenchen Fan <wenc...@databricks.com> Author: Koert Kuipers <ko...@tresata.com> Closes #12364 from cloud-fan/nullable. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/55cc1c99 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/55cc1c99 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/55cc1c99 Branch: refs/heads/master Commit: 55cc1c991a9e39efb14177a948b09b7909e53e25 Parents: 77361a4 Author: Wenchen Fan <wenc...@databricks.com> Authored: Fri May 6 01:08:04 2016 +0800 Committer: Cheng Lian <l...@databricks.com> Committed: Fri May 6 01:08:04 2016 +0800 ---------------------------------------------------------------------- .../sql/catalyst/encoders/RowEncoder.scala | 36 ++++++++++------- .../sql/catalyst/expressions/objects.scala | 42 ++++++++++++++++++++ .../sql/catalyst/encoders/RowEncoderSuite.scala | 8 ++++ .../org/apache/spark/sql/DatasetSuite.scala | 18 ++++++++- 4 files changed, 88 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/55cc1c99/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 44e135c..cfde3bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -35,9 +35,8 @@ import org.apache.spark.unsafe.types.UTF8String object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - // We use an If expression to wrap extractorsFor result of StructType - val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue + val inputObject = BoundReference(0, ObjectType(cls), nullable = false) + val serializer = serializerFor(inputObject, schema) val deserializer = deserializerFor(schema) new ExpressionEncoder[Row]( schema, @@ -130,21 +129,28 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => - val method = if (f.dataType.isInstanceOf[StructType]) { - "getStruct" + val fieldValue = serializerFor( + GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)), + f.dataType + ) + if (f.nullable) { + If( + Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), + Literal.create(null, f.dataType), + fieldValue + ) } else { - "get" + fieldValue } - If( - Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), - Literal.create(null, f.dataType), - serializerFor( - Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil), - f.dataType)) } - If(IsNull(inputObject), - Literal.create(null, inputType), - CreateStruct(convertedFields)) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + CreateStruct(convertedFields)) + } else { + CreateStruct(convertedFields) + } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/55cc1c99/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 523eed8..dbaff16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -688,3 +688,45 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) ev.copy(code = code, isNull = "false", value = childGen.value) } } + +/** + * Returns the value of field at index `index` from the external row `child`. + * This class can be viewed as [[GetStructField]] for [[Row]]s instead of [[InternalRow]]s. + * + * Note that the input row and the field we try to get are both guaranteed to be not null, if they + * are null, a runtime exception will be thrown. + */ +case class GetExternalRowField( + child: Expression, + index: Int, + dataType: DataType) extends UnaryExpression with NonSQLExpression { + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val row = child.genCode(ctx) + + val getField = dataType match { + case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)""" + case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)""" + } + + val code = s""" + ${row.code} + + if (${row.isNull}) { + throw new RuntimeException("The input external row cannot be null."); + } + + if (${row.value}.isNullAt($index)) { + throw new RuntimeException("The ${index}th field of input row cannot be null."); + } + + final ${ctx.javaType(dataType)} ${ev.value} = $getField; + """ + ev.copy(code = code, isNull = "false") + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/55cc1c99/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index a8fa372..98be3b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -160,6 +160,14 @@ class RowEncoderSuite extends SparkFunSuite { .compareTo(convertedBack.getDecimal(3)) == 0) } + test("RowEncoder should preserve schema nullability") { + val schema = new StructType().add("int", IntegerType, nullable = false) + val encoder = RowEncoder(schema) + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == IntegerType) + assert(encoder.serializer.head.nullable == false) + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema) http://git-wip-us.apache.org/repos/asf/spark/blob/55cc1c99/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 68a12b0..3cb4e52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import scala.language.postfixOps -import org.apache.spark.sql.catalyst.encoders.OuterScopes +import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -658,6 +658,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val dataset = Seq(1, 2, 3).toDS() checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4) } + + test("runtime null check for RowEncoder") { + val schema = new StructType().add("i", IntegerType, nullable = false) + val df = sqlContext.range(10).map(l => { + if (l % 5 == 0) { + Row(null) + } else { + Row(l) + } + })(RowEncoder(schema)) + + val message = intercept[Exception] { + df.collect() + }.getMessage + assert(message.contains("The 0th field of input row cannot be null")) + } } case class OtherTuple(_1: String, _2: Int) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org