adriangb commented on code in PR #17632: URL: https://github.com/apache/datafusion/pull/17632#discussion_r2365684244
########## datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs: ########## @@ -167,139 +220,301 @@ impl SharedBoundsAccumulator { }; Self { inner: Mutex::new(SharedBoundsState { - bounds: Vec::with_capacity(expected_calls), + bounds: HashMap::with_capacity(total_partitions), + filter_optimized: false, + completed_count: 0, }), - barrier: Barrier::new(expected_calls), + total_partitions, dynamic_filter, on_right, } } - /// Create a filter expression from individual partition bounds using OR logic. + /// Create a bounds predicate for a single partition: (col >= min AND col <= max) for all columns. + /// This is used in both progressive and final filter creation. + /// Returns None if no bounds are available for this partition. + fn create_partition_bounds_predicate( + &self, + partition_bounds: &PartitionBounds, + ) -> Result<Option<Arc<dyn PhysicalExpr>>> { + let mut column_predicates = Vec::with_capacity(partition_bounds.len()); + + for (col_idx, right_expr) in self.on_right.iter().enumerate() { + if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) { + // Create predicate: col >= min AND col <= max + let min_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::GtEq, + lit(column_bounds.min.clone()), + )) as Arc<dyn PhysicalExpr>; + let max_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::LtEq, + lit(column_bounds.max.clone()), + )) as Arc<dyn PhysicalExpr>; + let range_expr = + Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) + as Arc<dyn PhysicalExpr>; + column_predicates.push(range_expr); + } else { + // Missing bounds for this column, the created predicate will have lower selectivity but will still be correct + continue; + } + } + + // Combine all column predicates for this partition with AND + Ok(column_predicates.into_iter().reduce(|acc, pred| { + Arc::new(BinaryExpr::new(acc, Operator::And, pred)) as Arc<dyn PhysicalExpr> + })) + } + + /// Create progressive filter using hash-based expressions to avoid false negatives. + /// + /// This is the heart of progressive filtering. It creates a CASE expression that applies + /// bounds filtering only to rows belonging to completed partitions, while safely passing + /// through all data from incomplete partitions. /// - /// This creates a filter where each partition's bounds form a conjunction (AND) - /// of column range predicates, and all partitions are combined with OR. + /// ## Generated Expression Structure: + /// ```sql + /// CASE hash(cols) % num_partitions + /// WHEN 0 THEN (col1 >= min1 AND col1 <= max1 AND col2 >= min2 AND col2 <= max2) + /// WHEN 1 THEN (col1 >= min3 AND col1 <= max3 AND col2 >= min4 AND col2 <= max4) + /// ... + /// ELSE true -- Critical: ensures no false negatives for incomplete partitions + /// END + /// ``` /// - /// For example, with 2 partitions and 2 columns: - /// ((col0 >= p0_min0 AND col0 <= p0_max0 AND col1 >= p0_min1 AND col1 <= p0_max1) - /// OR - /// (col0 >= p1_min0 AND col0 <= p1_max0 AND col1 >= p1_min1 AND col1 <= p1_max1)) - pub(crate) fn create_filter_from_partition_bounds( + /// ## Correctness Key Points: + /// - **Hash Function**: Uses the same hash as the join's partitioning scheme + /// - **Modulo Operation**: Maps hash values to partition IDs (0 to num_partitions-1) + /// - **WHEN Clauses**: Only created for partitions that have completed and reported bounds + /// - **ELSE true**: Ensures rows from incomplete partitions are never filtered out + /// - **Single Hash**: Hash is computed once per row, regardless of how many partitions completed + pub(crate) fn create_progressive_filter_from_partition_bounds( &self, - bounds: &[PartitionBounds], + bounds: &HashMap<usize, PartitionBounds>, ) -> Result<Arc<dyn PhysicalExpr>> { - if bounds.is_empty() { - return Ok(lit(true)); - } + // Step 1: Create the partition assignment expression: hash(join_cols) % num_partitions + // This must match the hash function used by RepartitionExec for correctness + let hash_expr = repartition_hash(self.on_right.clone())?; + let total_partitions_expr = + lit(ScalarValue::UInt64(Some(self.total_partitions as u64))); + let modulo_expr = Arc::new(BinaryExpr::new( + hash_expr, + Operator::Modulo, + total_partitions_expr, + )) as Arc<dyn PhysicalExpr>; - // Create a predicate for each partition - let mut partition_predicates = Vec::with_capacity(bounds.len()); - - for partition_bounds in bounds.iter().sorted_by_key(|b| b.partition) { - // Create range predicates for each join key in this partition - let mut column_predicates = Vec::with_capacity(partition_bounds.len()); - - for (col_idx, right_expr) in self.on_right.iter().enumerate() { - if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) { - // Create predicate: col >= min AND col <= max - let min_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::GtEq, - lit(column_bounds.min.clone()), - )) as Arc<dyn PhysicalExpr>; - let max_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::LtEq, - lit(column_bounds.max.clone()), - )) as Arc<dyn PhysicalExpr>; - let range_expr = - Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) - as Arc<dyn PhysicalExpr>; - column_predicates.push(range_expr); + // Step 2: Build WHEN clauses for each completed partition + // Format: WHEN partition_id THEN (bounds_predicate) + let when_thens = bounds.values().sorted_by_key(|b| b.partition).try_fold( + Vec::new(), + |mut acc, partition_bounds| { + // Create literal for partition ID (e.g., WHEN 0, WHEN 1, etc.) + let when_value = + lit(ScalarValue::UInt64(Some(partition_bounds.partition as u64))); + + // Create bounds predicate for this partition (e.g., col >= min AND col <= max) + if let Some(then_predicate) = + self.create_partition_bounds_predicate(partition_bounds)? + { + acc.push((when_value, then_predicate)); } - } + Ok::<_, datafusion_common::DataFusionError>(acc) + }, + )?; - // Combine all column predicates for this partition with AND - if !column_predicates.is_empty() { - let partition_predicate = column_predicates - .into_iter() - .reduce(|acc, pred| { - Arc::new(BinaryExpr::new(acc, Operator::And, pred)) - as Arc<dyn PhysicalExpr> - }) - .unwrap(); - partition_predicates.push(partition_predicate); + // Step 3: Build the complete CASE expression + use datafusion_physical_expr::expressions::case; + let expr = if when_thens.is_empty() { + // Edge case: No partitions have completed yet - pass everything through + lit(ScalarValue::Boolean(Some(true))) + } else { + // Create CASE expression with critical ELSE true clause + // The ELSE true ensures we never filter out rows from incomplete partitions + case( + Some(modulo_expr), // CASE hash(cols) % num_partitions + when_thens, // WHEN clauses for completed partitions + Some(lit(ScalarValue::Boolean(Some(true)))), // ELSE true - no false negatives! + )? + }; + + Ok(expr) + } + + /// Create final optimized filter when all partitions have completed + /// + /// This method represents the performance optimization phase of progressive filtering. + /// Once all partitions have reported their bounds, we can eliminate the hash-based + /// CASE expression and use a simpler, more efficient bounds-only filter. + /// + /// ## Optimization Benefits: + /// 1. **No Hash Computation**: Eliminates expensive hash calculations per row + /// 2. **Simpler Expression**: OR-based bounds are faster to evaluate than CASE expressions + /// 3. **Better Vectorization**: Simple bounds comparisons optimize better in Arrow + /// 4. **Reduced CPU Overhead**: Significant performance improvement for large datasets + /// + /// ## Generated Expression Structure: + /// ```sql + /// (col1 >= min1 AND col1 <= max1) OR -- Partition 0 bounds + /// (col1 >= min2 AND col1 <= max2) OR -- Partition 1 bounds + /// ... + /// (col1 >= minN AND col1 <= maxN) -- Partition N bounds + /// ``` + /// + /// ## Correctness Maintained: + /// - Each OR clause represents the exact bounds from one partition + /// - Union of all partition bounds = complete build-side value range + /// - No false negatives: if a value exists in build side, it passes this filter + /// - Same filtering effect as progressive filter, but much more efficient + /// + /// This transformation is only applied when ALL partitions have completed to ensure + /// we have complete bounds information. + pub(crate) fn create_optimized_filter_from_partition_bounds( + &self, + bounds: &HashMap<usize, PartitionBounds>, + ) -> Result<Option<Arc<dyn PhysicalExpr>>> { + // Build individual partition predicates - each becomes one OR clause + let mut partition_filters = Vec::with_capacity(bounds.len()); + + for partition_bounds in bounds.values().sorted_by_key(|b| b.partition) { + if let Some(filter) = + self.create_partition_bounds_predicate(partition_bounds)? + { + // This partition contributed bounds - include in optimized filter + partition_filters.push(filter); } + // Skip empty partitions gracefully - they don't contribute bounds but + // shouldn't prevent the optimization from proceeding } - // Combine all partition predicates with OR - let combined_predicate = partition_predicates - .into_iter() - .reduce(|acc, pred| { - Arc::new(BinaryExpr::new(acc, Operator::Or, pred)) - as Arc<dyn PhysicalExpr> - }) - .unwrap_or_else(|| lit(true)); - - Ok(combined_predicate) + // Create the final OR expression: bounds_0 OR bounds_1 OR ... OR bounds_N + // This replaces the hash-based CASE expression with a much faster bounds-only check + Ok(partition_filters.into_iter().reduce(|acc, filter| { + Arc::new(BinaryExpr::new(acc, Operator::Or, filter)) as Arc<dyn PhysicalExpr> + })) } - /// Report bounds from a completed partition and update dynamic filter if all partitions are done + /// Report bounds from a completed partition and immediately update the dynamic filter + /// + /// This is the core method that implements progressive filtering. Unlike traditional approaches + /// that wait for all partitions to complete, this method immediately applies a partial filter + /// as soon as each partition finishes building its hash table. + /// + /// ## Progressive Filter Logic + /// + /// The method maintains correctness through careful filter design: + /// + /// **Key Insight**: We can safely filter rows that belong to completed partitions while + /// letting all other rows pass through, because the hash function determines partition + /// membership deterministically. + /// + /// ## Filter Evolution Example + /// + /// Consider a 2-partition join on column `price`: + /// + /// **Initial state**: No filter applied + /// ```sql + /// -- All probe-side rows pass through + /// SELECT * FROM probe_table -- No filtering + /// ``` + /// + /// **After Partition 0 completes** (found price range [100, 200]): + /// ```sql + /// -- Progressive filter applied + /// SELECT * FROM probe_table + /// WHERE CASE hash(price) % 2 + /// WHEN 0 THEN price >= 100 AND price <= 200 -- Filter partition-0 data + /// ELSE true -- Pass through partition-1 data + /// END + /// ``` + /// → Filters out probe rows with price ∉ [100, 200] that hash to partition 0 + /// + /// **After Partition 1 completes** (found price range [500, 600]): + /// ```sql + /// -- Final optimized filter + /// SELECT * FROM probe_table + /// WHERE (price >= 100 AND price <= 200) OR (price >= 500 AND price <= 600) + /// ``` + /// → Clean bounds-only filter, no hash computation needed + /// + /// ## Correctness Guarantee /// - /// This method coordinates the dynamic filter updates across all partitions. It stores the - /// bounds from the current partition, increments the completion counter, and when all - /// partitions have reported, creates an OR'd filter from individual partition bounds. + /// This approach ensures **zero false negatives** (never incorrectly excludes valid joins): /// - /// This method is async and uses a [`tokio::sync::Barrier`] to wait for all partitions - /// to report their bounds. Once that occurs, the method will resolve for all callers and the - /// dynamic filter will be updated exactly once. + /// 1. **Completed Partitions**: Rows are filtered by actual build-side bounds + /// 2. **Incomplete Partitions**: All rows pass through (`ELSE true`) + /// 3. **Partition Assignment**: Hash function matches the join's partitioning scheme exactly + /// 4. **Bounds Accuracy**: Min/max values computed from actual build-side data /// - /// # Note + /// The filter may have **false positives** (includes rows that won't join) during the + /// progressive phase, but these are eliminated during the actual join operation. /// - /// As barriers are reusable, it is likely an error to call this method more times than the - /// total number of partitions - as it can lead to pending futures that never resolve. We rely - /// on correct usage from the caller rather than imposing additional checks here. If this is a concern, - /// consider making the resulting future shared so the ready result can be reused. + /// ## Concurrency Handling + /// + /// - **Thread Safety**: Uses mutex to coordinate between concurrent partition executions + /// - **Deduplication**: Handles multiple reports from same partition (CollectLeft mode) + /// - **Atomic Updates**: Filter updates are applied atomically to avoid inconsistent states + /// + /// ## Performance Impact + /// + /// - **Immediate Benefit**: Probe-side filtering starts as soon as first partition completes + /// - **I/O Reduction**: Less data read from storage/network as build partitions complete + /// - **CPU Optimization**: Final filter removes hash computation overhead + /// - **Scalability**: No barrier synchronization delays between partitions /// /// # Arguments /// * `left_side_partition_id` - The identifier for the **left-side** partition reporting its bounds /// * `partition_bounds` - The bounds computed by this partition (if any) /// /// # Returns /// * `Result<()>` - Ok if successful, Err if filter update failed - pub(crate) async fn report_partition_bounds( + pub(crate) fn report_partition_bounds( &self, left_side_partition_id: usize, partition_bounds: Option<Vec<ColumnBounds>>, ) -> Result<()> { - // Store bounds in the accumulator - this runs once per partition - if let Some(bounds) = partition_bounds { - let mut guard = self.inner.lock(); - - let should_push = if let Some(last_bound) = guard.bounds.last() { - // In `PartitionMode::CollectLeft`, all streams on the left side share the same partition id (0). - // Since this function can be called multiple times for that same partition, we must deduplicate - // by checking against the last recorded bound. - last_bound.partition != left_side_partition_id - } else { - true - }; + let mut inner = self.inner.lock(); - if should_push { - guard - .bounds - .push(PartitionBounds::new(left_side_partition_id, bounds)); - } - } + // Always increment completion counter - every partition reports exactly once + inner.completed_count += 1; - if self.barrier.wait().await.is_leader() { - // All partitions have reported, so we can update the filter - let inner = self.inner.lock(); - if !inner.bounds.is_empty() { - let filter_expr = - self.create_filter_from_partition_bounds(&inner.bounds)?; - self.dynamic_filter.update(filter_expr)?; - } + // Store bounds from this partition (avoid duplicates) + // In CollectLeft mode, multiple streams may report the same partition_id, + // but we only want to store bounds once + inner + .bounds + .entry(left_side_partition_id) + .or_insert_with(|| { + if let Some(bounds) = partition_bounds { + PartitionBounds::new(left_side_partition_id, bounds) + } else { + // Insert an empty bounds entry to track this partition + PartitionBounds::new(left_side_partition_id, vec![]) + } + }); + + let completed = inner.completed_count; + let total = self.total_partitions; + + let all_partitions_complete = completed == total; + + // Create the appropriate filter based on completion status + let filter_expr = if all_partitions_complete && !inner.filter_optimized { + // All partitions complete - use optimized filter without hash checks + inner.filter_optimized = true; + self.create_optimized_filter_from_partition_bounds(&inner.bounds)? + } else { + // Progressive phase - use hash-based filter + Some(self.create_progressive_filter_from_partition_bounds(&inner.bounds)?) + }; + + // Release lock before updating filter to avoid holding it during the update + drop(inner); + + // Update the dynamic filter + if let Some(filter_expr) = filter_expr { + self.dynamic_filter.update(filter_expr)?; Review Comment: Will try to look and address next week -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org