This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new c07c26cd1e Fix incorrect results in `BitAnd` GroupsAccumulator (#6957)
c07c26cd1e is described below
commit c07c26cd1e237cda4f8db332f6b7acec3ab4055c
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Jul 13 13:19:34 2023 -0400
Fix incorrect results in `BitAnd` GroupsAccumulator (#6957)
Fix accumulator
---
.../tests/sqllogictests/test_files/aggregate.slt | 184 +++++++++++++--------
.../physical-expr/src/aggregate/bit_and_or_xor.rs | 83 ++++------
.../src/aggregate/groups_accumulator/prim_op.rs | 12 +-
3 files changed, 160 insertions(+), 119 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index 95cf51d571..72b9e8400b 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -1420,65 +1420,95 @@ select var(sq.column1), var_pop(sq.column1),
stddev(sq.column1), stddev_pop(sq.c
2 1 1.414213562373 1
-# sum / count for all nulls
-statement ok
-create table the_nulls as values (null::bigint, 1), (null::bigint, 1),
(null::bigint, 2);
-# counts should be zeros (even for nulls)
-query II
-SELECT count(column1), column2 from the_nulls group by column2 order by
column2;
-----
-0 1
-0 2
-
-# sums should be null
-query II
-SELECT sum(column1), column2 from the_nulls group by column2 order by column2;
+# aggregates on empty tables
+statement ok
+CREATE TABLE empty (column1 bigint, column2 int);
+
+# no group by column
+query IIRIIIII
+SELECT
+ count(column1), -- counts should be zero, even for nulls
+ sum(column1), -- other aggregates should be null
+ avg(column1),
+ min(column1),
+ max(column1),
+ bit_and(column1),
+ bit_or(column1),
+ bit_xor(column1)
+FROM empty
+----
+0 NULL NULL NULL NULL NULL NULL NULL
+
+# Same query but with grouping (no groups, so no output)
+query IIRIIIIII
+SELECT
+ count(column1),
+ sum(column1),
+ avg(column1),
+ min(column1),
+ max(column1),
+ bit_and(column1),
+ bit_or(column1),
+ bit_xor(column1),
+ column2
+FROM empty
+GROUP BY column2
+ORDER BY column2;
----
-NULL 1
-NULL 2
-# avg should be null
-query RI
-SELECT avg(column1), column2 from the_nulls group by column2 order by column2;
-----
-NULL 1
-NULL 2
-# bit_and should be null
-query II
-SELECT bit_and(column1), column2 from the_nulls group by column2 order by
column2;
-----
-NULL 1
-NULL 2
+statement ok
+drop table empty
-# bit_or should be null
-query II
-SELECT bit_or(column1), column2 from the_nulls group by column2 order by
column2;
-----
-NULL 1
-NULL 2
+# aggregates on all nulls
+statement ok
+CREATE TABLE the_nulls
+AS VALUES
+ (null::bigint, 1),
+ (null::bigint, 1),
+ (null::bigint, 2);
-# bit_xor should be null
query II
-SELECT bit_xor(column1), column2 from the_nulls group by column2 order by
column2;
+select * from the_nulls
----
NULL 1
-NULL 2
-
-# min should be null
-query II
-SELECT min(column1), column2 from the_nulls group by column2 order by column2;
-----
NULL 1
NULL 2
-# max should be null
-query II
-SELECT max(column1), column2 from the_nulls group by column2 order by column2;
-----
-NULL 1
-NULL 2
+# no group by column
+query IIRIIIII
+SELECT
+ count(column1), -- counts should be zero, even for nulls
+ sum(column1), -- other aggregates should be null
+ avg(column1),
+ min(column1),
+ max(column1),
+ bit_and(column1),
+ bit_or(column1),
+ bit_xor(column1)
+FROM the_nulls
+----
+0 NULL NULL NULL NULL NULL NULL NULL
+
+# Same query but with grouping
+query IIRIIIIII
+SELECT
+ count(column1), -- counts should be zero, even for nulls
+ sum(column1), -- other aggregates should be null
+ avg(column1),
+ min(column1),
+ max(column1),
+ bit_and(column1),
+ bit_or(column1),
+ bit_xor(column1),
+ column2
+FROM the_nulls
+GROUP BY column2
+ORDER BY column2;
+----
+0 NULL NULL NULL NULL NULL NULL NULL 1
+0 NULL NULL NULL NULL NULL NULL NULL 2
statement ok
@@ -1489,29 +1519,49 @@ create table bit_aggregate_functions (
c1 SMALLINT NOT NULL,
c2 SMALLINT NOT NULL,
c3 SMALLINT,
+ tag varchar
)
as values
- (5, 10, 11),
- (33, 11, null),
- (9, 12, null);
-
-# query_bit_and
-query III
-SELECT bit_and(c1), bit_and(c2), bit_and(c3) FROM bit_aggregate_functions
-----
-1 8 11
-
-# query_bit_or
-query III
-SELECT bit_or(c1), bit_or(c2), bit_or(c3) FROM bit_aggregate_functions
-----
-45 15 11
+ (5, 10, 11, 'A'),
+ (33, 11, null, 'B'),
+ (9, 12, null, 'A');
+
+# query_bit_and, query_bit_or, query_bit_xor
+query IIIIIIIII
+SELECT
+ bit_and(c1),
+ bit_and(c2),
+ bit_and(c3),
+ bit_or(c1),
+ bit_or(c2),
+ bit_or(c3),
+ bit_xor(c1),
+ bit_xor(c2),
+ bit_xor(c3)
+FROM bit_aggregate_functions
+----
+1 8 11 45 15 11 45 13 11
+
+# query_bit_and, query_bit_or, query_bit_xor, with group
+query IIIIIIIIIT
+SELECT
+ bit_and(c1),
+ bit_and(c2),
+ bit_and(c3),
+ bit_or(c1),
+ bit_or(c2),
+ bit_or(c3),
+ bit_xor(c1),
+ bit_xor(c2),
+ bit_xor(c3),
+ tag
+FROM bit_aggregate_functions
+GROUP BY tag
+ORDER BY tag
+----
+1 8 11 13 14 11 12 6 11 A
+33 11 NULL 33 11 NULL 33 11 NULL B
-# query_bit_xor
-query III
-SELECT bit_xor(c1), bit_xor(c2), bit_xor(c3) FROM bit_aggregate_functions
-----
-45 13 11
statement ok
create table bool_aggregate_functions (
diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
index ab37e5891e..6a2d509389 100644
--- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
+++ b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
@@ -49,15 +49,16 @@ use arrow::compute::{bit_and, bit_or, bit_xor};
use datafusion_row::accessor::RowAccessor;
/// Creates a [`PrimitiveGroupsAccumulator`] with the specified
-/// [`ArrowPrimitiveType`] which applies `$FN` to each element
+/// [`ArrowPrimitiveType`] that initailizes each accumulator to $START
+/// and applies `$FN` to each element
///
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
-macro_rules! instantiate_primitive_accumulator {
- ($SELF:expr, $PRIMTYPE:ident, $FN:expr) => {{
- Ok(Box::new(PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(
- &$SELF.data_type,
- $FN,
- )))
+macro_rules! instantiate_accumulator {
+ ($SELF:expr, $START:expr, $PRIMTYPE:ident, $FN:expr) => {{
+ Ok(Box::new(
+ PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$SELF.data_type,
$FN)
+ .with_starting_value($START),
+ ))
}};
}
@@ -279,35 +280,31 @@ impl AggregateExpr for BitAnd {
use std::ops::BitAndAssign;
match self.data_type {
DataType::Int8 => {
- instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
- .bitand_assign(y))
+ instantiate_accumulator!(self, -1, Int8Type, |x, y|
x.bitand_assign(y))
}
DataType::Int16 => {
- instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
- .bitand_assign(y))
+ instantiate_accumulator!(self, -1, Int16Type, |x, y|
x.bitand_assign(y))
}
DataType::Int32 => {
- instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
- .bitand_assign(y))
+ instantiate_accumulator!(self, -1, Int32Type, |x, y|
x.bitand_assign(y))
}
DataType::Int64 => {
- instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
- .bitand_assign(y))
+ instantiate_accumulator!(self, -1, Int64Type, |x, y|
x.bitand_assign(y))
}
DataType::UInt8 => {
- instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
+ instantiate_accumulator!(self, u8::MAX, UInt8Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt16 => {
- instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
+ instantiate_accumulator!(self, u16::MAX, UInt16Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt32 => {
- instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
+ instantiate_accumulator!(self, u32::MAX, UInt32Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt64 => {
- instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
+ instantiate_accumulator!(self, u64::MAX, UInt64Type, |x, y| x
.bitand_assign(y))
}
@@ -517,36 +514,28 @@ impl AggregateExpr for BitOr {
use std::ops::BitOrAssign;
match self.data_type {
DataType::Int8 => {
- instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
- .bitor_assign(y))
+ instantiate_accumulator!(self, 0, Int8Type, |x, y|
x.bitor_assign(y))
}
DataType::Int16 => {
- instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
- .bitor_assign(y))
+ instantiate_accumulator!(self, 0, Int16Type, |x, y|
x.bitor_assign(y))
}
DataType::Int32 => {
- instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
- .bitor_assign(y))
+ instantiate_accumulator!(self, 0, Int32Type, |x, y|
x.bitor_assign(y))
}
DataType::Int64 => {
- instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
- .bitor_assign(y))
+ instantiate_accumulator!(self, 0, Int64Type, |x, y|
x.bitor_assign(y))
}
DataType::UInt8 => {
- instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
- .bitor_assign(y))
+ instantiate_accumulator!(self, 0, UInt8Type, |x, y|
x.bitor_assign(y))
}
DataType::UInt16 => {
- instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
- .bitor_assign(y))
+ instantiate_accumulator!(self, 0, UInt16Type, |x, y|
x.bitor_assign(y))
}
DataType::UInt32 => {
- instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
- .bitor_assign(y))
+ instantiate_accumulator!(self, 0, UInt32Type, |x, y|
x.bitor_assign(y))
}
DataType::UInt64 => {
- instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
- .bitor_assign(y))
+ instantiate_accumulator!(self, 0, UInt64Type, |x, y|
x.bitor_assign(y))
}
_ => Err(DataFusionError::NotImplemented(format!(
@@ -756,36 +745,28 @@ impl AggregateExpr for BitXor {
use std::ops::BitXorAssign;
match self.data_type {
DataType::Int8 => {
- instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
- .bitxor_assign(y))
+ instantiate_accumulator!(self, 0, Int8Type, |x, y|
x.bitxor_assign(y))
}
DataType::Int16 => {
- instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
- .bitxor_assign(y))
+ instantiate_accumulator!(self, 0, Int16Type, |x, y|
x.bitxor_assign(y))
}
DataType::Int32 => {
- instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
- .bitxor_assign(y))
+ instantiate_accumulator!(self, 0, Int32Type, |x, y|
x.bitxor_assign(y))
}
DataType::Int64 => {
- instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
- .bitxor_assign(y))
+ instantiate_accumulator!(self, 0, Int64Type, |x, y|
x.bitxor_assign(y))
}
DataType::UInt8 => {
- instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
- .bitxor_assign(y))
+ instantiate_accumulator!(self, 0, UInt8Type, |x, y|
x.bitxor_assign(y))
}
DataType::UInt16 => {
- instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
- .bitxor_assign(y))
+ instantiate_accumulator!(self, 0, UInt16Type, |x, y|
x.bitxor_assign(y))
}
DataType::UInt32 => {
- instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
- .bitxor_assign(y))
+ instantiate_accumulator!(self, 0, UInt32Type, |x, y|
x.bitxor_assign(y))
}
DataType::UInt64 => {
- instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
- .bitxor_assign(y))
+ instantiate_accumulator!(self, 0, UInt64Type, |x, y|
x.bitxor_assign(y))
}
_ => Err(DataFusionError::NotImplemented(format!(
diff --git
a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs
b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs
index 8603010789..a49651a5e3 100644
--- a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs
+++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs
@@ -47,6 +47,9 @@ where
/// The output type (needed for Decimal precision and scale)
data_type: DataType,
+ /// The starting value for new groups
+ starting_value: T::Native,
+
/// Track nulls in the input / filters
null_state: NullState,
@@ -64,9 +67,16 @@ where
values: vec![],
data_type: data_type.clone(),
null_state: NullState::new(),
+ starting_value: T::default_value(),
prim_fn,
}
}
+
+ /// Set the starting values for new groups
+ pub fn with_starting_value(mut self, starting_value: T::Native) -> Self {
+ self.starting_value = starting_value;
+ self
+ }
}
impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F>
@@ -85,7 +95,7 @@ where
let values = values[0].as_primitive::<T>();
// update values
- self.values.resize(total_num_groups, T::default_value());
+ self.values.resize(total_num_groups, self.starting_value);
// NullState dispatches / handles tracking nulls and groups that saw
no values
self.null_state.accumulate(