This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7e9b88bfceb [SPARK-27561][SQL] Support implicit lateral column alias resolution on Project 7e9b88bfceb is described below commit 7e9b88bfceb86d3b32e82a86b672aab3c74def8c Author: Xinyi Yu <xinyi...@databricks.com> AuthorDate: Wed Dec 14 00:14:06 2022 +0800 [SPARK-27561][SQL] Support implicit lateral column alias resolution on Project ### What changes were proposed in this pull request? This PR implements a new feature: Implicit lateral column alias on `Project` case, controlled by `spark.sql.lateralColumnAlias.enableImplicitResolution` temporarily (default false now, but will turn on this conf once the feature is completely merged). #### Lateral column alias View https://issues.apache.org/jira/browse/SPARK-27561 for more details on lateral column alias. There are two main cases to support: LCA in Project, and LCA in Aggregate. ```sql -- LCA in Project. The base_salary references an attribute defined by a previous alias SELECT salary AS base_salary, base_salary + bonus AS total_salary FROM employee -- LCA in Aggregate. The avg_salary references an attribute defined by a previous alias SELECT dept, average(salary) AS avg_salary, avg_salary + average(bonus) FROM employee GROUP BY dept ``` This **implicit** lateral column alias (no explicit keyword, e.g. `lateral.base_salary`) should be supported. #### High level design This PR defines a new Resolution rule, `ResolveLateralColumnAlias` to resolve the implicit lateral column alias, covering the `Project` case. It introduces a new leaf node NamedExpression, `LateralColumnAliasReference`, as a placeholder used to hold a referenced that has been temporarily resolved as the reference to a lateral column alias. The whole process is generally divided into two phases: 1) recognize **resolved** lateral alias, wrap the attributes referencing them with `LateralColumnAliasReference`. 2) when the whole operator is resolved, unwrap `LateralColumnAliasReference`. For Project, it further resolves the attributes and push down the referenced lateral aliases to the new Project. For example: ``` // Before Project [age AS a, 'a + 1] +- Child // After phase 1 Project [age AS a, lateralalias(a) + 1] +- Child // After phase 2 Project [a, a + 1] +- Project [child output, age AS a] +- Child ``` #### Resolution order Given this new rule, the name resolution order will be (higher -> lower): ``` local table column > local metadata attribute > local lateral column alias > all others (outer reference of subquery, parameters of SQL UDF, ..) ``` There is a recent refactor that moves the creation of `OuterReference` in the Resolution batch: https://github.com/apache/spark/pull/38851. Because lateral column alias has higher resolution priority than outer reference, it will try to resolve an `OuterReference` using lateral column alias, similar as an `UnresolvedAttribute`. If success, it strips `OuterReference` and also wraps it with `LateralColumnAliasReference`. ### Why are the changes needed? The lateral column alias is a popular feature wanted for a long time. It is supported by lots of other database vendors (Redshift, snowflake, etc) and provides a better user experience. ### Does this PR introduce _any_ user-facing change? Yes, as shown in the above example, it will be able to resolve lateral column alias. I will write the migration guide or release note when most PRs of this feature are merged. ### How was this patch tested? Existing tests and newly added tests. Closes #38776 from anchovYu/SPARK-27561-refactor. Authored-by: Xinyi Yu <xinyi...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- core/src/main/resources/error/error-classes.json | 6 + .../sql/catalyst/expressions/AttributeMap.scala | 3 +- .../sql/catalyst/expressions/AttributeMap.scala | 3 + .../spark/sql/catalyst/analysis/Analyzer.scala | 119 +++++++- .../sql/catalyst/analysis/CheckAnalysis.scala | 25 +- .../ResolveLateralColumnAliasReference.scala | 135 +++++++++ .../catalyst/expressions/namedExpressions.scala | 33 +++ .../spark/sql/catalyst/expressions/subquery.scala | 9 +- .../sql/catalyst/rules/RuleIdCollection.scala | 2 + .../spark/sql/catalyst/trees/TreePatterns.scala | 1 + .../spark/sql/errors/QueryCompilationErrors.scala | 19 ++ .../org/apache/spark/sql/internal/SQLConf.scala | 11 + .../apache/spark/sql/LateralColumnAliasSuite.scala | 327 +++++++++++++++++++++ 13 files changed, 686 insertions(+), 7 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 54ee0bc2e74..25362d5893f 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -5,6 +5,12 @@ ], "sqlState" : "42000" }, + "AMBIGUOUS_LATERAL_COLUMN_ALIAS" : { + "message" : [ + "Lateral column alias <name> is ambiguous and has <n> matches." + ], + "sqlState" : "42000" + }, "AMBIGUOUS_REFERENCE" : { "message" : [ "Reference <name> is ambiguous, could be: <referenceNames>." diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index c55c542d957..504b65e3db6 100644 --- a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -49,7 +49,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) override def contains(k: Attribute): Boolean = get(k).isDefined - override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv + override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] = + AttributeMap(baseMap.values.toMap + kv) override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator diff --git a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 3d5d6471d26..ac6149f3acc 100644 --- a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -49,6 +49,9 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) override def contains(k: Attribute): Boolean = get(k).isDefined + override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] = + AttributeMap(baseMap.values.toMap + kv) + override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1] = baseMap.values.toMap + (key -> value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3f806137bab..04234204445 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin} import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -288,6 +288,8 @@ class Analyzer(override val catalogManager: CatalogManager) AddMetadataColumns :: DeduplicateRelations :: ResolveReferences :: + WrapLateralColumnAliasReference :: + ResolveLateralColumnAliasReference :: ResolveExpressionsWithNamePlaceholders :: ResolveDeserializer :: ResolveNewInstance :: @@ -1672,7 +1674,7 @@ class Analyzer(override val catalogManager: CatalogManager) // Only Project and Aggregate can host star expressions. case u @ (_: Project | _: Aggregate) => Try(s.expand(u.children.head, resolver)) match { - case Success(expanded) => expanded.map(wrapOuterReference) + case Success(expanded) => expanded.map(wrapOuterReference(_)) case Failure(_) => throw e } // Do not use the outer plan to resolve the star expression @@ -1761,6 +1763,117 @@ class Analyzer(override val catalogManager: CatalogManager) } } + /** + * The first phase to resolve lateral column alias. See comments in + * [[ResolveLateralColumnAliasReference]] for more detailed explanation. + */ + object WrapLateralColumnAliasReference extends Rule[LogicalPlan] { + import ResolveLateralColumnAliasReference.AliasEntry + + private def insertIntoAliasMap( + a: Alias, + idx: Int, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + + /** + * Use the given lateral alias to resolve the unresolved attribute with the name parts. + * + * Construct a dummy plan with the given lateral alias as project list, use the output of the + * plan to resolve. + * @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve. + */ + private def resolveByLateralAlias( + nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { + val resolvedAttr = resolveExpressionByPlanOutput( + expr = UnresolvedAttribute(nameParts), + plan = LocalRelation(Seq(lateralAlias.toAttribute)), + throws = false + ).asInstanceOf[NamedExpression] + if (resolvedAttr.resolved) { + Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute)) + } else { + None + } + } + + /** + * Recognize all the attributes in the given expression that reference lateral column aliases + * by looking up the alias map. Resolve these attributes and replace by wrapping with + * [[LateralColumnAliasReference]]. + * + * @param currentPlan Because lateral alias has lower resolution priority than table columns, + * the current plan is needed to first try resolving the attribute by its + * children + */ + private def wrapLCARef( + e: NamedExpression, + currentPlan: LogicalPlan, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + resolveExpressionByPlanChildren(u, currentPlan).isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case n if n == 1 && aliases.head.alias.resolved => + // Only resolved alias can be the lateral column alias + // The lateral alias can be a struct and have nested field, need to construct + // a dummy plan to resolve the expression + resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) + case _ => u + } + case o: OuterReference + if aliasMap.contains( + o.getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) + .map(_.head) + .getOrElse(o.name)) => + // handle OuterReference exactly same as UnresolvedAttribute + val nameParts = o + .getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) + .getOrElse(Seq(o.name)) + val aliases = aliasMap.get(nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) + case n if n == 1 && aliases.head.alias.resolved => + resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) + case _ => o + } + }.asInstanceOf[NamedExpression] + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { + plan + } else { + plan.resolveOperatorsUpWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { + case p @ Project(projectList, _) if p.childrenResolved + && !ResolveReferences.containsStar(projectList) + && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaWrapped = wrapLCARef(a, p, aliasMap).asInstanceOf[Alias] + // Insert the LCA-resolved alias instead of the unresolved one into map. If it is + // resolved, it can be referenced as LCA by later expressions (chaining). + // Unresolved Alias is also added to the map to perform ambiguous name check, but + // only resolved alias can be LCA. + aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) + lcaWrapped + case (e, _) => + wrapLCARef(e, p, aliasMap) + } + p.copy(projectList = newProjectList) + } + } + } + } + private def containsDeserializer(exprs: Seq[Expression]): Boolean = { exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } @@ -2143,7 +2256,7 @@ class Analyzer(override val catalogManager: CatalogManager) case u @ UnresolvedAttribute(nameParts) => withPosition(u) { try { AnalysisContext.get.outerPlan.get.resolveChildren(nameParts, resolver) match { - case Some(resolved) => wrapOuterReference(resolved) + case Some(resolved) => wrapOuterReference(resolved, Some(nameParts)) case None => u } } catch { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index be812adaaa1..e7e153a319d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, Decorrela import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_WINDOW_EXPRESSION +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_WINDOW_EXPRESSION} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils} import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} @@ -638,6 +638,16 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case UnresolvedWindowExpression(_, windowSpec) => throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowSpec.name) }) + // This should not happen, resolved Project or Aggregate should restore or resolve + // all lateral column alias references. Add check for extra safe. + projectList.foreach(_.transformDownWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if p.resolved => + throw SparkException.internalError("Resolved Project should not contain " + + s"any LateralColumnAliasReference.\nDebugging information: plan: $p", + context = lcaRef.origin.getQueryContext, + summary = lcaRef.origin.context.summary) + }) case j: Join if !j.duplicateResolved => val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) @@ -714,6 +724,19 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB "operator" -> other.nodeName, "invalidExprSqls" -> invalidExprSqls.mkString(", "))) + // This should not happen, resolved Project or Aggregate should restore or resolve + // all lateral column alias references. Add check for extra safe. + case agg @ Aggregate(_, aggList, _) + if aggList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) && agg.resolved => + aggList.foreach(_.transformDownWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference => + throw SparkException.internalError("Resolved Aggregate should not contain " + + s"any LateralColumnAliasReference.\nDebugging information: plan: $agg", + context = lcaRef.origin.getQueryContext, + summary = lcaRef.origin.context.summary) + }) + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala new file mode 100644 index 00000000000..2ca187b95ff --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE +import org.apache.spark.sql.internal.SQLConf + +/** + * This rule is the second phase to resolve lateral column alias. + * + * Resolve lateral column alias, which references the alias defined previously in the SELECT list. + * Plan-wise, it handles two types of operators: Project and Aggregate. + * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve + * the attributes referencing these aliases + * - in Aggregate TODO. + * + * The whole process is generally divided into two phases: + * 1) recognize resolved lateral alias, wrap the attributes referencing them with + * [[LateralColumnAliasReference]] + * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. + * For Project, it further resolves the attributes and push down the referenced lateral aliases. + * For Aggregate, TODO + * + * Example for Project: + * Before rewrite: + * Project [age AS a, 'a + 1] + * +- Child + * + * After phase 1: + * Project [age AS a, lateralalias(a) + 1] + * +- Child + * + * After phase 2: + * Project [a, a + 1] + * +- Project [child output, age AS a] + * +- Child + * + * Example for Aggregate TODO + * + * + * The name resolution priority: + * local table column > local lateral column alias > outer reference + * + * Because lateral column alias has higher resolution priority than outer reference, it will try + * to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an + * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with + * [[LateralColumnAliasReference]]. + */ +object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { + case class AliasEntry(alias: Alias, index: Int) + + /** + * A tag to store the nameParts from the original unresolved attribute. + * It is set for [[OuterReference]], used in the current rule to convert [[OuterReference]] back + * to [[LateralColumnAliasReference]]. + */ + val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr") + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { + plan + } else { + // phase 2: unwrap + plan.resolveOperatorsUpWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) { + case p @ Project(projectList, child) if p.resolved + && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + var aliasMap = AttributeMap.empty[AliasEntry] + val referencedAliases = collection.mutable.Set.empty[AliasEntry] + def unwrapLCAReference(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => + val aliasEntry = aliasMap.get(lcaRef.a).get + // If there is no chaining of lateral column alias reference, push down the alias + // and unwrap the LateralColumnAliasReference to the NamedExpression inside + // If there is chaining, don't resolve and save to future rounds + if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + referencedAliases += aliasEntry + lcaRef.ne + } else { + lcaRef + } + case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => + // It shouldn't happen, but restore to unresolved attribute to be safe. + UnresolvedAttribute(lcaRef.nameParts) + }.asInstanceOf[NamedExpression] + } + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaResolved = unwrapLCAReference(a) + // Insert the original alias instead of rewritten one to detect chained LCA + aliasMap += (a.toAttribute -> AliasEntry(a, idx)) + lcaResolved + case (e, _) => + unwrapLCAReference(e) + } + + if (referencedAliases.isEmpty) { + p + } else { + val outerProjectList = collection.mutable.Seq(newProjectList: _*) + val innerProjectList = + collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) + referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => + outerProjectList.update(idx, alias.toAttribute) + innerProjectList += alias + } + p.copy( + projectList = outerProjectList.toSeq, + child = Project(innerProjectList.toSeq, child) + ) + } + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8dd28e9aaae..0f5239be6ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -428,6 +428,39 @@ case class OuterReference(e: NamedExpression) final override val nodePatterns: Seq[TreePattern] = Seq(OUTER_REFERENCE) } +/** + * A placeholder used to hold a [[NamedExpression]] that has been temporarily resolved as the + * reference to a lateral column alias. + * + * This is created and removed by Analyzer rule [[ResolveLateralColumnAlias]]. + * There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all + * analysis check, then all [[LateralColumnAliasReference]] should already be removed. + * + * @param ne the resolved [[NamedExpression]] by lateral column alias + * @param nameParts the named parts of the original [[UnresolvedAttribute]]. Used to restore back + * to [[UnresolvedAttribute]] when needed + * @param a the attribute of referenced lateral column alias. Used to match alias when unwrapping + * and resolving LateralColumnAliasReference + */ +case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[String], a: Attribute) + extends LeafExpression with NamedExpression with Unevaluable { + assert(ne.resolved) + override def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + override def exprId: ExprId = ne.exprId + override def qualifier: Seq[String] = ne.qualifier + override def toAttribute: Attribute = ne.toAttribute + override def newInstance(): NamedExpression = + LateralColumnAliasReference(ne.newInstance(), nameParts, a) + + override def nullable: Boolean = ne.nullable + override def dataType: DataType = ne.dataType + override def prettyName: String = "lateralAliasReference" + override def sql: String = s"$prettyName($name)" + + final override val nodePatterns: Seq[TreePattern] = Seq(LATERAL_COLUMN_ALIAS_REFERENCE) +} + object VirtualColumn { // The attribute name used by Hive, which has different result than Spark, deprecated. val hiveGroupingIdName: String = "grouping__id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e7384dac2d5..b510893f370 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, LogicalPlan} @@ -158,8 +159,12 @@ object SubExprUtils extends PredicateHelper { /** * Wrap attributes in the expression with [[OuterReference]]s. */ - def wrapOuterReference[E <: Expression](e: E): E = { - e.transform { case a: Attribute => OuterReference(a) }.asInstanceOf[E] + def wrapOuterReference[E <: Expression](e: E, nameParts: Option[Seq[String]] = None): E = { + e.transform { case a: Attribute => + val o = OuterReference(a) + nameParts.map(o.setTagValue(NAME_PARTS_FROM_UNRESOLVED_ATTR, _)) + o + }.asInstanceOf[E] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index f6bef88ab86..efafd3cfbcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -77,6 +77,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$WindowsSubstitution" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$WrapLateralColumnAliasReference" :: "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$AnsiCombinedTypeCoercionRule" :: "org.apache.spark.sql.catalyst.analysis.ApplyCharTypePadding" :: "org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" :: @@ -88,6 +89,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveJoinStrategyHints" :: "org.apache.spark.sql.catalyst.analysis.ResolveInlineTables" :: "org.apache.spark.sql.catalyst.analysis.ResolveLambdaVariables" :: + "org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: "org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 8fca9ec60cd..1a8ad7c7d62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -58,6 +58,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index fbd8fb4a2a9..b329f6689d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3399,4 +3399,23 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { cause = Option(other)) } } + + def ambiguousLateralColumnAlias(name: String, numOfMatches: Int): Throwable = { + new AnalysisException( + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + messageParameters = Map( + "name" -> toSQLId(name), + "n" -> numOfMatches.toString + ) + ) + } + def ambiguousLateralColumnAlias(nameParts: Seq[String], numOfMatches: Int): Throwable = { + new AnalysisException( + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + messageParameters = Map( + "name" -> toSQLId(nameParts), + "n" -> numOfMatches.toString + ) + ) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 84d78f365ac..575775a0f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4027,6 +4027,17 @@ object SQLConf { .checkValues(ErrorMessageFormat.values.map(_.toString)) .createWithDefault(ErrorMessageFormat.PRETTY.toString) + val LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED = + buildConf("spark.sql.lateralColumnAlias.enableImplicitResolution") + .internal() + .doc("Enable resolving implicit lateral column alias defined in the same SELECT list. For " + + "example, with this conf turned on, for query `SELECT 1 AS a, a + 1` the `a` in `a + 1` " + + "can be resolved as the previously defined `1 AS a`. But note that table column has " + + "higher resolution priority than the lateral column alias.") + .version("3.4.0") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala new file mode 100644 index 00000000000..abeb3bb7841 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { + protected val testTable: String = "employee" + + override def beforeAll(): Unit = { + super.beforeAll() + sql( + s""" + |CREATE TABLE $testTable ( + | dept INTEGER, + | name String, + | salary INTEGER, + | bonus INTEGER, + | properties STRUCT<joinYear INTEGER, mostRecentEmployer STRING>) + |USING orc + |""".stripMargin) + sql( + s""" + |INSERT INTO $testTable VALUES + | (1, 'amy', 10000, 1000, named_struct('joinYear', 2019, 'mostRecentEmployer', 'A')), + | (2, 'alex', 12000, 1200, named_struct('joinYear', 2017, 'mostRecentEmployer', 'A')), + | (1, 'cathy', 9000, 1200, named_struct('joinYear', 2020, 'mostRecentEmployer', 'B')), + | (2, 'david', 10000, 1300, named_struct('joinYear', 2019, 'mostRecentEmployer', 'C')), + | (6, 'jen', 12000, 1200, named_struct('joinYear', 2018, 'mostRecentEmployer', 'D')) + |""".stripMargin) + } + + override def afterAll(): Unit = { + try { + sql(s"DROP TABLE IF EXISTS $testTable") + } finally { + super.afterAll() + } + } + + val lcaEnabled: Boolean = true + // by default the tests in this suites run with LCA on + override protected def test(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> lcaEnabled.toString) { + testFun + } + } + } + // mark special testcases test both LCA on and off + protected def testOnAndOff(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*)(testFun) + } + + private def withLCAOff(f: => Unit): Unit = { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "false") { + f + } + } + private def withLCAOn(f: => Unit): Unit = { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "true") { + f + } + } + + testOnAndOff("Lateral alias basics - Project") { + def checkAnswerWhenOnAndExceptionWhenOff(query: String, expectedAnswerLCAOn: Row): Unit = { + withLCAOn { checkAnswer(sql(query), expectedAnswerLCAOn) } + withLCAOff { + assert(intercept[AnalysisException]{ sql(query) } + .getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + } + } + + checkAnswerWhenOnAndExceptionWhenOff( + s"select dept as d, d + 1 as e from $testTable where name = 'amy'", + Row(1, 2)) + + checkAnswerWhenOnAndExceptionWhenOff( + s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'", + Row(20000, 21000)) + checkAnswerWhenOnAndExceptionWhenOff( + s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" + + s" where name = 'amy'", + Row(20000, 22000)) + + checkAnswerWhenOnAndExceptionWhenOff( + "select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " + + s"new_income from $testTable where name = 'amy'", + Row(20000, 23000)) + + // should referring to the previously defined LCA + checkAnswerWhenOnAndExceptionWhenOff( + s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'", + Row(18000, 18000, 10000) + ) + } + + test("Duplicated lateral alias names - Project") { + def checkDuplicatedAliasErrorHelper(query: String, parameters: Map[String, String]): Unit = { + checkError( + exception = intercept[AnalysisException] {sql(query)}, + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + sqlState = "42000", + parameters = parameters + ) + } + + // Has duplicated names but not referenced is fine + checkAnswer( + sql(s"SELECT salary AS d, bonus AS d FROM $testTable WHERE name = 'jen'"), + Row(12000, 1200) + ) + checkAnswer( + sql(s"SELECT salary AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(12000, 12000, 10000) + ) + checkAnswer( + sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(18000, 18000, 10000) + ) + checkAnswer( + sql(s"SELECT salary + 1000 AS new_salary, new_salary * 1.0 AS new_salary " + + s"FROM $testTable WHERE name = 'jen'"), + Row(13000, 13000.0)) + + // Referencing duplicated names raises error + checkDuplicatedAliasErrorHelper( + s"SELECT salary * 1.5 AS d, d, 10000 AS d, d + 1 FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT 10000 AS d, d * 1.0, salary * 1.5 AS d, d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT salary AS d, d + 1 AS d, d + 1 AS d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT salary * 1.5 AS d, d, bonus * 1.5 AS d, d + d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + + checkAnswer( + sql( + s""" + |SELECT salary * 1.5 AS salary, salary, 10000 AS salary, salary + |FROM $testTable + |WHERE name = 'jen' + |""".stripMargin), + Row(18000, 12000, 10000, 12000) + ) + } + + test("Lateral alias conflicts with table column - Project") { + checkAnswer( + sql( + "select salary * 2 as salary, salary * 2 + bonus as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 21000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + + s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + + " where name = 'amy'"), + Row(20000, 22000, 11000, 22000)) + + checkAnswer( + sql(s"SELECT named_struct('joinYear', 2022) AS properties, properties.joinYear " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row(2022), 2019)) + + checkAnswer( + sql(s"SELECT named_struct('name', 'someone') AS $testTable, $testTable.name " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row("someone"), "amy")) + } + + testOnAndOff("Lateral alias conflicts with OuterReference - Project") { + // an attribute can both be resolved as LCA and OuterReference + val query1 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id, id + 1 AS id2)) > 5 + |ORDER BY id + |""".stripMargin + withLCAOff { checkAnswer(sql(query1), Row(5) :: Row(6) :: Nil) } + withLCAOn { checkAnswer(sql(query1), Seq.empty) } + + // an attribute can only be resolved as LCA + val query2 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id1, id1 + 1 AS id2)) > 5 + |""".stripMargin + withLCAOff { + assert(intercept[AnalysisException] { sql(query2) } + .getErrorClass == "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") + } + withLCAOn { checkAnswer(sql(query2), Seq.empty) } + + // an attribute should only be resolved as OuterReference + val query3 = + s""" + |SELECT * + |FROM range(1, 7) outer_table + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id, outer_table.id + 1 AS id2)) > 5 + |""".stripMargin + withLCAOff { checkAnswer(sql(query3), Row(5) :: Row(6) :: Nil) } + withLCAOn { checkAnswer(sql(query3), Row(5) :: Row(6) :: Nil) } + + // a bit complex subquery that the id + 1 is first wrapped with OuterReference + // test if lca rule strips the OuterReference and resolves to lateral alias + val query4 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 + |ORDER BY id + |""".stripMargin + withLCAOff { intercept[AnalysisException] { sql(query4) } } // surprisingly can't run .. + withLCAOn { + val analyzedPlan = sql(query4).queryExecution.analyzed + assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) + // but running it triggers exception + // checkAnswer(sql(query4), Range(1, 7).map(Row(_))) + } + } + // TODO: more tests on LCA in subquery + + test("Lateral alias of a complex type - Project") { + checkAnswer( + sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1"), + Row(Row(1), 2, 3)) + + checkAnswer( + sql("SELECT named_struct('a', named_struct('b', 1)) AS foo, foo.a.b + 1 AS bar"), + Row(Row(Row(1)), 2) + ) + + checkAnswer( + sql("SELECT array(1, 2, 3) AS foo, foo[1] AS bar, bar + 1"), + Row(Seq(1, 2, 3), 2, 3) + ) + checkAnswer( + sql("SELECT array(array(1, 2), array(1, 2, 3), array(100)) AS foo, foo[2][0] + 1 AS bar"), + Row(Seq(Seq(1, 2), Seq(1, 2, 3), Seq(100)), 101) + ) + checkAnswer( + sql("SELECT array(named_struct('a', 1), named_struct('a', 2)) AS foo, foo[0].a + 1 AS bar"), + Row(Seq(Row(1), Row(2)), 2) + ) + + checkAnswer( + sql("SELECT map('a', 1, 'b', 2) AS foo, foo['b'] AS bar, bar + 1"), + Row(Map("a" -> 1, "b" -> 2), 2, 3) + ) + } + + test("Lateral alias reference attribute further be used by upper plan - Project") { + // this is out of the scope of lateral alias project functionality requirements, but naturally + // supported by the current design + checkAnswer( + sql(s"SELECT properties AS new_properties, new_properties.joinYear AS new_join_year " + + s"FROM $testTable WHERE dept = 1 ORDER BY new_join_year DESC"), + Row(Row(2020, "B"), 2020) :: Row(Row(2019, "A"), 2019) :: Nil + ) + } + + test("Lateral alias chaining - Project") { + checkAnswer( + sql( + s""" + |SELECT bonus * 1.1 AS new_bonus, salary + new_bonus AS new_base, + | new_base * 1.1 AS new_total, new_total - new_base AS r, + | new_total - r + |FROM $testTable WHERE name = 'cathy' + |""".stripMargin), + Row(1320, 10320, 11352, 1032, 10320) + ) + + checkAnswer( + sql("SELECT 1 AS a, a + 1 AS b, b - 1, b + 1 AS c, c + 1 AS d, d - a AS e, e + 1"), + Row(1, 2, 1, 3, 4, 3, 4) + ) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org