sunchao commented on code in PR #34558:
URL: https://github.com/apache/spark/pull/34558#discussion_r3293711737


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala:
##########
@@ -886,6 +1179,114 @@ case class ArrayAggregate(
     }
   }
 
+  protected def nullSafeCodeGen(
+      ctx: CodegenContext,
+      ev: ExprCode,
+      f: String => String): ExprCode = {
+    val argumentGen = argument.genCode(ctx)
+    val resultCode = f(argumentGen.value)
+
+    if (nullable) {
+      val nullSafeEval = ctx.nullSafeExec(argument.nullable, 
argumentGen.isNull)(resultCode)
+      ev.copy(code = code"""
+        |${argumentGen.code}
+        |boolean ${ev.isNull} = ${argumentGen.isNull};
+        |${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+        |$nullSafeEval
+      """)
+    } else {
+      ev.copy(code = code"""
+        |${argumentGen.code}
+        |${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+        |$resultCode
+      """, isNull = FalseLiteral)
+    }
+  }
+
+  protected def assignVar(
+      varCode: ExprCode,
+      atomicVar: String,
+      value: String,
+      isNull: String,
+      nullable: Boolean): String = {
+    val atomicAssign = assignAtomic(atomicVar, value, isNull, nullable)
+    if (nullable) {
+      s"""
+        ${varCode.value} = $value;
+        ${varCode.isNull} = $isNull;
+        $atomicAssign
+      """
+    } else {
+      s"""
+        ${varCode.value} = $value;
+        $atomicAssign
+      """
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    ctx.withLambdaVars(Seq(elementVar, accForMergeVar, accForFinishVar), 
varCodes => {
+      val Seq(elementCode, accForMergeCode, accForFinishCode) = varCodes
+
+      nullSafeCodeGen(ctx, ev, arg => {
+        val numElements = ctx.freshName("numElements")
+        val i = ctx.freshName("i")
+
+        val zeroCode = zero.genCode(ctx)
+        val mergeCode = merge.genCode(ctx)
+        val finishCode = finish.genCode(ctx)
+
+        val elementAssignment = assignArrayElement(ctx, arg, elementCode, 
elementVar, i)
+        val mergeAtomic = ctx.addReferenceObj(accForMergeVar.name,
+          accForMergeVar.value)
+        val finishAtomic = ctx.addReferenceObj(accForFinishVar.name,
+          accForFinishVar.value)
+
+        val mergeJavaType = CodeGenerator.javaType(accForMergeVar.dataType)
+        val finishJavaType = CodeGenerator.javaType(accForFinishVar.dataType)
+
+        // Some expressions return internal buffers that we have to copy
+        val mergeCopy = if (CodeGenerator.isPrimitiveType(merge.dataType)) {
+          s"${mergeCode.value}"
+        } else {
+          s"($mergeJavaType)InternalRow.copyValue(${mergeCode.value})"
+        }
+
+        val nullCheck = if (nullable) {
+          s"${ev.isNull} = ${finishCode.isNull};"
+        } else {
+          ""
+        }
+
+        val initialAssignment = assignVar(accForMergeCode, mergeAtomic, 
zeroCode.value,
+          zeroCode.isNull, zero.nullable)
+
+        val mergeAssignment = assignVar(accForMergeCode, mergeAtomic, 
mergeCopy,
+          mergeCode.isNull, merge.nullable)

Review Comment:
   [P1] Clear `ArrayAggregate` accumulator null state on every generated 
assignment
   
   `accForMergeVar` is always bound nullable in `bindInternal`, and 
`withLambdaVars` stores its `isNull` flag as mutable generated state. These 
assignments only update that flag when the source expression is nullable, which 
produces wrong results in both cases below:
   
   - Within one row, `aggregate(array(CAST(id AS INT) + 1, CAST(id AS INT) + 
2), CAST(NULL AS INT), (acc, x) -> coalesce(acc, 0) + x, acc -> coalesce(acc, 
-1))` returns `-1` with `CODEGEN_ONLY` instead of `3`; the same query returns 
`3` with `NO_CODEGEN`.
   - Across rows in one generated partition, after a row whose merge becomes 
null, an empty-array row with zero `0` returns `-1` instead of `0` with 
`CODEGEN_ONLY`; the interpreted path returns `0`.
   
   Please pass `accForMergeVar.nullable` to both `initialAssignment` and 
`mergeAssignment` and add generated-code regressions for these paths.
   
   _[ :robot: posted by Codex on behalf of sunchao using the code-review-for-me 
skill :robot: ]_



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to