This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new db0e8224e1e [SPARK-42548][SQL] Add ReferenceAllColumns to skip rewriting attributes db0e8224e1e is described below commit db0e8224e1e4c928fa2f7046ae13b6aad2b8cad6 Author: ulysses-you <ulyssesyo...@gmail.com> AuthorDate: Tue Feb 28 15:52:53 2023 +0800 [SPARK-42548][SQL] Add ReferenceAllColumns to skip rewriting attributes ### What changes were proposed in this pull request? Add a new trait `ReferenceAllColumns ` that overrides `references` using children output. Then we can skip it during rewriting attributes in transformUpWithNewOutput. ### Why are the changes needed? There are two reasons with this new trait: 1. it's dangerous to call `references` on an unresolved plan that all of references come from children 2. it's unnecessary to rewrite its attributes that all of references come from children ### Does this PR introduce _any_ user-facing change? prevent potential bug ### How was this patch tested? add test and pass CI Closes #40154 from ulysses-you/references. Authored-by: ulysses-you <ulyssesyo...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/plans/QueryPlan.scala | 37 +++++++++++++--------- .../sql/catalyst/plans/ReferenceAllColumns.scala | 34 ++++++++++++++++++++ .../plans/logical/ScriptTransformation.scala | 8 ++--- .../spark/sql/catalyst/plans/logical/object.scala | 8 ++--- .../sql/catalyst/analysis/TypeCoercionSuite.scala | 18 +++++++++++ .../org/apache/spark/sql/execution/objects.scala | 8 ++--- 6 files changed, 81 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 90d1bd805cb..ae5e9789dd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -297,21 +297,28 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] newChild } - val attrMappingForCurrentPlan = attrMapping.filter { - // The `attrMappingForCurrentPlan` is used to replace the attributes of the - // current `plan`, so the `oldAttr` must be part of `plan.references`. - case (oldAttr, _) => plan.references.contains(oldAttr) - } - - if (attrMappingForCurrentPlan.nonEmpty) { - assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId) - .exists(_._2.map(_._2.exprId).distinct.length > 1), - "Found duplicate rewrite attributes") - - val attributeRewrites = AttributeMap(attrMappingForCurrentPlan) - // Using attrMapping from the children plans to rewrite their parent node. - // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. - newPlan = newPlan.rewriteAttrs(attributeRewrites) + plan match { + case _: ReferenceAllColumns[_] => + // It's dangerous to call `references` on an unresolved `ReferenceAllColumns`, and + // it's unnecessary to rewrite its attributes that all of references come from children + + case _ => + val attrMappingForCurrentPlan = attrMapping.filter { + // The `attrMappingForCurrentPlan` is used to replace the attributes of the + // current `plan`, so the `oldAttr` must be part of `plan.references`. + case (oldAttr, _) => plan.references.contains(oldAttr) + } + + if (attrMappingForCurrentPlan.nonEmpty) { + assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") + + val attributeRewrites = AttributeMap(attrMappingForCurrentPlan) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + newPlan = newPlan.rewriteAttrs(attributeRewrites) + } } val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala new file mode 100644 index 00000000000..613e2a06f49 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.sql.catalyst.expressions.AttributeSet + +/** + * A trait that overrides `references` using children output. + * + * It's unnecessary to rewrite attributes for `ReferenceAllColumns` since all of references + * come from it's children. + * + * Note, the only used place is at [[QueryPlan.transformUpWithNewOutput]]. + */ +trait ReferenceAllColumns[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[PlanType] => + + @transient + override final lazy val references: AttributeSet = AttributeSet(children.flatMap(_.outputSet)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index 5fe5dc37371..e6ebf981bc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns /** * Transforms the input by forking and running the specified script. @@ -30,10 +31,7 @@ case class ScriptTransformation( script: String, output: Seq[Attribute], child: LogicalPlan, - ioschema: ScriptInputOutputSchema) extends UnaryNode { - @transient - override lazy val references: AttributeSet = AttributeSet(child.output) - + ioschema: ScriptInputOutputSchema) extends UnaryNode with ReferenceAllColumns[LogicalPlan] { override protected def withNewChildInternal(newChild: LogicalPlan): ScriptTransformation = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index b27c650cfb2..c6a4779374d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} @@ -64,13 +65,8 @@ trait ObjectProducer extends LogicalPlan { * A trait for logical operators that consumes domain objects as input. * The output of its child must be a single-field row containing the input object. */ -trait ObjectConsumer extends UnaryNode { +trait ObjectConsumer extends UnaryNode with ReferenceAllColumns[LogicalPlan] { assert(child.output.length == 1) - - // This operator always need all columns of its child, even it doesn't reference to. - @transient - override lazy val references: AttributeSet = child.outputSet - def inputObjAttr: Attribute = child.output.head } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index adce553d194..e30cce23136 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.internal.SQLConf @@ -1740,6 +1741,16 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase { } } } + + test("SPARK-32638: Add ReferenceAllColumns to skip rewriting attributes") { + val t1 = LocalRelation(AttributeReference("c", DecimalType(1, 0))()) + val t2 = LocalRelation(AttributeReference("c", DecimalType(2, 0))()) + val unresolved = t1.union(t2).select(UnresolvedStar(None)) + val referenceAllColumns = FakeReferenceAllColumns(unresolved) + val wp1 = widenSetOperationTypes(referenceAllColumns.select(t1.output.head)) + assert(wp1.isInstanceOf[Project]) + assert(wp1.expressions.forall(!_.exists(_ == t1.output.head))) + } } @@ -1798,3 +1809,10 @@ object TypeCoercionSuite { copy(left = newLeft, right = newRight) } } + +case class FakeReferenceAllColumns(child: LogicalPlan) + extends UnaryNode with ReferenceAllColumns[LogicalPlan] { + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(child = newChild) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index bda592ff929..c8d575016fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.python.BatchIterator @@ -58,13 +59,8 @@ trait ObjectProducerExec extends SparkPlan { /** * Physical version of `ObjectConsumer`. */ -trait ObjectConsumerExec extends UnaryExecNode { +trait ObjectConsumerExec extends UnaryExecNode with ReferenceAllColumns[SparkPlan] { assert(child.output.length == 1) - - // This operator always need all columns of its child, even it doesn't reference to. - @transient - override lazy val references: AttributeSet = child.outputSet - def inputObjectType: DataType = child.output.head.dataType } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org