cloud-fan commented on a change in pull request #27627:
URL: https://github.com/apache/spark/pull/27627#discussion_r416408564



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
##########
@@ -62,38 +62,113 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate with ImplicitCast
 
   private lazy val sum = AttributeReference("sum", sumDataType)()
 
+  private lazy val isEmptyOrNulls = AttributeReference("isEmptyOrNulls", 
BooleanType, false)()
+
   private lazy val zero = Literal.default(sumDataType)
 
-  override lazy val aggBufferAttributes = sum :: Nil
+  override lazy val aggBufferAttributes = sum :: isEmptyOrNulls :: Nil
 
   override lazy val initialValues: Seq[Expression] = Seq(
-    /* sum = */ Literal.create(null, sumDataType)
+    /* sum = */  zero,
+    /* isEmptyOrNulls = */ Literal.create(true, BooleanType)
   )
 
+  /**
+   * For decimal types and when child is nullable:
+   * isEmptyOrNulls flag is a boolean to represent if there are no rows or if 
all rows that
+   * have been seen are null.  This will be used to identify if the end result 
of sum in
+   * evaluateExpression should be null or not.
+   *
+   * Update of the isEmptyOrNulls flag:
+   * If this flag is false, then keep it as is.
+   * If this flag is true, then check if the incoming value is null and if it 
is null, keep it
+   * as true else update it to false.
+   * Once this flag is switched to false, it will remain false.
+   *
+   * The update of the sum is as follows:
+   * If sum is null, then we have a case of overflow, so keep sum as is.
+   * If sum is not null, and the incoming value is not null, then perform the 
addition along
+   * with the overflow checking. Note, that if overflow occurs, then sum will 
be null here.
+   * If the new incoming value is null, we will keep the sum in buffer as is 
and skip this
+   * incoming null
+   */
   override lazy val updateExpressions: Seq[Expression] = {
     if (child.nullable) {
-      Seq(
-        /* sum = */
-        coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
-      )
+      resultType match {
+        case d: DecimalType =>
+          Seq(
+            /* sum */
+            If(IsNull(sum), sum,
+              If(IsNotNull(child.cast(sumDataType)),
+                CheckOverflow(sum + child.cast(sumDataType), d, true), sum)),
+            /* isEmptyOrNulls */
+            If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls)
+          )
+        case _ =>
+          Seq(
+            coalesce(sum + child.cast(sumDataType), sum),
+            If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls)
+          )
+      }
     } else {
-      Seq(
-        /* sum = */
-        coalesce(sum, zero) + child.cast(sumDataType)
-      )
+      resultType match {
+        case d: DecimalType =>
+          Seq(
+            /* sum */
+            If(IsNull(sum), sum, CheckOverflow(sum + child.cast(sumDataType), 
d, true)),
+            /* isEmptyOrNulls */
+            false
+          )
+        case _ => Seq(sum + child.cast(sumDataType), false)
+      }
     }
   }
 
+  /**
+   * For decimal type:
+   * update of the sum is as follows:
+   * Check if either portion of the left.sum or right.sum has overflowed
+   * If it has, then the sum value will remain null.
+   * If it did not have overflow, then add the sum.left and sum.right and 
check for overflow.
+   *
+   * isEmptyOrNulls:  Set to false if either one of the left or right is set 
to false. This
+   * means we have seen atleast a row that was not null.
+   * If the value from bufferLeft and bufferRight are both true, then this 
will be true.
+   */
   override lazy val mergeExpressions: Seq[Expression] = {
-    Seq(
-      /* sum = */
-      coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
-    )
+    resultType match {
+      case d: DecimalType =>
+        Seq(
+          /* sum = */
+          If(And(IsNull(sum.left), EqualTo(isEmptyOrNulls.left, false)) ||
+            And(IsNull(sum.right), EqualTo(isEmptyOrNulls.right, false)),
+              Literal.create(null, resultType),
+              CheckOverflow(sum.left + sum.right, d, true)),
+          /* isEmptyOrNulls = */
+          And(isEmptyOrNulls.left, isEmptyOrNulls.right)
+          )
+      case _ =>
+        Seq(
+          coalesce(sum.left + sum.right, sum.left),
+          And(isEmptyOrNulls.left, isEmptyOrNulls.right)
+        )
+    }
   }
 
+  /**
+   * If the isEmptyOrNulls is true, then it means either there are no rows, or 
all the rows were
+   * null, so the result will be null.
+   * If the isEmptyOrNulls is false, then if sum is null that means an 
overflow has happened.
+   * So now, if ansi is enabled, then throw exception, if not then return null.
+   * If sum is not null, then return the sum.

Review comment:
       If we don't check overflow at 
https://github.com/apache/spark/pull/27627/files#r416407527 , we can just use 
`CheckOverflow` here, which respects the ansi flag.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to