This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 73f14051 chore: Move more expressions from core crate to spark-expr
crate (#1152)
73f14051 is described below
commit 73f14051adc3bfa513adc54a9af157928472ee0b
Author: Andy Grove <[email protected]>
AuthorDate: Mon Dec 9 11:45:12 2024 -0700
chore: Move more expressions from core crate to spark-expr crate (#1152)
* move aggregate expressions to spark-expr crate
* move more expressions
* move benchmark
* normalize_nan
* bitwise not
* comet scalar funcs
* update bench imports
---
native/Cargo.lock | 2 ++
native/Cargo.toml | 1 +
native/core/Cargo.toml | 7 +---
native/core/src/common/bit.rs | 6 ++--
.../datafusion/expressions/checkoverflow.rs | 15 +--------
.../src/execution/datafusion/expressions/mod.rs | 11 ------
native/core/src/execution/datafusion/planner.rs | 18 +++-------
native/core/src/lib.rs | 27 ---------------
native/core/src/parquet/read/levels.rs | 7 ++--
native/core/src/parquet/read/values.rs | 2 +-
native/spark-expr/Cargo.toml | 8 +++++
native/{core => spark-expr}/benches/aggregate.rs | 6 ++--
.../expressions => spark-expr/src}/avg.rs | 0
.../expressions => spark-expr/src}/avg_decimal.rs | 2 +-
.../expressions => spark-expr/src}/bitwise_not.rs | 18 +---------
.../src}/comet_scalar_funcs.rs | 6 ++--
.../expressions => spark-expr/src}/correlation.rs | 5 ++-
.../expressions => spark-expr/src}/covariance.rs | 0
native/spark-expr/src/lib.rs | 20 +++++++++++
.../src}/normalize_nan.rs | 0
.../expressions => spark-expr/src}/stddev.rs | 2 +-
.../expressions => spark-expr/src}/sum_decimal.rs | 10 ++----
native/spark-expr/src/utils.rs | 39 +++++++++++++++++++++-
.../expressions => spark-expr/src}/variance.rs | 0
24 files changed, 96 insertions(+), 116 deletions(-)
diff --git a/native/Cargo.lock b/native/Cargo.lock
index a7f8359d..67d041a3 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -942,10 +942,12 @@ dependencies = [
"datafusion-common",
"datafusion-expr",
"datafusion-physical-expr",
+ "futures",
"num",
"rand",
"regex",
"thiserror",
+ "tokio",
"twox-hash 2.0.1",
]
diff --git a/native/Cargo.toml b/native/Cargo.toml
index 85c46a6d..4ac85479 100644
--- a/native/Cargo.toml
+++ b/native/Cargo.toml
@@ -51,6 +51,7 @@ datafusion-comet-spark-expr = { path = "spark-expr", version
= "0.5.0" }
datafusion-comet-proto = { path = "proto", version = "0.5.0" }
chrono = { version = "0.4", default-features = false, features = ["clock"] }
chrono-tz = { version = "0.8" }
+futures = "0.3.28"
num = "0.4"
rand = "0.8"
regex = "1.9.6"
diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml
index daa0837c..4b9753ec 100644
--- a/native/core/Cargo.toml
+++ b/native/core/Cargo.toml
@@ -42,7 +42,7 @@ arrow-data = { workspace = true }
arrow-schema = { workspace = true }
parquet = { workspace = true, default-features = false, features =
["experimental"] }
half = { version = "2.4.1", default-features = false }
-futures = "0.3.28"
+futures = { workspace = true }
mimalloc = { version = "*", default-features = false, optional = true }
tokio = { version = "1", features = ["rt-multi-thread"] }
async-trait = "0.1"
@@ -88,7 +88,6 @@ hex = "0.4.3"
[features]
default = []
-nightly = []
[lib]
name = "comet"
@@ -123,10 +122,6 @@ harness = false
name = "filter"
harness = false
-[[bench]]
-name = "aggregate"
-harness = false
-
[[bench]]
name = "bloom_filter_agg"
harness = false
diff --git a/native/core/src/common/bit.rs b/native/core/src/common/bit.rs
index 871786bb..72d7729d 100644
--- a/native/core/src/common/bit.rs
+++ b/native/core/src/common/bit.rs
@@ -17,14 +17,12 @@
use std::{cmp::min, mem::size_of};
-use arrow::buffer::Buffer;
-
use crate::{
errors::CometResult as Result,
- likely,
parquet::{data_type::AsBytes, util::bit_packing::unpack32},
- unlikely,
};
+use arrow::buffer::Buffer;
+use datafusion_comet_spark_expr::utils::{likely, unlikely};
#[inline]
pub fn from_ne_slice<T: FromBytes>(bs: &[u8]) -> T {
diff --git a/native/core/src/execution/datafusion/expressions/checkoverflow.rs
b/native/core/src/execution/datafusion/expressions/checkoverflow.rs
index ed03ab66..e922171b 100644
--- a/native/core/src/execution/datafusion/expressions/checkoverflow.rs
+++ b/native/core/src/execution/datafusion/expressions/checkoverflow.rs
@@ -27,8 +27,7 @@ use arrow::{
datatypes::{Decimal128Type, DecimalType},
record_batch::RecordBatch,
};
-use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION,
MIN_DECIMAL_FOR_EACH_PRECISION};
-use arrow_schema::{DataType, Schema, DECIMAL128_MAX_PRECISION};
+use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_common::{DataFusionError, ScalarValue};
@@ -172,15 +171,3 @@ impl PhysicalExpr for CheckOverflow {
self.hash(&mut s);
}
}
-
-/// Adapted from arrow-rs `validate_decimal_precision` but returns bool
-/// instead of Err to avoid the cost of formatting the error strings and is
-/// optimized to remove a memcpy that exists in the original function
-/// we can remove this code once we upgrade to a version of arrow-rs that
-/// includes https://github.com/apache/arrow-rs/pull/6419
-#[inline]
-pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool {
- precision <= DECIMAL128_MAX_PRECISION
- && value >= MIN_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1]
- && value <= MAX_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1]
-}
diff --git a/native/core/src/execution/datafusion/expressions/mod.rs
b/native/core/src/execution/datafusion/expressions/mod.rs
index 48b80384..2bb14df3 100644
--- a/native/core/src/execution/datafusion/expressions/mod.rs
+++ b/native/core/src/execution/datafusion/expressions/mod.rs
@@ -17,26 +17,15 @@
//! Native DataFusion expressions
-pub mod bitwise_not;
pub mod checkoverflow;
-mod normalize_nan;
-pub use normalize_nan::NormalizeNaNAndZero;
use crate::errors::CometError;
-pub mod avg;
-pub mod avg_decimal;
pub mod bloom_filter_agg;
pub mod bloom_filter_might_contain;
-pub mod comet_scalar_funcs;
-pub mod correlation;
-pub mod covariance;
pub mod negative;
-pub mod stddev;
pub mod strings;
pub mod subquery;
-pub mod sum_decimal;
pub mod unbound;
-pub mod variance;
pub use datafusion_comet_spark_expr::{EvalMode, SparkError};
diff --git a/native/core/src/execution/datafusion/planner.rs
b/native/core/src/execution/datafusion/planner.rs
index 33c4924c..a83dba5d 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -18,29 +18,19 @@
//! Converts Spark physical plan to DataFusion physical plan
use super::expressions::EvalMode;
-use
crate::execution::datafusion::expressions::comet_scalar_funcs::create_comet_physical_fun;
use crate::execution::operators::{CopyMode, FilterExec};
use crate::{
errors::ExpressionError,
execution::{
datafusion::{
expressions::{
- avg::Avg,
- avg_decimal::AvgDecimal,
- bitwise_not::BitwiseNotExpr,
bloom_filter_agg::BloomFilterAgg,
bloom_filter_might_contain::BloomFilterMightContain,
checkoverflow::CheckOverflow,
- correlation::Correlation,
- covariance::Covariance,
negative,
- stddev::Stddev,
strings::{Contains, EndsWith, Like, StartsWith,
StringSpaceExpr, SubstringExpr},
subquery::Subquery,
- sum_decimal::SumDecimal,
unbound::UnboundColumn,
- variance::Variance,
- NormalizeNaNAndZero,
},
operators::expand::CometExpandExec,
shuffle_writer::ShuffleWriterExec,
@@ -82,6 +72,7 @@ use datafusion::{
},
prelude::SessionContext,
};
+use datafusion_comet_spark_expr::create_comet_physical_fun;
use datafusion_functions_nested::concat::ArrayAppend;
use datafusion_physical_expr::aggregate::{AggregateExprBuilder,
AggregateFunctionExpr};
@@ -99,9 +90,10 @@ use datafusion_comet_proto::{
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as
SparkPartitioning},
};
use datafusion_comet_spark_expr::{
- ArrayInsert, Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields,
GetStructField,
- HourExpr, IfExpr, ListExtract, MinuteExpr, RLike, SecondExpr,
SparkCastOptions,
- TimestampTruncExpr, ToJson,
+ ArrayInsert, Avg, AvgDecimal, BitwiseNotExpr, Cast, Correlation,
Covariance, CreateNamedStruct,
+ DateTruncExpr, GetArrayStructFields, GetStructField, HourExpr, IfExpr,
ListExtract, MinuteExpr,
+ NormalizeNaNAndZero, RLike, SecondExpr, SparkCastOptions, Stddev,
SumDecimal,
+ TimestampTruncExpr, ToJson, Variance,
};
use datafusion_common::scalar::ScalarStructBuilder;
use datafusion_common::{
diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs
index c6a7a414..68c8ae72 100644
--- a/native/core/src/lib.rs
+++ b/native/core/src/lib.rs
@@ -104,30 +104,3 @@ fn default_logger_config() -> CometResult<Config> {
.build(root)
.map_err(|err| CometError::Config(err.to_string()))
}
-
-// These are borrowed from hashbrown crate:
-// https://github.com/rust-lang/hashbrown/blob/master/src/raw/mod.rs
-
-// On stable we can use #[cold] to get a equivalent effect: this attributes
-// suggests that the function is unlikely to be called
-#[cfg(not(feature = "nightly"))]
-#[inline]
-#[cold]
-fn cold() {}
-
-#[cfg(not(feature = "nightly"))]
-#[inline]
-fn likely(b: bool) -> bool {
- if !b {
- cold();
- }
- b
-}
-#[cfg(not(feature = "nightly"))]
-#[inline]
-fn unlikely(b: bool) -> bool {
- if b {
- cold();
- }
- b
-}
diff --git a/native/core/src/parquet/read/levels.rs
b/native/core/src/parquet/read/levels.rs
index 3d74b277..9077c0e4 100644
--- a/native/core/src/parquet/read/levels.rs
+++ b/native/core/src/parquet/read/levels.rs
@@ -17,15 +17,14 @@
use std::mem;
-use arrow::buffer::Buffer;
-use parquet::schema::types::ColumnDescPtr;
-
use super::values::Decoder;
use crate::{
common::bit::{self, read_u32, BitReader},
parquet::ParquetMutableVector,
- unlikely,
};
+use arrow::buffer::Buffer;
+use datafusion_comet_spark_expr::utils::unlikely;
+use parquet::schema::types::ColumnDescPtr;
const INITIAL_BUF_LEN: usize = 16;
diff --git a/native/core/src/parquet/read/values.rs
b/native/core/src/parquet/read/values.rs
index b439e29e..71cd035d 100644
--- a/native/core/src/parquet/read/values.rs
+++ b/native/core/src/parquet/read/values.rs
@@ -28,9 +28,9 @@ use crate::write_val_or_null;
use crate::{
common::bit::{self, BitReader},
parquet::{data_type::*, ParquetMutableVector},
- unlikely,
};
use arrow::datatypes::DataType as ArrowDataType;
+use datafusion_comet_spark_expr::utils::unlikely;
pub fn get_decoder<T: DataType>(
value_data: Buffer,
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index 532bf743..65517431 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -29,6 +29,7 @@ edition = { workspace = true }
[dependencies]
arrow = { workspace = true }
arrow-array = { workspace = true }
+arrow-data = { workspace = true }
arrow-schema = { workspace = true }
chrono = { workspace = true }
datafusion = { workspace = true }
@@ -39,12 +40,14 @@ chrono-tz = { workspace = true }
num = { workspace = true }
regex = { workspace = true }
thiserror = { workspace = true }
+futures = { workspace = true }
twox-hash = "2.0.0"
[dev-dependencies]
arrow-data = {workspace = true}
criterion = "0.5.1"
rand = { workspace = true}
+tokio = { version = "1", features = ["rt-multi-thread"] }
[lib]
@@ -66,3 +69,8 @@ harness = false
[[bench]]
name = "decimal_div"
harness = false
+
+[[bench]]
+name = "aggregate"
+harness = false
+
diff --git a/native/core/benches/aggregate.rs
b/native/spark-expr/benches/aggregate.rs
similarity index 97%
rename from native/core/benches/aggregate.rs
rename to native/spark-expr/benches/aggregate.rs
index c6209406..43194fdd 100644
--- a/native/core/benches/aggregate.rs
+++ b/native/spark-expr/benches/aggregate.rs
@@ -19,16 +19,16 @@ use arrow::datatypes::{DataType, Field, Schema};
use arrow_array::builder::{Decimal128Builder, StringBuilder};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_schema::SchemaRef;
-use comet::execution::datafusion::expressions::avg_decimal::AvgDecimal;
-use comet::execution::datafusion::expressions::sum_decimal::SumDecimal;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
+use datafusion::execution::TaskContext;
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode,
PhysicalGroupBy};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
-use datafusion_execution::TaskContext;
+use datafusion_comet_spark_expr::AvgDecimal;
+use datafusion_comet_spark_expr::SumDecimal;
use datafusion_expr::AggregateUDF;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::Column;
diff --git a/native/core/src/execution/datafusion/expressions/avg.rs
b/native/spark-expr/src/avg.rs
similarity index 100%
rename from native/core/src/execution/datafusion/expressions/avg.rs
rename to native/spark-expr/src/avg.rs
diff --git a/native/core/src/execution/datafusion/expressions/avg_decimal.rs
b/native/spark-expr/src/avg_decimal.rs
similarity index 99%
rename from native/core/src/execution/datafusion/expressions/avg_decimal.rs
rename to native/spark-expr/src/avg_decimal.rs
index a265fdc2..163e1560 100644
--- a/native/core/src/execution/datafusion/expressions/avg_decimal.rs
+++ b/native/spark-expr/src/avg_decimal.rs
@@ -28,7 +28,7 @@ use datafusion_common::{not_impl_err, Result, ScalarValue};
use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr};
use std::{any::Any, sync::Arc};
-use
crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision;
+use crate::utils::is_valid_decimal_precision;
use arrow_array::ArrowNativeTypeOp;
use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION,
MIN_DECIMAL_FOR_EACH_PRECISION};
use datafusion::logical_expr::Volatility::Immutable;
diff --git a/native/core/src/execution/datafusion/expressions/bitwise_not.rs
b/native/spark-expr/src/bitwise_not.rs
similarity index 88%
rename from native/core/src/execution/datafusion/expressions/bitwise_not.rs
rename to native/spark-expr/src/bitwise_not.rs
index a2b9ebe5..36234935 100644
--- a/native/core/src/execution/datafusion/expressions/bitwise_not.rs
+++ b/native/spark-expr/src/bitwise_not.rs
@@ -28,7 +28,7 @@ use arrow::{
};
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion::{error::DataFusionError, logical_expr::ColumnarValue};
-use datafusion_common::{Result, ScalarValue};
+use datafusion_common::Result;
use datafusion_physical_expr::PhysicalExpr;
macro_rules! compute_op {
@@ -135,22 +135,6 @@ pub fn bitwise_not(arg: Arc<dyn PhysicalExpr>) ->
Result<Arc<dyn PhysicalExpr>>
Ok(Arc::new(BitwiseNotExpr::new(arg)))
}
-fn scalar_bitwise_not(scalar: ScalarValue) -> Result<ScalarValue> {
- match scalar {
- ScalarValue::Int8(None)
- | ScalarValue::Int16(None)
- | ScalarValue::Int32(None)
- | ScalarValue::Int64(None) => Ok(scalar),
- ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(!v))),
- ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(!v))),
- ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(!v))),
- ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(!v))),
- value => Err(DataFusionError::Internal(format!(
- "Can not run ! on scalar value {value:?}"
- ))),
- }
-}
-
#[cfg(test)]
mod tests {
use arrow::datatypes::*;
diff --git
a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs
b/native/spark-expr/src/comet_scalar_funcs.rs
similarity index 98%
rename from
native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs
rename to native/spark-expr/src/comet_scalar_funcs.rs
index 06717aab..71ff0e9d 100644
--- a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -15,15 +15,15 @@
// specific language governing permissions and limitations
// under the License.
-use arrow_schema::DataType;
-use datafusion_comet_spark_expr::scalar_funcs::hash_expressions::{
+use crate::scalar_funcs::hash_expressions::{
spark_sha224, spark_sha256, spark_sha384, spark_sha512,
};
-use datafusion_comet_spark_expr::scalar_funcs::{
+use crate::scalar_funcs::{
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
spark_floor, spark_hex,
spark_isnan, spark_make_decimal, spark_murmur3_hash,
spark_read_side_padding, spark_round,
spark_unhex, spark_unscaled_value, spark_xxhash64, SparkChrFunc,
};
+use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::registry::FunctionRegistry;
use datafusion_expr::{
diff --git a/native/core/src/execution/datafusion/expressions/correlation.rs
b/native/spark-expr/src/correlation.rs
similarity index 98%
rename from native/core/src/execution/datafusion/expressions/correlation.rs
rename to native/spark-expr/src/correlation.rs
index 6bf35e71..e5f36c6f 100644
--- a/native/core/src/execution/datafusion/expressions/correlation.rs
+++ b/native/spark-expr/src/correlation.rs
@@ -19,9 +19,8 @@ use arrow::compute::{and, filter, is_not_null};
use std::{any::Any, sync::Arc};
-use crate::execution::datafusion::expressions::{
- covariance::CovarianceAccumulator, stddev::StddevAccumulator,
-};
+use crate::covariance::CovarianceAccumulator;
+use crate::stddev::StddevAccumulator;
use arrow::{
array::ArrayRef,
datatypes::{DataType, Field},
diff --git a/native/core/src/execution/datafusion/expressions/covariance.rs
b/native/spark-expr/src/covariance.rs
similarity index 100%
rename from native/core/src/execution/datafusion/expressions/covariance.rs
rename to native/spark-expr/src/covariance.rs
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index c227b3a0..15f446ef 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -23,18 +23,38 @@ mod cast;
mod error;
mod if_expr;
+mod avg;
+pub use avg::Avg;
+mod bitwise_not;
+pub use bitwise_not::{bitwise_not, BitwiseNotExpr};
+mod avg_decimal;
+pub use avg_decimal::AvgDecimal;
+mod correlation;
+pub use correlation::Correlation;
+mod covariance;
+pub use covariance::Covariance;
mod kernels;
mod list;
mod regexp;
pub mod scalar_funcs;
pub mod spark_hash;
+mod stddev;
+pub use stddev::Stddev;
mod structs;
+mod sum_decimal;
+pub use sum_decimal::SumDecimal;
+mod normalize_nan;
mod temporal;
pub mod timezone;
mod to_json;
pub mod utils;
+pub use normalize_nan::NormalizeNaNAndZero;
+mod variance;
+pub use variance::Variance;
+mod comet_scalar_funcs;
pub use cast::{spark_cast, Cast, SparkCastOptions};
+pub use comet_scalar_funcs::create_comet_physical_fun;
pub use error::{SparkError, SparkResult};
pub use if_expr::IfExpr;
pub use list::{ArrayInsert, GetArrayStructFields, ListExtract};
diff --git a/native/core/src/execution/datafusion/expressions/normalize_nan.rs
b/native/spark-expr/src/normalize_nan.rs
similarity index 100%
rename from native/core/src/execution/datafusion/expressions/normalize_nan.rs
rename to native/spark-expr/src/normalize_nan.rs
diff --git a/native/core/src/execution/datafusion/expressions/stddev.rs
b/native/spark-expr/src/stddev.rs
similarity index 98%
rename from native/core/src/execution/datafusion/expressions/stddev.rs
rename to native/spark-expr/src/stddev.rs
index 1ba495e2..3cf604da 100644
--- a/native/core/src/execution/datafusion/expressions/stddev.rs
+++ b/native/spark-expr/src/stddev.rs
@@ -17,7 +17,7 @@
use std::{any::Any, sync::Arc};
-use crate::execution::datafusion::expressions::variance::VarianceAccumulator;
+use crate::variance::VarianceAccumulator;
use arrow::{
array::ArrayRef,
datatypes::{DataType, Field},
diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs
b/native/spark-expr/src/sum_decimal.rs
similarity index 98%
rename from native/core/src/execution/datafusion/expressions/sum_decimal.rs
rename to native/spark-expr/src/sum_decimal.rs
index d885ff90..ab142aee 100644
--- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs
+++ b/native/spark-expr/src/sum_decimal.rs
@@ -15,8 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use
crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision;
-use crate::unlikely;
+use crate::utils::{is_valid_decimal_precision, unlikely};
use arrow::{
array::BooleanBufferBuilder,
buffer::{BooleanBuffer, NullBuffer},
@@ -113,7 +112,6 @@ impl AggregateUDFImpl for SumDecimal {
Ok(Box::new(SumDecimalGroupsAccumulator::new(
self.result_type.clone(),
self.precision,
- self.scale,
)))
}
@@ -286,18 +284,16 @@ struct SumDecimalGroupsAccumulator {
sum: Vec<i128>,
result_type: DataType,
precision: u8,
- scale: i8,
}
impl SumDecimalGroupsAccumulator {
- fn new(result_type: DataType, precision: u8, scale: i8) -> Self {
+ fn new(result_type: DataType, precision: u8) -> Self {
Self {
is_not_null: BooleanBufferBuilder::new(0),
is_empty: BooleanBufferBuilder::new(0),
sum: Vec::new(),
result_type,
precision,
- scale,
}
}
@@ -488,11 +484,11 @@ mod tests {
use arrow::datatypes::*;
use arrow_array::builder::{Decimal128Builder, StringBuilder};
use arrow_array::RecordBatch;
+ use datafusion::execution::TaskContext;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode,
PhysicalGroupBy};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_common::Result;
- use datafusion_execution::TaskContext;
use datafusion_expr::AggregateUDF;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::{Column, Literal};
diff --git a/native/spark-expr/src/utils.rs b/native/spark-expr/src/utils.rs
index db4ad195..18a2314f 100644
--- a/native/spark-expr/src/utils.rs
+++ b/native/spark-expr/src/utils.rs
@@ -19,7 +19,7 @@ use arrow_array::{
cast::as_primitive_array,
types::{Int32Type, TimestampMicrosecondType},
};
-use arrow_schema::{ArrowError, DataType};
+use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION};
use std::sync::Arc;
use crate::timezone::Tz;
@@ -27,6 +27,7 @@ use arrow::{
array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray},
temporal_conversions::as_datetime,
};
+use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION,
MIN_DECIMAL_FOR_EACH_PRECISION};
use chrono::{DateTime, Offset, TimeZone};
/// Preprocesses input arrays to add timezone information from Spark to Arrow
array datatype or
@@ -176,3 +177,39 @@ fn pre_timestamp_cast(array: ArrayRef, timezone: String)
-> Result<ArrayRef, Arr
_ => Ok(array),
}
}
+
+/// Adapted from arrow-rs `validate_decimal_precision` but returns bool
+/// instead of Err to avoid the cost of formatting the error strings and is
+/// optimized to remove a memcpy that exists in the original function
+/// we can remove this code once we upgrade to a version of arrow-rs that
+/// includes https://github.com/apache/arrow-rs/pull/6419
+#[inline]
+pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool {
+ precision <= DECIMAL128_MAX_PRECISION
+ && value >= MIN_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1]
+ && value <= MAX_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1]
+}
+
+// These are borrowed from hashbrown crate:
+// https://github.com/rust-lang/hashbrown/blob/master/src/raw/mod.rs
+
+// On stable we can use #[cold] to get a equivalent effect: this attributes
+// suggests that the function is unlikely to be called
+#[inline]
+#[cold]
+pub fn cold() {}
+
+#[inline]
+pub fn likely(b: bool) -> bool {
+ if !b {
+ cold();
+ }
+ b
+}
+#[inline]
+pub fn unlikely(b: bool) -> bool {
+ if b {
+ cold();
+ }
+ b
+}
diff --git a/native/core/src/execution/datafusion/expressions/variance.rs
b/native/spark-expr/src/variance.rs
similarity index 100%
rename from native/core/src/execution/datafusion/expressions/variance.rs
rename to native/spark-expr/src/variance.rs
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]