pepijnve commented on code in PR #18152:
URL: https://github.com/apache/datafusion/pull/18152#discussion_r2466125372


##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -196,82 +577,135 @@ impl CaseExpr {
     /// END
     fn case_when_with_expr(&self, batch: &RecordBatch) -> 
Result<ColumnarValue> {
         let return_type = self.data_type(&batch.schema())?;
-        let expr = self.expr.as_ref().unwrap();
-        let base_value = expr.evaluate(batch)?;
-        let base_value = base_value.into_array(batch.num_rows())?;
+        let mut result_builder = ResultBuilder::new(&return_type, 
batch.num_rows());
+
+        // `remainder_rows` contains the indices of the rows that need to be 
evaluated
+        let mut remainder_rows: ArrayRef =
+            Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as 
u32));
+        // `remainder_batch` contains the rows themselves that need to be 
evaluated
+        let mut remainder_batch = Cow::Borrowed(batch);
+
+        // evaluate the base expression
+        let mut base_value = self
+            .expr
+            .as_ref()
+            .unwrap()
+            .evaluate(batch)?
+            .into_array(batch.num_rows())?;
+
+        // Fill in a result value already for rows where the base expression 
value is null
+        // Since each when expression is tested against the base expression 
using the equality
+        // operator, null base values can never match any when expression. `x 
= NULL` is falsy,
+        // for all possible values of `x`.
         let base_nulls = is_null(base_value.as_ref())?;
+        if base_nulls.true_count() > 0 {
+            // If there is an else expression, use that as the default value 
for the null rows
+            // Otherwise the default `null` value from the result builder will 
be used.
+            if let Some(e) = self.else_expr() {
+                let expr = try_cast(Arc::clone(e), &batch.schema(), 
return_type.clone())?;
 
-        // start with nulls as default output
-        let mut current_value = new_null_array(&return_type, batch.num_rows());
-        // We only consider non-null values while comparing with whens
-        let mut remainder = not(&base_nulls)?;
-        let mut non_null_remainder_count = remainder.true_count();
-        for i in 0..self.when_then_expr.len() {
-            // If there are no rows left to process, break out of the loop 
early
-            if non_null_remainder_count == 0 {
-                break;
+                if base_nulls.true_count() == remainder_batch.num_rows() {
+                    // All base values were null, so no need to filter
+                    let nulls_value = expr.evaluate(&remainder_batch)?;
+                    result_builder.add_branch_result(&remainder_rows, 
nulls_value)?;
+                } else {
+                    let nulls_filter = create_filter(&base_nulls);
+                    let nulls_batch =
+                        filter_record_batch(&remainder_batch, &nulls_filter)?;
+                    let nulls_rows = filter_array(&remainder_rows, 
&nulls_filter)?;
+                    let nulls_value = expr.evaluate(&nulls_batch)?;
+                    result_builder.add_branch_result(&nulls_rows, 
nulls_value)?;
+                }
             }
 
-            let when_predicate = &self.when_then_expr[i].0;
-            let when_value = when_predicate.evaluate_selection(batch, 
&remainder)?;
-            let when_value = when_value.into_array(batch.num_rows())?;
-            // build boolean array representing which rows match the "when" 
value
-            let when_match = compare_with_eq(
-                &when_value,
-                &base_value,
-                // The types of case and when expressions will be coerced to 
match.
-                // We only need to check if the base_value is nested.
-                base_value.data_type().is_nested(),
-            )?;
-            // Treat nulls as false
-            let when_match = match when_match.null_count() {
-                0 => Cow::Borrowed(&when_match),
-                _ => Cow::Owned(prep_null_mask_filter(&when_match)),
-            };
-            // Make sure we only consider rows that have not been matched yet
-            let when_value = and(&when_match, &remainder)?;
+            // All base values were null, so we can return early
+            if base_nulls.true_count() == remainder_batch.num_rows() {
+                return result_builder.finish();
+            }
+
+            // Remove the null rows from the remainder batch
+            let not_null_filter = create_filter(&not(&base_nulls)?);
+            remainder_batch =
+                Cow::Owned(filter_record_batch(&remainder_batch, 
&not_null_filter)?);
+            remainder_rows = filter_array(&remainder_rows, &not_null_filter)?;
+            base_value = filter_array(&base_value, &not_null_filter)?;
+        }
+
+        // The types of case and when expressions will be coerced to match.
+        // We only need to check if the base_value is nested.
+        let base_value_is_nested = base_value.data_type().is_nested();
+
+        for i in 0..self.when_then_expr.len() {
+            // Evaluate the 'when' predicate for the remainder batch
+            // This results in a boolean array with the same length as the 
remaining number of rows
+            let when_expr = &self.when_then_expr[i].0;
+            let when_value = match when_expr.evaluate(&remainder_batch)? {
+                ColumnarValue::Array(a) => {
+                    compare_with_eq(&a, &base_value, base_value_is_nested)
+                }
+                ColumnarValue::Scalar(s) => {
+                    let scalar = Scalar::new(s.to_array()?);
+                    compare_with_eq(&scalar, &base_value, base_value_is_nested)
+                }
+            }?;
 
-            // If the predicate did not match any rows, continue to the next 
branch immediately
             let when_match_count = when_value.true_count();
+
+            // If the 'when' predicate did not match any rows, continue to the 
next branch immediately
             if when_match_count == 0 {
                 continue;
             }
 
-            let then_expression = &self.when_then_expr[i].1;
-            let then_value = then_expression.evaluate_selection(batch, 
&when_value)?;
+            // If the 'when' predicate matched all remaining rows, there is no 
need to filter
+            if when_match_count == remainder_batch.num_rows() {
+                let then_expression = &self.when_then_expr[i].1;
+                let then_value = then_expression.evaluate(&remainder_batch)?;
+                result_builder.add_branch_result(&remainder_rows, then_value)?;
+                return result_builder.finish();
+            }
 
-            current_value = match then_value {
-                ColumnarValue::Scalar(ScalarValue::Null) => {
-                    nullif(current_value.as_ref(), &when_value)?
-                }
-                ColumnarValue::Scalar(then_value) => {
-                    zip(&when_value, &then_value.to_scalar()?, &current_value)?
-                }
-                ColumnarValue::Array(then_value) => {
-                    zip(&when_value, &then_value, &current_value)?
-                }
+            // Make sure 'NULL' is treated as false
+            let when_value = match when_value.null_count() {
+                0 => when_value,
+                _ => prep_null_mask_filter(&when_value),
             };
 
-            remainder = and_not(&remainder, &when_value)?;
-            non_null_remainder_count -= when_match_count;
-        }
+            // Filter the remainder batch based on the 'when' value
+            // This results in a batch containing only the rows that need to 
be evaluated
+            // for the current branch
+            let then_filter = create_filter(&when_value);
+            let then_batch = filter_record_batch(&remainder_batch, 
&then_filter)?;
+            let then_rows = filter_array(&remainder_rows, &then_filter)?;
 
-        if let Some(e) = self.else_expr() {
-            // null and unmatched tuples should be assigned else value
-            remainder = or(&base_nulls, &remainder)?;
-
-            if remainder.true_count() > 0 {
-                // keep `else_expr`'s data type and return type consistent
-                let expr = try_cast(Arc::clone(e), &batch.schema(), 
return_type.clone())?;
+            let then_expression = &self.when_then_expr[i].1;
+            let then_value = then_expression.evaluate(&then_batch)?;
+            result_builder.add_branch_result(&then_rows, then_value)?;
 
-                let else_ = expr
-                    .evaluate_selection(batch, &remainder)?
-                    .into_array(batch.num_rows())?;
-                current_value = zip(&remainder, &else_, &current_value)?;
+            // If this is the last 'when' branch and there is no 'else' 
expression, there's no
+            // point in calculating the remaining rows.
+            if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
+                return result_builder.finish();
             }

Review Comment:
   I went for an early return (with the minor code duplication that comes with 
it) to ensure that `remainder_batch` was not processed again since we're 
skipping the bit of code that would make it empty. Using `break` here breaks 
(pun not intended) the invariant that at every point in the function 
`remainder_rows` is the set of rows that still needs to be processed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to