Repository: spark Updated Branches: refs/heads/branch-2.0 e6e1d8232 -> 38a626a54
[SPARK-15657][SQL] RowEncoder should validate the data type of input object ## What changes were proposed in this pull request? This PR improves the error handling of `RowEncoder`. When we create a `RowEncoder` with a given schema, we should validate the data type of input object. e.g. we should throw an exception when a field is boolean but is declared as a string column. This PR also removes the support to use `Product` as a valid external type of struct type. This support is added at https://github.com/apache/spark/pull/9712, but is incomplete, e.g. nested product, product in array are both not working. However, we never officially support this feature and I think it's ok to ban it. ## How was this patch tested? new tests in `RowEncoderSuite`. Author: Wenchen Fan <wenc...@databricks.com> Closes #13401 from cloud-fan/bug. (cherry picked from commit 30c4774f33fed63b7d400d220d710fb432f599a8) Signed-off-by: Cheng Lian <l...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/38a626a5 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/38a626a5 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/38a626a5 Branch: refs/heads/branch-2.0 Commit: 38a626a54dd0fac0ca460e1ba534048de513bc29 Parents: e6e1d82 Author: Wenchen Fan <wenc...@databricks.com> Authored: Sun Jun 5 15:59:52 2016 -0700 Committer: Cheng Lian <l...@databricks.com> Committed: Sun Jun 5 16:00:00 2016 -0700 ---------------------------------------------------------------------- .../main/scala/org/apache/spark/sql/Row.scala | 10 +--- .../sql/catalyst/encoders/RowEncoder.scala | 17 ++++-- .../catalyst/expressions/objects/objects.scala | 61 +++++++++++++++++--- .../sql/catalyst/encoders/RowEncoderSuite.scala | 47 ++++++++++----- 4 files changed, 95 insertions(+), 40 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/38a626a5/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index a257b83..391001d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -304,15 +304,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getStruct(i: Int): Row = { - // Product and Row both are recognized as StructType in a Row - val t = get(i) - if (t.isInstanceOf[Product]) { - Row.fromTuple(t.asInstanceOf[Product]) - } else { - t.asInstanceOf[Row] - } - } + def getStruct(i: Int): Row = getAs[Row](i) /** * Returns the value at position i. http://git-wip-us.apache.org/repos/asf/spark/blob/38a626a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 6cd7b34..67fca15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -51,7 +51,7 @@ import org.apache.spark.unsafe.types.UTF8String * BinaryType -> byte array * ArrayType -> scala.collection.Seq or Array * MapType -> scala.collection.Map - * StructType -> org.apache.spark.sql.Row or Product + * StructType -> org.apache.spark.sql.Row * }}} */ object RowEncoder { @@ -121,11 +121,15 @@ object RowEncoder { case t @ ArrayType(et, _) => et match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + // TODO: validate input type for primitive array. NewInstance( classOf[GenericArrayData], inputObject :: Nil, dataType = t) - case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et)) + case _ => MapObjects( + element => serializerFor(ValidateExternalType(element, et), et), + inputObject, + ObjectType(classOf[Object])) } case t @ MapType(kt, vt, valueNullable) => @@ -151,8 +155,9 @@ object RowEncoder { case StructType(fields) => val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => val fieldValue = serializerFor( - GetExternalRowField( - inputObject, index, field.name, externalDataTypeForInput(field.dataType)), + ValidateExternalType( + GetExternalRowField(inputObject, index, field.name), + field.dataType), field.dataType) val convertedField = if (field.nullable) { If( @@ -183,7 +188,7 @@ object RowEncoder { * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or * `org.apache.spark.sql.types.Decimal`. */ - private def externalDataTypeForInput(dt: DataType): DataType = dt match { + def externalDataTypeForInput(dt: DataType): DataType = dt match { // In order to support both Decimal and java/scala BigDecimal in external row, we make this // as java.lang.Object. case _: DecimalType => ObjectType(classOf[java.lang.Object]) @@ -192,7 +197,7 @@ object RowEncoder { case _ => externalDataTypeFor(dt) } - private def externalDataTypeFor(dt: DataType): DataType = dt match { + def externalDataTypeFor(dt: DataType): DataType = dt match { case _ if ScalaReflection.isNativeType(dt) => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) http://git-wip-us.apache.org/repos/asf/spark/blob/38a626a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d4c71bf..87c8a2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -26,6 +26,7 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.GenericArrayData @@ -692,22 +693,17 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) case class GetExternalRowField( child: Expression, index: Int, - fieldName: String, - dataType: DataType) extends UnaryExpression with NonSQLExpression { + fieldName: String) extends UnaryExpression with NonSQLExpression { override def nullable: Boolean = false + override def dataType: DataType = ObjectType(classOf[Object]) + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val row = child.genCode(ctx) - - val getField = dataType match { - case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)""" - case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)""" - } - val code = s""" ${row.code} @@ -720,8 +716,55 @@ case class GetExternalRowField( "cannot be null."); } - final ${ctx.javaType(dataType)} ${ev.value} = $getField; + final Object ${ev.value} = ${row.value}.get($index); """ ev.copy(code = code, isNull = "false") } } + +/** + * Validates the actual data type of input expression at runtime. If it doesn't match the + * expectation, throw an exception. + */ +case class ValidateExternalType(child: Expression, expected: DataType) + extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ObjectType(classOf[Object])) + + override def nullable: Boolean = child.nullable + + override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected) + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val input = child.genCode(ctx) + val obj = input.value + + val typeCheck = expected match { + case _: DecimalType => + Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) + .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") + case _: ArrayType => + s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + case _ => + s"$obj instanceof ${ctx.boxedType(dataType)}" + } + + val code = s""" + ${input.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${input.isNull}) { + if ($typeCheck) { + ${ev.value} = (${ctx.boxedType(dataType)}) $obj; + } else { + throw new RuntimeException($obj.getClass().getName() + " is not a valid " + + "external type for schema of ${expected.simpleString}"); + } + } + + """ + ev.copy(code = code, isNull = input.isNull) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/38a626a5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 16abde0..2e513ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -127,22 +127,6 @@ class RowEncoderSuite extends SparkFunSuite { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) - test(s"encode/decode: Product") { - val schema = new StructType() - .add("structAsProduct", - new StructType() - .add("int", IntegerType) - .add("string", StringType) - .add("double", DoubleType)) - - val encoder = RowEncoder(schema).resolveAndBind() - - val input: Row = Row((100, "test", 0.123)) - val row = encoder.toRow(input) - val convertedBack = encoder.fromRow(row) - assert(input.getStruct(0) == convertedBack.getStruct(0)) - } - test("encode/decode decimal type") { val schema = new StructType() .add("int", IntegerType) @@ -232,6 +216,37 @@ class RowEncoderSuite extends SparkFunSuite { assert(e.getMessage.contains("top level row object")) } + test("RowEncoder should validate external type") { + val e1 = intercept[RuntimeException] { + val schema = new StructType().add("a", IntegerType) + val encoder = RowEncoder(schema) + encoder.toRow(Row(1.toShort)) + } + assert(e1.getMessage.contains("java.lang.Short is not a valid external type")) + + val e2 = intercept[RuntimeException] { + val schema = new StructType().add("a", StringType) + val encoder = RowEncoder(schema) + encoder.toRow(Row(1)) + } + assert(e2.getMessage.contains("java.lang.Integer is not a valid external type")) + + val e3 = intercept[RuntimeException] { + val schema = new StructType().add("a", + new StructType().add("b", IntegerType).add("c", StringType)) + val encoder = RowEncoder(schema) + encoder.toRow(Row(1 -> "a")) + } + assert(e3.getMessage.contains("scala.Tuple2 is not a valid external type")) + + val e4 = intercept[RuntimeException] { + val schema = new StructType().add("a", ArrayType(TimestampType)) + val encoder = RowEncoder(schema) + encoder.toRow(Row(Array("a"))) + } + assert(e4.getMessage.contains("java.lang.String is not a valid external type")) + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema).resolveAndBind() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org