This is an automated email from the ASF dual-hosted git repository.

cloud-fan pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new c7d4cdd687b5 [SPARK-56977][SQL] RewriteNearestByJoin should respect 
joinType in the synthetic join
c7d4cdd687b5 is described below

commit c7d4cdd687b5083a8e4998c550822fb7824e87cc
Author: Zero Qu <[email protected]>
AuthorDate: Fri May 22 16:45:57 2026 +0800

    [SPARK-56977][SQL] RewriteNearestByJoin should respect joinType in the 
synthetic join
    
    ### What changes were proposed in this pull request?
    
    backport 
https://github.com/apache/spark/pull/56052/changes/0c83226c80df5fcb2b68c7543b6af30a461d662a
 to branch-4.x
    
    ### Why are the changes needed?
    
    The original implementation hardcoded the synthetic join to `LeftOuter` and 
justified it on the grounds that `LEFT OUTER` and `INNER` are equivalent for an 
unconditioned join when the right side is non-empty, and `Generate(outer = 
false)` would drop unwanted rows for `INNER` when right is empty.
    
    That reasoning holds for correctness but has a major performance cost:
    
    - **`INNER NEAREST BY` cannot be planned as a Cartesian product.** Spark's 
strategy picks `CartesianProductExec` only for `Inner` joins with no condition; 
an unconditioned `LeftOuter` join falls back to `BroadcastNestedLoopJoin`, 
which tries to broadcast the right side. When the right relation is large, the 
broadcast either OOMs or exceeds `spark.sql.autoBroadcastJoinThreshold` and the 
planner is left with no good option. `CartesianProductExec` partitions both 
sides and streams pairs, [...]
    - It also makes the EXPLAIN output misleading (shows `LeftOuter` even 
though the user wrote `INNER`).
    - For `INNER` with an empty right side, the old plan generates one row per 
left input and then filters them away via `Generate(outer = false)` and the 
`size(matches) > 0` filter -- extra work that respecting `joinType` avoids at 
the source.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No change in query results. `EXPLAIN` output for `INNER NEAREST BY` queries 
now shows `Inner` rather than `LeftOuter` for the synthetic join node, and the 
physical plan for such queries can now use `CartesianProductExec` instead of 
`BroadcastNestedLoopJoin` when the right relation is too large to broadcast.
    
    ### How was this patch tested?
    
    - `RewriteNearestByJoinSuite`: `expectedRewrite` now takes a `joinType: 
JoinType` and the existing tests (similarity/distance x inner/leftouter, EXACT, 
boundary k, self-join, nondeterministic ranking) assert the synthetic join 
matches the user's join type.
    - Golden file `sql-tests/results/join-nearest-by.sql.out`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Coauthored-by: Claude Code (Opus 4.7), human-reviewed and tested
    
    Closes #56052 from zhidongqu-db/nearest-rewrite-fix-branch-4.x.
    
    Authored-by: Zero Qu <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../catalyst/optimizer/RewriteNearestByJoin.scala  | 14 +++---
 .../optimizer/RewriteNearestByJoinSuite.scala      | 51 +++++++++++++++++-----
 .../analyzer-results/join-nearest-by.sql.out       | 21 +++++++++
 .../resources/sql-tests/inputs/join-nearest-by.sql |  5 +++
 .../sql-tests/results/join-nearest-by.sql.out      | 26 +++++++----
 5 files changed, 90 insertions(+), 27 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala
index 3d45855cd60d..e920bbfffc55 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.rules._
  *         +- Aggregate [__qid],
  *              [first(left.col0) AS left.col0, ..., first(left.colN-1) AS 
left.colN-1,
  *               max_by(struct(right.*), expr, k) AS _matches]
- *             +- Join LeftOuter
+ *             +- Join Inner   // or LeftOuter for `LEFT OUTER NEAREST BY`
  *                :- Project [left.*, uuid() AS __qid]
  *                :  +- left
  *                +- right
@@ -79,18 +79,18 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] {
       val taggedLeft = Project(left.output :+ qidAlias, left)
       val qidAttr = qidAlias.toAttribute
 
-      // 2. LEFT OUTER-join the tagged left with right (no join condition). 
LEFT OUTER
-      //    (rather than INNER) preserves left rows even when `right` is 
empty, so that a
-      //    `LEFT OUTER NEAREST BY` query still returns those rows with `NULL` 
right-side
-      //    columns after the aggregate + inline below. When `right` is 
non-empty every left
-      //    row already has right-row pairings, so LEFT OUTER and INNER are 
equivalent.
+      // 2. Join the tagged left with right (no join condition), using the 
user's join type.
+      //    For `LEFT OUTER`, left rows with no right-side match are preserved 
with `NULL`
+      //    right-side columns through the aggregate + inline below; for 
`INNER`, such rows
+      //    are dropped. When `right` is non-empty every left row already has 
right-row
+      //    pairings, so `LEFT OUTER` and `INNER` are equivalent in that case.
       //
       //    This synthetic join is an unconditioned cross-product, so `NEAREST 
BY` queries
       //    are subject to `CheckCartesianProducts` and will be rejected when 
the user has
       //    set `spark.sql.crossJoin.enabled = false`. That is intentional: if 
the user has
       //    opted out of cross-products, the NEAREST BY rewrite -- which is 
itself a bounded
       //    cross-product today -- should not silently bypass that choice.
-      val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE)
+      val join = Join(taggedLeft, right, joinType, None, JoinHint.NONE)
 
       val (aggInput, rankingForAgg) = if (!rankingExpression.deterministic) {
         val rankingAlias = Alias(rankingExpression, "__ranking__")()
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala
index 650bdc7a6c35..729b58394d4b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
CreateStruct, Inline, Literal, Rand, Uuid}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
First, MaxMinByK}
-import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, 
NearestByDistance, NearestBySimilarity, PlanTest}
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, 
NearestByDistance, NearestBySimilarity, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, 
JoinHint, LocalRelation, NearestByJoin, Project}
 import org.apache.spark.sql.types.IntegerType
 
@@ -41,10 +41,10 @@ class RewriteNearestByJoinSuite extends PlanTest {
       numResults: Int,
       ranking: org.apache.spark.sql.catalyst.expressions.Expression,
       reverse: Boolean,
-      outer: Boolean) = {
+      joinType: JoinType) = {
     val qidAlias = Alias(Uuid(Some(0L)), "__qid")()
     val taggedLeft = Project(left.output :+ qidAlias, left)
-    val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE)
+    val join = Join(taggedLeft, right, joinType, None, JoinHint.NONE)
 
     val rightStruct = CreateStruct(right.output)
     val topKAgg = MaxMinByK(
@@ -66,7 +66,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
     val generate = Generate(
       Inline(matchesAlias.toAttribute),
       unrequiredChildIndex = 
Seq(aggregate.output.indexOf(matchesAlias.toAttribute)),
-      outer = outer,
+      outer = joinType == LeftOuter,
       qualifier = None,
       generatorOutput = generatorOutput,
       child = aggregate)
@@ -89,7 +89,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
     val expected = expectedRewrite(
       left, right, 5,
       ranking = left.output(0) + right.output(0),
-      reverse = false, outer = false)
+      reverse = false, joinType = Inner)
 
     comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
   }
@@ -106,7 +106,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
     val expected = expectedRewrite(
       left, right, 3,
       ranking = left.output(0) - right.output(0),
-      reverse = true, outer = false)
+      reverse = true, joinType = Inner)
 
     comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
   }
@@ -123,7 +123,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
     val expected = expectedRewrite(
       left, right, 1,
       ranking = left.output(0) + right.output(0),
-      reverse = false, outer = true)
+      reverse = false, joinType = LeftOuter)
 
     comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
   }
@@ -140,11 +140,38 @@ class RewriteNearestByJoinSuite extends PlanTest {
     val expected = expectedRewrite(
       left, right, 2,
       ranking = left.output(0) - right.output(0),
-      reverse = true, outer = true)
+      reverse = true, joinType = LeftOuter)
 
     comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
   }
 
+  test("synthetic Join uses the user's joinType") {
+    // Locks in that the rewrite's synthetic Join carries the user's `joinType`
+    // (Inner or LeftOuter).
+    val left = LocalRelation($"a".int, $"b".int)
+    val right = LocalRelation($"x".int, $"y".int)
+    Seq(Inner, LeftOuter).foreach { joinType =>
+      val query = NearestByJoin(
+        left, right, joinType, approx = true, numResults = 1,
+        rankingExpression = left.output(0) + right.output(0),
+        direction = NearestBySimilarity)
+
+      val rewritten = RewriteNearestByJoin(query.analyze)
+      val syntheticJoin = rewritten.collect { case j: Join => j }
+      assert(syntheticJoin.size == 1,
+        s"expected exactly one synthetic Join in the rewritten plan, got 
${syntheticJoin.size}")
+      assert(syntheticJoin.head.joinType == joinType,
+        s"expected synthetic Join to use $joinType, got 
${syntheticJoin.head.joinType}")
+
+      val generate = rewritten.collect { case g: Generate => g }
+      assert(generate.size == 1,
+        s"expected exactly one Generate in the rewritten plan, got 
${generate.size}")
+      val expectedOuter = joinType == LeftOuter
+      assert(generate.head.outer == expectedOuter,
+        s"expected Generate.outer == $expectedOuter for $joinType, got 
${generate.head.outer}")
+    }
+  }
+
   test("EXACT (approx = false) produces the same rewrite as APPROX") {
     // Locks in the current invariant that APPROX and EXACT lower through the 
same
     // brute-force rewrite. If a future change diverges them (e.g. an 
APPROX-only
@@ -160,7 +187,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
     val expected = expectedRewrite(
       left, right, 5,
       ranking = left.output(0) + right.output(0),
-      reverse = false, outer = false)
+      reverse = false, joinType = Inner)
 
     comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
   }
@@ -177,7 +204,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
     val expected = expectedRewrite(
       left, right, 1,
       ranking = left.output(0) + right.output(0),
-      reverse = false, outer = false)
+      reverse = false, joinType = Inner)
 
     comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
   }
@@ -194,7 +221,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
     val expected = expectedRewrite(
       left, right, NearestByJoin.MaxNumResults,
       ranking = left.output(0) + right.output(0),
-      reverse = false, outer = false)
+      reverse = false, joinType = Inner)
 
     comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
   }
@@ -214,7 +241,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
     val expected = expectedRewrite(
       t, tDup, 1,
       ranking = t.output(0) + tDup.output(0),
-      reverse = false, outer = false)
+      reverse = false, joinType = Inner)
 
     comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
   }
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out
index 7a795123cdcc..48819f172310 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out
@@ -129,6 +129,27 @@ Project [user_id#x, product#x]
                      +- LocalRelation [col1#x, col2#x]
 
 
+-- !query
+SELECT u.user_id, p.product
+FROM users u INNER JOIN (SELECT * FROM products WHERE false) p
+  APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore)
+-- !query analysis
+Project [user_id#x, product#x]
++- NearestByJoin Inner, true, 1, -abs((score#x - pscore#x)), 
NearestBySimilarity
+   :- SubqueryAlias u
+   :  +- SubqueryAlias spark_catalog.default.users
+   :     +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x])
+   :        +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as 
decimal(3,1)) AS score#x]
+   :           +- LocalRelation [col1#x, col2#x]
+   +- SubqueryAlias p
+      +- Project [product#x, pscore#x]
+         +- Filter false
+            +- SubqueryAlias spark_catalog.default.products
+               +- View (`spark_catalog`.`default`.`products`, [product#x, 
pscore#x])
+                  +- Project [cast(col1#x as string) AS product#x, cast(col2#x 
as decimal(3,1)) AS pscore#x]
+                     +- LocalRelation [col1#x, col2#x]
+
+
 -- !query
 SELECT u.user_id, p.product
 FROM users u INNER JOIN products p
diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql 
b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql
index 20b9b2fb7316..6b3dc63d28e3 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql
@@ -36,6 +36,11 @@ SELECT u.user_id, p.product
 FROM users u LEFT OUTER JOIN (SELECT * FROM products WHERE false) p
   APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore);
 
+-- INNER JOIN with NEAREST BY, empty right side
+SELECT u.user_id, p.product
+FROM users u INNER JOIN (SELECT * FROM products WHERE false) p
+  APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore);
+
 -- Explicit INNER keyword
 SELECT u.user_id, p.product
 FROM users u INNER JOIN products p
diff --git 
a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out 
b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out
index 286c61723b28..d06fb53686e7 100644
--- a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out
@@ -90,6 +90,16 @@ struct<user_id:int,product:string>
 3      NULL
 
 
+-- !query
+SELECT u.user_id, p.product
+FROM users u INNER JOIN (SELECT * FROM products WHERE false) p
+  APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore)
+-- !query schema
+struct<user_id:int,product:string>
+-- !query output
+
+
+
 -- !query
 SELECT u.user_id, p.product
 FROM users u INNER JOIN products p
@@ -286,12 +296,12 @@ AdaptiveSparkPlan isFinalPlan=false
                   +- SortAggregate(key=[__qid#x], 
functions=[partial_first(user_id#x, false), 
partial_max_by(named_struct(product, product#x, pscore, pscore#x), 
__ranking__#x, 1, false, 0, 0)])
                      +- Sort [__qid#x ASC NULLS FIRST], false, 0
                         +- Project [user_id#x, __qid#x, product#x, pscore#x, 
(rand(0) + cast(pscore#x as double)) AS __ranking__#x]
-                           +- BroadcastNestedLoopJoin BuildRight, LeftOuter
-                              :- Project [col1#x AS user_id#x, uuid(Some(x)) 
AS __qid#x]
-                              :  +- LocalTableScan [col1#x, col2#x]
-                              +- BroadcastExchange IdentityBroadcastMode, 
[plan_id=x]
-                                 +- Project [col1#x AS product#x, col2#x AS 
pscore#x]
-                                    +- LocalTableScan [col1#x, col2#x]
+                           +- BroadcastNestedLoopJoin BuildLeft, Inner
+                              :- BroadcastExchange IdentityBroadcastMode, 
[plan_id=x]
+                              :  +- Project [col1#x AS user_id#x, 
uuid(Some(x)) AS __qid#x]
+                              :     +- LocalTableScan [col1#x, col2#x]
+                              +- Project [col1#x AS product#x, col2#x AS 
pscore#x]
+                                 +- LocalTableScan [col1#x, col2#x]
 
 
 -- !query
@@ -313,7 +323,7 @@ AdaptiveSparkPlan isFinalPlan=false
                +- Exchange hashpartitioning(__qid#x, 4), ENSURE_REQUIREMENTS, 
[plan_id=x]
                   +- SortAggregate(key=[__qid#x], 
functions=[partial_first(user_id#x, false), 
partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x 
- pscore#x)), 2, true, 0, 0)])
                      +- Sort [__qid#x ASC NULLS FIRST], false, 0
-                        +- BroadcastNestedLoopJoin BuildRight, LeftOuter
+                        +- BroadcastNestedLoopJoin BuildRight, Inner
                            :- Filter (user_id#x > 1)
                            :  +- Project [col1#x AS user_id#x, col2#x AS 
score#x, uuid(Some(x)) AS __qid#x]
                            :     +- LocalTableScan [col1#x, col2#x]
@@ -342,7 +352,7 @@ AdaptiveSparkPlan isFinalPlan=false
                   +- Exchange hashpartitioning(__qid#x, 4), 
ENSURE_REQUIREMENTS, [plan_id=x]
                      +- SortAggregate(key=[__qid#x], 
functions=[partial_first(user_id#x, false), 
partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x 
- pscore#x)), 2, true, 0, 0)])
                         +- Sort [__qid#x ASC NULLS FIRST], false, 0
-                           +- BroadcastNestedLoopJoin BuildRight, LeftOuter
+                           +- BroadcastNestedLoopJoin BuildRight, Inner
                               :- Project [col1#x AS user_id#x, col2#x AS 
score#x, uuid(Some(x)) AS __qid#x]
                               :  +- LocalTableScan [col1#x, col2#x]
                               +- BroadcastExchange IdentityBroadcastMode, 
[plan_id=x]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to