This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 1d5d63c413 Feat: Allow pow with negative & non-integer exponent on
decimals (#19369)
1d5d63c413 is described below
commit 1d5d63c413ed73d1ef389e2b560b73975fb6b0e7
Author: Yuvraj <[email protected]>
AuthorDate: Thu Jan 8 19:37:32 2026 +0530
Feat: Allow pow with negative & non-integer exponent on decimals (#19369)
## Which issue does this PR close?
Closes #19348
## Rationale for this change
Previously, pow() on decimal types would error for negative exponents
and non-integer exponents with messages like:
- Arrow error: Arithmetic overflow: Unsupported exp value: -5
- Compute error: Cannot use non-integer exp
- This was a regression from when decimals were cast to float before
pow(). The efficient integer-based algorithm for computing power on
scaled integers cannot handle these cases.
## What changes are included in this PR?
- Modified pow_decimal_int to fallback to pow_decimal_float for negative
exponents
- Modified pow_decimal_float to use an efficient integer path for
non-negative integer exponents, otherwise fallback to f64 computation
Added pow_decimal_float_fallback function that:
- Converts the decimal to f64
- Computes powf(exp)
- Converts back to the original decimal type with proper scaling
- Added decimal_from_i128 helper to convert i128 results back to generic
decimal types (needed for Decimal256 support)
- Updated sqllogictests to expect success for negative/non-integer
exponents
## Are these changes tested?
Yes:
Unit tests for pow_decimal_float_fallback covering negative exponents,
fractional exponents, cube roots
Updated SQL logic tests in decimal.slt
## Are there any user-facing changes?
Yes. The following queries now work instead of returning errors:
```sql
-- Negative exponent
SELECT power(4::decimal(38, 5), -1); -- Returns 0.25
-- Non-integer exponent
SELECT power(2.5, 4.2); -- Returns 46.9
-- Square root via power
SELECT power(4::decimal, 0.5); -- Returns 2
---
datafusion/functions/src/math/power.rs | 321 ++++++++++++++++++++++---
datafusion/sqllogictest/test_files/decimal.slt | 28 ++-
2 files changed, 306 insertions(+), 43 deletions(-)
diff --git a/datafusion/functions/src/math/power.rs
b/datafusion/functions/src/math/power.rs
index fafadd3ba4..489c59aa3d 100644
--- a/datafusion/functions/src/math/power.rs
+++ b/datafusion/functions/src/math/power.rs
@@ -22,9 +22,10 @@ use super::log::LogFunc;
use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
use arrow::array::{Array, ArrayRef};
+use arrow::datatypes::i256;
use arrow::datatypes::{
- ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
- Decimal256Type, Float64Type, Int64Type,
+ ArrowNativeType, ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type,
+ Decimal128Type, Decimal256Type, Float64Type, Int64Type,
};
use arrow::error::ArrowError;
use datafusion_common::types::{NativeType, logical_float64, logical_int64};
@@ -37,6 +38,7 @@ use datafusion_expr::{
ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
lit,
};
use datafusion_macros::user_doc;
+use num_traits::{NumCast, ToPrimitive};
#[user_doc(
doc_section(label = "Math Functions"),
@@ -112,12 +114,15 @@ impl PowerFunc {
/// 2.5 is represented as 25 with scale 1
/// The unscaled result is 25^4 = 390625
/// Scale it back to 1: 390625 / 10^4 = 39
-///
-/// Returns error if base is invalid
fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
where
- T: From<i32> + ArrowNativeTypeOp,
+ T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
{
+ // Negative exponent: fall back to float computation
+ if exp < 0 {
+ return pow_decimal_float(base, scale, exp as f64);
+ }
+
let exp: u32 = exp.try_into().map_err(|_| {
ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
})?;
@@ -125,13 +130,13 @@ where
// If scale < 0, 10^scale (e.g., 10^-2 = 0.01) becomes 0 in integer
arithmetic.
if exp == 0 {
return if scale >= 0 {
- T::from(10).pow_checked(scale as u32).map_err(|_| {
+ T::usize_as(10).pow_checked(scale as u32).map_err(|_| {
ArrowError::ArithmeticOverflow(format!(
"Cannot make unscale factor for {scale} and {exp}"
))
})
} else {
- Ok(T::from(0))
+ Ok(T::ZERO)
};
}
let powered: T = base.pow_checked(exp).map_err(|_| {
@@ -149,11 +154,12 @@ where
// If mul_exp is positive, we divide (standard case).
// If mul_exp is negative, we multiply (negative scale case).
if mul_exp > 0 {
- let div_factor: T = T::from(10).pow_checked(mul_exp as
u32).map_err(|_| {
- ArrowError::ArithmeticOverflow(format!(
- "Cannot make div factor for {scale} and {exp}"
- ))
- })?;
+ let div_factor: T =
+ T::usize_as(10).pow_checked(mul_exp as u32).map_err(|_| {
+ ArrowError::ArithmeticOverflow(format!(
+ "Cannot make div factor for {scale} and {exp}"
+ ))
+ })?;
powered.div_checked(div_factor)
} else {
// mul_exp is negative, so we multiply by 10^(-mul_exp)
@@ -162,33 +168,227 @@ where
"Overflow while negating scale exponent".to_string(),
)
})?;
- let mul_factor: T = T::from(10).pow_checked(abs_exp as
u32).map_err(|_| {
- ArrowError::ArithmeticOverflow(format!(
- "Cannot make mul factor for {scale} and {exp}"
- ))
- })?;
+ let mul_factor: T =
+ T::usize_as(10).pow_checked(abs_exp as u32).map_err(|_| {
+ ArrowError::ArithmeticOverflow(format!(
+ "Cannot make mul factor for {scale} and {exp}"
+ ))
+ })?;
powered.mul_checked(mul_factor)
}
}
/// Binary function to calculate a math power to float exponent
/// for scaled integer types.
-/// Returns error if exponent is negative or non-integer, or base invalid
fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
where
- T: From<i32> + ArrowNativeTypeOp,
+ T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
{
- if !exp.is_finite() || exp.trunc() != exp {
+ if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX
as f64 {
+ return pow_decimal_int(base, scale, exp as i64);
+ }
+
+ if !exp.is_finite() {
return Err(ArrowError::ComputeError(format!(
- "Cannot use non-integer exp: {exp}"
+ "Cannot use non-finite exp: {exp}"
)));
}
- if exp < 0f64 || exp >= u32::MAX as f64 {
+
+ pow_decimal_float_fallback(base, scale, exp)
+}
+
+/// Compute the f64 power result and scale it back.
+/// Returns the rounded i128 result for conversion to target type.
+#[inline]
+fn compute_pow_f64_result(
+ base_f64: f64,
+ scale: i8,
+ exp: f64,
+) -> Result<i128, ArrowError> {
+ let result_f64 = base_f64.powf(exp);
+
+ if !result_f64.is_finite() {
return Err(ArrowError::ArithmeticOverflow(format!(
- "Unsupported exp value: {exp}"
+ "Result of {base_f64}^{exp} is not finite"
+ )));
+ }
+
+ let scale_factor = 10f64.powi(scale as i32);
+ let result_scaled = result_f64 * scale_factor;
+ let result_rounded = result_scaled.round();
+
+ if result_rounded.abs() > i128::MAX as f64 {
+ return Err(ArrowError::ArithmeticOverflow(format!(
+ "Result {result_rounded} is too large for the target decimal type"
+ )));
+ }
+
+ Ok(result_rounded as i128)
+}
+
+/// Convert i128 result to target decimal native type using NumCast.
+/// Returns error if value overflows the target type.
+#[inline]
+fn decimal_from_i128<T>(value: i128) -> Result<T, ArrowError>
+where
+ T: NumCast,
+{
+ NumCast::from(value).ok_or_else(|| {
+ ArrowError::ArithmeticOverflow(format!(
+ "Value {value} is too large for the target decimal type"
+ ))
+ })
+}
+
+/// Fallback implementation using f64 for negative or non-integer exponents.
+/// This handles cases that cannot be computed using integer arithmetic.
+fn pow_decimal_float_fallback<T>(base: T, scale: i8, exp: f64) -> Result<T,
ArrowError>
+where
+ T: ToPrimitive + NumCast + Copy,
+{
+ if scale < 0 {
+ return Err(ArrowError::NotYetImplemented(format!(
+ "Negative scale is not yet supported: {scale}"
)));
}
- pow_decimal_int(base, scale, exp as i64)
+
+ let scale_factor = 10f64.powi(scale as i32);
+ let base_f64 = base.to_f64().ok_or_else(|| {
+ ArrowError::ComputeError("Cannot convert base to f64".to_string())
+ })? / scale_factor;
+
+ let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
+
+ decimal_from_i128(result_i128)
+}
+
+/// Decimal256 specialized float exponent version.
+fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result<i256,
ArrowError> {
+ if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX
as f64 {
+ return pow_decimal256_int(base, scale, exp as i64);
+ }
+
+ if !exp.is_finite() {
+ return Err(ArrowError::ComputeError(format!(
+ "Cannot use non-finite exp: {exp}"
+ )));
+ }
+
+ pow_decimal256_float_fallback(base, scale, exp)
+}
+
+/// Decimal256 specialized integer exponent version.
+fn pow_decimal256_int(base: i256, scale: i8, exp: i64) -> Result<i256,
ArrowError> {
+ if exp < 0 {
+ return pow_decimal256_float(base, scale, exp as f64);
+ }
+
+ let exp: u32 = exp.try_into().map_err(|_| {
+ ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
+ })?;
+
+ if exp == 0 {
+ return if scale >= 0 {
+ i256::from_i128(10).pow_checked(scale as u32).map_err(|_| {
+ ArrowError::ArithmeticOverflow(format!(
+ "Cannot make unscale factor for {scale} and {exp}"
+ ))
+ })
+ } else {
+ Ok(i256::from_i128(0))
+ };
+ }
+
+ let powered: i256 = base.pow_checked(exp).map_err(|_| {
+ ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to
exp {exp}"))
+ })?;
+
+ let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
+
+ if mul_exp == 0 {
+ return Ok(powered);
+ }
+
+ if mul_exp > 0 {
+ let div_factor: i256 =
+ i256::from_i128(10)
+ .pow_checked(mul_exp as u32)
+ .map_err(|_| {
+ ArrowError::ArithmeticOverflow(format!(
+ "Cannot make div factor for {scale} and {exp}"
+ ))
+ })?;
+ powered.div_checked(div_factor)
+ } else {
+ let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
+ ArrowError::ArithmeticOverflow(
+ "Overflow while negating scale exponent".to_string(),
+ )
+ })?;
+ let mul_factor: i256 =
+ i256::from_i128(10)
+ .pow_checked(abs_exp as u32)
+ .map_err(|_| {
+ ArrowError::ArithmeticOverflow(format!(
+ "Cannot make mul factor for {scale} and {exp}"
+ ))
+ })?;
+ powered.mul_checked(mul_factor)
+ }
+}
+
+/// Fallback implementation for Decimal256.
+fn pow_decimal256_float_fallback(
+ base: i256,
+ scale: i8,
+ exp: f64,
+) -> Result<i256, ArrowError> {
+ if scale < 0 {
+ return Err(ArrowError::NotYetImplemented(format!(
+ "Negative scale is not yet supported: {scale}"
+ )));
+ }
+
+ let scale_factor = 10f64.powi(scale as i32);
+ let base_f64 = base.to_f64().ok_or_else(|| {
+ ArrowError::ComputeError("Cannot convert base to f64".to_string())
+ })? / scale_factor;
+
+ let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
+
+ // i256 can be constructed from i128 directly
+ Ok(i256::from_i128(result_i128))
+}
+
+/// Fallback implementation for decimal power when exponent is an array.
+/// Casts decimal to float64, computes power, and casts back to original
decimal type.
+/// This is used for performance when exponent varies per-row.
+fn pow_decimal_with_float_fallback(
+ base: &ArrayRef,
+ exponent: &ColumnarValue,
+ num_rows: usize,
+) -> Result<ColumnarValue> {
+ use arrow::compute::cast;
+
+ let original_type = base.data_type().clone();
+ let base_f64 = cast(base.as_ref(), &DataType::Float64)?;
+
+ let exp_f64 = match exponent {
+ ColumnarValue::Array(arr) => cast(arr.as_ref(), &DataType::Float64)?,
+ ColumnarValue::Scalar(scalar) => {
+ let scalar_f64 = scalar.cast_to(&DataType::Float64)?;
+ scalar_f64.to_array_of_size(num_rows)?
+ }
+ };
+
+ let result_f64 = calculate_binary_math::<Float64Type, Float64Type,
Float64Type, _>(
+ &base_f64,
+ &ColumnarValue::Array(exp_f64),
+ |b, e| Ok(f64::powf(b, e)),
+ )?;
+
+ let result = cast(result_f64.as_ref(), &original_type)?;
+ Ok(ColumnarValue::Array(result))
}
impl ScalarUDFImpl for PowerFunc {
@@ -218,8 +418,25 @@ impl ScalarUDFImpl for PowerFunc {
fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
let [base, exponent] = take_function_args(self.name(), &args.args)?;
+
+ // For decimal types, only use native decimal
+ // operations when we have a scalar exponent. When the exponent is an
array,
+ // fall back to float computation for better performance.
+ let use_float_fallback = matches!(
+ base.data_type(),
+ DataType::Decimal32(_, _)
+ | DataType::Decimal64(_, _)
+ | DataType::Decimal128(_, _)
+ | DataType::Decimal256(_, _)
+ ) && matches!(exponent, ColumnarValue::Array(_));
+
let base = base.to_array(args.number_rows)?;
+ // If decimal with array exponent, cast to float and compute
+ if use_float_fallback {
+ return pow_decimal_with_float_fallback(&base, exponent,
args.number_rows);
+ }
+
let arr: ArrayRef = match (base.data_type(), exponent.data_type()) {
(DataType::Float64, DataType::Float64) => {
calculate_binary_math::<Float64Type, Float64Type, Float64Type,
_>(
@@ -311,7 +528,7 @@ impl ScalarUDFImpl for PowerFunc {
>(
&base,
exponent,
- |b, e| pow_decimal_int(b, *scale, e),
+ |b, e| pow_decimal256_int(b, *scale, e),
*precision,
*scale,
)?
@@ -325,7 +542,7 @@ impl ScalarUDFImpl for PowerFunc {
>(
&base,
exponent,
- |b, e| pow_decimal_float(b, *scale, e),
+ |b, e| pow_decimal256_float(b, *scale, e),
*precision,
*scale,
)?
@@ -398,19 +615,53 @@ mod tests {
#[test]
fn test_pow_decimal128_helper() {
// Expression: 2.5 ^ 4 = 39.0625
- assert_eq!(pow_decimal_int(25, 1, 4).unwrap(), i128::from(390));
- assert_eq!(pow_decimal_int(2500, 3, 4).unwrap(), i128::from(39062));
- assert_eq!(pow_decimal_int(25000, 4, 4).unwrap(), i128::from(390625));
+ assert_eq!(pow_decimal_int(25i128, 1, 4).unwrap(), 390i128);
+ assert_eq!(pow_decimal_int(2500i128, 3, 4).unwrap(), 39062i128);
+ assert_eq!(pow_decimal_int(25000i128, 4, 4).unwrap(), 390625i128);
// Expression: 25 ^ 4 = 390625
- assert_eq!(pow_decimal_int(25, 0, 4).unwrap(), i128::from(390625));
+ assert_eq!(pow_decimal_int(25i128, 0, 4).unwrap(), 390625i128);
// Expressions for edge cases
- assert_eq!(pow_decimal_int(25, 1, 1).unwrap(), i128::from(25));
- assert_eq!(pow_decimal_int(25, 0, 1).unwrap(), i128::from(25));
- assert_eq!(pow_decimal_int(25, 0, 0).unwrap(), i128::from(1));
- assert_eq!(pow_decimal_int(25, 1, 0).unwrap(), i128::from(10));
+ assert_eq!(pow_decimal_int(25i128, 1, 1).unwrap(), 25i128);
+ assert_eq!(pow_decimal_int(25i128, 0, 1).unwrap(), 25i128);
+ assert_eq!(pow_decimal_int(25i128, 0, 0).unwrap(), 1i128);
+ assert_eq!(pow_decimal_int(25i128, 1, 0).unwrap(), 10i128);
+
+ assert_eq!(pow_decimal_int(25i128, -1, 4).unwrap(), 390625000i128);
+ }
+
+ #[test]
+ fn test_pow_decimal_float_fallback() {
+ // Test negative exponent: 4^(-1) = 0.25
+ // 4 with scale 2 = 400, result should be 25 (0.25 with scale 2)
+ let result: i128 = pow_decimal_float(400i128, 2, -1.0).unwrap();
+ assert_eq!(result, 25);
+
+ // Test non-integer exponent: 4^0.5 = 2
+ // 4 with scale 2 = 400, result should be 200 (2.0 with scale 2)
+ let result: i128 = pow_decimal_float(400i128, 2, 0.5).unwrap();
+ assert_eq!(result, 200);
+
+ // Test 8^(1/3) = 2 (cube root)
+ // 8 with scale 1 = 80, result should be 20 (2.0 with scale 1)
+ let result: i128 = pow_decimal_float(80i128, 1, 1.0 / 3.0).unwrap();
+ assert_eq!(result, 20);
+
+ // Test negative base with integer exponent still works
+ // (-2)^3 = -8
+ // -2 with scale 1 = -20, result should be -80 (-8.0 with scale 1)
+ let result: i128 = pow_decimal_float(-20i128, 1, 3.0).unwrap();
+ assert_eq!(result, -80);
+
+ // Test positive integer exponent goes through fast path
+ // 2.5^4 = 39.0625
+ // 25 with scale 1, result should be 390 (39.0 with scale 1) -
truncated
+ let result: i128 = pow_decimal_float(25i128, 1, 4.0).unwrap();
+ assert_eq!(result, 390); // Uses integer path
- assert_eq!(pow_decimal_int(25, -1, 4).unwrap(), i128::from(390625000));
+ // Test non-finite exponent returns error
+ assert!(pow_decimal_float(100i128, 2, f64::NAN).is_err());
+ assert!(pow_decimal_float(100i128, 2, f64::INFINITY).is_err());
}
}
diff --git a/datafusion/sqllogictest/test_files/decimal.slt
b/datafusion/sqllogictest/test_files/decimal.slt
index 85f2559f58..f53f493929 100644
--- a/datafusion/sqllogictest/test_files/decimal.slt
+++ b/datafusion/sqllogictest/test_files/decimal.slt
@@ -1095,8 +1095,17 @@ SELECT power(2, 100000000000)
----
Infinity
-query error Arrow error: Arithmetic overflow: Unsupported exp value
-SELECT power(2::decimal(38, 0), -5)
+# Negative exponent now works (fallback to f64)
+query RT
+SELECT power(2::decimal(38, 0), -5), arrow_typeof(power(2::decimal(38, 0),
-5));
+----
+0 Decimal128(38, 0)
+
+# Negative exponent with scale preserves decimal places
+query RT
+SELECT power(4::decimal(38, 5), -1), arrow_typeof(power(4::decimal(38, 5),
-1));
+----
+0.25 Decimal128(38, 5)
# Expected to have `16 Decimal128(38, 0)`
# Due to type coericion, it becomes Float -> Float -> Float
@@ -1116,20 +1125,23 @@ SELECT power(2.5, 4.0), arrow_typeof(power(2.5, 4.0));
----
39 Decimal128(2, 1)
-query error Compute error: Cannot use non-integer exp
+# Non-integer exponent now works (fallback to f64)
+query RT
SELECT power(2.5, 4.2), arrow_typeof(power(2.5, 4.2));
+----
+46.9 Decimal128(2, 1)
-query error Compute error: Cannot use non-integer exp: NaN
+query error Compute error: Cannot use non-finite exp: NaN
SELECT power(2::decimal(38, 0), arrow_cast('NaN','Float64'))
-query error Compute error: Cannot use non-integer exp: inf
+query error Compute error: Cannot use non-finite exp: inf
SELECT power(2::decimal(38, 0), arrow_cast('INF','Float64'))
-# Floating above u32::max
-query error Compute error: Cannot use non-integer exp
+# Floating above u32::max now works (fallback to f64, returns infinity which
is an error)
+query error Arrow error: Arithmetic overflow: Result of 2\^5000000000.1 is not
finite
SELECT power(2::decimal(38, 0), 5000000000.1)
-# Integer Above u32::max
+# Integer Above u32::max - still goes through integer path which fails
query error Arrow error: Arithmetic overflow: Unsupported exp value
SELECT power(2::decimal(38, 0), 5000000000)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]