Repository: spark
Updated Branches:
  refs/heads/master 25520e976 -> 8a977b065


[SPARK-16100][SQL] fix bug when use Map as the buffer type of Aggregator

## What changes were proposed in this pull request?

The root cause is in `MapObjects`. Its parameter `loopVar` is not declared as 
child, but sometimes can be same with `lambdaFunction`(e.g. the function that 
takes `loopVar` and produces `lambdaFunction` may be `identity`), which is a 
child. This brings trouble when call `withNewChildren`, it may mistakenly treat 
`loopVar` as a child and cause `IndexOutOfBoundsException: 0` later.

This PR fixes this bug by simply pulling out the paremters from 
`LambdaVariable` and pass them to `MapObjects` directly.

## How was this patch tested?

new test in `DatasetAggregatorSuite`

Author: Wenchen Fan <wenc...@databricks.com>

Closes #13835 from cloud-fan/map-objects.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8a977b06
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8a977b06
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8a977b06

Branch: refs/heads/master
Commit: 8a977b065418f07d2bf4fe1607a5534c32d04c47
Parents: 25520e9
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Wed Jun 29 06:39:28 2016 +0800
Committer: Cheng Lian <l...@databricks.com>
Committed: Wed Jun 29 06:39:28 2016 +0800

----------------------------------------------------------------------
 .../catalyst/expressions/objects/objects.scala  | 28 ++++++++++++--------
 .../spark/sql/DatasetAggregatorSuite.scala      | 15 +++++++++++
 2 files changed, 32 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8a977b06/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index c597a2a..ea4dee1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -353,7 +353,7 @@ object MapObjects {
     val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
     val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
     val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
-    MapObjects(loopVar, function(loopVar), inputData)
+    MapObjects(loopValue, loopIsNull, elementType, function(loopVar), 
inputData)
   }
 }
 
@@ -365,14 +365,20 @@ object MapObjects {
  * The following collection ObjectTypes are currently supported:
  *   Seq, Array, ArrayData, java.util.List
  *
- * @param loopVar A place holder that used as the loop variable when iterate 
the collection, and
- *                used as input for the `lambdaFunction`. It also carries the 
element type info.
+ * @param loopValue the name of the loop variable that used when iterate the 
collection, and used
+ *                  as input for the `lambdaFunction`
+ * @param loopIsNull the nullity of the loop variable that used when iterate 
the collection, and
+ *                   used as input for the `lambdaFunction`
+ * @param loopVarDataType the data type of the loop variable that used when 
iterate the collection,
+ *                        and used as input for the `lambdaFunction`
  * @param lambdaFunction A function that take the `loopVar` as input, and used 
as lambda function
  *                       to handle collection elements.
  * @param inputData An expression that when evaluated returns a collection 
object.
  */
 case class MapObjects private(
-    loopVar: LambdaVariable,
+    loopValue: String,
+    loopIsNull: String,
+    loopVarDataType: DataType,
     lambdaFunction: Expression,
     inputData: Expression) extends Expression with NonSQLExpression {
 
@@ -386,9 +392,9 @@ case class MapObjects private(
   override def dataType: DataType = ArrayType(lambdaFunction.dataType)
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val elementJavaType = ctx.javaType(loopVar.dataType)
-    ctx.addMutableState("boolean", loopVar.isNull, "")
-    ctx.addMutableState(elementJavaType, loopVar.value, "")
+    val elementJavaType = ctx.javaType(loopVarDataType)
+    ctx.addMutableState("boolean", loopIsNull, "")
+    ctx.addMutableState(elementJavaType, loopValue, "")
     val genInputData = inputData.genCode(ctx)
     val genFunction = lambdaFunction.genCode(ctx)
     val dataLength = ctx.freshName("dataLength")
@@ -443,11 +449,11 @@ case class MapObjects private(
     }
 
     val loopNullCheck = inputData.dataType match {
-      case _: ArrayType => s"${loopVar.isNull} = 
${genInputData.value}.isNullAt($loopIndex);"
+      case _: ArrayType => s"$loopIsNull = 
${genInputData.value}.isNullAt($loopIndex);"
       // The element of primitive array will never be null.
       case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive 
=>
-        s"${loopVar.isNull} = false"
-      case _ => s"${loopVar.isNull} = ${loopVar.value} == null;"
+        s"$loopIsNull = false"
+      case _ => s"$loopIsNull = $loopValue == null;"
     }
 
     val code = s"""
@@ -462,7 +468,7 @@ case class MapObjects private(
 
         int $loopIndex = 0;
         while ($loopIndex < $dataLength) {
-          ${loopVar.value} = ($elementJavaType) ($getLoopVar);
+          $loopValue = ($elementJavaType) ($getLoopVar);
           $loopNullCheck
 
           ${genFunction.code}

http://git-wip-us.apache.org/repos/asf/spark/blob/8a977b06/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index f955120..32fcf84 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -74,6 +74,16 @@ object ComplexBufferAgg extends Aggregator[AggData, (Int, 
AggData), Int] {
 }
 
 
+object MapTypeBufferAgg extends Aggregator[Int, Map[Int, Int], Int] {
+  override def zero: Map[Int, Int] = Map.empty
+  override def reduce(b: Map[Int, Int], a: Int): Map[Int, Int] = b
+  override def finish(reduction: Map[Int, Int]): Int = 1
+  override def merge(b1: Map[Int, Int], b2: Map[Int, Int]): Map[Int, Int] = b1
+  override def bufferEncoder: Encoder[Map[Int, Int]] = ExpressionEncoder()
+  override def outputEncoder: Encoder[Int] = ExpressionEncoder()
+}
+
+
 object NameAgg extends Aggregator[AggData, String, String] {
   def zero: String = ""
   def reduce(b: String, a: AggData): String = a.b + b
@@ -290,4 +300,9 @@ class DatasetAggregatorSuite extends QueryTest with 
SharedSQLContext {
       ds.groupByKey(_.a).agg(NullResultAgg.toColumn),
       1 -> AggData(1, "one"), 2 -> null)
   }
+
+  test("SPARK-16100: use Map as the buffer type of Aggregator") {
+    val ds = Seq(1, 2, 3).toDS()
+    checkDataset(ds.select(MapTypeBufferAgg.toColumn), 1)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to