jaltekruse commented on a change in pull request #27224: [SPARK-30523][SQL] - Collapse nested aggregates URL: https://github.com/apache/spark/pull/27224#discussion_r367202247
########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ########## @@ -964,6 +965,155 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Combines two adjacent [[Aggregate]] operators into one, if the first one is not necessary. + * + * If we are referencing the outputs of aggregate functions in the inner aggregate from the outer + * one, check if they are being used in outer aggregates in a way that can be collapsed into a + * single aggregate. A sum of sums, or a max of max, or min of min are all collapsible. + * avg over avg will not be collapsible because different number of raw rows will have contributed + * to the partial averages of the inner aggregate + * + * Min an Max can be folded in the case described above, or if they are referencing + * the group by columns, as they can safely be computed just using the set of + * unique values. + */ +object CombineAggregates extends Rule[LogicalPlan] with PredicateHelper { + + /** + * The aggregate expression list includes both aggregate expressions and + * the projected group by keys, this filters out the aggregate expressions + * in the list leaving just the group by keys. It also unwraps aliases to + * just give a list of the projected grouping expressions themselves. + */ + def justProjectedGroupExprs(aggExprs: Seq[NamedExpression], + groupExprs: Seq[Expression]): Seq[NamedExpression] = { + aggExprs.filter(namedEx => + groupExprs.exists(_.semanticEquals(unwrapAlias(namedEx))) + ) + } + + def unwrapAlias(ex: Expression): Expression = { + if (ex.isInstanceOf[Alias]) ex.children.head + else ex + } + + /** + * Pulls up references to aliases from an earlier operator and replaces them with the + * raw expression they are associated with. + * + * The output name of the original expression is assumed to be the desired final name + * of the rewritten expression, so if necessary an alias is added to ensure the output + * name is correct. + * + * @param ex expression to re-write + * @param aliasMap aliases from the input operator, mapped to their expressions + * @return rewritten expression with intermediate aliases removed + */ + def resolveAliasesMaintainingSchema(ex: NamedExpression, + aliasMap: AttributeMap[Expression]): NamedExpression = { + val ret = replaceAlias(ex, aliasMap) + ret match { + case namedEx: NamedExpression => + if (namedEx.name != ex.name) { + Alias(ret, ex.name)(ex.exprId, ex.qualifier, Some(ex.metadata)) + } else { + namedEx + } + case _ => Alias(ret, ex.name)(ex.exprId, ex.qualifier, Some(ex.metadata)) + } + } + + def collapseIntoOneAggregate(aggExprs: Seq[NamedExpression], + groupExprs: Seq[Expression], + childAgg: Aggregate): Aggregate = { + + val aliasMap = AttributeMap(childAgg.aggregateExpressions.collect { + case a: Alias => (a.toAttribute, a.child) + }) + val aliasResolvedAggExprs = aggExprs.map(resolveAliasesMaintainingSchema(_, aliasMap)) + val aliasResolvedGroupExprs = groupExprs.map(ex => replaceAlias(ex, aliasMap)) + Aggregate(aliasResolvedGroupExprs, aliasResolvedAggExprs, childAgg.child) + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + // The query execution/optimization does not guarantee the expressions are evaluated in order. + // We only can combine them if and only if both are deterministic. + case agg@Aggregate(groupExprs: Seq[Expression], + projectionsOfAggregateNode: Seq[NamedExpression], + childAgg@Aggregate(childGroupExprs, childAggExprs, grandChild)) => + + var collapsible = true + val collapsibleNestedAggs = projectionsOfAggregateNode.map(aggEx => { + // Don't need to rewrite the projected grouping key expressions, but want to maintain them + // in the list, so identify them early and keep them the same. + // + // The or condition is kind of a hack for "early exit" of this loop, so later iterations + // don't overwrite the value to again declare this collapsible + if (groupExprs.exists(ex => ex.semanticEquals(aggEx)) || !collapsible) aggEx + else { + aggEx match { + case a@Alias(outerAggExpr: AggregateExpression, _) => + outerAggExpr.aggregateFunction match { + case _: Max | _: Min | _: Sum => + // look for the expressions in the inner aggregate that produce + // the columns used by the outer aggregate + val resolvedInnerExprs = childAggExprs.filter( + ex => outerAggExpr.references.exists(_.name == ex.name)) + // this rule only handles cases where outer expressions are simple + // and reference a single column from the inner aggregate. Review comment: I don't believe there are lots of cases that would be collapsible, there is one case in the tests that is marked as "this should be collapsible" along these lines. Are there some in particular you are thinking of? ``` // I think this could be collapsed, I would need to make the rule smarter to understand // that sums can be added together with the standard scalar expression addition operator test("Cannot collapse nested Agg references, complex outer agg expression with sums") { val query = """|SELECT sum(sumEarnings + sumYear), instructor from ( | select sum(earnings) as sumEarnings, sum(year) as sumYear, course, instructor FROM courseSalesWider GROUP BY course, instructor |) group by instructor """.stripMargin ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org