gengliangwang commented on code in PR #55629:
URL: https://github.com/apache/spark/pull/55629#discussion_r3184057165
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala:
##########
@@ -2541,14 +2542,36 @@ object CheckCartesianProducts extends Rule[LogicalPlan]
with PredicateHelper {
}
}
- def apply(plan: LogicalPlan): LogicalPlan =
+ def apply(plan: LogicalPlan): LogicalPlan = {
if (conf.crossJoinEnabled) {
- plan
- } else plan.transformWithPruning(_.containsAnyPattern(INNER_LIKE_JOIN,
OUTER_JOIN)) {
+ return plan
+ }
+
+ // Joins synthesized by `RewriteNearestByJoin` are an intentional, bounded
cross-product
+ // wrapped by a `MaxMinByK` aggregate. Identify them by their unambiguous
post-rewrite
+ // signature -- `Aggregate(_, exprs, Join(_, _, LeftOuter, None, _))`
where `exprs`
+ // contains a `MaxMinByK` -- and skip them so user queries written as
`NEAREST BY` are not
+ // rejected when `spark.sql.crossJoin.enabled = false`. We use structural
detection rather
+ // than a `TreeNodeTag` because a tag set on the `Join` would be silently
dropped by any
+ // intervening optimizer rule that constructs a fresh `Join` via the
case-class
+ // constructor without calling `copyTagsFrom`.
+ val nearestByJoins: java.util.IdentityHashMap[Join, Unit] = {
+ val acc = new java.util.IdentityHashMap[Join, Unit]()
+ plan.foreach {
+ case Aggregate(_, exprs, j @ Join(_, _, LeftOuter, None, _), _)
Review Comment:
**Structural skip misses the nondeterministic-ranking shape.** When
`rankingExpression` is nondeterministic (legal only under `APPROX`),
`RewriteNearestByJoin` injects a materializing `Project(__ranking__ alias,
...)` between the `Aggregate` and the `Join`
(`RewriteNearestByJoin.scala:86-91`). The synthetic plan is then `Aggregate(_,
exprs, Project(_, Join(LeftOuter, None, ...)))`, but this `case` only matches
when `Join` is the `Aggregate`'s direct child. So the synthetic join is never
put into `nearestByJoins`, and `transformWithPruning` below throws on it
whenever `spark.sql.crossJoin.enabled = false`.
The new `.sql` tests cover (a) `crossJoin = false` with deterministic
ranking and (b) `rand()` with the default `crossJoin = true`, but never the
intersection — so this is uncaught.
Extend the pattern to skip an optional intervening `Project`, e.g.:
```scala
plan.foreach {
case Aggregate(_, exprs, child, _)
if exprs.exists(_.exists(_.isInstanceOf[MaxMinByK])) =>
child match {
case j @ Join(_, _, LeftOuter, None, _) => acc.put(j, ())
case Project(_, j @ Join(_, _, LeftOuter, None, _)) => acc.put(j, ())
case _ =>
}
case _ =>
}
```
And add a `.sql` test combining `crossJoin.enabled = false` with `APPROX
NEAREST ... BY ... rand() ...`. Also: the doc comment a few lines above calls
the post-rewrite signature "unambiguous", which is overstated even in this
strengthened form (a user query of the shape `Aggregate( max_by(struct(...),
expr, k), Join(LeftOuter, None, ...) )` would still pass through). Worth
tightening the wording.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala:
##########
@@ -2420,3 +2420,58 @@ object AsOfJoin {
}
}
}
+
+object NearestByJoin {
+ /** Upper bound on `numResults`. Mirrors the K-overload limit of
`MaxMinByK`. */
+ val MaxNumResults: Int = 100000
+}
+
+/**
+ * A logical plan for a nearest-by top-K ranking join. For each row on the
left side it returns
+ * up to `numResults` rows from the right side ordered by `rankingExpression`:
+ * - `NearestByDistance`: smallest values of `rankingExpression` first.
+ * - `NearestBySimilarity`: largest values of `rankingExpression` first.
+ *
+ * The `approx` field records the user's APPROX/EXACT choice from the SPIP.
Today both modes
+ * use the same brute-force rewrite. The flag is preserved on the logical plan
so future
+ * indexed approximate-nearest-neighbor strategies can fire only when `approx
= true`,
+ * leaving EXACT queries unaffected. See the SPIP linked from SPARK-56395.
+ */
+case class NearestByJoin(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ joinType: JoinType,
+ approx: Boolean,
+ numResults: Int,
+ rankingExpression: Expression,
+ direction: NearestByDirection)
+ extends BinaryNode with SupportsNonDeterministicExpression {
+
+ require(Seq(Inner, LeftOuter).contains(joinType),
+ s"Unsupported nearest-by join type $joinType")
+
+ // APPROX permits a nondeterministic ranking expression (per the SPIP); the
rewrite
Review Comment:
Comment is cut off mid-sentence (`"...; the rewrite"`). Please complete or
remove the dangling phrase.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]