This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new fc8824011b Refactor `power()` signature away from user defined (#18968)
fc8824011b is described below
commit fc8824011bf5d4baccbfe51b3888ed5573ef3bfb
Author: Jeffrey Vo <[email protected]>
AuthorDate: Tue Dec 16 09:30:51 2025 +0900
Refactor `power()` signature away from user defined (#18968)
## Which issue does this PR close?
<!--
We generally require a GitHub issue to be filed for all bug fixes and
enhancements and this helps us generate change logs for our releases.
You can link an issue to this PR using the GitHub syntax. For example
`Closes #123` indicates that this PR will close issue #123.
-->
- Part of https://github.com/apache/datafusion/issues/12725
## Rationale for this change
<!--
Why are you proposing this change? If this is already explained clearly
in the issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->
Prefer to avoid user_defined for consistency in function definitions.
## What changes are included in this PR?
<!--
There is no need to duplicate the description in the issue here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->
Refactor signature of power away from user_defined.
## Are these changes tested?
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code
If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->
Existing tests.
## Are there any user-facing changes?
<!--
If there are user-facing changes then we may require documentation to be
updated before approving the PR.
-->
No.
<!--
If there are any breaking changes to public APIs, please add the `api
change` label.
-->
---
datafusion/functions/src/math/power.rs | 351 ++++----------------------
datafusion/sqllogictest/test_files/math.slt | 54 ++++
datafusion/sqllogictest/test_files/scalar.slt | 2 +-
3 files changed, 105 insertions(+), 302 deletions(-)
diff --git a/datafusion/functions/src/math/power.rs
b/datafusion/functions/src/math/power.rs
index 198ad88b94..6b8eaa0be0 100644
--- a/datafusion/functions/src/math/power.rs
+++ b/datafusion/functions/src/math/power.rs
@@ -27,15 +27,15 @@ use arrow::datatypes::{
Decimal256Type, Float64Type, Int64Type,
};
use arrow::error::ArrowError;
+use datafusion_common::types::{NativeType, logical_float64, logical_int64};
use datafusion_common::utils::take_function_args;
-use datafusion_common::{Result, ScalarValue, exec_err, plan_datafusion_err};
+use datafusion_common::{Result, ScalarValue, internal_err};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
-use datafusion_expr::type_coercion::is_decimal;
use datafusion_expr::{
- ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF,
+ Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs,
ScalarUDF,
+ ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
lit,
};
-use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
#[user_doc(
@@ -67,8 +67,26 @@ impl Default for PowerFunc {
impl PowerFunc {
pub fn new() -> Self {
+ let integer = Coercion::new_implicit(
+ TypeSignatureClass::Native(logical_int64()),
+ vec![TypeSignatureClass::Integer],
+ NativeType::Int64,
+ );
+ let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
+ let float = Coercion::new_implicit(
+ TypeSignatureClass::Native(logical_float64()),
+ vec![TypeSignatureClass::Numeric],
+ NativeType::Float64,
+ );
Self {
- signature: Signature::user_defined(Volatility::Immutable),
+ signature: Signature::one_of(
+ vec![
+ TypeSignature::Coercible(vec![decimal.clone(), integer]),
+ TypeSignature::Coercible(vec![decimal.clone(),
float.clone()]),
+ TypeSignature::Coercible(vec![float; 2]),
+ ],
+ Volatility::Immutable,
+ ),
aliases: vec![String::from("pow")],
}
}
@@ -153,6 +171,7 @@ impl ScalarUDFImpl for PowerFunc {
fn as_any(&self) -> &dyn Any {
self
}
+
fn name(&self) -> &str {
"power"
}
@@ -162,57 +181,23 @@ impl ScalarUDFImpl for PowerFunc {
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- Ok(arg_types[0].clone())
+ if arg_types[0].is_null() {
+ Ok(DataType::Float64)
+ } else {
+ Ok(arg_types[0].clone())
+ }
}
fn aliases(&self) -> &[String] {
&self.aliases
}
- fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
- let [arg1, arg2] = take_function_args(self.name(), arg_types)?;
-
- fn coerced_type_exp(name: &str, data_type: &DataType) ->
Result<DataType> {
- match data_type {
- DataType::Null => Ok(DataType::Int64),
- d if d.is_floating() => Ok(DataType::Float64),
- d if d.is_integer() => Ok(DataType::Int64),
- d if is_decimal(d) => Ok(DataType::Float64),
- other => {
- exec_err!("Unsupported data type {other:?} for {}
function", name)
- }
- }
- }
-
- // Determine the exponent type first, as it affects base coercion
- let exp_type = coerced_type_exp(self.name(), arg2)?;
-
- // For base coercion: always use Float64 for integer/null bases
- // This matches PostgreSQL behavior and handles negative exponents
correctly
- fn coerced_type_base(name: &str, data_type: &DataType) ->
Result<DataType> {
- match data_type {
- d if d.is_floating() => Ok(DataType::Float64),
- // Integer and Null bases always coerce to Float64
- // (integer power doesn't support negative exponents, and pow()
- // should return float like PostgreSQL does)
- DataType::Null => Ok(DataType::Float64),
- d if d.is_integer() => Ok(DataType::Float64),
- d if is_decimal(d) => Ok(d.clone()),
- other => {
- exec_err!("Unsupported data type {other:?} for {}
function", name)
- }
- }
- }
-
- Ok(vec![coerced_type_base(self.name(), arg1)?, exp_type])
- }
-
fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
- let base = &args.args[0].to_array(args.number_rows)?;
- let exponent = &args.args[1];
+ let [base, exponent] = take_function_args(self.name(), &args.args)?;
+ let base = base.to_array(args.number_rows)?;
let arr: ArrayRef = match (base.data_type(), exponent.data_type()) {
- (DataType::Float64, _) => {
+ (DataType::Float64, DataType::Float64) => {
calculate_binary_math::<Float64Type, Float64Type, Float64Type,
_>(
&base,
exponent,
@@ -322,9 +307,8 @@ impl ScalarUDFImpl for PowerFunc {
)?
}
(base_type, exp_type) => {
- return exec_err!(
- "Unsupported data types for base {base_type:?} and
exponent {exp_type:?} for function {}",
- self.name()
+ return internal_err!(
+ "Unsupported data types for base {base_type:?} and
exponent {exp_type:?} for power"
);
}
};
@@ -332,30 +316,33 @@ impl ScalarUDFImpl for PowerFunc {
}
/// Simplify the `power` function by the relevant rules:
- /// 1. Power(a, 0) ===> 0
+ /// 1. Power(a, 0) ===> 1
/// 2. Power(a, 1) ===> a
/// 3. Power(a, Log(a, b)) ===> b
fn simplify(
&self,
- mut args: Vec<Expr>,
+ args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
- let exponent = args.pop().ok_or_else(|| {
- plan_datafusion_err!("Expected power to have 2 arguments, got 0")
- })?;
- let base = args.pop().ok_or_else(|| {
- plan_datafusion_err!("Expected power to have 2 arguments, got 1")
- })?;
-
+ let [base, exponent] = take_function_args("power", args)?;
+ let base_type = info.get_data_type(&base)?;
let exponent_type = info.get_data_type(&exponent)?;
+
+ // Null propagation
+ if base_type.is_null() || exponent_type.is_null() {
+ let return_type = self.return_type(&[base_type, exponent_type])?;
+ return Ok(ExprSimplifyResult::Simplified(lit(
+ ScalarValue::Null.cast_to(&return_type)?
+ )));
+ }
+
match exponent {
Expr::Literal(value, _)
if value == ScalarValue::new_zero(&exponent_type)? =>
{
- Ok(ExprSimplifyResult::Simplified(Expr::Literal(
- ScalarValue::new_one(&info.get_data_type(&base)?)?,
- None,
- )))
+ Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
+ &base_type,
+ )?)))
}
Expr::Literal(value, _) if value ==
ScalarValue::new_one(&exponent_type)? => {
Ok(ExprSimplifyResult::Simplified(base))
@@ -383,202 +370,6 @@ fn is_log(func: &ScalarUDF) -> bool {
#[cfg(test)]
mod tests {
use super::*;
- use arrow::array::{Array, Decimal128Array, Float64Array, Int64Array};
- use arrow::datatypes::{DECIMAL128_MAX_SCALE, Field};
- use datafusion_common::cast::{as_decimal128_array, as_float64_array};
- use datafusion_common::config::ConfigOptions;
- use std::sync::Arc;
-
- #[cfg(test)]
- #[ctor::ctor]
- fn init() {
- // Enable RUST_LOG logging configuration for test
- let _ = env_logger::try_init();
- }
-
- #[test]
- fn test_power_f64() {
- let arg_fields = vec![
- Field::new("a", DataType::Float64, true).into(),
- Field::new("a", DataType::Float64, true).into(),
- ];
- let args = ScalarFunctionArgs {
- args: vec![
- ColumnarValue::Array(Arc::new(Float64Array::from(vec![
- 2.0, 2.0, 3.0, 5.0,
- ]))), // base
- ColumnarValue::Array(Arc::new(Float64Array::from(vec![
- 3.0, 2.0, 4.0, 4.0,
- ]))), // exponent
- ],
- arg_fields,
- number_rows: 4,
- return_field: Field::new("f", DataType::Float64, true).into(),
- config_options: Arc::new(ConfigOptions::default()),
- };
- let result = PowerFunc::new()
- .invoke_with_args(args)
- .expect("failed to initialize function power");
-
- match result {
- ColumnarValue::Array(arr) => {
- let floats = as_float64_array(&arr)
- .expect("failed to convert result to a Float64Array");
- assert_eq!(floats.len(), 4);
- assert_eq!(floats.value(0), 8.0);
- assert_eq!(floats.value(1), 4.0);
- assert_eq!(floats.value(2), 81.0);
- assert_eq!(floats.value(3), 625.0);
- }
- ColumnarValue::Scalar(_) => {
- panic!("Expected an array value")
- }
- }
- }
-
- #[test]
- fn test_power_i128() {
- let arg_fields = vec![
- Field::new(
- "a",
- DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0),
- true,
- )
- .into(),
- Field::new("a", DataType::Int64, true).into(),
- ];
- let args = ScalarFunctionArgs {
- args: vec![
- ColumnarValue::Array(Arc::new(
- Decimal128Array::from(vec![2, 2, 3, 5, 0, 5])
- .with_precision_and_scale(DECIMAL128_MAX_SCALE as u8,
0)
- .unwrap(),
- )), // base
- ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4,
4, 4, 0]))), // exponent
- ],
- arg_fields,
- number_rows: 6,
- return_field: Field::new(
- "f",
- DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0),
- true,
- )
- .into(),
- config_options: Arc::new(ConfigOptions::default()),
- };
- let result = PowerFunc::new()
- .invoke_with_args(args)
- .expect("failed to initialize function power");
-
- match result {
- ColumnarValue::Array(arr) => {
- let ints = as_decimal128_array(&arr)
- .expect("failed to convert result to an array");
-
- assert_eq!(ints.len(), 6);
- assert_eq!(ints.value(0), i128::from(8));
- assert_eq!(ints.value(1), i128::from(4));
- assert_eq!(ints.value(2), i128::from(81));
- assert_eq!(ints.value(3), i128::from(625));
- assert_eq!(ints.value(4), i128::from(0));
- assert_eq!(ints.value(5), i128::from(1));
- }
- ColumnarValue::Scalar(_) => {
- panic!("Expected an array value")
- }
- }
- }
-
- #[test]
- fn test_power_array_null() {
- let arg_fields = vec![
- Field::new("a", DataType::Float64, true).into(),
- Field::new("a", DataType::Float64, true).into(),
- ];
- let args = ScalarFunctionArgs {
- args: vec![
- ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0,
2.0, 2.0]))), // base
- ColumnarValue::Array(Arc::new(Float64Array::from(vec![
- Some(1.0),
- None,
- Some(3.0),
- ]))), // exponent
- ],
- arg_fields,
- number_rows: 3,
- return_field: Field::new("f", DataType::Float64, true).into(),
- config_options: Arc::new(ConfigOptions::default()),
- };
- let result = PowerFunc::new()
- .invoke_with_args(args)
- .expect("failed to initialize function power");
-
- match result {
- ColumnarValue::Array(arr) => {
- let floats =
- as_float64_array(&arr).expect("failed to convert result to
an array");
-
- assert_eq!(floats.len(), 3);
- assert!(!floats.is_null(0));
- assert_eq!(floats.value(0), 2.0);
- assert!(floats.is_null(1));
- assert!(!floats.is_null(2));
- assert_eq!(floats.value(2), 8.0);
- }
- ColumnarValue::Scalar(_) => {
- panic!("Expected an array value")
- }
- }
- }
-
- #[test]
- fn test_power_decimal_with_scale() {
- // 2.5 ^ 4 = 39
- // 2.5 is 25 in Decimal128(2, 1) by parsing rules
- // Signature is Decimal128(2, 1) -> Int64 -> Decimal128(2, 1),
therefore
- // result is 390 in Decimal128(2, 1) aka 39 in unscaled Decimal128(2,
0)
- let arg_fields = vec![
- Field::new(
- "a",
- DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0),
- true,
- )
- .into(),
- Field::new("a", DataType::Int64, true).into(),
- ];
- let args = ScalarFunctionArgs {
- args: vec![
- ColumnarValue::Scalar(ScalarValue::Decimal128(
- Some(i128::from(25)),
- 2,
- 1,
- )), // base
- ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), // exponent
- ],
- arg_fields,
- number_rows: 1,
- return_field: Field::new("f", DataType::Decimal128(2, 1),
true).into(),
- config_options: Arc::new(ConfigOptions::default()),
- };
- let result = PowerFunc::new()
- .invoke_with_args(args)
- .expect("failed to initialize function power");
-
- match result {
- ColumnarValue::Array(arr) => {
- let ints = as_decimal128_array(&arr)
- .expect("failed to convert result to an array");
-
- assert_eq!(ints.len(), 1);
- assert_eq!(ints.value(0), i128::from(390));
- // Signature stays the same as input
- assert_eq!(*arr.data_type(), DataType::Decimal128(2, 1));
- }
- ColumnarValue::Scalar(_) => {
- panic!("Expected an array value")
- }
- }
- }
#[test]
fn test_pow_decimal128_helper() {
@@ -601,46 +392,4 @@ mod tests {
"Not yet implemented: Negative scale is not yet supported value:
-1"
);
}
-
- #[test]
- fn test_power_coerce_types() {
- let power_func = PowerFunc::new();
-
- // Int64 base with Int64 exponent -> base coerced to Float64 (like
PostgreSQL)
- // This allows negative exponents to work correctly
- let result = power_func
- .coerce_types(&[DataType::Int64, DataType::Int64])
- .unwrap();
- assert_eq!(result, vec![DataType::Float64, DataType::Int64]);
-
- // Float64 base with Float64 exponent -> both stay Float64
- let result = power_func
- .coerce_types(&[DataType::Float64, DataType::Float64])
- .unwrap();
- assert_eq!(result, vec![DataType::Float64, DataType::Float64]);
-
- // Int64 base with Float64 exponent -> base coerced to Float64
- let result = power_func
- .coerce_types(&[DataType::Int64, DataType::Float64])
- .unwrap();
- assert_eq!(result, vec![DataType::Float64, DataType::Float64]);
-
- // Int32 base with Float32 exponent -> both coerced to Float64
- let result = power_func
- .coerce_types(&[DataType::Int32, DataType::Float32])
- .unwrap();
- assert_eq!(result, vec![DataType::Float64, DataType::Float64]);
-
- // Null base with Float64 exponent -> base coerced to Float64
- let result = power_func
- .coerce_types(&[DataType::Null, DataType::Float64])
- .unwrap();
- assert_eq!(result, vec![DataType::Float64, DataType::Float64]);
-
- // Null base with Int64 exponent -> base coerced to Float64 (like
PostgreSQL)
- let result = power_func
- .coerce_types(&[DataType::Null, DataType::Int64])
- .unwrap();
- assert_eq!(result, vec![DataType::Float64, DataType::Int64]);
- }
}
diff --git a/datafusion/sqllogictest/test_files/math.slt
b/datafusion/sqllogictest/test_files/math.slt
index 322ba7a104..53cf17fe7a 100644
--- a/datafusion/sqllogictest/test_files/math.slt
+++ b/datafusion/sqllogictest/test_files/math.slt
@@ -740,6 +740,60 @@ select power(2107754225, 1221660777);
----
Infinity
+query R rowsort
+select power(base::double, exponent::double)
+from values
+ (2.0, 2.0),
+ (5.0, 4.0),
+ (2.0, 3.0),
+ (3.0, 4.0) as t(base, exponent);
+----
+4
+625
+8
+81
+
+query R rowsort
+select power(base::bigint, exponent::bigint)
+from values
+ (2, 2),
+ (5, 4),
+ (2, 3),
+ (3, 4),
+ (2, NULL) as t(base, exponent);
+----
+4
+625
+8
+81
+NULL
+
+query RT rowsort
+select
+ power(base::decimal(38, 0), exponent::decimal(38, 0)),
+ arrow_typeof(power(base::decimal(38, 0), exponent::decimal(38, 0)))
+from values
+ (0, 4),
+ (5, 0),
+ (2, 2),
+ (5, 4),
+ (2, 3),
+ (3, 4) as t(base, exponent);
+----
+0 Decimal128(38, 0)
+1 Decimal128(38, 0)
+4 Decimal128(38, 0)
+625 Decimal128(38, 0)
+8 Decimal128(38, 0)
+81 Decimal128(38, 0)
+
+query RT
+select
+ pow(2.5::decimal(2, 1), 4::bigint),
+ arrow_typeof(pow(2.5::decimal(2, 1), 4::bigint));
+----
+39 Decimal128(2, 1)
+
# factorial overflow
query error DataFusion error: Arrow error: Compute error: Overflow happened on
FACTORIAL\(350943270\)
select FACTORIAL(350943270);
diff --git a/datafusion/sqllogictest/test_files/scalar.slt
b/datafusion/sqllogictest/test_files/scalar.slt
index 7c6b38b78e..9c7071cb65 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -1883,7 +1883,7 @@ D false
# test string_temporal_coercion
query BBBBBBBBBB
-select
+select
arrow_cast(to_timestamp('2020-01-01 01:01:11.1234567890Z'),
'Timestamp(Second, None)') == '2020-01-01T01:01:11',
arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'),
'Timestamp(Second, None)') == arrow_cast('2020-01-02T01:01:11', 'LargeUtf8'),
arrow_cast(to_timestamp('2020-01-03 01:01:11.1234567890Z'),
'Time32(Second)') == '01:01:11',
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]