This is an automated email from the ASF dual-hosted git repository. yamamuro pushed a commit to branch branch-3.1 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push: new 448b8d0 [SPARK-34749][SQL][3.1] Simplify ResolveCreateNamedStruct 448b8d0 is described below commit 448b8d07df41040058c21e6102406e1656727599 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Thu Mar 18 07:44:11 2021 +0900 [SPARK-34749][SQL][3.1] Simplify ResolveCreateNamedStruct backports https://github.com/apache/spark/pull/31843 ### What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/31808 and simplifies its fix to one line (excluding comments). ### Why are the changes needed? code simplification ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? N/A Closes #31867 from cloud-fan/backport. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Takeshi Yamamuro <yamam...@apache.org> --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 -- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 10 +++++++++- .../sql/catalyst/expressions/complexTypeExtractors.scala | 11 +---------- .../spark/sql/catalyst/parser/ExpressionParserSuite.scala | 2 +- 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f98f33b..f4cdeab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3840,8 +3840,6 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { val children = e.children.grouped(2).flatMap { case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => Seq(Literal(e.name), e) - case Seq(NamePlaceholder, e: ExtractValue) if e.resolved && e.name.isDefined => - Seq(Literal(e.name.get), e) case kv => kv } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index cb59fbd..1779d41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -336,6 +336,14 @@ object CreateStruct { */ def apply(children: Seq[Expression]): CreateNamedStruct = { CreateNamedStruct(children.zipWithIndex.flatMap { + // For multi-part column name like `struct(a.b.c)`, it may be resolved into: + // 1. Attribute if `a.b.c` is simply a qualified column name. + // 2. GetStructField if `a.b` refers to a struct-type column. + // 3. GetArrayStructFields if `a.b` refers to a array-of-struct-type column. + // 4. GetMapValue if `a.b` refers to a map-type column. + // We should always use the last part of the column name (`c` in the above example) as the + // alias name inside CreateNamedStruct. + case (u: UnresolvedAttribute, _) => Seq(Literal(u.nameParts.last), u) case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) case (e: NamedExpression, _) => Seq(NamePlaceholder, e) case (e, index) => Seq(Literal(s"col${index + 1}"), e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9b80140..ef247ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -94,10 +94,7 @@ object ExtractValue { } } -trait ExtractValue extends Expression { - // The name that is used to extract the value. - def name: Option[String] -} +trait ExtractValue extends Expression /** * Returns the value of fields in the Struct `child`. @@ -163,7 +160,6 @@ case class GetArrayStructFields( override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}" - override def name: Option[String] = Some(field.name) protected override def nullSafeEval(input: Any): Any = { val array = input.asInstanceOf[ArrayData] @@ -241,7 +237,6 @@ case class GetArrayItem( override def toString: String = s"$child[$ordinal]" override def sql: String = s"${child.sql}[${ordinal.sql}]" - override def name: Option[String] = None override def left: Expression = child override def right: Expression = ordinal @@ -461,10 +456,6 @@ case class GetMapValue( override def toString: String = s"$child[$key]" override def sql: String = s"${child.sql}[${key.sql}]" - override def name: Option[String] = key match { - case NonNullLiteral(s, StringType) => Some(s.toString) - case _ => None - } override def left: Expression = child override def right: Expression = key diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 9f6a76b..9711cdc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -425,7 +425,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. assertEqual( "struct(a, b).b", - namedStruct(NamePlaceholder, 'a, NamePlaceholder, 'b).getField("b")) + namedStruct(Literal("a"), 'a, Literal("b"), 'b).getField("b")) } test("reference") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org