Repository: spark
Updated Branches:
  refs/heads/master 402bf2a50 -> 295747e59


[SPARK-19716][SQL] support by-name resolution for struct type elements in array

## What changes were proposed in this pull request?

Previously when we construct deserializer expression for array type, we will 
first cast the corresponding field to expected array type and then apply 
`MapObjects`.

However, by doing that, we lose the opportunity to do by-name resolution for 
struct type inside array type. In this PR, I introduce a `UnresolvedMapObjects` 
to hold the lambda function and the input array expression. Then during 
analysis, after the input array expression is resolved, we get the actual array 
element type and apply by-name resolution. Then we don't need to add `Cast` for 
array type when constructing the deserializer expression, as the element type 
is determined later at analyzer.

## How was this patch tested?

new regression test

Author: Wenchen Fan <wenc...@databricks.com>

Closes #17398 from cloud-fan/dataset.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/295747e5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/295747e5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/295747e5

Branch: refs/heads/master
Commit: 295747e59739ee8a697ac3eba485d3439e4a04c3
Parents: 402bf2a
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Apr 4 16:38:32 2017 -0700
Committer: Cheng Lian <l...@databricks.com>
Committed: Tue Apr 4 16:38:32 2017 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 66 +++++++++++---------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 19 +++++-
 .../expressions/complexTypeExtractors.scala     |  2 +-
 .../catalyst/expressions/objects/objects.scala  | 32 +++++++---
 .../encoders/EncoderResolutionSuite.scala       | 52 +++++++++++++++
 .../sql/expressions/ReduceAggregator.scala      |  2 +-
 .../org/apache/spark/sql/DatasetSuite.scala     |  9 +++
 7 files changed, 141 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/295747e5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index da37eb0..206ae2f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -92,7 +92,7 @@ object ScalaReflection extends ScalaReflection {
    * Array[T].  Special handling is performed for primitive types to map them 
back to their raw
    * JVM form instead of the Scala Array that handles auto boxing.
    */
-  private def arrayClassFor(tpe: `Type`): DataType = 
ScalaReflectionLock.synchronized {
+  private def arrayClassFor(tpe: `Type`): ObjectType = 
ScalaReflectionLock.synchronized {
     val cls = tpe match {
       case t if t <:< definitions.IntTpe => classOf[Array[Int]]
       case t if t <:< definitions.LongTpe => classOf[Array[Long]]
@@ -178,15 +178,17 @@ object ScalaReflection extends ScalaReflection {
      * is [a: int, b: long], then we will hit runtime error and say that we 
can't construct class
      * `Data` with int and long, because we lost the information that `b` 
should be a string.
      *
-     * This method help us "remember" the required data type by adding a 
`UpCast`.  Note that we
-     * don't need to cast struct type because there must be 
`UnresolvedExtractValue` or
-     * `GetStructField` wrapping it, thus we only need to handle leaf type.
+     * This method help us "remember" the required data type by adding a 
`UpCast`. Note that we
+     * only need to do this for leaf nodes.
      */
     def upCastToExpectedType(
         expr: Expression,
         expected: DataType,
         walkedTypePath: Seq[String]): Expression = expected match {
       case _: StructType => expr
+      case _: ArrayType => expr
+      // TODO: ideally we should also skip MapType, but nested StructType 
inside MapType is rare and
+      // it's not trivial to support by-name resolution for StructType inside 
MapType.
       case _ => UpCast(expr, expected, walkedTypePath)
     }
 
@@ -265,42 +267,48 @@ object ScalaReflection extends ScalaReflection {
 
       case t if t <:< localTypeOf[Array[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
+        val Schema(_, elementNullable) = schemaFor(elementType)
+        val className = getClassNameFromType(elementType)
+        val newTypePath = s"""- array element class: "$className"""" +: 
walkedTypePath
 
-        // TODO: add runtime null check for primitive array
-        val primitiveMethod = elementType match {
-          case t if t <:< definitions.IntTpe => Some("toIntArray")
-          case t if t <:< definitions.LongTpe => Some("toLongArray")
-          case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
-          case t if t <:< definitions.FloatTpe => Some("toFloatArray")
-          case t if t <:< definitions.ShortTpe => Some("toShortArray")
-          case t if t <:< definitions.ByteTpe => Some("toByteArray")
-          case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
-          case _ => None
+        val mapFunction: Expression => Expression = p => {
+          val converter = deserializerFor(elementType, Some(p), newTypePath)
+          if (elementNullable) {
+            converter
+          } else {
+            AssertNotNull(converter, newTypePath)
+          }
         }
 
-        primitiveMethod.map { method =>
-          Invoke(getPath, method, arrayClassFor(elementType))
-        }.getOrElse {
-          val className = getClassNameFromType(elementType)
-          val newTypePath = s"""- array element class: "$className"""" +: 
walkedTypePath
-          Invoke(
-            MapObjects(
-              p => deserializerFor(elementType, Some(p), newTypePath),
-              getPath,
-              schemaFor(elementType).dataType),
-            "array",
-            arrayClassFor(elementType))
+        val arrayData = UnresolvedMapObjects(mapFunction, getPath)
+        val arrayCls = arrayClassFor(elementType)
+
+        if (elementNullable) {
+          Invoke(arrayData, "array", arrayCls)
+        } else {
+          val primitiveMethod = elementType match {
+            case t if t <:< definitions.IntTpe => "toIntArray"
+            case t if t <:< definitions.LongTpe => "toLongArray"
+            case t if t <:< definitions.DoubleTpe => "toDoubleArray"
+            case t if t <:< definitions.FloatTpe => "toFloatArray"
+            case t if t <:< definitions.ShortTpe => "toShortArray"
+            case t if t <:< definitions.ByteTpe => "toByteArray"
+            case t if t <:< definitions.BooleanTpe => "toBooleanArray"
+            case other => throw new IllegalStateException("expect primitive 
array element type " +
+              "but got " + other)
+          }
+          Invoke(arrayData, primitiveMethod, arrayCls)
         }
 
       case t if t <:< localTypeOf[Seq[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
-        val Schema(dataType, nullable) = schemaFor(elementType)
+        val Schema(_, elementNullable) = schemaFor(elementType)
         val className = getClassNameFromType(elementType)
         val newTypePath = s"""- array element class: "$className"""" +: 
walkedTypePath
 
         val mapFunction: Expression => Expression = p => {
           val converter = deserializerFor(elementType, Some(p), newTypePath)
-          if (nullable) {
+          if (elementNullable) {
             converter
           } else {
             AssertNotNull(converter, newTypePath)
@@ -312,7 +320,7 @@ object ScalaReflection extends ScalaReflection {
           case NoSymbol => classOf[Seq[_]]
           case _ => mirror.runtimeClass(t.typeSymbol.asClass)
         }
-        MapObjects(mapFunction, getPath, dataType, Some(cls))
+        UnresolvedMapObjects(mapFunction, getPath, Some(cls))
 
       case t if t <:< localTypeOf[Map[_, _]] =>
         // TODO: add walked type path for map

http://git-wip-us.apache.org/repos/asf/spark/blob/295747e5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2d53d24..c698ca6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.encoders.OuterScopes
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
+import org.apache.spark.sql.catalyst.expressions.objects.{MapObjects, 
NewInstance, UnresolvedMapObjects}
 import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
 import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
 import org.apache.spark.sql.catalyst.plans._
@@ -2227,8 +2227,21 @@ class Analyzer(
           validateTopLevelTupleFields(deserializer, inputs)
           val resolved = resolveExpression(
             deserializer, LocalRelation(inputs), throws = true)
-          validateNestedTupleFields(resolved)
-          resolved
+          val result = resolved transformDown {
+            case UnresolvedMapObjects(func, inputData, cls) if 
inputData.resolved =>
+              inputData.dataType match {
+                case ArrayType(et, _) =>
+                  val expr = MapObjects(func, inputData, et, cls) transformUp {
+                    case UnresolvedExtractValue(child, fieldName) if 
child.resolved =>
+                      ExtractValue(child, fieldName, resolver)
+                  }
+                  expr
+                case other =>
+                  throw new AnalysisException("need an array field but got " + 
other.simpleString)
+              }
+          }
+          validateNestedTupleFields(result)
+          result
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/295747e5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index de1594d..ef88cfb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -68,7 +68,7 @@ object ExtractValue {
           case StructType(_) =>
             s"Field name should be String Literal, but it's $extraction"
           case other =>
-            s"Can't extract value from $child"
+            s"Can't extract value from $child: need struct type but got 
${other.simpleString}"
         }
         throw new AnalysisException(errorMsg)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/295747e5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index bb584f7..00e2ac9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -448,6 +448,17 @@ object MapObjects {
   }
 }
 
+case class UnresolvedMapObjects(
+    function: Expression => Expression,
+    child: Expression,
+    customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with 
Unevaluable {
+  override lazy val resolved = false
+
+  override def dataType: DataType = 
customCollectionCls.map(ObjectType.apply).getOrElse {
+    throw new UnsupportedOperationException("not resolved")
+  }
+}
+
 /**
  * Applies the given expression to every element of a collection of items, 
returning the result
  * as an ArrayType or ObjectType. This is similar to a typical map operation, 
but where the lambda
@@ -581,17 +592,24 @@ case class MapObjects private(
           // collection
           val collObjectName = s"${cls.getName}$$.MODULE$$"
           val getBuilderVar = s"$collObjectName.newBuilder()"
-
-          (s"""${classOf[Builder[_, _]].getName} $builderValue = 
$getBuilderVar;
-        $builderValue.sizeHint($dataLength);""",
+          (
+            s"""
+               ${classOf[Builder[_, _]].getName} $builderValue = 
$getBuilderVar;
+               $builderValue.sizeHint($dataLength);
+             """,
             genValue => s"$builderValue.$$plus$$eq($genValue);",
-            s"(${cls.getName}) $builderValue.result();")
+            s"(${cls.getName}) $builderValue.result();"
+          )
         case None =>
           // array
-          (s"""$convertedType[] $convertedArray = null;
-        $convertedArray = $arrayConstructor;""",
+          (
+            s"""
+               $convertedType[] $convertedArray = null;
+               $convertedArray = $arrayConstructor;
+             """,
             genValue => s"$convertedArray[$loopIndex] = $genValue;",
-            s"new ${classOf[GenericArrayData].getName}($convertedArray);")
+            s"new ${classOf[GenericArrayData].getName}($convertedArray);"
+          )
       }
 
     val code = s"""

http://git-wip-us.apache.org/repos/asf/spark/blob/295747e5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 802397d..e5a3e1f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -33,6 +33,10 @@ case class StringIntClass(a: String, b: Int)
 
 case class ComplexClass(a: Long, b: StringLongClass)
 
+case class ArrayClass(arr: Seq[StringIntClass])
+
+case class NestedArrayClass(nestedArr: Array[ArrayClass])
+
 class EncoderResolutionSuite extends PlanTest {
   private val str = UTF8String.fromString("hello")
 
@@ -62,6 +66,54 @@ class EncoderResolutionSuite extends PlanTest {
     encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 
1.toByte), 2))
   }
 
+  test("real type doesn't match encoder schema but they are compatible: 
array") {
+    val encoder = ExpressionEncoder[ArrayClass]
+    val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", 
"int").add("c", "int")))
+    val array = new GenericArrayData(Array(InternalRow(1, 2, 3)))
+    encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
+  }
+
+  test("real type doesn't match encoder schema but they are compatible: nested 
array") {
+    val encoder = ExpressionEncoder[NestedArrayClass]
+    val et = new StructType().add("arr", ArrayType(
+      new StructType().add("a", "int").add("b", "int").add("c", "int")))
+    val attrs = Seq('nestedArr.array(et))
+    val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3)))
+    val outerArr = new GenericArrayData(Array(InternalRow(innerArr)))
+    encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr))
+  }
+
+  test("the real type is not compatible with encoder schema: non-array field") 
{
+    val encoder = ExpressionEncoder[ArrayClass]
+    val attrs = Seq('arr.int)
+    assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message 
==
+      "need an array field but got int")
+  }
+
+  test("the real type is not compatible with encoder schema: array element 
type") {
+    val encoder = ExpressionEncoder[ArrayClass]
+    val attrs = Seq('arr.array(new StructType().add("c", "int")))
+    assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message 
==
+      "No such struct field a in c")
+  }
+
+  test("the real type is not compatible with encoder schema: nested array 
element type") {
+    val encoder = ExpressionEncoder[NestedArrayClass]
+
+    withClue("inner element is not array") {
+      val attrs = Seq('nestedArr.array(new StructType().add("arr", "int")))
+      
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
+        "need an array field but got int")
+    }
+
+    withClue("nested array element type is not compatible") {
+      val attrs = Seq('nestedArr.array(new StructType()
+        .add("arr", ArrayType(new StructType().add("c", "int")))))
+      
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
+        "No such struct field a in c")
+    }
+  }
+
   test("nullability of array type element should not fail analysis") {
     val encoder = ExpressionEncoder[Seq[Int]]
     val attrs = 'a.array(IntegerType) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/295747e5/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
index 1743783..e266ae5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -30,7 +30,7 @@ import 
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
   extends Aggregator[T, (Boolean, T), T] {
 
-  private val encoder = implicitly[Encoder[T]]
+  @transient private val encoder = implicitly[Encoder[T]]
 
   override def zero: (Boolean, T) = (false, null.asInstanceOf[T])
 

http://git-wip-us.apache.org/repos/asf/spark/blob/295747e5/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 68e071a..5b5cd28 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -142,6 +142,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
     assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2)))
   }
 
+  test("as seq of case class - reorder fields by name") {
+    val df = spark.range(3).select(array(struct($"id".cast("int").as("b"), 
lit("a").as("a"))))
+    val ds = df.as[Seq[ClassData]]
+    assert(ds.collect() === Array(
+      Seq(ClassData("a", 0)),
+      Seq(ClassData("a", 1)),
+      Seq(ClassData("a", 2))))
+  }
+
   test("map") {
     val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
     checkDataset(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to