cloud-fan commented on code in PR #54976:
URL: https://github.com/apache/spark/pull/54976#discussion_r3288078816


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveZip.scala:
##########
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeMap, Expression, NamedExpression, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Zip}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.ZIP
+
+/**
+ * Resolves a [[Zip]] node by rewriting it into a single [[Project]] over the 
shared base plan.
+ *
+ * Both children of Zip must derive from the same base plan through chains of 
scalar Project
+ * nodes (1:1 row mapping). `Project.resolved` already rejects Generator, 
AggregateExpression,
+ * and WindowExpression. This rule additionally rejects non-scalar Python UDFs 
(e.g.
+ * GROUPED_MAP), which are not caught by `Project.resolved`.
+ *
+ * This rule:
+ * 1. Waits for both children to be resolved
+ * 2. Strips Project layers from each side to find the base plan, composing 
alias
+ *    substitutions so the resulting expressions reference the base plan's 
attributes directly
+ * 3. Verifies the base plans produce the same result (via `sameResult`)
+ * 4. Verifies neither side contains a non-scalar Python UDF
+ * 5. Remaps the right side's attribute references to the left base plan's 
output
+ * 6. Produces a single Project that combines both sides' expressions
+ *
+ * If the base plans do not match, or a non-scalar Python UDF is present, the 
Zip node remains
+ * unresolved and CheckAnalysis will report a `ZIP_PLANS_NOT_MERGEABLE` error.
+ */
+object ResolveZip extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsUpWithPruning(
+    _.containsPattern(ZIP), ruleId) {
+    case z: Zip if z.childrenResolved =>
+      val (leftExprs, leftBase) = extractProjectAndBase(z.left)
+      val (rightExprs, rightBase) = extractProjectAndBase(z.right)
+      if (leftBase.sameResult(rightBase) && allScalar(leftExprs ++ 
rightExprs)) {
+        // Build an attribute mapping from rightBase output to leftBase output 
(by position)
+        val attrMapping = AttributeMap(rightBase.output.zip(leftBase.output))
+        // Remap right expressions to reference leftBase's attributes
+        val remappedRightExprs = rightExprs.map { expr =>
+          expr.transform {
+            case a: Attribute => attrMapping.getOrElse(a, a)
+          }.asInstanceOf[NamedExpression]
+        }
+        Project(leftExprs ++ remappedRightExprs, leftBase)
+      } else {
+        z
+      }
+  }
+
+  /**
+   * Walks down a chain of [[Project]] nodes, composing alias substitutions so 
the returned
+   * expressions reference the deepest non-Project base directly. Necessary 
because chains of
+   * `select`/`withColumn` produce nested Projects, and merging two Zip sides 
requires both to
+   * reach the same `sameResult` base regardless of chain depth.
+   */
+  private def extractProjectAndBase(
+      plan: LogicalPlan): (Seq[NamedExpression], LogicalPlan) = plan match {
+    case Project(projectList, child) => stripProjects(child, projectList)
+    case other => (other.output, other)
+  }
+
+  @scala.annotation.tailrec
+  private def stripProjects(
+      plan: LogicalPlan,
+      outerExprs: Seq[NamedExpression]): (Seq[NamedExpression], LogicalPlan) = 
plan match {
+    case Project(innerExprs, child) =>
+      val aliasMap = AttributeMap(innerExprs.collect {
+        case a: Alias => a.toAttribute -> a.child
+      })
+      val composed = outerExprs.map(substitute(_, aliasMap))
+      stripProjects(child, composed)
+    case other => (outerExprs, other)
+  }
+
+  /**
+   * Replaces references to inner aliases inside `expr` with the underlying 
expressions. When
+   * `expr` is a bare [[Attribute]] that matches an inner alias, wraps the 
substituted expression
+   * in a fresh [[Alias]] preserving the outer name and exprId so downstream 
references stay
+   * stable.
+   */
+  private def substitute(
+      expr: NamedExpression, aliasMap: AttributeMap[Expression]): 
NamedExpression = expr match {
+    case attr: Attribute if aliasMap.contains(attr) =>
+      Alias(aliasMap(attr), attr.name)(exprId = attr.exprId)
+    case _ =>
+      expr.transform {
+        case a: Attribute if aliasMap.contains(a) => aliasMap(a)
+      }.asInstanceOf[NamedExpression]
+  }

Review Comment:
   This `stripProjects`/`substitute` chain is an unguarded re-implementation of 
`CollapseProject` in the analyzer. `CollapseProject.canCollapseExpressions` 
(Optimizer.scala:1429) refuses to inline an alias body when (a) the producer is 
nondeterministic or (b) a non-cheap producer is consumed more than once — 
exactly to prevent the rewritten plan from evaluating the producer more times 
than the original. `substitute` has none of those guards, so for any chain like 
`df.withColumn("r", rand()).withColumn("x", $"r" + $"r")` (or any `withColumn` 
that re-references an earlier alias more than once), the rewrite produces 
`Project([rand() + rand() AS x], df)` — two `Rand` evaluations per row, so `x` 
is no longer `2 * r`. The same query without `zip` returns `2 * r` because 
`CollapseProject` keeps the chain intact. Same hazard for `uuid()`, subquery 
expressions, and (the perf case) Python UDFs / heavy expressions referenced 
multiple times.
   
   The cleanest fix is to drop the "must rewrite to a single Project" 
constraint rather than guard the substitution. Concretely: rewrite `Zip(left, 
right)` into a dependency-ordered chain of Projects, where each user-written 
alias stays as a separate `Alias` in its own layer. Sketch:
   
   1. Find the merge point — both sides share a base after stripping outer 
`Project` chains (same `sameResult` predicate this rule uses today).
   2. For each side, collect the *new* aliases each `Project` layer introduces 
(skip pure passthrough Attribute references — only `Alias` entries count).
   3. Order the collected aliases by dependency depth: depth-1 aliases 
reference only base attributes, depth-k aliases reference at least one 
depth-(k-1) alias. Each depth becomes one Project layer.
   4. At each depth's Project layer, include the new aliases at that depth plus 
passthroughs of any attribute referenced by a higher depth or by the final top 
Project.
   5. Top: `Project(left.top.output ++ right.top.output_remapped, 
deepest_layer)`.
   
   This handles asymmetric chain lengths naturally — the merged plan's depth is 
whatever the dependency graph demands (typically `max(left.depth, right.depth)` 
once pure passthrough layers like `select` get absorbed). No substitution, so 
`rand()` stays as one alias, expensive UDFs stay as one alias, and 
`CollapseProject` runs afterward with all its existing guards intact.



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveZipSuite.scala:
##########
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class ResolveZipSuite extends AnalysisTest {
+
+  private val base = LocalRelation($"a".int, $"b".int, $"c".int)
+
+  object Resolve extends RuleExecutor[LogicalPlan] {
+    override val batches: Seq[Batch] = Seq(
+      Batch("ResolveZip", Once, ResolveZip))
+  }
+
+  test("resolve Zip: both sides have Project over same base") {
+    val left = Project(Seq(base.output(0)), base)
+    val right = Project(Seq(base.output(1)), base)
+    val zip = Zip(left, right)
+
+    val resolved = Resolve.execute(zip)
+    val expected = Project(Seq(base.output(0), base.output(1)), base)
+    comparePlans(resolved, expected)
+  }
+
+  test("resolve Zip: left is bare plan, right has Project") {
+    val right = Project(Seq(base.output(0)), base)
+    val zip = Zip(base, right)
+
+    val resolved = Resolve.execute(zip)
+    val expected = Project(base.output ++ Seq(base.output(0)), base)
+    comparePlans(resolved, expected)
+  }
+
+  test("resolve Zip: both sides are bare same plan") {
+    val zip = Zip(base, base)
+
+    val resolved = Resolve.execute(zip)
+    val expected = Project(base.output ++ base.output, base)
+    comparePlans(resolved, expected)
+  }
+
+  test("resolve Zip: both sides have expressions over same base") {
+    val left = base.select(($"a" + 1).as("a_plus_1"))
+    val right = base.select(($"b" * 2).as("b_times_2"))
+    val zip = Zip(left.analyze, right.analyze)
+
+    val resolved = Resolve.execute(zip)
+    assert(!resolved.isInstanceOf[Zip], "Zip should have been resolved to a 
Project")
+    assert(resolved.isInstanceOf[Project])
+    assert(resolved.output.length == 2)
+    assert(resolved.output(0).name == "a_plus_1")
+    assert(resolved.output(1).name == "b_times_2")
+  }
+
+  test("resolve Zip: different base plans - Zip remains unresolved") {
+    val base2 = LocalRelation($"x".int, $"y".int, $"z".int, $"w".int)
+    val left = Project(Seq(base.output(0)), base)
+    val right = Project(Seq(base2.output(0)), base2)
+    val zip = Zip(left, right)
+
+    val resolved = Resolve.execute(zip)
+    // ResolveZip cannot merge, so Zip stays
+    assert(resolved.isInstanceOf[Zip])
+  }
+
+  test("resolve Zip: skipped when children are unresolved") {
+    val unresolvedChild = Project(
+      Seq(UnresolvedAttribute("a")),
+      UnresolvedRelation(Seq("t")))
+    val zip = Zip(unresolvedChild, unresolvedChild)
+
+    val result = Resolve.execute(zip)
+    // Zip should remain unchanged because children are not resolved
+    assert(result.isInstanceOf[Zip])
+  }
+
+  test("CheckAnalysis: different base plans throws ZIP_PLANS_NOT_MERGEABLE") {
+    val base2 = LocalRelation($"x".int, $"y".int, $"z".int, $"w".int)
+    val left = Project(Seq(base.output(0)), base)
+    val right = Project(Seq(base2.output(0)), base2)
+    val zip = Zip(left, right)
+
+    assertAnalysisErrorCondition(
+      zip,
+      expectedErrorCondition = "ZIP_PLANS_NOT_MERGEABLE",
+      expectedMessageParameters = Map.empty
+    )
+  }
+
+  test("resolve Zip: longer chain of selects on both sides") {
+    // Left has 3 nested Projects, right has 1 Project. Both reach the same 
base.
+    val left = Project(Seq(base.output(0)),
+      Project(Seq(base.output(0), base.output(1)),
+        Project(base.output, base)))
+    val right = Project(Seq(base.output(1)), base)
+    val zip = Zip(left, right)
+
+    val resolved = Resolve.execute(zip)
+    assert(resolved.isInstanceOf[Project], "Asymmetric chain should still 
merge to a Project")
+    assert(resolved.output.map(_.name) == Seq("a", "b"))
+  }
+
+  test("resolve Zip: chained Project with aliases composes substitutions") {
+    // Build df.select(a + 1 AS x).select(x * 2 AS y) — outer references the 
inner alias.

Review Comment:
   Non-ASCII em-dash `—` will trip scalastyle's nonascii rule. Use `--`.
   
   ```suggestion
       // Build df.select(a + 1 AS x).select(x * 2 AS y) -- outer references 
the inner alias.
   ```



##########
sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala:
##########
@@ -819,6 +819,26 @@ abstract class Dataset[T] extends Serializable {
    */
   def crossJoin(right: Dataset[_]): DataFrame
 
+  /**
+   * Combines the columns of this DataFrame with another DataFrame 
side-by-side, preserving row
+   * alignment between the two inputs.
+   *
+   * Both DataFrames must derive from a common source DataFrame through 
column-only operations
+   * (such as `select` or `withColumn`) that preserve the row-to-row mapping 
of the source.
+   * Operations that change row identity or count -- including `filter`, 
`join`, `groupBy`,
+   * `distinct`, `orderBy`, `limit`, and non-scalar Python UDFs -- are not 
supported on either
+   * side. An `AnalysisException` is thrown when the two DataFrames cannot be 
aligned.

Review Comment:
   The actual rule is more permissive than this passage suggests — `ResolveZip` 
strips outer `Project` chains and then checks `sameResult` on the bases, so 
anything you put *below* the strippable Projects is accepted as long as it's 
identical on both sides. `df.filter(p).select(a).zip(df.filter(p).select(b))`, 
two identical `orderBy`s, and two identical aggregates all merge today, and 
they're row-aligned for free (the shared sub-plan still executes exactly once 
since the right base is discarded). Worth restating the contract honestly — 
e.g., "both DataFrames must produce the same canonicalized plan after stripping 
outer `Project` chains" — and fixing the same passage in 
`python/pyspark/sql/dataframe.py`.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala:
##########
@@ -690,6 +690,19 @@ trait CheckAnalysis extends LookupCatalog with 
QueryErrorsBase with PlanToString
               errorClass = 
"NEAREST_BY_JOIN.EXACT_WITH_NONDETERMINISTIC_EXPRESSION",
               messageParameters = Map("expression" -> 
toSQLExpr(rankingExpression)))
 
+          case z: Zip =>
+            def stripProjects(plan: LogicalPlan): LogicalPlan = plan match {
+              case Project(_, child) => stripProjects(child)
+              case other => other
+            }
+            val leftBase = stripProjects(z.left)
+            val rightBase = stripProjects(z.right)
+            if (!leftBase.sameResult(rightBase)) {
+              z.failAnalysis(
+                errorClass = "ZIP_PLANS_NOT_MERGEABLE",
+                messageParameters = Map.empty)
+            }

Review Comment:
   Two concerns about this block:
   
   (a) **Logic duplication.** `stripProjects` + `sameResult` is exactly the 
predicate `ResolveZip` already uses. If `ResolveZip` ever broadens the strip, 
this check silently drifts. Worth extracting a helper on `ResolveZip` and 
calling it from both sites.
   
   (b) **Cryptic error for the non-scalar Python UDF case.** When the two sides 
share a base but one contains a non-scalar Python UDF, `ResolveZip` returns the 
unresolved `Zip` (because `allScalar` is false), and this guard's `!sameResult` 
is false — so no error is emitted here. The traversal then hits the `case o if 
!o.resolved` catch-all at line 1093 and throws 
`SparkException.internalError("Found the unresolved operator: Zip ...")`. The 
PR description promises a clean rejection of non-scalar UDFs; instead the user 
sees an `INTERNAL_ERROR`. Suggest dropping the `sameResult` guard so any 
surviving `Zip` raises `ZIP_PLANS_NOT_MERGEABLE` (and broadening the error 
message accordingly), or adding a distinct error class for the UDF case.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to