This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 87e5adc9c7 feat: Implement the bitwise_not in NotExpr (#5902)
87e5adc9c7 is described below
commit 87e5adc9c7253f9aa50581f82953bfd3800d0b85
Author: RT_Enzyme <[email protected]>
AuthorDate: Wed Apr 12 03:40:40 2023 +0800
feat: Implement the bitwise_not in NotExpr (#5902)
* feat: make NotExpr support integer Bitwise_not
* feat: make NotExpr support integer Bitwise_not
---------
Co-authored-by: RT_Enzyme <[email protected]>
---
.../src/physical_optimizer/sort_enforcement.rs | 3 +-
datafusion/core/tests/sql/expr.rs | 15 +-
datafusion/physical-expr/src/expressions/mod.rs | 2 +-
datafusion/physical-expr/src/expressions/not.rs | 195 ++++++++++++++++++---
4 files changed, 187 insertions(+), 28 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs
b/datafusion/core/src/physical_optimizer/sort_enforcement.rs
index bada74193b..76628fdcb0 100644
--- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs
@@ -895,7 +895,8 @@ mod tests {
fn create_test_schema() -> Result<SchemaRef> {
let nullable_column = Field::new("nullable_col", DataType::Int32,
true);
- let non_nullable_column = Field::new("non_nullable_col",
DataType::Int32, false);
+ let non_nullable_column =
+ Field::new("non_nullable_col", DataType::Boolean, false);
let schema = Arc::new(Schema::new(vec![nullable_column,
non_nullable_column]));
Ok(schema)
diff --git a/datafusion/core/tests/sql/expr.rs
b/datafusion/core/tests/sql/expr.rs
index 672571670e..3d8ccf8667 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -622,6 +622,17 @@ async fn test_not_expressions() -> Result<()> {
];
assert_batches_eq!(expected, &actual);
+ let sql = "SELECT not(1), not(0)";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+--------------+--------------+",
+ "| NOT Int64(1) | NOT Int64(0) |",
+ "+--------------+--------------+",
+ "| -2 | -1 |",
+ "+--------------+--------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
let sql = "SELECT null, not(null)";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
@@ -638,9 +649,7 @@ async fn test_not_expressions() -> Result<()> {
match result {
Ok(_) => panic!("expected error"),
Err(e) => {
- assert_contains!(e.to_string(),
- "NOT 'Literal { value: Utf8(\"hi\") }' can't be
evaluated because the expression's type is Utf8, not boolean or NULL"
- );
+ assert_contains!(e.to_string(), "Can't NOT or BITWISE_NOT
datatype: 'Utf8'");
}
}
Ok(())
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index 63fb7b7d37..ad4b7031c0 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -87,7 +87,7 @@ pub use like::{like, LikeExpr};
pub use literal::{lit, Literal};
pub use negative::{negative, NegativeExpr};
pub use no_op::NoOp;
-pub use not::{not, NotExpr};
+pub use not::{bitwise_not, not, NotExpr};
pub use nullif::nullif_func;
pub use try_cast::{try_cast, TryCastExpr};
diff --git a/datafusion/physical-expr/src/expressions/not.rs
b/datafusion/physical-expr/src/expressions/not.rs
index bf935aa97e..ee44b91c9d 100644
--- a/datafusion/physical-expr/src/expressions/not.rs
+++ b/datafusion/physical-expr/src/expressions/not.rs
@@ -23,8 +23,12 @@ use std::sync::Arc;
use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;
-use arrow::datatypes::{DataType, Schema};
+use arrow::datatypes::{
+ DataType, Int16Type, Int32Type, Int64Type, Int8Type, Schema, UInt16Type,
UInt32Type,
+ UInt64Type, UInt8Type,
+};
use arrow::record_batch::RecordBatch;
+use datafusion_common::cast::as_primitive_array;
use datafusion_common::{cast::as_boolean_array, DataFusionError, Result,
ScalarValue};
use datafusion_expr::ColumnarValue;
@@ -59,8 +63,25 @@ impl PhysicalExpr for NotExpr {
self
}
- fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
- Ok(DataType::Boolean)
+ fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
+ // Ok(DataType::Boolean)
+ let data_type = self.arg.data_type(input_schema)?;
+ match data_type {
+ DataType::Boolean => Ok(DataType::Boolean),
+ DataType::UInt8 => Ok(DataType::UInt8),
+ DataType::UInt16 => Ok(DataType::UInt16),
+ DataType::UInt32 => Ok(DataType::UInt32),
+ DataType::UInt64 => Ok(DataType::UInt64),
+ DataType::Int8 => Ok(DataType::Int8),
+ DataType::Int16 => Ok(DataType::Int16),
+ DataType::Int32 => Ok(DataType::Int32),
+ DataType::Int64 => Ok(DataType::Int64),
+ DataType::Null => Ok(DataType::Null),
+ _ => Err(DataFusionError::Internal(format!(
+ "Can't NOT or BITWISE_NOT datatype: '{:?}'",
+ data_type
+ ))),
+ }
}
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
@@ -71,26 +92,46 @@ impl PhysicalExpr for NotExpr {
let evaluate_arg = self.arg.evaluate(batch)?;
match evaluate_arg {
ColumnarValue::Array(array) => {
- let array = as_boolean_array(&array)?;
- Ok(ColumnarValue::Array(Arc::new(
- arrow::compute::kernels::boolean::not(array)?,
- )))
+ match array.data_type() {
+ DataType::Boolean => {
+ let array = as_boolean_array(&array)?;
+ Ok(ColumnarValue::Array(Arc::new(
+ arrow::compute::kernels::boolean::not(array)?,
+ )))
+ },
+ DataType::UInt8 => expr_array_not!(&array, UInt8Type),
+ DataType::UInt16 => expr_array_not!(&array, UInt16Type),
+ DataType::UInt32 => expr_array_not!(&array, UInt32Type),
+ DataType::UInt64 => expr_array_not!(&array, UInt64Type),
+ DataType::Int8 => expr_array_not!(&array, Int8Type),
+ DataType::Int16 => expr_array_not!(&array, Int16Type),
+ DataType::Int32 => expr_array_not!(&array, Int32Type),
+ DataType::Int64 => expr_array_not!(&array, Int64Type),
+ _ => Err(DataFusionError::Internal(format!(
+ "NOT or Bitwise_not can't be evaluated because the
expression's typs is {:?}, not boolean or integer",
+ array.data_type(),
+ )))
+ }
}
ColumnarValue::Scalar(scalar) => {
- if scalar.is_null() {
- return
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
+ match scalar {
+ ScalarValue::Boolean(v) => expr_not!(v, Boolean),
+ ScalarValue::Int8(v) => expr_not!(v, Int8),
+ ScalarValue::Int16(v) => expr_not!(v, Int16),
+ ScalarValue::Int32(v) => expr_not!(v, Int32),
+ ScalarValue::Int64(v) => expr_not!(v, Int64),
+ ScalarValue::UInt8(v) => expr_not!(v, UInt8),
+ ScalarValue::UInt16(v) => expr_not!(v, UInt16),
+ ScalarValue::UInt32(v) => expr_not!(v, UInt32),
+ ScalarValue::UInt64(v) => expr_not!(v, UInt64),
+ ScalarValue::Null =>
Ok(ColumnarValue::Scalar(ScalarValue::Null)),
+ _ => {
+ Err(DataFusionError::Internal(format!(
+ "NOT/BITWISE_NOT '{:?}' can't be evaluated because
the expression's type is {:?}, not boolean or NULL or Integer",
+ self.arg, scalar.get_datatype(),
+ )))
+ }
}
- let value_type = scalar.get_datatype();
- if value_type != DataType::Boolean {
- return Err(DataFusionError::Internal(format!(
- "NOT '{:?}' can't be evaluated because the
expression's type is {:?}, not boolean or NULL",
- self.arg, value_type,
- )));
- }
- let bool_value: bool = scalar.try_into()?;
- Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
- !bool_value,
- ))))
}
}
}
@@ -121,12 +162,46 @@ pub fn not(arg: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn
PhysicalExpr>> {
Ok(Arc::new(NotExpr::new(arg)))
}
+/// Create a unary expression BITWISE_NOT
+pub fn bitwise_not(arg: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn
PhysicalExpr>> {
+ Ok(Arc::new(NotExpr::new(arg)))
+}
+
+macro_rules! expr_not {
+ ($VALUE:expr, $SCALAR_TY:ident) => {
+ match $VALUE {
+ Some(v) =>
Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TY(Some(!v)))),
+ None => Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TY(None))),
+ }
+ };
+}
+
+macro_rules! expr_array_not {
+ ($ARRAY:expr, $PRIMITIVE_TY:ident) => {
+ Ok(ColumnarValue::Array(Arc::new(
+ arrow::compute::kernels::bitwise::bitwise_not(as_primitive_array::<
+ $PRIMITIVE_TY,
+ >($ARRAY)?)?,
+ )))
+ };
+}
+
+use expr_array_not;
+use expr_not;
+
#[cfg(test)]
mod tests {
use super::*;
- use crate::expressions::col;
- use arrow::{array::BooleanArray, datatypes::*};
- use datafusion_common::Result;
+ use crate::expressions::{col, lit};
+ use arrow::{array::*, datatypes::*};
+ use arrow_schema::DataType::{
+ Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8,
+ };
+ use datafusion_common::{
+ cast::{as_boolean_array, as_primitive_array},
+ Result,
+ };
+ use paste::paste;
#[test]
fn neg_op() -> Result<()> {
@@ -149,4 +224,78 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn scalar_bitwise_not_op() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("t",
DataType::UInt8, true)]));
+ let dummy_input = Arc::new(UInt8Array::from(vec![Some(1u8)]));
+ let dummy_batch = RecordBatch::try_new(schema, vec![dummy_input])?;
+
+ let expr = bitwise_not(lit(0u8))?;
+ match expr.evaluate(&dummy_batch)? {
+ ColumnarValue::Scalar(v) => assert_eq!(v,
ScalarValue::UInt8(Some(255))),
+ _ => unreachable!("should be ColumnarValue::Scalar datatype"),
+ }
+
+ let expr = bitwise_not(lit(1u32))?;
+ match expr.evaluate(&dummy_batch)? {
+ ColumnarValue::Scalar(v) => {
+ assert_eq!(v, ScalarValue::UInt32(Some(u32::MAX - 1)))
+ }
+ _ => unreachable!("should be ColumnarValue::Scalar datatype"),
+ }
+
+ let expr = bitwise_not(lit(ScalarValue::UInt16(None)))?;
+ match expr.evaluate(&dummy_batch)? {
+ ColumnarValue::Scalar(v) => assert_eq!(v,
ScalarValue::UInt16(None)),
+ _ => unreachable!("should be ColumnarValue::Scalar datatype"),
+ }
+
+ let expr = bitwise_not(lit(3i8))?;
+ match expr.evaluate(&dummy_batch)? {
+ // 3i8: 0000 0011 => !3i8: 1111 1100 = -4
+ ColumnarValue::Scalar(v) => assert_eq!(v,
ScalarValue::Int8(Some(-4i8))),
+ _ => unreachable!("should be ColumnarValue::Scalar datatype"),
+ }
+
+ Ok(())
+ }
+
+ macro_rules! test_array_bitwise_not_op {
+ ($DATA_TY:tt, $($VALUE:expr),* ) => {
+ let schema = Schema::new(vec![Field::new("a", DataType::$DATA_TY,
true)]);
+ let expr = not(col("a", &schema)?)?;
+ assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TY);
+ assert!(expr.nullable(&schema)?);
+ let mut arr = Vec::new();
+ let mut arr_expected = Vec::new();
+ $(
+ arr.push(Some($VALUE));
+ arr_expected.push(Some(!$VALUE));
+ )+
+ arr.push(None);
+ arr_expected.push(None);
+ let input = paste!{[<$DATA_TY Array>]::from(arr)};
+ let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)};
+ let batch =
+ RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(input)])?;
+ let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
+ let result =
+ as_primitive_array(&result).expect(format!("failed to downcast
to {:?}Array", $DATA_TY).as_str());
+ assert_eq!(result, expected);
+ };
+ }
+
+ #[test]
+ fn array_bitwise_not_op() -> Result<()> {
+ test_array_bitwise_not_op!(UInt8, 0x1, 0xF2);
+ test_array_bitwise_not_op!(UInt16, 32u16, 255u16);
+ test_array_bitwise_not_op!(UInt32, 144u32, 166u32);
+ test_array_bitwise_not_op!(UInt64, 123u64, 321u64);
+ test_array_bitwise_not_op!(Int8, -1i8, 1i8);
+ test_array_bitwise_not_op!(Int16, -123i16, 123i16);
+ test_array_bitwise_not_op!(Int32, -1234i32, 1234i32);
+ test_array_bitwise_not_op!(Int64, -12345i64, 12345i64);
+ Ok(())
+ }
}