This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7c3c7c5a4bd [SPARK-41086][SQL] Use DataFrame ID to semantically 
validate CollectMetrics
7c3c7c5a4bd is described below

commit 7c3c7c5a4bd94c9e05b5e680a5242c2485875633
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Fri Sep 22 11:07:25 2023 +0800

    [SPARK-41086][SQL] Use DataFrame ID to semantically validate CollectMetrics
    
    ### What changes were proposed in this pull request?
    
    In existing code, plan matching is used to validate if two CollectMetrics 
have the same name but different semantic. However, plan matching approach is 
fragile. A better way to tackle this is to just utilize the unique DataFrame 
Id. This is because observe API is only supported by DataFrame API. SQL does 
not have such syntax.
    
    So two CollectMetric are semantic the same if and only if they have same 
name and same DataFrame id.
    
    ### Why are the changes needed?
    
    This is to use a more stable approach to replace a fragile approach.
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    NO
    
    Closes #43010 from amaliujia/another_approch_for_collect_metrics.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  6 +--
 python/pyspark/sql/connect/plan.py                 |  1 +
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  4 +-
 .../sql/catalyst/analysis/CheckAnalysis.scala      | 36 ++------------
 .../plans/logical/basicLogicalOperators.scala      |  3 +-
 .../sql/catalyst/analysis/AnalysisSuite.scala      | 55 +++++++++-------------
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  2 +-
 .../spark/sql/execution/SparkStrategies.scala      |  2 +-
 8 files changed, 35 insertions(+), 74 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 924169715f7..dda7a713fa0 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -164,7 +164,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) 
extends Logging {
       case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION =>
         transformCachedRemoteRelation(rel.getCachedRemoteRelation)
       case proto.Relation.RelTypeCase.COLLECT_METRICS =>
-        transformCollectMetrics(rel.getCollectMetrics)
+        transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId)
       case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
       case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
         throw new IndexOutOfBoundsException("Expected Relation to be set, but 
is empty.")
@@ -1048,12 +1048,12 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       numPartitionsOpt)
   }
 
-  private def transformCollectMetrics(rel: proto.CollectMetrics): LogicalPlan 
= {
+  private def transformCollectMetrics(rel: proto.CollectMetrics, planId: 
Long): LogicalPlan = {
     val metrics = rel.getMetricsList.asScala.toSeq.map { expr =>
       Column(transformExpression(expr))
     }
 
-    CollectMetrics(rel.getName, metrics.map(_.named), 
transformRelation(rel.getInput))
+    CollectMetrics(rel.getName, metrics.map(_.named), 
transformRelation(rel.getInput), planId)
   }
 
   private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index d069081e1af..219545cf646 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1192,6 +1192,7 @@ class CollectMetrics(LogicalPlan):
         assert self._child is not None
 
         plan = proto.Relation()
+        plan.common.plan_id = self._child._plan_id
         plan.collect_metrics.input.CopyFrom(self._child.plan(session))
         plan.collect_metrics.name = self._name
         plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for 
x in self._exprs])
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cff29de858e..aac85e19721 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3893,9 +3893,9 @@ object CleanupAliases extends Rule[LogicalPlan] with 
AliasHelper {
       Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
         orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
 
-    case CollectMetrics(name, metrics, child) =>
+    case CollectMetrics(name, metrics, child, dataframeId) =>
       val cleanedMetrics = metrics.map(trimNonTopLevelAliases)
-      CollectMetrics(name, cleanedMetrics, child)
+      CollectMetrics(name, cleanedMetrics, child, dataframeId)
 
     case Unpivot(ids, values, aliases, variableColumnName, valueColumnNames, 
child) =>
       val cleanedIds = ids.map(_.map(trimNonTopLevelAliases))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 3c9a816df26..83b682bc917 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -497,7 +497,7 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
             groupingExprs.foreach(checkValidGroupingExprs)
             aggregateExprs.foreach(checkValidAggregateExpression)
 
-          case CollectMetrics(name, metrics, _) =>
+          case CollectMetrics(name, metrics, _, _) =>
             if (name == null || name.isEmpty) {
               operator.failAnalysis(
                 errorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME",
@@ -1097,17 +1097,15 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
    * are allowed (e.g. self-joins).
    */
   private def checkCollectedMetrics(plan: LogicalPlan): Unit = {
-    val metricsMap = mutable.Map.empty[String, LogicalPlan]
+    val metricsMap = mutable.Map.empty[String, CollectMetrics]
     def check(plan: LogicalPlan): Unit = plan.foreach { node =>
       node match {
-        case metrics @ CollectMetrics(name, _, _) =>
-          val simplifiedMetrics = 
simplifyPlanForCollectedMetrics(metrics.canonicalized)
+        case metrics @ CollectMetrics(name, _, _, dataframeId) =>
           metricsMap.get(name) match {
             case Some(other) =>
-              val simplifiedOther = 
simplifyPlanForCollectedMetrics(other.canonicalized)
               // Exact duplicates are allowed. They can be the result
               // of a CTE that is used multiple times or a self join.
-              if (simplifiedMetrics != simplifiedOther) {
+              if (dataframeId != other.dataframeId) {
                 failAnalysis(
                   errorClass = "DUPLICATED_METRICS_NAME",
                   messageParameters = Map("metricName" -> name))
@@ -1126,32 +1124,6 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
     check(plan)
   }
 
-  /**
-   * This method is only used for checking collected metrics. This method 
tries to
-   * remove extra project which only re-assign expr ids from the plan so that 
we can identify exact
-   * duplicates metric definition.
-   */
-  def simplifyPlanForCollectedMetrics(plan: LogicalPlan): LogicalPlan = {
-    plan.resolveOperators {
-      case p: Project if p.projectList.size == p.child.output.size =>
-        val assignExprIdOnly = p.projectList.zipWithIndex.forall {
-          case (Alias(attr: AttributeReference, _), index) =>
-            // The input plan of this method is already canonicalized. The 
attribute id becomes the
-            // ordinal of this attribute in the child outputs. So an 
alias-only Project means the
-            // the id of the aliased attribute is the same as its index in the 
project list.
-            attr.exprId.id == index
-          case (left: AttributeReference, index) =>
-            left.exprId.id == index
-          case _ => false
-        }
-        if (assignExprIdOnly) {
-          p.child
-        } else {
-          p
-        }
-    }
-  }
-
   /**
    * Validates to make sure the outer references appearing inside the subquery
    * are allowed.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index efb7dbb44ef..8f976a49a2b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1969,7 +1969,8 @@ trait SupportsSubquery extends LogicalPlan
 case class CollectMetrics(
     name: String,
     metrics: Seq[NamedExpression],
-    child: LogicalPlan)
+    child: LogicalPlan,
+    dataframeId: Long)
   extends UnaryNode {
 
   override lazy val resolved: Boolean = {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index ed3137430df..ffc12a2b981 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -779,34 +779,35 @@ class AnalysisSuite extends AnalysisTest with Matchers {
     val literal = Literal(1).as("lit")
 
     // Ok
-    assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, 
testRelation).resolved)
+    assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, 
testRelation, 0).resolved)
 
     // Bad name
-    assert(!CollectMetrics("", sum :: Nil, testRelation).resolved)
+    assert(!CollectMetrics("", sum :: Nil, testRelation, 0).resolved)
     assertAnalysisErrorClass(
-      CollectMetrics("", sum :: Nil, testRelation),
+      CollectMetrics("", sum :: Nil, testRelation, 0),
       expectedErrorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME",
       expectedMessageParameters = Map(
-        "operator" -> "'CollectMetrics , [sum(a#x) AS sum#xL]\n+- 
LocalRelation <empty>, [a#x]\n")
+        "operator" ->
+          "'CollectMetrics , [sum(a#x) AS sum#xL], 0\n+- LocalRelation 
<empty>, [a#x]\n")
     )
 
     // No columns
-    assert(!CollectMetrics("evt", Nil, testRelation).resolved)
+    assert(!CollectMetrics("evt", Nil, testRelation, 0).resolved)
 
     def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit 
= {
-      assertAnalysisError(CollectMetrics("event", exprs, testRelation), errors)
+      assertAnalysisError(CollectMetrics("event", exprs, testRelation, 0), 
errors)
     }
 
     // Unwrapped attribute
     assertAnalysisErrorClass(
-      CollectMetrics("event", a :: Nil, testRelation),
+      CollectMetrics("event", a :: Nil, testRelation, 0),
       expectedErrorClass = 
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
       expectedMessageParameters = Map("expr" -> "\"a\"")
     )
 
     // Unwrapped non-deterministic expression
     assertAnalysisErrorClass(
-      CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation),
+      CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation, 0),
       expectedErrorClass = 
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC",
       expectedMessageParameters = Map("expr" -> "\"rand(10) AS rnd\"")
     )
@@ -816,7 +817,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       CollectMetrics(
         "event",
         Sum(a).toAggregateExpression(isDistinct = true).as("sum") :: Nil,
-        testRelation),
+        testRelation, 0),
       expectedErrorClass =
         
"INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_DISTINCT_UNSUPPORTED",
       expectedMessageParameters = Map("expr" -> "\"sum(DISTINCT a) AS sum\"")
@@ -827,7 +828,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       CollectMetrics(
         "event",
         Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum") 
:: Nil,
-        testRelation),
+        testRelation, 0),
       expectedErrorClass = 
"INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED",
       expectedMessageParameters = Map("expr" -> "\"sum(sum(a)) AS sum\"")
     )
@@ -838,7 +839,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       WindowSpecDefinition(Nil, a.asc :: Nil,
         SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
     assertAnalysisErrorClass(
-      CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation),
+      CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation, 0),
       expectedErrorClass = 
"INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED",
       expectedMessageParameters = Map(
         "expr" ->
@@ -856,14 +857,14 @@ class AnalysisSuite extends AnalysisTest with Matchers {
 
     // Same result - duplicate names are allowed
     assertAnalysisSuccess(Union(
-      CollectMetrics("evt1", count :: Nil, testRelation) ::
-      CollectMetrics("evt1", count :: Nil, testRelation) :: Nil))
+      CollectMetrics("evt1", count :: Nil, testRelation, 0) ::
+      CollectMetrics("evt1", count :: Nil, testRelation, 0) :: Nil))
 
     // Same children, structurally different metrics - fail
     assertAnalysisErrorClass(
       Union(
-        CollectMetrics("evt1", count :: Nil, testRelation) ::
-          CollectMetrics("evt1", sum :: Nil, testRelation) :: Nil),
+        CollectMetrics("evt1", count :: Nil, testRelation, 0) ::
+          CollectMetrics("evt1", sum :: Nil, testRelation, 1) :: Nil),
       expectedErrorClass = "DUPLICATED_METRICS_NAME",
       expectedMessageParameters = Map("metricName" -> "evt1")
     )
@@ -873,17 +874,17 @@ class AnalysisSuite extends AnalysisTest with Matchers {
     val tblB = LocalRelation(b)
     assertAnalysisErrorClass(
       Union(
-        CollectMetrics("evt1", count :: Nil, testRelation) ::
-          CollectMetrics("evt1", count :: Nil, tblB) :: Nil),
+        CollectMetrics("evt1", count :: Nil, testRelation, 0) ::
+          CollectMetrics("evt1", count :: Nil, tblB, 1) :: Nil),
       expectedErrorClass = "DUPLICATED_METRICS_NAME",
       expectedMessageParameters = Map("metricName" -> "evt1")
     )
 
     // Subquery different tree - fail
-    val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: 
Nil, testRelation))
+    val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: 
Nil, testRelation, 0))
     val query = Project(
       b :: ScalarSubquery(subquery, Nil).as("sum") :: Nil,
-      CollectMetrics("evt1", count :: Nil, tblB))
+      CollectMetrics("evt1", count :: Nil, tblB, 1))
     assertAnalysisErrorClass(
       query,
       expectedErrorClass = "DUPLICATED_METRICS_NAME",
@@ -895,7 +896,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       case a: AggregateExpression => a.copy(filter = Some(true))
     }.asInstanceOf[NamedExpression]
     assertAnalysisErrorClass(
-      CollectMetrics("evt1", sumWithFilter :: Nil, testRelation),
+      CollectMetrics("evt1", sumWithFilter :: Nil, testRelation, 0),
       expectedErrorClass =
         
"INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED",
       expectedMessageParameters = Map("expr" -> "\"sum(a) FILTER (WHERE true) 
AS sum\"")
@@ -1675,18 +1676,4 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       checkAnalysis(ident2.select($"a"), testRelation.select($"a").analyze)
     }
   }
-
-  test("simplifyPlanForCollectedMetrics should handle non alias-only project 
case") {
-    val inner = Project(
-      Seq(
-        Alias(testRelation2.output(0), "a")(),
-        testRelation2.output(1),
-        Alias(testRelation2.output(2), "c")(),
-        testRelation2.output(3),
-        testRelation2.output(4)
-      ),
-      testRelation2)
-    val actualPlan = 
getAnalyzer.simplifyPlanForCollectedMetrics(inner.canonicalized)
-    assert(actualPlan == testRelation2.canonicalized)
-  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 528904bb29a..f07496e6430 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2218,7 +2218,7 @@ class Dataset[T] private[sql](
   */
   @varargs
   def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = 
withTypedPlan {
-    CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan)
+    CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id)
   }
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 903565a6d59..d851eacd5ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -935,7 +935,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("UPDATE 
TABLE")
       case _: MergeIntoTable =>
         throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("MERGE INTO 
TABLE")
-      case logical.CollectMetrics(name, metrics, child) =>
+      case logical.CollectMetrics(name, metrics, child, _) =>
         execution.CollectMetricsExec(name, metrics, planLater(child)) :: Nil
       case WriteFiles(child, fileFormat, partitionColumns, bucket, options, 
staticPartitions) =>
         WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket, 
options,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to