alamb commented on a change in pull request #7967:
URL: https://github.com/apache/arrow/pull/7967#discussion_r471445155
##########
File path: rust/datafusion/src/optimizer/type_coercion.rs
##########
@@ -345,4 +345,143 @@ mod tests {
assert_eq!(expected, format!("{:?}", expr2));
}
+
+ #[test]
+ fn test_maybe_coerce() -> Result<()> {
+ // this vec contains: arg1, arg2, expected result
+ let cases = vec![
+ // 2 entries, same values
+ (
+ vec![DataType::UInt8, DataType::UInt16],
+ vec![DataType::UInt8, DataType::UInt16],
+ Some(vec![DataType::UInt8, DataType::UInt16]),
+ ),
+ // 2 entries, can coerse values
+ (
+ vec![DataType::UInt16, DataType::UInt16],
+ vec![DataType::UInt8, DataType::UInt16],
+ Some(vec![DataType::UInt16, DataType::UInt16]),
+ ),
+ // 0 entries, all good
+ (vec![], vec![], Some(vec![])),
+ // 2 entries, can't coerce
+ (
+ vec![DataType::Boolean, DataType::UInt16],
+ vec![DataType::UInt8, DataType::UInt16],
+ None,
+ ),
+ // u32 -> u16 is possible
+ (
+ vec![DataType::Boolean, DataType::UInt32],
+ vec![DataType::Boolean, DataType::UInt16],
+ Some(vec![DataType::Boolean, DataType::UInt32]),
+ ),
+ ];
+
+ for case in cases {
+ assert_eq!(maybe_coerce(&case.0, &case.1), case.2)
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn test_maybe_rewrite() -> Result<()> {
+ // create a schema
+ let schema = |t: Vec<DataType>| {
+ Schema::new(
+ t.iter()
+ .enumerate()
+ .map(|(i, t)| Field::new(&*format!("c{}", i), t.clone(),
true))
+ .collect(),
+ )
+ };
+
+ // create a vector of expressions
+ let expressions = |t: Vec<DataType>, schema| -> Result<Vec<Expr>> {
+ t.iter()
+ .enumerate()
+ .map(|(i, t)| col(&*format!("c{}", i)).cast_to(&t, &schema))
Review comment:
```suggestion
.map(|(i, t)| col(&format!("c{}", i)).cast_to(&t, &schema))
```
##########
File path: rust/datafusion/src/sql/planner.rs
##########
@@ -515,27 +515,29 @@ impl<S: SchemaProvider> SqlToRel<S> {
}
_ => match self.schema_provider.get_function_meta(&name) {
Some(fm) => {
- let rex_args = function
Review comment:
I wonder given the emphasis on pluggable planers, if
52218c852b7b3016afeaf95d8a46d6deea89d231 (removing the type coercion from
physical planner) is a good idea. As in "is it ok to assume that all plans went
through the existing `Optimizer` passes before being converted"
It seems reasonable to me, but it might be worth mentioning somewhere (e.g.
on the physical planner, etc)
##########
File path: rust/datafusion/tests/sql.rs
##########
@@ -232,6 +326,55 @@ fn custom_sqrt(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(builder.finish()))
}
+fn custom_add(args: &[ArrayRef]) -> Result<ArrayRef> {
+ match (args[0].data_type(), args[1].data_type()) {
+ (DataType::Float64, DataType::Float64) => {
+ let input1 = &args[0]
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .expect("cast failed");
+ let input2 = &args[1]
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .expect("cast failed");
+
+ let mut builder = Float64Builder::new(input1.len());
+ for i in 0..input1.len() {
+ if input1.is_null(i) || input2.is_null(i) {
+ builder.append_null()?;
+ } else {
+ builder.append_value(input1.value(i) + input2.value(i))?;
+ }
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ (DataType::Float32, DataType::Float32) => {
+ // all other cases return a constant vector (just to be diferent)
+ let mut builder = Float64Builder::new(args[0].len());
+ for _ in 0..args[0].len() {
+ builder.append_value(3232.0)?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ (DataType::Float32, DataType::Float64) => {
+ // all other cases return a constant vector (just to be diferent)
+ let mut builder = Float64Builder::new(args[0].len());
+ for _ in 0..args[0].len() {
+ builder.append_value(3264.0)?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }
+ (_, _) => {
+ // all other cases return a constant vector (just to be diferent)
Review comment:
maybe it is worth `panic!` if the argument types didn't match the
registration of the UDF
##########
File path: rust/datafusion/src/execution/physical_plan/udf.rs
##########
@@ -37,8 +37,11 @@ pub type ScalarUdf = Arc<dyn Fn(&[ArrayRef]) ->
Result<ArrayRef> + Send + Sync>;
pub struct ScalarFunction {
/// Function name
pub name: String,
- /// Function argument meta-data
- pub args: Vec<Field>,
+ /// Set of valid argument types.
+ /// The first dimension (0) represents specific combinations of valid
argument types
+ /// The second dimension (1) represents the types of each argument.
+ /// For example, [[t1, t2]] is a function of 2 arguments that only accept
t1 on the first arg and t2 on the second
Review comment:
makes sense
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]