jorgecarleitao commented on a change in pull request #7880:
URL: https://github.com/apache/arrow/pull/7880#discussion_r470414491
##########
File path: rust/datafusion/src/optimizer/type_coercion.rs
##########
@@ -43,138 +45,77 @@ impl<'a> TypeCoercionRule<'a> {
Self { scalar_functions }
}
- /// Rewrite an expression list to include explicit CAST operations when
required
- fn rewrite_expr_list(&self, expr: &[Expr], schema: &Schema) ->
Result<Vec<Expr>> {
- Ok(expr
+ /// Rewrite an expression to include explicit CAST operations when required
+ fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> {
+ let expressions = utils::expr_expressions(expr)?;
+
+ // recurse of the re-write
+ let mut expressions = expressions
.iter()
.map(|e| self.rewrite_expr(e, schema))
- .collect::<Result<Vec<_>>>()?)
- }
+ .collect::<Result<Vec<_>>>()?;
- /// Rewrite an expression to include explicit CAST operations when required
- fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> {
+ // modify `expressions` by introducing casts when necessary
match expr {
- Expr::BinaryExpr { left, op, right } => {
- let left = self.rewrite_expr(left, schema)?;
- let right = self.rewrite_expr(right, schema)?;
- let left_type = left.get_type(schema)?;
- let right_type = right.get_type(schema)?;
- if left_type == right_type {
- Ok(Expr::BinaryExpr {
- left: Box::new(left),
- op: op.clone(),
- right: Box::new(right),
- })
- } else {
+ Expr::BinaryExpr { .. } => {
+ let left_type = expressions[0].get_type(schema)?;
+ let right_type = expressions[1].get_type(schema)?;
+ if left_type != right_type {
let super_type = utils::get_supertype(&left_type,
&right_type)?;
- Ok(Expr::BinaryExpr {
- left: Box::new(left.cast_to(&super_type, schema)?),
- op: op.clone(),
- right: Box::new(right.cast_to(&super_type, schema)?),
- })
+
+ expressions[0] = expressions[0].cast_to(&super_type,
schema)?;
+ expressions[1] = expressions[1].cast_to(&super_type,
schema)?;
}
}
- Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(self.rewrite_expr(e,
schema)?))),
- Expr::IsNotNull(e) => {
- Ok(Expr::IsNotNull(Box::new(self.rewrite_expr(e, schema)?)))
- }
- Expr::ScalarFunction {
- name,
- args,
- return_type,
- } => {
+ Expr::ScalarFunction { name, .. } => {
// cast the inputs of scalar functions to the appropriate type
where possible
match self.scalar_functions.get(name) {
Some(func_meta) => {
- let mut func_args = Vec::with_capacity(args.len());
- for i in 0..args.len() {
+ for i in 0..expressions.len() {
let field = &func_meta.args[i];
- let expr = self.rewrite_expr(&args[i], schema)?;
- let actual_type = expr.get_type(schema)?;
+ let actual_type = expressions[i].get_type(schema)?;
let required_type = field.data_type();
- if &actual_type == required_type {
- func_args.push(expr)
- } else {
+ if &actual_type != required_type {
let super_type =
utils::get_supertype(&actual_type,
required_type)?;
- func_args.push(expr.cast_to(&super_type,
schema)?);
- }
+ expressions[i] =
+ expressions[i].cast_to(&super_type,
schema)?
+ };
}
-
- Ok(Expr::ScalarFunction {
- name: name.clone(),
- args: func_args,
- return_type: return_type.clone(),
- })
}
- _ => Err(ExecutionError::General(format!(
- "Invalid scalar function {}",
- name
- ))),
+ _ => {
+ return Err(ExecutionError::General(format!(
+ "Invalid scalar function {}",
+ name
+ )))
+ }
}
}
- Expr::AggregateFunction {
- name,
- args,
- return_type,
- } => Ok(Expr::AggregateFunction {
- name: name.clone(),
- args: args
- .iter()
- .map(|a| self.rewrite_expr(a, schema))
- .collect::<Result<Vec<_>>>()?,
- return_type: return_type.clone(),
- }),
- Expr::Cast { .. } => Ok(expr.clone()),
- Expr::Column(_) => Ok(expr.clone()),
- Expr::Alias(expr, alias) => Ok(Expr::Alias(
- Box::new(self.rewrite_expr(expr, schema)?),
- alias.to_owned(),
- )),
- Expr::Literal(_) => Ok(expr.clone()),
- Expr::Not(_) => Ok(expr.clone()),
- Expr::Sort { .. } => Ok(expr.clone()),
- Expr::Wildcard { .. } => Err(ExecutionError::General(
- "Wildcard expressions are not valid in a logical query
plan".to_owned(),
- )),
- Expr::Nested(e) => self.rewrite_expr(e, schema),
- }
+ _ => {}
+ };
+ utils::from_expression(expr, &expressions)
}
}
impl<'a> OptimizerRule for TypeCoercionRule<'a> {
fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan> {
- match plan {
- LogicalPlan::Projection { expr, input, .. } => {
- LogicalPlanBuilder::from(&self.optimize(input)?)
- .project(self.rewrite_expr_list(expr, input.schema())?)?
- .build()
- }
- LogicalPlan::Selection { expr, input, .. } => {
- LogicalPlanBuilder::from(&self.optimize(input)?)
- .filter(self.rewrite_expr(expr, input.schema())?)?
- .build()
- }
- LogicalPlan::Aggregate {
- input,
- group_expr,
- aggr_expr,
- ..
- } => LogicalPlanBuilder::from(&self.optimize(input)?)
- .aggregate(
- self.rewrite_expr_list(group_expr, input.schema())?,
- self.rewrite_expr_list(aggr_expr, input.schema())?,
- )?
- .build(),
- LogicalPlan::TableScan { .. } => Ok(plan.clone()),
- LogicalPlan::InMemoryScan { .. } => Ok(plan.clone()),
- LogicalPlan::ParquetScan { .. } => Ok(plan.clone()),
- LogicalPlan::CsvScan { .. } => Ok(plan.clone()),
- LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()),
- LogicalPlan::Limit { .. } => Ok(plan.clone()),
- LogicalPlan::Sort { .. } => Ok(plan.clone()),
- LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()),
- }
+ let inputs = utils::inputs(plan);
+ let expressions = utils::expressions(plan);
+
+ // apply the optimization to all inputs of the plan
+ let new_inputs = inputs
+ .iter()
+ .map(|plan| self.optimize(*plan))
+ .collect::<Result<Vec<_>>>()?;
+ // re-write all expressions on this plan.
+ // This assumes a single input, [0]. It wont work for join, subqueries
and union operations with more than one input.
+ // It is currently not an issue as we do not have any plan with more
than one input.
+ let new_expressions = expressions
Review comment:
Good catch. No because I am unsure how that could be possible: if we
have expressions on a plan, we need an `input` to convert them to physical
expressions and evaluate them against. AFAIK an expression always requires an
input to be evaluated against.
Do you have an example in mind?
AFAIK even a literal expression requires a schema to pass to
`Expr::get_type` and `Expr::name`.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]