cloud-fan commented on a change in pull request #27627: URL: https://github.com/apache/spark/pull/27627#discussion_r416408564
########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala ########## @@ -62,38 +62,113 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val isEmptyOrNulls = AttributeReference("isEmptyOrNulls", BooleanType, false)() + private lazy val zero = Literal.default(sumDataType) - override lazy val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributes = sum :: isEmptyOrNulls :: Nil override lazy val initialValues: Seq[Expression] = Seq( - /* sum = */ Literal.create(null, sumDataType) + /* sum = */ zero, + /* isEmptyOrNulls = */ Literal.create(true, BooleanType) ) + /** + * For decimal types and when child is nullable: + * isEmptyOrNulls flag is a boolean to represent if there are no rows or if all rows that + * have been seen are null. This will be used to identify if the end result of sum in + * evaluateExpression should be null or not. + * + * Update of the isEmptyOrNulls flag: + * If this flag is false, then keep it as is. + * If this flag is true, then check if the incoming value is null and if it is null, keep it + * as true else update it to false. + * Once this flag is switched to false, it will remain false. + * + * The update of the sum is as follows: + * If sum is null, then we have a case of overflow, so keep sum as is. + * If sum is not null, and the incoming value is not null, then perform the addition along + * with the overflow checking. Note, that if overflow occurs, then sum will be null here. + * If the new incoming value is null, we will keep the sum in buffer as is and skip this + * incoming null + */ override lazy val updateExpressions: Seq[Expression] = { if (child.nullable) { - Seq( - /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - ) + resultType match { + case d: DecimalType => + Seq( + /* sum */ + If(IsNull(sum), sum, + If(IsNotNull(child.cast(sumDataType)), + CheckOverflow(sum + child.cast(sumDataType), d, true), sum)), + /* isEmptyOrNulls */ + If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls) + ) + case _ => + Seq( + coalesce(sum + child.cast(sumDataType), sum), + If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls) + ) + } } else { - Seq( - /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType) - ) + resultType match { + case d: DecimalType => + Seq( + /* sum */ + If(IsNull(sum), sum, CheckOverflow(sum + child.cast(sumDataType), d, true)), + /* isEmptyOrNulls */ + false + ) + case _ => Seq(sum + child.cast(sumDataType), false) + } } } + /** + * For decimal type: + * update of the sum is as follows: + * Check if either portion of the left.sum or right.sum has overflowed + * If it has, then the sum value will remain null. + * If it did not have overflow, then add the sum.left and sum.right and check for overflow. + * + * isEmptyOrNulls: Set to false if either one of the left or right is set to false. This + * means we have seen atleast a row that was not null. + * If the value from bufferLeft and bufferRight are both true, then this will be true. + */ override lazy val mergeExpressions: Seq[Expression] = { - Seq( - /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) - ) + resultType match { + case d: DecimalType => + Seq( + /* sum = */ + If(And(IsNull(sum.left), EqualTo(isEmptyOrNulls.left, false)) || + And(IsNull(sum.right), EqualTo(isEmptyOrNulls.right, false)), + Literal.create(null, resultType), + CheckOverflow(sum.left + sum.right, d, true)), + /* isEmptyOrNulls = */ + And(isEmptyOrNulls.left, isEmptyOrNulls.right) + ) + case _ => + Seq( + coalesce(sum.left + sum.right, sum.left), + And(isEmptyOrNulls.left, isEmptyOrNulls.right) + ) + } } + /** + * If the isEmptyOrNulls is true, then it means either there are no rows, or all the rows were + * null, so the result will be null. + * If the isEmptyOrNulls is false, then if sum is null that means an overflow has happened. + * So now, if ansi is enabled, then throw exception, if not then return null. + * If sum is not null, then return the sum. Review comment: If we don't check overflow at https://github.com/apache/spark/pull/27627/files#r416407527 , we can just use `CheckOverflow` here, which respects the ansi flag. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org