Github user gatorsmile commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19082#discussion_r156184646
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 ---
    @@ -863,25 +984,43 @@ case class HashAggregateExec(
         }
     
         val updateRowInUnsafeRowMap: String = {
    -      ctx.INPUT_ROW = unsafeRowBuffer
    +      // We need to copy the aggregation row buffer to a local row first 
because each aggregate
    +      // function directly updates the buffer when it finishes.
    +      val localRowBuffer = ctx.freshName("localUnsafeRowBuffer")
    +      val initLocalRowBuffer = s"InternalRow $localRowBuffer = 
$unsafeRowBuffer.copy();"
    +
    +      ctx.INPUT_ROW = localRowBuffer
           val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, 
inputAttr))
           val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
           val effectiveCodes = subExprs.codes.mkString("\n")
           val unsafeRowBufferEvals = 
ctx.withSubExprEliminationExprs(subExprs.states) {
             boundUpdateExpr.map(_.genCode(ctx))
           }
    -      val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { 
case (ev, i) =>
    +
    +      val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { 
case (ev, i) =>
             val dt = updateExpr(i).dataType
    -        ctx.updateColumn(unsafeRowBuffer, dt, i, ev, 
updateExpr(i).nullable)
    +        val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i, 
ev, updateExpr(i).nullable)
    +        s"""
    +           | // evaluate aggregate function
    +           | ${ev.code}
    +           | // update unsafe row buffer
    +           | $updateColumnCode
    +         """.stripMargin
           }
    +
    +      val updateAggValCode = splitAggregateExpressions(
    +        ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states,
    +        Seq(("InternalRow", unsafeRowBuffer)))
    --- End diff --
    
    ```
    ctx,
    boundUpdateExpr,
    evalAndUpdateCodes,
    subExprs.states,
    Seq(("InternalRow", unsafeRowBuffer)))
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to