Repository: spark Updated Branches: refs/heads/master 32be51fba -> ebfe3a1f2
[SPARK-15192][SQL] null check for SparkSession.createDataFrame ## What changes were proposed in this pull request? This PR adds null check in `SparkSession.createDataFrame`, so that we can make sure the passed in rows matches the given schema. ## How was this patch tested? new tests in `DatasetSuite` Author: Wenchen Fan <wenc...@databricks.com> Closes #13008 from cloud-fan/row-encoder. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ebfe3a1f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ebfe3a1f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ebfe3a1f Branch: refs/heads/master Commit: ebfe3a1f2c77e6869c3c36ba67afb7fabe6a94f5 Parents: 32be51f Author: Wenchen Fan <wenc...@databricks.com> Authored: Wed May 18 18:06:38 2016 -0700 Committer: Yin Huai <yh...@databricks.com> Committed: Wed May 18 18:06:38 2016 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- .../apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++-- .../spark/sql/catalyst/encoders/RowEncoder.scala | 10 +++------- .../sql/catalyst/expressions/BoundAttribute.scala | 2 +- .../sql/catalyst/expressions/objects/objects.scala | 4 +++- .../main/scala/org/apache/spark/sql/SparkSession.scala | 4 ++-- .../scala/org/apache/spark/sql/api/r/SQLUtils.scala | 5 ++++- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 13 +++++++++++-- .../scala/org/apache/spark/sql/test/SQLTestUtils.scala | 6 +----- 9 files changed, 28 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ebfe3a1f/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 9166faa..28e4966 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -116,7 +116,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { StructField("freq", LongType)) val schema = StructType(fields) val rowDataRDD = model.freqItemsets.map { x => - Row(x.items, x.freq) + Row(x.items.toSeq, x.freq) } sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) } http://git-wip-us.apache.org/repos/asf/spark/blob/ebfe3a1f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index cb9a62d..c0fa220 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -113,8 +113,8 @@ object ScalaReflection extends ScalaReflection { * Returns true if the value of this data type is same between internal and external. */ def isNativeType(dt: DataType): Boolean = dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => true + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType | CalendarIntervalType => true case _ => false } http://git-wip-us.apache.org/repos/asf/spark/blob/ebfe3a1f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index a5f39aa..71b39c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -70,8 +70,7 @@ object RowEncoder { private def serializerFor( inputObject: Expression, inputType: DataType): Expression = inputType match { - case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject + case dt if ScalaReflection.isNativeType(dt) => inputObject case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) @@ -151,7 +150,7 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => val fieldValue = serializerFor( - GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)), + GetExternalRowField(inputObject, i, f.name, externalDataTypeForInput(f.dataType)), f.dataType ) if (f.nullable) { @@ -193,7 +192,6 @@ object RowEncoder { private def externalDataTypeFor(dt: DataType): DataType = dt match { case _ if ScalaReflection.isNativeType(dt) => dt - case CalendarIntervalType => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -202,7 +200,6 @@ object RowEncoder { case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) case udt: UserDefinedType[_] => ObjectType(udt.userClass) - case _: NullType => ObjectType(classOf[java.lang.Object]) } private def deserializerFor(schema: StructType): Expression = { @@ -222,8 +219,7 @@ object RowEncoder { } private def deserializerFor(input: Expression): Expression = input.dataType match { - case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType | CalendarIntervalType => input + case dt if ScalaReflection.isNativeType(dt) => input case udt: UserDefinedType[_] => val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) http://git-wip-us.apache.org/repos/asf/spark/blob/ebfe3a1f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 99f156a..a38f1ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression { - override def toString: String = s"input[$ordinal, ${dataType.simpleString}]" + override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { http://git-wip-us.apache.org/repos/asf/spark/blob/ebfe3a1f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 7df6e06..fc38369 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -693,6 +693,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) case class GetExternalRowField( child: Expression, index: Int, + fieldName: String, dataType: DataType) extends UnaryExpression with NonSQLExpression { override def nullable: Boolean = false @@ -716,7 +717,8 @@ case class GetExternalRowField( } if (${row.value}.isNullAt($index)) { - throw new RuntimeException("The ${index}th field of input row cannot be null."); + throw new RuntimeException("The ${index}th field '$fieldName' of input row " + + "cannot be null."); } final ${ctx.javaType(dataType)} ${ev.value} = $getField; http://git-wip-us.apache.org/repos/asf/spark/blob/ebfe3a1f/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index da575c7..629243b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -478,8 +478,8 @@ class SparkSession private( // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val catalystRows = if (needsConversion) { - val converter = CatalystTypeConverters.createToCatalystConverter(schema) - rowRDD.map(converter(_).asInstanceOf[InternalRow]) + val encoder = RowEncoder(schema) + rowRDD.map(encoder.toRow) } else { rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)} } http://git-wip-us.apache.org/repos/asf/spark/blob/ebfe3a1f/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index ffb606f..486a440 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.collection.JavaConverters._ import scala.util.matching.Regex import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} @@ -108,6 +109,8 @@ private[sql] object SQLUtils { data match { case d: java.lang.Double if dataType == FloatType => new java.lang.Float(d) + // Scala Map is the only allowed external type of map type in Row. + case m: java.util.Map[_, _] => m.asScala case _ => data } } @@ -118,7 +121,7 @@ private[sql] object SQLUtils { val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => doConversion(SerDe.readObject(dis), schema.fields(i).dataType) - }.toSeq) + }) } private[sql] def rowToRBytes(row: Row): Array[Byte] = { http://git-wip-us.apache.org/repos/asf/spark/blob/ebfe3a1f/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b02b714..1935e41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -507,7 +507,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val schema = StructType(Seq( StructField("f", StructType(Seq( StructField("a", StringType, nullable = true), - StructField("b", IntegerType, nullable = false) + StructField("b", IntegerType, nullable = true) )), nullable = true) )) @@ -684,7 +684,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val message = intercept[Exception] { df.collect() }.getMessage - assert(message.contains("The 0th field of input row cannot be null")) + assert(message.contains("The 0th field 'i' of input row cannot be null")) + } + + test("row nullability mismatch") { + val schema = new StructType().add("a", StringType, true).add("b", StringType, false) + val rdd = sqlContext.sparkContext.parallelize(Row(null, "123") :: Row("234", null) :: Nil) + val message = intercept[Exception] { + sqlContext.createDataFrame(rdd, schema).collect() + }.getMessage + assert(message.contains("The 1th field 'b' of input row cannot be null")) } test("createTempView") { http://git-wip-us.apache.org/repos/asf/spark/blob/ebfe3a1f/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 45a9c9d..51538ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -217,11 +217,7 @@ private[sql] trait SQLTestUtils case FilterExec(_, child) => child } - val childRDD = withoutFilters - .execute() - .map(row => Row.fromSeq(row.copy().toSeq(schema))) - - spark.createDataFrame(childRDD, schema) + spark.internalCreateDataFrame(withoutFilters.execute(), schema) } /** --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org