This is an automated email from the ASF dual-hosted git repository.
findepi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 070517a87f Derive UDF equality from PartialEq, Hash (#16842)
070517a87f is described below
commit 070517a87f15fc9e5f64423e29ee155d4a6a868d
Author: Piotr Findeisen <[email protected]>
AuthorDate: Fri Jul 25 21:21:44 2025 +0200
Derive UDF equality from PartialEq, Hash (#16842)
* Derive UDF equality from PartialEq, Hash
Reduce boilerplate in cases where implementation of
`{ScalarUDFImpl,AggregateUDFImpl,WindowUDFImpl}::{equals,hash_code}` can
be derived using standard `PartialEq` and `Hash` traits.
This is code complexity reduction.
While valuable on its own, this also prepares for more automatic
derivation of UDF equals/hash and/or removal of default implementations
(which currently are error-prone).
* udf_equals_hash example
* test udf_equals_hash
* empty: roll the dice 🎲
---
.../user_defined/user_defined_scalar_functions.rs | 171 +++++--------------
datafusion/expr/src/async_udf.rs | 38 +++--
datafusion/expr/src/expr_fn.rs | 67 ++++----
datafusion/expr/src/udf.rs | 33 ++--
datafusion/expr/src/utils.rs | 182 ++++++++++++++++++++-
datafusion/ffi/src/udf/mod.rs | 66 ++++----
datafusion/proto/tests/cases/mod.rs | 32 +---
datafusion/sql/tests/sql_integration.rs | 38 +----
8 files changed, 335 insertions(+), 292 deletions(-)
diff --git
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index c5f9bdeb69..dd8283613a 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -17,7 +17,7 @@
use std::any::Any;
use std::collections::HashMap;
-use std::hash::{DefaultHasher, Hash, Hasher};
+use std::hash::{Hash, Hasher};
use std::sync::Arc;
use arrow::array::{as_string_array, create_array, record_batch, Int8Array,
UInt64Array};
@@ -43,9 +43,9 @@ use datafusion_common::{
use datafusion_expr::expr::FieldMetadata;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
- lit_with_metadata, Accumulator, ColumnarValue, CreateFunction,
CreateFunctionBody,
- LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs,
ScalarFunctionArgs,
- ScalarUDF, ScalarUDFImpl, Signature, Volatility,
+ lit_with_metadata, udf_equals_hash, Accumulator, ColumnarValue,
CreateFunction,
+ CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg,
ReturnFieldArgs,
+ ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions_nested::range::range_udf;
use parking_lot::Mutex;
@@ -181,6 +181,7 @@ async fn scalar_udf() -> Result<()> {
Ok(())
}
+#[derive(PartialEq, Hash)]
struct Simple0ArgsScalarUDF {
name: String,
signature: Signature,
@@ -218,33 +219,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
}
- fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
- let Some(other) = other.as_any().downcast_ref::<Self>() else {
- return false;
- };
- let Self {
- name,
- signature,
- return_type,
- } = self;
- name == &other.name
- && signature == &other.signature
- && return_type == &other.return_type
- }
-
- fn hash_value(&self) -> u64 {
- let Self {
- name,
- signature,
- return_type,
- } = self;
- let mut hasher = DefaultHasher::new();
- std::any::type_name::<Self>().hash(&mut hasher);
- name.hash(&mut hasher);
- signature.hash(&mut hasher);
- return_type.hash(&mut hasher);
- hasher.finish()
- }
+ udf_equals_hash!(ScalarUDFImpl);
}
#[tokio::test]
@@ -517,7 +492,7 @@ async fn test_user_defined_functions_with_alias() ->
Result<()> {
}
/// Volatile UDF that should append a different value to each row
-#[derive(Debug)]
+#[derive(Debug, PartialEq, Hash)]
struct AddIndexToStringVolatileScalarUDF {
name: String,
signature: Signature,
@@ -586,33 +561,7 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
}
- fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
- let Some(other) = other.as_any().downcast_ref::<Self>() else {
- return false;
- };
- let Self {
- name,
- signature,
- return_type,
- } = self;
- name == &other.name
- && signature == &other.signature
- && return_type == &other.return_type
- }
-
- fn hash_value(&self) -> u64 {
- let Self {
- name,
- signature,
- return_type,
- } = self;
- let mut hasher = DefaultHasher::new();
- std::any::type_name::<Self>().hash(&mut hasher);
- name.hash(&mut hasher);
- signature.hash(&mut hasher);
- return_type.hash(&mut hasher);
- hasher.finish()
- }
+ udf_equals_hash!(ScalarUDFImpl);
}
#[tokio::test]
@@ -992,7 +941,7 @@ impl FunctionFactory for CustomFunctionFactory {
//
// it also defines custom [ScalarUDFImpl::simplify()]
// to replace ScalarUDF expression with one instance contains.
-#[derive(Debug)]
+#[derive(Debug, PartialEq, Hash)]
struct ScalarFunctionWrapper {
name: String,
expr: Expr,
@@ -1031,37 +980,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
Ok(ExprSimplifyResult::Simplified(replacement))
}
- fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
- let Some(other) = other.as_any().downcast_ref::<Self>() else {
- return false;
- };
- let Self {
- name,
- expr,
- signature,
- return_type,
- } = self;
- name == &other.name
- && expr == &other.expr
- && signature == &other.signature
- && return_type == &other.return_type
- }
-
- fn hash_value(&self) -> u64 {
- let Self {
- name,
- expr,
- signature,
- return_type,
- } = self;
- let mut hasher = DefaultHasher::new();
- std::any::type_name::<Self>().hash(&mut hasher);
- name.hash(&mut hasher);
- expr.hash(&mut hasher);
- signature.hash(&mut hasher);
- return_type.hash(&mut hasher);
- hasher.finish()
- }
+ udf_equals_hash!(ScalarUDFImpl);
}
impl ScalarFunctionWrapper {
@@ -1296,6 +1215,21 @@ struct MyRegexUdf {
regex: Regex,
}
+impl PartialEq for MyRegexUdf {
+ fn eq(&self, other: &Self) -> bool {
+ let Self { signature, regex } = self;
+ signature == &other.signature && regex.as_str() == other.regex.as_str()
+ }
+}
+
+impl Hash for MyRegexUdf {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ let Self { signature, regex } = self;
+ signature.hash(state);
+ regex.as_str().hash(state);
+ }
+}
+
impl MyRegexUdf {
fn new(pattern: &str) -> Self {
Self {
@@ -1348,19 +1282,7 @@ impl ScalarUDFImpl for MyRegexUdf {
}
}
- fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
- if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() {
- self.regex.as_str() == other.regex.as_str()
- } else {
- false
- }
- }
-
- fn hash_value(&self) -> u64 {
- let hasher = &mut DefaultHasher::new();
- self.regex.as_str().hash(hasher);
- hasher.finish()
- }
+ udf_equals_hash!(ScalarUDFImpl);
}
#[tokio::test]
@@ -1458,13 +1380,25 @@ async fn plan_and_collect(ctx: &SessionContext, sql:
&str) -> Result<Vec<RecordB
ctx.sql(sql).await?.collect().await
}
-#[derive(Debug)]
+#[derive(Debug, PartialEq)]
struct MetadataBasedUdf {
name: String,
signature: Signature,
metadata: HashMap<String, String>,
}
+impl Hash for MetadataBasedUdf {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ let Self {
+ name,
+ signature,
+ metadata: _, // unhashable
+ } = self;
+ name.hash(state);
+ signature.hash(state);
+ }
+}
+
impl MetadataBasedUdf {
fn new(metadata: HashMap<String, String>) -> Self {
// The name we return must be unique. Otherwise we will not call
distinct
@@ -1537,32 +1471,7 @@ impl ScalarUDFImpl for MetadataBasedUdf {
}
}
- fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
- let Some(other) = other.as_any().downcast_ref::<Self>() else {
- return false;
- };
- let Self {
- name,
- signature,
- metadata,
- } = self;
- name == &other.name
- && signature == &other.signature
- && metadata == &other.metadata
- }
-
- fn hash_value(&self) -> u64 {
- let Self {
- name,
- signature,
- metadata: _, // unhashable
- } = self;
- let mut hasher = DefaultHasher::new();
- std::any::type_name::<Self>().hash(&mut hasher);
- name.hash(&mut hasher);
- signature.hash(&mut hasher);
- hasher.finish()
- }
+ udf_equals_hash!(ScalarUDFImpl);
}
#[tokio::test]
diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs
index 24ed124bb2..753ad7b778 100644
--- a/datafusion/expr/src/async_udf.rs
+++ b/datafusion/expr/src/async_udf.rs
@@ -15,7 +15,10 @@
// specific language governing permissions and limitations
// under the License.
-use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
+use crate::utils::{arc_ptr_eq, arc_ptr_hash};
+use crate::{
+ udf_equals_hash, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl,
+};
use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, FieldRef};
use async_trait::async_trait;
@@ -26,7 +29,7 @@ use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_expr_common::signature::Signature;
use std::any::Any;
use std::fmt::{Debug, Display};
-use std::hash::{DefaultHasher, Hash, Hasher};
+use std::hash::{Hash, Hasher};
use std::sync::Arc;
/// A scalar UDF that can invoke using async methods
@@ -62,6 +65,21 @@ pub struct AsyncScalarUDF {
inner: Arc<dyn AsyncScalarUDFImpl>,
}
+impl PartialEq for AsyncScalarUDF {
+ fn eq(&self, other: &Self) -> bool {
+ let Self { inner } = self;
+ // TODO when MSRV >= 1.86.0, switch to
`inner.equals(other.inner.as_ref())` leveraging trait upcasting.
+ arc_ptr_eq(inner, &other.inner)
+ }
+}
+
+impl Hash for AsyncScalarUDF {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ let Self { inner } = self;
+ arc_ptr_hash(inner, state);
+ }
+}
+
impl AsyncScalarUDF {
pub fn new(inner: Arc<dyn AsyncScalarUDFImpl>) -> Self {
Self { inner }
@@ -113,21 +131,7 @@ impl ScalarUDFImpl for AsyncScalarUDF {
internal_err!("async functions should not be called directly")
}
- fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
- let Some(other) = other.as_any().downcast_ref::<Self>() else {
- return false;
- };
- let Self { inner } = self;
- // TODO when MSRV >= 1.86.0, switch to
`inner.equals(other.inner.as_ref())` leveraging trait upcasting
- Arc::ptr_eq(inner, &other.inner)
- }
-
- fn hash_value(&self) -> u64 {
- let Self { inner } = self;
- let mut hasher = DefaultHasher::new();
- Arc::as_ptr(inner).hash(&mut hasher);
- hasher.finish()
- }
+ udf_equals_hash!(ScalarUDFImpl);
}
impl Display for AsyncScalarUDF {
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index c0351a9dca..1d8d183807 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -26,10 +26,11 @@ use crate::function::{
StateFieldsArgs,
};
use crate::select_expr::SelectExpr;
+use crate::utils::{arc_ptr_eq, arc_ptr_hash};
use crate::{
conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
- AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator,
ScalarFunctionArgs,
- ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
+ udf_equals_hash, AggregateUDF, Expr, LogicalPlan, Operator,
PartitionEvaluator,
+ ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, Signature,
Volatility,
};
use crate::{
AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF,
WindowUDFImpl,
@@ -409,6 +410,36 @@ pub struct SimpleScalarUDF {
fun: ScalarFunctionImplementation,
}
+impl PartialEq for SimpleScalarUDF {
+ fn eq(&self, other: &Self) -> bool {
+ let Self {
+ name,
+ signature,
+ return_type,
+ fun,
+ } = self;
+ name == &other.name
+ && signature == &other.signature
+ && return_type == &other.return_type
+ && arc_ptr_eq(fun, &other.fun)
+ }
+}
+
+impl Hash for SimpleScalarUDF {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ let Self {
+ name,
+ signature,
+ return_type,
+ fun,
+ } = self;
+ name.hash(state);
+ signature.hash(state);
+ return_type.hash(state);
+ arc_ptr_hash(fun, state);
+ }
+}
+
impl Debug for SimpleScalarUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SimpleScalarUDF")
@@ -476,37 +507,7 @@ impl ScalarUDFImpl for SimpleScalarUDF {
(self.fun)(&args.args)
}
- fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
- let Some(other) = other.as_any().downcast_ref::<Self>() else {
- return false;
- };
- let Self {
- name,
- signature,
- return_type,
- fun,
- } = self;
- name == &other.name
- && signature == &other.signature
- && return_type == &other.return_type
- && Arc::ptr_eq(fun, &other.fun)
- }
-
- fn hash_value(&self) -> u64 {
- let Self {
- name,
- signature,
- return_type,
- fun,
- } = self;
- let mut hasher = DefaultHasher::new();
- std::any::type_name::<Self>().hash(&mut hasher);
- name.hash(&mut hasher);
- signature.hash(&mut hasher);
- return_type.hash(&mut hasher);
- Arc::as_ptr(fun).hash(&mut hasher);
- hasher.finish()
- }
+ udf_equals_hash!(ScalarUDFImpl);
}
/// Creates a new UDAF with a specific signature, state type and return type.
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index 3a94981ae4..171c4e041f 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -21,7 +21,7 @@ use crate::async_udf::AsyncScalarUDF;
use crate::expr::schema_name_from_exprs_comma_separated_without_space;
use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
use crate::sort_properties::{ExprProperties, SortProperties};
-use crate::{ColumnarValue, Documentation, Expr, Signature};
+use crate::{udf_equals_hash, ColumnarValue, Documentation, Expr, Signature};
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue};
use datafusion_expr_common::interval_arithmetic::Interval;
@@ -747,6 +747,21 @@ struct AliasedScalarUDFImpl {
aliases: Vec<String>,
}
+impl PartialEq for AliasedScalarUDFImpl {
+ fn eq(&self, other: &Self) -> bool {
+ let Self { inner, aliases } = self;
+ inner.equals(other.inner.as_ref()) && aliases == &other.aliases
+ }
+}
+
+impl Hash for AliasedScalarUDFImpl {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ let Self { inner, aliases } = self;
+ inner.hash_value().hash(state);
+ aliases.hash(state);
+ }
+}
+
impl AliasedScalarUDFImpl {
pub fn new(
inner: Arc<dyn ScalarUDFImpl>,
@@ -831,21 +846,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
self.inner.coerce_types(arg_types)
}
- fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
- if let Some(other) =
other.as_any().downcast_ref::<AliasedScalarUDFImpl>() {
- self.inner.equals(other.inner.as_ref()) && self.aliases ==
other.aliases
- } else {
- false
- }
- }
-
- fn hash_value(&self) -> u64 {
- let hasher = &mut DefaultHasher::new();
- std::any::type_name::<Self>().hash(hasher);
- self.inner.hash_value().hash(hasher);
- self.aliases.hash(hasher);
- hasher.finish()
- }
+ udf_equals_hash!(ScalarUDFImpl);
fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 8950f5e450..e554152328 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -19,6 +19,7 @@
use std::cmp::Ordering;
use std::collections::{BTreeSet, HashSet};
+use std::hash::Hasher;
use std::sync::Arc;
use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams};
@@ -1260,6 +1261,94 @@ pub fn collect_subquery_cols(
})
}
+/// Generates implementation of `equals` and `hash_value` methods for a trait,
delegating
+/// to [`PartialEq`] and [`Hash`] implementations on Self.
+/// Meant to be used with traits representing user-defined functions (UDFs).
+///
+/// Example showing generation of [`ScalarUDFImpl::equals`] and
[`ScalarUDFImpl::hash_value`]
+/// implementations.
+///
+/// ```
+/// # use arrow::datatypes::DataType;
+/// # use datafusion_expr::{udf_equals_hash, ScalarFunctionArgs,
ScalarUDFImpl};
+/// # use datafusion_expr_common::columnar_value::ColumnarValue;
+/// # use datafusion_expr_common::signature::Signature;
+/// # use std::any::Any;
+///
+/// // Implementing PartialEq & Hash is a prerequisite for using this macro,
+/// // but the implementation can be derived.
+/// #[derive(Debug, PartialEq, Hash)]
+/// struct VarcharToTimestampTz {
+/// safe: bool,
+/// }
+///
+/// impl ScalarUDFImpl for VarcharToTimestampTz {
+/// /* other methods omitted for brevity */
+/// # fn as_any(&self) -> &dyn Any {
+/// # self
+/// # }
+/// #
+/// # fn name(&self) -> &str {
+/// # "varchar_to_timestamp_tz"
+/// # }
+/// #
+/// # fn signature(&self) -> &Signature {
+/// # todo!()
+/// # }
+/// #
+/// # fn return_type(
+/// # &self,
+/// # _arg_types: &[DataType],
+/// # ) -> datafusion_common::Result<DataType> {
+/// # todo!()
+/// # }
+/// #
+/// # fn invoke_with_args(
+/// # &self,
+/// # args: ScalarFunctionArgs,
+/// # ) -> datafusion_common::Result<ColumnarValue> {
+/// # todo!()
+/// # }
+/// #
+/// udf_equals_hash!(ScalarUDFImpl);
+/// }
+/// ```
+///
+/// [`ScalarUDFImpl::equals`]: crate::ScalarUDFImpl::equals
+/// [`ScalarUDFImpl::hash_value`]: crate::ScalarUDFImpl::hash_value
+#[macro_export]
+macro_rules! udf_equals_hash {
+ ($udf_type:tt) => {
+ fn equals(&self, other: &dyn $udf_type) -> bool {
+ use ::core::any::Any;
+ use ::core::cmp::PartialEq;
+ let Some(other) = <dyn Any +
'static>::downcast_ref::<Self>(other.as_any())
+ else {
+ return false;
+ };
+ PartialEq::eq(self, other)
+ }
+
+ fn hash_value(&self) -> u64 {
+ use ::std::any::type_name;
+ use ::std::hash::{DefaultHasher, Hash, Hasher};
+ let hasher = &mut DefaultHasher::new();
+ type_name::<Self>().hash(hasher);
+ Hash::hash(self, hasher);
+ Hasher::finish(hasher)
+ }
+ };
+}
+
+pub fn arc_ptr_eq<T: ?Sized>(a: &Arc<T>, b: &Arc<T>) -> bool {
+ // Not necessarily equivalent to `Arc::ptr_eq` for fat pointers.
+ std::ptr::eq(Arc::as_ptr(a), Arc::as_ptr(b))
+}
+
+pub fn arc_ptr_hash<T: ?Sized>(a: &Arc<T>, hasher: &mut impl Hasher) {
+ std::ptr::hash(Arc::as_ptr(a), hasher)
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -1268,9 +1357,13 @@ mod tests {
expr::WindowFunction,
expr_vec_fmt, grouping_set, lit, rollup,
test::function_stub::{max_udaf, min_udaf, sum_udaf},
- Cast, ExprFunctionExt, WindowFunctionDefinition,
+ Cast, ExprFunctionExt, ScalarFunctionArgs, ScalarUDFImpl,
+ WindowFunctionDefinition,
};
use arrow::datatypes::{UnionFields, UnionMode};
+ use datafusion_expr_common::columnar_value::ColumnarValue;
+ use datafusion_expr_common::signature::Volatility;
+ use std::any::Any;
#[test]
fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
@@ -1690,4 +1783,91 @@ mod tests {
DataType::List(Arc::new(Field::new("my_union", union_type, true)));
assert!(!can_hash(&list_union_type));
}
+
+ #[test]
+ fn test_udf_equals_hash() {
+ #[derive(Debug, PartialEq, Hash)]
+ struct StatefulFunctionWithEqHash {
+ signature: Signature,
+ state: bool,
+ }
+ impl ScalarUDFImpl for StatefulFunctionWithEqHash {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "StatefulFunctionWithEqHash"
+ }
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType>
{
+ todo!()
+ }
+ fn invoke_with_args(
+ &self,
+ _args: ScalarFunctionArgs,
+ ) -> Result<ColumnarValue> {
+ todo!()
+ }
+ }
+
+ #[derive(Debug, PartialEq, Hash)]
+ struct StatefulFunctionWithEqHashWithUdfEqualsHash {
+ signature: Signature,
+ state: bool,
+ }
+ impl ScalarUDFImpl for StatefulFunctionWithEqHashWithUdfEqualsHash {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "StatefulFunctionWithEqHashWithUdfEqualsHash"
+ }
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType>
{
+ todo!()
+ }
+ fn invoke_with_args(
+ &self,
+ _args: ScalarFunctionArgs,
+ ) -> Result<ColumnarValue> {
+ todo!()
+ }
+ udf_equals_hash!(ScalarUDFImpl);
+ }
+
+ let signature = Signature::exact(vec![DataType::Utf8],
Volatility::Immutable);
+
+ // Sadly, without `udf_equals_hash!` macro, the equals and hash_value
ignore state fields,
+ // even though the struct implements `PartialEq` and `Hash`.
+ let a: Box<dyn ScalarUDFImpl> = Box::new(StatefulFunctionWithEqHash {
+ signature: signature.clone(),
+ state: true,
+ });
+ let b: Box<dyn ScalarUDFImpl> = Box::new(StatefulFunctionWithEqHash {
+ signature: signature.clone(),
+ state: false,
+ });
+ assert!(a.equals(b.as_ref()));
+ assert_eq!(a.hash_value(), b.hash_value());
+
+ // With udf_equals_hash! macro, the equals and hash_value compare the
state.
+ // even though the struct implements `PartialEq` and `Hash`.
+ let a: Box<dyn ScalarUDFImpl> =
+ Box::new(StatefulFunctionWithEqHashWithUdfEqualsHash {
+ signature: signature.clone(),
+ state: true,
+ });
+ let b: Box<dyn ScalarUDFImpl> =
+ Box::new(StatefulFunctionWithEqHashWithUdfEqualsHash {
+ signature: signature.clone(),
+ state: false,
+ });
+ assert!(!a.equals(b.as_ref()));
+ // This could be true, but it's very unlikely that boolean true and
false hash the same
+ assert_ne!(a.hash_value(), b.hash_value());
+ }
}
diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs
index 09cd0df128..1c835bd3ec 100644
--- a/datafusion/ffi/src/udf/mod.rs
+++ b/datafusion/ffi/src/udf/mod.rs
@@ -32,7 +32,7 @@ use arrow::{
ffi::{from_ffi, to_ffi, FFI_ArrowSchema},
};
use arrow_schema::FieldRef;
-use datafusion::logical_expr::ReturnFieldArgs;
+use datafusion::logical_expr::{udf_equals_hash, ReturnFieldArgs};
use datafusion::{
error::DataFusionError,
logical_expr::type_coercion::functions::data_types_with_scalar_udf,
@@ -46,7 +46,7 @@ use datafusion::{
use return_type_args::{
FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned,
};
-use std::hash::{DefaultHasher, Hash, Hasher};
+use std::hash::{Hash, Hasher};
use std::{ffi::c_void, sync::Arc};
pub mod return_type_args;
@@ -287,6 +287,36 @@ pub struct ForeignScalarUDF {
unsafe impl Send for ForeignScalarUDF {}
unsafe impl Sync for ForeignScalarUDF {}
+impl PartialEq for ForeignScalarUDF {
+ fn eq(&self, other: &Self) -> bool {
+ let Self {
+ name,
+ aliases,
+ udf,
+ signature,
+ } = self;
+ name == &other.name
+ && aliases == &other.aliases
+ && std::ptr::eq(udf, &other.udf)
+ && signature == &other.signature
+ }
+}
+
+impl Hash for ForeignScalarUDF {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ let Self {
+ name,
+ aliases,
+ udf,
+ signature,
+ } = self;
+ name.hash(state);
+ aliases.hash(state);
+ std::ptr::hash(udf, state);
+ signature.hash(state);
+ }
+}
+
impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF {
type Error = DataFusionError;
@@ -409,37 +439,7 @@ impl ScalarUDFImpl for ForeignScalarUDF {
}
}
- fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
- let Some(other) = other.as_any().downcast_ref::<Self>() else {
- return false;
- };
- let Self {
- name,
- aliases,
- udf,
- signature,
- } = self;
- name == &other.name
- && aliases == &other.aliases
- && std::ptr::eq(udf, &other.udf)
- && signature == &other.signature
- }
-
- fn hash_value(&self) -> u64 {
- let Self {
- name,
- aliases,
- udf,
- signature,
- } = self;
- let mut hasher = DefaultHasher::new();
- std::any::type_name::<Self>().hash(&mut hasher);
- name.hash(&mut hasher);
- aliases.hash(&mut hasher);
- std::ptr::hash(udf, &mut hasher);
- signature.hash(&mut hasher);
- hasher.finish()
- }
+ udf_equals_hash!(ScalarUDFImpl);
}
#[cfg(test)]
diff --git a/datafusion/proto/tests/cases/mod.rs
b/datafusion/proto/tests/cases/mod.rs
index 6158b727df..ab08f5b9be 100644
--- a/datafusion/proto/tests/cases/mod.rs
+++ b/datafusion/proto/tests/cases/mod.rs
@@ -20,8 +20,8 @@ use datafusion::logical_expr::ColumnarValue;
use datafusion_common::plan_err;
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::{
- Accumulator, AggregateUDFImpl, PartitionEvaluator, ScalarFunctionArgs,
ScalarUDFImpl,
- Signature, Volatility, WindowUDFImpl,
+ udf_equals_hash, Accumulator, AggregateUDFImpl, PartitionEvaluator,
+ ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, WindowUDFImpl,
};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
@@ -82,33 +82,7 @@ impl ScalarUDFImpl for MyRegexUdf {
&self.aliases
}
- fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
- let Some(other) = other.as_any().downcast_ref::<Self>() else {
- return false;
- };
- let Self {
- signature,
- pattern,
- aliases,
- } = self;
- signature == &other.signature
- && pattern == &other.pattern
- && aliases == &other.aliases
- }
-
- fn hash_value(&self) -> u64 {
- let Self {
- signature,
- pattern,
- aliases,
- } = self;
- let mut hasher = DefaultHasher::new();
- std::any::type_name::<Self>().hash(&mut hasher);
- signature.hash(&mut hasher);
- pattern.hash(&mut hasher);
- aliases.hash(&mut hasher);
- hasher.finish()
- }
+ udf_equals_hash!(ScalarUDFImpl);
}
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/sql/tests/sql_integration.rs
b/datafusion/sql/tests/sql_integration.rs
index 1e857b6a07..dd5ec4a201 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -16,7 +16,7 @@
// under the License.
use std::any::Any;
-use std::hash::{DefaultHasher, Hash, Hasher};
+use std::hash::Hash;
#[cfg(test)]
use std::sync::Arc;
use std::vec;
@@ -25,9 +25,9 @@ use arrow::datatypes::{TimeUnit::Nanosecond, *};
use common::MockContextProvider;
use datafusion_common::{assert_contains, DataFusionError, Result};
use datafusion_expr::{
- col, logical_plan::LogicalPlan, test::function_stub::sum_udaf,
ColumnarValue,
- CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
Signature,
- Volatility,
+ col, logical_plan::LogicalPlan, test::function_stub::sum_udaf,
udf_equals_hash,
+ ColumnarValue, CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF,
+ ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions::{string, unicode};
use datafusion_sql::{
@@ -3312,7 +3312,7 @@ fn make_udf(name: &'static str, args: Vec<DataType>,
return_type: DataType) -> S
}
/// Mocked UDF
-#[derive(Debug)]
+#[derive(Debug, PartialEq, Hash)]
struct DummyUDF {
name: &'static str,
signature: Signature,
@@ -3350,33 +3350,7 @@ impl ScalarUDFImpl for DummyUDF {
panic!("dummy - not implemented")
}
- fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
- let Some(other) = other.as_any().downcast_ref::<Self>() else {
- return false;
- };
- let Self {
- name,
- signature,
- return_type,
- } = self;
- name == &other.name
- && signature == &other.signature
- && return_type == &other.return_type
- }
-
- fn hash_value(&self) -> u64 {
- let Self {
- name,
- signature,
- return_type,
- } = self;
- let mut hasher = DefaultHasher::new();
- std::any::type_name::<Self>().hash(&mut hasher);
- name.hash(&mut hasher);
- signature.hash(&mut hasher);
- return_type.hash(&mut hasher);
- hasher.finish()
- }
+ udf_equals_hash!(ScalarUDFImpl);
}
fn parse_decimals_parser_options() -> ParserOptions {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]