This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 1d5d63c413 Feat: Allow pow with negative & non-integer exponent on 
decimals (#19369)
1d5d63c413 is described below

commit 1d5d63c413ed73d1ef389e2b560b73975fb6b0e7
Author: Yuvraj <[email protected]>
AuthorDate: Thu Jan 8 19:37:32 2026 +0530

    Feat: Allow pow with negative & non-integer exponent on decimals (#19369)
    
    ## Which issue does this PR close?
    Closes #19348
    
    ## Rationale for this change
    Previously, pow() on decimal types would error for negative exponents
    and non-integer exponents with messages like:
    
    - Arrow error: Arithmetic overflow: Unsupported exp value: -5
    - Compute error: Cannot use non-integer exp
    - This was a regression from when decimals were cast to float before
    pow(). The efficient integer-based algorithm for computing power on
    scaled integers cannot handle these cases.
    
    ## What changes are included in this PR?
    - Modified pow_decimal_int to fallback to pow_decimal_float for negative
    exponents
    - Modified pow_decimal_float to use an efficient integer path for
    non-negative integer exponents, otherwise fallback to f64 computation
    
    Added pow_decimal_float_fallback function that:
    - Converts the decimal to f64
    - Computes powf(exp)
    - Converts back to the original decimal type with proper scaling
    - Added decimal_from_i128 helper to convert i128 results back to generic
    decimal types (needed for Decimal256 support)
    - Updated sqllogictests to expect success for negative/non-integer
    exponents
    
    ## Are these changes tested?
    Yes:
    
    Unit tests for pow_decimal_float_fallback covering negative exponents,
    fractional exponents, cube roots
    Updated SQL logic tests in decimal.slt
    
    ## Are there any user-facing changes?
    Yes. The following queries now work instead of returning errors:
    
    ```sql
    -- Negative exponent
    SELECT power(4::decimal(38, 5), -1);  -- Returns 0.25
    
    -- Non-integer exponent
    SELECT power(2.5, 4.2);  -- Returns 46.9
    
    -- Square root via power
    SELECT power(4::decimal, 0.5);  -- Returns 2
---
 datafusion/functions/src/math/power.rs         | 321 ++++++++++++++++++++++---
 datafusion/sqllogictest/test_files/decimal.slt |  28 ++-
 2 files changed, 306 insertions(+), 43 deletions(-)

diff --git a/datafusion/functions/src/math/power.rs 
b/datafusion/functions/src/math/power.rs
index fafadd3ba4..489c59aa3d 100644
--- a/datafusion/functions/src/math/power.rs
+++ b/datafusion/functions/src/math/power.rs
@@ -22,9 +22,10 @@ use super::log::LogFunc;
 
 use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
 use arrow::array::{Array, ArrayRef};
+use arrow::datatypes::i256;
 use arrow::datatypes::{
-    ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
-    Decimal256Type, Float64Type, Int64Type,
+    ArrowNativeType, ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type,
+    Decimal128Type, Decimal256Type, Float64Type, Int64Type,
 };
 use arrow::error::ArrowError;
 use datafusion_common::types::{NativeType, logical_float64, logical_int64};
@@ -37,6 +38,7 @@ use datafusion_expr::{
     ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, 
lit,
 };
 use datafusion_macros::user_doc;
+use num_traits::{NumCast, ToPrimitive};
 
 #[user_doc(
     doc_section(label = "Math Functions"),
@@ -112,12 +114,15 @@ impl PowerFunc {
 ///   2.5 is represented as 25 with scale 1
 ///   The unscaled result is 25^4 = 390625
 ///   Scale it back to 1: 390625 / 10^4 = 39
-///
-/// Returns error if base is invalid
 fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
 where
-    T: From<i32> + ArrowNativeTypeOp,
+    T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
 {
+    // Negative exponent: fall back to float computation
+    if exp < 0 {
+        return pow_decimal_float(base, scale, exp as f64);
+    }
+
     let exp: u32 = exp.try_into().map_err(|_| {
         ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
     })?;
@@ -125,13 +130,13 @@ where
     // If scale < 0, 10^scale (e.g., 10^-2 = 0.01) becomes 0 in integer 
arithmetic.
     if exp == 0 {
         return if scale >= 0 {
-            T::from(10).pow_checked(scale as u32).map_err(|_| {
+            T::usize_as(10).pow_checked(scale as u32).map_err(|_| {
                 ArrowError::ArithmeticOverflow(format!(
                     "Cannot make unscale factor for {scale} and {exp}"
                 ))
             })
         } else {
-            Ok(T::from(0))
+            Ok(T::ZERO)
         };
     }
     let powered: T = base.pow_checked(exp).map_err(|_| {
@@ -149,11 +154,12 @@ where
     // If mul_exp is positive, we divide (standard case).
     // If mul_exp is negative, we multiply (negative scale case).
     if mul_exp > 0 {
-        let div_factor: T = T::from(10).pow_checked(mul_exp as 
u32).map_err(|_| {
-            ArrowError::ArithmeticOverflow(format!(
-                "Cannot make div factor for {scale} and {exp}"
-            ))
-        })?;
+        let div_factor: T =
+            T::usize_as(10).pow_checked(mul_exp as u32).map_err(|_| {
+                ArrowError::ArithmeticOverflow(format!(
+                    "Cannot make div factor for {scale} and {exp}"
+                ))
+            })?;
         powered.div_checked(div_factor)
     } else {
         // mul_exp is negative, so we multiply by 10^(-mul_exp)
@@ -162,33 +168,227 @@ where
                 "Overflow while negating scale exponent".to_string(),
             )
         })?;
-        let mul_factor: T = T::from(10).pow_checked(abs_exp as 
u32).map_err(|_| {
-            ArrowError::ArithmeticOverflow(format!(
-                "Cannot make mul factor for {scale} and {exp}"
-            ))
-        })?;
+        let mul_factor: T =
+            T::usize_as(10).pow_checked(abs_exp as u32).map_err(|_| {
+                ArrowError::ArithmeticOverflow(format!(
+                    "Cannot make mul factor for {scale} and {exp}"
+                ))
+            })?;
         powered.mul_checked(mul_factor)
     }
 }
 
 /// Binary function to calculate a math power to float exponent
 /// for scaled integer types.
-/// Returns error if exponent is negative or non-integer, or base invalid
 fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
 where
-    T: From<i32> + ArrowNativeTypeOp,
+    T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
 {
-    if !exp.is_finite() || exp.trunc() != exp {
+    if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX 
as f64 {
+        return pow_decimal_int(base, scale, exp as i64);
+    }
+
+    if !exp.is_finite() {
         return Err(ArrowError::ComputeError(format!(
-            "Cannot use non-integer exp: {exp}"
+            "Cannot use non-finite exp: {exp}"
         )));
     }
-    if exp < 0f64 || exp >= u32::MAX as f64 {
+
+    pow_decimal_float_fallback(base, scale, exp)
+}
+
+/// Compute the f64 power result and scale it back.
+/// Returns the rounded i128 result for conversion to target type.
+#[inline]
+fn compute_pow_f64_result(
+    base_f64: f64,
+    scale: i8,
+    exp: f64,
+) -> Result<i128, ArrowError> {
+    let result_f64 = base_f64.powf(exp);
+
+    if !result_f64.is_finite() {
         return Err(ArrowError::ArithmeticOverflow(format!(
-            "Unsupported exp value: {exp}"
+            "Result of {base_f64}^{exp} is not finite"
+        )));
+    }
+
+    let scale_factor = 10f64.powi(scale as i32);
+    let result_scaled = result_f64 * scale_factor;
+    let result_rounded = result_scaled.round();
+
+    if result_rounded.abs() > i128::MAX as f64 {
+        return Err(ArrowError::ArithmeticOverflow(format!(
+            "Result {result_rounded} is too large for the target decimal type"
+        )));
+    }
+
+    Ok(result_rounded as i128)
+}
+
+/// Convert i128 result to target decimal native type using NumCast.
+/// Returns error if value overflows the target type.
+#[inline]
+fn decimal_from_i128<T>(value: i128) -> Result<T, ArrowError>
+where
+    T: NumCast,
+{
+    NumCast::from(value).ok_or_else(|| {
+        ArrowError::ArithmeticOverflow(format!(
+            "Value {value} is too large for the target decimal type"
+        ))
+    })
+}
+
+/// Fallback implementation using f64 for negative or non-integer exponents.
+/// This handles cases that cannot be computed using integer arithmetic.
+fn pow_decimal_float_fallback<T>(base: T, scale: i8, exp: f64) -> Result<T, 
ArrowError>
+where
+    T: ToPrimitive + NumCast + Copy,
+{
+    if scale < 0 {
+        return Err(ArrowError::NotYetImplemented(format!(
+            "Negative scale is not yet supported: {scale}"
         )));
     }
-    pow_decimal_int(base, scale, exp as i64)
+
+    let scale_factor = 10f64.powi(scale as i32);
+    let base_f64 = base.to_f64().ok_or_else(|| {
+        ArrowError::ComputeError("Cannot convert base to f64".to_string())
+    })? / scale_factor;
+
+    let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
+
+    decimal_from_i128(result_i128)
+}
+
+/// Decimal256 specialized float exponent version.
+fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result<i256, 
ArrowError> {
+    if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX 
as f64 {
+        return pow_decimal256_int(base, scale, exp as i64);
+    }
+
+    if !exp.is_finite() {
+        return Err(ArrowError::ComputeError(format!(
+            "Cannot use non-finite exp: {exp}"
+        )));
+    }
+
+    pow_decimal256_float_fallback(base, scale, exp)
+}
+
+/// Decimal256 specialized integer exponent version.
+fn pow_decimal256_int(base: i256, scale: i8, exp: i64) -> Result<i256, 
ArrowError> {
+    if exp < 0 {
+        return pow_decimal256_float(base, scale, exp as f64);
+    }
+
+    let exp: u32 = exp.try_into().map_err(|_| {
+        ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
+    })?;
+
+    if exp == 0 {
+        return if scale >= 0 {
+            i256::from_i128(10).pow_checked(scale as u32).map_err(|_| {
+                ArrowError::ArithmeticOverflow(format!(
+                    "Cannot make unscale factor for {scale} and {exp}"
+                ))
+            })
+        } else {
+            Ok(i256::from_i128(0))
+        };
+    }
+
+    let powered: i256 = base.pow_checked(exp).map_err(|_| {
+        ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to 
exp {exp}"))
+    })?;
+
+    let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
+
+    if mul_exp == 0 {
+        return Ok(powered);
+    }
+
+    if mul_exp > 0 {
+        let div_factor: i256 =
+            i256::from_i128(10)
+                .pow_checked(mul_exp as u32)
+                .map_err(|_| {
+                    ArrowError::ArithmeticOverflow(format!(
+                        "Cannot make div factor for {scale} and {exp}"
+                    ))
+                })?;
+        powered.div_checked(div_factor)
+    } else {
+        let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
+            ArrowError::ArithmeticOverflow(
+                "Overflow while negating scale exponent".to_string(),
+            )
+        })?;
+        let mul_factor: i256 =
+            i256::from_i128(10)
+                .pow_checked(abs_exp as u32)
+                .map_err(|_| {
+                    ArrowError::ArithmeticOverflow(format!(
+                        "Cannot make mul factor for {scale} and {exp}"
+                    ))
+                })?;
+        powered.mul_checked(mul_factor)
+    }
+}
+
+/// Fallback implementation for Decimal256.
+fn pow_decimal256_float_fallback(
+    base: i256,
+    scale: i8,
+    exp: f64,
+) -> Result<i256, ArrowError> {
+    if scale < 0 {
+        return Err(ArrowError::NotYetImplemented(format!(
+            "Negative scale is not yet supported: {scale}"
+        )));
+    }
+
+    let scale_factor = 10f64.powi(scale as i32);
+    let base_f64 = base.to_f64().ok_or_else(|| {
+        ArrowError::ComputeError("Cannot convert base to f64".to_string())
+    })? / scale_factor;
+
+    let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
+
+    // i256 can be constructed from i128 directly
+    Ok(i256::from_i128(result_i128))
+}
+
+/// Fallback implementation for decimal power when exponent is an array.
+/// Casts decimal to float64, computes power, and casts back to original 
decimal type.
+/// This is used for performance when exponent varies per-row.
+fn pow_decimal_with_float_fallback(
+    base: &ArrayRef,
+    exponent: &ColumnarValue,
+    num_rows: usize,
+) -> Result<ColumnarValue> {
+    use arrow::compute::cast;
+
+    let original_type = base.data_type().clone();
+    let base_f64 = cast(base.as_ref(), &DataType::Float64)?;
+
+    let exp_f64 = match exponent {
+        ColumnarValue::Array(arr) => cast(arr.as_ref(), &DataType::Float64)?,
+        ColumnarValue::Scalar(scalar) => {
+            let scalar_f64 = scalar.cast_to(&DataType::Float64)?;
+            scalar_f64.to_array_of_size(num_rows)?
+        }
+    };
+
+    let result_f64 = calculate_binary_math::<Float64Type, Float64Type, 
Float64Type, _>(
+        &base_f64,
+        &ColumnarValue::Array(exp_f64),
+        |b, e| Ok(f64::powf(b, e)),
+    )?;
+
+    let result = cast(result_f64.as_ref(), &original_type)?;
+    Ok(ColumnarValue::Array(result))
 }
 
 impl ScalarUDFImpl for PowerFunc {
@@ -218,8 +418,25 @@ impl ScalarUDFImpl for PowerFunc {
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
         let [base, exponent] = take_function_args(self.name(), &args.args)?;
+
+        // For decimal types, only use native decimal
+        // operations when we have a scalar exponent. When the exponent is an 
array,
+        // fall back to float computation for better performance.
+        let use_float_fallback = matches!(
+            base.data_type(),
+            DataType::Decimal32(_, _)
+                | DataType::Decimal64(_, _)
+                | DataType::Decimal128(_, _)
+                | DataType::Decimal256(_, _)
+        ) && matches!(exponent, ColumnarValue::Array(_));
+
         let base = base.to_array(args.number_rows)?;
 
+        // If decimal with array exponent, cast to float and compute
+        if use_float_fallback {
+            return pow_decimal_with_float_fallback(&base, exponent, 
args.number_rows);
+        }
+
         let arr: ArrayRef = match (base.data_type(), exponent.data_type()) {
             (DataType::Float64, DataType::Float64) => {
                 calculate_binary_math::<Float64Type, Float64Type, Float64Type, 
_>(
@@ -311,7 +528,7 @@ impl ScalarUDFImpl for PowerFunc {
                 >(
                     &base,
                     exponent,
-                    |b, e| pow_decimal_int(b, *scale, e),
+                    |b, e| pow_decimal256_int(b, *scale, e),
                     *precision,
                     *scale,
                 )?
@@ -325,7 +542,7 @@ impl ScalarUDFImpl for PowerFunc {
                 >(
                     &base,
                     exponent,
-                    |b, e| pow_decimal_float(b, *scale, e),
+                    |b, e| pow_decimal256_float(b, *scale, e),
                     *precision,
                     *scale,
                 )?
@@ -398,19 +615,53 @@ mod tests {
     #[test]
     fn test_pow_decimal128_helper() {
         // Expression: 2.5 ^ 4 = 39.0625
-        assert_eq!(pow_decimal_int(25, 1, 4).unwrap(), i128::from(390));
-        assert_eq!(pow_decimal_int(2500, 3, 4).unwrap(), i128::from(39062));
-        assert_eq!(pow_decimal_int(25000, 4, 4).unwrap(), i128::from(390625));
+        assert_eq!(pow_decimal_int(25i128, 1, 4).unwrap(), 390i128);
+        assert_eq!(pow_decimal_int(2500i128, 3, 4).unwrap(), 39062i128);
+        assert_eq!(pow_decimal_int(25000i128, 4, 4).unwrap(), 390625i128);
 
         // Expression: 25 ^ 4 = 390625
-        assert_eq!(pow_decimal_int(25, 0, 4).unwrap(), i128::from(390625));
+        assert_eq!(pow_decimal_int(25i128, 0, 4).unwrap(), 390625i128);
 
         // Expressions for edge cases
-        assert_eq!(pow_decimal_int(25, 1, 1).unwrap(), i128::from(25));
-        assert_eq!(pow_decimal_int(25, 0, 1).unwrap(), i128::from(25));
-        assert_eq!(pow_decimal_int(25, 0, 0).unwrap(), i128::from(1));
-        assert_eq!(pow_decimal_int(25, 1, 0).unwrap(), i128::from(10));
+        assert_eq!(pow_decimal_int(25i128, 1, 1).unwrap(), 25i128);
+        assert_eq!(pow_decimal_int(25i128, 0, 1).unwrap(), 25i128);
+        assert_eq!(pow_decimal_int(25i128, 0, 0).unwrap(), 1i128);
+        assert_eq!(pow_decimal_int(25i128, 1, 0).unwrap(), 10i128);
+
+        assert_eq!(pow_decimal_int(25i128, -1, 4).unwrap(), 390625000i128);
+    }
+
+    #[test]
+    fn test_pow_decimal_float_fallback() {
+        // Test negative exponent: 4^(-1) = 0.25
+        // 4 with scale 2 = 400, result should be 25 (0.25 with scale 2)
+        let result: i128 = pow_decimal_float(400i128, 2, -1.0).unwrap();
+        assert_eq!(result, 25);
+
+        // Test non-integer exponent: 4^0.5 = 2
+        // 4 with scale 2 = 400, result should be 200 (2.0 with scale 2)
+        let result: i128 = pow_decimal_float(400i128, 2, 0.5).unwrap();
+        assert_eq!(result, 200);
+
+        // Test 8^(1/3) = 2 (cube root)
+        // 8 with scale 1 = 80, result should be 20 (2.0 with scale 1)
+        let result: i128 = pow_decimal_float(80i128, 1, 1.0 / 3.0).unwrap();
+        assert_eq!(result, 20);
+
+        // Test negative base with integer exponent still works
+        // (-2)^3 = -8
+        // -2 with scale 1 = -20, result should be -80 (-8.0 with scale 1)
+        let result: i128 = pow_decimal_float(-20i128, 1, 3.0).unwrap();
+        assert_eq!(result, -80);
+
+        // Test positive integer exponent goes through fast path
+        // 2.5^4 = 39.0625
+        // 25 with scale 1, result should be 390 (39.0 with scale 1) - 
truncated
+        let result: i128 = pow_decimal_float(25i128, 1, 4.0).unwrap();
+        assert_eq!(result, 390); // Uses integer path
 
-        assert_eq!(pow_decimal_int(25, -1, 4).unwrap(), i128::from(390625000));
+        // Test non-finite exponent returns error
+        assert!(pow_decimal_float(100i128, 2, f64::NAN).is_err());
+        assert!(pow_decimal_float(100i128, 2, f64::INFINITY).is_err());
     }
 }
diff --git a/datafusion/sqllogictest/test_files/decimal.slt 
b/datafusion/sqllogictest/test_files/decimal.slt
index 85f2559f58..f53f493929 100644
--- a/datafusion/sqllogictest/test_files/decimal.slt
+++ b/datafusion/sqllogictest/test_files/decimal.slt
@@ -1095,8 +1095,17 @@ SELECT power(2, 100000000000)
 ----
 Infinity
 
-query error Arrow error: Arithmetic overflow: Unsupported exp value
-SELECT power(2::decimal(38, 0), -5)
+# Negative exponent now works (fallback to f64)
+query RT
+SELECT power(2::decimal(38, 0), -5), arrow_typeof(power(2::decimal(38, 0), 
-5));
+----
+0 Decimal128(38, 0)
+
+# Negative exponent with scale preserves decimal places
+query RT
+SELECT power(4::decimal(38, 5), -1), arrow_typeof(power(4::decimal(38, 5), 
-1));
+----
+0.25 Decimal128(38, 5)
 
 # Expected to have `16 Decimal128(38, 0)`
 # Due to type coericion, it becomes Float -> Float -> Float
@@ -1116,20 +1125,23 @@ SELECT power(2.5, 4.0), arrow_typeof(power(2.5, 4.0));
 ----
 39 Decimal128(2, 1)
 
-query error Compute error: Cannot use non-integer exp
+# Non-integer exponent now works (fallback to f64)
+query RT
 SELECT power(2.5, 4.2), arrow_typeof(power(2.5, 4.2));
+----
+46.9 Decimal128(2, 1)
 
-query error Compute error: Cannot use non-integer exp: NaN
+query error Compute error: Cannot use non-finite exp: NaN
 SELECT power(2::decimal(38, 0), arrow_cast('NaN','Float64'))
 
-query error Compute error: Cannot use non-integer exp: inf
+query error Compute error: Cannot use non-finite exp: inf
 SELECT power(2::decimal(38, 0), arrow_cast('INF','Float64'))
 
-# Floating above u32::max
-query error Compute error: Cannot use non-integer exp
+# Floating above u32::max now works (fallback to f64, returns infinity which 
is an error)
+query error Arrow error: Arithmetic overflow: Result of 2\^5000000000.1 is not 
finite
 SELECT power(2::decimal(38, 0), 5000000000.1)
 
-# Integer Above u32::max
+# Integer Above u32::max - still goes through integer path which fails
 query error Arrow error: Arithmetic overflow: Unsupported exp value
 SELECT power(2::decimal(38, 0), 5000000000)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to