This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8db30e25d6 Introduce `Signature::Coercible` (#12275)
8db30e25d6 is described below

commit 8db30e25d6fe65a9779d237cf48aea9aee297502
Author: Jay Zhan <[email protected]>
AuthorDate: Tue Sep 3 07:57:40 2024 +0800

    Introduce `Signature::Coercible` (#12275)
    
    * introduce signature float
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * change float to coercible
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rm test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add comment
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * typo
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/core/src/dataframe/mod.rs               |  6 ++--
 datafusion/expr-common/src/signature.rs            | 15 +++++++++-
 .../expr-common/src/type_coercion/aggregates.rs    |  6 ++--
 datafusion/expr/src/type_coercion/functions.rs     | 33 +++++++++++++++++++++-
 datafusion/functions-aggregate/src/stddev.rs       | 11 ++++----
 datafusion/functions-aggregate/src/variance.rs     | 11 ++++----
 6 files changed, 64 insertions(+), 18 deletions(-)

diff --git a/datafusion/core/src/dataframe/mod.rs 
b/datafusion/core/src/dataframe/mod.rs
index b8c0bd9d74..2138bd1294 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -2428,7 +2428,8 @@ mod tests {
         let df: Vec<RecordBatch> = df.select(aggr_expr)?.collect().await?;
 
         assert_batches_sorted_eq!(
-            
["+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
+            [
+                
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
                 "| first_value | last_val | approx_distinct | approx_median | 
median | max | min  | c2 | c3   |",
                 
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
                 "|             |          |                 |               |  
      |     |      | 1  | -85  |",
@@ -2452,7 +2453,8 @@ mod tests {
                 "| -85         | 45       | 8               | -34           | 
45     | 83  | -85  | 3  | -72  |",
                 "| -85         | 65       | 17              | -17           | 
65     | 83  | -101 | 5  | -101 |",
                 "| -85         | 83       | 5               | -25           | 
83     | 83  | -85  | 2  | -48  |",
-                
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+"],
+                
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
+            ],
             &df
         );
 
diff --git a/datafusion/expr-common/src/signature.rs 
b/datafusion/expr-common/src/signature.rs
index 4dcfa423e3..2043757a49 100644
--- a/datafusion/expr-common/src/signature.rs
+++ b/datafusion/expr-common/src/signature.rs
@@ -105,6 +105,11 @@ pub enum TypeSignature {
     Uniform(usize, Vec<DataType>),
     /// Exact number of arguments of an exact type
     Exact(Vec<DataType>),
+    /// The number of arguments that can be coerced to in order
+    /// For example, `Coercible(vec![DataType::Float64])` accepts
+    /// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
+    /// since i32 and f32 can be casted to f64
+    Coercible(Vec<DataType>),
     /// Fixed number of arguments of arbitrary types
     /// If a function takes 0 argument, its `TypeSignature` should be `Any(0)`
     Any(usize),
@@ -188,7 +193,7 @@ impl TypeSignature {
             TypeSignature::Numeric(num) => {
                 vec![format!("Numeric({})", num)]
             }
-            TypeSignature::Exact(types) => {
+            TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
                 vec![Self::join_types(types, ", ")]
             }
             TypeSignature::Any(arg_count) => {
@@ -300,6 +305,14 @@ impl Signature {
             volatility,
         }
     }
+    /// Target coerce types in order
+    pub fn coercible(target_types: Vec<DataType>, volatility: Volatility) -> 
Self {
+        Self {
+            type_signature: TypeSignature::Coercible(target_types),
+            volatility,
+        }
+    }
+
     /// A specified number of arguments of any type
     pub fn any(arg_count: usize, volatility: Volatility) -> Self {
         Signature {
diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs 
b/datafusion/expr-common/src/type_coercion/aggregates.rs
index 40ee596eee..2add9e7c18 100644
--- a/datafusion/expr-common/src/type_coercion/aggregates.rs
+++ b/datafusion/expr-common/src/type_coercion/aggregates.rs
@@ -128,9 +128,11 @@ pub fn check_arg_count(
                 );
             }
         }
-        TypeSignature::UserDefined | TypeSignature::Numeric(_) => {
+        TypeSignature::UserDefined
+        | TypeSignature::Numeric(_)
+        | TypeSignature::Coercible(_) => {
             // User-defined signature is validated in `coerce_types`
-            // Numreic signature is validated in `get_valid_types`
+            // Numeric and Coercible signature is validated in 
`get_valid_types`
         }
         _ => {
             return internal_err!(
diff --git a/datafusion/expr/src/type_coercion/functions.rs 
b/datafusion/expr/src/type_coercion/functions.rs
index b0b14a1a4e..d30d202df0 100644
--- a/datafusion/expr/src/type_coercion/functions.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -175,7 +175,14 @@ fn try_coerce_types(
     let mut valid_types = valid_types;
 
     // Well-supported signature that returns exact valid types.
-    if !valid_types.is_empty() && matches!(type_signature, 
TypeSignature::UserDefined) {
+    if !valid_types.is_empty()
+        && matches!(
+            type_signature,
+            TypeSignature::UserDefined
+                | TypeSignature::Numeric(_)
+                | TypeSignature::Coercible(_)
+        )
+    {
         // exact valid types
         assert_eq!(valid_types.len(), 1);
         let valid_types = valid_types.swap_remove(0);
@@ -397,6 +404,30 @@ fn get_valid_types(
 
             vec![vec![valid_type; *number]]
         }
+        TypeSignature::Coercible(target_types) => {
+            if target_types.is_empty() {
+                return plan_err!(
+                    "The signature expected at least one argument but received 
{}",
+                    current_types.len()
+                );
+            }
+            if target_types.len() != current_types.len() {
+                return plan_err!(
+                    "The signature expected {} arguments but received {}",
+                    target_types.len(),
+                    current_types.len()
+                );
+            }
+
+            for (data_type, target_type) in 
current_types.iter().zip(target_types.iter())
+            {
+                if !can_cast_types(data_type, target_type) {
+                    return plan_err!("{data_type} is not coercible to 
{target_type}");
+                }
+            }
+
+            vec![target_types.to_owned()]
+        }
         TypeSignature::Uniform(number, valid_types) => valid_types
             .iter()
             .map(|valid_type| (0..*number).map(|_| 
valid_type.clone()).collect())
diff --git a/datafusion/functions-aggregate/src/stddev.rs 
b/datafusion/functions-aggregate/src/stddev.rs
index 3534fb5b4d..a25ab5e319 100644
--- a/datafusion/functions-aggregate/src/stddev.rs
+++ b/datafusion/functions-aggregate/src/stddev.rs
@@ -68,7 +68,10 @@ impl Stddev {
     /// Create a new STDDEV aggregate function
     pub fn new() -> Self {
         Self {
-            signature: Signature::numeric(1, Volatility::Immutable),
+            signature: Signature::coercible(
+                vec![DataType::Float64],
+                Volatility::Immutable,
+            ),
             alias: vec!["stddev_samp".to_string()],
         }
     }
@@ -88,11 +91,7 @@ impl AggregateUDFImpl for Stddev {
         &self.signature
     }
 
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        if !arg_types[0].is_numeric() {
-            return plan_err!("Stddev requires numeric input types");
-        }
-
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
         Ok(DataType::Float64)
     }
 
diff --git a/datafusion/functions-aggregate/src/variance.rs 
b/datafusion/functions-aggregate/src/variance.rs
index f5f2d06e38..367a8669ab 100644
--- a/datafusion/functions-aggregate/src/variance.rs
+++ b/datafusion/functions-aggregate/src/variance.rs
@@ -79,7 +79,10 @@ impl VarianceSample {
     pub fn new() -> Self {
         Self {
             aliases: vec![String::from("var_sample"), 
String::from("var_samp")],
-            signature: Signature::numeric(1, Volatility::Immutable),
+            signature: Signature::coercible(
+                vec![DataType::Float64],
+                Volatility::Immutable,
+            ),
         }
     }
 }
@@ -97,11 +100,7 @@ impl AggregateUDFImpl for VarianceSample {
         &self.signature
     }
 
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        if !arg_types[0].is_numeric() {
-            return plan_err!("Variance requires numeric input types");
-        }
-
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
         Ok(DataType::Float64)
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to