This is an automated email from the ASF dual-hosted git repository.
dtenedor pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 525caed9c660 [SPARK-56977][SQL] RewriteNearestByJoin should respect
joinType in the synthetic join
525caed9c660 is described below
commit 525caed9c6609204b3b32ce1d1c237d7ed37bb62
Author: Zero Qu <[email protected]>
AuthorDate: Thu May 21 14:06:59 2026 -0700
[SPARK-56977][SQL] RewriteNearestByJoin should respect joinType in the
synthetic join
### What changes were proposed in this pull request?
This PR changes `RewriteNearestByJoin` to construct its synthetic
cross-join with the user's `joinType` (`Inner` or `LeftOuter`) instead of
always using `LeftOuter`. The `Generate` operator's `outer` flag continues to
be derived from `joinType == LeftOuter`, so the externally observable semantics
are unchanged.
### Why are the changes needed?
The original implementation hardcoded the synthetic join to `LeftOuter` and
justified it on the grounds that `LEFT OUTER` and `INNER` are equivalent for an
unconditioned join when the right side is non-empty, and `Generate(outer =
false)` would drop unwanted rows for `INNER` when right is empty.
That reasoning holds for correctness but has a major performance cost:
- **`INNER NEAREST BY` cannot be planned as a Cartesian product.** Spark's
strategy picks `CartesianProductExec` only for `Inner` joins with no condition;
an unconditioned `LeftOuter` join falls back to `BroadcastNestedLoopJoin`,
which tries to broadcast the right side. When the right relation is large, the
broadcast either OOMs or exceeds `spark.sql.autoBroadcastJoinThreshold` and the
planner is left with no good option. `CartesianProductExec` partitions both
sides and streams pairs, [...]
- It also makes the EXPLAIN output misleading (shows `LeftOuter` even
though the user wrote `INNER`).
- For `INNER` with an empty right side, the old plan generates one row per
left input and then filters them away via `Generate(outer = false)` and the
`size(matches) > 0` filter -- extra work that respecting `joinType` avoids at
the source.
### Does this PR introduce _any_ user-facing change?
No change in query results. `EXPLAIN` output for `INNER NEAREST BY` queries
now shows `Inner` rather than `LeftOuter` for the synthetic join node, and the
physical plan for such queries can now use `CartesianProductExec` instead of
`BroadcastNestedLoopJoin` when the right relation is too large to broadcast.
### How was this patch tested?
- `RewriteNearestByJoinSuite`: `expectedRewrite` now takes a `joinType:
JoinType` and the existing tests (similarity/distance x inner/leftouter, EXACT,
boundary k, self-join, nondeterministic ranking) assert the synthetic join
matches the user's join type.
- Golden file `sql-tests/results/join-nearest-by.sql.out`
### Was this patch authored or co-authored using generative AI tooling?
Coauthored-by: Claude Code (Opus 4.7), human-reviewed and tested
Closes #56023 from zhidongqu-db/respect-nn-join-type.
Authored-by: Zero Qu <[email protected]>
Signed-off-by: Daniel Tenedorio <[email protected]>
---
.../catalyst/optimizer/RewriteNearestByJoin.scala | 14 +++---
.../optimizer/RewriteNearestByJoinSuite.scala | 51 +++++++++++++++++-----
.../analyzer-results/join-nearest-by.sql.out | 21 +++++++++
.../resources/sql-tests/inputs/join-nearest-by.sql | 5 +++
.../sql-tests/results/join-nearest-by.sql.out | 26 +++++++----
5 files changed, 90 insertions(+), 27 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala
index 3d45855cd60d..e920bbfffc55 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.rules._
* +- Aggregate [__qid],
* [first(left.col0) AS left.col0, ..., first(left.colN-1) AS
left.colN-1,
* max_by(struct(right.*), expr, k) AS _matches]
- * +- Join LeftOuter
+ * +- Join Inner // or LeftOuter for `LEFT OUTER NEAREST BY`
* :- Project [left.*, uuid() AS __qid]
* : +- left
* +- right
@@ -79,18 +79,18 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] {
val taggedLeft = Project(left.output :+ qidAlias, left)
val qidAttr = qidAlias.toAttribute
- // 2. LEFT OUTER-join the tagged left with right (no join condition).
LEFT OUTER
- // (rather than INNER) preserves left rows even when `right` is
empty, so that a
- // `LEFT OUTER NEAREST BY` query still returns those rows with `NULL`
right-side
- // columns after the aggregate + inline below. When `right` is
non-empty every left
- // row already has right-row pairings, so LEFT OUTER and INNER are
equivalent.
+ // 2. Join the tagged left with right (no join condition), using the
user's join type.
+ // For `LEFT OUTER`, left rows with no right-side match are preserved
with `NULL`
+ // right-side columns through the aggregate + inline below; for
`INNER`, such rows
+ // are dropped. When `right` is non-empty every left row already has
right-row
+ // pairings, so `LEFT OUTER` and `INNER` are equivalent in that case.
//
// This synthetic join is an unconditioned cross-product, so `NEAREST
BY` queries
// are subject to `CheckCartesianProducts` and will be rejected when
the user has
// set `spark.sql.crossJoin.enabled = false`. That is intentional: if
the user has
// opted out of cross-products, the NEAREST BY rewrite -- which is
itself a bounded
// cross-product today -- should not silently bypass that choice.
- val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE)
+ val join = Join(taggedLeft, right, joinType, None, JoinHint.NONE)
val (aggInput, rankingForAgg) = if (!rankingExpression.deterministic) {
val rankingAlias = Alias(rankingExpression, "__ranking__")()
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala
index 650bdc7a6c35..729b58394d4b 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference,
CreateStruct, Inline, Literal, Rand, Uuid}
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
First, MaxMinByK}
-import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter,
NearestByDistance, NearestBySimilarity, PlanTest}
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter,
NearestByDistance, NearestBySimilarity, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join,
JoinHint, LocalRelation, NearestByJoin, Project}
import org.apache.spark.sql.types.IntegerType
@@ -41,10 +41,10 @@ class RewriteNearestByJoinSuite extends PlanTest {
numResults: Int,
ranking: org.apache.spark.sql.catalyst.expressions.Expression,
reverse: Boolean,
- outer: Boolean) = {
+ joinType: JoinType) = {
val qidAlias = Alias(Uuid(Some(0L)), "__qid")()
val taggedLeft = Project(left.output :+ qidAlias, left)
- val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE)
+ val join = Join(taggedLeft, right, joinType, None, JoinHint.NONE)
val rightStruct = CreateStruct(right.output)
val topKAgg = MaxMinByK(
@@ -66,7 +66,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val generate = Generate(
Inline(matchesAlias.toAttribute),
unrequiredChildIndex =
Seq(aggregate.output.indexOf(matchesAlias.toAttribute)),
- outer = outer,
+ outer = joinType == LeftOuter,
qualifier = None,
generatorOutput = generatorOutput,
child = aggregate)
@@ -89,7 +89,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 5,
ranking = left.output(0) + right.output(0),
- reverse = false, outer = false)
+ reverse = false, joinType = Inner)
comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
@@ -106,7 +106,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 3,
ranking = left.output(0) - right.output(0),
- reverse = true, outer = false)
+ reverse = true, joinType = Inner)
comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
@@ -123,7 +123,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 1,
ranking = left.output(0) + right.output(0),
- reverse = false, outer = true)
+ reverse = false, joinType = LeftOuter)
comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
@@ -140,11 +140,38 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 2,
ranking = left.output(0) - right.output(0),
- reverse = true, outer = true)
+ reverse = true, joinType = LeftOuter)
comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
+ test("synthetic Join uses the user's joinType") {
+ // Locks in that the rewrite's synthetic Join carries the user's `joinType`
+ // (Inner or LeftOuter).
+ val left = LocalRelation($"a".int, $"b".int)
+ val right = LocalRelation($"x".int, $"y".int)
+ Seq(Inner, LeftOuter).foreach { joinType =>
+ val query = NearestByJoin(
+ left, right, joinType, approx = true, numResults = 1,
+ rankingExpression = left.output(0) + right.output(0),
+ direction = NearestBySimilarity)
+
+ val rewritten = RewriteNearestByJoin(query.analyze)
+ val syntheticJoin = rewritten.collect { case j: Join => j }
+ assert(syntheticJoin.size == 1,
+ s"expected exactly one synthetic Join in the rewritten plan, got
${syntheticJoin.size}")
+ assert(syntheticJoin.head.joinType == joinType,
+ s"expected synthetic Join to use $joinType, got
${syntheticJoin.head.joinType}")
+
+ val generate = rewritten.collect { case g: Generate => g }
+ assert(generate.size == 1,
+ s"expected exactly one Generate in the rewritten plan, got
${generate.size}")
+ val expectedOuter = joinType == LeftOuter
+ assert(generate.head.outer == expectedOuter,
+ s"expected Generate.outer == $expectedOuter for $joinType, got
${generate.head.outer}")
+ }
+ }
+
test("EXACT (approx = false) produces the same rewrite as APPROX") {
// Locks in the current invariant that APPROX and EXACT lower through the
same
// brute-force rewrite. If a future change diverges them (e.g. an
APPROX-only
@@ -160,7 +187,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 5,
ranking = left.output(0) + right.output(0),
- reverse = false, outer = false)
+ reverse = false, joinType = Inner)
comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
@@ -177,7 +204,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 1,
ranking = left.output(0) + right.output(0),
- reverse = false, outer = false)
+ reverse = false, joinType = Inner)
comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
@@ -194,7 +221,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, NearestByJoin.MaxNumResults,
ranking = left.output(0) + right.output(0),
- reverse = false, outer = false)
+ reverse = false, joinType = Inner)
comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
@@ -214,7 +241,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
t, tDup, 1,
ranking = t.output(0) + tDup.output(0),
- reverse = false, outer = false)
+ reverse = false, joinType = Inner)
comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out
index 7a795123cdcc..48819f172310 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out
@@ -129,6 +129,27 @@ Project [user_id#x, product#x]
+- LocalRelation [col1#x, col2#x]
+-- !query
+SELECT u.user_id, p.product
+FROM users u INNER JOIN (SELECT * FROM products WHERE false) p
+ APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore)
+-- !query analysis
+Project [user_id#x, product#x]
++- NearestByJoin Inner, true, 1, -abs((score#x - pscore#x)),
NearestBySimilarity
+ :- SubqueryAlias u
+ : +- SubqueryAlias spark_catalog.default.users
+ : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x])
+ : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as
decimal(3,1)) AS score#x]
+ : +- LocalRelation [col1#x, col2#x]
+ +- SubqueryAlias p
+ +- Project [product#x, pscore#x]
+ +- Filter false
+ +- SubqueryAlias spark_catalog.default.products
+ +- View (`spark_catalog`.`default`.`products`, [product#x,
pscore#x])
+ +- Project [cast(col1#x as string) AS product#x, cast(col2#x
as decimal(3,1)) AS pscore#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
-- !query
SELECT u.user_id, p.product
FROM users u INNER JOIN products p
diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql
b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql
index 20b9b2fb7316..6b3dc63d28e3 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql
@@ -36,6 +36,11 @@ SELECT u.user_id, p.product
FROM users u LEFT OUTER JOIN (SELECT * FROM products WHERE false) p
APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore);
+-- INNER JOIN with NEAREST BY, empty right side
+SELECT u.user_id, p.product
+FROM users u INNER JOIN (SELECT * FROM products WHERE false) p
+ APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore);
+
-- Explicit INNER keyword
SELECT u.user_id, p.product
FROM users u INNER JOIN products p
diff --git
a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out
b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out
index 286c61723b28..d06fb53686e7 100644
--- a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out
@@ -90,6 +90,16 @@ struct<user_id:int,product:string>
3 NULL
+-- !query
+SELECT u.user_id, p.product
+FROM users u INNER JOIN (SELECT * FROM products WHERE false) p
+ APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore)
+-- !query schema
+struct<user_id:int,product:string>
+-- !query output
+
+
+
-- !query
SELECT u.user_id, p.product
FROM users u INNER JOIN products p
@@ -286,12 +296,12 @@ AdaptiveSparkPlan isFinalPlan=false
+- SortAggregate(key=[__qid#x],
functions=[partial_first(user_id#x, false),
partial_max_by(named_struct(product, product#x, pscore, pscore#x),
__ranking__#x, 1, false, 0, 0)])
+- Sort [__qid#x ASC NULLS FIRST], false, 0
+- Project [user_id#x, __qid#x, product#x, pscore#x,
(rand(0) + cast(pscore#x as double)) AS __ranking__#x]
- +- BroadcastNestedLoopJoin BuildRight, LeftOuter
- :- Project [col1#x AS user_id#x, uuid(Some(x))
AS __qid#x]
- : +- LocalTableScan [col1#x, col2#x]
- +- BroadcastExchange IdentityBroadcastMode,
[plan_id=x]
- +- Project [col1#x AS product#x, col2#x AS
pscore#x]
- +- LocalTableScan [col1#x, col2#x]
+ +- BroadcastNestedLoopJoin BuildLeft, Inner
+ :- BroadcastExchange IdentityBroadcastMode,
[plan_id=x]
+ : +- Project [col1#x AS user_id#x,
uuid(Some(x)) AS __qid#x]
+ : +- LocalTableScan [col1#x, col2#x]
+ +- Project [col1#x AS product#x, col2#x AS
pscore#x]
+ +- LocalTableScan [col1#x, col2#x]
-- !query
@@ -313,7 +323,7 @@ AdaptiveSparkPlan isFinalPlan=false
+- Exchange hashpartitioning(__qid#x, 4), ENSURE_REQUIREMENTS,
[plan_id=x]
+- SortAggregate(key=[__qid#x],
functions=[partial_first(user_id#x, false),
partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x
- pscore#x)), 2, true, 0, 0)])
+- Sort [__qid#x ASC NULLS FIRST], false, 0
- +- BroadcastNestedLoopJoin BuildRight, LeftOuter
+ +- BroadcastNestedLoopJoin BuildRight, Inner
:- Filter (user_id#x > 1)
: +- Project [col1#x AS user_id#x, col2#x AS
score#x, uuid(Some(x)) AS __qid#x]
: +- LocalTableScan [col1#x, col2#x]
@@ -342,7 +352,7 @@ AdaptiveSparkPlan isFinalPlan=false
+- Exchange hashpartitioning(__qid#x, 4),
ENSURE_REQUIREMENTS, [plan_id=x]
+- SortAggregate(key=[__qid#x],
functions=[partial_first(user_id#x, false),
partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x
- pscore#x)), 2, true, 0, 0)])
+- Sort [__qid#x ASC NULLS FIRST], false, 0
- +- BroadcastNestedLoopJoin BuildRight, LeftOuter
+ +- BroadcastNestedLoopJoin BuildRight, Inner
:- Project [col1#x AS user_id#x, col2#x AS
score#x, uuid(Some(x)) AS __qid#x]
: +- LocalTableScan [col1#x, col2#x]
+- BroadcastExchange IdentityBroadcastMode,
[plan_id=x]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]