This is an automated email from the ASF dual-hosted git repository.
ytyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 4b9a468cc1 feat: Add
`ScalarValue::{new_one,new_zero,new_ten,distance}` support for `Decimal128` and
`Decimal256` (#16831)
4b9a468cc1 is described below
commit 4b9a468cc1949062cf3cd8685ba8ced377fd212e
Author: theirix <[email protected]>
AuthorDate: Sun Jul 27 05:09:38 2025 +0100
feat: Add `ScalarValue::{new_one,new_zero,new_ten,distance}` support for
`Decimal128` and `Decimal256` (#16831)
* Add missing ScalarValue impls for large decimals
Add methods distance, new_zero, new_one, new_ten for Decimal128,
Decimal256
* Support expr simplication for Decimal256
* Replace lookup table with i128::pow
* Support different scales for Decimal constructors
- Allow to construct one and ten with different scales
- Add tests for new_one, new_ten
- Add test for distance
* Revert "Replace lookup table with i128::pow"
This reverts commit ba23e8c40c97088a405a36b8f1e1c84146178b73.
* Use Arrow error directly
---
datafusion/common/src/scalar/mod.rs | 301 ++++++++++++++++++++-
.../optimizer/src/simplify_expressions/utils.rs | 88 ++++++
2 files changed, 381 insertions(+), 8 deletions(-)
diff --git a/datafusion/common/src/scalar/mod.rs
b/datafusion/common/src/scalar/mod.rs
index 62ae19fd5c..1ced4ab825 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -74,12 +74,13 @@ use arrow::compute::kernels::numeric::{
add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping,
};
use arrow::datatypes::{
- i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType,
DataType,
- Date32Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type,
- IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano,
IntervalMonthDayNanoType,
- IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType,
- TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
UInt16Type,
- UInt32Type, UInt64Type, UInt8Type, UnionFields, UnionMode,
DECIMAL128_MAX_PRECISION,
+ i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType,
ArrowNativeType,
+ ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type,
Field,
+ Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime,
+ IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType,
IntervalUnit,
+ IntervalYearMonthType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType,
+ TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type,
UInt64Type,
+ UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION,
};
use arrow::util::display::{array_value_to_string, ArrayFormatter,
FormatOptions};
use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array};
@@ -1516,6 +1517,34 @@ impl ScalarValue {
DataType::Float16 =>
ScalarValue::Float16(Some(f16::from_f32(1.0))),
DataType::Float32 => ScalarValue::Float32(Some(1.0)),
DataType::Float64 => ScalarValue::Float64(Some(1.0)),
+ DataType::Decimal128(precision, scale) => {
+ validate_decimal_precision_and_scale::<Decimal128Type>(
+ *precision, *scale,
+ )?;
+ if *scale < 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match i128::from(10).checked_pow(*scale as u32) {
+ Some(value) => {
+ ScalarValue::Decimal128(Some(value), *precision,
*scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
+ DataType::Decimal256(precision, scale) => {
+ validate_decimal_precision_and_scale::<Decimal256Type>(
+ *precision, *scale,
+ )?;
+ if *scale < 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match i256::from(10).checked_pow(*scale as u32) {
+ Some(value) => {
+ ScalarValue::Decimal256(Some(value), *precision,
*scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
_ => {
return _not_impl_err!(
"Can't create an one scalar from data_type
\"{datatype:?}\""
@@ -1534,6 +1563,34 @@ impl ScalarValue {
DataType::Float16 =>
ScalarValue::Float16(Some(f16::from_f32(-1.0))),
DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
+ DataType::Decimal128(precision, scale) => {
+ validate_decimal_precision_and_scale::<Decimal128Type>(
+ *precision, *scale,
+ )?;
+ if *scale < 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match i128::from(10).checked_pow(*scale as u32) {
+ Some(value) => {
+ ScalarValue::Decimal128(Some(-value), *precision,
*scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
+ DataType::Decimal256(precision, scale) => {
+ validate_decimal_precision_and_scale::<Decimal256Type>(
+ *precision, *scale,
+ )?;
+ if *scale < 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match i256::from(10).checked_pow(*scale as u32) {
+ Some(value) => {
+ ScalarValue::Decimal256(Some(-value), *precision,
*scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
_ => {
return _not_impl_err!(
"Can't create a negative one scalar from data_type
\"{datatype:?}\""
@@ -1555,6 +1612,38 @@ impl ScalarValue {
DataType::Float16 =>
ScalarValue::Float16(Some(f16::from_f32(10.0))),
DataType::Float32 => ScalarValue::Float32(Some(10.0)),
DataType::Float64 => ScalarValue::Float64(Some(10.0)),
+ DataType::Decimal128(precision, scale) => {
+ if let Err(err) =
validate_decimal_precision_and_scale::<Decimal128Type>(
+ *precision, *scale,
+ ) {
+ return _internal_err!("Invalid precision and scale {err}");
+ }
+ if *scale <= 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match i128::from(10).checked_pow((*scale + 1) as u32) {
+ Some(value) => {
+ ScalarValue::Decimal128(Some(value), *precision,
*scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
+ DataType::Decimal256(precision, scale) => {
+ if let Err(err) =
validate_decimal_precision_and_scale::<Decimal256Type>(
+ *precision, *scale,
+ ) {
+ return _internal_err!("Invalid precision and scale {err}");
+ }
+ if *scale <= 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match i256::from(10).checked_pow((*scale + 1) as u32) {
+ Some(value) => {
+ ScalarValue::Decimal256(Some(value), *precision,
*scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
_ => {
return _not_impl_err!(
"Can't create a ten scalar from data_type \"{datatype:?}\""
@@ -1924,6 +2013,26 @@ impl ScalarValue {
(Self::Float64(Some(l)), Self::Float64(Some(r))) => {
Some((l - r).abs().round() as _)
}
+ (
+ Self::Decimal128(Some(l), lprecision, lscale),
+ Self::Decimal128(Some(r), rprecision, rscale),
+ ) => {
+ if lprecision == rprecision && lscale == rscale {
+ l.checked_sub(*r)?.checked_abs()?.to_usize()
+ } else {
+ None
+ }
+ }
+ (
+ Self::Decimal256(Some(l), lprecision, lscale),
+ Self::Decimal256(Some(r), rprecision, rscale),
+ ) => {
+ if lprecision == rprecision && lscale == rscale {
+ l.checked_sub(*r)?.checked_abs()?.to_usize()
+ } else {
+ None
+ }
+ }
_ => None,
}
}
@@ -4489,7 +4598,9 @@ mod tests {
};
use arrow::buffer::{Buffer, OffsetBuffer};
use arrow::compute::{is_null, kernels};
- use arrow::datatypes::{ArrowNumericType, Fields, Float64Type};
+ use arrow::datatypes::{
+ ArrowNumericType, Fields, Float64Type, DECIMAL256_MAX_PRECISION,
+ };
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_columns;
use chrono::NaiveDate;
@@ -5225,6 +5336,116 @@ mod tests {
Ok(())
}
+ #[test]
+ fn test_new_one_decimal128() {
+ assert_eq!(
+ ScalarValue::new_one(&DataType::Decimal128(5, 0)).unwrap(),
+ ScalarValue::Decimal128(Some(1), 5, 0)
+ );
+ assert_eq!(
+ ScalarValue::new_one(&DataType::Decimal128(5, 1)).unwrap(),
+ ScalarValue::Decimal128(Some(10), 5, 1)
+ );
+ assert_eq!(
+ ScalarValue::new_one(&DataType::Decimal128(5, 2)).unwrap(),
+ ScalarValue::Decimal128(Some(100), 5, 2)
+ );
+ // More precision
+ assert_eq!(
+ ScalarValue::new_one(&DataType::Decimal128(7, 2)).unwrap(),
+ ScalarValue::Decimal128(Some(100), 7, 2)
+ );
+ // No negative scale
+ assert!(ScalarValue::new_one(&DataType::Decimal128(5, -1)).is_err());
+ // Invalid combination
+ assert!(ScalarValue::new_one(&DataType::Decimal128(0, 2)).is_err());
+ assert!(ScalarValue::new_one(&DataType::Decimal128(5, 7)).is_err());
+ }
+
+ #[test]
+ fn test_new_one_decimal256() {
+ assert_eq!(
+ ScalarValue::new_one(&DataType::Decimal256(5, 0)).unwrap(),
+ ScalarValue::Decimal256(Some(1.into()), 5, 0)
+ );
+ assert_eq!(
+ ScalarValue::new_one(&DataType::Decimal256(5, 1)).unwrap(),
+ ScalarValue::Decimal256(Some(10.into()), 5, 1)
+ );
+ assert_eq!(
+ ScalarValue::new_one(&DataType::Decimal256(5, 2)).unwrap(),
+ ScalarValue::Decimal256(Some(100.into()), 5, 2)
+ );
+ // More precision
+ assert_eq!(
+ ScalarValue::new_one(&DataType::Decimal256(7, 2)).unwrap(),
+ ScalarValue::Decimal256(Some(100.into()), 7, 2)
+ );
+ // No negative scale
+ assert!(ScalarValue::new_one(&DataType::Decimal256(5, -1)).is_err());
+ // Invalid combination
+ assert!(ScalarValue::new_one(&DataType::Decimal256(0, 2)).is_err());
+ assert!(ScalarValue::new_one(&DataType::Decimal256(5, 7)).is_err());
+ }
+
+ #[test]
+ fn test_new_ten_decimal128() {
+ assert_eq!(
+ ScalarValue::new_ten(&DataType::Decimal128(5, 1)).unwrap(),
+ ScalarValue::Decimal128(Some(100), 5, 1)
+ );
+ assert_eq!(
+ ScalarValue::new_ten(&DataType::Decimal128(5, 2)).unwrap(),
+ ScalarValue::Decimal128(Some(1000), 5, 2)
+ );
+ // More precision
+ assert_eq!(
+ ScalarValue::new_ten(&DataType::Decimal128(7, 2)).unwrap(),
+ ScalarValue::Decimal128(Some(1000), 7, 2)
+ );
+ // No negative or zero scale
+ assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 0)).is_err());
+ assert!(ScalarValue::new_ten(&DataType::Decimal128(5, -1)).is_err());
+ // Invalid combination
+ assert!(ScalarValue::new_ten(&DataType::Decimal128(0, 2)).is_err());
+ assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 7)).is_err());
+ }
+
+ #[test]
+ fn test_new_ten_decimal256() {
+ assert_eq!(
+ ScalarValue::new_ten(&DataType::Decimal256(5, 1)).unwrap(),
+ ScalarValue::Decimal256(Some(100.into()), 5, 1)
+ );
+ assert_eq!(
+ ScalarValue::new_ten(&DataType::Decimal256(5, 2)).unwrap(),
+ ScalarValue::Decimal256(Some(1000.into()), 5, 2)
+ );
+ // More precision
+ assert_eq!(
+ ScalarValue::new_ten(&DataType::Decimal256(7, 2)).unwrap(),
+ ScalarValue::Decimal256(Some(1000.into()), 7, 2)
+ );
+ // No negative or zero scale
+ assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 0)).is_err());
+ assert!(ScalarValue::new_ten(&DataType::Decimal256(5, -1)).is_err());
+ // Invalid combination
+ assert!(ScalarValue::new_ten(&DataType::Decimal256(0, 2)).is_err());
+ assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 7)).is_err());
+ }
+
+ #[test]
+ fn test_new_negative_one_decimal128() {
+ assert_eq!(
+ ScalarValue::new_negative_one(&DataType::Decimal128(5,
0)).unwrap(),
+ ScalarValue::Decimal128(Some(-1), 5, 0)
+ );
+ assert_eq!(
+ ScalarValue::new_negative_one(&DataType::Decimal128(5,
2)).unwrap(),
+ ScalarValue::Decimal128(Some(-100), 5, 2)
+ );
+ }
+
#[test]
fn test_list_partial_cmp() {
let a =
@@ -7275,6 +7496,26 @@ mod tests {
ScalarValue::Float64(Some(-9.9)),
5,
),
+ (
+ ScalarValue::Decimal128(Some(10), 1, 0),
+ ScalarValue::Decimal128(Some(5), 1, 0),
+ 5,
+ ),
+ (
+ ScalarValue::Decimal128(Some(5), 1, 0),
+ ScalarValue::Decimal128(Some(10), 1, 0),
+ 5,
+ ),
+ (
+ ScalarValue::Decimal256(Some(10.into()), 1, 0),
+ ScalarValue::Decimal256(Some(5.into()), 1, 0),
+ 5,
+ ),
+ (
+ ScalarValue::Decimal256(Some(5.into()), 1, 0),
+ ScalarValue::Decimal256(Some(10.into()), 1, 0),
+ 5,
+ ),
];
for (lhs, rhs, expected) in cases.iter() {
let distance = lhs.distance(rhs).unwrap();
@@ -7282,6 +7523,24 @@ mod tests {
}
}
+ #[test]
+ fn test_distance_none() {
+ let cases = [
+ (
+ ScalarValue::Decimal128(Some(i128::MAX),
DECIMAL128_MAX_PRECISION, 0),
+ ScalarValue::Decimal128(Some(-i128::MAX),
DECIMAL128_MAX_PRECISION, 0),
+ ),
+ (
+ ScalarValue::Decimal256(Some(i256::MAX),
DECIMAL256_MAX_PRECISION, 0),
+ ScalarValue::Decimal256(Some(-i256::MAX),
DECIMAL256_MAX_PRECISION, 0),
+ ),
+ ];
+ for (lhs, rhs) in cases.iter() {
+ let distance = lhs.distance(rhs);
+ assert!(distance.is_none(), "{lhs} vs {rhs}");
+ }
+ }
+
#[test]
fn test_scalar_distance_invalid() {
let cases = [
@@ -7323,7 +7582,33 @@ mod tests {
(ScalarValue::Date64(Some(0)), ScalarValue::Date64(Some(1))),
(
ScalarValue::Decimal128(Some(123), 5, 5),
- ScalarValue::Decimal128(Some(120), 5, 5),
+ ScalarValue::Decimal128(Some(120), 5, 3),
+ ),
+ (
+ ScalarValue::Decimal128(Some(123), 5, 5),
+ ScalarValue::Decimal128(Some(120), 3, 5),
+ ),
+ (
+ ScalarValue::Decimal256(Some(123.into()), 5, 5),
+ ScalarValue::Decimal256(Some(120.into()), 3, 5),
+ ),
+ // Distance 2 * 2^50 is larger than usize
+ (
+ ScalarValue::Decimal256(
+ Some(i256::from_parts(0, 2_i64.pow(50).into())),
+ 1,
+ 0,
+ ),
+ ScalarValue::Decimal256(
+ Some(i256::from_parts(0, (-(2_i64).pow(50)).into())),
+ 1,
+ 0,
+ ),
+ ),
+ // Distance overflow
+ (
+ ScalarValue::Decimal256(Some(i256::from_parts(0, i128::MAX)),
1, 0),
+ ScalarValue::Decimal256(Some(i256::from_parts(0, -i128::MAX)),
1, 0),
),
];
for (lhs, rhs) in cases {
diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs
b/datafusion/optimizer/src/simplify_expressions/utils.rs
index 4df0e125eb..2f7dadceba 100644
--- a/datafusion/optimizer/src/simplify_expressions/utils.rs
+++ b/datafusion/optimizer/src/simplify_expressions/utils.rs
@@ -17,6 +17,7 @@
//! Utility functions for expression simplification
+use arrow::datatypes::i256;
use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_expr::{
expr::{Between, BinaryExpr, InList},
@@ -150,6 +151,11 @@ pub fn is_zero(s: &Expr) -> bool {
Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 0. => true,
Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 0. => true,
Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s), _) if *v == 0
=> true,
+ Expr::Literal(ScalarValue::Decimal256(Some(v), _p, _s), _)
+ if *v == i256::ZERO =>
+ {
+ true
+ }
_ => false,
}
}
@@ -173,6 +179,13 @@ pub fn is_one(s: &Expr) -> bool {
.map(|x| x == v)
.unwrap_or_default()
}
+ Expr::Literal(ScalarValue::Decimal256(Some(v), _p, s), _) => {
+ *s >= 0
+ && match i256::from(10).checked_pow(*s as u32) {
+ Some(res) => res == *v,
+ None => false,
+ }
+ }
_ => false,
}
}
@@ -365,3 +378,78 @@ pub fn distribute_negation(expr: Expr) -> Expr {
_ => Expr::Negative(Box::new(expr)),
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::{is_one, is_zero};
+ use arrow::datatypes::i256;
+ use datafusion_common::ScalarValue;
+ use datafusion_expr::lit;
+
+ #[test]
+ fn test_is_zero() {
+ assert!(is_zero(&lit(ScalarValue::Int8(Some(0)))));
+ assert!(is_zero(&lit(ScalarValue::Float32(Some(0.0)))));
+ assert!(is_zero(&lit(ScalarValue::Decimal128(
+ Some(i128::from(0)),
+ 9,
+ 0
+ ))));
+ assert!(is_zero(&lit(ScalarValue::Decimal128(
+ Some(i128::from(0)),
+ 9,
+ 5
+ ))));
+ assert!(is_zero(&lit(ScalarValue::Decimal256(
+ Some(i256::ZERO),
+ 9,
+ 0
+ ))));
+ assert!(is_zero(&lit(ScalarValue::Decimal256(
+ Some(i256::ZERO),
+ 9,
+ 5
+ ))));
+ }
+
+ #[test]
+ fn test_is_one() {
+ assert!(is_one(&lit(ScalarValue::Int8(Some(1)))));
+ assert!(is_one(&lit(ScalarValue::Float32(Some(1.0)))));
+ assert!(is_one(&lit(ScalarValue::Decimal128(
+ Some(i128::from(1)),
+ 9,
+ 0
+ ))));
+ assert!(is_one(&lit(ScalarValue::Decimal128(
+ Some(i128::from(10)),
+ 9,
+ 1
+ ))));
+ assert!(is_one(&lit(ScalarValue::Decimal128(
+ Some(i128::from(100)),
+ 9,
+ 2
+ ))));
+ assert!(is_one(&lit(ScalarValue::Decimal256(
+ Some(i256::from(1)),
+ 9,
+ 0
+ ))));
+ assert!(is_one(&lit(ScalarValue::Decimal256(
+ Some(i256::from(10)),
+ 9,
+ 1
+ ))));
+ assert!(is_one(&lit(ScalarValue::Decimal256(
+ Some(i256::from(100)),
+ 9,
+ 2
+ ))));
+ assert!(!is_one(&lit(ScalarValue::Decimal256(
+ Some(i256::from(100)),
+ 9,
+ -1
+ ))));
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]