Repository: spark
Updated Branches:
  refs/heads/master 66a3a5a2d -> 1035aaa61


[SPARK-23587][SQL] Add interpreted execution for MapObjects expression

## What changes were proposed in this pull request?

Add interpreted execution for `MapObjects` expression.

## How was this patch tested?

Added unit test.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #20771 from viirya/SPARK-23587.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1035aaa6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1035aaa6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1035aaa6

Branch: refs/heads/master
Commit: 1035aaa61704b2790192d3186fe37e678553d36d
Parents: 66a3a5a
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Wed Apr 4 01:36:58 2018 +0200
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Wed Apr 4 01:36:58 2018 +0200

----------------------------------------------------------------------
 .../catalyst/expressions/objects/objects.scala  | 110 +++++++++++++++++--
 .../expressions/ObjectExpressionsSuite.scala    |  67 ++++++++++-
 2 files changed, 165 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1035aaa6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index adf9ddf..0e9d357 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects
 
 import java.lang.reflect.Modifier
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable.Builder
 import scala.language.existentials
 import scala.reflect.ClassTag
@@ -501,12 +502,22 @@ case class LambdaVariable(
     value: String,
     isNull: String,
     dataType: DataType,
-    nullable: Boolean = true) extends LeafExpression
-  with Unevaluable with NonSQLExpression {
+    nullable: Boolean = true) extends LeafExpression with NonSQLExpression {
+
+  // Interpreted execution of `LambdaVariable` always get the 0-index element 
from input row.
+  override def eval(input: InternalRow): Any = {
+    assert(input.numFields == 1,
+      "The input row of interpreted LambdaVariable should have only 1 field.")
+    input.get(0, dataType)
+  }
 
   override def genCode(ctx: CodegenContext): ExprCode = {
     ExprCode(code = "", value = value, isNull = if (nullable) isNull else 
"false")
   }
+
+  // This won't be called as `genCode` is overrided, just overriding it to make
+  // `LambdaVariable` non-abstract.
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = ev
 }
 
 /**
@@ -599,8 +610,92 @@ case class MapObjects private(
 
   override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
 
-  override def eval(input: InternalRow): Any =
-    throw new UnsupportedOperationException("Only code-generated evaluation is 
supported")
+  // The data with UserDefinedType are actually stored with the data type of 
its sqlType.
+  // When we want to apply MapObjects on it, we have to use it.
+  lazy private val inputDataType = inputData.dataType match {
+    case u: UserDefinedType[_] => u.sqlType
+    case _ => inputData.dataType
+  }
+
+  private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = {
+    val row = new GenericInternalRow(1)
+    inputCollection.toIterator.map { element =>
+      row.update(0, element)
+      lambdaFunction.eval(row)
+    }
+  }
+
+  private lazy val convertToSeq: Any => Seq[_] = inputDataType match {
+    case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+      _.asInstanceOf[Seq[_]]
+    case ObjectType(cls) if cls.isArray =>
+      _.asInstanceOf[Array[_]].toSeq
+    case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
+      _.asInstanceOf[java.util.List[_]].asScala
+    case ObjectType(cls) if cls == classOf[Object] =>
+      (inputCollection) => {
+        if (inputCollection.getClass.isArray) {
+          inputCollection.asInstanceOf[Array[_]].toSeq
+        } else {
+          inputCollection.asInstanceOf[Seq[_]]
+        }
+      }
+    case ArrayType(et, _) =>
+      _.asInstanceOf[ArrayData].array
+  }
+
+  private lazy val mapElements: Seq[_] => Any = customCollectionCls match {
+    case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+      // Scala sequence
+      executeFuncOnCollection(_).toSeq
+    case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
+      // Scala set
+      executeFuncOnCollection(_).toSet
+    case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
+      // Java list
+      if (cls == classOf[java.util.List[_]] || cls == 
classOf[java.util.AbstractList[_]] ||
+          cls == classOf[java.util.AbstractSequentialList[_]]) {
+        // Specifying non concrete implementations of `java.util.List`
+        executeFuncOnCollection(_).toSeq.asJava
+      } else {
+        val constructors = cls.getConstructors()
+        val intParamConstructor = constructors.find { constructor =>
+          constructor.getParameterCount == 1 && 
constructor.getParameterTypes()(0) == classOf[Int]
+        }
+        val noParamConstructor = constructors.find { constructor =>
+          constructor.getParameterCount == 0
+        }
+
+        val constructor = intParamConstructor.map { intConstructor =>
+          (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object])
+        }.getOrElse {
+          (_: Int) => noParamConstructor.get.newInstance()
+        }
+
+        // Specifying concrete implementations of `java.util.List`
+        (inputs) => {
+          val results = executeFuncOnCollection(inputs)
+          val builder = 
constructor(inputs.length).asInstanceOf[java.util.List[Any]]
+          results.foreach(builder.add(_))
+          builder
+        }
+      }
+    case None =>
+      // array
+      x => new GenericArrayData(executeFuncOnCollection(x).toArray)
+    case Some(cls) =>
+      throw new RuntimeException(s"class `${cls.getName}` is not supported by 
`MapObjects` as " +
+        "resulting collection.")
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val inputCollection = inputData.eval(input)
+
+    if (inputCollection == null) {
+      return null
+    }
+    mapElements(convertToSeq(inputCollection))
+  }
 
   override def dataType: DataType =
     customCollectionCls.map(ObjectType.apply).getOrElse(
@@ -647,13 +742,6 @@ case class MapObjects private(
       case _ => ""
     }
 
-    // The data with PythonUserDefinedType are actually stored with the data 
type of its sqlType.
-    // When we want to apply MapObjects on it, we have to use it.
-    val inputDataType = inputData.dataType match {
-      case p: PythonUserDefinedType => p.sqlType
-      case _ => inputData.dataType
-    }
-
     // `MapObjects` generates a while loop to traverse the elements of the 
input collection. We
     // need to take care of Seq and List because they may have O(n) complexity 
for indexed accessing
     // like `list.get(1)`. Here we use Iterator to traverse Seq and List.

http://git-wip-us.apache.org/repos/asf/spark/blob/1035aaa6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 1f6964d..0edd27c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import scala.collection.JavaConverters._
 import scala.reflect.ClassTag
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
@@ -25,7 +26,7 @@ import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
GenericArrayData}
 import org.apache.spark.sql.types._
 
 
@@ -135,6 +136,70 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     }
   }
 
+  test("SPARK-23587: MapObjects should support interpreted execution") {
+    def testMapObjects(collection: Any, collectionCls: Class[_], inputType: 
DataType): Unit = {
+      val function = (lambda: Expression) => Add(lambda, Literal(1))
+      val elementType = IntegerType
+      val expected = Seq(2, 3, 4)
+
+      val inputObject = BoundReference(0, inputType, nullable = true)
+      val optClass = Option(collectionCls)
+      val mapObj = MapObjects(function, inputObject, elementType, true, 
optClass)
+      val row = InternalRow.fromSeq(Seq(collection))
+      val result = mapObj.eval(row)
+
+      collectionCls match {
+        case null =>
+          assert(result.asInstanceOf[ArrayData].array.toSeq == expected)
+        case l if classOf[java.util.List[_]].isAssignableFrom(l) =>
+          assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == 
expected)
+        case s if classOf[Seq[_]].isAssignableFrom(s) =>
+          assert(result.asInstanceOf[Seq[_]].toSeq == expected)
+        case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) =>
+          assert(result.asInstanceOf[scala.collection.Set[_]] == 
expected.toSet)
+      }
+    }
+
+    val customCollectionClasses = Seq(classOf[Seq[Int]], 
classOf[scala.collection.Set[Int]],
+      classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
+      classOf[java.util.AbstractSequentialList[Int]], 
classOf[java.util.Vector[Int]],
+      classOf[java.util.Stack[Int]], null)
+
+    val list = new java.util.ArrayList[Int]()
+    list.add(1)
+    list.add(2)
+    list.add(3)
+    val arrayData = new GenericArrayData(Array(1, 2, 3))
+    val vector = new java.util.Vector[Int]()
+    vector.add(1)
+    vector.add(2)
+    vector.add(3)
+    val stack = new java.util.Stack[Int]()
+    stack.add(1)
+    stack.add(2)
+    stack.add(3)
+
+    Seq(
+      (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])),
+      (Array(1, 2, 3), ObjectType(classOf[Array[Int]])),
+      (Seq(1, 2, 3), ObjectType(classOf[Object])),
+      (Array(1, 2, 3), ObjectType(classOf[Object])),
+      (list, ObjectType(classOf[java.util.List[Int]])),
+      (vector, ObjectType(classOf[java.util.Vector[Int]])),
+      (stack, ObjectType(classOf[java.util.Stack[Int]])),
+      (arrayData, ArrayType(IntegerType))
+    ).foreach { case (collection, inputType) =>
+      customCollectionClasses.foreach(testMapObjects(collection, _, inputType))
+
+      // Unsupported custom collection class
+      val errMsg = intercept[RuntimeException] {
+        testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], 
inputType)
+      }.getMessage()
+      assert(errMsg.contains("`scala.collection.Map` is not supported by 
`MapObjects` " +
+        "as resulting collection."))
+    }
+  }
+
   test("SPARK-23592: DecodeUsingSerializer should support interpreted 
execution") {
     val cls = classOf[java.lang.Integer]
     val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), 
nullable = true)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to