Repository: spark
Updated Branches:
  refs/heads/branch-2.0 36acf8856 -> f5784459e


[SPARK-15192][SQL] null check for SparkSession.createDataFrame

## What changes were proposed in this pull request?

This PR adds null check in `SparkSession.createDataFrame`, so that we can make 
sure the passed in rows matches the given schema.

## How was this patch tested?

new tests in `DatasetSuite`

Author: Wenchen Fan <wenc...@databricks.com>

Closes #13008 from cloud-fan/row-encoder.

(cherry picked from commit ebfe3a1f2c77e6869c3c36ba67afb7fabe6a94f5)
Signed-off-by: Yin Huai <yh...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f5784459
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f5784459
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f5784459

Branch: refs/heads/branch-2.0
Commit: f5784459e84e84d6641725bf0d0fb31db742456c
Parents: 36acf88
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Wed May 18 18:06:38 2016 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Wed May 18 18:06:52 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/mllib/fpm/FPGrowth.scala    |  2 +-
 .../apache/spark/sql/catalyst/ScalaReflection.scala    |  4 ++--
 .../spark/sql/catalyst/encoders/RowEncoder.scala       | 10 +++-------
 .../sql/catalyst/expressions/BoundAttribute.scala      |  2 +-
 .../sql/catalyst/expressions/objects/objects.scala     |  4 +++-
 .../main/scala/org/apache/spark/sql/SparkSession.scala |  4 ++--
 .../scala/org/apache/spark/sql/api/r/SQLUtils.scala    |  5 ++++-
 .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 13 +++++++++++--
 .../scala/org/apache/spark/sql/test/SQLTestUtils.scala |  6 +-----
 9 files changed, 28 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f5784459/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 9166faa..28e4966 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -116,7 +116,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
         StructField("freq", LongType))
       val schema = StructType(fields)
       val rowDataRDD = model.freqItemsets.map { x =>
-        Row(x.items, x.freq)
+        Row(x.items.toSeq, x.freq)
       }
       sqlContext.createDataFrame(rowDataRDD, 
schema).write.parquet(Loader.dataPath(path))
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/f5784459/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index cb9a62d..c0fa220 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -113,8 +113,8 @@ object ScalaReflection extends ScalaReflection {
    * Returns true if the value of this data type is same between internal and 
external.
    */
   def isNativeType(dt: DataType): Boolean = dt match {
-    case BooleanType | ByteType | ShortType | IntegerType | LongType |
-         FloatType | DoubleType | BinaryType => true
+    case NullType | BooleanType | ByteType | ShortType | IntegerType | 
LongType |
+         FloatType | DoubleType | BinaryType | CalendarIntervalType => true
     case _ => false
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f5784459/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index a5f39aa..71b39c5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -70,8 +70,7 @@ object RowEncoder {
   private def serializerFor(
       inputObject: Expression,
       inputType: DataType): Expression = inputType match {
-    case NullType | BooleanType | ByteType | ShortType | IntegerType | 
LongType |
-         FloatType | DoubleType | BinaryType | CalendarIntervalType => 
inputObject
+    case dt if ScalaReflection.isNativeType(dt) => inputObject
 
     case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType)
 
@@ -151,7 +150,7 @@ object RowEncoder {
     case StructType(fields) =>
       val convertedFields = fields.zipWithIndex.map { case (f, i) =>
         val fieldValue = serializerFor(
-          GetExternalRowField(inputObject, i, 
externalDataTypeForInput(f.dataType)),
+          GetExternalRowField(inputObject, i, f.name, 
externalDataTypeForInput(f.dataType)),
           f.dataType
         )
         if (f.nullable) {
@@ -193,7 +192,6 @@ object RowEncoder {
 
   private def externalDataTypeFor(dt: DataType): DataType = dt match {
     case _ if ScalaReflection.isNativeType(dt) => dt
-    case CalendarIntervalType => dt
     case TimestampType => ObjectType(classOf[java.sql.Timestamp])
     case DateType => ObjectType(classOf[java.sql.Date])
     case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
@@ -202,7 +200,6 @@ object RowEncoder {
     case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
     case _: StructType => ObjectType(classOf[Row])
     case udt: UserDefinedType[_] => ObjectType(udt.userClass)
-    case _: NullType => ObjectType(classOf[java.lang.Object])
   }
 
   private def deserializerFor(schema: StructType): Expression = {
@@ -222,8 +219,7 @@ object RowEncoder {
   }
 
   private def deserializerFor(input: Expression): Expression = input.dataType 
match {
-    case NullType | BooleanType | ByteType | ShortType | IntegerType | 
LongType |
-         FloatType | DoubleType | BinaryType | CalendarIntervalType => input
+    case dt if ScalaReflection.isNativeType(dt) => input
 
     case udt: UserDefinedType[_] =>
       val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])

http://git-wip-us.apache.org/repos/asf/spark/blob/f5784459/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 99f156a..a38f1ec 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
 case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
   extends LeafExpression {
 
-  override def toString: String = s"input[$ordinal, ${dataType.simpleString}]"
+  override def toString: String = s"input[$ordinal, ${dataType.simpleString}, 
$nullable]"
 
   // Use special getter for primitive types (for UnsafeRow)
   override def eval(input: InternalRow): Any = {

http://git-wip-us.apache.org/repos/asf/spark/blob/f5784459/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 7df6e06..fc38369 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -693,6 +693,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: 
Seq[String])
 case class GetExternalRowField(
     child: Expression,
     index: Int,
+    fieldName: String,
     dataType: DataType) extends UnaryExpression with NonSQLExpression {
 
   override def nullable: Boolean = false
@@ -716,7 +717,8 @@ case class GetExternalRowField(
       }
 
       if (${row.value}.isNullAt($index)) {
-        throw new RuntimeException("The ${index}th field of input row cannot 
be null.");
+        throw new RuntimeException("The ${index}th field '$fieldName' of input 
row " +
+          "cannot be null.");
       }
 
       final ${ctx.javaType(dataType)} ${ev.value} = $getField;

http://git-wip-us.apache.org/repos/asf/spark/blob/f5784459/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index aa974f2..8ef6cd3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -478,8 +478,8 @@ class SparkSession private(
     // TODO: use MutableProjection when rowRDD is another DataFrame and the 
applied
     // schema differs from the existing schema on any field data type.
     val catalystRows = if (needsConversion) {
-      val converter = CatalystTypeConverters.createToCatalystConverter(schema)
-      rowRDD.map(converter(_).asInstanceOf[InternalRow])
+      val encoder = RowEncoder(schema)
+      rowRDD.map(encoder.toRow)
     } else {
       rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)}
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/f5784459/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 36173a4..2021cdd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.api.r
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, 
DataOutputStream}
 
+import scala.collection.JavaConverters._
 import scala.util.matching.Regex
 
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
@@ -110,6 +111,8 @@ private[sql] object SQLUtils {
     data match {
       case d: java.lang.Double if dataType == FloatType =>
         new java.lang.Float(d)
+      // Scala Map is the only allowed external type of map type in Row.
+      case m: java.util.Map[_, _] => m.asScala
       case _ => data
     }
   }
@@ -120,7 +123,7 @@ private[sql] object SQLUtils {
     val num = SerDe.readInt(dis)
     Row.fromSeq((0 until num).map { i =>
       doConversion(SerDe.readObject(dis), schema.fields(i).dataType)
-    }.toSeq)
+    })
   }
 
   private[sql] def rowToRBytes(row: Row): Array[Byte] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/f5784459/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index b02b714..1935e41 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -507,7 +507,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     val schema = StructType(Seq(
       StructField("f", StructType(Seq(
         StructField("a", StringType, nullable = true),
-        StructField("b", IntegerType, nullable = false)
+        StructField("b", IntegerType, nullable = true)
       )), nullable = true)
     ))
 
@@ -684,7 +684,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
     val message = intercept[Exception] {
       df.collect()
     }.getMessage
-    assert(message.contains("The 0th field of input row cannot be null"))
+    assert(message.contains("The 0th field 'i' of input row cannot be null"))
+  }
+
+  test("row nullability mismatch") {
+    val schema = new StructType().add("a", StringType, true).add("b", 
StringType, false)
+    val rdd = sqlContext.sparkContext.parallelize(Row(null, "123") :: 
Row("234", null) :: Nil)
+    val message = intercept[Exception] {
+      sqlContext.createDataFrame(rdd, schema).collect()
+    }.getMessage
+    assert(message.contains("The 1th field 'b' of input row cannot be null"))
   }
 
   test("createTempView") {

http://git-wip-us.apache.org/repos/asf/spark/blob/f5784459/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 45a9c9d..51538ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -217,11 +217,7 @@ private[sql] trait SQLTestUtils
       case FilterExec(_, child) => child
     }
 
-    val childRDD = withoutFilters
-      .execute()
-      .map(row => Row.fromSeq(row.copy().toSeq(schema)))
-
-    spark.createDataFrame(childRDD, schema)
+    spark.internalCreateDataFrame(withoutFilters.execute(), schema)
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to