This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7fcb12684af3 [SPARK-51964][SQL] Correctly resolve attributes from
hidden output in ORDER BY and HAVING on top of an Aggregate in single-pass
Analyzer
7fcb12684af3 is described below
commit 7fcb12684af308941c67aa61e29580ed7345dbd4
Author: Vladimir Golubev <[email protected]>
AuthorDate: Mon May 5 22:40:12 2025 +0800
[SPARK-51964][SQL] Correctly resolve attributes from hidden output in ORDER
BY and HAVING on top of an Aggregate in single-pass Analyzer
### What changes were proposed in this pull request?
Correctly resolve names in HAVING and ORDER BY in single-pass Analyzer. In
case we are resolving those expression trees on top of an `Aggregate`, we can
naturally only access attributes which are directly outputted from the
`Aggregate` (main output), or if they are present in `Aggregate`'s grouping
expressions (hidden output). However, we can actually access any attributes
from hidden output if we are resolving an `AggregateExpression` in ORDER BY or
HAVING. This `AggregateExpression` [...]
This query fails, because "col2" is not present in grouping expressions and
is not present in `Aggregate`'s output:
`SELECT COUNT(col1) FROM VALUES (1, 2) GROUP BY col1 ORDER BY col2;`
This query succeeds, because we are resolving "col2" under `MAX`:
`SELECT COUNT(col1) FROM VALUES (1, 2) GROUP BY col1 ORDER BY MAX(col2);`
### Why are the changes needed?
This improves name resolution in the new single-pass Analyzer.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
copilot.nvim.
Closes #50769 from
vladimirg-db/vladimir-golubev_data/single-pass-analyzer-support-correct-attribute-resolution-under-aggregate.
Authored-by: Vladimir Golubev <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../analysis/resolver/AggregateResolver.scala | 20 +++-
.../resolver/ExpressionResolutionContext.scala | 8 +-
.../analysis/resolver/ExpressionResolver.scala | 7 +-
.../analysis/resolver/FunctionResolver.scala | 12 ++
.../sql/catalyst/analysis/resolver/NameScope.scala | 125 ++++++++++++++++-----
.../apache/spark/sql/catalyst/util/package.scala | 21 ++++
6 files changed, 161 insertions(+), 32 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolver.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolver.scala
index f39e036807b6..b643ab657717 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolver.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolver.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.analysis.resolver
-import java.util.LinkedHashMap
+import java.util.{HashSet, LinkedHashMap}
import scala.jdk.CollectionConverters._
@@ -29,7 +29,9 @@ import org.apache.spark.sql.catalyst.analysis.{
}
import org.apache.spark.sql.catalyst.expressions.{
Alias,
+ AttributeReference,
Expression,
+ ExprId,
ExprUtils,
IntegerLiteral,
Literal,
@@ -109,7 +111,8 @@ class AggregateResolver(operatorResolver: Resolver,
expressionResolver: Expressi
ExprUtils.assertValidAggregation(resolvedAggregate)
scopes.overwriteOutputAndExtendHiddenOutput(
- output = resolvedAggregate.aggregateExpressions.map(_.toAttribute)
+ output = resolvedAggregate.aggregateExpressions.map(_.toAttribute),
+ groupingAttributeIds = Some(getGroupingAttributeIds(resolvedAggregate))
)
resolvedAggregate
@@ -329,4 +332,17 @@ class AggregateResolver(operatorResolver: Resolver,
expressionResolver: Expressi
.candidates
.isEmpty
}
+
+ private def getGroupingAttributeIds(aggregate: Aggregate): HashSet[ExprId] =
{
+ val groupingAttributeIds = new
HashSet[ExprId](aggregate.groupingExpressions.size)
+ aggregate.groupingExpressions.foreach { rootExpression =>
+ rootExpression.foreach {
+ case attribute: AttributeReference =>
+ groupingAttributeIds.add(attribute.exprId)
+ case _ =>
+ }
+ }
+
+ groupingAttributeIds
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala
index 66fa5a4226e6..cb8bad19b977 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala
@@ -43,6 +43,8 @@ package org.apache.spark.sql.catalyst.analysis.resolver
* Otherwise, extra [[Alias]]es have to be stripped away.
* @param resolvingGroupingExpressions A flag indicating whether an expression
we are resolving is
* one of [[Aggregate.groupingExpressions]].
+ * @param resolvingTreeUnderAggregateExpression A flag indicating whether an
expression we are
+ * resolving a tree under [[AggregateExpression]].
*/
class ExpressionResolutionContext(
val isRoot: Boolean = false,
@@ -52,7 +54,8 @@ class ExpressionResolutionContext(
var hasAttributeOutsideOfAggregateExpressions: Boolean = false,
var hasLateralColumnAlias: Boolean = false,
var isTopOfProjectList: Boolean = false,
- var resolvingGroupingExpressions: Boolean = false) {
+ var resolvingGroupingExpressions: Boolean = false,
+ var resolvingTreeUnderAggregateExpression: Boolean = false) {
/**
* Propagate generic information that is valid across the whole expression
tree from the
@@ -81,7 +84,8 @@ object ExpressionResolutionContext {
)
} else {
new ExpressionResolutionContext(
- resolvingGroupingExpressions = parent.resolvingGroupingExpressions
+ resolvingGroupingExpressions = parent.resolvingGroupingExpressions,
+ resolvingTreeUnderAggregateExpression =
parent.resolvingTreeUnderAggregateExpression
)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala
index 2202d90a04a7..bbee4fbb1011 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala
@@ -651,7 +651,12 @@ class ExpressionResolver(
.peek()
.resolvingGroupingExpressions && conf.groupByAliases
),
- canResolveNameByHiddenOutput = canResolveNameByHiddenOutput
+ canResolveNameByHiddenOutput = canResolveNameByHiddenOutput,
+ canReferenceAggregatedAccessOnlyAttributes = (
+ expressionResolutionContextStack
+ .peek()
+ .resolvingTreeUnderAggregateExpression
+ )
)
val candidate = nameTarget.pickCandidate(unresolvedAttribute)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala
index e02cd600b888..4b6cf80a915a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala
@@ -77,9 +77,14 @@ class FunctionResolver(
private val typeCoercionResolver: TypeCoercionResolver =
new TypeCoercionResolver(timezoneAwareExpressionResolver)
+ private val expressionResolutionContextStack =
+ expressionResolver.getExpressionResolutionContextStack
/**
* Main method used to resolve an [[UnresolvedFunction]]. It resolves it in
the following steps:
+ * - Check if the `unresolvedFunction` is an aggregate expression. Set
+ * `resolvingTreeUnderAggregateExpression` to `true` in that case so we
can properly resolve
+ * attributes in ORDER BY and HAVING.
* - If the function is `count(*)` it is replaced with `count(1)` (please
check
* [[normalizeCountExpression]] documentation for more details).
Otherwise, we resolve the
* children of it.
@@ -93,6 +98,13 @@ class FunctionResolver(
* - Apply timezone, if the resulting expression is
[[TimeZoneAwareExpression]].
*/
override def resolve(unresolvedFunction: UnresolvedFunction): Expression = {
+ val expressionInfo = functionResolution.lookupBuiltinOrTempFunction(
+ unresolvedFunction.nameParts, Some(unresolvedFunction)
+ )
+ if (expressionInfo.exists(_.getGroup == "agg_funcs")) {
+
expressionResolutionContextStack.peek().resolvingTreeUnderAggregateExpression =
true
+ }
+
val functionWithResolvedChildren =
if (isCountStarExpansionAllowed(unresolvedFunction)) {
normalizeCountExpression(unresolvedFunction)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala
index 86fc43fd5224..977e5938dce5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.analysis.resolver
-import java.util.{ArrayDeque, HashMap, LinkedHashMap}
+import java.util.{ArrayDeque, HashMap, HashSet, LinkedHashMap}
import scala.collection.mutable
import scala.jdk.CollectionConverters._
@@ -169,9 +169,20 @@ class NameScope(
/**
* [[hiddenAttributesForResolution]] is an [[AttributeSeq]] that is used for
resolution of
* multipart attribute names, by hidden output. It's created from the
`hiddenOutput` when
- * [[NameScope]] is updated.
+ * [[NameScope]] is updated. [[AGGREGATED_ACCESS_ONLY]] attributes are
excluded from
+ * resolution by default, since they can only be referenced in specific
cases (see
+ * [[resolveMultipartName]] for more details).
*/
private lazy val hiddenAttributesForResolution: AttributeSeq =
+ AttributeSeq.fromNormalOutput(hiddenOutput.filter(!_.aggregatedAccessOnly))
+
+ /**
+ * [[hiddenAttributesForResolutionWithAggregatedOnlyAccess]] is an
[[AttributeSeq]] that is used
+ * for resolution of multipart attribute names, by hidden output including
attributes with
+ * [[AGGREGATED_ACCESS_ONLY]]. These attributes can only be accessed if we
are resolving a tree
+ * under [[AggregateExpression]] (see [[resolveMultipartName]] for more
details).
+ */
+ private lazy val hiddenAttributesForResolutionWithAggregatedOnlyAccess:
AttributeSeq =
AttributeSeq.fromNormalOutput(hiddenOutput)
/**
@@ -403,6 +414,25 @@ class NameScope(
*
* {{ SELECT col1 + col2 AS a FROM VALUES (1, 2) GROUP BY a; }}}
*
+ * In case we are resolving names in expression trees from HAVING or ORDER
BY on top of
+ * [[Aggregate]], we are able to resolve hidden attributes only if those are
present in
+ * grouping expressions, or if the reference itself is under an
[[AggregateExpression]].
+ * In the latter case `canReferenceAggregatedAccessOnlyAttributes` will be
true, and
+ * `hiddenAttributesForResolutionWithAggregatedOnlyAccess` will be used
instead of
+ * `hiddenAttributesForResolution`. Consider the following example:
+ *
+ * {{{
+ * -- This succeeds, because `col2` is in the grouping expressions.
+ * SELECT COUNT(col1) FROM t1 GROUP BY col1, col2 ORDER BY col2;
+ *
+ * -- This fails, because `col2` is not in the grouping expressions.
+ * SELECT COUNT(col1) FROM t1 GROUP BY col1 ORDER BY col2;
+ *
+ * -- This succeeds, despite the fact that `col2` is not in the grouping
expressions.
+ * -- Such references are allowed under an aggregate expression (MAX).
+ * SELECT COUNT(col1) FROM t1 GROUP BY col1 ORDER BY MAX(col2);
+ * }}}
+ *
* We are relying on the [[AttributeSeq]] to perform that work, since it
requires complex
* resolution logic involving nested field extraction and multipart name
matching.
*
@@ -412,7 +442,13 @@ class NameScope(
multipartName: Seq[String],
canLaterallyReferenceColumn: Boolean = true,
canReferenceAggregateExpressionAliases: Boolean = false,
- canResolveNameByHiddenOutput: Boolean = false): NameTarget = {
+ canResolveNameByHiddenOutput: Boolean = false,
+ canReferenceAggregatedAccessOnlyAttributes: Boolean = false): NameTarget
= {
+ val currentHiddenAttributesForResolution = if
(canReferenceAggregatedAccessOnlyAttributes) {
+ hiddenAttributesForResolutionWithAggregatedOnlyAccess
+ } else {
+ hiddenAttributesForResolution
+ }
val resolvedMultipartName: ResolvedMultipartName =
tryResolveMultipartNameByOutput(
@@ -432,7 +468,7 @@ class NameScope(
tryResolveMultipartNameByOutput(
multipartName,
nameComparator,
- hiddenAttributesForResolution,
+ currentHiddenAttributesForResolution,
canResolveByProposedAttributes = canResolveNameByHiddenOutput
)
)
@@ -676,9 +712,9 @@ class NameScopeStack extends SQLConfHelper {
def overwriteCurrent(
output: Option[Seq[Attribute]] = None,
hiddenOutput: Option[Seq[Attribute]] = None): Unit = {
- val hiddenOutputWithUpdatedNullabilities =
updateNullabilitiesInHiddenOutput(
- output.getOrElse(stack.peek().output),
- hiddenOutput.getOrElse(stack.peek().hiddenOutput)
+ val hiddenOutputWithUpdatedNullabilities = updateHiddenOutputProperties(
+ output = output.getOrElse(stack.peek().output),
+ hiddenOutput = hiddenOutput.getOrElse(stack.peek().hiddenOutput)
)
val newScope = stack.pop.overwriteOutput(output,
Some(hiddenOutputWithUpdatedNullabilities))
@@ -708,17 +744,24 @@ class NameScopeStack extends SQLConfHelper {
* output we have to have both hidden output from the previous scope and
the provided output.
* This is done for [[Project]] and [[Aggregate]] operators.
*
- * 2. updates nullabilities of attributes in hidden output from new output,
so that if attribute
- * was nullable in either old hidden output or new output, it must stay
nullable in new hidden
- * output as well.
+ * 2. updates properties of attributes in hidden output. THis includes
nullabilities and access
+ * modes. See [[updateHiddenOutputProperties]] for more details.
*/
- def overwriteOutputAndExtendHiddenOutput(output: Seq[Attribute]): Unit = {
+ def overwriteOutputAndExtendHiddenOutput(
+ output: Seq[Attribute],
+ groupingAttributeIds: Option[HashSet[ExprId]] = None): Unit = {
val prevScope = stack.pop
- val hiddenOutputWithUpdatedNullabilities =
- updateNullabilitiesInHiddenOutput(output, prevScope.hiddenOutput)
- val hiddenOutput = hiddenOutputWithUpdatedNullabilities ++ output.filter {
attribute =>
- prevScope.getHiddenAttributeById(attribute.exprId).isEmpty
- }
+
+ val hiddenOutputWithUpdatedProperties = updateHiddenOutputProperties(
+ output = output,
+ hiddenOutput = prevScope.hiddenOutput,
+ groupingAttributeIds = groupingAttributeIds
+ )
+
+ val hiddenOutput = hiddenOutputWithUpdatedProperties ++ output.filter {
attribute =>
+ prevScope.getHiddenAttributeById(attribute.exprId).isEmpty
+ }
+
val newScope = prevScope.overwriteOutput(
output = Some(output),
hiddenOutput = Some(hiddenOutput)
@@ -848,12 +891,14 @@ class NameScopeStack extends SQLConfHelper {
multipartName: Seq[String],
canLaterallyReferenceColumn: Boolean = true,
canReferenceAggregateExpressionAliases: Boolean = false,
- canResolveNameByHiddenOutput: Boolean = false): NameTarget = {
+ canResolveNameByHiddenOutput: Boolean = false,
+ canReferenceAggregatedAccessOnlyAttributes: Boolean = false): NameTarget
= {
val nameTargetFromCurrentScope = current.resolveMultipartName(
multipartName,
canLaterallyReferenceColumn = canLaterallyReferenceColumn,
canReferenceAggregateExpressionAliases =
canReferenceAggregateExpressionAliases,
- canResolveNameByHiddenOutput = canResolveNameByHiddenOutput
+ canResolveNameByHiddenOutput = canResolveNameByHiddenOutput,
+ canReferenceAggregatedAccessOnlyAttributes =
canReferenceAggregatedAccessOnlyAttributes
)
if (nameTargetFromCurrentScope.candidates.nonEmpty) {
@@ -864,7 +909,8 @@ class NameScopeStack extends SQLConfHelper {
val nameTarget = outer.resolveMultipartName(
multipartName,
canLaterallyReferenceColumn = false,
- canReferenceAggregateExpressionAliases = false
+ canReferenceAggregateExpressionAliases = false,
+ canReferenceAggregatedAccessOnlyAttributes =
canReferenceAggregatedAccessOnlyAttributes
)
if (nameTarget.candidates.nonEmpty) {
@@ -918,20 +964,45 @@ class NameScopeStack extends SQLConfHelper {
}
/**
- * When the scope gets the new output, we need to refresh nullabilities in
its `hiddenOutput`. If
- * an attribute is nullable in either old hidden output or new output, it
must remain nullable in
- * new hidden output as well.
+ * Update attribute properties when overwriting the current outputs.
+ *
+ * 1. When the scope gets the new output, we need to refresh nullabilities
in its `hiddenOutput`.
+ * If an attribute is nullable in either old hidden output or new output, it
must remain nullable
+ * in new hidden output as well.
+ *
+ * 2. If we are updating the hidden output on top of an [[Aggregate]],
HAVING and ORDER BY clauses
+ * may later reference either attributes from grouping expressions, or any
other attributes
+ * under the condition that they are referenced under
[[AggregateExpression]]. We mark those
+ * attributes as [[AGGREGATED_ACCESS_ONLY]] to reference them in
[[resolveMultipartName]] only
+ * if `canReferenceAggregatedAccessOnlyAttributes` is set to `true`.
+ * Attributes from grouping expressions lose their access metadata (e.g.
+ * [[QUALIFIED_ACCESS_ONLY]]) - grouping expression attributes can be simply
referenced given
+ * that the relevant expression tree is canonically equal to the grouping
expression tree.
*/
- private def updateNullabilitiesInHiddenOutput(
+ private def updateHiddenOutputProperties(
output: Seq[Attribute],
- hiddenOutput: Seq[Attribute]) = {
+ hiddenOutput: Seq[Attribute],
+ groupingAttributeIds: Option[HashSet[ExprId]] = None) = {
val outputLookup = new HashMap[ExprId, Attribute](output.size)
output.foreach(attribute => outputLookup.put(attribute.exprId, attribute))
- hiddenOutput.map {
- case attribute if outputLookup.containsKey(attribute.exprId) =>
+ hiddenOutput.map { attribute =>
+ val attributeWithUpdatedNullability = if
(outputLookup.containsKey(attribute.exprId)) {
attribute.withNullability(attribute.nullable ||
outputLookup.get(attribute.exprId).nullable)
- case attribute => attribute
+ } else {
+ attribute
+ }
+
+ groupingAttributeIds match {
+ case Some(groupingAttributeIds) =>
+ if (groupingAttributeIds.contains(attribute.exprId)) {
+ attributeWithUpdatedNullability.markAsAllowAnyAccess()
+ } else {
+ attributeWithUpdatedNullability.markAsAggregatedAccessOnly()
+ }
+ case None =>
+ attributeWithUpdatedNullability
+ }
}
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index 841f367896c4..a94666088b9b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -156,6 +156,14 @@ package object util extends Logging {
*/
val QUALIFIED_ACCESS_ONLY = "__qualified_access_only"
+ /**
+ * If set, this metadata column can only be accessed under
[[AggregateExpression]]. This is
+ * important when resolving columns in ORDER BY and HAVING clauses on top of
[[Aggregate]].
+ * In this case we can only reference attributes from grouping expressions,
or attributes marked
+ * as "__aggregated_access_only" under [[AggregateExpression]].
+ */
+ val AGGREGATED_ACCESS_ONLY = "__aggregated_access_only"
+
implicit class MetadataColumnHelper(attr: Attribute) {
def isMetadataCol: Boolean = MetadataAttribute.isValid(attr.metadata)
@@ -164,6 +172,10 @@ package object util extends Logging {
attr.metadata.contains(QUALIFIED_ACCESS_ONLY) &&
attr.metadata.getBoolean(QUALIFIED_ACCESS_ONLY)
+ def aggregatedAccessOnly: Boolean = attr.isMetadataCol &&
+ attr.metadata.contains(AGGREGATED_ACCESS_ONLY) &&
+ attr.metadata.getBoolean(AGGREGATED_ACCESS_ONLY)
+
def markAsQualifiedAccessOnly(): Attribute = attr.withMetadata(
new MetadataBuilder()
.withMetadata(attr.metadata)
@@ -172,12 +184,21 @@ package object util extends Logging {
.build()
)
+ def markAsAggregatedAccessOnly(): Attribute = attr.withMetadata(
+ new MetadataBuilder()
+ .withMetadata(attr.metadata)
+ .putString(METADATA_COL_ATTR_KEY, attr.name)
+ .putBoolean(AGGREGATED_ACCESS_ONLY, true)
+ .build()
+ )
+
def markAsAllowAnyAccess(): Attribute = {
if (qualifiedAccessOnly) {
attr.withMetadata(
new MetadataBuilder()
.withMetadata(attr.metadata)
.remove(QUALIFIED_ACCESS_ONLY)
+ .remove(AGGREGATED_ACCESS_ONLY)
.build()
)
} else {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]