Repository: spark
Updated Branches:
  refs/heads/master 5197562af -> a35523653


[SPARK-23583][SQL] Invoke should support interpreted execution

## What changes were proposed in this pull request?

This pr added interpreted execution for `Invoke`.

## How was this patch tested?

Added tests in `ObjectExpressionsSuite`.

Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>

Closes #20797 from kiszk/SPARK-28583.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a3552365
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a3552365
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a3552365

Branch: refs/heads/master
Commit: a35523653cdac039ee2ddff316bc2c25d6514a91
Parents: 5197562
Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Authored: Wed Apr 4 18:36:15 2018 +0200
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Wed Apr 4 18:36:15 2018 +0200

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 48 ++++++++++++++-
 .../catalyst/expressions/objects/objects.scala  | 56 +++++++++++++++--
 .../expressions/ObjectExpressionsSuite.scala    | 65 ++++++++++++++++++++
 3 files changed, 163 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a3552365/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9a4bf00..1aae3ae 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
 import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, 
UnresolvedAttribute, UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, 
GenericArrayData, MapData}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
@@ -794,6 +794,52 @@ object ScalaReflection extends ScalaReflection {
     "interface", "long", "native", "new", "null", "package", "private", 
"protected", "public",
     "return", "short", "static", "strictfp", "super", "switch", 
"synchronized", "this", "throw",
     "throws", "transient", "true", "try", "void", "volatile", "while")
+
+  val typeJavaMapping = Map[DataType, Class[_]](
+    BooleanType -> classOf[Boolean],
+    ByteType -> classOf[Byte],
+    ShortType -> classOf[Short],
+    IntegerType -> classOf[Int],
+    LongType -> classOf[Long],
+    FloatType -> classOf[Float],
+    DoubleType -> classOf[Double],
+    StringType -> classOf[UTF8String],
+    DateType -> classOf[DateType.InternalType],
+    TimestampType -> classOf[TimestampType.InternalType],
+    BinaryType -> classOf[BinaryType.InternalType],
+    CalendarIntervalType -> classOf[CalendarInterval]
+  )
+
+  val typeBoxedJavaMapping = Map[DataType, Class[_]](
+    BooleanType -> classOf[java.lang.Boolean],
+    ByteType -> classOf[java.lang.Byte],
+    ShortType -> classOf[java.lang.Short],
+    IntegerType -> classOf[java.lang.Integer],
+    LongType -> classOf[java.lang.Long],
+    FloatType -> classOf[java.lang.Float],
+    DoubleType -> classOf[java.lang.Double],
+    DateType -> classOf[java.lang.Integer],
+    TimestampType -> classOf[java.lang.Long]
+  )
+
+  def dataTypeJavaClass(dt: DataType): Class[_] = {
+    dt match {
+      case _: DecimalType => classOf[Decimal]
+      case _: StructType => classOf[InternalRow]
+      case _: ArrayType => classOf[ArrayData]
+      case _: MapType => classOf[MapData]
+      case ObjectType(cls) => cls
+      case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object])
+    }
+  }
+
+  def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = {
+    if (arguments != Nil) {
+      arguments.map(e => dataTypeJavaClass(e.dataType))
+    } else {
+      Seq.empty
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/a3552365/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 0e9d357..a455c1c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions.objects
 
-import java.lang.reflect.Modifier
+import java.lang.reflect.{Method, Modifier}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.Builder
@@ -28,7 +28,7 @@ import scala.util.Try
 import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.serializer._
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
 import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
@@ -104,6 +104,38 @@ trait InvokeLike extends Expression with NonSQLExpression {
 
     (argCode, argValues.mkString(", "), resultIsNull)
   }
+
+  /**
+   * Evaluate each argument with a given row, invoke a method with a given 
object and arguments,
+   * and cast a return value if the return type can be mapped to a Java Boxed 
type
+   *
+   * @param obj the object for the method to be called. If null, perform s 
static method call
+   * @param method the method object to be called
+   * @param arguments the arguments used for the method call
+   * @param input the row used for evaluating arguments
+   * @param dataType the data type of the return object
+   * @return the return object of a method call
+   */
+  def invoke(
+      obj: Any,
+      method: Method,
+      arguments: Seq[Expression],
+      input: InternalRow,
+      dataType: DataType): Any = {
+    val args = arguments.map(e => e.eval(input).asInstanceOf[Object])
+    if (needNullCheck && args.exists(_ == null)) {
+      // return null if one of arguments is null
+      null
+    } else {
+      val ret = method.invoke(obj, args: _*)
+      val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType)
+      if (boxedClass.isDefined) {
+        boxedClass.get.cast(ret)
+      } else {
+        ret
+      }
+    }
+  }
 }
 
 /**
@@ -264,12 +296,11 @@ case class Invoke(
     propagateNull: Boolean = true,
     returnNullable : Boolean = true) extends InvokeLike {
 
+  lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
+
   override def nullable: Boolean = targetObject.nullable || needNullCheck || 
returnNullable
   override def children: Seq[Expression] = targetObject +: arguments
 
-  override def eval(input: InternalRow): Any =
-    throw new UnsupportedOperationException("Only code-generated evaluation is 
supported.")
-
   private lazy val encodedFunctionName = 
TermName(functionName).encodedName.toString
 
   @transient lazy val method = targetObject.dataType match {
@@ -283,6 +314,21 @@ case class Invoke(
     case _ => None
   }
 
+  override def eval(input: InternalRow): Any = {
+    val obj = targetObject.eval(input)
+    if (obj == null) {
+      // return null if obj is null
+      null
+    } else {
+      val invokeMethod = if (method.isDefined) {
+        method.get
+      } else {
+        obj.getClass.getDeclaredMethod(functionName, argClasses: _*)
+      }
+      invoke(obj, invokeMethod, arguments, input, dataType)
+    }
+  }
+
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val javaType = CodeGenerator.javaType(dataType)
     val obj = targetObject.genCode(ctx)

http://git-wip-us.apache.org/repos/asf/spark/blob/a3552365/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 0edd27c..9bfe291 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -24,11 +24,23 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
 import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
GenericArrayData}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
+class InvokeTargetClass extends Serializable {
+  def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
+  def filterPrimitiveInt(e: Int): Boolean = e > 0
+  def binOp(e1: Int, e2: Double): Double = e1 + e2
+}
+
+class InvokeTargetSubClass extends InvokeTargetClass {
+  override def binOp(e1: Int, e2: Double): Double = e1 - e2
+}
 
 class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
@@ -81,6 +93,41 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is 
fixed
   }
 
+  test("SPARK-23583: Invoke should support interpreted execution") {
+    val targetObject = new InvokeTargetClass
+    val funcClass = classOf[InvokeTargetClass]
+    val funcObj = Literal.create(targetObject, ObjectType(funcClass))
+    val targetSubObject = new InvokeTargetSubClass
+    val funcSubObj = Literal.create(targetSubObject, 
ObjectType(classOf[InvokeTargetSubClass]))
+    val funcNullObj = Literal.create(null, ObjectType(funcClass))
+
+    val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true))
+    val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false))
+    val inputSum = Seq(BoundReference(0, IntegerType, false), 
BoundReference(1, DoubleType, false))
+
+    checkObjectExprEvaluation(
+      Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
+      java.lang.Boolean.valueOf(true), 
InternalRow.fromSeq(Seq(Integer.valueOf(1))))
+
+    checkObjectExprEvaluation(
+      Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt),
+      false, InternalRow.fromSeq(Seq(-1)))
+
+    checkObjectExprEvaluation(
+      Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
+      null, InternalRow.fromSeq(Seq(null)))
+
+    checkObjectExprEvaluation(
+      Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt),
+      null, InternalRow.fromSeq(Seq(Integer.valueOf(1))))
+
+    checkObjectExprEvaluation(
+      Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, 
InternalRow.apply(1, 0.25))
+
+    checkObjectExprEvaluation(
+      Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, 
InternalRow.apply(1, 0.25))
+  }
+
   test("SPARK-23585: UnwrapOption should support interpreted execution") {
     val cls = classOf[Option[Int]]
     val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
@@ -105,6 +152,24 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), 
InternalRow.fromSeq(Seq()))
   }
 
+  // by scala values instead of catalyst values.
+  private def checkObjectExprEvaluation(
+      expression: => Expression, expected: Any, inputRow: InternalRow = 
EmptyRow): Unit = {
+    val serializer = new JavaSerializer(new SparkConf()).newInstance
+    val resolver = ResolveTimeZone(new SQLConf)
+    val expr = 
resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
+    checkEvaluationWithoutCodegen(expr, expected, inputRow)
+    checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow)
+    if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
+      checkEvaluationWithUnsafeProjection(
+        expr,
+        expected,
+        inputRow,
+        UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is 
fixed
+    }
+    checkEvaluationWithOptimization(expr, expected, inputRow)
+  }
+
   test("SPARK-23594 GetExternalRowField should support interpreted execution") 
{
     val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = 
true)
     val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = 
"c0")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to