This is an automated email from the ASF dual-hosted git repository.

paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 62f8376a perf(rust/sedona-spatial-join): Use row count first to decide 
join order (#725)
62f8376a is described below

commit 62f8376ac50650d9fe7cc5ac6fb299e13dca023b
Author: Yongting You <[email protected]>
AuthorDate: Fri Mar 20 02:31:53 2026 +0800

    perf(rust/sedona-spatial-join): Use row count first to decide join order 
(#725)
---
 .../src/planner/physical_planner.rs                |  34 ++-
 .../tests/spatial_join_integration.rs              | 280 ++++++++++++++++++++-
 2 files changed, 297 insertions(+), 17 deletions(-)

diff --git a/rust/sedona-spatial-join/src/planner/physical_planner.rs 
b/rust/sedona-spatial-join/src/planner/physical_planner.rs
index 33fa3a44..b99fe4e8 100644
--- a/rust/sedona-spatial-join/src/planner/physical_planner.rs
+++ b/rust/sedona-spatial-join/src/planner/physical_planner.rs
@@ -187,23 +187,37 @@ impl ExtensionPlanner for SpatialJoinExtensionPlanner {
     }
 }
 
+/// Spatial join reordering heuristic:
+/// 1. Put the input with fewer rows on the build side, because fewer entries
+///    produce a smaller and more efficient spatial index (R-tree).
+/// 2. If row-count statistics are unavailable (for example, for CSV sources),
+///    fall back to total input size as an estimate.
+/// 3. Do not swap the join order if no relevant statistics are available.
 fn should_swap_join_order(left: &dyn ExecutionPlan, right: &dyn ExecutionPlan) 
-> Result<bool> {
     let left_stats = left.partition_statistics(None)?;
     let right_stats = right.partition_statistics(None)?;
 
-    match (
-        left_stats.total_byte_size.get_value(),
-        right_stats.total_byte_size.get_value(),
-    ) {
-        (Some(l), Some(r)) => Ok(l > r),
+    let left_num_rows = left_stats.num_rows;
+    let right_num_rows = right_stats.num_rows;
+    let left_total_byte_size = left_stats.total_byte_size;
+    let right_total_byte_size = right_stats.total_byte_size;
+
+    let should_swap = match (left_num_rows.get_value(), 
right_num_rows.get_value()) {
+        (Some(l), Some(r)) => l > r,
         _ => match (
-            left_stats.num_rows.get_value(),
-            right_stats.num_rows.get_value(),
+            left_total_byte_size.get_value(),
+            right_total_byte_size.get_value(),
         ) {
-            (Some(l), Some(r)) => Ok(l > r),
-            _ => Ok(false),
+            (Some(l), Some(r)) => l > r,
+            _ => false,
         },
-    }
+    };
+
+    log::info!(
+        "spatial join swap heuristic: left_num_rows={left_num_rows:?}, 
right_num_rows={right_num_rows:?}, 
left_total_byte_size={left_total_byte_size:?}, 
right_total_byte_size={right_total_byte_size:?}, should_swap={should_swap}"
+    );
+
+    Ok(should_swap)
 }
 
 /// This function is mostly taken from the match arm for handling 
LogicalPlan::Join in
diff --git a/rust/sedona-spatial-join/tests/spatial_join_integration.rs 
b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
index 35338017..da54f52e 100644
--- a/rust/sedona-spatial-join/tests/spatial_join_integration.rs
+++ b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
@@ -15,22 +15,27 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::sync::Arc;
+use std::{any::Any, sync::Arc};
 
-use arrow_array::{Array, RecordBatch};
+use arrow_array::{Array, Float64Array, Int32Array, RecordBatch, StringArray};
 use arrow_schema::{DataType, Field, Schema, SchemaRef};
+use async_trait::async_trait;
 use datafusion::{
-    catalog::{MemTable, TableProvider},
-    datasource::empty::EmptyTable,
+    catalog::{MemTable, Session, TableProvider},
+    datasource::{empty::EmptyTable, TableType},
     execution::SessionStateBuilder,
     prelude::{SessionConfig, SessionContext},
 };
 use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
-use datafusion_common::{JoinSide, Result};
-use datafusion_expr::{ColumnarValue, JoinType};
+use datafusion_common::{stats::Precision, JoinSide, Result, Statistics};
+use datafusion_execution::TaskContext;
+use datafusion_expr::{ColumnarValue, Expr, JoinType};
 use datafusion_physical_plan::filter::FilterExec;
 use datafusion_physical_plan::joins::NestedLoopJoinExec;
-use datafusion_physical_plan::ExecutionPlan;
+use datafusion_physical_plan::{
+    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, 
PlanProperties,
+    SendableRecordBatchStream,
+};
 use geo::{Distance, Euclidean};
 use geo_types::{Coord, Rect};
 use rstest::rstest;
@@ -175,6 +180,232 @@ fn setup_context(options: Option<SpatialJoinOptions>, 
batch_size: usize) -> Resu
     Ok(ctx)
 }
 
+#[derive(Debug)]
+struct StatsOverrideTableProvider {
+    inner: Arc<dyn TableProvider>,
+    stats: Statistics,
+}
+
+impl StatsOverrideTableProvider {
+    fn new(inner: Arc<dyn TableProvider>, stats: Statistics) -> Self {
+        Self { inner, stats }
+    }
+}
+
+#[async_trait]
+impl TableProvider for StatsOverrideTableProvider {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        self.inner.schema()
+    }
+
+    fn table_type(&self) -> TableType {
+        self.inner.table_type()
+    }
+
+    async fn scan(
+        &self,
+        state: &dyn Session,
+        projection: Option<&Vec<usize>>,
+        filters: &[Expr],
+        limit: Option<usize>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        let inner = self.inner.scan(state, projection, filters, limit).await?;
+        Ok(Arc::new(StatsOverrideExec::new(inner, self.stats.clone())))
+    }
+}
+
+#[derive(Debug)]
+struct StatsOverrideExec {
+    inner: Arc<dyn ExecutionPlan>,
+    stats: Statistics,
+    properties: PlanProperties,
+}
+
+impl StatsOverrideExec {
+    fn new(inner: Arc<dyn ExecutionPlan>, stats: Statistics) -> Self {
+        let properties = PlanProperties::new(
+            inner.equivalence_properties().clone(),
+            inner.output_partitioning().clone(),
+            inner.pipeline_behavior(),
+            inner.boundedness(),
+        );
+        Self {
+            inner,
+            stats,
+            properties,
+        }
+    }
+}
+
+#[derive(Clone, Copy)]
+enum OriginalInputSide {
+    Left,
+    Right,
+}
+
+impl DisplayAs for StatsOverrideExec {
+    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> 
std::fmt::Result {
+        write!(f, "StatsOverrideExec")
+    }
+}
+
+fn stats_with(
+    schema: &Schema,
+    num_rows: Option<usize>,
+    total_byte_size: Option<usize>,
+) -> Statistics {
+    let mut stats = Statistics::new_unknown(schema);
+    if let Some(num_rows) = num_rows {
+        stats = stats.with_num_rows(Precision::Exact(num_rows));
+    }
+    if let Some(total_byte_size) = total_byte_size {
+        stats = stats.with_total_byte_size(Precision::Exact(total_byte_size));
+    }
+    stats
+}
+
+fn single_row_table(schema: SchemaRef, id: i32, marker: &str) -> 
Result<Arc<dyn TableProvider>> {
+    let batch = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![id])),
+            Arc::new(Float64Array::from(vec![0.0])),
+            Arc::new(StringArray::from(vec![marker])),
+        ],
+    )?;
+    Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
+}
+
+// Keep the data fixed and vary only the advertised stats so the planner swap
+// decision is explained entirely by the heuristic under test.
+async fn assert_build_side_from_stats(
+    left_num_rows: Option<usize>,
+    right_num_rows: Option<usize>,
+    left_total_byte_size: Option<usize>,
+    right_total_byte_size: Option<usize>,
+    expected_build_side: OriginalInputSide,
+) -> Result<()> {
+    let left_schema = Arc::new(Schema::new(vec![
+        Field::new("id", DataType::Int32, false),
+        Field::new("x", DataType::Float64, false),
+        Field::new("l_marker", DataType::Utf8, false),
+    ]));
+    let right_schema = Arc::new(Schema::new(vec![
+        Field::new("id", DataType::Int32, false),
+        Field::new("x", DataType::Float64, false),
+        Field::new("r_marker", DataType::Utf8, false),
+    ]));
+
+    let left_provider: Arc<dyn TableProvider> = 
Arc::new(StatsOverrideTableProvider::new(
+        single_row_table(left_schema.clone(), 1, "left")?,
+        stats_with(left_schema.as_ref(), left_num_rows, left_total_byte_size),
+    ));
+    let right_provider: Arc<dyn TableProvider> = 
Arc::new(StatsOverrideTableProvider::new(
+        single_row_table(right_schema.clone(), 10, "right")?,
+        stats_with(right_schema.as_ref(), right_num_rows, 
right_total_byte_size),
+    ));
+
+    let ctx = setup_context(Some(SpatialJoinOptions::default()), 10)?;
+    ctx.register_table("L", left_provider)?;
+    ctx.register_table("R", right_provider)?;
+
+    let df = ctx
+        .sql("SELECT * FROM L JOIN R ON ST_Intersects(ST_Point(L.x, 0), 
ST_Point(R.x, 0))")
+        .await?;
+    let plan = df.clone().create_physical_plan().await?;
+    let spatial_join_execs = collect_spatial_join_exec(&plan)?;
+    assert_eq!(
+        spatial_join_execs.len(),
+        1,
+        "expected exactly one SpatialJoinExec"
+    );
+
+    let spatial_join = spatial_join_execs[0];
+    let expected_marker = match expected_build_side {
+        OriginalInputSide::Left => "l_marker",
+        OriginalInputSide::Right => "r_marker",
+    };
+    let expected_probe_marker = match expected_build_side {
+        OriginalInputSide::Left => "r_marker",
+        OriginalInputSide::Right => "l_marker",
+    };
+    assert!(spatial_join.left.schema().index_of(expected_marker).is_ok());
+    assert!(spatial_join
+        .right
+        .schema()
+        .index_of(expected_probe_marker)
+        .is_ok());
+
+    let result_batches = df.collect().await?;
+    assert_eq!(
+        result_batches
+            .iter()
+            .map(RecordBatch::num_rows)
+            .sum::<usize>(),
+        1
+    );
+
+    Ok(())
+}
+
+impl ExecutionPlan for StatsOverrideExec {
+    fn name(&self) -> &str {
+        "StatsOverrideExec"
+    }
+
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn properties(&self) -> &PlanProperties {
+        &self.properties
+    }
+
+    fn maintains_input_order(&self) -> Vec<bool> {
+        vec![true]
+    }
+
+    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
+        vec![&self.inner]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        children: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        match children.as_slice() {
+            [child] => Ok(Arc::new(Self::new(Arc::clone(child), 
self.stats.clone()))),
+            _ => datafusion_common::internal_err!(
+                "StatsOverrideExec expects exactly one child, got {}",
+                children.len()
+            ),
+        }
+    }
+
+    fn execute(
+        &self,
+        partition: usize,
+        context: Arc<TaskContext>,
+    ) -> Result<SendableRecordBatchStream> {
+        self.inner.execute(partition, context)
+    }
+
+    fn statistics(&self) -> Result<Statistics> {
+        Ok(self.stats.clone())
+    }
+
+    fn partition_statistics(&self, partition: Option<usize>) -> 
Result<Statistics> {
+        match partition {
+            None => Ok(self.stats.clone()),
+            Some(partition) => 
self.inner.partition_statistics(Some(partition)),
+        }
+    }
+}
+
 #[tokio::test]
 async fn test_empty_data() -> Result<()> {
     let schema = Arc::new(Schema::new(vec![
@@ -456,6 +687,41 @@ async fn test_spatial_join_swap_inputs_produces_same_plan(
     Ok(())
 }
 
+#[tokio::test]
+// When both row count and size are present, the planner swaps to put the
+// smaller-row input on the build side even if it is larger by byte size.
+async fn test_spatial_join_reordering_uses_row_count() -> Result<()> {
+    assert_build_side_from_stats(
+        Some(100),
+        Some(10),
+        Some(100),
+        Some(10_000),
+        OriginalInputSide::Right,
+    )
+    .await
+}
+
+#[tokio::test]
+// When row count is absent on both sides, the planner swaps to put the
+// smaller-bytes input on the build side.
+async fn test_spatial_join_reordering_uses_size_fallback() -> Result<()> {
+    assert_build_side_from_stats(
+        None,
+        None,
+        Some(10_000),
+        Some(100),
+        OriginalInputSide::Right,
+    )
+    .await
+}
+
+#[tokio::test]
+// When both row count and size are absent, the planner preserves the original
+// join order.
+async fn test_spatial_join_reordering_preserves_order_without_stats() -> 
Result<()> {
+    assert_build_side_from_stats(None, None, None, None, 
OriginalInputSide::Left).await
+}
+
 #[tokio::test]
 async fn test_range_join_with_empty_partitions() -> Result<()> {
     let ((left_schema, left_partitions), (right_schema, right_partitions)) =

Reply via email to