Jefffrey commented on code in PR #8985:
URL: https://github.com/apache/arrow-datafusion/pull/8985#discussion_r1490785559
##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() ->
Result<()> {
Ok(())
}
+#[derive(Debug)]
+struct TakeUDF {
+ signature: Signature,
+}
+
+impl TakeUDF {
+ fn new() -> Self {
+ Self {
+ signature: Signature::any(3, Volatility::Immutable),
+ }
+ }
+}
+
+/// Implement a ScalarUDFImpl whose return type is a function of the input
values
+impl ScalarUDFImpl for TakeUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "take"
+ }
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+ fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+ not_impl_err!("Not called because the return_type_from_exprs is
implemented")
+ }
+
+ /// Thus function returns the type of the first or second argument based on
Review Comment:
```suggestion
/// This function returns the type of the first or second argument based
on
```
##########
datafusion/expr/src/udf.rs:
##########
@@ -249,6 +257,43 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// the arguments
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
+ /// What [`DataType`] will be returned by this function, given the
+ /// arguments?
+ ///
+ /// Note most UDFs should implement [`Self::return_type`] and not this
+ /// function. The output type for most functions only depends on the types
+ /// of their inputs (e.g. `sqrt(f32)` is always `f32`).
+ ///
+ /// By default, this function calls [`Self::return_type`] with the
+ /// types of each argument.
+ ///
+ /// This method can be overridden for functions that return different
+ /// *types* based on the *values* of their arguments.
+ ///
+ /// For example, the following two function calls get the same argument
+ /// types (something and a `Utf8` string) but return different types based
+ /// on the value of the second argument:
+ ///
+ /// * `arrow_cast(x, 'Int16')` --> `Int16`
+ /// * `arrow_cast(x, 'Float32')` --> `Float32`
+ ///
+ /// # Notes:
+ ///
+ /// This function must consistently return the same type for the same
+ /// logical input even if the input is simplified (e.g. it must return the
same
+ /// value for `('foo' | 'bar')` as it does for ('foobar').
Review Comment:
Maybe add some documentation about what would happen if a user tries to
implement both `return_type()` and `return_type_from_exprs()`? (Which takes
priority, etc.)
And what the suggested implementation for `return_type()` be if they choose
to implement `return_type_from_exprs()` instead
##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() ->
Result<()> {
Ok(())
}
+#[derive(Debug)]
+struct TakeUDF {
+ signature: Signature,
+}
+
+impl TakeUDF {
+ fn new() -> Self {
+ Self {
+ signature: Signature::any(3, Volatility::Immutable),
+ }
+ }
+}
+
+/// Implement a ScalarUDFImpl whose return type is a function of the input
values
+impl ScalarUDFImpl for TakeUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "take"
+ }
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+ fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+ not_impl_err!("Not called because the return_type_from_exprs is
implemented")
+ }
+
+ /// Thus function returns the type of the first or second argument based on
+ /// the third argument:
+ ///
+ /// 1. If the third argument is '0', return the type of the first argument
+ /// 2. If the third argument is '1', return the type of the second argument
+ fn return_type_from_exprs(
+ &self,
+ arg_exprs: &[Expr],
+ schema: &dyn ExprSchema,
+ ) -> Result<DataType> {
+ if arg_exprs.len() != 3 {
+ return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
+ }
+
+ let take_idx = if let
Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
+ arg_exprs.get(2)
+ {
+ if *idx == 0 || *idx == 1 {
+ *idx as usize
+ } else {
+ return plan_err!("The third argument must be 0 or 1, got:
{idx}");
+ }
+ } else {
+ return plan_err!(
+ "The third argument must be a literal of type int64, but got
{:?}",
+ arg_exprs.get(2)
+ );
+ };
+
+ arg_exprs.get(take_idx).unwrap().get_type(schema)
+ }
+
+ // The actual implementation rethr
Review Comment:
```suggestion
// The actual implementation
```
##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() ->
Result<()> {
Ok(())
}
+#[derive(Debug)]
+struct TakeUDF {
+ signature: Signature,
+}
+
+impl TakeUDF {
+ fn new() -> Self {
+ Self {
+ signature: Signature::any(3, Volatility::Immutable),
+ }
+ }
+}
+
+/// Implement a ScalarUDFImpl whose return type is a function of the input
values
+impl ScalarUDFImpl for TakeUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "take"
+ }
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+ fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+ not_impl_err!("Not called because the return_type_from_exprs is
implemented")
+ }
+
+ /// Thus function returns the type of the first or second argument based on
+ /// the third argument:
+ ///
+ /// 1. If the third argument is '0', return the type of the first argument
+ /// 2. If the third argument is '1', return the type of the second argument
+ fn return_type_from_exprs(
+ &self,
+ arg_exprs: &[Expr],
+ schema: &dyn ExprSchema,
+ ) -> Result<DataType> {
+ if arg_exprs.len() != 3 {
+ return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
+ }
+
+ let take_idx = if let
Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
+ arg_exprs.get(2)
+ {
+ if *idx == 0 || *idx == 1 {
+ *idx as usize
+ } else {
+ return plan_err!("The third argument must be 0 or 1, got:
{idx}");
+ }
+ } else {
+ return plan_err!(
+ "The third argument must be a literal of type int64, but got
{:?}",
+ arg_exprs.get(2)
+ );
+ };
+
+ arg_exprs.get(take_idx).unwrap().get_type(schema)
+ }
+
+ // The actual implementation rethr
+ fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ let take_idx = match &args[2] {
+ ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v
as usize,
+ _ => unreachable!(),
+ };
+ match &args[take_idx] {
+ ColumnarValue::Array(array) =>
Ok(ColumnarValue::Array(array.clone())),
+ ColumnarValue::Scalar(_) => unimplemented!(),
+ }
+ }
+}
+
+#[tokio::test]
+async fn verify_udf_return_type() -> Result<()> {
+ // Create a new ScalarUDF from the implementation
+ let take = ScalarUDF::from(TakeUDF::new());
+
+ // SELECT
+ // take(smallint_col, double_col, 0) as take0,
+ // take(smallint_col, double_col, 1) as take1
+ // FROM alltypes_plain;
+ let exprs = vec![
+ take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)])
+ .alias("take0"),
+ take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)])
+ .alias("take1"),
+ ];
+
+ let ctx = SessionContext::new();
+ register_alltypes_parquet(&ctx).await?;
+
+ let df = ctx.table("alltypes_plain").await?.select(exprs)?;
+
+ let schema = df.schema();
+
+ // The output schema should be
+ // * type of column smallint_col (float64)
+ // * type of column double_col (float32)
Review Comment:
```suggestion
// * type of column smallint_col (int32)
// * type of column double_col (float64)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]