Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156184646 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -863,25 +984,43 @@ case class HashAggregateExec( } val updateRowInUnsafeRowMap: String = { - ctx.INPUT_ROW = unsafeRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localUnsafeRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $unsafeRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + s""" + | // evaluate aggregate function + | ${ev.code} + | // update unsafe row buffer + | $updateColumnCode + """.stripMargin } + + val updateAggValCode = splitAggregateExpressions( + ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, + Seq(("InternalRow", unsafeRowBuffer))) --- End diff -- ``` ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, Seq(("InternalRow", unsafeRowBuffer))) ```
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org