alamb commented on a change in pull request #7880:
URL: https://github.com/apache/arrow/pull/7880#discussion_r470753855
##########
File path: rust/datafusion/src/optimizer/type_coercion.rs
##########
@@ -43,138 +45,77 @@ impl<'a> TypeCoercionRule<'a> {
Self { scalar_functions }
}
- /// Rewrite an expression list to include explicit CAST operations when
required
- fn rewrite_expr_list(&self, expr: &[Expr], schema: &Schema) ->
Result<Vec<Expr>> {
- Ok(expr
+ /// Rewrite an expression to include explicit CAST operations when required
+ fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> {
+ let expressions = utils::expr_expressions(expr)?;
+
+ // recurse of the re-write
+ let mut expressions = expressions
.iter()
.map(|e| self.rewrite_expr(e, schema))
- .collect::<Result<Vec<_>>>()?)
- }
+ .collect::<Result<Vec<_>>>()?;
- /// Rewrite an expression to include explicit CAST operations when required
- fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> {
+ // modify `expressions` by introducing casts when necessary
match expr {
- Expr::BinaryExpr { left, op, right } => {
- let left = self.rewrite_expr(left, schema)?;
- let right = self.rewrite_expr(right, schema)?;
- let left_type = left.get_type(schema)?;
- let right_type = right.get_type(schema)?;
- if left_type == right_type {
- Ok(Expr::BinaryExpr {
- left: Box::new(left),
- op: op.clone(),
- right: Box::new(right),
- })
- } else {
+ Expr::BinaryExpr { .. } => {
+ let left_type = expressions[0].get_type(schema)?;
+ let right_type = expressions[1].get_type(schema)?;
+ if left_type != right_type {
let super_type = utils::get_supertype(&left_type,
&right_type)?;
- Ok(Expr::BinaryExpr {
- left: Box::new(left.cast_to(&super_type, schema)?),
- op: op.clone(),
- right: Box::new(right.cast_to(&super_type, schema)?),
- })
+
+ expressions[0] = expressions[0].cast_to(&super_type,
schema)?;
+ expressions[1] = expressions[1].cast_to(&super_type,
schema)?;
}
}
- Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(self.rewrite_expr(e,
schema)?))),
- Expr::IsNotNull(e) => {
- Ok(Expr::IsNotNull(Box::new(self.rewrite_expr(e, schema)?)))
- }
- Expr::ScalarFunction {
- name,
- args,
- return_type,
- } => {
+ Expr::ScalarFunction { name, .. } => {
// cast the inputs of scalar functions to the appropriate type
where possible
match self.scalar_functions.get(name) {
Some(func_meta) => {
- let mut func_args = Vec::with_capacity(args.len());
- for i in 0..args.len() {
+ for i in 0..expressions.len() {
let field = &func_meta.args[i];
- let expr = self.rewrite_expr(&args[i], schema)?;
- let actual_type = expr.get_type(schema)?;
+ let actual_type = expressions[i].get_type(schema)?;
let required_type = field.data_type();
- if &actual_type == required_type {
- func_args.push(expr)
- } else {
+ if &actual_type != required_type {
let super_type =
utils::get_supertype(&actual_type,
required_type)?;
- func_args.push(expr.cast_to(&super_type,
schema)?);
- }
+ expressions[i] =
+ expressions[i].cast_to(&super_type,
schema)?
+ };
}
-
- Ok(Expr::ScalarFunction {
- name: name.clone(),
- args: func_args,
- return_type: return_type.clone(),
- })
}
- _ => Err(ExecutionError::General(format!(
- "Invalid scalar function {}",
- name
- ))),
+ _ => {
+ return Err(ExecutionError::General(format!(
+ "Invalid scalar function {}",
+ name
+ )))
+ }
}
}
- Expr::AggregateFunction {
- name,
- args,
- return_type,
- } => Ok(Expr::AggregateFunction {
- name: name.clone(),
- args: args
- .iter()
- .map(|a| self.rewrite_expr(a, schema))
- .collect::<Result<Vec<_>>>()?,
- return_type: return_type.clone(),
- }),
- Expr::Cast { .. } => Ok(expr.clone()),
- Expr::Column(_) => Ok(expr.clone()),
- Expr::Alias(expr, alias) => Ok(Expr::Alias(
- Box::new(self.rewrite_expr(expr, schema)?),
- alias.to_owned(),
- )),
- Expr::Literal(_) => Ok(expr.clone()),
- Expr::Not(_) => Ok(expr.clone()),
- Expr::Sort { .. } => Ok(expr.clone()),
- Expr::Wildcard { .. } => Err(ExecutionError::General(
- "Wildcard expressions are not valid in a logical query
plan".to_owned(),
- )),
- Expr::Nested(e) => self.rewrite_expr(e, schema),
- }
+ _ => {}
+ };
+ utils::from_expression(expr, &expressions)
}
}
impl<'a> OptimizerRule for TypeCoercionRule<'a> {
fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan> {
- match plan {
- LogicalPlan::Projection { expr, input, .. } => {
- LogicalPlanBuilder::from(&self.optimize(input)?)
- .project(self.rewrite_expr_list(expr, input.schema())?)?
- .build()
- }
- LogicalPlan::Selection { expr, input, .. } => {
- LogicalPlanBuilder::from(&self.optimize(input)?)
- .filter(self.rewrite_expr(expr, input.schema())?)?
- .build()
- }
- LogicalPlan::Aggregate {
- input,
- group_expr,
- aggr_expr,
- ..
- } => LogicalPlanBuilder::from(&self.optimize(input)?)
- .aggregate(
- self.rewrite_expr_list(group_expr, input.schema())?,
- self.rewrite_expr_list(aggr_expr, input.schema())?,
- )?
- .build(),
- LogicalPlan::TableScan { .. } => Ok(plan.clone()),
- LogicalPlan::InMemoryScan { .. } => Ok(plan.clone()),
- LogicalPlan::ParquetScan { .. } => Ok(plan.clone()),
- LogicalPlan::CsvScan { .. } => Ok(plan.clone()),
- LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()),
- LogicalPlan::Limit { .. } => Ok(plan.clone()),
- LogicalPlan::Sort { .. } => Ok(plan.clone()),
- LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()),
- }
+ let inputs = utils::inputs(plan);
+ let expressions = utils::expressions(plan);
+
+ // apply the optimization to all inputs of the plan
+ let new_inputs = inputs
+ .iter()
+ .map(|plan| self.optimize(*plan))
+ .collect::<Result<Vec<_>>>()?;
+ // re-write all expressions on this plan.
+ // This assumes a single input, [0]. It wont work for join, subqueries
and union operations with more than one input.
+ // It is currently not an issue as we do not have any plan with more
than one input.
+ let new_expressions = expressions
Review comment:
```suggestion
assert!(expressions.len() == 0 || inputs.len() > 0, "Assume that all
plan nodes with expressions had inputs");
let new_expressions = expressions
```
I think the `EmptyRelation`,
https://github.com/apache/arrow/blob/master/rust/datafusion/src/logicalplan.rs#L761-L764,
for example has no input LogicalPlan, but perhaps you are saying "even though
`EmptyRelation` has no inputs (and thus could cause `inputs[0].schema()` to
panic) it also has no Expressions then the potential panic'ing code won't be
run.
I guess I was thinking to the future where we add expressions to root nodes
(e.g. perhaps filtering *during* a table scan or something) which would then
have expressions but no input.
I think this code is fine as is. Perhaps we could make the code slightly
easier to work with in the future if we did something like the assert
suggestion here that there were no inputs if there were expressions rather than
panic.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]