Jefffrey commented on code in PR #8985:
URL: https://github.com/apache/arrow-datafusion/pull/8985#discussion_r1490785559


##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() -> 
Result<()> {
     Ok(())
 }
 
+#[derive(Debug)]
+struct TakeUDF {
+    signature: Signature,
+}
+
+impl TakeUDF {
+    fn new() -> Self {
+        Self {
+            signature: Signature::any(3, Volatility::Immutable),
+        }
+    }
+}
+
+/// Implement a ScalarUDFImpl whose return type is a function of the input 
values
+impl ScalarUDFImpl for TakeUDF {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+    fn name(&self) -> &str {
+        "take"
+    }
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+        not_impl_err!("Not called because the return_type_from_exprs is 
implemented")
+    }
+
+    /// Thus function returns the type of the first or second argument based on

Review Comment:
   ```suggestion
       /// This function returns the type of the first or second argument based 
on
   ```



##########
datafusion/expr/src/udf.rs:
##########
@@ -249,6 +257,43 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
     /// the arguments
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
 
+    /// What [`DataType`] will be returned by this function, given the
+    /// arguments?
+    ///
+    /// Note most UDFs should implement [`Self::return_type`] and not this
+    /// function. The output type for most functions only depends on the types
+    /// of their inputs (e.g. `sqrt(f32)` is always `f32`).
+    ///
+    /// By default, this function calls [`Self::return_type`] with the
+    /// types of each argument.
+    ///
+    /// This method can be overridden for functions that return different
+    /// *types* based on the *values* of their arguments.
+    ///
+    /// For example, the following two function calls get the same argument
+    /// types (something and a `Utf8` string) but return different types based
+    /// on the value of the second argument:
+    ///
+    /// * `arrow_cast(x, 'Int16')` --> `Int16`
+    /// * `arrow_cast(x, 'Float32')` --> `Float32`
+    ///
+    /// # Notes:
+    ///
+    /// This function must consistently return the same type for the same
+    /// logical input even if the input is simplified (e.g. it must return the 
same
+    /// value for `('foo' | 'bar')` as it does for ('foobar').

Review Comment:
   Maybe add some documentation about what would happen if a user tries to 
implement both `return_type()` and `return_type_from_exprs()`? (Which takes 
priority, etc.)
   
   And what the suggested implementation for `return_type()` be if they choose 
to implement `return_type_from_exprs()` instead



##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() -> 
Result<()> {
     Ok(())
 }
 
+#[derive(Debug)]
+struct TakeUDF {
+    signature: Signature,
+}
+
+impl TakeUDF {
+    fn new() -> Self {
+        Self {
+            signature: Signature::any(3, Volatility::Immutable),
+        }
+    }
+}
+
+/// Implement a ScalarUDFImpl whose return type is a function of the input 
values
+impl ScalarUDFImpl for TakeUDF {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+    fn name(&self) -> &str {
+        "take"
+    }
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+        not_impl_err!("Not called because the return_type_from_exprs is 
implemented")
+    }
+
+    /// Thus function returns the type of the first or second argument based on
+    /// the third argument:
+    ///
+    /// 1. If the third argument is '0', return the type of the first argument
+    /// 2. If the third argument is '1', return the type of the second argument
+    fn return_type_from_exprs(
+        &self,
+        arg_exprs: &[Expr],
+        schema: &dyn ExprSchema,
+    ) -> Result<DataType> {
+        if arg_exprs.len() != 3 {
+            return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
+        }
+
+        let take_idx = if let 
Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
+            arg_exprs.get(2)
+        {
+            if *idx == 0 || *idx == 1 {
+                *idx as usize
+            } else {
+                return plan_err!("The third argument must be 0 or 1, got: 
{idx}");
+            }
+        } else {
+            return plan_err!(
+                "The third argument must be a literal of type int64, but got 
{:?}",
+                arg_exprs.get(2)
+            );
+        };
+
+        arg_exprs.get(take_idx).unwrap().get_type(schema)
+    }
+
+    // The actual implementation rethr

Review Comment:
   ```suggestion
       // The actual implementation
   ```



##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() -> 
Result<()> {
     Ok(())
 }
 
+#[derive(Debug)]
+struct TakeUDF {
+    signature: Signature,
+}
+
+impl TakeUDF {
+    fn new() -> Self {
+        Self {
+            signature: Signature::any(3, Volatility::Immutable),
+        }
+    }
+}
+
+/// Implement a ScalarUDFImpl whose return type is a function of the input 
values
+impl ScalarUDFImpl for TakeUDF {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+    fn name(&self) -> &str {
+        "take"
+    }
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+        not_impl_err!("Not called because the return_type_from_exprs is 
implemented")
+    }
+
+    /// Thus function returns the type of the first or second argument based on
+    /// the third argument:
+    ///
+    /// 1. If the third argument is '0', return the type of the first argument
+    /// 2. If the third argument is '1', return the type of the second argument
+    fn return_type_from_exprs(
+        &self,
+        arg_exprs: &[Expr],
+        schema: &dyn ExprSchema,
+    ) -> Result<DataType> {
+        if arg_exprs.len() != 3 {
+            return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
+        }
+
+        let take_idx = if let 
Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
+            arg_exprs.get(2)
+        {
+            if *idx == 0 || *idx == 1 {
+                *idx as usize
+            } else {
+                return plan_err!("The third argument must be 0 or 1, got: 
{idx}");
+            }
+        } else {
+            return plan_err!(
+                "The third argument must be a literal of type int64, but got 
{:?}",
+                arg_exprs.get(2)
+            );
+        };
+
+        arg_exprs.get(take_idx).unwrap().get_type(schema)
+    }
+
+    // The actual implementation rethr
+    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+        let take_idx = match &args[2] {
+            ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v 
as usize,
+            _ => unreachable!(),
+        };
+        match &args[take_idx] {
+            ColumnarValue::Array(array) => 
Ok(ColumnarValue::Array(array.clone())),
+            ColumnarValue::Scalar(_) => unimplemented!(),
+        }
+    }
+}
+
+#[tokio::test]
+async fn verify_udf_return_type() -> Result<()> {
+    // Create a new ScalarUDF from the implementation
+    let take = ScalarUDF::from(TakeUDF::new());
+
+    // SELECT
+    //   take(smallint_col, double_col, 0) as take0,
+    //   take(smallint_col, double_col, 1) as take1
+    // FROM alltypes_plain;
+    let exprs = vec![
+        take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)])
+            .alias("take0"),
+        take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)])
+            .alias("take1"),
+    ];
+
+    let ctx = SessionContext::new();
+    register_alltypes_parquet(&ctx).await?;
+
+    let df = ctx.table("alltypes_plain").await?.select(exprs)?;
+
+    let schema = df.schema();
+
+    // The output schema should be
+    // * type of column smallint_col (float64)
+    // * type of column double_col (float32)

Review Comment:
   ```suggestion
       // * type of column smallint_col (int32)
       // * type of column double_col (float64)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to