Copilot commented on code in PR #611:
URL: https://github.com/apache/sedona-db/pull/611#discussion_r2815221611
##########
rust/sedona-spatial-join/src/planner/optimizer.rs:
##########
@@ -93,47 +184,57 @@ impl OptimizerRule for SpatialJoinLogicalRewrite {
return Ok(Transformed::no(plan));
}
Review Comment:
`SpatialJoinLogicalRewrite` is documented as “non-KNN” and says KNN joins
are skipped, but the current logic only bails out for non-KNN joins with
equi-join conditions. A KNN join without `on` conditions would still be
rewritten here (i.e., after `PushDownFilter`), which reintroduces the risk this
PR is trying to avoid if the early rewrite doesn’t fire for any reason.
Consider explicitly returning `Transformed::no` when
`spatial_predicate_names.contains("st_knn")` to match the intent and make the
rule safer.
##########
rust/sedona-spatial-join/tests/spatial_join_integration.rs:
##########
@@ -1368,3 +1388,93 @@ async fn test_knn_join_include_tie_breakers(
Ok(())
}
+
+/// Verify that a filter on the *object* (build / right) side of a KNN join is
NOT pushed down
+/// into the build side subtree.
+///
+/// If `PushDownFilter` incorrectly pushes `R.id > 5` below the spatial join,
the set of objects
+/// considered for KNN changes, yielding wrong nearest-neighbor results.
+#[tokio::test]
+async fn test_knn_join_object_side_filter_not_pushed_down() -> Result<()> {
+ let sql = "SELECT L.id, R.id \
+ FROM L JOIN R ON ST_KNN(ST_Point(L.x, 0), ST_Point(R.x, 1), 3,
false) \
+ WHERE R.id > 5";
+ let plan = plan_for_filter_pushdown_test(sql).await?;
+
+ let spatial_joins = collect_spatial_join_exec(&plan)?;
+ assert_eq!(
+ spatial_joins.len(),
+ 1,
+ "expected exactly one SpatialJoinExec"
+ );
+ let sj = spatial_joins[0];
+
+ // The build (right / object) side must NOT have a FilterExec pushed into
it.
+ assert!(
+ !subtree_contains_filter_exec(&sj.right),
+ "FilterExec should NOT be pushed into the object (right/build) side of
a KNN join"
+ );
+
+ Ok(())
+}
+
+/// Verify that for a *non-KNN* spatial join, a filter on the build side IS
pushed down
+/// (the normal, desirable behaviour).
+#[tokio::test]
+async fn test_non_knn_join_object_side_filter_is_pushed_down() -> Result<()> {
+ let sql = "SELECT L.id, R.id \
+ FROM L JOIN R ON ST_Intersects(ST_Buffer(ST_Point(L.x, 0),
1.5), ST_Point(R.x, 1)) \
+ WHERE R.id > 5";
+ let plan = plan_for_filter_pushdown_test(sql).await?;
+
+ let spatial_joins = collect_spatial_join_exec(&plan)?;
+ assert_eq!(
+ spatial_joins.len(),
+ 1,
+ "expected exactly one SpatialJoinExec"
+ );
+ let sj = spatial_joins[0];
+
+ // For non-KNN joins, the filter SHOULD be pushed down to the build side.
+ assert!(
+ subtree_contains_filter_exec(&sj.right),
+ "FilterExec should be pushed into the object (right/build) side of a
non-KNN spatial join"
+ );
+
+ Ok(())
+}
+
+/// Recursively check whether any node in the physical plan tree is a
`FilterExec`.
+fn subtree_contains_filter_exec(plan: &Arc<dyn ExecutionPlan>) -> bool {
+ let mut found = false;
+ plan.apply(|node| {
+ if node.as_any().downcast_ref::<FilterExec>().is_some() {
+ found = true;
+ return Ok(TreeNodeRecursion::Stop);
+ }
+ Ok(TreeNodeRecursion::Continue)
+ })
+ .expect("failed to walk plan");
+ found
+}
+
+/// Create a session context with two small tables for filter-pushdown tests.
+///
+/// L(id INT, x DOUBLE) and R(id INT, x DOUBLE) each with 10 rows.
+/// Geometry is constructed in SQL via ST_Point so no geometry column exists
on the table itself.
+async fn plan_for_filter_pushdown_test(sql: &str) -> Result<Arc<dyn
ExecutionPlan>> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("id", DataType::Int32, false),
+ Field::new("x", DataType::Float64, false),
+ ]));
+
+ let options = SpatialJoinOptions::default();
+ let ctx = setup_context(Some(options), 100)?;
+ let empty_l: Arc<dyn TableProvider> =
Arc::new(EmptyTable::new(schema.clone()));
+ let empty_r: Arc<dyn TableProvider> =
Arc::new(EmptyTable::new(schema.clone()));
+ ctx.register_table("L", empty_l)?;
+ ctx.register_table("R", empty_r)?;
Review Comment:
`plan_for_filter_pushdown_test` claims each table has 10 rows, but it
registers `EmptyTable`s, which have 0 rows. With DataFusion’s default optimizer
rules, empty inputs can be propagated and the join may be optimized away,
making these pushdown assertions brittle (e.g., no `SpatialJoinExec` to
inspect). Use a `MemTable` with a small non-empty `RecordBatch` (or otherwise
prevent empty-relation propagation) so the physical plan consistently contains
the join subtree you’re checking.
##########
rust/sedona-spatial-join/src/planner/optimizer.rs:
##########
@@ -29,20 +28,112 @@ use datafusion_expr::{BinaryExpr, Expr, Operator};
use datafusion_expr::{Filter, Join, JoinType, LogicalPlan};
use sedona_common::option::SedonaOptions;
-/// Register only the logical spatial join optimizer rule.
+/// Register the logical spatial join optimizer rules.
///
-/// This enables building `Join(filter=...)` from patterns like
`Filter(CrossJoin)`.
-/// It intentionally does not register any physical plan rewrite rules.
+/// This inserts rules at specific positions relative to DataFusion's built-in
`PushDownFilter`
+/// rule to ensure correct semantics for KNN joins:
+///
+/// - `MergeSpatialFilterIntoJoin` and `KnnJoinEarlyRewrite` are inserted
*before*
+/// `PushDownFilter` so that KNN joins are converted to
`SpatialJoinPlanNode` extension nodes
+/// before filter pushdown runs. Extension nodes naturally block filter
pushdown via
+/// `prevent_predicate_push_down_columns()`, preventing incorrect pushdown
to the build side
+/// of KNN joins.
+///
+/// - `SpatialJoinLogicalRewrite` is appended at the end so that non-KNN
spatial joins still
+/// benefit from filter pushdown before being converted to extension nodes.
pub(crate) fn register_spatial_join_logical_optimizer(
- session_state_builder: SessionStateBuilder,
+ mut session_state_builder: SessionStateBuilder,
) -> SessionStateBuilder {
+ let optimizer = session_state_builder
+ .optimizer()
+ .get_or_insert_with(Optimizer::new);
+
+ // Find PushDownFilter position by name
+ let push_down_pos = optimizer
+ .rules
+ .iter()
+ .position(|r| r.name() == "push_down_filter")
+ .expect("PushDownFilter rule not found in default optimizer rules");
+
+ // Insert KNN-specific rules BEFORE PushDownFilter.
+ // MergeSpatialFilterIntoJoin must come first because it creates the
Join(filter=...)
+ // nodes that KnnJoinEarlyRewrite then converts to SpatialJoinPlanNode.
+ optimizer
+ .rules
+ .insert(push_down_pos, Arc::new(KnnJoinEarlyRewrite));
+ optimizer
+ .rules
+ .insert(push_down_pos, Arc::new(MergeSpatialFilterIntoJoin));
Review Comment:
This will panic at runtime if the optimizer rule list doesn’t contain a rule
named exactly `push_down_filter` (e.g., custom optimizer configuration,
upstream rename, or feature-gated rule sets). Since this is library
initialization code, it would be safer to handle the “not found” case
gracefully (append near the start/end, or fall back to `with_optimizer_rule`
ordering) rather than `expect`ing.
```suggestion
// Find PushDownFilter position by name (if present)
let push_down_pos = optimizer
.rules
.iter()
.position(|r| r.name() == "push_down_filter");
// Insert KNN-specific rules relative to PushDownFilter when available.
// MergeSpatialFilterIntoJoin must come first because it creates the
Join(filter=...)
// nodes that KnnJoinEarlyRewrite then converts to SpatialJoinPlanNode.
if let Some(pos) = push_down_pos {
// Insert BEFORE PushDownFilter: insert in reverse order so that
// MergeSpatialFilterIntoJoin ends up before KnnJoinEarlyRewrite.
optimizer
.rules
.insert(pos, Arc::new(KnnJoinEarlyRewrite));
optimizer
.rules
.insert(pos, Arc::new(MergeSpatialFilterIntoJoin));
} else {
// Fallback: PushDownFilter not found (e.g., custom optimizer
configuration).
// Append rules at the end, preserving their logical order.
optimizer
.rules
.push(Arc::new(MergeSpatialFilterIntoJoin));
optimizer
.rules
.push(Arc::new(KnnJoinEarlyRewrite));
}
```
##########
rust/sedona-spatial-join/tests/spatial_join_integration.rs:
##########
@@ -1368,3 +1388,93 @@ async fn test_knn_join_include_tie_breakers(
Ok(())
}
+
+/// Verify that a filter on the *object* (build / right) side of a KNN join is
NOT pushed down
+/// into the build side subtree.
+///
+/// If `PushDownFilter` incorrectly pushes `R.id > 5` below the spatial join,
the set of objects
+/// considered for KNN changes, yielding wrong nearest-neighbor results.
Review Comment:
Grammar in the doc comment: “considered for KNN changes” is confusing; it
reads like an unintended word choice. Consider rephrasing to something like
“considered for the KNN search”, to make the semantics clearer.
```suggestion
/// considered for the KNN search changes, yielding wrong nearest-neighbor
results.
```
##########
rust/sedona-spatial-join/tests/spatial_join_integration.rs:
##########
@@ -1088,72 +1090,90 @@ async fn test_knn_join_with_filter_correctness(
};
let k = 3;
- let sql = format!(
- "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON ST_KNN(L.geometry,
R.geometry, {}, false) AND (L.id % 7) = (R.id % 7)",
- k
- );
+ let sqls = [
+ format!(
+ "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON
ST_KNN(L.geometry, R.geometry, {}, false) AND (L.id % 7) = (R.id % 7)",
+ k
+ ),
+ format!(
+ "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON
ST_KNN(L.geometry, R.geometry, {}, false) AND L.id % 7 = 0",
+ k
+ ),
+ format!(
+ "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON
ST_KNN(L.geometry, R.geometry, {}, false) AND R.id % 7 = 0",
+ k
+ ),
+ ];
- let batches = run_spatial_join_query(
- &left_schema,
- &right_schema,
- left_partitions.clone(),
- right_partitions.clone(),
- Some(options),
- max_batch_size,
- &sql,
- )
- .await?;
+ for (idx, sql) in sqls.iter().enumerate() {
+ let batches = run_spatial_join_query(
+ &left_schema,
+ &right_schema,
+ left_partitions.clone(),
+ right_partitions.clone(),
+ Some(options.clone()),
+ max_batch_size,
+ sql,
+ )
+ .await?;
- let mut actual_results = Vec::new();
- let combined_batch = arrow::compute::concat_batches(&batches.schema(),
&[batches])?;
- let l_ids = combined_batch
- .column(0)
- .as_any()
- .downcast_ref::<arrow_array::Int32Array>()
- .unwrap();
- let r_ids = combined_batch
- .column(1)
- .as_any()
- .downcast_ref::<arrow_array::Int32Array>()
- .unwrap();
+ let mut actual_results = Vec::new();
+ let combined_batch = arrow::compute::concat_batches(&batches.schema(),
&[batches])?;
+ let l_ids = combined_batch
+ .column(0)
+ .as_any()
+ .downcast_ref::<arrow_array::Int32Array>()
+ .unwrap();
+ let r_ids = combined_batch
+ .column(1)
+ .as_any()
+ .downcast_ref::<arrow_array::Int32Array>()
+ .unwrap();
- for i in 0..combined_batch.num_rows() {
- actual_results.push((l_ids.value(i), r_ids.value(i)));
- }
- actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
+ for i in 0..combined_batch.num_rows() {
+ actual_results.push((l_ids.value(i), r_ids.value(i)));
+ }
+ actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(||
a.1.cmp(&b.1)));
- // Prove the test actually exercises the "< K rows after filtering" case.
- // Build a list of all probe-side IDs and count how many results each has.
- let all_left_ids: Vec<i32> = extract_geoms_and_ids(&left_partitions)
- .into_iter()
- .map(|(id, _)| id)
- .collect();
- let mut per_left_counts: std::collections::HashMap<i32, usize> =
- std::collections::HashMap::new();
- for (l_id, _) in &actual_results {
- *per_left_counts.entry(*l_id).or_default() += 1;
- }
- let min_count = all_left_ids
- .iter()
- .map(|l_id| *per_left_counts.get(l_id).unwrap_or(&0))
- .min()
- .unwrap_or(0);
- assert!(
- min_count < k,
- "expected at least one probe row to produce < K rows after filtering;
min_count={min_count}, k={k}"
- );
+ // Prove the test actually exercises the "< K rows after filtering"
case.
+ // Build a list of all probe-side IDs and count how many results each
has.
+ let all_left_ids: Vec<i32> = extract_geoms_and_ids(&left_partitions)
+ .into_iter()
+ .map(|(id, _)| id)
+ .collect();
+ let mut per_left_counts: std::collections::HashMap<i32, usize> =
+ std::collections::HashMap::new();
+ for (l_id, _) in &actual_results {
+ *per_left_counts.entry(*l_id).or_default() += 1;
+ }
+ let min_count = all_left_ids
+ .iter()
+ .map(|l_id| *per_left_counts.get(l_id).unwrap_or(&0))
+ .min()
+ .unwrap_or(0);
+ assert!(
+ min_count < k,
+ "expected at least one probe row to produce < K rows after
filtering; min_count={min_count}, k={k}"
+ );
- let expected_results = compute_knn_ground_truth_with_pair_filter(
- &left_partitions,
- &right_partitions,
- k,
- |l_id, r_id| (l_id.rem_euclid(7)) == (r_id.rem_euclid(7)),
- )
- .into_iter()
- .map(|(l, r, _)| (l, r))
- .collect::<Vec<_>>();
+ let filter_closure = match idx {
+ 0 => |l_id: i32, r_id: i32| (l_id.rem_euclid(7)) ==
(r_id.rem_euclid(7)),
+ 1 => |l_id: i32, _r_id: i32| l_id.rem_euclid(7) == 0,
+ 2 => |_l_id: i32, r_id: i32| r_id.rem_euclid(7) == 0,
+ _ => unreachable!(),
+ };
Review Comment:
`filter_closure` is assigned from a `match` returning three different
closure types. Because `compute_knn_ground_truth_with_pair_filter` is generic
over `Fn`, this won’t coerce automatically and should fail to compile with a
“match arms have incompatible types” error. Consider using an explicit
function-pointer type (e.g., `let filter_closure: fn(i32, i32) -> bool = ...`),
boxing to `Box<dyn Fn(i32,i32)->bool>`, or computing `expected_results` inside
the `match` instead of returning closures.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]