maropu commented on a change in pull request #32470: URL: https://github.com/apache/spark/pull/32470#discussion_r649623086
########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ########## @@ -2457,164 +2450,127 @@ class Analyzer(override val catalogManager: CatalogManager) _.containsPattern(AGGREGATE), ruleId) { // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly // resolve the having condition expression, here we skip resolving it in ResolveReferences - // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519. + // and transform it to Filter after aggregate is resolved. Basically columns in HAVING should + // be resolved with `agg.child.output` first. See more details in SPARK-31519. case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved => - resolveHaving(Filter(cond, agg), agg) - - case f @ Filter(_, agg: Aggregate) if agg.resolved => - resolveHaving(f, agg) - - case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => - - // Try resolving the ordering as though it is in the aggregate clause. - try { - // If a sort order is unresolved, containing references not in aggregate, or containing - // `AggregateExpression`, we need to push down it to the underlying aggregate operator. - val unresolvedSortOrders = sortOrder.filter { s => - !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + resolveOperatorWithAggregate(Seq(cond), agg, (newExprs, newChild) => { + Filter(newExprs.head, newChild) + }) + + case Filter(cond, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = resolveExpressionByPlanOutput(cond, agg) + resolveOperatorWithAggregate(Seq(maybeResolved), agg, (newExprs, newChild) => { + Filter(newExprs.head, newChild) + }) + + case Sort(sortOrder, global, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = sortOrder.map(_.child).map(resolveExpressionByPlanOutput(_, agg)) + resolveOperatorWithAggregate(maybeResolved, agg, (newExprs, newChild) => { + val newSortOrder = sortOrder.zip(newExprs).map { + case (sortOrder, expr) => sortOrder.copy(child = expr) } - val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) - - val aggregateWithExtraOrdering = aggregate.copy( - aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering) - - val resolvedAggregate: Aggregate = - executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate] - - val (reResolvedAggExprs, resolvedAliasedOrdering) = - resolvedAggregate.aggregateExpressions.splitAt(aggregate.aggregateExpressions.length) - - // If we pass the analysis check, then the ordering expressions should only reference to - // aggregate expressions or grouping expressions, and it's safe to push them down to - // Aggregate. - checkAnalysis(resolvedAggregate) - - val originalAggExprs = aggregate.aggregateExpressions.map(trimNonTopLevelAliases) - - // If the ordering expression is same with original aggregate expression, we don't need - // to push down this ordering expression and can reference the original aggregate - // expression instead. - val needsPushDown = ArrayBuffer.empty[NamedExpression] - val orderToAlias = unresolvedSortOrders.zip(aliasedOrdering) - val evaluatedOrderings = - resolvedAliasedOrdering.asInstanceOf[Seq[Alias]].zip(orderToAlias).map { - case (evaluated, (order, aliasOrder)) => - val index = reResolvedAggExprs.indexWhere { - case Alias(child, _) => child semanticEquals evaluated.child - case other => other semanticEquals evaluated.child - } - - if (index == -1) { - if (hasCharVarchar(evaluated)) { - needsPushDown += aliasOrder - order.copy(child = aliasOrder) - } else { - needsPushDown += evaluated - order.copy(child = evaluated.toAttribute) - } - } else { - order.copy(child = originalAggExprs(index).toAttribute) - } + Sort(newSortOrder, global, newChild) + }) + } + + def resolveExprsWithAggregate( + exprs: Seq[Expression], + agg: Aggregate): (Seq[NamedExpression], Seq[Expression]) = { Review comment: How about leaving some comments about what are the returned two values? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ########## @@ -2457,164 +2450,127 @@ class Analyzer(override val catalogManager: CatalogManager) _.containsPattern(AGGREGATE), ruleId) { // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly // resolve the having condition expression, here we skip resolving it in ResolveReferences - // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519. + // and transform it to Filter after aggregate is resolved. Basically columns in HAVING should + // be resolved with `agg.child.output` first. See more details in SPARK-31519. case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved => - resolveHaving(Filter(cond, agg), agg) - - case f @ Filter(_, agg: Aggregate) if agg.resolved => - resolveHaving(f, agg) - - case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => - - // Try resolving the ordering as though it is in the aggregate clause. - try { - // If a sort order is unresolved, containing references not in aggregate, or containing - // `AggregateExpression`, we need to push down it to the underlying aggregate operator. - val unresolvedSortOrders = sortOrder.filter { s => - !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + resolveOperatorWithAggregate(Seq(cond), agg, (newExprs, newChild) => { + Filter(newExprs.head, newChild) + }) + + case Filter(cond, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = resolveExpressionByPlanOutput(cond, agg) + resolveOperatorWithAggregate(Seq(maybeResolved), agg, (newExprs, newChild) => { + Filter(newExprs.head, newChild) + }) + + case Sort(sortOrder, global, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = sortOrder.map(_.child).map(resolveExpressionByPlanOutput(_, agg)) + resolveOperatorWithAggregate(maybeResolved, agg, (newExprs, newChild) => { + val newSortOrder = sortOrder.zip(newExprs).map { + case (sortOrder, expr) => sortOrder.copy(child = expr) } - val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) - - val aggregateWithExtraOrdering = aggregate.copy( - aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering) - - val resolvedAggregate: Aggregate = - executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate] - - val (reResolvedAggExprs, resolvedAliasedOrdering) = - resolvedAggregate.aggregateExpressions.splitAt(aggregate.aggregateExpressions.length) - - // If we pass the analysis check, then the ordering expressions should only reference to - // aggregate expressions or grouping expressions, and it's safe to push them down to - // Aggregate. - checkAnalysis(resolvedAggregate) - - val originalAggExprs = aggregate.aggregateExpressions.map(trimNonTopLevelAliases) - - // If the ordering expression is same with original aggregate expression, we don't need - // to push down this ordering expression and can reference the original aggregate - // expression instead. - val needsPushDown = ArrayBuffer.empty[NamedExpression] - val orderToAlias = unresolvedSortOrders.zip(aliasedOrdering) - val evaluatedOrderings = - resolvedAliasedOrdering.asInstanceOf[Seq[Alias]].zip(orderToAlias).map { - case (evaluated, (order, aliasOrder)) => - val index = reResolvedAggExprs.indexWhere { - case Alias(child, _) => child semanticEquals evaluated.child - case other => other semanticEquals evaluated.child - } - - if (index == -1) { - if (hasCharVarchar(evaluated)) { - needsPushDown += aliasOrder - order.copy(child = aliasOrder) - } else { - needsPushDown += evaluated - order.copy(child = evaluated.toAttribute) - } - } else { - order.copy(child = originalAggExprs(index).toAttribute) - } + Sort(newSortOrder, global, newChild) + }) + } + + def resolveExprsWithAggregate( + exprs: Seq[Expression], + agg: Aggregate): (Seq[NamedExpression], Seq[Expression]) = { + val aggregateExpressions = ArrayBuffer.empty[NamedExpression] Review comment: nit: how about `aggregateExpressions` -> `extraAggExprs`? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ########## @@ -2457,164 +2450,127 @@ class Analyzer(override val catalogManager: CatalogManager) _.containsPattern(AGGREGATE), ruleId) { // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly // resolve the having condition expression, here we skip resolving it in ResolveReferences - // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519. + // and transform it to Filter after aggregate is resolved. Basically columns in HAVING should + // be resolved with `agg.child.output` first. See more details in SPARK-31519. case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved => - resolveHaving(Filter(cond, agg), agg) - - case f @ Filter(_, agg: Aggregate) if agg.resolved => - resolveHaving(f, agg) - - case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => - - // Try resolving the ordering as though it is in the aggregate clause. - try { - // If a sort order is unresolved, containing references not in aggregate, or containing - // `AggregateExpression`, we need to push down it to the underlying aggregate operator. - val unresolvedSortOrders = sortOrder.filter { s => - !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + resolveOperatorWithAggregate(Seq(cond), agg, (newExprs, newChild) => { + Filter(newExprs.head, newChild) + }) + + case Filter(cond, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = resolveExpressionByPlanOutput(cond, agg) + resolveOperatorWithAggregate(Seq(maybeResolved), agg, (newExprs, newChild) => { + Filter(newExprs.head, newChild) + }) + + case Sort(sortOrder, global, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = sortOrder.map(_.child).map(resolveExpressionByPlanOutput(_, agg)) + resolveOperatorWithAggregate(maybeResolved, agg, (newExprs, newChild) => { + val newSortOrder = sortOrder.zip(newExprs).map { + case (sortOrder, expr) => sortOrder.copy(child = expr) } - val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) - - val aggregateWithExtraOrdering = aggregate.copy( - aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering) - - val resolvedAggregate: Aggregate = - executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate] - - val (reResolvedAggExprs, resolvedAliasedOrdering) = - resolvedAggregate.aggregateExpressions.splitAt(aggregate.aggregateExpressions.length) - - // If we pass the analysis check, then the ordering expressions should only reference to - // aggregate expressions or grouping expressions, and it's safe to push them down to - // Aggregate. - checkAnalysis(resolvedAggregate) - - val originalAggExprs = aggregate.aggregateExpressions.map(trimNonTopLevelAliases) - - // If the ordering expression is same with original aggregate expression, we don't need - // to push down this ordering expression and can reference the original aggregate - // expression instead. - val needsPushDown = ArrayBuffer.empty[NamedExpression] - val orderToAlias = unresolvedSortOrders.zip(aliasedOrdering) - val evaluatedOrderings = - resolvedAliasedOrdering.asInstanceOf[Seq[Alias]].zip(orderToAlias).map { - case (evaluated, (order, aliasOrder)) => - val index = reResolvedAggExprs.indexWhere { - case Alias(child, _) => child semanticEquals evaluated.child - case other => other semanticEquals evaluated.child - } - - if (index == -1) { - if (hasCharVarchar(evaluated)) { - needsPushDown += aliasOrder - order.copy(child = aliasOrder) - } else { - needsPushDown += evaluated - order.copy(child = evaluated.toAttribute) - } - } else { - order.copy(child = originalAggExprs(index).toAttribute) - } + Sort(newSortOrder, global, newChild) + }) + } + + def resolveExprsWithAggregate( + exprs: Seq[Expression], + agg: Aggregate): (Seq[NamedExpression], Seq[Expression]) = { + val aggregateExpressions = ArrayBuffer.empty[NamedExpression] + val transformed = exprs.map { e => + // Try resolving the expression as though it is in the aggregate clause. + def resolveCol(input: Expression): Expression = { + resolveExpressionByPlanOutput(input, agg.child) + } + def resolveSubQuery(input: Expression): Expression = { + if (SubqueryExpression.hasSubquery(input)) { + val fake = Project(Alias(input, "fake")() :: Nil, agg.child) + ResolveSubquery(fake).asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child + } else { + input } - - val sortOrdersMap = unresolvedSortOrders - .map(new TreeNodeRef(_)) - .zip(evaluatedOrderings) - .toMap - val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) - - // Since we don't rely on sort.resolved as the stop condition for this rule, - // we need to check this and prevent applying this rule multiple times - if (sortOrder == finalSortOrders) { - sort + } + val maybeResolved = resolveSubQuery(resolveCol(e)) + if (maybeResolved.resolved && maybeResolved.references.subsetOf(agg.outputSet) && + !containsAggregate(maybeResolved)) { + // The given expression is valid and doesn't need extra resolution. + maybeResolved + } else if (containsUnresolvedFunc(maybeResolved)) { + // The given expression has unresolved functions which may be aggregate functions and we + // need to wait for other rules to resolve the functions first. + maybeResolved + } else { + // Avoid adding an extra aggregate expression if it's already present in + // `agg.aggregateExpressions`. + val index = if (maybeResolved.resolved) { + agg.aggregateExpressions.indexWhere { + case Alias(child, _) => child semanticEquals maybeResolved + case other => other semanticEquals maybeResolved + } } else { - Project(aggregate.output, - Sort(finalSortOrders, global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + -1 + } + if (index >= 0) { + agg.aggregateExpressions(index).toAttribute + } else { + buildAggExprList(maybeResolved, agg, aggregateExpressions) } - } catch { - // Attempting to resolve in the aggregate can result in ambiguity. When this happens, - // just return the original plan. - case ae: AnalysisException => sort } + } + (aggregateExpressions.toSeq, transformed) } - def hasCharVarchar(expr: Alias): Boolean = { - expr.find { - case ne: NamedExpression => CharVarcharUtils.getRawType(ne.metadata).nonEmpty - case _ => false - }.nonEmpty + private def buildAggExprList( + expr: Expression, + agg: Aggregate, + aggExprList: ArrayBuffer[NamedExpression]): Expression = expr match { + case ae: AggregateExpression if ae.resolved => + val alias = Alias(ae, ae.toString)() + aggExprList += alias + alias.toAttribute + // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. + case grouping: Expression if grouping.resolved && + agg.groupingExpressions.exists(_.semanticEquals(grouping)) && + !ResolveGroupingAnalytics.hasGroupingFunction(grouping) && + !agg.output.exists(_.semanticEquals(grouping)) => + grouping match { + case ne: NamedExpression => + aggExprList += ne + ne.toAttribute + case _ => + val alias = Alias(grouping, grouping.toString)() + aggExprList += alias + alias.toAttribute + } + case a: Attribute if agg.child.outputSet.contains(a) && !agg.outputSet.contains(a) => + // Undo the resolution. This attribute is neither inside aggregate functions nor a + // grouping column. It shouldn't be resolved with `agg.child.output`. + CurrentOrigin.withOrigin(a.origin)(UnresolvedAttribute(Seq(a.name))) Review comment: Just out of curiosity; what's this handling for? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ########## @@ -623,28 +622,26 @@ class Analyzer(override val catalogManager: CatalogManager) // groupingExpressions for condition resolving. val aggForResolving = aggregate.copy(groupingExpressions = groupByExprs) // Try resolving the condition of the filter as though it is in the aggregate clause - val resolvedInfo = - ResolveAggregateFunctions.resolveFilterCondInAggregate(h.havingCondition, aggForResolving) + val (extraAggExprs, Seq(afterResolve)) = Review comment: nit: how about `afterResolve` -> `resolvedHavingCond`? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ########## @@ -2457,164 +2450,127 @@ class Analyzer(override val catalogManager: CatalogManager) _.containsPattern(AGGREGATE), ruleId) { // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly // resolve the having condition expression, here we skip resolving it in ResolveReferences - // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519. + // and transform it to Filter after aggregate is resolved. Basically columns in HAVING should + // be resolved with `agg.child.output` first. See more details in SPARK-31519. case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved => - resolveHaving(Filter(cond, agg), agg) - - case f @ Filter(_, agg: Aggregate) if agg.resolved => - resolveHaving(f, agg) - - case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => - - // Try resolving the ordering as though it is in the aggregate clause. - try { - // If a sort order is unresolved, containing references not in aggregate, or containing - // `AggregateExpression`, we need to push down it to the underlying aggregate operator. - val unresolvedSortOrders = sortOrder.filter { s => - !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + resolveOperatorWithAggregate(Seq(cond), agg, (newExprs, newChild) => { + Filter(newExprs.head, newChild) + }) + + case Filter(cond, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = resolveExpressionByPlanOutput(cond, agg) + resolveOperatorWithAggregate(Seq(maybeResolved), agg, (newExprs, newChild) => { + Filter(newExprs.head, newChild) + }) + + case Sort(sortOrder, global, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = sortOrder.map(_.child).map(resolveExpressionByPlanOutput(_, agg)) + resolveOperatorWithAggregate(maybeResolved, agg, (newExprs, newChild) => { + val newSortOrder = sortOrder.zip(newExprs).map { + case (sortOrder, expr) => sortOrder.copy(child = expr) } - val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) - - val aggregateWithExtraOrdering = aggregate.copy( - aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering) - - val resolvedAggregate: Aggregate = - executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate] - - val (reResolvedAggExprs, resolvedAliasedOrdering) = - resolvedAggregate.aggregateExpressions.splitAt(aggregate.aggregateExpressions.length) - - // If we pass the analysis check, then the ordering expressions should only reference to - // aggregate expressions or grouping expressions, and it's safe to push them down to - // Aggregate. - checkAnalysis(resolvedAggregate) - - val originalAggExprs = aggregate.aggregateExpressions.map(trimNonTopLevelAliases) - - // If the ordering expression is same with original aggregate expression, we don't need - // to push down this ordering expression and can reference the original aggregate - // expression instead. - val needsPushDown = ArrayBuffer.empty[NamedExpression] - val orderToAlias = unresolvedSortOrders.zip(aliasedOrdering) - val evaluatedOrderings = - resolvedAliasedOrdering.asInstanceOf[Seq[Alias]].zip(orderToAlias).map { - case (evaluated, (order, aliasOrder)) => - val index = reResolvedAggExprs.indexWhere { - case Alias(child, _) => child semanticEquals evaluated.child - case other => other semanticEquals evaluated.child - } - - if (index == -1) { - if (hasCharVarchar(evaluated)) { - needsPushDown += aliasOrder - order.copy(child = aliasOrder) - } else { - needsPushDown += evaluated - order.copy(child = evaluated.toAttribute) - } - } else { - order.copy(child = originalAggExprs(index).toAttribute) - } + Sort(newSortOrder, global, newChild) + }) + } + + def resolveExprsWithAggregate( + exprs: Seq[Expression], + agg: Aggregate): (Seq[NamedExpression], Seq[Expression]) = { + val aggregateExpressions = ArrayBuffer.empty[NamedExpression] + val transformed = exprs.map { e => + // Try resolving the expression as though it is in the aggregate clause. + def resolveCol(input: Expression): Expression = { + resolveExpressionByPlanOutput(input, agg.child) + } + def resolveSubQuery(input: Expression): Expression = { + if (SubqueryExpression.hasSubquery(input)) { + val fake = Project(Alias(input, "fake")() :: Nil, agg.child) + ResolveSubquery(fake).asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child + } else { + input } - - val sortOrdersMap = unresolvedSortOrders - .map(new TreeNodeRef(_)) - .zip(evaluatedOrderings) - .toMap - val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) - - // Since we don't rely on sort.resolved as the stop condition for this rule, - // we need to check this and prevent applying this rule multiple times - if (sortOrder == finalSortOrders) { - sort + } + val maybeResolved = resolveSubQuery(resolveCol(e)) + if (maybeResolved.resolved && maybeResolved.references.subsetOf(agg.outputSet) && + !containsAggregate(maybeResolved)) { + // The given expression is valid and doesn't need extra resolution. + maybeResolved + } else if (containsUnresolvedFunc(maybeResolved)) { + // The given expression has unresolved functions which may be aggregate functions and we + // need to wait for other rules to resolve the functions first. + maybeResolved + } else { + // Avoid adding an extra aggregate expression if it's already present in + // `agg.aggregateExpressions`. + val index = if (maybeResolved.resolved) { + agg.aggregateExpressions.indexWhere { + case Alias(child, _) => child semanticEquals maybeResolved + case other => other semanticEquals maybeResolved + } } else { - Project(aggregate.output, - Sort(finalSortOrders, global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + -1 + } + if (index >= 0) { Review comment: It looks we don't need this if condition? ``` if (maybeResolved.resolved) { val index = agg.aggregateExpressions.indexWhere { case Alias(child, _) => child semanticEquals maybeResolved case other => other semanticEquals maybeResolved } agg.aggregateExpressions(index).toAttribute } else { buildAggExprList(maybeResolved, agg, aggregateExpressions) } ``` ? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ########## @@ -2457,164 +2450,127 @@ class Analyzer(override val catalogManager: CatalogManager) _.containsPattern(AGGREGATE), ruleId) { // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly // resolve the having condition expression, here we skip resolving it in ResolveReferences - // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519. + // and transform it to Filter after aggregate is resolved. Basically columns in HAVING should + // be resolved with `agg.child.output` first. See more details in SPARK-31519. case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved => - resolveHaving(Filter(cond, agg), agg) - - case f @ Filter(_, agg: Aggregate) if agg.resolved => - resolveHaving(f, agg) - - case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => - - // Try resolving the ordering as though it is in the aggregate clause. - try { - // If a sort order is unresolved, containing references not in aggregate, or containing - // `AggregateExpression`, we need to push down it to the underlying aggregate operator. - val unresolvedSortOrders = sortOrder.filter { s => - !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + resolveOperatorWithAggregate(Seq(cond), agg, (newExprs, newChild) => { + Filter(newExprs.head, newChild) + }) + + case Filter(cond, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = resolveExpressionByPlanOutput(cond, agg) + resolveOperatorWithAggregate(Seq(maybeResolved), agg, (newExprs, newChild) => { + Filter(newExprs.head, newChild) + }) + + case Sort(sortOrder, global, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = sortOrder.map(_.child).map(resolveExpressionByPlanOutput(_, agg)) + resolveOperatorWithAggregate(maybeResolved, agg, (newExprs, newChild) => { + val newSortOrder = sortOrder.zip(newExprs).map { + case (sortOrder, expr) => sortOrder.copy(child = expr) } - val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) - - val aggregateWithExtraOrdering = aggregate.copy( - aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering) - - val resolvedAggregate: Aggregate = - executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate] - - val (reResolvedAggExprs, resolvedAliasedOrdering) = - resolvedAggregate.aggregateExpressions.splitAt(aggregate.aggregateExpressions.length) - - // If we pass the analysis check, then the ordering expressions should only reference to - // aggregate expressions or grouping expressions, and it's safe to push them down to - // Aggregate. - checkAnalysis(resolvedAggregate) - - val originalAggExprs = aggregate.aggregateExpressions.map(trimNonTopLevelAliases) - - // If the ordering expression is same with original aggregate expression, we don't need - // to push down this ordering expression and can reference the original aggregate - // expression instead. - val needsPushDown = ArrayBuffer.empty[NamedExpression] - val orderToAlias = unresolvedSortOrders.zip(aliasedOrdering) - val evaluatedOrderings = - resolvedAliasedOrdering.asInstanceOf[Seq[Alias]].zip(orderToAlias).map { - case (evaluated, (order, aliasOrder)) => - val index = reResolvedAggExprs.indexWhere { - case Alias(child, _) => child semanticEquals evaluated.child - case other => other semanticEquals evaluated.child - } - - if (index == -1) { - if (hasCharVarchar(evaluated)) { - needsPushDown += aliasOrder - order.copy(child = aliasOrder) - } else { - needsPushDown += evaluated - order.copy(child = evaluated.toAttribute) - } - } else { - order.copy(child = originalAggExprs(index).toAttribute) - } + Sort(newSortOrder, global, newChild) + }) + } + + def resolveExprsWithAggregate( + exprs: Seq[Expression], + agg: Aggregate): (Seq[NamedExpression], Seq[Expression]) = { + val aggregateExpressions = ArrayBuffer.empty[NamedExpression] + val transformed = exprs.map { e => + // Try resolving the expression as though it is in the aggregate clause. + def resolveCol(input: Expression): Expression = { + resolveExpressionByPlanOutput(input, agg.child) + } + def resolveSubQuery(input: Expression): Expression = { + if (SubqueryExpression.hasSubquery(input)) { + val fake = Project(Alias(input, "fake")() :: Nil, agg.child) + ResolveSubquery(fake).asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child + } else { + input } - - val sortOrdersMap = unresolvedSortOrders - .map(new TreeNodeRef(_)) - .zip(evaluatedOrderings) - .toMap - val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) - - // Since we don't rely on sort.resolved as the stop condition for this rule, - // we need to check this and prevent applying this rule multiple times - if (sortOrder == finalSortOrders) { - sort + } + val maybeResolved = resolveSubQuery(resolveCol(e)) + if (maybeResolved.resolved && maybeResolved.references.subsetOf(agg.outputSet) && + !containsAggregate(maybeResolved)) { + // The given expression is valid and doesn't need extra resolution. + maybeResolved + } else if (containsUnresolvedFunc(maybeResolved)) { + // The given expression has unresolved functions which may be aggregate functions and we + // need to wait for other rules to resolve the functions first. + maybeResolved + } else { + // Avoid adding an extra aggregate expression if it's already present in + // `agg.aggregateExpressions`. + val index = if (maybeResolved.resolved) { + agg.aggregateExpressions.indexWhere { + case Alias(child, _) => child semanticEquals maybeResolved + case other => other semanticEquals maybeResolved + } } else { - Project(aggregate.output, - Sort(finalSortOrders, global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + -1 + } + if (index >= 0) { Review comment: Ah, I see. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org