This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new 62f8376a perf(rust/sedona-spatial-join): Use row count first to decide
join order (#725)
62f8376a is described below
commit 62f8376ac50650d9fe7cc5ac6fb299e13dca023b
Author: Yongting You <[email protected]>
AuthorDate: Fri Mar 20 02:31:53 2026 +0800
perf(rust/sedona-spatial-join): Use row count first to decide join order
(#725)
---
.../src/planner/physical_planner.rs | 34 ++-
.../tests/spatial_join_integration.rs | 280 ++++++++++++++++++++-
2 files changed, 297 insertions(+), 17 deletions(-)
diff --git a/rust/sedona-spatial-join/src/planner/physical_planner.rs
b/rust/sedona-spatial-join/src/planner/physical_planner.rs
index 33fa3a44..b99fe4e8 100644
--- a/rust/sedona-spatial-join/src/planner/physical_planner.rs
+++ b/rust/sedona-spatial-join/src/planner/physical_planner.rs
@@ -187,23 +187,37 @@ impl ExtensionPlanner for SpatialJoinExtensionPlanner {
}
}
+/// Spatial join reordering heuristic:
+/// 1. Put the input with fewer rows on the build side, because fewer entries
+/// produce a smaller and more efficient spatial index (R-tree).
+/// 2. If row-count statistics are unavailable (for example, for CSV sources),
+/// fall back to total input size as an estimate.
+/// 3. Do not swap the join order if no relevant statistics are available.
fn should_swap_join_order(left: &dyn ExecutionPlan, right: &dyn ExecutionPlan)
-> Result<bool> {
let left_stats = left.partition_statistics(None)?;
let right_stats = right.partition_statistics(None)?;
- match (
- left_stats.total_byte_size.get_value(),
- right_stats.total_byte_size.get_value(),
- ) {
- (Some(l), Some(r)) => Ok(l > r),
+ let left_num_rows = left_stats.num_rows;
+ let right_num_rows = right_stats.num_rows;
+ let left_total_byte_size = left_stats.total_byte_size;
+ let right_total_byte_size = right_stats.total_byte_size;
+
+ let should_swap = match (left_num_rows.get_value(),
right_num_rows.get_value()) {
+ (Some(l), Some(r)) => l > r,
_ => match (
- left_stats.num_rows.get_value(),
- right_stats.num_rows.get_value(),
+ left_total_byte_size.get_value(),
+ right_total_byte_size.get_value(),
) {
- (Some(l), Some(r)) => Ok(l > r),
- _ => Ok(false),
+ (Some(l), Some(r)) => l > r,
+ _ => false,
},
- }
+ };
+
+ log::info!(
+ "spatial join swap heuristic: left_num_rows={left_num_rows:?},
right_num_rows={right_num_rows:?},
left_total_byte_size={left_total_byte_size:?},
right_total_byte_size={right_total_byte_size:?}, should_swap={should_swap}"
+ );
+
+ Ok(should_swap)
}
/// This function is mostly taken from the match arm for handling
LogicalPlan::Join in
diff --git a/rust/sedona-spatial-join/tests/spatial_join_integration.rs
b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
index 35338017..da54f52e 100644
--- a/rust/sedona-spatial-join/tests/spatial_join_integration.rs
+++ b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
@@ -15,22 +15,27 @@
// specific language governing permissions and limitations
// under the License.
-use std::sync::Arc;
+use std::{any::Any, sync::Arc};
-use arrow_array::{Array, RecordBatch};
+use arrow_array::{Array, Float64Array, Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
+use async_trait::async_trait;
use datafusion::{
- catalog::{MemTable, TableProvider},
- datasource::empty::EmptyTable,
+ catalog::{MemTable, Session, TableProvider},
+ datasource::{empty::EmptyTable, TableType},
execution::SessionStateBuilder,
prelude::{SessionConfig, SessionContext},
};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
-use datafusion_common::{JoinSide, Result};
-use datafusion_expr::{ColumnarValue, JoinType};
+use datafusion_common::{stats::Precision, JoinSide, Result, Statistics};
+use datafusion_execution::TaskContext;
+use datafusion_expr::{ColumnarValue, Expr, JoinType};
use datafusion_physical_plan::filter::FilterExec;
use datafusion_physical_plan::joins::NestedLoopJoinExec;
-use datafusion_physical_plan::ExecutionPlan;
+use datafusion_physical_plan::{
+ DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties,
PlanProperties,
+ SendableRecordBatchStream,
+};
use geo::{Distance, Euclidean};
use geo_types::{Coord, Rect};
use rstest::rstest;
@@ -175,6 +180,232 @@ fn setup_context(options: Option<SpatialJoinOptions>,
batch_size: usize) -> Resu
Ok(ctx)
}
+#[derive(Debug)]
+struct StatsOverrideTableProvider {
+ inner: Arc<dyn TableProvider>,
+ stats: Statistics,
+}
+
+impl StatsOverrideTableProvider {
+ fn new(inner: Arc<dyn TableProvider>, stats: Statistics) -> Self {
+ Self { inner, stats }
+ }
+}
+
+#[async_trait]
+impl TableProvider for StatsOverrideTableProvider {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ self.inner.schema()
+ }
+
+ fn table_type(&self) -> TableType {
+ self.inner.table_type()
+ }
+
+ async fn scan(
+ &self,
+ state: &dyn Session,
+ projection: Option<&Vec<usize>>,
+ filters: &[Expr],
+ limit: Option<usize>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ let inner = self.inner.scan(state, projection, filters, limit).await?;
+ Ok(Arc::new(StatsOverrideExec::new(inner, self.stats.clone())))
+ }
+}
+
+#[derive(Debug)]
+struct StatsOverrideExec {
+ inner: Arc<dyn ExecutionPlan>,
+ stats: Statistics,
+ properties: PlanProperties,
+}
+
+impl StatsOverrideExec {
+ fn new(inner: Arc<dyn ExecutionPlan>, stats: Statistics) -> Self {
+ let properties = PlanProperties::new(
+ inner.equivalence_properties().clone(),
+ inner.output_partitioning().clone(),
+ inner.pipeline_behavior(),
+ inner.boundedness(),
+ );
+ Self {
+ inner,
+ stats,
+ properties,
+ }
+ }
+}
+
+#[derive(Clone, Copy)]
+enum OriginalInputSide {
+ Left,
+ Right,
+}
+
+impl DisplayAs for StatsOverrideExec {
+ fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) ->
std::fmt::Result {
+ write!(f, "StatsOverrideExec")
+ }
+}
+
+fn stats_with(
+ schema: &Schema,
+ num_rows: Option<usize>,
+ total_byte_size: Option<usize>,
+) -> Statistics {
+ let mut stats = Statistics::new_unknown(schema);
+ if let Some(num_rows) = num_rows {
+ stats = stats.with_num_rows(Precision::Exact(num_rows));
+ }
+ if let Some(total_byte_size) = total_byte_size {
+ stats = stats.with_total_byte_size(Precision::Exact(total_byte_size));
+ }
+ stats
+}
+
+fn single_row_table(schema: SchemaRef, id: i32, marker: &str) ->
Result<Arc<dyn TableProvider>> {
+ let batch = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(Int32Array::from(vec![id])),
+ Arc::new(Float64Array::from(vec![0.0])),
+ Arc::new(StringArray::from(vec![marker])),
+ ],
+ )?;
+ Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
+}
+
+// Keep the data fixed and vary only the advertised stats so the planner swap
+// decision is explained entirely by the heuristic under test.
+async fn assert_build_side_from_stats(
+ left_num_rows: Option<usize>,
+ right_num_rows: Option<usize>,
+ left_total_byte_size: Option<usize>,
+ right_total_byte_size: Option<usize>,
+ expected_build_side: OriginalInputSide,
+) -> Result<()> {
+ let left_schema = Arc::new(Schema::new(vec![
+ Field::new("id", DataType::Int32, false),
+ Field::new("x", DataType::Float64, false),
+ Field::new("l_marker", DataType::Utf8, false),
+ ]));
+ let right_schema = Arc::new(Schema::new(vec![
+ Field::new("id", DataType::Int32, false),
+ Field::new("x", DataType::Float64, false),
+ Field::new("r_marker", DataType::Utf8, false),
+ ]));
+
+ let left_provider: Arc<dyn TableProvider> =
Arc::new(StatsOverrideTableProvider::new(
+ single_row_table(left_schema.clone(), 1, "left")?,
+ stats_with(left_schema.as_ref(), left_num_rows, left_total_byte_size),
+ ));
+ let right_provider: Arc<dyn TableProvider> =
Arc::new(StatsOverrideTableProvider::new(
+ single_row_table(right_schema.clone(), 10, "right")?,
+ stats_with(right_schema.as_ref(), right_num_rows,
right_total_byte_size),
+ ));
+
+ let ctx = setup_context(Some(SpatialJoinOptions::default()), 10)?;
+ ctx.register_table("L", left_provider)?;
+ ctx.register_table("R", right_provider)?;
+
+ let df = ctx
+ .sql("SELECT * FROM L JOIN R ON ST_Intersects(ST_Point(L.x, 0),
ST_Point(R.x, 0))")
+ .await?;
+ let plan = df.clone().create_physical_plan().await?;
+ let spatial_join_execs = collect_spatial_join_exec(&plan)?;
+ assert_eq!(
+ spatial_join_execs.len(),
+ 1,
+ "expected exactly one SpatialJoinExec"
+ );
+
+ let spatial_join = spatial_join_execs[0];
+ let expected_marker = match expected_build_side {
+ OriginalInputSide::Left => "l_marker",
+ OriginalInputSide::Right => "r_marker",
+ };
+ let expected_probe_marker = match expected_build_side {
+ OriginalInputSide::Left => "r_marker",
+ OriginalInputSide::Right => "l_marker",
+ };
+ assert!(spatial_join.left.schema().index_of(expected_marker).is_ok());
+ assert!(spatial_join
+ .right
+ .schema()
+ .index_of(expected_probe_marker)
+ .is_ok());
+
+ let result_batches = df.collect().await?;
+ assert_eq!(
+ result_batches
+ .iter()
+ .map(RecordBatch::num_rows)
+ .sum::<usize>(),
+ 1
+ );
+
+ Ok(())
+}
+
+impl ExecutionPlan for StatsOverrideExec {
+ fn name(&self) -> &str {
+ "StatsOverrideExec"
+ }
+
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn properties(&self) -> &PlanProperties {
+ &self.properties
+ }
+
+ fn maintains_input_order(&self) -> Vec<bool> {
+ vec![true]
+ }
+
+ fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
+ vec![&self.inner]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn ExecutionPlan>>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ match children.as_slice() {
+ [child] => Ok(Arc::new(Self::new(Arc::clone(child),
self.stats.clone()))),
+ _ => datafusion_common::internal_err!(
+ "StatsOverrideExec expects exactly one child, got {}",
+ children.len()
+ ),
+ }
+ }
+
+ fn execute(
+ &self,
+ partition: usize,
+ context: Arc<TaskContext>,
+ ) -> Result<SendableRecordBatchStream> {
+ self.inner.execute(partition, context)
+ }
+
+ fn statistics(&self) -> Result<Statistics> {
+ Ok(self.stats.clone())
+ }
+
+ fn partition_statistics(&self, partition: Option<usize>) ->
Result<Statistics> {
+ match partition {
+ None => Ok(self.stats.clone()),
+ Some(partition) =>
self.inner.partition_statistics(Some(partition)),
+ }
+ }
+}
+
#[tokio::test]
async fn test_empty_data() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
@@ -456,6 +687,41 @@ async fn test_spatial_join_swap_inputs_produces_same_plan(
Ok(())
}
+#[tokio::test]
+// When both row count and size are present, the planner swaps to put the
+// smaller-row input on the build side even if it is larger by byte size.
+async fn test_spatial_join_reordering_uses_row_count() -> Result<()> {
+ assert_build_side_from_stats(
+ Some(100),
+ Some(10),
+ Some(100),
+ Some(10_000),
+ OriginalInputSide::Right,
+ )
+ .await
+}
+
+#[tokio::test]
+// When row count is absent on both sides, the planner swaps to put the
+// smaller-bytes input on the build side.
+async fn test_spatial_join_reordering_uses_size_fallback() -> Result<()> {
+ assert_build_side_from_stats(
+ None,
+ None,
+ Some(10_000),
+ Some(100),
+ OriginalInputSide::Right,
+ )
+ .await
+}
+
+#[tokio::test]
+// When both row count and size are absent, the planner preserves the original
+// join order.
+async fn test_spatial_join_reordering_preserves_order_without_stats() ->
Result<()> {
+ assert_build_side_from_stats(None, None, None, None,
OriginalInputSide::Left).await
+}
+
#[tokio::test]
async fn test_range_join_with_empty_partitions() -> Result<()> {
let ((left_schema, left_partitions), (right_schema, right_partitions)) =