This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ed045d9895 Add Decimal256 to `ScalarValue` (#7048)
ed045d9895 is described below
commit ed045d989501946d9a73d8e1c3b884f279a0a00d
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Jul 25 05:22:57 2023 -0700
Add Decimal256 to `ScalarValue` (#7048)
* Initial Support Decimal256 ScalarValue
* Add Decimal256 to proto
* Update protobuf code
* Add Decimal256 to from_proto
* Update datafusion/expr/src/type_coercion/aggregates.rs
Co-authored-by: Daniël Heres <[email protected]>
---------
Co-authored-by: Daniël Heres <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/common/src/cast.rs | 6 +
datafusion/common/src/scalar.rs | 161 +++++++++++++++++++--
.../sqllogictests/test_files/arrow_typeof.slt | 22 ++-
.../tests/sqllogictests/test_files/decimal.slt | 9 ++
datafusion/expr/src/type_coercion/aggregates.rs | 19 +++
datafusion/physical-expr/src/aggregate/average.rs | 4 +-
datafusion/physical-expr/src/aggregate/sum.rs | 22 ++-
datafusion/proto/proto/datafusion.proto | 8 +
datafusion/proto/src/generated/pbjson.rs | 145 +++++++++++++++++++
datafusion/proto/src/generated/prost.rs | 14 +-
datafusion/proto/src/logical_plan/from_proto.rs | 10 +-
datafusion/proto/src/logical_plan/to_proto.rs | 18 +++
12 files changed, 411 insertions(+), 27 deletions(-)
diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs
index 04ae32ec35..4356f36b18 100644
--- a/datafusion/common/src/cast.rs
+++ b/datafusion/common/src/cast.rs
@@ -34,6 +34,7 @@ use arrow::{
},
datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType},
};
+use arrow_array::Decimal256Array;
// Downcast ArrayRef to Date32Array
pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> {
@@ -65,6 +66,11 @@ pub fn as_decimal128_array(array: &dyn Array) ->
Result<&Decimal128Array> {
Ok(downcast_value!(array, Decimal128Array))
}
+// Downcast ArrayRef to Decimal256Array
+pub fn as_decimal256_array(array: &dyn Array) -> Result<&Decimal256Array> {
+ Ok(downcast_value!(array, Decimal256Array))
+}
+
// Downcast ArrayRef to Float32Array
pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array> {
Ok(downcast_value!(array, Float32Array))
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 99ff5f3384..4a7767023f 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -26,14 +26,14 @@ use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
use crate::cast::{
- as_decimal128_array, as_dictionary_array, as_fixed_size_binary_array,
- as_fixed_size_list_array, as_list_array, as_struct_array,
+ as_decimal128_array, as_decimal256_array, as_dictionary_array,
+ as_fixed_size_binary_array, as_fixed_size_list_array, as_list_array,
as_struct_array,
};
use crate::delta::shift_months;
use crate::error::{DataFusionError, Result};
use arrow::buffer::NullBuffer;
use arrow::compute::nullif;
-use arrow::datatypes::{FieldRef, Fields, SchemaBuilder};
+use arrow::datatypes::{i256, FieldRef, Fields, SchemaBuilder};
use arrow::{
array::*,
compute::kernels::cast::{cast_with_options, CastOptions},
@@ -47,6 +47,7 @@ use arrow::{
},
};
use arrow_array::timezone::Tz;
+use arrow_array::ArrowNativeTypeOp;
use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime};
// Constants we use throughout this file:
@@ -75,6 +76,8 @@ pub enum ScalarValue {
Float64(Option<f64>),
/// 128bit decimal, using the i128 to represent the decimal, precision
scale
Decimal128(Option<i128>, u8, i8),
+ /// 256bit decimal, using the i256 to represent the decimal, precision
scale
+ Decimal256(Option<i256>, u8, i8),
/// signed 8bit int
Int8(Option<i8>),
/// signed 16bit int
@@ -160,6 +163,10 @@ impl PartialEq for ScalarValue {
v1.eq(v2) && p1.eq(p2) && s1.eq(s2)
}
(Decimal128(_, _, _), _) => false,
+ (Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => {
+ v1.eq(v2) && p1.eq(p2) && s1.eq(s2)
+ }
+ (Decimal256(_, _, _), _) => false,
(Boolean(v1), Boolean(v2)) => v1.eq(v2),
(Boolean(_), _) => false,
(Float32(v1), Float32(v2)) => match (v1, v2) {
@@ -283,6 +290,15 @@ impl PartialOrd for ScalarValue {
}
}
(Decimal128(_, _, _), _) => None,
+ (Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => {
+ if p1.eq(p2) && s1.eq(s2) {
+ v1.partial_cmp(v2)
+ } else {
+ // Two decimal values can be compared if they have the
same precision and scale.
+ None
+ }
+ }
+ (Decimal256(_, _, _), _) => None,
(Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2),
(Boolean(_), _) => None,
(Float32(v1), Float32(v2)) => match (v1, v2) {
@@ -1038,6 +1054,7 @@ macro_rules! impl_op_arithmetic {
get_sign!($OPERATION),
true,
)))),
+ // todo: Add Decimal256 support
_ => Err(DataFusionError::Internal(format!(
"Operator {} is not implemented for types {:?} and {:?}",
stringify!($OPERATION),
@@ -1516,6 +1533,11 @@ impl std::hash::Hash for ScalarValue {
p.hash(state);
s.hash(state)
}
+ Decimal256(v, p, s) => {
+ v.hash(state);
+ p.hash(state);
+ s.hash(state)
+ }
Boolean(v) => v.hash(state),
Float32(v) => v.map(Fl).hash(state),
Float64(v) => v.map(Fl).hash(state),
@@ -1994,6 +2016,9 @@ impl ScalarValue {
ScalarValue::Decimal128(_, precision, scale) => {
DataType::Decimal128(*precision, *scale)
}
+ ScalarValue::Decimal256(_, precision, scale) => {
+ DataType::Decimal256(*precision, *scale)
+ }
ScalarValue::TimestampSecond(_, tz_opt) => {
DataType::Timestamp(TimeUnit::Second, tz_opt.clone())
}
@@ -2083,6 +2108,9 @@ impl ScalarValue {
ScalarValue::Decimal128(Some(v), precision, scale) => {
Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale))
}
+ ScalarValue::Decimal256(Some(v), precision, scale) => Ok(
+ ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision,
*scale),
+ ),
value => Err(DataFusionError::Internal(format!(
"Can not run arithmetic negative on scalar value {value:?}"
))),
@@ -2154,6 +2182,7 @@ impl ScalarValue {
ScalarValue::Float32(v) => v.is_none(),
ScalarValue::Float64(v) => v.is_none(),
ScalarValue::Decimal128(v, _, _) => v.is_none(),
+ ScalarValue::Decimal256(v, _, _) => v.is_none(),
ScalarValue::Int8(v) => v.is_none(),
ScalarValue::Int16(v) => v.is_none(),
ScalarValue::Int32(v) => v.is_none(),
@@ -2415,10 +2444,10 @@ impl ScalarValue {
ScalarValue::iter_to_decimal_array(scalars, *precision,
*scale)?;
Arc::new(decimal_array)
}
- DataType::Decimal256(_, _) => {
- return Err(DataFusionError::Internal(
- "Decimal256 is not supported for ScalarValue".to_string(),
- ));
+ DataType::Decimal256(precision, scale) => {
+ let decimal_array =
+ ScalarValue::iter_to_decimal256_array(scalars, *precision,
*scale)?;
+ Arc::new(decimal_array)
}
DataType::Null => ScalarValue::iter_to_null_array(scalars),
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
@@ -2680,6 +2709,22 @@ impl ScalarValue {
Ok(array)
}
+ fn iter_to_decimal256_array(
+ scalars: impl IntoIterator<Item = ScalarValue>,
+ precision: u8,
+ scale: i8,
+ ) -> Result<Decimal256Array> {
+ let array = scalars
+ .into_iter()
+ .map(|element: ScalarValue| match element {
+ ScalarValue::Decimal256(v1, _, _) => v1,
+ _ => unreachable!(),
+ })
+ .collect::<Decimal256Array>()
+ .with_precision_and_scale(precision, scale)?;
+ Ok(array)
+ }
+
fn iter_to_array_list(
scalars: impl IntoIterator<Item = ScalarValue>,
data_type: &DataType,
@@ -2764,12 +2809,28 @@ impl ScalarValue {
}
}
+ fn build_decimal256_array(
+ value: Option<i256>,
+ precision: u8,
+ scale: i8,
+ size: usize,
+ ) -> Decimal256Array {
+ std::iter::repeat(value)
+ .take(size)
+ .collect::<Decimal256Array>()
+ .with_precision_and_scale(precision, scale)
+ .unwrap()
+ }
+
/// Converts a scalar value into an array of `size` rows.
pub fn to_array_of_size(&self, size: usize) -> ArrayRef {
match self {
ScalarValue::Decimal128(e, precision, scale) => Arc::new(
ScalarValue::build_decimal_array(*e, *precision, *scale, size),
),
+ ScalarValue::Decimal256(e, precision, scale) => Arc::new(
+ ScalarValue::build_decimal256_array(*e, *precision, *scale,
size),
+ ),
ScalarValue::Boolean(e) => {
Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef
}
@@ -3044,12 +3105,28 @@ impl ScalarValue {
precision: u8,
scale: i8,
) -> Result<ScalarValue> {
- let array = as_decimal128_array(array)?;
- if array.is_null(index) {
- Ok(ScalarValue::Decimal128(None, precision, scale))
- } else {
- let value = array.value(index);
- Ok(ScalarValue::Decimal128(Some(value), precision, scale))
+ match array.data_type() {
+ DataType::Decimal128(_, _) => {
+ let array = as_decimal128_array(array)?;
+ if array.is_null(index) {
+ Ok(ScalarValue::Decimal128(None, precision, scale))
+ } else {
+ let value = array.value(index);
+ Ok(ScalarValue::Decimal128(Some(value), precision, scale))
+ }
+ }
+ DataType::Decimal256(_, _) => {
+ let array = as_decimal256_array(array)?;
+ if array.is_null(index) {
+ Ok(ScalarValue::Decimal256(None, precision, scale))
+ } else {
+ let value = array.value(index);
+ Ok(ScalarValue::Decimal256(Some(value), precision, scale))
+ }
+ }
+ _ => Err(DataFusionError::Internal(
+ "Unsupported decimal type".to_string(),
+ )),
}
}
@@ -3067,6 +3144,11 @@ impl ScalarValue {
array, index, *precision, *scale,
)?
}
+ DataType::Decimal256(precision, scale) => {
+ ScalarValue::get_decimal_value_from_array(
+ array, index, *precision, *scale,
+ )?
+ }
DataType::Boolean => typed_cast!(array, index, BooleanArray,
Boolean),
DataType::Float64 => typed_cast!(array, index, Float64Array,
Float64),
DataType::Float32 => typed_cast!(array, index, Float32Array,
Float32),
@@ -3265,6 +3347,25 @@ impl ScalarValue {
}
}
+ fn eq_array_decimal256(
+ array: &ArrayRef,
+ index: usize,
+ value: Option<&i256>,
+ precision: u8,
+ scale: i8,
+ ) -> Result<bool> {
+ let array = as_decimal256_array(array)?;
+ if array.precision() != precision || array.scale() != scale {
+ return Ok(false);
+ }
+ let is_null = array.is_null(index);
+ if let Some(v) = value {
+ Ok(!array.is_null(index) && array.value(index) == *v)
+ } else {
+ Ok(is_null)
+ }
+ }
+
/// Compares a single row of array @ index for equality with self,
/// in an optimized fashion.
///
@@ -3294,6 +3395,16 @@ impl ScalarValue {
)
.unwrap()
}
+ ScalarValue::Decimal256(v, precision, scale) => {
+ ScalarValue::eq_array_decimal256(
+ array,
+ index,
+ v.as_ref(),
+ *precision,
+ *scale,
+ )
+ .unwrap()
+ }
ScalarValue::Boolean(val) => {
eq_array_primitive!(array, index, BooleanArray, val)
}
@@ -3416,6 +3527,7 @@ impl ScalarValue {
| ScalarValue::Float32(_)
| ScalarValue::Float64(_)
| ScalarValue::Decimal128(_, _, _)
+ | ScalarValue::Decimal256(_, _, _)
| ScalarValue::Int8(_)
| ScalarValue::Int16(_)
| ScalarValue::Int32(_)
@@ -3647,6 +3759,22 @@ impl TryFrom<ScalarValue> for i128 {
}
}
+// special implementation for i256 because of Decimal128
+impl TryFrom<ScalarValue> for i256 {
+ type Error = DataFusionError;
+
+ fn try_from(value: ScalarValue) -> Result<Self> {
+ match value {
+ ScalarValue::Decimal256(Some(inner_value), _, _) =>
Ok(inner_value),
+ _ => Err(DataFusionError::Internal(format!(
+ "Cannot convert {:?} to {}",
+ value,
+ std::any::type_name::<Self>()
+ ))),
+ }
+ }
+}
+
impl_try_from!(UInt8, u8);
impl_try_from!(UInt16, u16);
impl_try_from!(UInt32, u32);
@@ -3684,6 +3812,9 @@ impl TryFrom<&DataType> for ScalarValue {
DataType::Decimal128(precision, scale) => {
ScalarValue::Decimal128(None, *precision, *scale)
}
+ DataType::Decimal256(precision, scale) => {
+ ScalarValue::Decimal256(None, *precision, *scale)
+ }
DataType::Utf8 => ScalarValue::Utf8(None),
DataType::LargeUtf8 => ScalarValue::LargeUtf8(None),
DataType::Binary => ScalarValue::Binary(None),
@@ -3753,6 +3884,9 @@ impl fmt::Display for ScalarValue {
ScalarValue::Decimal128(v, p, s) => {
write!(f, "{v:?},{p:?},{s:?}")?;
}
+ ScalarValue::Decimal256(v, p, s) => {
+ write!(f, "{v:?},{p:?},{s:?}")?;
+ }
ScalarValue::Boolean(e) => format_option!(f, e)?,
ScalarValue::Float32(e) => format_option!(f, e)?,
ScalarValue::Float64(e) => format_option!(f, e)?,
@@ -3830,6 +3964,7 @@ impl fmt::Debug for ScalarValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ScalarValue::Decimal128(_, _, _) => write!(f,
"Decimal128({self})"),
+ ScalarValue::Decimal256(_, _, _) => write!(f,
"Decimal256({self})"),
ScalarValue::Boolean(_) => write!(f, "Boolean({self})"),
ScalarValue::Float32(_) => write!(f, "Float32({self})"),
ScalarValue::Float64(_) => write!(f, "Float64({self})"),
diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
index 4a3d39bdeb..5c82c7e009 100644
--- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
@@ -180,23 +180,29 @@ drop table foo
statement ok
create table foo as select
- arrow_cast(100, 'Decimal128(5,2)') as col_d128
- -- Can't make a decimal 156:
- -- This feature is not implemented: Can't create a scalar from array of type
"Decimal256(3, 2)"
- --arrow_cast(100, 'Decimal256(5,2)') as col_d256
+ arrow_cast(100, 'Decimal128(5,2)') as col_d128,
+ arrow_cast(100, 'Decimal256(5,2)') as col_d256
;
## Ensure each column in the table has the expected type
-query T
+query TT
SELECT
- arrow_typeof(col_d128)
- -- arrow_typeof(col_d256),
+ arrow_typeof(col_d128),
+ arrow_typeof(col_d256)
FROM foo;
----
-Decimal128(5, 2)
+Decimal128(5, 2) Decimal256(5, 2)
+
+query RR
+SELECT
+ col_d128,
+ col_d256
+ FROM foo;
+----
+100 100.00
statement ok
drop table foo
diff --git a/datafusion/core/tests/sqllogictests/test_files/decimal.slt
b/datafusion/core/tests/sqllogictests/test_files/decimal.slt
index f413517741..8fd08f87c8 100644
--- a/datafusion/core/tests/sqllogictests/test_files/decimal.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/decimal.slt
@@ -612,3 +612,12 @@ insert into foo VALUES (1, 5);
query error DataFusion error: Arrow error: Compute error: Overflow happened
on: 100000000000000000000 \* 100000000000000000000000000000000000000
select a / b from foo;
+
+statement ok
+create table t as values (arrow_cast(123, 'Decimal256(5,2)'));
+
+query error DataFusion error: Internal error: Operator \+ is not implemented
for types Decimal256\(None,15,2\) and Decimal256\(Some\(12300\),15,2\)\. This
was likely caused by a bug in DataFusion's code and we would welcome that you
file an bug report in our issue tracker
+select AVG(column1) from t;
+
+statement ok
+drop table t;
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs
b/datafusion/expr/src/type_coercion/aggregates.rs
index 1fccdcbd2c..dec2eb7f12 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -17,6 +17,7 @@
use arrow::datatypes::{
DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
+ DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};
use datafusion_common::{DataFusionError, Result};
use std::ops::Deref;
@@ -360,6 +361,12 @@ pub fn sum_return_type(arg_type: &DataType) ->
Result<DataType> {
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal128(new_precision, *scale))
}
+ DataType::Decimal256(precision, scale) => {
+ // in the spark, the result type is DECIMAL(min(38,precision+10),
s)
+ // ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+ let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
+ Ok(DataType::Decimal256(new_precision, *scale))
+ }
DataType::Dictionary(_, dict_value_type) => {
sum_return_type(dict_value_type.as_ref())
}
@@ -423,6 +430,13 @@ pub fn avg_return_type(arg_type: &DataType) ->
Result<DataType> {
let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal128(new_precision, new_scale))
}
+ DataType::Decimal256(precision, scale) => {
+ // in the spark, the result type is DECIMAL(min(38,precision+4),
min(38,scale+4)).
+ // ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+ let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
+ let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
+ Ok(DataType::Decimal256(new_precision, new_scale))
+ }
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
DataType::Dictionary(_, dict_value_type) => {
avg_return_type(dict_value_type.as_ref())
@@ -441,6 +455,11 @@ pub fn avg_sum_type(arg_type: &DataType) ->
Result<DataType> {
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal128(new_precision, *scale))
}
+ DataType::Decimal256(precision, scale) => {
+ // in Spark the sum type of avg is DECIMAL(min(38,precision+10), s)
+ let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
+ Ok(DataType::Decimal256(new_precision, *scale))
+ }
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
DataType::Dictionary(_, dict_value_type) => {
avg_sum_type(dict_value_type.as_ref())
diff --git a/datafusion/physical-expr/src/aggregate/average.rs
b/datafusion/physical-expr/src/aggregate/average.rs
index a1d77a2d88..9c01093edf 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -77,12 +77,12 @@ impl Avg {
// the internal sum data type of avg just support FLOAT64 and Decimal
data type.
assert!(matches!(
sum_data_type,
- DataType::Float64 | DataType::Decimal128(_, _)
+ DataType::Float64 | DataType::Decimal128(_, _) |
DataType::Decimal256(_, _)
));
// the result of avg just support FLOAT64 and Decimal data type.
assert!(matches!(
rt_data_type,
- DataType::Float64 | DataType::Decimal128(_, _)
+ DataType::Float64 | DataType::Decimal128(_, _) |
DataType::Decimal256(_, _)
));
Self {
name: name.into(),
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs
b/datafusion/physical-expr/src/aggregate/sum.rs
index 45e2be7fb4..9ac90cef4b 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -28,6 +28,7 @@ use crate::expressions::format_state_name;
use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr};
use arrow::array::Array;
use arrow::array::Decimal128Array;
+use arrow::array::Decimal256Array;
use arrow::compute;
use arrow::compute::kernels::cast;
use arrow::datatypes::DataType;
@@ -39,8 +40,8 @@ use arrow::{
datatypes::Field,
};
use arrow_array::types::{
- Decimal128Type, Float32Type, Float64Type, Int32Type, Int64Type, UInt32Type,
- UInt64Type,
+ Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type,
Int64Type,
+ UInt32Type, UInt64Type,
};
use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue};
use datafusion_expr::Accumulator;
@@ -169,6 +170,10 @@ impl AggregateExpr for Sum {
instantiate_primitive_accumulator!(self, Decimal128Type, |x,
y| x
.add_assign(y))
}
+ DataType::Decimal256(_, _) => {
+ instantiate_primitive_accumulator!(self, Decimal256Type, |x,
y| *x =
+ *x + y)
+ }
_ => Err(DataFusionError::NotImplemented(format!(
"GroupsAccumulator not supported for {}: {}",
self.name, self.data_type
@@ -250,6 +255,16 @@ fn sum_decimal_batch(values: &ArrayRef, precision: u8,
scale: i8) -> Result<Scal
Ok(ScalarValue::Decimal128(result, precision, scale))
}
+fn sum_decimal256_batch(
+ values: &ArrayRef,
+ precision: u8,
+ scale: i8,
+) -> Result<ScalarValue> {
+ let array = downcast_value!(values, Decimal256Array);
+ let result = compute::sum(array);
+ Ok(ScalarValue::Decimal256(result, precision, scale))
+}
+
// sums the array and returns a ScalarValue of its corresponding type.
pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) ->
Result<ScalarValue> {
// TODO refine the cast kernel in arrow-rs
@@ -263,6 +278,9 @@ pub(crate) fn sum_batch(values: &ArrayRef, sum_type:
&DataType) -> Result<Scalar
DataType::Decimal128(precision, scale) => {
sum_decimal_batch(values, *precision, *scale)?
}
+ DataType::Decimal256(precision, scale) => {
+ sum_decimal256_batch(values, *precision, *scale)?
+ }
DataType::Float64 => typed_sum_delta_batch!(values, Float64Array,
Float64),
DataType::Float32 => typed_sum_delta_batch!(values, Float32Array,
Float32),
DataType::Int64 => typed_sum_delta_batch!(values, Int64Array, Int64),
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 8192a403d3..f7247effdd 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -908,6 +908,8 @@ message ScalarValue{
//WAS: ScalarType null_list_value = 18;
Decimal128 decimal128_value = 20;
+ Decimal256 decimal256_value = 39;
+
int64 date_64_value = 21;
int32 interval_yearmonth_value = 24;
int64 interval_daytime_value = 25;
@@ -934,6 +936,12 @@ message Decimal128{
int64 s = 3;
}
+message Decimal256{
+ bytes value = 1;
+ int64 p = 2;
+ int64 s = 3;
+}
+
// Serialized data type
message ArrowType{
oneof arrow_type_enum {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 05bfbd089d..aaf6bb97bb 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -4983,6 +4983,137 @@ impl<'de> serde::Deserialize<'de> for Decimal128 {
deserializer.deserialize_struct("datafusion.Decimal128", FIELDS,
GeneratedVisitor)
}
}
+impl serde::Serialize for Decimal256 {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if !self.value.is_empty() {
+ len += 1;
+ }
+ if self.p != 0 {
+ len += 1;
+ }
+ if self.s != 0 {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.Decimal256", len)?;
+ if !self.value.is_empty() {
+ struct_ser.serialize_field("value",
pbjson::private::base64::encode(&self.value).as_str())?;
+ }
+ if self.p != 0 {
+ struct_ser.serialize_field("p",
ToString::to_string(&self.p).as_str())?;
+ }
+ if self.s != 0 {
+ struct_ser.serialize_field("s",
ToString::to_string(&self.s).as_str())?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for Decimal256 {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "value",
+ "p",
+ "s",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Value,
+ P,
+ S,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "value" => Ok(GeneratedField::Value),
+ "p" => Ok(GeneratedField::P),
+ "s" => Ok(GeneratedField::S),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = Decimal256;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.Decimal256")
+ }
+
+ fn visit_map<V>(self, mut map: V) ->
std::result::Result<Decimal256, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut value__ = None;
+ let mut p__ = None;
+ let mut s__ = None;
+ while let Some(k) = map.next_key()? {
+ match k {
+ GeneratedField::Value => {
+ if value__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("value"));
+ }
+ value__ =
+
Some(map.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
+ ;
+ }
+ GeneratedField::P => {
+ if p__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("p"));
+ }
+ p__ =
+
Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+ ;
+ }
+ GeneratedField::S => {
+ if s__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("s"));
+ }
+ s__ =
+
Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+ ;
+ }
+ }
+ }
+ Ok(Decimal256 {
+ value: value__.unwrap_or_default(),
+ p: p__.unwrap_or_default(),
+ s: s__.unwrap_or_default(),
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.Decimal256", FIELDS,
GeneratedVisitor)
+ }
+}
impl serde::Serialize for DfField {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
@@ -19125,6 +19256,9 @@ impl serde::Serialize for ScalarValue {
scalar_value::Value::Decimal128Value(v) => {
struct_ser.serialize_field("decimal128Value", v)?;
}
+ scalar_value::Value::Decimal256Value(v) => {
+ struct_ser.serialize_field("decimal256Value", v)?;
+ }
scalar_value::Value::Date64Value(v) => {
struct_ser.serialize_field("date64Value",
ToString::to_string(&v).as_str())?;
}
@@ -19218,6 +19352,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
"listValue",
"decimal128_value",
"decimal128Value",
+ "decimal256_value",
+ "decimal256Value",
"date_64_value",
"date64Value",
"interval_yearmonth_value",
@@ -19270,6 +19406,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
Time32Value,
ListValue,
Decimal128Value,
+ Decimal256Value,
Date64Value,
IntervalYearmonthValue,
IntervalDaytimeValue,
@@ -19324,6 +19461,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
"time32Value" | "time32_value" =>
Ok(GeneratedField::Time32Value),
"listValue" | "list_value" =>
Ok(GeneratedField::ListValue),
"decimal128Value" | "decimal128_value" =>
Ok(GeneratedField::Decimal128Value),
+ "decimal256Value" | "decimal256_value" =>
Ok(GeneratedField::Decimal256Value),
"date64Value" | "date_64_value" =>
Ok(GeneratedField::Date64Value),
"intervalYearmonthValue" |
"interval_yearmonth_value" => Ok(GeneratedField::IntervalYearmonthValue),
"intervalDaytimeValue" | "interval_daytime_value"
=> Ok(GeneratedField::IntervalDaytimeValue),
@@ -19471,6 +19609,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
return
Err(serde::de::Error::duplicate_field("decimal128Value"));
}
value__ =
map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal128Value)
+;
+ }
+ GeneratedField::Decimal256Value => {
+ if value__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("decimal256Value"));
+ }
+ value__ =
map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal256Value)
;
}
GeneratedField::Date64Value => {
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index f50754494d..e1ad6acec8 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1097,7 +1097,7 @@ pub struct ScalarFixedSizeBinary {
pub struct ScalarValue {
#[prost(
oneof = "scalar_value::Value",
- tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20,
21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34"
+ tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20,
39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34"
)]
pub value: ::core::option::Option<scalar_value::Value>,
}
@@ -1146,6 +1146,8 @@ pub mod scalar_value {
ListValue(super::ScalarListValue),
#[prost(message, tag = "20")]
Decimal128Value(super::Decimal128),
+ #[prost(message, tag = "39")]
+ Decimal256Value(super::Decimal256),
#[prost(int64, tag = "21")]
Date64Value(i64),
#[prost(int32, tag = "24")]
@@ -1188,6 +1190,16 @@ pub struct Decimal128 {
#[prost(int64, tag = "3")]
pub s: i64,
}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct Decimal256 {
+ #[prost(bytes = "vec", tag = "1")]
+ pub value: ::prost::alloc::vec::Vec<u8>,
+ #[prost(int64, tag = "2")]
+ pub p: i64,
+ #[prost(int64, tag = "3")]
+ pub s: i64,
+}
/// Serialized data type
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 674588692d..71a1bf87db 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -26,7 +26,7 @@ use crate::protobuf::{
OptimizedPhysicalPlanType, PlaceholderNode, RollupNode,
};
use arrow::datatypes::{
- DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit,
+ i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema,
TimeUnit,
UnionFields, UnionMode,
};
use datafusion::execution::registry::FunctionRegistry;
@@ -648,6 +648,14 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
val.s as i8,
)
}
+ Value::Decimal256Value(val) => {
+ let array = vec_to_array(val.value.clone());
+ Self::Decimal256(
+ Some(i256::from_be_bytes(array)),
+ val.p as u8,
+ val.s as i8,
+ )
+ }
Value::Date64Value(v) => Self::Date64(Some(*v)),
Value::Time32Value(v) => {
let time_value =
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 072bc84d54..f1a9615761 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1148,6 +1148,24 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
)),
}),
},
+ ScalarValue::Decimal256(val, p, s) => match *val {
+ Some(v) => {
+ let array = v.to_be_bytes();
+ let vec_val: Vec<u8> = array.to_vec();
+ Ok(protobuf::ScalarValue {
+ value:
Some(Value::Decimal256Value(protobuf::Decimal256 {
+ value: vec_val,
+ p: *p as i64,
+ s: *s as i64,
+ })),
+ })
+ }
+ None => Ok(protobuf::ScalarValue {
+ value: Some(protobuf::scalar_value::Value::NullValue(
+ (&data_type).try_into()?,
+ )),
+ }),
+ },
ScalarValue::Date64(val) => {
create_proto_scalar(val.as_ref(), &data_type, |s|
Value::Date64Value(*s))
}