This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 5af0819654a [SPARK-41086][SQL] Use DataFrame ID to semantically validate CollectMetrics 5af0819654a is described below commit 5af0819654aca896d73c16875b07b2143cb1132c Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Fri Sep 22 11:07:25 2023 +0800 [SPARK-41086][SQL] Use DataFrame ID to semantically validate CollectMetrics ### What changes were proposed in this pull request? In existing code, plan matching is used to validate if two CollectMetrics have the same name but different semantic. However, plan matching approach is fragile. A better way to tackle this is to just utilize the unique DataFrame Id. This is because observe API is only supported by DataFrame API. SQL does not have such syntax. So two CollectMetric are semantic the same if and only if they have same name and same DataFrame id. ### Why are the changes needed? This is to use a more stable approach to replace a fragile approach. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? UT ### Was this patch authored or co-authored using generative AI tooling? NO Closes #43010 from amaliujia/another_approch_for_collect_metrics. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 7c3c7c5a4bd94c9e05b5e680a5242c2485875633) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/connect/planner/SparkConnectPlanner.scala | 6 +-- python/pyspark/sql/connect/plan.py | 1 + .../spark/sql/catalyst/analysis/Analyzer.scala | 4 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 36 ++------------ .../plans/logical/basicLogicalOperators.scala | 3 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 55 +++++++++------------- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- 8 files changed, 35 insertions(+), 74 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 641dfc5dcd3..50a55f5e641 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -164,7 +164,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION => transformCachedRemoteRelation(rel.getCachedRemoteRelation) case proto.Relation.RelTypeCase.COLLECT_METRICS => - transformCollectMetrics(rel.getCollectMetrics) + transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId) case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") @@ -1054,12 +1054,12 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { numPartitionsOpt) } - private def transformCollectMetrics(rel: proto.CollectMetrics): LogicalPlan = { + private def transformCollectMetrics(rel: proto.CollectMetrics, planId: Long): LogicalPlan = { val metrics = rel.getMetricsList.asScala.toSeq.map { expr => Column(transformExpression(expr)) } - CollectMetrics(rel.getName, metrics.map(_.named), transformRelation(rel.getInput)) + CollectMetrics(rel.getName, metrics.map(_.named), transformRelation(rel.getInput), planId) } private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = { diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 196b1f119ba..b7ea1f94993 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1196,6 +1196,7 @@ class CollectMetrics(LogicalPlan): assert self._child is not None plan = proto.Relation() + plan.common.plan_id = self._child._plan_id plan.collect_metrics.input.CopyFrom(self._child.plan(session)) plan.collect_metrics.name = self._name plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for x in self._exprs]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6c5d19f58ac..8e3c9b30c61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3796,9 +3796,9 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { Window(cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) - case CollectMetrics(name, metrics, child) => + case CollectMetrics(name, metrics, child, dataframeId) => val cleanedMetrics = metrics.map(trimNonTopLevelAliases) - CollectMetrics(name, cleanedMetrics, child) + CollectMetrics(name, cleanedMetrics, child, dataframeId) case Unpivot(ids, values, aliases, variableColumnName, valueColumnNames, child) => val cleanedIds = ids.map(_.map(trimNonTopLevelAliases)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 139fa34a1df..511f3622e7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -484,7 +484,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB groupingExprs.foreach(checkValidGroupingExprs) aggregateExprs.foreach(checkValidAggregateExpression) - case CollectMetrics(name, metrics, _) => + case CollectMetrics(name, metrics, _, _) => if (name == null || name.isEmpty) { operator.failAnalysis( errorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME", @@ -1075,17 +1075,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB * are allowed (e.g. self-joins). */ private def checkCollectedMetrics(plan: LogicalPlan): Unit = { - val metricsMap = mutable.Map.empty[String, LogicalPlan] + val metricsMap = mutable.Map.empty[String, CollectMetrics] def check(plan: LogicalPlan): Unit = plan.foreach { node => node match { - case metrics @ CollectMetrics(name, _, _) => - val simplifiedMetrics = simplifyPlanForCollectedMetrics(metrics.canonicalized) + case metrics @ CollectMetrics(name, _, _, dataframeId) => metricsMap.get(name) match { case Some(other) => - val simplifiedOther = simplifyPlanForCollectedMetrics(other.canonicalized) // Exact duplicates are allowed. They can be the result // of a CTE that is used multiple times or a self join. - if (simplifiedMetrics != simplifiedOther) { + if (dataframeId != other.dataframeId) { failAnalysis( errorClass = "DUPLICATED_METRICS_NAME", messageParameters = Map("metricName" -> name)) @@ -1104,32 +1102,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB check(plan) } - /** - * This method is only used for checking collected metrics. This method tries to - * remove extra project which only re-assign expr ids from the plan so that we can identify exact - * duplicates metric definition. - */ - def simplifyPlanForCollectedMetrics(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperators { - case p: Project if p.projectList.size == p.child.output.size => - val assignExprIdOnly = p.projectList.zipWithIndex.forall { - case (Alias(attr: AttributeReference, _), index) => - // The input plan of this method is already canonicalized. The attribute id becomes the - // ordinal of this attribute in the child outputs. So an alias-only Project means the - // the id of the aliased attribute is the same as its index in the project list. - attr.exprId.id == index - case (left: AttributeReference, index) => - left.exprId.id == index - case _ => false - } - if (assignExprIdOnly) { - p.child - } else { - p - } - } - } - /** * Validates to make sure the outer references appearing inside the subquery * are allowed. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 4bb830662a3..96b67fc52e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1952,7 +1952,8 @@ trait SupportsSubquery extends LogicalPlan case class CollectMetrics( name: String, metrics: Seq[NamedExpression], - child: LogicalPlan) + child: LogicalPlan, + dataframeId: Long) extends UnaryNode { override lazy val resolved: Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 57b37e67b32..802b6d471a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -771,34 +771,35 @@ class AnalysisSuite extends AnalysisTest with Matchers { val literal = Literal(1).as("lit") // Ok - assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, testRelation).resolved) + assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, testRelation, 0).resolved) // Bad name - assert(!CollectMetrics("", sum :: Nil, testRelation).resolved) + assert(!CollectMetrics("", sum :: Nil, testRelation, 0).resolved) assertAnalysisErrorClass( - CollectMetrics("", sum :: Nil, testRelation), + CollectMetrics("", sum :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME", expectedMessageParameters = Map( - "operator" -> "'CollectMetrics , [sum(a#x) AS sum#xL]\n+- LocalRelation <empty>, [a#x]\n") + "operator" -> + "'CollectMetrics , [sum(a#x) AS sum#xL], 0\n+- LocalRelation <empty>, [a#x]\n") ) // No columns - assert(!CollectMetrics("evt", Nil, testRelation).resolved) + assert(!CollectMetrics("evt", Nil, testRelation, 0).resolved) def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit = { - assertAnalysisError(CollectMetrics("event", exprs, testRelation), errors) + assertAnalysisError(CollectMetrics("event", exprs, testRelation, 0), errors) } // Unwrapped attribute assertAnalysisErrorClass( - CollectMetrics("event", a :: Nil, testRelation), + CollectMetrics("event", a :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", expectedMessageParameters = Map("expr" -> "\"a\"") ) // Unwrapped non-deterministic expression assertAnalysisErrorClass( - CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation), + CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC", expectedMessageParameters = Map("expr" -> "\"rand(10) AS rnd\"") ) @@ -808,7 +809,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { CollectMetrics( "event", Sum(a).toAggregateExpression(isDistinct = true).as("sum") :: Nil, - testRelation), + testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_DISTINCT_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(DISTINCT a) AS sum\"") @@ -819,7 +820,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { CollectMetrics( "event", Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum") :: Nil, - testRelation), + testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(sum(a)) AS sum\"") ) @@ -830,7 +831,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { WindowSpecDefinition(Nil, a.asc :: Nil, SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) assertAnalysisErrorClass( - CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation), + CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED", expectedMessageParameters = Map( "expr" -> @@ -848,14 +849,14 @@ class AnalysisSuite extends AnalysisTest with Matchers { // Same result - duplicate names are allowed assertAnalysisSuccess(Union( - CollectMetrics("evt1", count :: Nil, testRelation) :: - CollectMetrics("evt1", count :: Nil, testRelation) :: Nil)) + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: Nil)) // Same children, structurally different metrics - fail assertAnalysisErrorClass( Union( - CollectMetrics("evt1", count :: Nil, testRelation) :: - CollectMetrics("evt1", sum :: Nil, testRelation) :: Nil), + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: + CollectMetrics("evt1", sum :: Nil, testRelation, 1) :: Nil), expectedErrorClass = "DUPLICATED_METRICS_NAME", expectedMessageParameters = Map("metricName" -> "evt1") ) @@ -865,17 +866,17 @@ class AnalysisSuite extends AnalysisTest with Matchers { val tblB = LocalRelation(b) assertAnalysisErrorClass( Union( - CollectMetrics("evt1", count :: Nil, testRelation) :: - CollectMetrics("evt1", count :: Nil, tblB) :: Nil), + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: + CollectMetrics("evt1", count :: Nil, tblB, 1) :: Nil), expectedErrorClass = "DUPLICATED_METRICS_NAME", expectedMessageParameters = Map("metricName" -> "evt1") ) // Subquery different tree - fail - val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: Nil, testRelation)) + val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: Nil, testRelation, 0)) val query = Project( b :: ScalarSubquery(subquery, Nil).as("sum") :: Nil, - CollectMetrics("evt1", count :: Nil, tblB)) + CollectMetrics("evt1", count :: Nil, tblB, 1)) assertAnalysisErrorClass( query, expectedErrorClass = "DUPLICATED_METRICS_NAME", @@ -887,7 +888,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { case a: AggregateExpression => a.copy(filter = Some(true)) }.asInstanceOf[NamedExpression] assertAnalysisErrorClass( - CollectMetrics("evt1", sumWithFilter :: Nil, testRelation), + CollectMetrics("evt1", sumWithFilter :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(a) FILTER (WHERE true) AS sum\"") @@ -1667,18 +1668,4 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkAnalysis(ident2.select($"a"), testRelation.select($"a").analyze) } } - - test("simplifyPlanForCollectedMetrics should handle non alias-only project case") { - val inner = Project( - Seq( - Alias(testRelation2.output(0), "a")(), - testRelation2.output(1), - Alias(testRelation2.output(2), "c")(), - testRelation2.output(3), - testRelation2.output(4) - ), - testRelation2) - val actualPlan = getAnalyzer.simplifyPlanForCollectedMetrics(inner.canonicalized) - assert(actualPlan == testRelation2.canonicalized) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fd8421fa096..e047b927b90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2189,7 +2189,7 @@ class Dataset[T] private[sql]( */ @varargs def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan { - CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan) + CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 903565a6d59..d851eacd5ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -935,7 +935,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("UPDATE TABLE") case _: MergeIntoTable => throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("MERGE INTO TABLE") - case logical.CollectMetrics(name, metrics, child) => + case logical.CollectMetrics(name, metrics, child, _) => execution.CollectMetricsExec(name, metrics, planLater(child)) :: Nil case WriteFiles(child, fileFormat, partitionColumns, bucket, options, staticPartitions) => WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket, options, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org