cloud-fan commented on a change in pull request #27627:
URL: https://github.com/apache/spark/pull/27627#discussion_r416407135



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
##########
@@ -62,38 +62,113 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate with ImplicitCast
 
   private lazy val sum = AttributeReference("sum", sumDataType)()
 
+  private lazy val isEmptyOrNulls = AttributeReference("isEmptyOrNulls", 
BooleanType, false)()
+
   private lazy val zero = Literal.default(sumDataType)
 
-  override lazy val aggBufferAttributes = sum :: Nil
+  override lazy val aggBufferAttributes = sum :: isEmptyOrNulls :: Nil
 
   override lazy val initialValues: Seq[Expression] = Seq(
-    /* sum = */ Literal.create(null, sumDataType)
+    /* sum = */  zero,
+    /* isEmptyOrNulls = */ Literal.create(true, BooleanType)
   )
 
+  /**
+   * For decimal types and when child is nullable:
+   * isEmptyOrNulls flag is a boolean to represent if there are no rows or if 
all rows that
+   * have been seen are null.  This will be used to identify if the end result 
of sum in
+   * evaluateExpression should be null or not.
+   *
+   * Update of the isEmptyOrNulls flag:
+   * If this flag is false, then keep it as is.
+   * If this flag is true, then check if the incoming value is null and if it 
is null, keep it
+   * as true else update it to false.
+   * Once this flag is switched to false, it will remain false.
+   *
+   * The update of the sum is as follows:
+   * If sum is null, then we have a case of overflow, so keep sum as is.
+   * If sum is not null, and the incoming value is not null, then perform the 
addition along
+   * with the overflow checking. Note, that if overflow occurs, then sum will 
be null here.
+   * If the new incoming value is null, we will keep the sum in buffer as is 
and skip this
+   * incoming null
+   */
   override lazy val updateExpressions: Seq[Expression] = {
     if (child.nullable) {
-      Seq(
-        /* sum = */
-        coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
-      )
+      resultType match {
+        case d: DecimalType =>
+          Seq(
+            /* sum */
+            If(IsNull(sum), sum,
+              If(IsNotNull(child.cast(sumDataType)),
+                CheckOverflow(sum + child.cast(sumDataType), d, true), sum)),
+            /* isEmptyOrNulls */
+            If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls)
+          )
+        case _ =>
+          Seq(
+            coalesce(sum + child.cast(sumDataType), sum),
+            If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls)
+          )
+      }
     } else {
-      Seq(
-        /* sum = */
-        coalesce(sum, zero) + child.cast(sumDataType)
-      )
+      resultType match {
+        case d: DecimalType =>
+          Seq(
+            /* sum */
+            If(IsNull(sum), sum, CheckOverflow(sum + child.cast(sumDataType), 
d, true)),
+            /* isEmptyOrNulls */
+            false
+          )
+        case _ => Seq(sum + child.cast(sumDataType), false)
+      }
     }
   }
 
+  /**
+   * For decimal type:
+   * update of the sum is as follows:
+   * Check if either portion of the left.sum or right.sum has overflowed

Review comment:
       we should explain how we check overflow: the `sum` is null and `isEmpty` 
is false.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to