Repository: spark
Updated Branches:
  refs/heads/master 441cdcca6 -> e4899a253


[SPARK-2254] [SQL] ScalaRefection should mark primitive types as non-nullable.

Author: Takuya UESHIN <ues...@happy-camper.st>

Closes #1193 from ueshin/issues/SPARK-2254 and squashes the following commits:

cfd6088 [Takuya UESHIN] Modify ScalaRefection.schemaFor method to return 
nullability of Scala Type.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e4899a25
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e4899a25
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e4899a25

Branch: refs/heads/master
Commit: e4899a253728bfa7c78709a37a4837f74b72bd61
Parents: 441cdcc
Author: Takuya UESHIN <ues...@happy-camper.st>
Authored: Wed Jun 25 23:55:31 2014 -0700
Committer: Reynold Xin <r...@apache.org>
Committed: Wed Jun 25 23:55:31 2014 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |  65 ++++-----
 .../sql/catalyst/ScalaReflectionSuite.scala     | 131 +++++++++++++++++++
 2 files changed, 165 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e4899a25/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 196695a..ada48ea 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -30,53 +30,56 @@ import org.apache.spark.sql.catalyst.types._
 object ScalaReflection {
   import scala.reflect.runtime.universe._
 
+  case class Schema(dataType: DataType, nullable: Boolean)
+
   /** Returns a Sequence of attributes for the given case class type. */
   def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
-    case s: StructType =>
-      s.fields.map(f => AttributeReference(f.name, f.dataType, nullable = 
true)())
+    case Schema(s: StructType, _) =>
+      s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
   }
 
-  /** Returns a catalyst DataType for the given Scala Type using reflection. */
-  def schemaFor[T: TypeTag]: DataType = schemaFor(typeOf[T])
+  /** Returns a catalyst DataType and its nullability for the given Scala Type 
using reflection. */
+  def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])
 
-  /** Returns a catalyst DataType for the given Scala Type using reflection. */
-  def schemaFor(tpe: `Type`): DataType = tpe match {
+  /** Returns a catalyst DataType and its nullability for the given Scala Type 
using reflection. */
+  def schemaFor(tpe: `Type`): Schema = tpe match {
     case t if t <:< typeOf[Option[_]] =>
       val TypeRef(_, _, Seq(optType)) = t
-      schemaFor(optType)
+      Schema(schemaFor(optType).dataType, nullable = true)
     case t if t <:< typeOf[Product] =>
       val params = t.member("<init>": TermName).asMethod.paramss
-      StructType(
-        params.head.map(p =>
-          StructField(p.name.toString, schemaFor(p.typeSignature), nullable = 
true)))
+      Schema(StructType(
+        params.head.map { p =>
+          val Schema(dataType, nullable) = schemaFor(p.typeSignature)
+          StructField(p.name.toString, dataType, nullable)
+        }), nullable = true)
     // Need to decide if we actually need a special type here.
-    case t if t <:< typeOf[Array[Byte]] => BinaryType
+    case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
     case t if t <:< typeOf[Array[_]] =>
       sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
     case t if t <:< typeOf[Seq[_]] =>
       val TypeRef(_, _, Seq(elementType)) = t
-      ArrayType(schemaFor(elementType))
+      Schema(ArrayType(schemaFor(elementType).dataType), nullable = true)
     case t if t <:< typeOf[Map[_,_]] =>
       val TypeRef(_, _, Seq(keyType, valueType)) = t
-      MapType(schemaFor(keyType), schemaFor(valueType))
-    case t if t <:< typeOf[String] => StringType
-    case t if t <:< typeOf[Timestamp] => TimestampType
-    case t if t <:< typeOf[BigDecimal] => DecimalType
-    case t if t <:< typeOf[java.lang.Integer] => IntegerType
-    case t if t <:< typeOf[java.lang.Long] => LongType
-    case t if t <:< typeOf[java.lang.Double] => DoubleType
-    case t if t <:< typeOf[java.lang.Float] => FloatType
-    case t if t <:< typeOf[java.lang.Short] => ShortType
-    case t if t <:< typeOf[java.lang.Byte] => ByteType
-    case t if t <:< typeOf[java.lang.Boolean] => BooleanType
-    // TODO: The following datatypes could be marked as non-nullable.
-    case t if t <:< definitions.IntTpe => IntegerType
-    case t if t <:< definitions.LongTpe => LongType
-    case t if t <:< definitions.DoubleTpe => DoubleType
-    case t if t <:< definitions.FloatTpe => FloatType
-    case t if t <:< definitions.ShortTpe => ShortType
-    case t if t <:< definitions.ByteTpe => ByteType
-    case t if t <:< definitions.BooleanTpe => BooleanType
+      Schema(MapType(schemaFor(keyType).dataType, 
schemaFor(valueType).dataType), nullable = true)
+    case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
+    case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
+    case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
+    case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable 
= true)
+    case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
+    case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = 
true)
+    case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = 
true)
+    case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = 
true)
+    case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
+    case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable 
= true)
+    case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
+    case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
+    case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = 
false)
+    case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
+    case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
+    case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
+    case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = 
false)
   }
 
   implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {

http://git-wip-us.apache.org/repos/asf/spark/blob/e4899a25/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
new file mode 100644
index 0000000..489d7e9
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import java.sql.Timestamp
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
+
+case class PrimitiveData(
+    intField: Int,
+    longField: Long,
+    doubleField: Double,
+    floatField: Float,
+    shortField: Short,
+    byteField: Byte,
+    booleanField: Boolean)
+
+case class NullableData(
+    intField: java.lang.Integer,
+    longField: java.lang.Long,
+    doubleField: java.lang.Double,
+    floatField: java.lang.Float,
+    shortField: java.lang.Short,
+    byteField: java.lang.Byte,
+    booleanField: java.lang.Boolean,
+    stringField: String,
+    decimalField: BigDecimal,
+    timestampField: Timestamp,
+    binaryField: Array[Byte])
+
+case class OptionalData(
+    intField: Option[Int],
+    longField: Option[Long],
+    doubleField: Option[Double],
+    floatField: Option[Float],
+    shortField: Option[Short],
+    byteField: Option[Byte],
+    booleanField: Option[Boolean])
+
+case class ComplexData(
+    arrayField: Seq[Int],
+    mapField: Map[Int, String],
+    structField: PrimitiveData)
+
+class ScalaReflectionSuite extends FunSuite {
+  import ScalaReflection._
+
+  test("primitive data") {
+    val schema = schemaFor[PrimitiveData]
+    assert(schema === Schema(
+      StructType(Seq(
+        StructField("intField", IntegerType, nullable = false),
+        StructField("longField", LongType, nullable = false),
+        StructField("doubleField", DoubleType, nullable = false),
+        StructField("floatField", FloatType, nullable = false),
+        StructField("shortField", ShortType, nullable = false),
+        StructField("byteField", ByteType, nullable = false),
+        StructField("booleanField", BooleanType, nullable = false))),
+      nullable = true))
+  }
+
+  test("nullable data") {
+    val schema = schemaFor[NullableData]
+    assert(schema === Schema(
+      StructType(Seq(
+        StructField("intField", IntegerType, nullable = true),
+        StructField("longField", LongType, nullable = true),
+        StructField("doubleField", DoubleType, nullable = true),
+        StructField("floatField", FloatType, nullable = true),
+        StructField("shortField", ShortType, nullable = true),
+        StructField("byteField", ByteType, nullable = true),
+        StructField("booleanField", BooleanType, nullable = true),
+        StructField("stringField", StringType, nullable = true),
+        StructField("decimalField", DecimalType, nullable = true),
+        StructField("timestampField", TimestampType, nullable = true),
+        StructField("binaryField", BinaryType, nullable = true))),
+      nullable = true))
+  }
+
+  test("optinal data") {
+    val schema = schemaFor[OptionalData]
+    assert(schema === Schema(
+      StructType(Seq(
+        StructField("intField", IntegerType, nullable = true),
+        StructField("longField", LongType, nullable = true),
+        StructField("doubleField", DoubleType, nullable = true),
+        StructField("floatField", FloatType, nullable = true),
+        StructField("shortField", ShortType, nullable = true),
+        StructField("byteField", ByteType, nullable = true),
+        StructField("booleanField", BooleanType, nullable = true))),
+      nullable = true))
+  }
+
+  test("complex data") {
+    val schema = schemaFor[ComplexData]
+    assert(schema === Schema(
+      StructType(Seq(
+        StructField("arrayField", ArrayType(IntegerType), nullable = true),
+        StructField("mapField", MapType(IntegerType, StringType), nullable = 
true),
+        StructField(
+          "structField",
+          StructType(Seq(
+            StructField("intField", IntegerType, nullable = false),
+            StructField("longField", LongType, nullable = false),
+            StructField("doubleField", DoubleType, nullable = false),
+            StructField("floatField", FloatType, nullable = false),
+            StructField("shortField", ShortType, nullable = false),
+            StructField("byteField", ByteType, nullable = false),
+            StructField("booleanField", BooleanType, nullable = false))),
+          nullable = true))),
+      nullable = true))
+  }
+}

Reply via email to