Repository: spark
Updated Branches:
  refs/heads/master b79bf1df6 -> 6c400b4f3


[SPARK-9354][SQL] Remove InternalRow.get generic getter call in Hive 
integration code.

Replaced them with get(ordinal, datatype) so we can use UnsafeRow here.

I passed the data types throughout.

Author: Reynold Xin <r...@databricks.com>

Closes #7669 from rxin/row-generic-getter-hive and squashes the following 
commits:

3467d8e [Reynold Xin] [SPARK-9354][SQL] Remove Internal.get generic getter call 
in Hive integration code.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6c400b4f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6c400b4f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6c400b4f

Branch: refs/heads/master
Commit: 6c400b4f39be3fb5f473b8d2db11d239ea8ddf42
Parents: b79bf1d
Author: Reynold Xin <r...@databricks.com>
Authored: Sun Jul 26 10:27:39 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Sun Jul 26 10:27:39 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/hive/HiveInspectors.scala  | 43 +++++++-----
 .../org/apache/spark/sql/hive/hiveUDFs.scala    | 74 ++++++++++++--------
 .../spark/sql/hive/HiveInspectorSuite.scala     | 53 ++++++++------
 3 files changed, 102 insertions(+), 68 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6c400b4f/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 16977ce..f467500 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -46,7 +46,7 @@ import scala.collection.JavaConversions._
  *     long / scala.Long
  *     short / scala.Short
  *     byte / scala.Byte
- *     org.apache.spark.sql.types.Decimal
+ *     [[org.apache.spark.sql.types.Decimal]]
  *     Array[Byte]
  *     java.sql.Date
  *     java.sql.Timestamp
@@ -54,7 +54,7 @@ import scala.collection.JavaConversions._
  *    Map: scala.collection.immutable.Map
  *    List: scala.collection.immutable.Seq
  *    Struct:
- *           org.apache.spark.sql.catalyst.expression.Row
+ *           [[org.apache.spark.sql.catalyst.InternalRow]]
  *    Union: NOT SUPPORTED YET
  *  The Complex types plays as a container, which can hold arbitrary data 
types.
  *
@@ -454,7 +454,7 @@ private[hive] trait HiveInspectors {
    *
    *  NOTICE: the complex data type requires recursive wrapping.
    */
-  def wrap(a: Any, oi: ObjectInspector): AnyRef = oi match {
+  def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = oi match 
{
     case x: ConstantObjectInspector => x.getWritableConstantValue
     case _ if a == null => null
     case x: PrimitiveObjectInspector => x match {
@@ -488,43 +488,50 @@ private[hive] trait HiveInspectors {
     }
     case x: SettableStructObjectInspector =>
       val fieldRefs = x.getAllStructFieldRefs
+      val structType = dataType.asInstanceOf[StructType]
       val row = a.asInstanceOf[InternalRow]
       // 1. create the pojo (most likely) object
       val result = x.create()
       var i = 0
       while (i < fieldRefs.length) {
         // 2. set the property for the pojo
+        val tpe = structType(i).dataType
         x.setStructFieldData(
           result,
           fieldRefs.get(i),
-          wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector))
+          wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe))
         i += 1
       }
 
       result
     case x: StructObjectInspector =>
       val fieldRefs = x.getAllStructFieldRefs
+      val structType = dataType.asInstanceOf[StructType]
       val row = a.asInstanceOf[InternalRow]
       val result = new java.util.ArrayList[AnyRef](fieldRefs.length)
       var i = 0
       while (i < fieldRefs.length) {
-        result.add(wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector))
+        val tpe = structType(i).dataType
+        result.add(wrap(row.get(i, tpe), 
fieldRefs.get(i).getFieldObjectInspector, tpe))
         i += 1
       }
 
       result
     case x: ListObjectInspector =>
       val list = new java.util.ArrayList[Object]
+      val tpe = dataType.asInstanceOf[ArrayType].elementType
       a.asInstanceOf[Seq[_]].foreach {
-        v => list.add(wrap(v, x.getListElementObjectInspector))
+        v => list.add(wrap(v, x.getListElementObjectInspector, tpe))
       }
       list
     case x: MapObjectInspector =>
+      val keyType = dataType.asInstanceOf[MapType].keyType
+      val valueType = dataType.asInstanceOf[MapType].valueType
       // Some UDFs seem to assume we pass in a HashMap.
       val hashMap = new java.util.HashMap[AnyRef, AnyRef]()
-      hashMap.putAll(a.asInstanceOf[Map[_, _]].map {
-        case (k, v) =>
-          wrap(k, x.getMapKeyObjectInspector) -> wrap(v, 
x.getMapValueObjectInspector)
+      hashMap.putAll(a.asInstanceOf[Map[_, _]].map { case (k, v) =>
+        wrap(k, x.getMapKeyObjectInspector, keyType) ->
+          wrap(v, x.getMapValueObjectInspector, valueType)
       })
 
       hashMap
@@ -533,22 +540,24 @@ private[hive] trait HiveInspectors {
   def wrap(
       row: InternalRow,
       inspectors: Seq[ObjectInspector],
-      cache: Array[AnyRef]): Array[AnyRef] = {
+      cache: Array[AnyRef],
+      dataTypes: Array[DataType]): Array[AnyRef] = {
     var i = 0
     while (i < inspectors.length) {
-      cache(i) = wrap(row.get(i), inspectors(i))
+      cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i))
       i += 1
     }
     cache
   }
 
   def wrap(
-    row: Seq[Any],
-    inspectors: Seq[ObjectInspector],
-    cache: Array[AnyRef]): Array[AnyRef] = {
+      row: Seq[Any],
+      inspectors: Seq[ObjectInspector],
+      cache: Array[AnyRef],
+      dataTypes: Array[DataType]): Array[AnyRef] = {
     var i = 0
     while (i < inspectors.length) {
-      cache(i) = wrap(row(i), inspectors(i))
+      cache(i) = wrap(row(i), inspectors(i), dataTypes(i))
       i += 1
     }
     cache
@@ -625,7 +634,7 @@ private[hive] trait HiveInspectors {
         
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector,
 null)
       } else {
         val list = new java.util.ArrayList[Object]()
-        value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, 
listObjectInspector)))
+        value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, 
listObjectInspector, dt)))
         
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector,
 list)
       }
     case Literal(value, MapType(keyType, valueType, _)) =>
@@ -636,7 +645,7 @@ private[hive] trait HiveInspectors {
       } else {
         val map = new java.util.HashMap[Object, Object]()
         value.asInstanceOf[Map[_, _]].foreach (entry => {
-          map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI))
+          map.put(wrap(entry._1, keyOI, keyType), wrap(entry._2, valueOI, 
valueType))
         })
         ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, 
valueOI, map)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/6c400b4f/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 3259b50..54bf6bd 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -83,24 +83,22 @@ private[hive] class HiveFunctionRegistry(underlying: 
analysis.FunctionRegistry)
 private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, 
children: Seq[Expression])
   extends Expression with HiveInspectors with CodegenFallback with Logging {
 
-  type UDFType = UDF
-
   override def deterministic: Boolean = isUDFDeterministic
 
   override def nullable: Boolean = true
 
   @transient
-  lazy val function = funcWrapper.createFunction[UDFType]()
+  lazy val function = funcWrapper.createFunction[UDF]()
 
   @transient
-  protected lazy val method =
+  private lazy val method =
     function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
 
   @transient
-  protected lazy val arguments = children.map(toInspector).toArray
+  private lazy val arguments = children.map(toInspector).toArray
 
   @transient
-  protected lazy val isUDFDeterministic = {
+  private lazy val isUDFDeterministic = {
     val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
     udfType != null && udfType.deterministic()
   }
@@ -109,7 +107,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: 
HiveFunctionWrapper, childre
 
   // Create parameter converters
   @transient
-  protected lazy val conversionHelper = new ConversionHelper(method, arguments)
+  private lazy val conversionHelper = new ConversionHelper(method, arguments)
 
   @transient
   lazy val dataType = javaClassToDataType(method.getReturnType)
@@ -119,14 +117,19 @@ private[hive] case class HiveSimpleUDF(funcWrapper: 
HiveFunctionWrapper, childre
     method.getGenericReturnType(), ObjectInspectorOptions.JAVA)
 
   @transient
-  protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
+  private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
+
+  @transient
+  private lazy val inputDataTypes: Array[DataType] = 
children.map(_.dataType).toArray
 
   // TODO: Finish input output types.
   override def eval(input: InternalRow): Any = {
-    unwrap(
-      FunctionRegistry.invoke(method, function, conversionHelper
-        .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, 
cached): _*): _*),
-      returnInspector)
+    val inputs = wrap(children.map(c => c.eval(input)), arguments, cached, 
inputDataTypes)
+    val ret = FunctionRegistry.invoke(
+      method,
+      function,
+      conversionHelper.convertIfNecessary(inputs : _*): _*)
+    unwrap(ret, returnInspector)
   }
 
   override def toString: String = {
@@ -135,47 +138,48 @@ private[hive] case class HiveSimpleUDF(funcWrapper: 
HiveFunctionWrapper, childre
 }
 
 // Adapter from Catalyst ExpressionResult to Hive DeferredObject
-private[hive] class DeferredObjectAdapter(oi: ObjectInspector)
+private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: 
DataType)
   extends DeferredObject with HiveInspectors {
+
   private var func: () => Any = _
   def set(func: () => Any): Unit = {
     this.func = func
   }
   override def prepare(i: Int): Unit = {}
-  override def get(): AnyRef = wrap(func(), oi)
+  override def get(): AnyRef = wrap(func(), oi, dataType)
 }
 
 private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, 
children: Seq[Expression])
   extends Expression with HiveInspectors with CodegenFallback with Logging {
-  type UDFType = GenericUDF
+
+  override def nullable: Boolean = true
 
   override def deterministic: Boolean = isUDFDeterministic
 
-  override def nullable: Boolean = true
+  override def foldable: Boolean =
+    isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]
 
   @transient
-  lazy val function = funcWrapper.createFunction[UDFType]()
+  lazy val function = funcWrapper.createFunction[GenericUDF]()
 
   @transient
-  protected lazy val argumentInspectors = children.map(toInspector)
+  private lazy val argumentInspectors = children.map(toInspector)
 
   @transient
-  protected lazy val returnInspector = {
+  private lazy val returnInspector = {
     function.initializeAndFoldConstants(argumentInspectors.toArray)
   }
 
   @transient
-  protected lazy val isUDFDeterministic = {
+  private lazy val isUDFDeterministic = {
     val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
     udfType != null && udfType.deterministic()
   }
 
-  override def foldable: Boolean =
-    isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]
-
   @transient
-  protected lazy val deferedObjects =
-    argumentInspectors.map(new 
DeferredObjectAdapter(_)).toArray[DeferredObject]
+  private lazy val deferedObjects = argumentInspectors.zip(children).map { 
case (inspect, child) =>
+    new DeferredObjectAdapter(inspect, child.dataType)
+  }.toArray[DeferredObject]
 
   lazy val dataType: DataType = inspectorToDataType(returnInspector)
 
@@ -354,6 +358,9 @@ private[hive] case class HiveWindowFunction(
   // Output buffer.
   private var outputBuffer: Any = _
 
+  @transient
+  private lazy val inputDataTypes: Array[DataType] = 
children.map(_.dataType).toArray
+
   override def init(): Unit = {
     evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors)
   }
@@ -368,8 +375,13 @@ private[hive] case class HiveWindowFunction(
   }
 
   override def prepareInputParameters(input: InternalRow): AnyRef = {
-    wrap(inputProjection(input), inputInspectors, new 
Array[AnyRef](children.length))
+    wrap(
+      inputProjection(input),
+      inputInspectors,
+      new Array[AnyRef](children.length),
+      inputDataTypes)
   }
+
   // Add input parameters for a single row.
   override def update(input: AnyRef): Unit = {
     evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]])
@@ -510,12 +522,15 @@ private[hive] case class HiveGenericUDTF(
     field => (inspectorToDataType(field.getFieldObjectInspector), true)
   }
 
+  @transient
+  private lazy val inputDataTypes: Array[DataType] = 
children.map(_.dataType).toArray
+
   override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
     outputInspector // Make sure initialized.
 
     val inputProjection = new InterpretedProjection(children)
 
-    function.process(wrap(inputProjection(input), inputInspectors, udtInput))
+    function.process(wrap(inputProjection(input), inputInspectors, udtInput, 
inputDataTypes))
     collector.collectRows()
   }
 
@@ -584,9 +599,12 @@ private[hive] case class HiveUDAFFunction(
   @transient
   protected lazy val cached = new Array[AnyRef](exprs.length)
 
+  @transient
+  private lazy val inputDataTypes: Array[DataType] = 
exprs.map(_.dataType).toArray
+
   def update(input: InternalRow): Unit = {
     val inputs = inputProjection(input)
-    function.iterate(buffer, wrap(inputs, inspectors, cached))
+    function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes))
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6c400b4f/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index 8bb498a..0330013 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -48,7 +48,11 @@ class HiveInspectorSuite extends SparkFunSuite with 
HiveInspectors {
       ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector]
 
     val a = unwrap(state, soi).asInstanceOf[InternalRow]
-    val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State]
+
+    val dt = new StructType()
+      .add("counts", MapType(LongType, LongType))
+      .add("percentiles", ArrayType(DoubleType))
+    val b = wrap(a, soi, dt).asInstanceOf[UDAFPercentile.State]
 
     val sfCounts = soi.getStructFieldRef("counts")
     val sfPercentiles = soi.getStructFieldRef("percentiles")
@@ -158,44 +162,45 @@ class HiveInspectorSuite extends SparkFunSuite with 
HiveInspectors {
     val writableOIs = dataTypes.map(toWritableInspector)
     val nullRow = data.map(d => null)
 
-    checkValues(nullRow, nullRow.zip(writableOIs).map {
-      case (d, oi) => unwrap(wrap(d, oi), oi)
+    checkValues(nullRow, nullRow.zip(writableOIs).zip(dataTypes).map {
+      case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi)
     })
 
     // struct couldn't be constant, sweep it out
     val constantExprs = data.filter(!_.dataType.isInstanceOf[StructType])
+    val constantTypes = constantExprs.map(_.dataType)
     val constantData = constantExprs.map(_.eval())
     val constantNullData = constantData.map(_ => null)
     val constantWritableOIs = constantExprs.map(e => 
toWritableInspector(e.dataType))
     val constantNullWritableOIs =
       constantExprs.map(e => toInspector(Literal.create(null, e.dataType)))
 
-    checkValues(constantData, constantData.zip(constantWritableOIs).map {
-      case (d, oi) => unwrap(wrap(d, oi), oi)
+    checkValues(constantData, 
constantData.zip(constantWritableOIs).zip(constantTypes).map {
+      case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi)
     })
 
-    checkValues(constantNullData, 
constantData.zip(constantNullWritableOIs).map {
-      case (d, oi) => unwrap(wrap(d, oi), oi)
+    checkValues(constantNullData, 
constantData.zip(constantNullWritableOIs).zip(constantTypes).map {
+      case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi)
     })
 
-    checkValues(constantNullData, 
constantNullData.zip(constantWritableOIs).map {
-      case (d, oi) => unwrap(wrap(d, oi), oi)
+    checkValues(constantNullData, 
constantNullData.zip(constantWritableOIs).zip(constantTypes).map {
+      case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi)
     })
   }
 
   test("wrap / unwrap primitive writable object inspector") {
     val writableOIs = dataTypes.map(toWritableInspector)
 
-    checkValues(row, row.zip(writableOIs).map {
-      case (data, oi) => unwrap(wrap(data, oi), oi)
+    checkValues(row, row.zip(writableOIs).zip(dataTypes).map {
+      case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi)
     })
   }
 
   test("wrap / unwrap primitive java object inspector") {
     val ois = dataTypes.map(toInspector)
 
-    checkValues(row, row.zip(ois).map {
-      case (data, oi) => unwrap(wrap(data, oi), oi)
+    checkValues(row, row.zip(ois).zip(dataTypes).map {
+      case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi)
     })
   }
 
@@ -205,31 +210,33 @@ class HiveInspectorSuite extends SparkFunSuite with 
HiveInspectors {
     })
     val inspector = toInspector(dt)
     checkValues(row,
-      unwrap(wrap(InternalRow.fromSeq(row), inspector), 
inspector).asInstanceOf[InternalRow])
-    checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
+      unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), 
inspector).asInstanceOf[InternalRow])
+    checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
   }
 
   test("wrap / unwrap Array Type") {
     val dt = ArrayType(dataTypes(0))
 
     val d = row(0) :: row(0) :: Nil
-    checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt)))
-    checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
+    checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt)))
+    checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
     checkValue(d,
-      unwrap(wrap(d, toInspector(Literal.create(d, dt))), 
toInspector(Literal.create(d, dt))))
+      unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), 
toInspector(Literal.create(d, dt))))
     checkValue(d,
-      unwrap(wrap(null, toInspector(Literal.create(d, dt))), 
toInspector(Literal.create(d, dt))))
+      unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt),
+        toInspector(Literal.create(d, dt))))
   }
 
   test("wrap / unwrap Map Type") {
     val dt = MapType(dataTypes(0), dataTypes(1))
 
     val d = Map(row(0) -> row(1))
-    checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt)))
-    checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
+    checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt)))
+    checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
     checkValue(d,
-      unwrap(wrap(d, toInspector(Literal.create(d, dt))), 
toInspector(Literal.create(d, dt))))
+      unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), 
toInspector(Literal.create(d, dt))))
     checkValue(d,
-      unwrap(wrap(null, toInspector(Literal.create(d, dt))), 
toInspector(Literal.create(d, dt))))
+      unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt),
+        toInspector(Literal.create(d, dt))))
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to