Dandandan commented on code in PR #3632:
URL: https://github.com/apache/datafusion-comet/pull/3632#discussion_r3176806147


##########
native/core/src/execution/operators/grace_hash_join.rs:
##########
@@ -0,0 +1,2837 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Grace Hash Join operator for Apache DataFusion Comet.
+//!
+//! Partitions both build and probe sides into N buckets by hashing join keys,
+//! then performs per-partition hash joins. Spills partitions to disk (Arrow 
IPC)
+//! when memory is tight.
+//!
+//! Supports all join types. Recursively repartitions oversized partitions
+//! up to `MAX_RECURSION_DEPTH` levels.
+
+use std::any::Any;
+use std::fmt;
+use std::fs::File;
+use std::io::{BufReader, BufWriter};
+use std::sync::Arc;
+use std::sync::Mutex;
+
+use ahash::RandomState;
+use arrow::array::UInt32Array;
+use arrow::compute::{concat_batches, take};
+use arrow::datatypes::SchemaRef;
+use arrow::ipc::reader::StreamReader;
+use arrow::ipc::writer::{IpcWriteOptions, StreamWriter};
+use arrow::ipc::CompressionType;
+use arrow::record_batch::RecordBatch;
+use datafusion::common::hash_utils::create_hashes;
+use datafusion::common::{DataFusionError, JoinType, NullEquality, Result as 
DFResult};
+use datafusion::execution::context::TaskContext;
+use datafusion::execution::disk_manager::RefCountedTempFile;
+use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
+use datafusion::physical_expr::EquivalenceProperties;
+use datafusion::physical_expr::PhysicalExpr;
+use datafusion::physical_plan::display::DisplayableExecutionPlan;
+use datafusion::physical_plan::joins::utils::JoinFilter;
+use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
+use datafusion::physical_plan::metrics::{
+    BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, 
MetricsSet, Time,
+};
+use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
+use datafusion::physical_plan::{
+    DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
+    SendableRecordBatchStream,
+};
+use futures::stream::{self, StreamExt, TryStreamExt};
+use futures::Stream;
+use log::info;
+use tokio::sync::mpsc;
+
+/// Global atomic counter for unique GHJ instance IDs (debug tracing).
+static GHJ_INSTANCE_COUNTER: std::sync::atomic::AtomicUsize =
+    std::sync::atomic::AtomicUsize::new(0);
+
+/// Type alias for join key expression pairs.
+type JoinOnRef<'a> = &'a [(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)];
+
+/// Number of partitions (buckets) for the grace hash join.
+const DEFAULT_NUM_PARTITIONS: usize = 16;
+
+/// Maximum recursion depth for repartitioning oversized partitions.
+/// At depth 3 with 16 partitions per level, effective partitions = 16^3 = 
4096.
+const MAX_RECURSION_DEPTH: usize = 3;
+
+/// I/O buffer size for spill file reads and writes. The default 
BufReader/BufWriter
+/// size (8 KB) is far too small for multi-GB spill files. 1 MB provides good
+/// sequential throughput while keeping per-partition memory overhead modest.
+const SPILL_IO_BUFFER_SIZE: usize = 1024 * 1024;
+
+/// Log progress every N probe rows accumulated.
+const PROBE_PROGRESS_MILESTONE_ROWS: usize = 5_000_000;
+
+/// Target number of rows per coalesced batch when reading spill files.
+/// Spill files contain many tiny sub-batches (from partitioning). Coalescing
+/// into larger batches reduces per-batch overhead in the hash join kernel
+/// and channel send/recv costs.
+const SPILL_READ_COALESCE_TARGET: usize = 8192;
+
+/// Target build-side size per merged partition. After Phase 2, adjacent
+/// `FinishedPartition`s are merged so each group has roughly this much
+/// build data, reducing the number of per-partition HashJoinExec calls.
+const TARGET_PARTITION_BUILD_SIZE: usize = 32 * 1024 * 1024;
+
+/// Random state for hashing join keys into partitions. Uses fixed seeds
+/// different from DataFusion's HashJoinExec to avoid correlation.
+/// The `recursion_level` is XORed into the seed so that recursive
+/// repartitioning uses different hash functions at each level.
+fn partition_random_state(recursion_level: usize) -> RandomState {
+    RandomState::with_seeds(
+        0x517cc1b727220a95 ^ (recursion_level as u64),
+        0x3a8b7c9d1e2f4056,
+        0,
+        0,
+    )
+}
+
+// ---------------------------------------------------------------------------
+// SpillWriter: incremental append to Arrow IPC spill files
+// ---------------------------------------------------------------------------
+
+/// Wraps an Arrow IPC `StreamWriter` for incremental spill writes.
+/// Avoids the O(n²) read-rewrite pattern by keeping the writer open.
+struct SpillWriter {
+    writer: StreamWriter<BufWriter<File>>,
+    temp_file: RefCountedTempFile,
+    bytes_written: usize,
+}
+
+impl SpillWriter {
+    /// Create a new spill writer backed by a temp file.
+    fn new(temp_file: RefCountedTempFile, schema: &SchemaRef) -> 
DFResult<Self> {
+        let file = std::fs::OpenOptions::new()
+            .write(true)
+            .create(true)
+            .truncate(true)
+            .open(temp_file.path())
+            .map_err(|e| DataFusionError::Execution(format!("Failed to open 
spill file: {e}")))?;
+        let buf_writer = BufWriter::with_capacity(SPILL_IO_BUFFER_SIZE, file);
+        let write_options =
+            
IpcWriteOptions::default().try_with_compression(Some(CompressionType::LZ4_FRAME))?;
+        let writer = StreamWriter::try_new_with_options(buf_writer, schema, 
write_options)?;
+        Ok(Self {
+            writer,
+            temp_file,
+            bytes_written: 0,
+        })
+    }
+
+    /// Append a single batch to the spill file.
+    fn write_batch(&mut self, batch: &RecordBatch) -> DFResult<()> {
+        if batch.num_rows() > 0 {
+            self.bytes_written += batch.get_array_memory_size();
+            self.writer.write(batch)?;
+        }
+        Ok(())
+    }
+
+    /// Append multiple batches to the spill file.
+    fn write_batches(&mut self, batches: &[RecordBatch]) -> DFResult<()> {
+        for batch in batches {
+            self.write_batch(batch)?;
+        }
+        Ok(())
+    }
+
+    /// Finish writing. Must be called before reading back.
+    fn finish(mut self) -> DFResult<(RefCountedTempFile, usize)> {
+        self.writer.finish()?;
+        Ok((self.temp_file, self.bytes_written))
+    }
+}
+
+// ---------------------------------------------------------------------------
+// SpillReaderExec: streaming ExecutionPlan for reading spill files
+// ---------------------------------------------------------------------------
+
+/// An ExecutionPlan that streams record batches from an Arrow IPC spill file.
+/// Used during the join phase so that spilled probe data is read on-demand
+/// instead of loaded entirely into memory.
+#[derive(Debug)]
+struct SpillReaderExec {
+    spill_file: RefCountedTempFile,
+    schema: SchemaRef,
+    cache: Arc<PlanProperties>,
+}
+
+impl SpillReaderExec {
+    fn new(spill_file: RefCountedTempFile, schema: SchemaRef) -> Self {
+        let cache = Arc::new(PlanProperties::new(
+            EquivalenceProperties::new(Arc::clone(&schema)),
+            Partitioning::UnknownPartitioning(1),
+            
datafusion::physical_plan::execution_plan::EmissionType::Incremental,
+            datafusion::physical_plan::execution_plan::Boundedness::Bounded,
+        ));
+        Self {
+            spill_file,
+            schema,
+            cache,
+        }
+    }
+}
+
+impl DisplayAs for SpillReaderExec {
+    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> 
fmt::Result {
+        write!(f, "SpillReaderExec")
+    }
+}
+
+impl ExecutionPlan for SpillReaderExec {
+    fn name(&self) -> &str {
+        "SpillReaderExec"
+    }
+
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        Arc::clone(&self.schema)
+    }
+
+    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
+        vec![]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        _children: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> DFResult<Arc<dyn ExecutionPlan>> {
+        Ok(self)
+    }
+
+    fn properties(&self) -> &Arc<PlanProperties> {
+        &self.cache
+    }
+
+    fn execute(
+        &self,
+        _partition: usize,
+        _context: Arc<TaskContext>,
+    ) -> DFResult<SendableRecordBatchStream> {
+        let stream_schema = Arc::clone(&self.schema);
+        let coalesce_schema = Arc::clone(&self.schema);
+        let path = self.spill_file.path().to_path_buf();
+        // Move the spill file handle into the blocking closure to keep
+        // the temp file alive until the reader is done.
+        let spill_file_handle = self.spill_file.clone();
+
+        // Use a channel so file I/O runs on a blocking thread and doesn't
+        // block the async executor. This lets select_all interleave multiple
+        // partition streams effectively.
+        let (tx, rx) = mpsc::channel::<DFResult<RecordBatch>>(4);
+
+        tokio::task::spawn_blocking(move || {
+            let _keep_alive = spill_file_handle;
+            let file = match File::open(&path) {
+                Ok(f) => f,
+                Err(e) => {
+                    let _ = 
tx.blocking_send(Err(DataFusionError::Execution(format!(
+                        "Failed to open spill file: {e}"
+                    ))));
+                    return;
+                }
+            };
+            let reader = match StreamReader::try_new(
+                BufReader::with_capacity(SPILL_IO_BUFFER_SIZE, file),
+                None,
+            ) {
+                Ok(r) => r,
+                Err(e) => {
+                    let _ = 
tx.blocking_send(Err(DataFusionError::ArrowError(Box::new(e), None)));
+                    return;
+                }
+            };
+
+            // Coalesce small sub-batches into larger ones to reduce per-batch
+            // overhead in the downstream hash join.
+            let mut pending: Vec<RecordBatch> = Vec::new();
+            let mut pending_rows = 0usize;
+
+            for batch_result in reader {
+                let batch = match batch_result {
+                    Ok(b) => b,
+                    Err(e) => {
+                        let _ =
+                            
tx.blocking_send(Err(DataFusionError::ArrowError(Box::new(e), None)));
+                        return;
+                    }
+                };
+                if batch.num_rows() == 0 {
+                    continue;
+                }
+                pending_rows += batch.num_rows();
+                pending.push(batch);
+
+                if pending_rows >= SPILL_READ_COALESCE_TARGET {
+                    let merged = if pending.len() == 1 {
+                        Ok(pending.pop().unwrap())
+                    } else {
+                        concat_batches(&coalesce_schema, &pending)
+                            .map_err(|e| 
DataFusionError::ArrowError(Box::new(e), None))
+                    };
+                    pending.clear();
+                    pending_rows = 0;
+                    if tx.blocking_send(merged).is_err() {
+                        return;
+                    }
+                }
+            }
+
+            // Flush remaining
+            if !pending.is_empty() {
+                let merged = if pending.len() == 1 {
+                    Ok(pending.pop().unwrap())
+                } else {
+                    concat_batches(&coalesce_schema, &pending)
+                        .map_err(|e| DataFusionError::ArrowError(Box::new(e), 
None))
+                };
+                let _ = tx.blocking_send(merged);
+            }
+        });
+
+        let batch_stream = futures::stream::unfold(rx, |mut rx| async move {
+            rx.recv().await.map(|batch| (batch, rx))
+        });
+        Ok(Box::pin(RecordBatchStreamAdapter::new(
+            stream_schema,
+            batch_stream,
+        )))
+    }
+}
+
+// ---------------------------------------------------------------------------
+// StreamSourceExec: wrap an existing stream as an ExecutionPlan
+// ---------------------------------------------------------------------------
+
+/// An ExecutionPlan that yields batches from a pre-existing stream.
+/// Used in the fast path to feed the probe side's live stream into
+/// a `HashJoinExec` without buffering or spilling.
+struct StreamSourceExec {
+    stream: Mutex<Option<SendableRecordBatchStream>>,
+    schema: SchemaRef,
+    cache: Arc<PlanProperties>,
+}
+
+impl StreamSourceExec {
+    fn new(stream: SendableRecordBatchStream, schema: SchemaRef) -> Self {
+        let cache = Arc::new(PlanProperties::new(
+            EquivalenceProperties::new(Arc::clone(&schema)),
+            Partitioning::UnknownPartitioning(1),
+            
datafusion::physical_plan::execution_plan::EmissionType::Incremental,
+            datafusion::physical_plan::execution_plan::Boundedness::Bounded,
+        ));
+        Self {
+            stream: Mutex::new(Some(stream)),
+            schema,
+            cache,
+        }
+    }
+}
+
+impl fmt::Debug for StreamSourceExec {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        f.debug_struct("StreamSourceExec").finish()
+    }
+}
+
+impl DisplayAs for StreamSourceExec {
+    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> 
fmt::Result {
+        write!(f, "StreamSourceExec")
+    }
+}
+
+impl ExecutionPlan for StreamSourceExec {
+    fn name(&self) -> &str {
+        "StreamSourceExec"
+    }
+
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        Arc::clone(&self.schema)
+    }
+
+    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
+        vec![]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        _children: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> DFResult<Arc<dyn ExecutionPlan>> {
+        Ok(self)
+    }
+
+    fn properties(&self) -> &Arc<PlanProperties> {
+        &self.cache
+    }
+
+    fn execute(
+        &self,
+        _partition: usize,
+        _context: Arc<TaskContext>,
+    ) -> DFResult<SendableRecordBatchStream> {
+        self.stream
+            .lock()
+            .map_err(|e| DataFusionError::Internal(format!("lock poisoned: 
{e}")))?
+            .take()
+            .ok_or_else(|| {
+                DataFusionError::Internal("StreamSourceExec: stream already 
consumed".to_string())
+            })
+    }
+}
+
+// ---------------------------------------------------------------------------
+// GraceHashJoinMetrics
+// ---------------------------------------------------------------------------
+
+/// Production metrics for the Grace Hash Join operator.
+struct GraceHashJoinMetrics {
+    /// Baseline metrics (output rows, elapsed compute)
+    baseline: BaselineMetrics,
+    /// Time spent partitioning the build side
+    build_time: Time,
+    /// Time spent partitioning the probe side
+    probe_time: Time,
+    /// Number of spill events
+    spill_count: Count,
+    /// Total bytes spilled to disk
+    spilled_bytes: Count,
+    /// Number of build-side input rows
+    build_input_rows: Count,
+    /// Number of build-side input batches
+    build_input_batches: Count,
+    /// Number of probe-side input rows
+    input_rows: Count,
+    /// Number of probe-side input batches
+    input_batches: Count,
+    /// Number of output batches
+    output_batches: Count,
+    /// Time spent in per-partition joins
+    join_time: Time,
+}
+
+impl GraceHashJoinMetrics {
+    fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
+        Self {
+            baseline: BaselineMetrics::new(metrics, partition),
+            build_time: MetricBuilder::new(metrics).subset_time("build_time", 
partition),
+            probe_time: MetricBuilder::new(metrics).subset_time("probe_time", 
partition),
+            spill_count: MetricBuilder::new(metrics).spill_count(partition),
+            spilled_bytes: 
MetricBuilder::new(metrics).spilled_bytes(partition),
+            build_input_rows: 
MetricBuilder::new(metrics).counter("build_input_rows", partition),
+            build_input_batches: MetricBuilder::new(metrics)
+                .counter("build_input_batches", partition),
+            input_rows: MetricBuilder::new(metrics).counter("input_rows", 
partition),
+            input_batches: 
MetricBuilder::new(metrics).counter("input_batches", partition),
+            output_batches: 
MetricBuilder::new(metrics).counter("output_batches", partition),
+            join_time: MetricBuilder::new(metrics).subset_time("join_time", 
partition),
+        }
+    }
+}
+
+// ---------------------------------------------------------------------------
+// GraceHashJoinExec
+// ---------------------------------------------------------------------------
+
+/// Grace Hash Join execution plan.
+///
+/// Partitions both sides into N buckets, then joins each bucket independently
+/// using DataFusion's HashJoinExec. Spills partitions to disk when memory
+/// pressure is detected.
+#[derive(Debug)]
+pub struct GraceHashJoinExec {
+    /// Left input
+    left: Arc<dyn ExecutionPlan>,
+    /// Right input
+    right: Arc<dyn ExecutionPlan>,
+    /// Join key pairs: (left_key, right_key)
+    on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
+    /// Optional join filter applied after key matching
+    filter: Option<JoinFilter>,
+    /// Join type
+    join_type: JoinType,
+    /// Number of hash partitions
+    num_partitions: usize,
+    /// Whether left is the build side (true) or right is (false)
+    build_left: bool,
+    /// Maximum build-side bytes for the fast path (0 = disabled)
+    fast_path_threshold: usize,
+    /// Output schema
+    schema: SchemaRef,
+    /// Plan properties cache
+    cache: Arc<PlanProperties>,
+    /// Metrics
+    metrics: ExecutionPlanMetricsSet,
+}
+
+impl GraceHashJoinExec {
+    #[allow(clippy::too_many_arguments)]
+    pub fn try_new(
+        left: Arc<dyn ExecutionPlan>,
+        right: Arc<dyn ExecutionPlan>,
+        on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
+        filter: Option<JoinFilter>,
+        join_type: &JoinType,
+        num_partitions: usize,
+        build_left: bool,
+        fast_path_threshold: usize,
+    ) -> DFResult<Self> {
+        // Build the output schema using HashJoinExec's logic.
+        // HashJoinExec expects left=build, right=probe. When build_left=false,
+        // we swap inputs + keys + join type for schema derivation, then store
+        // original values for our own partitioning logic.
+        let hash_join = HashJoinExec::try_new(
+            Arc::clone(&left),
+            Arc::clone(&right),
+            on.clone(),
+            filter.clone(),
+            join_type,
+            None,
+            PartitionMode::CollectLeft,
+            NullEquality::NullEqualsNothing,
+            false,
+        )?;
+        let (schema, cache) = if build_left {
+            (hash_join.schema(), hash_join.properties().clone())
+        } else {
+            // Swap to get correct output schema for build-right
+            let swapped = hash_join.swap_inputs(PartitionMode::CollectLeft)?;
+            (swapped.schema(), swapped.properties().clone())
+        };
+
+        Ok(Self {
+            left,
+            right,
+            on,
+            filter,
+            join_type: *join_type,
+            num_partitions: if num_partitions == 0 {
+                DEFAULT_NUM_PARTITIONS
+            } else {
+                num_partitions
+            },
+            build_left,
+            fast_path_threshold,
+            schema,
+            cache,
+            metrics: ExecutionPlanMetricsSet::new(),
+        })
+    }
+}
+
+impl DisplayAs for GraceHashJoinExec {
+    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> 
fmt::Result {
+        match t {
+            DisplayFormatType::Default
+            | DisplayFormatType::Verbose
+            | DisplayFormatType::TreeRender => {
+                let on: Vec<String> = self.on.iter().map(|(l, r)| 
format!("({l}, {r})")).collect();
+                write!(
+                    f,
+                    "GraceHashJoinExec: join_type={:?}, on=[{}], 
num_partitions={}",
+                    self.join_type,
+                    on.join(", "),
+                    self.num_partitions,
+                )
+            }
+        }
+    }
+}
+
+impl ExecutionPlan for GraceHashJoinExec {
+    fn name(&self) -> &str {
+        "GraceHashJoinExec"
+    }
+
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        Arc::clone(&self.schema)
+    }
+
+    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
+        vec![&self.left, &self.right]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        children: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> DFResult<Arc<dyn ExecutionPlan>> {
+        Ok(Arc::new(GraceHashJoinExec::try_new(
+            Arc::clone(&children[0]),
+            Arc::clone(&children[1]),
+            self.on.clone(),
+            self.filter.clone(),
+            &self.join_type,
+            self.num_partitions,
+            self.build_left,
+            self.fast_path_threshold,
+        )?))
+    }
+
+    fn properties(&self) -> &Arc<PlanProperties> {
+        &self.cache
+    }
+
+    fn execute(
+        &self,
+        partition: usize,
+        context: Arc<TaskContext>,
+    ) -> DFResult<SendableRecordBatchStream> {
+        info!(
+            "GraceHashJoin: execute() called. build_left={}, join_type={:?}, \
+             num_partitions={}, fast_path_threshold={}\n  left: {}\n  right: 
{}",
+            self.build_left,
+            self.join_type,
+            self.num_partitions,
+            self.fast_path_threshold,
+            DisplayableExecutionPlan::new(self.left.as_ref()).one_line(),
+            DisplayableExecutionPlan::new(self.right.as_ref()).one_line(),
+        );
+        let left_stream = self.left.execute(partition, Arc::clone(&context))?;
+        let right_stream = self.right.execute(partition, 
Arc::clone(&context))?;
+
+        let join_metrics = GraceHashJoinMetrics::new(&self.metrics, partition);
+
+        // Determine build/probe streams and schemas based on build_left.
+        // The internal execution always treats first arg as build, second as 
probe.
+        let (build_stream, probe_stream, build_schema, probe_schema, build_on, 
probe_on) =
+            if self.build_left {
+                let build_keys: Vec<_> = self.on.iter().map(|(l, _)| 
Arc::clone(l)).collect();
+                let probe_keys: Vec<_> = self.on.iter().map(|(_, r)| 
Arc::clone(r)).collect();
+                (
+                    left_stream,
+                    right_stream,
+                    self.left.schema(),
+                    self.right.schema(),
+                    build_keys,
+                    probe_keys,
+                )
+            } else {
+                // Build right: right is build side, left is probe side
+                let build_keys: Vec<_> = self.on.iter().map(|(_, r)| 
Arc::clone(r)).collect();
+                let probe_keys: Vec<_> = self.on.iter().map(|(l, _)| 
Arc::clone(l)).collect();
+                (
+                    right_stream,
+                    left_stream,
+                    self.right.schema(),
+                    self.left.schema(),
+                    build_keys,
+                    probe_keys,
+                )
+            };
+
+        let on = self.on.clone();
+        let filter = self.filter.clone();
+        let join_type = self.join_type;
+        let num_partitions = self.num_partitions;
+        let build_left = self.build_left;
+        let fast_path_threshold = self.fast_path_threshold;
+
+        let result_stream = futures::stream::once(async move {
+            execute_grace_hash_join(
+                build_stream,
+                probe_stream,
+                build_on,
+                probe_on,
+                on,
+                filter,
+                join_type,
+                num_partitions,
+                build_left,
+                fast_path_threshold,
+                build_schema,
+                probe_schema,
+                context,
+                join_metrics,
+            )
+            .await
+        })
+        .try_flatten();
+
+        Ok(Box::pin(RecordBatchStreamAdapter::new(
+            Arc::clone(&self.schema),
+            result_stream,
+        )))
+    }
+
+    fn metrics(&self) -> Option<MetricsSet> {
+        Some(self.metrics.clone_inner())
+    }
+}
+
+// ---------------------------------------------------------------------------
+// Per-partition state
+// ---------------------------------------------------------------------------
+
+/// Per-partition state tracking buffered data or spill writers.
+struct HashPartition {
+    /// In-memory build-side batches for this partition.
+    build_batches: Vec<RecordBatch>,
+    /// In-memory probe-side batches for this partition.
+    probe_batches: Vec<RecordBatch>,
+    /// Incremental spill writer for build side (if spilling).
+    build_spill_writer: Option<SpillWriter>,
+    /// Incremental spill writer for probe side (if spilling).
+    probe_spill_writer: Option<SpillWriter>,
+    /// Approximate memory used by build-side batches in this partition.
+    build_mem_size: usize,
+    /// Approximate memory used by probe-side batches in this partition.
+    probe_mem_size: usize,
+}
+
+impl HashPartition {
+    fn new() -> Self {
+        Self {
+            build_batches: Vec::new(),
+            probe_batches: Vec::new(),
+            build_spill_writer: None,
+            probe_spill_writer: None,
+            build_mem_size: 0,
+            probe_mem_size: 0,
+        }
+    }
+
+    /// Whether the build side has been spilled to disk.
+    fn build_spilled(&self) -> bool {
+        self.build_spill_writer.is_some()
+    }
+}
+
+// ---------------------------------------------------------------------------
+// Main execution logic
+// ---------------------------------------------------------------------------
+
+/// Main execution logic for the grace hash join.
+///
+/// `build_stream`/`probe_stream`: already swapped based on build_left.
+/// `build_keys`/`probe_keys`: key expressions for their respective sides.
+/// `original_on`: original (left_key, right_key) pairs for HashJoinExec.
+/// `build_left`: whether left is build side (affects HashJoinExec 
construction).
+#[allow(clippy::too_many_arguments)]
+async fn execute_grace_hash_join(
+    build_stream: SendableRecordBatchStream,
+    probe_stream: SendableRecordBatchStream,
+    build_keys: Vec<Arc<dyn PhysicalExpr>>,
+    probe_keys: Vec<Arc<dyn PhysicalExpr>>,
+    original_on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
+    filter: Option<JoinFilter>,
+    join_type: JoinType,
+    num_partitions: usize,
+    build_left: bool,
+    fast_path_threshold: usize,
+    build_schema: SchemaRef,
+    probe_schema: SchemaRef,
+    context: Arc<TaskContext>,
+    metrics: GraceHashJoinMetrics,
+) -> DFResult<impl Stream<Item = DFResult<RecordBatch>>> {
+    let ghj_id = GHJ_INSTANCE_COUNTER.fetch_add(1, 
std::sync::atomic::Ordering::Relaxed);
+
+    // Set up memory reservation (shared across build and probe phases)
+    let mut reservation = MutableReservation(
+        MemoryConsumer::new("GraceHashJoinExec")
+            .with_can_spill(true)
+            .register(&context.runtime_env().memory_pool),
+    );
+
+    info!(
+        "GHJ#{}: started. build_left={}, join_type={:?}, pool reserved={}",
+        ghj_id,
+        build_left,
+        join_type,
+        context.runtime_env().memory_pool.reserved(),
+    );
+
+    // Optimistic fast path: if the fast path threshold is set, try buffering
+    // the build side without partitioning. This avoids the overhead of hash
+    // computation, prefix-sum, and per-partition take() for every build batch,
+    // which is wasted work when the build side fits in memory and the fast 
path
+    // is taken (the common case with a generous threshold).
+    if fast_path_threshold > 0 {
+        let build_result = {
+            let _timer = metrics.build_time.timer();
+            buffer_build_optimistic(build_stream, &mut reservation, 
&metrics).await?
+        };
+
+        match build_result {
+            BuildBufferResult::Complete(build_batches, actual_build_bytes)
+                if actual_build_bytes <= fast_path_threshold =>
+            {
+                // Fast path: all build data buffered, no memory pressure.
+                // Skip partitioning entirely and stream probe through 
HashJoinExec.
+                let total_build_rows: usize = build_batches.iter().map(|b| 
b.num_rows()).sum();
+                info!(
+                    "GHJ#{}: optimistic fast path — build side ({} rows, {} 
bytes). \
+                     Streaming probe directly through HashJoinExec. pool 
reserved={}",
+                    ghj_id,
+                    total_build_rows,
+                    actual_build_bytes,
+                    context.runtime_env().memory_pool.reserved(),
+                );
+
+                reservation.free();
+
+                // Wrap probe stream to count input_batches and input_rows
+                // (normally counted during partition_probe_side, which is
+                // skipped in the fast path).
+                let probe_input_batches = metrics.input_batches.clone();
+                let probe_input_rows = metrics.input_rows.clone();
+                let probe_schema_clone = Arc::clone(&probe_schema);
+                let counting_probe = probe_stream.inspect_ok(move |batch| {
+                    probe_input_batches.add(1);
+                    probe_input_rows.add(batch.num_rows());
+                });
+                let counting_probe: SendableRecordBatchStream = Box::pin(
+                    RecordBatchStreamAdapter::new(probe_schema_clone, 
counting_probe),
+                );
+
+                let stream = create_fast_path_stream(
+                    build_batches,
+                    counting_probe,
+                    &original_on,
+                    &filter,
+                    &join_type,
+                    build_left,
+                    &build_schema,
+                    &probe_schema,
+                    &context,
+                )?;
+
+                let output_metrics = metrics.baseline.clone();
+                let output_batch_count = metrics.output_batches.clone();
+                let join_time = metrics.join_time.clone();
+                let result_stream = stream.inspect_ok(move |batch| {
+                    let _timer = join_time.timer();
+                    output_metrics.record_output(batch.num_rows());
+                    output_batch_count.add(1);
+                });
+
+                return Ok(result_stream.boxed());
+            }
+            result => {
+                // Build side too large for fast path, or memory pressure 
occurred.
+                // Partition the buffered batches offline and continue with 
slow path.
+                let (buffered_batches, remaining_stream) = match result {
+                    BuildBufferResult::Complete(batches, _) => (batches, None),
+                    BuildBufferResult::NeedPartition(batches, stream) => 
(batches, Some(stream)),
+                };
+
+                info!(
+                    "GHJ#{}: optimistic buffer fallback — partitioning {} 
buffered batches. \
+                     pool reserved={}",
+                    ghj_id,
+                    buffered_batches.len(),
+                    context.runtime_env().memory_pool.reserved(),
+                );
+
+                // Free reservation for buffered batches; partition_from_buffer
+                // and partition_build_side will re-track per-partition memory.
+                reservation.free();
+
+                let mut partitions: Vec<HashPartition> =
+                    (0..num_partitions).map(|_| 
HashPartition::new()).collect();
+                let mut scratch = ScratchSpace::default();
+
+                {
+                    let _timer = metrics.build_time.timer();
+                    partition_from_buffer(
+                        buffered_batches,
+                        &build_keys,
+                        num_partitions,
+                        &build_schema,
+                        &mut partitions,
+                        &mut reservation,
+                        &context,
+                        &metrics,
+                        &mut scratch,
+                    )?;
+
+                    // Continue reading remaining stream if optimistic buffer 
was interrupted
+                    if let Some(remaining) = remaining_stream {
+                        partition_build_side(
+                            remaining,
+                            &build_keys,
+                            num_partitions,
+                            &build_schema,
+                            &mut partitions,
+                            &mut reservation,
+                            &context,
+                            &metrics,
+                            &mut scratch,
+                        )
+                        .await?;
+                    }
+                }
+
+                return execute_slow_path(
+                    ghj_id,
+                    partitions,
+                    probe_stream,
+                    probe_keys,
+                    original_on,
+                    filter,
+                    join_type,
+                    num_partitions,
+                    build_left,
+                    build_schema,
+                    probe_schema,
+                    context,
+                    metrics,
+                    reservation,
+                    scratch,
+                )
+                .await
+                .map(|s| s.boxed());
+            }
+        }
+    }
+
+    // Non-optimistic path: fast_path_threshold == 0 (disabled).
+    // Always partition the build side.
+    let mut partitions: Vec<HashPartition> =
+        (0..num_partitions).map(|_| HashPartition::new()).collect();
+    let mut scratch = ScratchSpace::default();
+
+    {
+        let _timer = metrics.build_time.timer();
+        partition_build_side(
+            build_stream,
+            &build_keys,
+            num_partitions,
+            &build_schema,
+            &mut partitions,
+            &mut reservation,
+            &context,
+            &metrics,
+            &mut scratch,
+        )
+        .await?;
+    }
+
+    execute_slow_path(
+        ghj_id,
+        partitions,
+        probe_stream,
+        probe_keys,
+        original_on,
+        filter,
+        join_type,
+        num_partitions,
+        build_left,
+        build_schema,
+        probe_schema,
+        context,
+        metrics,
+        reservation,
+        scratch,
+    )
+    .await
+    .map(|s| s.boxed())
+}
+
+/// Result of optimistic build-side buffering.
+enum BuildBufferResult {
+    /// All build batches buffered successfully with memory tracking.
+    Complete(Vec<RecordBatch>, usize),
+    /// Memory pressure occurred — returns buffered batches and remaining 
stream.
+    NeedPartition(Vec<RecordBatch>, SendableRecordBatchStream),
+}
+
+/// Buffer the build side without partitioning. Returns all batches and total 
bytes,
+/// or signals memory pressure with the partially-buffered data and remaining 
stream.
+async fn buffer_build_optimistic(
+    mut input: SendableRecordBatchStream,
+    reservation: &mut MutableReservation,
+    metrics: &GraceHashJoinMetrics,
+) -> DFResult<BuildBufferResult> {
+    let mut batches = Vec::new();
+    let mut total_bytes = 0usize;
+
+    while let Some(batch) = input.next().await {
+        let batch = batch?;
+        if batch.num_rows() == 0 {
+            continue;
+        }
+
+        metrics.build_input_batches.add(1);
+        metrics.build_input_rows.add(batch.num_rows());
+
+        let batch_size = batch.get_array_memory_size();
+
+        if reservation.try_grow(batch_size).is_err() {
+            // Memory pressure — return what we have and the remaining stream.
+            // The caller will partition the buffered data and continue 
streaming.
+            batches.push(batch);
+            return Ok(BuildBufferResult::NeedPartition(batches, input));
+        }
+
+        total_bytes += batch_size;
+        batches.push(batch);
+    }
+
+    Ok(BuildBufferResult::Complete(batches, total_bytes))
+}
+
+/// Partition already-buffered build batches into the partition structure.
+/// Used when the optimistic fast path falls back to the slow path.
+#[allow(clippy::too_many_arguments)]
+fn partition_from_buffer(
+    batches: Vec<RecordBatch>,
+    keys: &[Arc<dyn PhysicalExpr>],
+    num_partitions: usize,
+    schema: &SchemaRef,
+    partitions: &mut [HashPartition],
+    reservation: &mut MutableReservation,
+    context: &Arc<TaskContext>,
+    metrics: &GraceHashJoinMetrics,
+    scratch: &mut ScratchSpace,
+) -> DFResult<()> {
+    for batch in batches {
+        if batch.num_rows() == 0 {
+            continue;
+        }
+
+        let total_batch_size = batch.get_array_memory_size();
+        let total_rows = batch.num_rows();
+
+        scratch.compute_partitions(&batch, keys, num_partitions, 0)?;
+
+        #[allow(clippy::needless_range_loop)]
+        for part_idx in 0..num_partitions {
+            if scratch.partition_len(part_idx) == 0 {
+                continue;
+            }
+
+            let sub_rows = scratch.partition_len(part_idx);
+            let sub_batch = if sub_rows == total_rows {
+                batch.clone()
+            } else {
+                scratch.take_partition(&batch, part_idx)?.unwrap()
+            };
+            let batch_size = if total_rows > 0 {
+                (total_batch_size as u64 * sub_rows as u64 / total_rows as 
u64) as usize
+            } else {
+                0
+            };
+
+            if partitions[part_idx].build_spilled() {
+                if let Some(ref mut writer) = 
partitions[part_idx].build_spill_writer {
+                    writer.write_batch(&sub_batch)?;
+                }
+            } else {
+                if reservation.try_grow(batch_size).is_err() {
+                    info!(
+                        "GraceHashJoin: memory pressure during buffer 
partition, \
+                         spilling largest partition"
+                    );
+                    spill_largest_partition(partitions, schema, context, 
reservation, metrics)?;
+
+                    if reservation.try_grow(batch_size).is_err() {
+                        spill_partition_build(
+                            &mut partitions[part_idx],
+                            schema,
+                            context,
+                            reservation,
+                            metrics,
+                        )?;
+                        if let Some(ref mut writer) = 
partitions[part_idx].build_spill_writer {
+                            writer.write_batch(&sub_batch)?;
+                        }
+                        continue;
+                    }
+                }
+
+                partitions[part_idx].build_mem_size += batch_size;
+                partitions[part_idx].build_batches.push(sub_batch);
+            }
+        }
+    }
+
+    Ok(())
+}
+
+/// Create and execute a HashJoinExec, handling build_left swap logic.
+///
+/// When `build_left` is true, the left source is the build side and 
CollectLeft
+/// mode works directly. When `build_left` is false, we create the join with
+/// the original left/right order then swap inputs so the right side is 
collected.
+fn execute_hash_join(
+    left_source: Arc<dyn ExecutionPlan>,
+    right_source: Arc<dyn ExecutionPlan>,
+    original_on: JoinOnRef<'_>,
+    filter: &Option<JoinFilter>,
+    join_type: &JoinType,
+    build_left: bool,
+    context: &Arc<TaskContext>,
+) -> DFResult<SendableRecordBatchStream> {
+    if build_left {
+        let hash_join = HashJoinExec::try_new(
+            left_source,
+            right_source,
+            original_on.to_vec(),
+            filter.clone(),
+            join_type,
+            None,
+            PartitionMode::CollectLeft,
+            NullEquality::NullEqualsNothing,
+            false,
+        )?;
+        hash_join.execute(0, context_for_join_output(context))
+    } else {
+        let hash_join = Arc::new(HashJoinExec::try_new(
+            left_source,
+            right_source,
+            original_on.to_vec(),
+            filter.clone(),
+            join_type,
+            None,
+            PartitionMode::CollectLeft,
+            NullEquality::NullEqualsNothing,
+            false,
+        )?);
+        let swapped = hash_join.swap_inputs(PartitionMode::CollectLeft)?;
+        swapped.execute(0, context_for_join_output(context))
+    }
+}
+
+/// Create the fast-path HashJoinExec stream (no partitioning, no spilling).
+#[allow(clippy::too_many_arguments, clippy::type_complexity)]
+fn create_fast_path_stream(
+    build_data: Vec<RecordBatch>,
+    probe_stream: SendableRecordBatchStream,
+    original_on: &[(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)],
+    filter: &Option<JoinFilter>,
+    join_type: &JoinType,
+    build_left: bool,
+    build_schema: &SchemaRef,
+    probe_schema: &SchemaRef,
+    context: &Arc<TaskContext>,
+) -> DFResult<SendableRecordBatchStream> {
+    let build_source = memory_source_exec(build_data, build_schema)?;
+    let probe_source: Arc<dyn ExecutionPlan> = Arc::new(StreamSourceExec::new(
+        probe_stream,
+        Arc::clone(probe_schema),
+    ));
+
+    let (left_source, right_source): (Arc<dyn ExecutionPlan>, Arc<dyn 
ExecutionPlan>) =
+        if build_left {
+            (build_source, probe_source)
+        } else {
+            (probe_source, build_source)
+        };
+
+    execute_hash_join(
+        left_source,
+        right_source,
+        original_on,
+        filter,
+        join_type,
+        build_left,
+        context,
+    )
+}
+
+/// Execute the slow path: partition probe side, merge partitions, and join.
+#[allow(clippy::too_many_arguments)]
+async fn execute_slow_path(
+    ghj_id: usize,
+    mut partitions: Vec<HashPartition>,
+    probe_stream: SendableRecordBatchStream,
+    probe_keys: Vec<Arc<dyn PhysicalExpr>>,
+    original_on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
+    filter: Option<JoinFilter>,
+    join_type: JoinType,
+    num_partitions: usize,
+    build_left: bool,
+    build_schema: SchemaRef,
+    probe_schema: SchemaRef,
+    context: Arc<TaskContext>,
+    metrics: GraceHashJoinMetrics,
+    mut reservation: MutableReservation,
+    mut scratch: ScratchSpace,
+) -> DFResult<impl Stream<Item = DFResult<RecordBatch>>> {
+    let build_spilled = partitions.iter().any(|p| p.build_spilled());
+    let actual_build_bytes: usize = partitions
+        .iter()
+        .flat_map(|p| p.build_batches.iter())
+        .map(|b| b.get_array_memory_size())
+        .sum();
+    let total_build_rows: usize = partitions
+        .iter()
+        .flat_map(|p| p.build_batches.iter())
+        .map(|b| b.num_rows())
+        .sum();
+    info!(
+        "GHJ#{}: slow path — build spilled={}, {} rows, {} bytes (actual). \
+         join_type={:?}, build_left={}. pool reserved={}. Partitioning probe 
side.",
+        ghj_id,
+        build_spilled,
+        total_build_rows,
+        actual_build_bytes,
+        join_type,
+        build_left,
+        context.runtime_env().memory_pool.reserved(),
+    );
+
+    // Phase 2: Partition the probe side
+    {
+        let _timer = metrics.probe_time.timer();
+        partition_probe_side(
+            probe_stream,
+            &probe_keys,
+            num_partitions,
+            &probe_schema,
+            &mut partitions,
+            &mut reservation,
+            &build_schema,
+            &context,
+            &metrics,
+            &mut scratch,
+        )
+        .await?;
+    }
+
+    // Log probe-side partition summary
+    {
+        let total_probe_rows: usize = partitions
+            .iter()
+            .flat_map(|p| p.probe_batches.iter())
+            .map(|b| b.num_rows())
+            .sum();
+        let total_probe_bytes: usize = partitions.iter().map(|p| 
p.probe_mem_size).sum();
+        let probe_spilled = partitions
+            .iter()
+            .filter(|p| p.probe_spill_writer.is_some())
+            .count();
+        info!(
+            "GHJ#{}: probe phase complete. \
+             total probe (in-memory): {} rows, {} bytes, {} spilled. \
+             reservation={}, pool reserved={}",
+            ghj_id,
+            total_probe_rows,
+            total_probe_bytes,
+            probe_spilled,
+            reservation.0.size(),
+            context.runtime_env().memory_pool.reserved(),
+        );
+    }
+
+    // Finish all open spill writers before reading back
+    let finished_partitions = finish_spill_writers(partitions)?;
+
+    // Merge adjacent partitions to reduce the number of HashJoinExec calls.
+    // Compute desired partition count from total build bytes.
+    let total_build_bytes: usize = finished_partitions.iter().map(|p| 
p.build_bytes).sum();
+    let desired_partitions = if total_build_bytes > 0 {
+        let desired = total_build_bytes.div_ceil(TARGET_PARTITION_BUILD_SIZE);
+        desired.max(1).min(num_partitions)
+    } else {
+        1
+    };
+    let original_partition_count = finished_partitions.len();
+    let finished_partitions = merge_finished_partitions(finished_partitions, 
desired_partitions);
+    if finished_partitions.len() < original_partition_count {
+        info!(
+            "GraceHashJoin: merged {} partitions into {} (total build {} 
bytes, \
+             target {} bytes/partition)",
+            original_partition_count,
+            finished_partitions.len(),
+            total_build_bytes,
+            TARGET_PARTITION_BUILD_SIZE,
+        );
+    }
+
+    // Release all remaining reservation before Phase 3. The in-memory
+    // partition data is now owned by finished_partitions and will be moved
+    // into per-partition HashJoinExec instances (which track memory via
+    // their own HashJoinInput reservations). Keeping our reservation alive
+    // would double-count the memory and starve other consumers.
+    info!(
+        "GHJ#{}: freeing reservation ({} bytes) before Phase 3. pool 
reserved={}",
+        ghj_id,
+        reservation.0.size(),
+        context.runtime_env().memory_pool.reserved(),
+    );
+    reservation.free();
+
+    // Phase 3: Join partitions sequentially.
+    // We use a concurrency limit of 1 to avoid creating multiple simultaneous
+    // HashJoinInput reservations per task. With multiple Spark tasks sharing
+    // the same memory pool, even modest build sides (e.g. 22 MB) can exhaust
+    // memory when many tasks run concurrent hash table builds simultaneously.
+    const MAX_CONCURRENT_PARTITIONS: usize = 1;
+    let semaphore = 
Arc::new(tokio::sync::Semaphore::new(MAX_CONCURRENT_PARTITIONS));
+    let (tx, rx) = 
mpsc::channel::<DFResult<RecordBatch>>(MAX_CONCURRENT_PARTITIONS * 2);
+
+    for partition in finished_partitions {
+        let tx = tx.clone();
+        let sem = Arc::clone(&semaphore);
+        let original_on = original_on.clone();
+        let filter = filter.clone();
+        let build_schema = Arc::clone(&build_schema);
+        let probe_schema = Arc::clone(&probe_schema);
+        let context = Arc::clone(&context);
+
+        tokio::spawn(async move {
+            let _permit = match sem.acquire().await {
+                Ok(p) => p,
+                Err(_) => return, // semaphore closed
+            };
+            match join_single_partition(
+                partition,
+                original_on,
+                filter,
+                join_type,
+                build_left,
+                build_schema,
+                probe_schema,
+                context,
+            )
+            .await
+            {
+                Ok(streams) => {
+                    for mut stream in streams {
+                        while let Some(batch) = stream.next().await {
+                            if tx.send(batch).await.is_err() {
+                                return;
+                            }
+                        }
+                    }
+                }
+                Err(e) => {
+                    let _ = tx.send(Err(e)).await;
+                }
+            }
+        });
+    }
+    drop(tx);
+
+    let output_metrics = metrics.baseline.clone();
+    let output_batch_count = metrics.output_batches.clone();
+    let join_time = metrics.join_time.clone();
+    let output_row_count = 
std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
+    let counter = Arc::clone(&output_row_count);
+    let jt = join_type;
+    let bl = build_left;
+    let result_stream = futures::stream::unfold(rx, |mut rx| async move {
+        rx.recv().await.map(|batch| (batch, rx))
+    })
+    .inspect_ok(move |batch| {
+        let _timer = join_time.timer();
+        output_metrics.record_output(batch.num_rows());
+        output_batch_count.add(1);
+        let prev = counter.fetch_add(batch.num_rows(), 
std::sync::atomic::Ordering::Relaxed);
+        let new_total = prev + batch.num_rows();
+        // Log every ~1M rows to detect exploding joins
+        if new_total / 1_000_000 > prev / 1_000_000 {
+            info!(
+                "GraceHashJoin: slow path output: {} rows emitted so far \
+                 (join_type={:?}, build_left={})",
+                new_total, jt, bl,
+            );
+        }
+    });
+
+    Ok(result_stream.boxed())
+}
+
+/// Wraps MemoryReservation to allow mutation through reference.
+struct MutableReservation(MemoryReservation);
+
+impl MutableReservation {
+    fn try_grow(&mut self, additional: usize) -> DFResult<()> {
+        self.0.try_grow(additional)
+    }
+
+    fn shrink(&mut self, amount: usize) {
+        self.0.shrink(amount);
+    }
+
+    fn free(&mut self) -> usize {
+        self.0.free()
+    }
+}
+
+// ---------------------------------------------------------------------------
+// ScratchSpace: reusable buffers for efficient hash partitioning
+// ---------------------------------------------------------------------------
+
+/// Reusable scratch buffers for partitioning batches. Uses a prefix-sum
+/// algorithm (borrowed from the shuffle `multi_partition.rs`) to compute
+/// contiguous row-index regions per partition in a single pass, avoiding
+/// N separate `take()` kernel calls.
+#[derive(Default)]
+struct ScratchSpace {
+    /// Hash values for each row.
+    hashes: Vec<u64>,
+    /// Partition id assigned to each row.
+    partition_ids: Vec<u32>,
+    /// Row indices reordered so that each partition's rows are contiguous.
+    partition_row_indices: Vec<u32>,
+    /// `partition_starts[k]..partition_starts[k+1]` gives the slice of
+    /// `partition_row_indices` belonging to partition k.
+    partition_starts: Vec<u32>,
+}
+
+impl ScratchSpace {
+    /// Compute hashes and partition ids, then build the prefix-sum index
+    /// structures for the given batch.
+    fn compute_partitions(
+        &mut self,
+        batch: &RecordBatch,
+        keys: &[Arc<dyn PhysicalExpr>],
+        num_partitions: usize,
+        recursion_level: usize,
+    ) -> DFResult<()> {
+        let num_rows = batch.num_rows();
+
+        // Evaluate key columns
+        let key_columns: Vec<_> = keys
+            .iter()
+            .map(|expr| expr.evaluate(batch).and_then(|cv| 
cv.into_array(num_rows)))
+            .collect::<DFResult<Vec<_>>>()?;
+
+        // Hash
+        self.hashes.resize(num_rows, 0);
+        self.hashes.truncate(num_rows);
+        self.hashes.fill(0);
+        let random_state = partition_random_state(recursion_level);
+        create_hashes(&key_columns, &random_state, &mut self.hashes)?;
+
+        // Assign partition ids
+        self.partition_ids.resize(num_rows, 0);
+        for (i, hash) in self.hashes[..num_rows].iter().enumerate() {
+            self.partition_ids[i] = (*hash as u32) % (num_partitions as u32);

Review Comment:
   https://github.com/apache/datafusion/pull/21900 
   this is a cool technique to make it faster



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to