This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 8db30e25d6 Introduce `Signature::Coercible` (#12275)
8db30e25d6 is described below
commit 8db30e25d6fe65a9779d237cf48aea9aee297502
Author: Jay Zhan <[email protected]>
AuthorDate: Tue Sep 3 07:57:40 2024 +0800
Introduce `Signature::Coercible` (#12275)
* introduce signature float
Signed-off-by: jayzhan211 <[email protected]>
* fix test
Signed-off-by: jayzhan211 <[email protected]>
* change float to coercible
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* rm test
Signed-off-by: jayzhan211 <[email protected]>
* add comment
Signed-off-by: jayzhan211 <[email protected]>
* typo
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/core/src/dataframe/mod.rs | 6 ++--
datafusion/expr-common/src/signature.rs | 15 +++++++++-
.../expr-common/src/type_coercion/aggregates.rs | 6 ++--
datafusion/expr/src/type_coercion/functions.rs | 33 +++++++++++++++++++++-
datafusion/functions-aggregate/src/stddev.rs | 11 ++++----
datafusion/functions-aggregate/src/variance.rs | 11 ++++----
6 files changed, 64 insertions(+), 18 deletions(-)
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index b8c0bd9d74..2138bd1294 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -2428,7 +2428,8 @@ mod tests {
let df: Vec<RecordBatch> = df.select(aggr_expr)?.collect().await?;
assert_batches_sorted_eq!(
-
["+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
+ [
+
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| first_value | last_val | approx_distinct | approx_median |
median | max | min | c2 | c3 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| | | | |
| | | 1 | -85 |",
@@ -2452,7 +2453,8 @@ mod tests {
"| -85 | 45 | 8 | -34 |
45 | 83 | -85 | 3 | -72 |",
"| -85 | 65 | 17 | -17 |
65 | 83 | -101 | 5 | -101 |",
"| -85 | 83 | 5 | -25 |
83 | 83 | -85 | 2 | -48 |",
-
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+"],
+
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
+ ],
&df
);
diff --git a/datafusion/expr-common/src/signature.rs
b/datafusion/expr-common/src/signature.rs
index 4dcfa423e3..2043757a49 100644
--- a/datafusion/expr-common/src/signature.rs
+++ b/datafusion/expr-common/src/signature.rs
@@ -105,6 +105,11 @@ pub enum TypeSignature {
Uniform(usize, Vec<DataType>),
/// Exact number of arguments of an exact type
Exact(Vec<DataType>),
+ /// The number of arguments that can be coerced to in order
+ /// For example, `Coercible(vec![DataType::Float64])` accepts
+ /// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
+ /// since i32 and f32 can be casted to f64
+ Coercible(Vec<DataType>),
/// Fixed number of arguments of arbitrary types
/// If a function takes 0 argument, its `TypeSignature` should be `Any(0)`
Any(usize),
@@ -188,7 +193,7 @@ impl TypeSignature {
TypeSignature::Numeric(num) => {
vec![format!("Numeric({})", num)]
}
- TypeSignature::Exact(types) => {
+ TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Any(arg_count) => {
@@ -300,6 +305,14 @@ impl Signature {
volatility,
}
}
+ /// Target coerce types in order
+ pub fn coercible(target_types: Vec<DataType>, volatility: Volatility) ->
Self {
+ Self {
+ type_signature: TypeSignature::Coercible(target_types),
+ volatility,
+ }
+ }
+
/// A specified number of arguments of any type
pub fn any(arg_count: usize, volatility: Volatility) -> Self {
Signature {
diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs
b/datafusion/expr-common/src/type_coercion/aggregates.rs
index 40ee596eee..2add9e7c18 100644
--- a/datafusion/expr-common/src/type_coercion/aggregates.rs
+++ b/datafusion/expr-common/src/type_coercion/aggregates.rs
@@ -128,9 +128,11 @@ pub fn check_arg_count(
);
}
}
- TypeSignature::UserDefined | TypeSignature::Numeric(_) => {
+ TypeSignature::UserDefined
+ | TypeSignature::Numeric(_)
+ | TypeSignature::Coercible(_) => {
// User-defined signature is validated in `coerce_types`
- // Numreic signature is validated in `get_valid_types`
+ // Numeric and Coercible signature is validated in
`get_valid_types`
}
_ => {
return internal_err!(
diff --git a/datafusion/expr/src/type_coercion/functions.rs
b/datafusion/expr/src/type_coercion/functions.rs
index b0b14a1a4e..d30d202df0 100644
--- a/datafusion/expr/src/type_coercion/functions.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -175,7 +175,14 @@ fn try_coerce_types(
let mut valid_types = valid_types;
// Well-supported signature that returns exact valid types.
- if !valid_types.is_empty() && matches!(type_signature,
TypeSignature::UserDefined) {
+ if !valid_types.is_empty()
+ && matches!(
+ type_signature,
+ TypeSignature::UserDefined
+ | TypeSignature::Numeric(_)
+ | TypeSignature::Coercible(_)
+ )
+ {
// exact valid types
assert_eq!(valid_types.len(), 1);
let valid_types = valid_types.swap_remove(0);
@@ -397,6 +404,30 @@ fn get_valid_types(
vec![vec![valid_type; *number]]
}
+ TypeSignature::Coercible(target_types) => {
+ if target_types.is_empty() {
+ return plan_err!(
+ "The signature expected at least one argument but received
{}",
+ current_types.len()
+ );
+ }
+ if target_types.len() != current_types.len() {
+ return plan_err!(
+ "The signature expected {} arguments but received {}",
+ target_types.len(),
+ current_types.len()
+ );
+ }
+
+ for (data_type, target_type) in
current_types.iter().zip(target_types.iter())
+ {
+ if !can_cast_types(data_type, target_type) {
+ return plan_err!("{data_type} is not coercible to
{target_type}");
+ }
+ }
+
+ vec![target_types.to_owned()]
+ }
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_|
valid_type.clone()).collect())
diff --git a/datafusion/functions-aggregate/src/stddev.rs
b/datafusion/functions-aggregate/src/stddev.rs
index 3534fb5b4d..a25ab5e319 100644
--- a/datafusion/functions-aggregate/src/stddev.rs
+++ b/datafusion/functions-aggregate/src/stddev.rs
@@ -68,7 +68,10 @@ impl Stddev {
/// Create a new STDDEV aggregate function
pub fn new() -> Self {
Self {
- signature: Signature::numeric(1, Volatility::Immutable),
+ signature: Signature::coercible(
+ vec![DataType::Float64],
+ Volatility::Immutable,
+ ),
alias: vec!["stddev_samp".to_string()],
}
}
@@ -88,11 +91,7 @@ impl AggregateUDFImpl for Stddev {
&self.signature
}
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- if !arg_types[0].is_numeric() {
- return plan_err!("Stddev requires numeric input types");
- }
-
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}
diff --git a/datafusion/functions-aggregate/src/variance.rs
b/datafusion/functions-aggregate/src/variance.rs
index f5f2d06e38..367a8669ab 100644
--- a/datafusion/functions-aggregate/src/variance.rs
+++ b/datafusion/functions-aggregate/src/variance.rs
@@ -79,7 +79,10 @@ impl VarianceSample {
pub fn new() -> Self {
Self {
aliases: vec![String::from("var_sample"),
String::from("var_samp")],
- signature: Signature::numeric(1, Volatility::Immutable),
+ signature: Signature::coercible(
+ vec![DataType::Float64],
+ Volatility::Immutable,
+ ),
}
}
}
@@ -97,11 +100,7 @@ impl AggregateUDFImpl for VarianceSample {
&self.signature
}
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- if !arg_types[0].is_numeric() {
- return plan_err!("Variance requires numeric input types");
- }
-
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]