Repository: spark Updated Branches: refs/heads/master 25520e976 -> 8a977b065
[SPARK-16100][SQL] fix bug when use Map as the buffer type of Aggregator ## What changes were proposed in this pull request? The root cause is in `MapObjects`. Its parameter `loopVar` is not declared as child, but sometimes can be same with `lambdaFunction`(e.g. the function that takes `loopVar` and produces `lambdaFunction` may be `identity`), which is a child. This brings trouble when call `withNewChildren`, it may mistakenly treat `loopVar` as a child and cause `IndexOutOfBoundsException: 0` later. This PR fixes this bug by simply pulling out the paremters from `LambdaVariable` and pass them to `MapObjects` directly. ## How was this patch tested? new test in `DatasetAggregatorSuite` Author: Wenchen Fan <wenc...@databricks.com> Closes #13835 from cloud-fan/map-objects. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8a977b06 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8a977b06 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8a977b06 Branch: refs/heads/master Commit: 8a977b065418f07d2bf4fe1607a5534c32d04c47 Parents: 25520e9 Author: Wenchen Fan <wenc...@databricks.com> Authored: Wed Jun 29 06:39:28 2016 +0800 Committer: Cheng Lian <l...@databricks.com> Committed: Wed Jun 29 06:39:28 2016 +0800 ---------------------------------------------------------------------- .../catalyst/expressions/objects/objects.scala | 28 ++++++++++++-------- .../spark/sql/DatasetAggregatorSuite.scala | 15 +++++++++++ 2 files changed, 32 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8a977b06/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c597a2a..ea4dee1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -353,7 +353,7 @@ object MapObjects { val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopVar, function(loopVar), inputData) + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) } } @@ -365,14 +365,20 @@ object MapObjects { * The following collection ObjectTypes are currently supported: * Seq, Array, ArrayData, java.util.List * - * @param loopVar A place holder that used as the loop variable when iterate the collection, and - * used as input for the `lambdaFunction`. It also carries the element type info. + * @param loopValue the name of the loop variable that used when iterate the collection, and used + * as input for the `lambdaFunction` + * @param loopIsNull the nullity of the loop variable that used when iterate the collection, and + * used as input for the `lambdaFunction` + * @param loopVarDataType the data type of the loop variable that used when iterate the collection, + * and used as input for the `lambdaFunction` * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. */ case class MapObjects private( - loopVar: LambdaVariable, + loopValue: String, + loopIsNull: String, + loopVarDataType: DataType, lambdaFunction: Expression, inputData: Expression) extends Expression with NonSQLExpression { @@ -386,9 +392,9 @@ case class MapObjects private( override def dataType: DataType = ArrayType(lambdaFunction.dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementJavaType = ctx.javaType(loopVar.dataType) - ctx.addMutableState("boolean", loopVar.isNull, "") - ctx.addMutableState(elementJavaType, loopVar.value, "") + val elementJavaType = ctx.javaType(loopVarDataType) + ctx.addMutableState("boolean", loopIsNull, "") + ctx.addMutableState(elementJavaType, loopValue, "") val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -443,11 +449,11 @@ case class MapObjects private( } val loopNullCheck = inputData.dataType match { - case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" + case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => - s"${loopVar.isNull} = false" - case _ => s"${loopVar.isNull} = ${loopVar.value} == null;" + s"$loopIsNull = false" + case _ => s"$loopIsNull = $loopValue == null;" } val code = s""" @@ -462,7 +468,7 @@ case class MapObjects private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - ${loopVar.value} = ($elementJavaType) ($getLoopVar); + $loopValue = ($elementJavaType) ($getLoopVar); $loopNullCheck ${genFunction.code} http://git-wip-us.apache.org/repos/asf/spark/blob/8a977b06/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index f955120..32fcf84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -74,6 +74,16 @@ object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { } +object MapTypeBufferAgg extends Aggregator[Int, Map[Int, Int], Int] { + override def zero: Map[Int, Int] = Map.empty + override def reduce(b: Map[Int, Int], a: Int): Map[Int, Int] = b + override def finish(reduction: Map[Int, Int]): Int = 1 + override def merge(b1: Map[Int, Int], b2: Map[Int, Int]): Map[Int, Int] = b1 + override def bufferEncoder: Encoder[Map[Int, Int]] = ExpressionEncoder() + override def outputEncoder: Encoder[Int] = ExpressionEncoder() +} + + object NameAgg extends Aggregator[AggData, String, String] { def zero: String = "" def reduce(b: String, a: AggData): String = a.b + b @@ -290,4 +300,9 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.groupByKey(_.a).agg(NullResultAgg.toColumn), 1 -> AggData(1, "one"), 2 -> null) } + + test("SPARK-16100: use Map as the buffer type of Aggregator") { + val ds = Seq(1, 2, 3).toDS() + checkDataset(ds.select(MapTypeBufferAgg.toColumn), 1) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org