Repository: spark
Updated Branches:
  refs/heads/branch-2.0 1426235bf -> c0bcecf91


[SPARK-15351][SQL] RowEncoder should support array as the external type for 
ArrayType

## What changes were proposed in this pull request?

This PR improves `RowEncoder` and `MapObjects`, to support array as the 
external type for `ArrayType`. The idea is straightforward, we use `Object` as 
the external input type for `ArrayType`, and determine its type at runtime in 
`MapObjects`.

## How was this patch tested?

new test in `RowEncoderSuite`

Author: Wenchen Fan <wenc...@databricks.com>

Closes #13138 from cloud-fan/map-object.

(cherry picked from commit c36ca651f9177f8e7a3f6a0098cba5a810ee9deb)
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c0bcecf9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c0bcecf9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c0bcecf9

Branch: refs/heads/branch-2.0
Commit: c0bcecf914a0e0f6669a62a50e6198af38d4aac6
Parents: 1426235
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue May 17 17:02:52 2016 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue May 17 17:03:15 2016 +0800

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/sql/Row.scala   |  4 +-
 .../sql/catalyst/encoders/RowEncoder.scala      | 22 +++++
 .../catalyst/expressions/objects/objects.scala  | 99 +++++++++-----------
 .../sql/catalyst/util/GenericArrayData.scala    |  5 +
 .../sql/catalyst/encoders/RowEncoderSuite.scala | 17 ++++
 5 files changed, 92 insertions(+), 55 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c0bcecf9/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 726291b..a257b83 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -151,7 +151,7 @@ trait Row extends Serializable {
    *   BinaryType -> byte array
    *   ArrayType -> scala.collection.Seq (use getList for java.util.List)
    *   MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
-   *   StructType -> org.apache.spark.sql.Row (or Product)
+   *   StructType -> org.apache.spark.sql.Row
    * }}}
    */
   def apply(i: Int): Any = get(i)
@@ -176,7 +176,7 @@ trait Row extends Serializable {
    *   BinaryType -> byte array
    *   ArrayType -> scala.collection.Seq (use getList for java.util.List)
    *   MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
-   *   StructType -> org.apache.spark.sql.Row (or Product)
+   *   StructType -> org.apache.spark.sql.Row
    * }}}
    */
   def get(i: Int): Any

http://git-wip-us.apache.org/repos/asf/spark/blob/c0bcecf9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index ae842a9..a5f39aa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -32,6 +32,26 @@ import org.apache.spark.unsafe.types.UTF8String
 /**
  * A factory for constructing encoders that convert external row to/from the 
Spark SQL
  * internal binary representation.
+ *
+ * The following is a mapping between Spark SQL types and its allowed external 
types:
+ * {{{
+ *   BooleanType -> java.lang.Boolean
+ *   ByteType -> java.lang.Byte
+ *   ShortType -> java.lang.Short
+ *   IntegerType -> java.lang.Integer
+ *   FloatType -> java.lang.Float
+ *   DoubleType -> java.lang.Double
+ *   StringType -> String
+ *   DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal
+ *
+ *   DateType -> java.sql.Date
+ *   TimestampType -> java.sql.Timestamp
+ *
+ *   BinaryType -> byte array
+ *   ArrayType -> scala.collection.Seq or Array
+ *   MapType -> scala.collection.Map
+ *   StructType -> org.apache.spark.sql.Row or Product
+ * }}}
  */
 object RowEncoder {
   def apply(schema: StructType): ExpressionEncoder[Row] = {
@@ -166,6 +186,8 @@ object RowEncoder {
     // In order to support both Decimal and java/scala BigDecimal in external 
row, we make this
     // as java.lang.Object.
     case _: DecimalType => ObjectType(classOf[java.lang.Object])
+    // In order to support both Array and Seq in external row, we make this as 
java.lang.Object.
+    case _: ArrayType => ObjectType(classOf[java.lang.Object])
     case _ => externalDataTypeFor(dt)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c0bcecf9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index e8a6c74..7df6e06 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -376,45 +376,6 @@ case class MapObjects private(
     lambdaFunction: Expression,
     inputData: Expression) extends Expression with NonSQLExpression {
 
-  @tailrec
-  private def itemAccessorMethod(dataType: DataType): String => String = 
dataType match {
-    case NullType =>
-      val nullTypeClassName = NullType.getClass.getName + ".MODULE$"
-      (i: String) => s".get($i, $nullTypeClassName)"
-    case IntegerType => (i: String) => s".getInt($i)"
-    case LongType => (i: String) => s".getLong($i)"
-    case FloatType => (i: String) => s".getFloat($i)"
-    case DoubleType => (i: String) => s".getDouble($i)"
-    case ByteType => (i: String) => s".getByte($i)"
-    case ShortType => (i: String) => s".getShort($i)"
-    case BooleanType => (i: String) => s".getBoolean($i)"
-    case StringType => (i: String) => s".getUTF8String($i)"
-    case s: StructType => (i: String) => s".getStruct($i, ${s.size})"
-    case a: ArrayType => (i: String) => s".getArray($i)"
-    case _: MapType => (i: String) => s".getMap($i)"
-    case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
-    case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)"
-    case DateType => (i: String) => s".getInt($i)"
-  }
-
-  private lazy val (lengthFunction, itemAccessor, primitiveElement) = 
inputData.dataType match {
-    case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
-      (".size()", (i: String) => s".apply($i)", false)
-    case ObjectType(cls) if cls.isArray =>
-      (".length", (i: String) => s"[$i]", false)
-    case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
-      (".size()", (i: String) => s".get($i)", false)
-    case ArrayType(t, _) =>
-      val (sqlType, primitiveElement) = t match {
-        case m: MapType => (m, false)
-        case s: StructType => (s, false)
-        case s: StringType => (s, false)
-        case udt: UserDefinedType[_] => (udt.sqlType, false)
-        case o => (o, true)
-      }
-      (".numElements()", itemAccessorMethod(sqlType), primitiveElement)
-  }
-
   override def nullable: Boolean = true
 
   override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
@@ -425,7 +386,6 @@ case class MapObjects private(
   override def dataType: DataType = ArrayType(lambdaFunction.dataType)
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val javaType = ctx.javaType(dataType)
     val elementJavaType = ctx.javaType(loopVar.dataType)
     ctx.addMutableState("boolean", loopVar.isNull, "")
     ctx.addMutableState(elementJavaType, loopVar.value, "")
@@ -448,27 +408,61 @@ case class MapObjects private(
       s"new $convertedType[$dataLength]"
     }
 
-    val loopNullCheck = if (primitiveElement) {
-      s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
-    } else {
-      s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == 
null;"
+    // In RowEncoder, we use `Object` to represent Array or Seq, so we need to 
determine the type
+    // of input collection at runtime for this case.
+    val seq = ctx.freshName("seq")
+    val array = ctx.freshName("array")
+    val determineCollectionType = inputData.dataType match {
+      case ObjectType(cls) if cls == classOf[Object] =>
+        val seqClass = classOf[Seq[_]].getName
+        s"""
+          $seqClass $seq = null;
+          $elementJavaType[] $array = null;
+          if (${genInputData.value}.getClass().isArray()) {
+            $array = ($elementJavaType[]) ${genInputData.value};
+          } else {
+            $seq = ($seqClass) ${genInputData.value};
+          }
+         """
+      case _ => ""
+    }
+
+
+    val (getLength, getLoopVar) = inputData.dataType match {
+      case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+        s"${genInputData.value}.size()" -> 
s"${genInputData.value}.apply($loopIndex)"
+      case ObjectType(cls) if cls.isArray =>
+        s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]"
+      case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) 
=>
+        s"${genInputData.value}.size()" -> 
s"${genInputData.value}.get($loopIndex)"
+      case ArrayType(et, _) =>
+        s"${genInputData.value}.numElements()" -> 
ctx.getValue(genInputData.value, et, loopIndex)
+      case ObjectType(cls) if cls == classOf[Object] =>
+        s"$seq == null ? $array.length : $seq.size()" ->
+          s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
+    }
+
+    val loopNullCheck = inputData.dataType match {
+      case _: ArrayType => s"${loopVar.isNull} = 
${genInputData.value}.isNullAt($loopIndex);"
+      // The element of primitive array will never be null.
+      case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive 
=>
+        s"${loopVar.isNull} = false"
+      case _ => s"${loopVar.isNull} = ${loopVar.value} == null;"
     }
 
     val code = s"""
       ${genInputData.code}
+      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
 
-      boolean ${ev.isNull} = ${genInputData.value} == null;
-      $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
-
-      if (!${ev.isNull}) {
+      if (!${genInputData.isNull}) {
+        $determineCollectionType
         $convertedType[] $convertedArray = null;
-        int $dataLength = ${genInputData.value}$lengthFunction;
+        int $dataLength = $getLength;
         $convertedArray = $arrayConstructor;
 
         int $loopIndex = 0;
         while ($loopIndex < $dataLength) {
-          ${loopVar.value} =
-            ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
+          ${loopVar.value} = ($elementJavaType) ($getLoopVar);
           $loopNullCheck
 
           ${genFunction.code}
@@ -481,11 +475,10 @@ case class MapObjects private(
           $loopIndex += 1;
         }
 
-        ${ev.isNull} = false;
         ${ev.value} = new 
${classOf[GenericArrayData].getName}($convertedArray);
       }
     """
-    ev.copy(code = code)
+    ev.copy(code = code, isNull = genInputData.isNull)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c0bcecf9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
index 2b8cdc1..3a665d3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
@@ -37,6 +37,11 @@ class GenericArrayData(val array: Array[Any]) extends 
ArrayData {
   def this(primitiveArray: Array[Byte]) = this(primitiveArray.toSeq)
   def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq)
 
+  def this(seqOrArray: Any) = this(seqOrArray match {
+    case seq: Seq[Any] => seq
+    case array: Array[_] => array.toSeq
+  })
+
   override def copy(): ArrayData = new GenericArrayData(array.clone())
 
   override def numElements(): Int = array.length

http://git-wip-us.apache.org/repos/asf/spark/blob/c0bcecf9/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 4800e2e..7bb006c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -185,6 +185,23 @@ class RowEncoderSuite extends SparkFunSuite {
     assert(encoder.serializer.head.nullable == false)
   }
 
+  test("RowEncoder should support array as the external type for ArrayType") {
+    val schema = new StructType()
+      .add("array", ArrayType(IntegerType))
+      .add("nestedArray", ArrayType(ArrayType(StringType)))
+      .add("deepNestedArray", ArrayType(ArrayType(ArrayType(LongType))))
+    val encoder = RowEncoder(schema)
+    val input = Row(
+      Array(1, 2, null),
+      Array(Array("abc", null), null),
+      Array(Seq(Array(0L, null), null), null))
+    val row = encoder.toRow(input)
+    val convertedBack = encoder.fromRow(row)
+    assert(convertedBack.getSeq(0) == Seq(1, 2, null))
+    assert(convertedBack.getSeq(1) == Seq(Seq("abc", null), null))
+    assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null))
+  }
+
   private def encodeDecodeTest(schema: StructType): Unit = {
     test(s"encode/decode: ${schema.simpleString}") {
       val encoder = RowEncoder(schema)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to