alamb commented on code in PR #18329:
URL: https://github.com/apache/datafusion/pull/18329#discussion_r2475217780


##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -68,7 +70,106 @@ enum EvalMethod {
     /// if there is just one when/then pair and both the `then` and `else` are 
expressions
     ///
     /// CASE WHEN condition THEN expression ELSE expression END
-    ExpressionOrExpression,
+    ExpressionOrExpression(ProjectedCaseBody),
+}
+
+/// The body of a CASE expression which consists of an optional base 
expression, the "when/then"
+/// branches and an optional "else" branch.
+#[derive(Debug, Hash, PartialEq, Eq)]
+struct CaseBody {
+    /// Optional base expression that can be compared to literal values in the 
"when" expressions
+    expr: Option<Arc<dyn PhysicalExpr>>,
+    /// One or more when/then expressions
+    when_then_expr: Vec<WhenThen>,
+    /// Optional "else" expression
+    else_expr: Option<Arc<dyn PhysicalExpr>>,
+}
+
+impl CaseBody {
+    /// Derives a [ProjectedCaseBody] from this [CaseBody].
+    fn project(&self) -> Result<ProjectedCaseBody> {
+        // Determine the set of columns that are used in all the expressions 
of the case body.
+        let mut used_column_indices = HashSet::<usize>::new();
+        let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
+            expr.apply(|expr| {
+                if let Some(column) = expr.as_any().downcast_ref::<Column>() {
+                    used_column_indices.insert(column.index());
+                }
+                Ok(TreeNodeRecursion::Continue)
+            })
+            .expect("Closure cannot fail");
+        };
+
+        if let Some(e) = &self.expr {
+            collect_column_indices(e);
+        }
+        self.when_then_expr.iter().for_each(|(w, t)| {
+            collect_column_indices(w);
+            collect_column_indices(t);
+        });
+        if let Some(e) = &self.else_expr {
+            collect_column_indices(e);
+        }

Review Comment:
   > . I haven't reviewed this PR in detail but there may be other helper bits 
that you can use and generally it would be nice if we coalesce projection 
manipulation into ProjectionExprs because I feel like there's a lot of 
duplicate code in random places right now (obviously needs to be balanced with 
keeping the API surface area on ProjectionExprs reasonable).
   
   I also agree it feels like there is lots of random remapping code floating 
around
   
   However, that being said it is not a problem this PR introduces (though it 
may make it slightly worse)
   
   



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -845,6 +935,79 @@ impl CaseExpr {
         result_builder.finish()
     }
 
+    /// See [CaseExpr::expr_or_expr].
+    fn expr_or_expr(
+        &self,
+        batch: &RecordBatch,
+        when_value: &BooleanArray,
+        return_type: &DataType,
+    ) -> Result<ColumnarValue> {
+        let then_value = self.when_then_expr[0]
+            .1
+            .evaluate_selection(batch, when_value)?
+            .into_array(batch.num_rows())?;
+
+        // evaluate else expression on the values not covered by when_value
+        let remainder = not(when_value)?;
+        let e = self.else_expr.as_ref().unwrap();
+        // keep `else_expr`'s data type and return type consistent
+        let expr = try_cast(Arc::clone(e), &batch.schema(), 
return_type.clone())
+            .unwrap_or_else(|_| Arc::clone(e));
+        let else_ = expr
+            .evaluate_selection(batch, &remainder)?
+            .into_array(batch.num_rows())?;
+
+        Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
+    }
+}
+
+impl CaseExpr {
+    /// This function evaluates the form of CASE that matches an expression to 
fixed values.
+    ///
+    /// CASE expression
+    ///     WHEN value THEN result
+    ///     [WHEN ...]
+    ///     [ELSE result]
+    /// END
+    fn case_when_with_expr(
+        &self,
+        batch: &RecordBatch,
+        projected: &ProjectedCaseBody,
+    ) -> Result<ColumnarValue> {
+        let return_type = self.data_type(&batch.schema())?;
+        if projected.projection.len() < batch.num_columns() {

Review Comment:
   I was confused about this check -- when would `projection.len()` be greater? 
When all the input columns are used?
   
   If that is correct, can you please add some comments about that invariant in 
`ProjectedCaseBody`? Or maybe even better you could represent the idea of no 
projection more explicity
   
   Perhaps soemthing like 
   ```rust
   pub struct CaseExpr {
       /// The case expression body
       body: CaseBody,
       /// Optional projection to apply
       projection: Option<Projection>,
       /// Evaluation method to use
       eval_method: EvalMethod,
   }



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -46,13 +48,13 @@ enum EvalMethod {
     ///      [WHEN ...]
     ///      [ELSE result]
     /// END
-    NoExpression,
+    NoExpression(ProjectedCaseBody),

Review Comment:
   I am somewhat concerned about the duplication here -- the EvalMethod has a 
CaseBody (embedded in ProjectedCaseBody) but the CaseExpr also (still) has a 
`CaseBody` -- could they get out of sync?



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -68,7 +70,104 @@ enum EvalMethod {
     /// if there is just one when/then pair and both the `then` and `else` are 
expressions
     ///
     /// CASE WHEN condition THEN expression ELSE expression END
-    ExpressionOrExpression,
+    ExpressionOrExpression(ProjectedCaseBody),
+}
+
+/// The body of a CASE expression which consists of an optional base 
expression, the "when/then"
+/// branches and an optional "else" branch.
+#[derive(Debug, Hash, PartialEq, Eq)]
+struct CaseBody {
+    /// Optional base expression that can be compared to literal values in the 
"when" expressions
+    expr: Option<Arc<dyn PhysicalExpr>>,
+    /// One or more when/then expressions
+    when_then_expr: Vec<WhenThen>,
+    /// Optional "else" expression
+    else_expr: Option<Arc<dyn PhysicalExpr>>,
+}
+
+impl CaseBody {
+    /// Derives a [ProjectedCaseBody] from this [CaseBody].
+    fn project(&self) -> Result<ProjectedCaseBody> {
+        // Determine the set of columns that are used in all the expressions 
of the case body.
+        let mut used_column_indices = HashSet::<usize>::new();
+        let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
+            expr.apply(|expr| {
+                if let Some(column) = expr.as_any().downcast_ref::<Column>() {
+                    used_column_indices.insert(column.index());
+                }
+                Ok(TreeNodeRecursion::Continue)
+            })
+            .expect("Closure cannot fail");
+        };
+
+        if let Some(e) = &self.expr {
+            collect_column_indices(e);
+        }
+        self.when_then_expr.iter().for_each(|(w, t)| {
+            collect_column_indices(w);
+            collect_column_indices(t);
+        });
+        if let Some(e) = &self.else_expr {
+            collect_column_indices(e);
+        }
+
+        // Construct a mapping from the original column index to the projected 
column index.
+        let column_index_map = used_column_indices
+            .iter()
+            .enumerate()
+            .map(|(projected, original)| (*original, projected))
+            .collect::<HashMap<usize, usize>>();
+
+        // Construct the projected body by rewriting each expression from the 
original body
+        // using the column index mapping.
+        let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn 
PhysicalExpr>> {
+            Arc::clone(expr)
+                .transform_down(|e| {
+                    if let Some(column) = e.as_any().downcast_ref::<Column>() {
+                        let original = column.index();
+                        let projected = 
*column_index_map.get(&original).unwrap();
+                        if projected != original {
+                            return Ok(Transformed::yes(Arc::new(Column::new(
+                                column.name(),
+                                projected,
+                            ))));
+                        }
+                    }
+                    Ok(Transformed::no(e))
+                })
+                .map(|t| t.data)
+        };
+
+        let projected_body = CaseBody {
+            expr: self.expr.as_ref().map(project).transpose()?,
+            when_then_expr: self
+                .when_then_expr
+                .iter()
+                .map(|(e, t)| Ok((project(e)?, project(t)?)))
+                .collect::<Result<Vec<_>>>()?,
+            else_expr: self.else_expr.as_ref().map(project).transpose()?,
+        };
+
+        // Construct the projection vector
+        let projection = column_index_map
+            .iter()
+            .sorted_by_key(|(_, v)| **v)
+            .map(|(k, _)| *k)
+            .collect::<Vec<_>>();
+
+        Ok(ProjectedCaseBody {
+            projection,
+            body: projected_body,
+        })
+    }
+}
+
+/// A derived case body that can be used to evaluate a case expression after 
projecting
+/// record batches using a projection vector.
+#[derive(Debug, Hash, PartialEq, Eq)]
+struct ProjectedCaseBody {
+    projection: Vec<usize>,
+    body: CaseBody,

Review Comment:
   I think a few more comments would help here as I had to read the code 
carefully to understand what was being projected
   
   perhaps something like
   
   ```suggestion
   /// A derived case body that can be used to evaluate a case expression after 
projecting
   /// record batches using a projection vector.
   ///
   /// This is used to avoid filtering / copying columns that are not used in 
the
   /// input `RecordBatch` when progressively evaluating a `CASE` expression's
   /// remainder batches.
   ///
   /// For example, if we are evaluating the following case expression that
   /// only references columns B and D:
   ///
   /// ```sql
   /// CASE WHEN B > 10 THEN D ELSE NULL END
   /// ```
   ///
   /// If the input has 4 columns, `[A, B, C, D]` it is wasteful to carry 
through
   /// columns A and C that are not used in the expression. Instead, the 
projection 
   /// vector will be `[1, 3]` and the case body
   /// will be rewritten to refer to columns `[0, 1]` 
   /// (i.e. remap references from `B` -> 0`, and `D` -> `1`)
   #[derive(Debug, Hash, PartialEq, Eq)]
   struct ProjectedCaseBody {
       projection: Vec<usize>,
       body: CaseBody,
   ```



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -556,63 +651,61 @@ impl CaseExpr {
         };
 
         if when_then_expr.is_empty() {
-            exec_err!("There must be at least one WHEN clause")
-        } else {
-            let eval_method = if expr.is_some() {
-                EvalMethod::WithExpression
-            } else if when_then_expr.len() == 1
-                && is_cheap_and_infallible(&(when_then_expr[0].1))
-                && else_expr.is_none()
-            {
-                EvalMethod::InfallibleExprOrNull
-            } else if when_then_expr.len() == 1
-                && when_then_expr[0].1.as_any().is::<Literal>()
-                && else_expr.is_some()
-                && else_expr.as_ref().unwrap().as_any().is::<Literal>()
-            {
-                EvalMethod::ScalarOrScalar
-            } else if when_then_expr.len() == 1 && else_expr.is_some() {
-                EvalMethod::ExpressionOrExpression
-            } else {
-                EvalMethod::NoExpression
-            };
-
-            Ok(Self {
-                expr,
-                when_then_expr,
-                else_expr,
-                eval_method,
-            })
+            return exec_err!("There must be at least one WHEN clause");
         }
+
+        let body = CaseBody {
+            expr,
+            when_then_expr,
+            else_expr,
+        };
+
+        let eval_method = if body.expr.is_some() {
+            EvalMethod::WithExpression(body.project()?)
+        } else if body.when_then_expr.len() == 1
+            && is_cheap_and_infallible(&(body.when_then_expr[0].1))
+            && body.else_expr.is_none()
+        {
+            EvalMethod::InfallibleExprOrNull
+        } else if body.when_then_expr.len() == 1
+            && body.when_then_expr[0].1.as_any().is::<Literal>()
+            && body.else_expr.is_some()
+            && body.else_expr.as_ref().unwrap().as_any().is::<Literal>()
+        {
+            EvalMethod::ScalarOrScalar
+        } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() {
+            EvalMethod::ExpressionOrExpression(body.project()?)
+        } else {
+            EvalMethod::NoExpression(body.project()?)
+        };
+
+        Ok(Self { body, eval_method })
     }
 
     /// Optional base expression that can be compared to literal values in the 
"when" expressions
     pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
-        self.expr.as_ref()
+        self.body.expr.as_ref()
     }
 
     /// One or more when/then expressions
     pub fn when_then_expr(&self) -> &[WhenThen] {
-        &self.when_then_expr
+        &self.body.when_then_expr
     }
 
     /// Optional "else" expression
     pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
-        self.else_expr.as_ref()
+        self.body.else_expr.as_ref()
     }
 }
 
-impl CaseExpr {
-    /// This function evaluates the form of CASE that matches an expression to 
fixed values.
-    ///
-    /// CASE expression
-    ///     WHEN value THEN result
-    ///     [WHEN ...]
-    ///     [ELSE result]
-    /// END
-    fn case_when_with_expr(&self, batch: &RecordBatch) -> 
Result<ColumnarValue> {
-        let return_type = self.data_type(&batch.schema())?;
-        let mut result_builder = ResultBuilder::new(&return_type, 
batch.num_rows());
+impl CaseBody {
+    /// See [CaseExpr::case_when_with_expr].
+    fn case_when_with_expr(
+        &self,
+        batch: &RecordBatch,
+        return_type: &DataType,
+    ) -> Result<ColumnarValue> {
+        let mut result_builder = ResultBuilder::new(return_type, 
batch.num_rows());
 
         // `remainder_rows` contains the indices of the rows that need to be 
evaluated
         let mut remainder_rows: ArrayRef =

Review Comment:
   it took me som reading to understand the point of the PR is to remove unused 
columns carried along in `remainder_rows` 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to