UBarney commented on code in PR #17632:
URL: https://github.com/apache/datafusion/pull/17632#discussion_r2365668859


##########
datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs:
##########
@@ -167,139 +220,301 @@ impl SharedBoundsAccumulator {
         };
         Self {
             inner: Mutex::new(SharedBoundsState {
-                bounds: Vec::with_capacity(expected_calls),
+                bounds: HashMap::with_capacity(total_partitions),
+                filter_optimized: false,
+                completed_count: 0,
             }),
-            barrier: Barrier::new(expected_calls),
+            total_partitions,
             dynamic_filter,
             on_right,
         }
     }
 
-    /// Create a filter expression from individual partition bounds using OR 
logic.
+    /// Create a bounds predicate for a single partition: (col >= min AND col 
<= max) for all columns.
+    /// This is used in both progressive and final filter creation.
+    /// Returns None if no bounds are available for this partition.
+    fn create_partition_bounds_predicate(
+        &self,
+        partition_bounds: &PartitionBounds,
+    ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
+        let mut column_predicates = Vec::with_capacity(partition_bounds.len());
+
+        for (col_idx, right_expr) in self.on_right.iter().enumerate() {
+            if let Some(column_bounds) = 
partition_bounds.get_column_bounds(col_idx) {
+                // Create predicate: col >= min AND col <= max
+                let min_expr = Arc::new(BinaryExpr::new(
+                    Arc::clone(right_expr),
+                    Operator::GtEq,
+                    lit(column_bounds.min.clone()),
+                )) as Arc<dyn PhysicalExpr>;
+                let max_expr = Arc::new(BinaryExpr::new(
+                    Arc::clone(right_expr),
+                    Operator::LtEq,
+                    lit(column_bounds.max.clone()),
+                )) as Arc<dyn PhysicalExpr>;
+                let range_expr =
+                    Arc::new(BinaryExpr::new(min_expr, Operator::And, 
max_expr))
+                        as Arc<dyn PhysicalExpr>;
+                column_predicates.push(range_expr);
+            } else {
+                // Missing bounds for this column, the created predicate will 
have lower selectivity but will still be correct
+                continue;
+            }
+        }
+
+        // Combine all column predicates for this partition with AND
+        Ok(column_predicates.into_iter().reduce(|acc, pred| {
+            Arc::new(BinaryExpr::new(acc, Operator::And, pred)) as Arc<dyn 
PhysicalExpr>
+        }))
+    }
+
+    /// Create progressive filter using hash-based expressions to avoid false 
negatives.
+    ///
+    /// This is the heart of progressive filtering. It creates a CASE 
expression that applies
+    /// bounds filtering only to rows belonging to completed partitions, while 
safely passing
+    /// through all data from incomplete partitions.
     ///
-    /// This creates a filter where each partition's bounds form a conjunction 
(AND)
-    /// of column range predicates, and all partitions are combined with OR.
+    /// ## Generated Expression Structure:
+    /// ```sql
+    /// CASE hash(cols) % num_partitions
+    ///   WHEN 0 THEN (col1 >= min1 AND col1 <= max1 AND col2 >= min2 AND col2 
<= max2)
+    ///   WHEN 1 THEN (col1 >= min3 AND col1 <= max3 AND col2 >= min4 AND col2 
<= max4)
+    ///   ...
+    ///   ELSE true  -- Critical: ensures no false negatives for incomplete 
partitions
+    /// END
+    /// ```
     ///
-    /// For example, with 2 partitions and 2 columns:
-    /// ((col0 >= p0_min0 AND col0 <= p0_max0 AND col1 >= p0_min1 AND col1 <= 
p0_max1)
-    ///  OR
-    ///  (col0 >= p1_min0 AND col0 <= p1_max0 AND col1 >= p1_min1 AND col1 <= 
p1_max1))
-    pub(crate) fn create_filter_from_partition_bounds(
+    /// ## Correctness Key Points:
+    /// - **Hash Function**: Uses the same hash as the join's partitioning 
scheme
+    /// - **Modulo Operation**: Maps hash values to partition IDs (0 to 
num_partitions-1)
+    /// - **WHEN Clauses**: Only created for partitions that have completed 
and reported bounds
+    /// - **ELSE true**: Ensures rows from incomplete partitions are never 
filtered out
+    /// - **Single Hash**: Hash is computed once per row, regardless of how 
many partitions completed
+    pub(crate) fn create_progressive_filter_from_partition_bounds(
         &self,
-        bounds: &[PartitionBounds],
+        bounds: &HashMap<usize, PartitionBounds>,
     ) -> Result<Arc<dyn PhysicalExpr>> {
-        if bounds.is_empty() {
-            return Ok(lit(true));
-        }
+        // Step 1: Create the partition assignment expression: hash(join_cols) 
% num_partitions
+        // This must match the hash function used by RepartitionExec for 
correctness
+        let hash_expr = repartition_hash(self.on_right.clone())?;
+        let total_partitions_expr =
+            lit(ScalarValue::UInt64(Some(self.total_partitions as u64)));
+        let modulo_expr = Arc::new(BinaryExpr::new(
+            hash_expr,
+            Operator::Modulo,
+            total_partitions_expr,
+        )) as Arc<dyn PhysicalExpr>;
 
-        // Create a predicate for each partition
-        let mut partition_predicates = Vec::with_capacity(bounds.len());
-
-        for partition_bounds in bounds.iter().sorted_by_key(|b| b.partition) {
-            // Create range predicates for each join key in this partition
-            let mut column_predicates = 
Vec::with_capacity(partition_bounds.len());
-
-            for (col_idx, right_expr) in self.on_right.iter().enumerate() {
-                if let Some(column_bounds) = 
partition_bounds.get_column_bounds(col_idx) {
-                    // Create predicate: col >= min AND col <= max
-                    let min_expr = Arc::new(BinaryExpr::new(
-                        Arc::clone(right_expr),
-                        Operator::GtEq,
-                        lit(column_bounds.min.clone()),
-                    )) as Arc<dyn PhysicalExpr>;
-                    let max_expr = Arc::new(BinaryExpr::new(
-                        Arc::clone(right_expr),
-                        Operator::LtEq,
-                        lit(column_bounds.max.clone()),
-                    )) as Arc<dyn PhysicalExpr>;
-                    let range_expr =
-                        Arc::new(BinaryExpr::new(min_expr, Operator::And, 
max_expr))
-                            as Arc<dyn PhysicalExpr>;
-                    column_predicates.push(range_expr);
+        // Step 2: Build WHEN clauses for each completed partition
+        // Format: WHEN partition_id THEN (bounds_predicate)
+        let when_thens = bounds.values().sorted_by_key(|b| 
b.partition).try_fold(
+            Vec::new(),
+            |mut acc, partition_bounds| {
+                // Create literal for partition ID (e.g., WHEN 0, WHEN 1, etc.)
+                let when_value =
+                    lit(ScalarValue::UInt64(Some(partition_bounds.partition as 
u64)));
+
+                // Create bounds predicate for this partition (e.g., col >= 
min AND col <= max)
+                if let Some(then_predicate) =
+                    self.create_partition_bounds_predicate(partition_bounds)?
+                {
+                    acc.push((when_value, then_predicate));
                 }
-            }
+                Ok::<_, datafusion_common::DataFusionError>(acc)
+            },
+        )?;
 
-            // Combine all column predicates for this partition with AND
-            if !column_predicates.is_empty() {
-                let partition_predicate = column_predicates
-                    .into_iter()
-                    .reduce(|acc, pred| {
-                        Arc::new(BinaryExpr::new(acc, Operator::And, pred))
-                            as Arc<dyn PhysicalExpr>
-                    })
-                    .unwrap();
-                partition_predicates.push(partition_predicate);
+        // Step 3: Build the complete CASE expression
+        use datafusion_physical_expr::expressions::case;
+        let expr = if when_thens.is_empty() {
+            // Edge case: No partitions have completed yet - pass everything 
through
+            lit(ScalarValue::Boolean(Some(true)))
+        } else {
+            // Create CASE expression with critical ELSE true clause
+            // The ELSE true ensures we never filter out rows from incomplete 
partitions
+            case(
+                Some(modulo_expr),    // CASE hash(cols) % num_partitions
+                when_thens,           // WHEN clauses for completed partitions
+                Some(lit(ScalarValue::Boolean(Some(true)))),  // ELSE true - 
no false negatives!
+            )?
+        };
+
+        Ok(expr)
+    }
+
+    /// Create final optimized filter when all partitions have completed
+    ///
+    /// This method represents the performance optimization phase of 
progressive filtering.
+    /// Once all partitions have reported their bounds, we can eliminate the 
hash-based
+    /// CASE expression and use a simpler, more efficient bounds-only filter.
+    ///
+    /// ## Optimization Benefits:
+    /// 1. **No Hash Computation**: Eliminates expensive hash calculations per 
row
+    /// 2. **Simpler Expression**: OR-based bounds are faster to evaluate than 
CASE expressions
+    /// 3. **Better Vectorization**: Simple bounds comparisons optimize better 
in Arrow
+    /// 4. **Reduced CPU Overhead**: Significant performance improvement for 
large datasets
+    ///
+    /// ## Generated Expression Structure:
+    /// ```sql
+    /// (col1 >= min1 AND col1 <= max1) OR    -- Partition 0 bounds
+    /// (col1 >= min2 AND col1 <= max2) OR    -- Partition 1 bounds
+    /// ...
+    /// (col1 >= minN AND col1 <= maxN)       -- Partition N bounds
+    /// ```
+    ///
+    /// ## Correctness Maintained:
+    /// - Each OR clause represents the exact bounds from one partition
+    /// - Union of all partition bounds = complete build-side value range
+    /// - No false negatives: if a value exists in build side, it passes this 
filter
+    /// - Same filtering effect as progressive filter, but much more efficient
+    ///
+    /// This transformation is only applied when ALL partitions have completed 
to ensure
+    /// we have complete bounds information.
+    pub(crate) fn create_optimized_filter_from_partition_bounds(
+        &self,
+        bounds: &HashMap<usize, PartitionBounds>,
+    ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
+        // Build individual partition predicates - each becomes one OR clause
+        let mut partition_filters = Vec::with_capacity(bounds.len());
+
+        for partition_bounds in bounds.values().sorted_by_key(|b| b.partition) 
{
+            if let Some(filter) =
+                self.create_partition_bounds_predicate(partition_bounds)?
+            {
+                // This partition contributed bounds - include in optimized 
filter
+                partition_filters.push(filter);
             }
+            // Skip empty partitions gracefully - they don't contribute bounds 
but
+            // shouldn't prevent the optimization from proceeding
         }
 
-        // Combine all partition predicates with OR
-        let combined_predicate = partition_predicates
-            .into_iter()
-            .reduce(|acc, pred| {
-                Arc::new(BinaryExpr::new(acc, Operator::Or, pred))
-                    as Arc<dyn PhysicalExpr>
-            })
-            .unwrap_or_else(|| lit(true));
-
-        Ok(combined_predicate)
+        // Create the final OR expression: bounds_0 OR bounds_1 OR ... OR 
bounds_N
+        // This replaces the hash-based CASE expression with a much faster 
bounds-only check
+        Ok(partition_filters.into_iter().reduce(|acc, filter| {
+            Arc::new(BinaryExpr::new(acc, Operator::Or, filter)) as Arc<dyn 
PhysicalExpr>
+        }))
     }
 
-    /// Report bounds from a completed partition and update dynamic filter if 
all partitions are done
+    /// Report bounds from a completed partition and immediately update the 
dynamic filter
+    ///
+    /// This is the core method that implements progressive filtering. Unlike 
traditional approaches
+    /// that wait for all partitions to complete, this method immediately 
applies a partial filter
+    /// as soon as each partition finishes building its hash table.
+    ///
+    /// ## Progressive Filter Logic
+    ///
+    /// The method maintains correctness through careful filter design:
+    ///
+    /// **Key Insight**: We can safely filter rows that belong to completed 
partitions while
+    /// letting all other rows pass through, because the hash function 
determines partition
+    /// membership deterministically.
+    ///
+    /// ## Filter Evolution Example
+    ///
+    /// Consider a 2-partition join on column `price`:
+    ///
+    /// **Initial state**: No filter applied
+    /// ```sql
+    /// -- All probe-side rows pass through
+    /// SELECT * FROM probe_table  -- No filtering
+    /// ```
+    ///
+    /// **After Partition 0 completes** (found price range [100, 200]):
+    /// ```sql
+    /// -- Progressive filter applied
+    /// SELECT * FROM probe_table
+    /// WHERE CASE hash(price) % 2
+    ///         WHEN 0 THEN price >= 100 AND price <= 200  -- Filter 
partition-0 data
+    ///         ELSE true                                   -- Pass through 
partition-1 data
+    ///       END
+    /// ```
+    /// → Filters out probe rows with price ∉ [100, 200] that hash to 
partition 0
+    ///
+    /// **After Partition 1 completes** (found price range [500, 600]):
+    /// ```sql
+    /// -- Final optimized filter
+    /// SELECT * FROM probe_table
+    /// WHERE (price >= 100 AND price <= 200) OR (price >= 500 AND price <= 
600)
+    /// ```
+    /// → Clean bounds-only filter, no hash computation needed
+    ///
+    /// ## Correctness Guarantee
     ///
-    /// This method coordinates the dynamic filter updates across all 
partitions. It stores the
-    /// bounds from the current partition, increments the completion counter, 
and when all
-    /// partitions have reported, creates an OR'd filter from individual 
partition bounds.
+    /// This approach ensures **zero false negatives** (never incorrectly 
excludes valid joins):
     ///
-    /// This method is async and uses a [`tokio::sync::Barrier`] to wait for 
all partitions
-    /// to report their bounds. Once that occurs, the method will resolve for 
all callers and the
-    /// dynamic filter will be updated exactly once.
+    /// 1. **Completed Partitions**: Rows are filtered by actual build-side 
bounds
+    /// 2. **Incomplete Partitions**: All rows pass through (`ELSE true`)
+    /// 3. **Partition Assignment**: Hash function matches the join's 
partitioning scheme exactly
+    /// 4. **Bounds Accuracy**: Min/max values computed from actual build-side 
data
     ///
-    /// # Note
+    /// The filter may have **false positives** (includes rows that won't 
join) during the
+    /// progressive phase, but these are eliminated during the actual join 
operation.
     ///
-    /// As barriers are reusable, it is likely an error to call this method 
more times than the
-    /// total number of partitions - as it can lead to pending futures that 
never resolve. We rely
-    /// on correct usage from the caller rather than imposing additional 
checks here. If this is a concern,
-    /// consider making the resulting future shared so the ready result can be 
reused.
+    /// ## Concurrency Handling
+    ///
+    /// - **Thread Safety**: Uses mutex to coordinate between concurrent 
partition executions
+    /// - **Deduplication**: Handles multiple reports from same partition 
(CollectLeft mode)
+    /// - **Atomic Updates**: Filter updates are applied atomically to avoid 
inconsistent states
+    ///
+    /// ## Performance Impact
+    ///
+    /// - **Immediate Benefit**: Probe-side filtering starts as soon as first 
partition completes
+    /// - **I/O Reduction**: Less data read from storage/network as build 
partitions complete
+    /// - **CPU Optimization**: Final filter removes hash computation overhead
+    /// - **Scalability**: No barrier synchronization delays between partitions
     ///
     /// # Arguments
     /// * `left_side_partition_id` - The identifier for the **left-side** 
partition reporting its bounds
     /// * `partition_bounds` - The bounds computed by this partition (if any)
     ///
     /// # Returns
     /// * `Result<()>` - Ok if successful, Err if filter update failed
-    pub(crate) async fn report_partition_bounds(
+    pub(crate) fn report_partition_bounds(
         &self,
         left_side_partition_id: usize,
         partition_bounds: Option<Vec<ColumnBounds>>,
     ) -> Result<()> {
-        // Store bounds in the accumulator - this runs once per partition
-        if let Some(bounds) = partition_bounds {
-            let mut guard = self.inner.lock();
-
-            let should_push = if let Some(last_bound) = guard.bounds.last() {
-                // In `PartitionMode::CollectLeft`, all streams on the left 
side share the same partition id (0).
-                // Since this function can be called multiple times for that 
same partition, we must deduplicate
-                // by checking against the last recorded bound.
-                last_bound.partition != left_side_partition_id
-            } else {
-                true
-            };
+        let mut inner = self.inner.lock();
 
-            if should_push {
-                guard
-                    .bounds
-                    .push(PartitionBounds::new(left_side_partition_id, 
bounds));
-            }
-        }
+        // Always increment completion counter - every partition reports 
exactly once
+        inner.completed_count += 1;
 
-        if self.barrier.wait().await.is_leader() {
-            // All partitions have reported, so we can update the filter
-            let inner = self.inner.lock();
-            if !inner.bounds.is_empty() {
-                let filter_expr =
-                    self.create_filter_from_partition_bounds(&inner.bounds)?;
-                self.dynamic_filter.update(filter_expr)?;
-            }
+        // Store bounds from this partition (avoid duplicates)
+        // In CollectLeft mode, multiple streams may report the same 
partition_id,
+        // but we only want to store bounds once
+        inner
+            .bounds
+            .entry(left_side_partition_id)
+            .or_insert_with(|| {
+                if let Some(bounds) = partition_bounds {
+                    PartitionBounds::new(left_side_partition_id, bounds)
+                } else {
+                    // Insert an empty bounds entry to track this partition
+                    PartitionBounds::new(left_side_partition_id, vec![])
+                }
+            });
+
+        let completed = inner.completed_count;
+        let total = self.total_partitions;
+
+        let all_partitions_complete = completed == total;
+
+        // Create the appropriate filter based on completion status
+        let filter_expr = if all_partitions_complete && 
!inner.filter_optimized {
+            // All partitions complete - use optimized filter without hash 
checks
+            inner.filter_optimized = true;
+            self.create_optimized_filter_from_partition_bounds(&inner.bounds)?
+        } else {
+            // Progressive phase - use hash-based filter
+            
Some(self.create_progressive_filter_from_partition_bounds(&inner.bounds)?)
+        };
+
+        // Release lock before updating filter to avoid holding it during the 
update
+        drop(inner);
+
+        // Update the dynamic filter
+        if let Some(filter_expr) = filter_expr {
+            self.dynamic_filter.update(filter_expr)?;

Review Comment:
   
   The manual `drop(guard)` before `self.dynamic_filter.update()` may to create 
a race condition.
   
   This opens a window where a final, optimized filter calculated by one thread 
could be overwritten by filter generated in progressive phase from another 
thread that wins the race to call update().
   
   
   ```
         
+----------------------------------------------------------------------------------+
   
   Time  |      Thread 1                      |      Thread 2                   
   |      Thread 3
    |    |                                    |                                 
   |
    V    | self.inner.lock() [ACQUIRED]       |                                 
   |
         |   +-> Calculates filter_expr_1     |                                 
   |
         | drop(guard) [RELEASED]             |                                 
   |
         |                                    |                                 
   |
         |                                    | self.inner.lock() [ACQUIRED]    
   |
         |                                    |   +-> Calculates filter_expr_2  
   |
         |                                    | drop(guard) [RELEASED]          
   |
         |                                    |                                 
   |
    |    |                                    |                                 
   |
    |    |                                    |                                 
   |
    V    | self.dynamic_filter.update(...)    |                                 
   | self.inner.lock() [ACQUIRED]
         |   |                                |                                 
   |   |
         |   +-> Acquires write lock          |                                 
   |   |
         |                                    |                                 
   |   +-> Calculates filter_expr_3
         |                                    |                                 
   |       (The **optimal** one)
         |                                    |                                 
   |        
         |                                    |                                 
   |
         |   +-> Writes filter_1 to state     |                                 
   |
         |       (Current State: filter_1)    |                                 
   |
         |                                    |                                 
   |
         |   +-> Releases write lock          |                                 
   | drop(guard) [RELEASED]
         |                                    |                                 
   |
    |    
+----------------------------------------------------------------------------------+
    V    |                                                                      
            |
         | Thread 1's update is complete. Now Thread 2 and 3 race to apply 
their updates.   |
         
+----------------------------------------------------------------------------------+
    |    |                                    |                                 
   |
    V    |                                    |                                 
   | self.dynamic_filter.update(...)
         |                                    |                                 
   |   |
         |                                    |                                 
   |   +-> Acquires write lock
         |                                    |                                 
   |   +-> Writes filter_3
         |                                    |                                 
   |       (Current State: filter_3)
         |                                    |                                 
   |   +-> Releases write lock
         |                                    |                                 
   |
         |                                    | self.dynamic_filter.update(...) 
   |
         |                                    |   |                             
   |
         |                                    |   +-> Acquires write lock       
   |
         |                                    |   +-> Writes filter_2 (stale)   
   |
         |                                    |       (Current State: filter_2) 
   |
         |                                    |   +-> Releases write lock       
   |
         |                                    |                                 
   |
    |    |                                    |                                 
   |
    V    
+----------------------------------------------------------------------------------+
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to