alamb commented on code in PR #18329:
URL: https://github.com/apache/datafusion/pull/18329#discussion_r2477529977


##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -68,7 +70,104 @@ enum EvalMethod {
     /// if there is just one when/then pair and both the `then` and `else` are 
expressions
     ///
     /// CASE WHEN condition THEN expression ELSE expression END
-    ExpressionOrExpression,
+    ExpressionOrExpression(ProjectedCaseBody),
+}
+
+/// The body of a CASE expression which consists of an optional base 
expression, the "when/then"
+/// branches and an optional "else" branch.
+#[derive(Debug, Hash, PartialEq, Eq)]
+struct CaseBody {
+    /// Optional base expression that can be compared to literal values in the 
"when" expressions
+    expr: Option<Arc<dyn PhysicalExpr>>,
+    /// One or more when/then expressions
+    when_then_expr: Vec<WhenThen>,
+    /// Optional "else" expression
+    else_expr: Option<Arc<dyn PhysicalExpr>>,
+}
+
+impl CaseBody {
+    /// Derives a [ProjectedCaseBody] from this [CaseBody].
+    fn project(&self) -> Result<ProjectedCaseBody> {
+        // Determine the set of columns that are used in all the expressions 
of the case body.
+        let mut used_column_indices = HashSet::<usize>::new();
+        let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
+            expr.apply(|expr| {
+                if let Some(column) = expr.as_any().downcast_ref::<Column>() {
+                    used_column_indices.insert(column.index());
+                }
+                Ok(TreeNodeRecursion::Continue)
+            })
+            .expect("Closure cannot fail");
+        };
+
+        if let Some(e) = &self.expr {
+            collect_column_indices(e);
+        }
+        self.when_then_expr.iter().for_each(|(w, t)| {
+            collect_column_indices(w);
+            collect_column_indices(t);
+        });
+        if let Some(e) = &self.else_expr {
+            collect_column_indices(e);
+        }
+
+        // Construct a mapping from the original column index to the projected 
column index.
+        let column_index_map = used_column_indices
+            .iter()
+            .enumerate()
+            .map(|(projected, original)| (*original, projected))
+            .collect::<HashMap<usize, usize>>();
+
+        // Construct the projected body by rewriting each expression from the 
original body
+        // using the column index mapping.
+        let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn 
PhysicalExpr>> {
+            Arc::clone(expr)
+                .transform_down(|e| {
+                    if let Some(column) = e.as_any().downcast_ref::<Column>() {
+                        let original = column.index();
+                        let projected = 
*column_index_map.get(&original).unwrap();
+                        if projected != original {
+                            return Ok(Transformed::yes(Arc::new(Column::new(
+                                column.name(),
+                                projected,
+                            ))));
+                        }
+                    }
+                    Ok(Transformed::no(e))
+                })
+                .map(|t| t.data)
+        };
+
+        let projected_body = CaseBody {
+            expr: self.expr.as_ref().map(project).transpose()?,
+            when_then_expr: self
+                .when_then_expr
+                .iter()
+                .map(|(e, t)| Ok((project(e)?, project(t)?)))
+                .collect::<Result<Vec<_>>>()?,
+            else_expr: self.else_expr.as_ref().map(project).transpose()?,
+        };
+
+        // Construct the projection vector
+        let projection = column_index_map
+            .iter()
+            .sorted_by_key(|(_, v)| **v)
+            .map(|(k, _)| *k)
+            .collect::<Vec<_>>();
+
+        Ok(ProjectedCaseBody {
+            projection,
+            body: projected_body,
+        })
+    }
+}
+
+/// A derived case body that can be used to evaluate a case expression after 
projecting
+/// record batches using a projection vector.
+#[derive(Debug, Hash, PartialEq, Eq)]
+struct ProjectedCaseBody {
+    projection: Vec<usize>,
+    body: CaseBody,

Review Comment:
   I personally found the term `project` to make sense and be consistent with 
the rest of the code



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to