paleolimbot commented on code in PR #562:
URL: https://github.com/apache/sedona-db/pull/562#discussion_r2795221835


##########
rust/sedona-spatial-join/src/planner/logical_plan_node.rs:
##########
@@ -0,0 +1,129 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::cmp::Ordering;
+use std::fmt;
+use std::sync::Arc;
+
+use datafusion_common::{plan_err, DFSchemaRef, NullEquality, Result};
+use datafusion_expr::logical_plan::UserDefinedLogicalNodeCore;
+use datafusion_expr::{Expr, JoinConstraint, JoinType, LogicalPlan};
+
+/// Logical extension node used as a planning hook for spatial joins.
+///
+/// Carries a join's inputs and filter expression so the physical planner can 
recognize and plan
+/// a `SpatialJoinExec`.
+#[derive(PartialEq, Eq, Hash)]
+pub(crate) struct SpatialJoinPlanNode {

Review Comment:
   Not in this PR, but we can also make this pub and construct it directly from 
R and Python DataFrame APIs (e.g., `df.sjoin(...)`)



##########
rust/sedona-spatial-join/src/planner/logical_plan_node.rs:
##########
@@ -0,0 +1,129 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::cmp::Ordering;
+use std::fmt;
+use std::sync::Arc;
+
+use datafusion_common::{plan_err, DFSchemaRef, NullEquality, Result};
+use datafusion_expr::logical_plan::UserDefinedLogicalNodeCore;
+use datafusion_expr::{Expr, JoinConstraint, JoinType, LogicalPlan};
+
+/// Logical extension node used as a planning hook for spatial joins.
+///
+/// Carries a join's inputs and filter expression so the physical planner can 
recognize and plan
+/// a `SpatialJoinExec`.
+#[derive(PartialEq, Eq, Hash)]
+pub(crate) struct SpatialJoinPlanNode {
+    pub left: LogicalPlan,
+    pub right: LogicalPlan,
+    pub join_type: JoinType,
+    pub filter: Expr,
+    pub schema: DFSchemaRef,
+    pub join_constraint: JoinConstraint,
+    pub null_equality: NullEquality,
+}
+
+// Manual implementation needed because of `schema` field. Comparison excludes 
this field.
+// See 
https://github.com/apache/datafusion/blob/52.1.0/datafusion/expr/src/logical_plan/plan.rs#L3886
+impl PartialOrd for SpatialJoinPlanNode {
+    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+        #[derive(PartialEq, PartialOrd)]
+        struct ComparableJoin<'a> {
+            pub left: &'a LogicalPlan,
+            pub right: &'a LogicalPlan,
+            pub filter: &'a Expr,
+            pub join_type: &'a JoinType,
+            pub join_constraint: &'a JoinConstraint,
+            pub null_equality: &'a NullEquality,
+        }
+        let comparable_self = ComparableJoin {
+            left: &self.left,
+            right: &self.right,
+            filter: &self.filter,
+            join_type: &self.join_type,
+            join_constraint: &self.join_constraint,
+            null_equality: &self.null_equality,
+        };
+        let comparable_other = ComparableJoin {
+            left: &other.left,
+            right: &other.right,
+            filter: &other.filter,
+            join_type: &other.join_type,
+            join_constraint: &other.join_constraint,
+            null_equality: &other.null_equality,
+        };
+        comparable_self
+            .partial_cmp(&comparable_other)
+            // TODO (https://github.com/apache/datafusion/issues/17477) avoid 
recomparing all fields
+            .filter(|cmp| *cmp != Ordering::Equal || self == other)
+    }
+}
+
+impl fmt::Debug for SpatialJoinPlanNode {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        UserDefinedLogicalNodeCore::fmt_for_explain(self, f)
+    }
+}
+
+impl UserDefinedLogicalNodeCore for SpatialJoinPlanNode {
+    fn name(&self) -> &str {
+        "SpatialJoin"
+    }
+
+    fn inputs(&self) -> Vec<&LogicalPlan> {
+        vec![&self.left, &self.right]
+    }
+
+    fn schema(&self) -> &DFSchemaRef {
+        &self.schema
+    }
+
+    fn expressions(&self) -> Vec<Expr> {
+        vec![self.filter.clone()]
+    }
+
+    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(
+            f,
+            "SpatialJoin: join_type={:?}, filter={}",
+            self.join_type, self.filter
+        )
+    }
+
+    fn with_exprs_and_inputs(
+        &self,
+        mut exprs: Vec<Expr>,
+        mut inputs: Vec<LogicalPlan>,
+    ) -> Result<Self> {
+        if exprs.len() != 1 {
+            return plan_err!("SpatialJoinPlanNode expects 1 expr");

Review Comment:
   ```suggestion
               return sedona_internal_err!("SpatialJoinPlanNode expects 1 
expr");
   ```



##########
rust/sedona-spatial-join/src/planner.rs:
##########
@@ -0,0 +1,40 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! DataFusion planner integration for Sedona spatial joins.
+//!
+//! This module wires Sedona's logical optimizer rules and physical planning 
extensions that
+//! can produce `SpatialJoinExec`.
+
+use datafusion::execution::SessionStateBuilder;
+
+mod logical_plan_node;
+mod optimizer;
+mod physical_planner;
+mod spatial_expr_utils;
+
+/// Register Sedona spatial join planning hooks.
+///
+/// Enables logical rewrites (to surface join filters) and a query planner 
extension that can
+/// plan `SpatialJoinExec`.

Review Comment:
   ```suggestion
   /// Register Sedona spatial join planning hooks.
   ///
   /// Enables logical rewrites (to surface join filters) and a query planner 
extension that can
   /// plan `SpatialJoinExec`. This is the primary entry point to leveraging 
the spatial join
   /// implementation provided by this crate and ensures joins created by SQL 
or using
   /// a DataFrame API that meet certain conditions (e.g. contain a spatial 
predicate as
   /// a join condition) are executed using the `SpatialJoinExec`.
   ```
   
   Just giving this a bit of fanfare since it's the entry point to pretty much 
everything else.



##########
rust/sedona-spatial-join/src/planner/physical_planner.rs:
##########
@@ -0,0 +1,263 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+use std::fmt;
+use std::sync::Arc;
+
+use async_trait::async_trait;
+
+use arrow_schema::Schema;
+
+use datafusion::execution::context::QueryPlanner;
+use datafusion::execution::session_state::{SessionState, SessionStateBuilder};
+use datafusion::physical_plan::ExecutionPlan;
+use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, 
PhysicalPlanner};
+use datafusion_common::{plan_err, DFSchema, Result};
+use datafusion_expr::logical_plan::UserDefinedLogicalNode;
+use datafusion_expr::LogicalPlan;
+use datafusion_physical_expr::create_physical_expr;
+use datafusion_physical_plan::joins::utils::JoinFilter;
+use datafusion_physical_plan::joins::NestedLoopJoinExec;
+use sedona_common::sedona_internal_err;
+
+use crate::exec::SpatialJoinExec;
+use crate::planner::logical_plan_node::SpatialJoinPlanNode;
+use crate::planner::spatial_expr_utils::{is_spatial_predicate_supported, 
transform_join_filter};
+use crate::spatial_predicate::SpatialPredicate;
+use sedona_common::option::SedonaOptions;
+
+/// Registers a query planner that can produce [`SpatialJoinExec`] from a 
logical extension node.
+pub fn register_spatial_join_planner(builder: SessionStateBuilder) -> 
SessionStateBuilder {
+    builder.with_query_planner(Arc::new(SedonaSpatialQueryPlanner))
+}
+
+/// Query planner that enables Sedona's spatial join planning.
+///
+/// Installs an [`ExtensionPlanner`] that recognizes `SpatialJoinPlanNode` and 
produces
+/// `SpatialJoinExec` when supported and enabled.
+pub struct SedonaSpatialQueryPlanner;

Review Comment:
   Because we only get one physical planner, this should probably be always be 
the planner when constructing a `SedonaContext` (i.e., this struct can live in 
`sedona` and always be the planner). I'm not sure there's a high chance of many 
more extension logical nodes but that is also where we do things like interpret 
SQL and potentially pre-process the logical plan (this would be the opportunity 
to post-process the logical plan) and that should probably be in the same place.
   
   ```suggestion
   /// Query planner that enables Sedona's spatial join planning.
   ///
   /// Installs an [`ExtensionPlanner`] that recognizes `SpatialJoinPlanNode` 
and produces
   /// `SpatialJoinExec` when supported and enabled.
   pub struct SedonaSpatialQueryPlanner { extension_planner: 
Vec<Arc<ExtensionPlanner>> };
   ```



##########
rust/sedona-spatial-join/src/planner/physical_planner.rs:
##########
@@ -0,0 +1,263 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+use std::fmt;
+use std::sync::Arc;
+
+use async_trait::async_trait;
+
+use arrow_schema::Schema;
+
+use datafusion::execution::context::QueryPlanner;
+use datafusion::execution::session_state::{SessionState, SessionStateBuilder};
+use datafusion::physical_plan::ExecutionPlan;
+use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, 
PhysicalPlanner};
+use datafusion_common::{plan_err, DFSchema, Result};
+use datafusion_expr::logical_plan::UserDefinedLogicalNode;
+use datafusion_expr::LogicalPlan;
+use datafusion_physical_expr::create_physical_expr;
+use datafusion_physical_plan::joins::utils::JoinFilter;
+use datafusion_physical_plan::joins::NestedLoopJoinExec;
+use sedona_common::sedona_internal_err;
+
+use crate::exec::SpatialJoinExec;
+use crate::planner::logical_plan_node::SpatialJoinPlanNode;
+use crate::planner::spatial_expr_utils::{is_spatial_predicate_supported, 
transform_join_filter};
+use crate::spatial_predicate::SpatialPredicate;
+use sedona_common::option::SedonaOptions;
+
+/// Registers a query planner that can produce [`SpatialJoinExec`] from a 
logical extension node.
+pub fn register_spatial_join_planner(builder: SessionStateBuilder) -> 
SessionStateBuilder {

Review Comment:
   ```suggestion
   /// Registers a query planner that can produce [`SpatialJoinExec`] from a 
logical extension node.
   pub(crate) fn register_spatial_join_planner(builder: SessionStateBuilder) -> 
SessionStateBuilder {
   ```
   
   Should this be slightly less public or link to the function people should 
more likely be using?



##########
rust/sedona-spatial-join/src/planner/optimizer.rs:
##########
@@ -0,0 +1,231 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+use std::sync::Arc;
+
+use crate::planner::logical_plan_node::SpatialJoinPlanNode;
+use crate::planner::spatial_expr_utils::collect_spatial_predicate_names;
+use crate::planner::spatial_expr_utils::is_spatial_predicate;
+use datafusion::execution::session_state::SessionStateBuilder;
+use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
+use datafusion_common::tree_node::Transformed;
+use datafusion_common::NullEquality;
+use datafusion_common::Result;
+use datafusion_expr::logical_plan::Extension;
+use datafusion_expr::{BinaryExpr, Expr, Operator};
+use datafusion_expr::{Filter, Join, JoinType, LogicalPlan};
+use sedona_common::option::SedonaOptions;
+
+/// Register only the logical spatial join optimizer rule.
+///
+/// This enables building `Join(filter=...)` from patterns like 
`Filter(CrossJoin)`.
+/// It intentionally does not register any physical plan rewrite rules.
+pub fn register_spatial_join_logical_optimizer(

Review Comment:
   Should these be slightly less public or link to the public entrypoint that 
should actually be used?
   
   ```suggestion
   pub(crate) fn register_spatial_join_logical_optimizer(
   ```



##########
rust/sedona-spatial-join/src/planner/logical_plan_node.rs:
##########
@@ -0,0 +1,129 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::cmp::Ordering;
+use std::fmt;
+use std::sync::Arc;
+
+use datafusion_common::{plan_err, DFSchemaRef, NullEquality, Result};
+use datafusion_expr::logical_plan::UserDefinedLogicalNodeCore;
+use datafusion_expr::{Expr, JoinConstraint, JoinType, LogicalPlan};
+
+/// Logical extension node used as a planning hook for spatial joins.
+///
+/// Carries a join's inputs and filter expression so the physical planner can 
recognize and plan
+/// a `SpatialJoinExec`.
+#[derive(PartialEq, Eq, Hash)]
+pub(crate) struct SpatialJoinPlanNode {
+    pub left: LogicalPlan,
+    pub right: LogicalPlan,
+    pub join_type: JoinType,
+    pub filter: Expr,
+    pub schema: DFSchemaRef,
+    pub join_constraint: JoinConstraint,
+    pub null_equality: NullEquality,
+}
+
+// Manual implementation needed because of `schema` field. Comparison excludes 
this field.
+// See 
https://github.com/apache/datafusion/blob/52.1.0/datafusion/expr/src/logical_plan/plan.rs#L3886
+impl PartialOrd for SpatialJoinPlanNode {
+    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+        #[derive(PartialEq, PartialOrd)]
+        struct ComparableJoin<'a> {
+            pub left: &'a LogicalPlan,
+            pub right: &'a LogicalPlan,
+            pub filter: &'a Expr,
+            pub join_type: &'a JoinType,
+            pub join_constraint: &'a JoinConstraint,
+            pub null_equality: &'a NullEquality,
+        }
+        let comparable_self = ComparableJoin {
+            left: &self.left,
+            right: &self.right,
+            filter: &self.filter,
+            join_type: &self.join_type,
+            join_constraint: &self.join_constraint,
+            null_equality: &self.null_equality,
+        };
+        let comparable_other = ComparableJoin {
+            left: &other.left,
+            right: &other.right,
+            filter: &other.filter,
+            join_type: &other.join_type,
+            join_constraint: &other.join_constraint,
+            null_equality: &other.null_equality,
+        };
+        comparable_self
+            .partial_cmp(&comparable_other)
+            // TODO (https://github.com/apache/datafusion/issues/17477) avoid 
recomparing all fields
+            .filter(|cmp| *cmp != Ordering::Equal || self == other)
+    }
+}
+
+impl fmt::Debug for SpatialJoinPlanNode {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        UserDefinedLogicalNodeCore::fmt_for_explain(self, f)
+    }
+}
+
+impl UserDefinedLogicalNodeCore for SpatialJoinPlanNode {
+    fn name(&self) -> &str {
+        "SpatialJoin"
+    }
+
+    fn inputs(&self) -> Vec<&LogicalPlan> {
+        vec![&self.left, &self.right]
+    }
+
+    fn schema(&self) -> &DFSchemaRef {
+        &self.schema
+    }
+
+    fn expressions(&self) -> Vec<Expr> {
+        vec![self.filter.clone()]
+    }
+
+    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(
+            f,
+            "SpatialJoin: join_type={:?}, filter={}",
+            self.join_type, self.filter
+        )
+    }
+
+    fn with_exprs_and_inputs(
+        &self,
+        mut exprs: Vec<Expr>,
+        mut inputs: Vec<LogicalPlan>,
+    ) -> Result<Self> {
+        if exprs.len() != 1 {
+            return plan_err!("SpatialJoinPlanNode expects 1 expr");
+        }
+        if inputs.len() != 2 {
+            return plan_err!("SpatialJoinPlanNode expects 2 inputs");

Review Comment:
   ```suggestion
               return sedona_internal_err!("SpatialJoinPlanNode expects 2 
inputs");
   ```
   
   (Unless a user can construct this incorrectly?)



##########
rust/sedona-spatial-join/src/planner/optimizer.rs:
##########
@@ -0,0 +1,231 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+use std::sync::Arc;
+
+use crate::planner::logical_plan_node::SpatialJoinPlanNode;
+use crate::planner::spatial_expr_utils::collect_spatial_predicate_names;
+use crate::planner::spatial_expr_utils::is_spatial_predicate;
+use datafusion::execution::session_state::SessionStateBuilder;
+use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
+use datafusion_common::tree_node::Transformed;
+use datafusion_common::NullEquality;
+use datafusion_common::Result;
+use datafusion_expr::logical_plan::Extension;
+use datafusion_expr::{BinaryExpr, Expr, Operator};
+use datafusion_expr::{Filter, Join, JoinType, LogicalPlan};
+use sedona_common::option::SedonaOptions;
+
+/// Register only the logical spatial join optimizer rule.
+///
+/// This enables building `Join(filter=...)` from patterns like 
`Filter(CrossJoin)`.
+/// It intentionally does not register any physical plan rewrite rules.
+pub fn register_spatial_join_logical_optimizer(
+    session_state_builder: SessionStateBuilder,
+) -> SessionStateBuilder {
+    session_state_builder
+        .with_optimizer_rule(Arc::new(MergeSpatialProjectionIntoJoin))
+        .with_optimizer_rule(Arc::new(SpatialJoinLogicalRewrite))

Review Comment:
   Do we have the ability to ensure that these run before the 
`pushdown_filters` rule to ensure that filters aren't pushed through a KNN join?



##########
rust/sedona-spatial-join/src/planner/optimizer.rs:
##########
@@ -0,0 +1,231 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+use std::sync::Arc;
+
+use crate::planner::logical_plan_node::SpatialJoinPlanNode;
+use crate::planner::spatial_expr_utils::collect_spatial_predicate_names;
+use crate::planner::spatial_expr_utils::is_spatial_predicate;
+use datafusion::execution::session_state::SessionStateBuilder;
+use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
+use datafusion_common::tree_node::Transformed;
+use datafusion_common::NullEquality;
+use datafusion_common::Result;
+use datafusion_expr::logical_plan::Extension;
+use datafusion_expr::{BinaryExpr, Expr, Operator};
+use datafusion_expr::{Filter, Join, JoinType, LogicalPlan};
+use sedona_common::option::SedonaOptions;
+
+/// Register only the logical spatial join optimizer rule.
+///
+/// This enables building `Join(filter=...)` from patterns like 
`Filter(CrossJoin)`.
+/// It intentionally does not register any physical plan rewrite rules.
+pub fn register_spatial_join_logical_optimizer(
+    session_state_builder: SessionStateBuilder,
+) -> SessionStateBuilder {
+    session_state_builder
+        .with_optimizer_rule(Arc::new(MergeSpatialProjectionIntoJoin))
+        .with_optimizer_rule(Arc::new(SpatialJoinLogicalRewrite))
+}
+/// Logical optimizer rule that enables spatial join planning.
+///
+/// This rule turns eligible `Join(filter=...)` nodes into a 
`SpatialJoinPlanNode` extension.
+#[derive(Default, Debug)]
+struct SpatialJoinLogicalRewrite;
+
+impl OptimizerRule for SpatialJoinLogicalRewrite {
+    fn name(&self) -> &str {
+        "spatial_join_logical_rewrite"
+    }
+
+    fn apply_order(&self) -> Option<ApplyOrder> {
+        Some(ApplyOrder::BottomUp)
+    }
+
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
+
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        let options = config.options();
+        let Some(ext) = options.extensions.get::<SedonaOptions>() else {
+            return Ok(Transformed::no(plan));
+        };
+        if !ext.spatial_join.enable {
+            return Ok(Transformed::no(plan));
+        }
+
+        let LogicalPlan::Join(join) = &plan else {
+            return Ok(Transformed::no(plan));
+        };
+
+        // v1: only rewrite joins that already have a spatial predicate in 
`filter`.

Review Comment:
   ```suggestion
           // only rewrite joins that already have a spatial predicate in 
`filter`.
   ```



##########
rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs:
##########
@@ -14,566 +14,95 @@
 // KIND, either express or implied.  See the License for the
 // specific language governing permissions and limitations
 // under the License.
+
+use std::collections::HashSet;
 use std::sync::Arc;
 
-use crate::exec::SpatialJoinExec;
 use crate::spatial_predicate::{
     DistancePredicate, KNNPredicate, RelationPredicate, SpatialPredicate, 
SpatialRelationType,
 };
-use arrow_schema::{Schema, SchemaRef};
-use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
-use datafusion::physical_optimizer::sanity_checker::SanityCheckPlan;
-use datafusion::{
-    config::ConfigOptions, execution::session_state::SessionStateBuilder,
-    physical_optimizer::PhysicalOptimizerRule,
-};
+use arrow_schema::Schema;
 use datafusion_common::ScalarValue;
 use datafusion_common::{
     tree_node::{Transformed, TreeNode},
     JoinSide,
 };
 use datafusion_common::{HashMap, Result};
-use datafusion_expr::{Expr, Filter, Join, JoinType, LogicalPlan, Operator};
+use datafusion_expr::{Expr, Operator};
 use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
 use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr};
-use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
 use datafusion_physical_plan::joins::utils::ColumnIndex;
-use datafusion_physical_plan::joins::{HashJoinExec, NestedLoopJoinExec};
-use datafusion_physical_plan::projection::ProjectionExec;
-use datafusion_physical_plan::{joins::utils::JoinFilter, ExecutionPlan};
-use sedona_common::{option::SedonaOptions, sedona_internal_err};
+use datafusion_physical_plan::joins::utils::JoinFilter;
+use sedona_common::sedona_internal_err;
 use sedona_expr::utils::{parse_distance_predicate, ParsedDistancePredicate};
 use sedona_schema::datatypes::SedonaType;
 use sedona_schema::matchers::ArgMatcher;
 
-/// Physical planner extension for spatial joins
-///
-/// This extension recognizes nested loop join operations with spatial 
predicates
-/// and converts them to SpatialJoinExec, which is specially optimized for 
spatial joins.
-#[derive(Debug, Default)]
-pub struct SpatialJoinOptimizer;
-
-impl SpatialJoinOptimizer {
-    pub fn new() -> Self {
-        Self
-    }
-}
-
-impl PhysicalOptimizerRule for SpatialJoinOptimizer {
-    fn optimize(
-        &self,
-        plan: Arc<dyn ExecutionPlan>,
-        config: &ConfigOptions,
-    ) -> Result<Arc<dyn ExecutionPlan>> {
-        let Some(extension) = config.extensions.get::<SedonaOptions>() else {
-            return Ok(plan);
-        };
-
-        if extension.spatial_join.enable {
-            let transformed = plan.transform_up(|plan| 
self.try_optimize_join(plan, config))?;
-            Ok(transformed.data)
-        } else {
-            Ok(plan)
-        }
-    }
-
-    /// A human readable name for this optimizer rule
-    fn name(&self) -> &str {
-        "spatial_join_optimizer"
-    }
-
-    /// A flag to indicate whether the physical planner should valid the rule 
will not
-    /// change the schema of the plan after the rewriting.
-    /// Some of the optimization rules might change the nullable properties of 
the schema
-    /// and should disable the schema check.
-    fn schema_check(&self) -> bool {
-        true
-    }
-}
-
-impl OptimizerRule for SpatialJoinOptimizer {
-    fn name(&self) -> &str {
-        "spatial_join_optimizer"
-    }
-
-    fn apply_order(&self) -> Option<ApplyOrder> {
-        Some(ApplyOrder::BottomUp)
-    }
-
-    /// Try to rewrite the plan containing a spatial Filter on top of a cross 
join without on or filter
-    /// to a theta-join with filter. For instance, the following query plan:
-    ///
-    /// ```text
-    /// Filter: st_intersects(l.geom, _scalar_sq_1.geom)
-    ///   Left Join (no on, no filter):
-    ///     TableScan: l projection=[id, geom]
-    ///     SubqueryAlias: __scalar_sq_1
-    ///       Projection: r.geom
-    ///         Filter: r.id = Int32(1)
-    ///           TableScan: r projection=[id, geom]
-    /// ```
-    ///
-    /// will be rewritten to
-    ///
-    /// ```text
-    /// Inner Join: Filter: st_intersects(l.geom, _scalar_sq_1.geom)
-    ///   TableScan: l projection=[id, geom]
-    ///   SubqueryAlias: __scalar_sq_1
-    ///     Projection: r.geom
-    ///       Filter: r.id = Int32(1)
-    ///         TableScan: r projection=[id, geom]
-    /// ```
-    ///
-    /// This is for enabling this logical join operator to be converted to a 
NestedLoopJoin physical
-    /// node with a spatial predicate, so that it could subsequently be 
optimized to a SpatialJoin
-    /// physical node. Please refer to the `PhysicalOptimizerRule` 
implementation of this struct
-    /// and [SpatialJoinOptimizer::try_optimize_join] for details.
-    fn rewrite(
-        &self,
-        plan: LogicalPlan,
-        config: &dyn OptimizerConfig,
-    ) -> Result<Transformed<LogicalPlan>> {
-        let options = config.options();
-        let Some(extension) = options.extensions.get::<SedonaOptions>() else {
-            return Ok(Transformed::no(plan));
-        };
-        if !extension.spatial_join.enable {
-            return Ok(Transformed::no(plan));
-        }
-
-        let LogicalPlan::Filter(Filter {
-            predicate, input, ..
-        }) = &plan
-        else {
-            return Ok(Transformed::no(plan));
-        };
-        if !is_spatial_predicate(predicate) {
-            return Ok(Transformed::no(plan));
-        }
-
-        let LogicalPlan::Join(Join {
-            ref left,
-            ref right,
-            ref on,
-            ref filter,
-            join_type,
-            ref join_constraint,
-            ref null_equality,
-            ..
-        }) = input.as_ref()
-        else {
-            return Ok(Transformed::no(plan));
-        };
-
-        // Check if this is a suitable join for rewriting
-        if !matches!(
-            join_type,
-            JoinType::Inner | JoinType::Left | JoinType::Right
-        ) || !on.is_empty()
-            || filter.is_some()
-        {
-            return Ok(Transformed::no(plan));
-        }
-
-        let rewritten_plan = Join::try_new(
-            Arc::clone(left),
-            Arc::clone(right),
-            on.clone(),
-            Some(predicate.clone()),
-            JoinType::Inner,
-            *join_constraint,
-            *null_equality,
-        )?;
-
-        Ok(Transformed::yes(LogicalPlan::Join(rewritten_plan)))
-    }
-}
-
-/// Check if a given logical expression contains a spatial predicate component 
or not. We assume that the given
+/// Collect the names of spatial predicates appeared in expr. We assume that 
the given
 /// `expr` evaluates to a boolean value and originates from a filter logical 
node.
-fn is_spatial_predicate(expr: &Expr) -> bool {
-    fn is_distance_expr(expr: &Expr) -> bool {
-        let Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, 
.. }) = expr else {
-            return false;
-        };
-        func.name().to_lowercase() == "st_distance"
-    }
-
-    match expr {
-        Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
-            left, right, op, ..
-        }) => match op {
-            Operator::And => is_spatial_predicate(left) || 
is_spatial_predicate(right),
-            Operator::Lt | Operator::LtEq => is_distance_expr(left),
-            Operator::Gt | Operator::GtEq => is_distance_expr(right),
-            _ => false,
-        },
-        Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, .. 
}) => {
-            let func_name = func.name().to_lowercase();
-            matches!(
-                func_name.as_str(),
-                "st_intersects"
-                    | "st_contains"
-                    | "st_within"
-                    | "st_covers"
-                    | "st_covered_by"
-                    | "st_coveredby"
-                    | "st_touches"
-                    | "st_crosses"
-                    | "st_overlaps"
-                    | "st_equals"
-                    | "st_dwithin"
-                    | "st_knn"
-            )
-        }
-        _ => false,
-    }
-}
-
-impl SpatialJoinOptimizer {
-    /// Rewrite `plan` containing NestedLoopJoinExec or HashJoinExec with 
spatial predicates to SpatialJoinExec.
-    fn try_optimize_join(
-        &self,
-        plan: Arc<dyn ExecutionPlan>,
-        config: &ConfigOptions,
-    ) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
-        // Check if this is a NestedLoopJoinExec that we can convert to 
spatial join
-        if let Some(nested_loop_join) = 
plan.as_any().downcast_ref::<NestedLoopJoinExec>() {
-            if let Some(spatial_join) =
-                self.try_convert_to_spatial_join(nested_loop_join, config)?
-            {
-                return Ok(Transformed::yes(spatial_join));
-            }
-        }
-
-        // Check if this is a HashJoinExec with spatial filter that we can 
convert to spatial join
-        if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
-            if let Some(spatial_join) = 
self.try_convert_hash_join_to_spatial(hash_join, config)? {
-                return Ok(Transformed::yes(spatial_join));
-            }
-        }
-
-        // No optimization applied, return the original plan
-        Ok(Transformed::no(plan))
-    }
-
-    /// Try to convert a NestedLoopJoinExec with spatial predicates as join 
condition to a SpatialJoinExec.
-    /// SpatialJoinExec executes the query using an optimized algorithm, which 
is more efficient than
-    /// NestedLoopJoinExec.
-    fn try_convert_to_spatial_join(
-        &self,
-        nested_loop_join: &NestedLoopJoinExec,
-        config: &ConfigOptions,
-    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
-        let Some(options) = config.extensions.get::<SedonaOptions>() else {
-            return Ok(None);
-        };
-
-        if let Some(join_filter) = nested_loop_join.filter() {
-            if let Some((spatial_predicate, remainder)) = 
transform_join_filter(join_filter) {
-                // The left side of the nested loop join is required to have 
only one partition, while SpatialJoinExec
-                // does not have that requirement. SpatialJoinExec can consume 
the streams on the build side in parallel
-                // when the build side has multiple partitions.
-                // If the left side is a CoalescePartitionsExec, we can drop 
the CoalescePartitionsExec and directly use
-                // the input.
-                let left = nested_loop_join.left();
-                let left = if let Some(coalesce_partitions) =
-                    left.as_any().downcast_ref::<CoalescePartitionsExec>()
-                {
-                    // Remove unnecessary CoalescePartitionsExec for spatial 
joins
-                    coalesce_partitions.input()
-                } else {
-                    left
-                };
-
-                let left = left.clone();
-                let right = nested_loop_join.right().clone();
-                let join_type = nested_loop_join.join_type();
-
-                // Check if the geospatial types involved in spatial_predicate 
are supported
-                if !is_spatial_predicate_supported(
-                    &spatial_predicate,
-                    &left.schema(),
-                    &right.schema(),
-                )? {
-                    return Ok(None);
+pub(crate) fn collect_spatial_predicate_names(expr: &Expr) -> HashSet<String> {
+    fn collect(expr: &Expr, acc: &mut HashSet<String>) {
+        match expr {
+            Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
+                left, right, op, ..
+            }) => match op {
+                Operator::And => {
+                    collect(left, acc);
+                    collect(right, acc);
                 }
-
-                // Create the spatial join
-                let spatial_join = SpatialJoinExec::try_new(
-                    left,
-                    right,
-                    spatial_predicate,
-                    remainder,
-                    join_type,
-                    nested_loop_join.projection().cloned(),
-                    &options.spatial_join,
-                )?;
-
-                return Ok(Some(Arc::new(spatial_join)));
-            }
-        }
-
-        Ok(None)
-    }
-
-    /// Try to convert a HashJoinExec with spatial predicates in the filter to 
a SpatialJoinExec.
-    /// This handles cases where there's an equi-join condition (like c.id = 
r.id) along with
-    /// the ST_KNN predicate. We flip them so the spatial predicate drives the 
join
-    /// and the equi-conditions become filters.
-    fn try_convert_hash_join_to_spatial(
-        &self,
-        hash_join: &HashJoinExec,
-        config: &ConfigOptions,
-    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
-        let Some(options) = config.extensions.get::<SedonaOptions>() else {
-            return Ok(None);
-        };
-
-        // Check if the filter contains spatial predicates
-        if let Some(join_filter) = hash_join.filter() {
-            if let Some((spatial_predicate, mut remainder)) = 
transform_join_filter(join_filter) {
-                // The transform_join_filter now prioritizes ST_KNN predicates
-                // Only proceed if we found an ST_KNN (other spatial 
predicates are left in hash join)
-                if !matches!(spatial_predicate, 
SpatialPredicate::KNearestNeighbors(_)) {
-                    return Ok(None);
+                Operator::Lt | Operator::LtEq => {
+                    if is_distance_expr(left) {
+                        acc.insert("st_dwithin".to_string());
+                    }
                 }
-
-                // Check if the geospatial types involved in spatial_predicate 
are supported (planar geometries only)
-                if !is_spatial_predicate_supported(
-                    &spatial_predicate,
-                    &hash_join.left().schema(),
-                    &hash_join.right().schema(),
-                )? {
-                    return Ok(None);
+                Operator::Gt | Operator::GtEq => {
+                    if is_distance_expr(right) {
+                        acc.insert("st_dwithin".to_string());
+                    }
+                }
+                _ => (),
+            },
+            Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, 
.. }) => {
+                let func_name = func.name().to_lowercase();
+                if matches!(
+                    func_name.as_str(),
+                    "st_intersects"
+                        | "st_contains"
+                        | "st_within"
+                        | "st_covers"
+                        | "st_covered_by"
+                        | "st_coveredby"
+                        | "st_touches"
+                        | "st_crosses"
+                        | "st_overlaps"
+                        | "st_equals"
+                        | "st_dwithin"
+                        | "st_knn"
+                ) {
+                    acc.insert(func_name);
                 }
-
-                // Extract the equi-join conditions and convert them to a 
filter
-                let equi_filter = 
self.create_equi_filter_from_hash_join(hash_join)?;
-
-                // Combine the equi-filter with any existing remainder
-                remainder = self.combine_filters(remainder, equi_filter)?;
-
-                // Create spatial join where:
-                // - Spatial predicate (ST_KNN) drives the join
-                // - Equi-conditions (c.id = r.id) become filters
-
-                // Create SpatialJoinExec without projection first
-                // Use try_new_with_options to mark this as converted from 
HashJoin
-                let spatial_join = 
Arc::new(SpatialJoinExec::try_new_with_options(
-                    hash_join.left().clone(),
-                    hash_join.right().clone(),
-                    spatial_predicate,
-                    remainder,
-                    hash_join.join_type(),
-                    None, // No projection in SpatialJoinExec
-                    &options.spatial_join,
-                    true, // converted_from_hash_join = true
-                )?);
-
-                // Now wrap it with ProjectionExec to match HashJoinExec's 
output schema exactly
-                let expected_schema = hash_join.schema();
-                let spatial_schema = spatial_join.schema();
-
-                // Create a projection that selects the exact columns 
HashJoinExec would output
-                let projection_exec = self.create_schema_matching_projection(
-                    spatial_join,
-                    &expected_schema,
-                    &spatial_schema,
-                )?;
-
-                return Ok(Some(projection_exec));
             }
+            _ => (),
         }
-
-        Ok(None)
     }
 
-    /// Create a filter expression from the hash join's equi-join conditions
-    fn create_equi_filter_from_hash_join(
-        &self,
-        hash_join: &HashJoinExec,
-    ) -> Result<Option<JoinFilter>> {
-        let join_keys = hash_join.on();
-
-        if join_keys.is_empty() {
-            return Ok(None);
-        }
-
-        // Build filter expressions from the equi-join conditions
-        let mut expressions = vec![];
-
-        // Get the left schema size to calculate right column offsets
-        let left_schema_size = hash_join.left().schema().fields().len();
-
-        for (left_key, right_key) in join_keys.iter() {
-            // Create equality expression: left_key = right_key
-            // But we need to adjust the column indices for SpatialJoinExec 
schema
-            if let (Some(left_col), Some(right_col)) = (
-                left_key.as_any().downcast_ref::<Column>(),
-                right_key.as_any().downcast_ref::<Column>(),
-            ) {
-                // In SpatialJoinExec schema: [left_fields..., right_fields...]
-                // Left columns keep their indices, right columns get offset 
by left_schema_size
-                let left_idx = left_col.index();
-                let right_idx = left_schema_size + right_col.index();
-
-                let left_expr =
-                    Arc::new(Column::new(left_col.name(), left_idx)) as 
Arc<dyn PhysicalExpr>;
-                let right_expr =
-                    Arc::new(Column::new(right_col.name(), right_idx)) as 
Arc<dyn PhysicalExpr>;
-
-                let eq_expr = Arc::new(BinaryExpr::new(left_expr, 
Operator::Eq, right_expr))
-                    as Arc<dyn PhysicalExpr>;
-
-                expressions.push(eq_expr);
-            }
-        }
-
-        // IMPORTANT: Create column indices for ALL columns in the spatial 
join schema
-        // not just the filter columns. This is required by 
build_batch_from_indices.
-        let left_schema = hash_join.left().schema();
-        let right_schema = hash_join.right().schema();
-        let mut column_indices = vec![];
-
-        // Add all left side columns
-        for (i, _field) in left_schema.fields().iter().enumerate() {
-            column_indices.push(ColumnIndex {
-                index: i,
-                side: JoinSide::Left,
-            });
-        }
-
-        // Add all right side columns
-        for (i, _field) in right_schema.fields().iter().enumerate() {
-            column_indices.push(ColumnIndex {
-                index: i,
-                side: JoinSide::Right,
-            });
-        }
-
-        // Combine all conditions with AND
-        let filter_expr = if expressions.len() == 1 {
-            expressions.into_iter().next().unwrap()
-        } else {
-            expressions
-                .into_iter()
-                .reduce(|acc, expr| {
-                    Arc::new(BinaryExpr::new(acc, Operator::And, expr)) as 
Arc<dyn PhysicalExpr>
-                })
-                .unwrap()
+    fn is_distance_expr(expr: &Expr) -> bool {
+        let Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { func, 
.. }) = expr else {
+            return false;

Review Comment:
   Perhaps not for this PR, but some of these may fit well in sedona-expr (so 
we can use them in future optimizer rules)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to