This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 5af0819654a [SPARK-41086][SQL] Use DataFrame ID to semantically 
validate CollectMetrics
5af0819654a is described below

commit 5af0819654aca896d73c16875b07b2143cb1132c
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Fri Sep 22 11:07:25 2023 +0800

    [SPARK-41086][SQL] Use DataFrame ID to semantically validate CollectMetrics
    
    ### What changes were proposed in this pull request?
    
    In existing code, plan matching is used to validate if two CollectMetrics 
have the same name but different semantic. However, plan matching approach is 
fragile. A better way to tackle this is to just utilize the unique DataFrame 
Id. This is because observe API is only supported by DataFrame API. SQL does 
not have such syntax.
    
    So two CollectMetric are semantic the same if and only if they have same 
name and same DataFrame id.
    
    ### Why are the changes needed?
    
    This is to use a more stable approach to replace a fragile approach.
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    NO
    
    Closes #43010 from amaliujia/another_approch_for_collect_metrics.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 7c3c7c5a4bd94c9e05b5e680a5242c2485875633)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  6 +--
 python/pyspark/sql/connect/plan.py                 |  1 +
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  4 +-
 .../sql/catalyst/analysis/CheckAnalysis.scala      | 36 ++------------
 .../plans/logical/basicLogicalOperators.scala      |  3 +-
 .../sql/catalyst/analysis/AnalysisSuite.scala      | 55 +++++++++-------------
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  2 +-
 .../spark/sql/execution/SparkStrategies.scala      |  2 +-
 8 files changed, 35 insertions(+), 74 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 641dfc5dcd3..50a55f5e641 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -164,7 +164,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) 
extends Logging {
       case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION =>
         transformCachedRemoteRelation(rel.getCachedRemoteRelation)
       case proto.Relation.RelTypeCase.COLLECT_METRICS =>
-        transformCollectMetrics(rel.getCollectMetrics)
+        transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId)
       case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
       case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
         throw new IndexOutOfBoundsException("Expected Relation to be set, but 
is empty.")
@@ -1054,12 +1054,12 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       numPartitionsOpt)
   }
 
-  private def transformCollectMetrics(rel: proto.CollectMetrics): LogicalPlan 
= {
+  private def transformCollectMetrics(rel: proto.CollectMetrics, planId: 
Long): LogicalPlan = {
     val metrics = rel.getMetricsList.asScala.toSeq.map { expr =>
       Column(transformExpression(expr))
     }
 
-    CollectMetrics(rel.getName, metrics.map(_.named), 
transformRelation(rel.getInput))
+    CollectMetrics(rel.getName, metrics.map(_.named), 
transformRelation(rel.getInput), planId)
   }
 
   private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 196b1f119ba..b7ea1f94993 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1196,6 +1196,7 @@ class CollectMetrics(LogicalPlan):
         assert self._child is not None
 
         plan = proto.Relation()
+        plan.common.plan_id = self._child._plan_id
         plan.collect_metrics.input.CopyFrom(self._child.plan(session))
         plan.collect_metrics.name = self._name
         plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for 
x in self._exprs])
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6c5d19f58ac..8e3c9b30c61 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3796,9 +3796,9 @@ object CleanupAliases extends Rule[LogicalPlan] with 
AliasHelper {
       Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
         orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
 
-    case CollectMetrics(name, metrics, child) =>
+    case CollectMetrics(name, metrics, child, dataframeId) =>
       val cleanedMetrics = metrics.map(trimNonTopLevelAliases)
-      CollectMetrics(name, cleanedMetrics, child)
+      CollectMetrics(name, cleanedMetrics, child, dataframeId)
 
     case Unpivot(ids, values, aliases, variableColumnName, valueColumnNames, 
child) =>
       val cleanedIds = ids.map(_.map(trimNonTopLevelAliases))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 139fa34a1df..511f3622e7e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -484,7 +484,7 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
             groupingExprs.foreach(checkValidGroupingExprs)
             aggregateExprs.foreach(checkValidAggregateExpression)
 
-          case CollectMetrics(name, metrics, _) =>
+          case CollectMetrics(name, metrics, _, _) =>
             if (name == null || name.isEmpty) {
               operator.failAnalysis(
                 errorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME",
@@ -1075,17 +1075,15 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
    * are allowed (e.g. self-joins).
    */
   private def checkCollectedMetrics(plan: LogicalPlan): Unit = {
-    val metricsMap = mutable.Map.empty[String, LogicalPlan]
+    val metricsMap = mutable.Map.empty[String, CollectMetrics]
     def check(plan: LogicalPlan): Unit = plan.foreach { node =>
       node match {
-        case metrics @ CollectMetrics(name, _, _) =>
-          val simplifiedMetrics = 
simplifyPlanForCollectedMetrics(metrics.canonicalized)
+        case metrics @ CollectMetrics(name, _, _, dataframeId) =>
           metricsMap.get(name) match {
             case Some(other) =>
-              val simplifiedOther = 
simplifyPlanForCollectedMetrics(other.canonicalized)
               // Exact duplicates are allowed. They can be the result
               // of a CTE that is used multiple times or a self join.
-              if (simplifiedMetrics != simplifiedOther) {
+              if (dataframeId != other.dataframeId) {
                 failAnalysis(
                   errorClass = "DUPLICATED_METRICS_NAME",
                   messageParameters = Map("metricName" -> name))
@@ -1104,32 +1102,6 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
     check(plan)
   }
 
-  /**
-   * This method is only used for checking collected metrics. This method 
tries to
-   * remove extra project which only re-assign expr ids from the plan so that 
we can identify exact
-   * duplicates metric definition.
-   */
-  def simplifyPlanForCollectedMetrics(plan: LogicalPlan): LogicalPlan = {
-    plan.resolveOperators {
-      case p: Project if p.projectList.size == p.child.output.size =>
-        val assignExprIdOnly = p.projectList.zipWithIndex.forall {
-          case (Alias(attr: AttributeReference, _), index) =>
-            // The input plan of this method is already canonicalized. The 
attribute id becomes the
-            // ordinal of this attribute in the child outputs. So an 
alias-only Project means the
-            // the id of the aliased attribute is the same as its index in the 
project list.
-            attr.exprId.id == index
-          case (left: AttributeReference, index) =>
-            left.exprId.id == index
-          case _ => false
-        }
-        if (assignExprIdOnly) {
-          p.child
-        } else {
-          p
-        }
-    }
-  }
-
   /**
    * Validates to make sure the outer references appearing inside the subquery
    * are allowed.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 4bb830662a3..96b67fc52e0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1952,7 +1952,8 @@ trait SupportsSubquery extends LogicalPlan
 case class CollectMetrics(
     name: String,
     metrics: Seq[NamedExpression],
-    child: LogicalPlan)
+    child: LogicalPlan,
+    dataframeId: Long)
   extends UnaryNode {
 
   override lazy val resolved: Boolean = {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 57b37e67b32..802b6d471a6 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -771,34 +771,35 @@ class AnalysisSuite extends AnalysisTest with Matchers {
     val literal = Literal(1).as("lit")
 
     // Ok
-    assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, 
testRelation).resolved)
+    assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, 
testRelation, 0).resolved)
 
     // Bad name
-    assert(!CollectMetrics("", sum :: Nil, testRelation).resolved)
+    assert(!CollectMetrics("", sum :: Nil, testRelation, 0).resolved)
     assertAnalysisErrorClass(
-      CollectMetrics("", sum :: Nil, testRelation),
+      CollectMetrics("", sum :: Nil, testRelation, 0),
       expectedErrorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME",
       expectedMessageParameters = Map(
-        "operator" -> "'CollectMetrics , [sum(a#x) AS sum#xL]\n+- 
LocalRelation <empty>, [a#x]\n")
+        "operator" ->
+          "'CollectMetrics , [sum(a#x) AS sum#xL], 0\n+- LocalRelation 
<empty>, [a#x]\n")
     )
 
     // No columns
-    assert(!CollectMetrics("evt", Nil, testRelation).resolved)
+    assert(!CollectMetrics("evt", Nil, testRelation, 0).resolved)
 
     def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit 
= {
-      assertAnalysisError(CollectMetrics("event", exprs, testRelation), errors)
+      assertAnalysisError(CollectMetrics("event", exprs, testRelation, 0), 
errors)
     }
 
     // Unwrapped attribute
     assertAnalysisErrorClass(
-      CollectMetrics("event", a :: Nil, testRelation),
+      CollectMetrics("event", a :: Nil, testRelation, 0),
       expectedErrorClass = 
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
       expectedMessageParameters = Map("expr" -> "\"a\"")
     )
 
     // Unwrapped non-deterministic expression
     assertAnalysisErrorClass(
-      CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation),
+      CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation, 0),
       expectedErrorClass = 
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC",
       expectedMessageParameters = Map("expr" -> "\"rand(10) AS rnd\"")
     )
@@ -808,7 +809,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       CollectMetrics(
         "event",
         Sum(a).toAggregateExpression(isDistinct = true).as("sum") :: Nil,
-        testRelation),
+        testRelation, 0),
       expectedErrorClass =
         
"INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_DISTINCT_UNSUPPORTED",
       expectedMessageParameters = Map("expr" -> "\"sum(DISTINCT a) AS sum\"")
@@ -819,7 +820,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       CollectMetrics(
         "event",
         Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum") 
:: Nil,
-        testRelation),
+        testRelation, 0),
       expectedErrorClass = 
"INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED",
       expectedMessageParameters = Map("expr" -> "\"sum(sum(a)) AS sum\"")
     )
@@ -830,7 +831,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       WindowSpecDefinition(Nil, a.asc :: Nil,
         SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
     assertAnalysisErrorClass(
-      CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation),
+      CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation, 0),
       expectedErrorClass = 
"INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED",
       expectedMessageParameters = Map(
         "expr" ->
@@ -848,14 +849,14 @@ class AnalysisSuite extends AnalysisTest with Matchers {
 
     // Same result - duplicate names are allowed
     assertAnalysisSuccess(Union(
-      CollectMetrics("evt1", count :: Nil, testRelation) ::
-      CollectMetrics("evt1", count :: Nil, testRelation) :: Nil))
+      CollectMetrics("evt1", count :: Nil, testRelation, 0) ::
+      CollectMetrics("evt1", count :: Nil, testRelation, 0) :: Nil))
 
     // Same children, structurally different metrics - fail
     assertAnalysisErrorClass(
       Union(
-        CollectMetrics("evt1", count :: Nil, testRelation) ::
-          CollectMetrics("evt1", sum :: Nil, testRelation) :: Nil),
+        CollectMetrics("evt1", count :: Nil, testRelation, 0) ::
+          CollectMetrics("evt1", sum :: Nil, testRelation, 1) :: Nil),
       expectedErrorClass = "DUPLICATED_METRICS_NAME",
       expectedMessageParameters = Map("metricName" -> "evt1")
     )
@@ -865,17 +866,17 @@ class AnalysisSuite extends AnalysisTest with Matchers {
     val tblB = LocalRelation(b)
     assertAnalysisErrorClass(
       Union(
-        CollectMetrics("evt1", count :: Nil, testRelation) ::
-          CollectMetrics("evt1", count :: Nil, tblB) :: Nil),
+        CollectMetrics("evt1", count :: Nil, testRelation, 0) ::
+          CollectMetrics("evt1", count :: Nil, tblB, 1) :: Nil),
       expectedErrorClass = "DUPLICATED_METRICS_NAME",
       expectedMessageParameters = Map("metricName" -> "evt1")
     )
 
     // Subquery different tree - fail
-    val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: 
Nil, testRelation))
+    val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: 
Nil, testRelation, 0))
     val query = Project(
       b :: ScalarSubquery(subquery, Nil).as("sum") :: Nil,
-      CollectMetrics("evt1", count :: Nil, tblB))
+      CollectMetrics("evt1", count :: Nil, tblB, 1))
     assertAnalysisErrorClass(
       query,
       expectedErrorClass = "DUPLICATED_METRICS_NAME",
@@ -887,7 +888,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       case a: AggregateExpression => a.copy(filter = Some(true))
     }.asInstanceOf[NamedExpression]
     assertAnalysisErrorClass(
-      CollectMetrics("evt1", sumWithFilter :: Nil, testRelation),
+      CollectMetrics("evt1", sumWithFilter :: Nil, testRelation, 0),
       expectedErrorClass =
         
"INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED",
       expectedMessageParameters = Map("expr" -> "\"sum(a) FILTER (WHERE true) 
AS sum\"")
@@ -1667,18 +1668,4 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       checkAnalysis(ident2.select($"a"), testRelation.select($"a").analyze)
     }
   }
-
-  test("simplifyPlanForCollectedMetrics should handle non alias-only project 
case") {
-    val inner = Project(
-      Seq(
-        Alias(testRelation2.output(0), "a")(),
-        testRelation2.output(1),
-        Alias(testRelation2.output(2), "c")(),
-        testRelation2.output(3),
-        testRelation2.output(4)
-      ),
-      testRelation2)
-    val actualPlan = 
getAnalyzer.simplifyPlanForCollectedMetrics(inner.canonicalized)
-    assert(actualPlan == testRelation2.canonicalized)
-  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index fd8421fa096..e047b927b90 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2189,7 +2189,7 @@ class Dataset[T] private[sql](
   */
   @varargs
   def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = 
withTypedPlan {
-    CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan)
+    CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id)
   }
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 903565a6d59..d851eacd5ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -935,7 +935,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("UPDATE 
TABLE")
       case _: MergeIntoTable =>
         throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("MERGE INTO 
TABLE")
-      case logical.CollectMetrics(name, metrics, child) =>
+      case logical.CollectMetrics(name, metrics, child, _) =>
         execution.CollectMetricsExec(name, metrics, planLater(child)) :: Nil
       case WriteFiles(child, fileFormat, partitionColumns, bucket, options, 
staticPartitions) =>
         WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket, 
options,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to