This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new b7f477237c Port ArrayHas family to `functions-array` (#9496)
b7f477237c is described below
commit b7f477237cc91bbdf90e655db181ca4f0a64fc25
Author: Jay Zhan <[email protected]>
AuthorDate: Sat Mar 9 08:52:41 2024 +0800
Port ArrayHas family to `functions-array` (#9496)
* array has rewrite
Signed-off-by: jayzhan211 <[email protected]>
* first draft
Signed-off-by: jayzhan211 <[email protected]>
* rm dims
Signed-off-by: jayzhan211 <[email protected]>
* replace optimizer
Signed-off-by: jayzhan211 <[email protected]>
* remove proto and import udf
Signed-off-by: jayzhan211 <[email protected]>
* Remove unecessary dependency
* Add doc
Co-authored-by: Andrew Lamb <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion-cli/Cargo.lock | 40 +--
datafusion/expr/src/built_in_function.rs | 25 +-
datafusion/expr/src/expr_fn.rs | 18 --
datafusion/functions-array/Cargo.toml | 1 +
datafusion/functions-array/src/array_has.rs | 306 +++++++++++++++++++++
datafusion/functions-array/src/lib.rs | 8 +
datafusion/functions-array/src/utils.rs | 34 +++
datafusion/optimizer/Cargo.toml | 4 +-
datafusion/optimizer/src/analyzer/mod.rs | 3 +
datafusion/optimizer/src/analyzer/rewrite_expr.rs | 51 ++++
datafusion/physical-expr/src/array_expressions.rs | 127 ---------
datafusion/physical-expr/src/expressions/binary.rs | 6 +-
datafusion/physical-expr/src/functions.rs | 9 -
datafusion/proto/proto/datafusion.proto | 6 +-
datafusion/proto/src/generated/pbjson.rs | 9 -
datafusion/proto/src/generated/prost.rs | 12 +-
datafusion/proto/src/logical_plan/from_proto.rs | 30 +-
datafusion/proto/src/logical_plan/to_proto.rs | 3 -
18 files changed, 444 insertions(+), 248 deletions(-)
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 3afd26a6e7..5e3c8648fc 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -808,9 +808,9 @@ dependencies = [
[[package]]
name = "bumpalo"
-version = "3.15.3"
+version = "3.15.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8ea184aa71bb362a1157c896979544cc23974e08fd265f29ea96b59f0b4a555b"
+checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa"
[[package]]
name = "byteorder"
@@ -857,9 +857,9 @@ dependencies = [
[[package]]
name = "cc"
-version = "1.0.89"
+version = "1.0.90"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a0ba8f7aaa012f30d5b2861462f6708eccd49c3c39863fe083a308035f63d723"
+checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5"
dependencies = [
"jobserver",
"libc",
@@ -873,9 +873,9 @@ checksum =
"baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
-version = "0.4.34"
+version = "0.4.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5bc015644b92d5890fab7489e49d21f879d5c990186827d42ec511919404f38b"
+checksum = "8eaf5903dcbc0a39312feb77df2ff4c76387d591b9fc7b04a238dcf8bb62639a"
dependencies = [
"android-tzdata",
"iana-time-zone",
@@ -1268,6 +1268,7 @@ dependencies = [
"datafusion-common",
"datafusion-execution",
"datafusion-expr",
+ "itertools",
"log",
"paste",
]
@@ -2344,9 +2345,9 @@ dependencies = [
[[package]]
name = "object_store"
-version = "0.9.0"
+version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d139f545f64630e2e3688fd9f81c470888ab01edeb72d13b4e86c566f1130000"
+checksum = "b8718f8b65fdf67a45108d1548347d4af7d71fb81ce727bbf9e3b2535e079db3"
dependencies = [
"async-trait",
"base64 0.21.7",
@@ -2356,6 +2357,7 @@ dependencies = [
"humantime",
"hyper",
"itertools",
+ "md-5",
"parking_lot",
"percent-encoding",
"quick-xml",
@@ -2534,18 +2536,18 @@ dependencies = [
[[package]]
name = "pin-project"
-version = "1.1.4"
+version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0302c4a0442c456bd56f841aee5c3bfd17967563f6fadc9ceb9f9c23cf3807e0"
+checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
-version = "1.1.4"
+version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690"
+checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
dependencies = [
"proc-macro2",
"quote",
@@ -2767,9 +2769,9 @@ checksum =
"c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
[[package]]
name = "reqwest"
-version = "0.11.24"
+version = "0.11.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251"
+checksum = "0eea5a9eb898d3783f17c6407670e3592fd174cb81a10e51d4c37f49450b9946"
dependencies = [
"base64 0.21.7",
"bytes",
@@ -3323,20 +3325,20 @@ checksum =
"2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
[[package]]
name = "system-configuration"
-version = "0.5.1"
+version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
+checksum = "658bc6ee10a9b4fcf576e9b0819d95ec16f4d2c02d39fd83ac1c8789785c4a42"
dependencies = [
- "bitflags 1.3.2",
+ "bitflags 2.4.2",
"core-foundation",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
-version = "0.5.0"
+version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
+checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
dependencies = [
"core-foundation-sys",
"libc",
diff --git a/datafusion/expr/src/built_in_function.rs
b/datafusion/expr/src/built_in_function.rs
index a763a58379..be10da3669 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -118,12 +118,6 @@ pub enum BuiltinScalarFunction {
ArraySort,
/// array_concat
ArrayConcat,
- /// array_has
- ArrayHas,
- /// array_has_all
- ArrayHasAll,
- /// array_has_any
- ArrayHasAny,
/// array_pop_front
ArrayPopFront,
/// array_pop_back
@@ -367,9 +361,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArraySort => Volatility::Immutable,
BuiltinScalarFunction::ArrayConcat => Volatility::Immutable,
BuiltinScalarFunction::ArrayEmpty => Volatility::Immutable,
- BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable,
- BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable,
- BuiltinScalarFunction::ArrayHas => Volatility::Immutable,
BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable,
BuiltinScalarFunction::ArrayElement => Volatility::Immutable,
BuiltinScalarFunction::ArrayExcept => Volatility::Immutable,
@@ -536,10 +527,7 @@ impl BuiltinScalarFunction {
Ok(expr_type)
}
- BuiltinScalarFunction::ArrayHasAll
- | BuiltinScalarFunction::ArrayHasAny
- | BuiltinScalarFunction::ArrayHas
- | BuiltinScalarFunction::ArrayEmpty => Ok(Boolean),
+ BuiltinScalarFunction::ArrayEmpty => Ok(Boolean),
BuiltinScalarFunction::ArrayDistinct =>
Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] {
List(field)
@@ -849,12 +837,6 @@ impl BuiltinScalarFunction {
}
BuiltinScalarFunction::ArrayExcept => Signature::any(2,
self.volatility()),
BuiltinScalarFunction::Flatten =>
Signature::array(self.volatility()),
- BuiltinScalarFunction::ArrayHasAll |
BuiltinScalarFunction::ArrayHasAny => {
- Signature::any(2, self.volatility())
- }
- BuiltinScalarFunction::ArrayHas => {
- Signature::array_and_element(self.volatility())
- }
BuiltinScalarFunction::ArrayLength => {
Signature::variadic_any(self.volatility())
}
@@ -1423,11 +1405,6 @@ impl BuiltinScalarFunction {
],
BuiltinScalarFunction::ArrayExcept => &["array_except",
"list_except"],
BuiltinScalarFunction::Flatten => &["flatten"],
- BuiltinScalarFunction::ArrayHasAll => &["array_has_all",
"list_has_all"],
- BuiltinScalarFunction::ArrayHasAny => &["array_has_any",
"list_has_any"],
- BuiltinScalarFunction::ArrayHas => {
- &["array_has", "list_has", "array_contains", "list_contains"]
- }
BuiltinScalarFunction::ArrayLength => &["array_length",
"list_length"],
BuiltinScalarFunction::ArrayPopFront => {
&["array_pop_front", "list_pop_front"]
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 99f44a73c1..ad69208ce9 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -611,30 +611,12 @@ scalar_expr!(
);
nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays.");
-scalar_expr!(
- ArrayHas,
- array_has,
- first_array second_array,
- "returns true, if the element appears in the first array, otherwise false."
-);
scalar_expr!(
ArrayEmpty,
array_empty,
array,
"returns true for an empty array or false for a non-empty array."
);
-scalar_expr!(
- ArrayHasAll,
- array_has_all,
- first_array second_array,
- "returns true if each element of the second array appears in the first
array; otherwise, it returns false."
-);
-scalar_expr!(
- ArrayHasAny,
- array_has_any,
- first_array second_array,
- "returns true if at least one element of the second array appears in the
first array; otherwise, it returns false."
-);
scalar_expr!(
Flatten,
flatten,
diff --git a/datafusion/functions-array/Cargo.toml
b/datafusion/functions-array/Cargo.toml
index 088babdf50..17be817238 100644
--- a/datafusion/functions-array/Cargo.toml
+++ b/datafusion/functions-array/Cargo.toml
@@ -41,5 +41,6 @@ arrow = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
+itertools = { version = "0.12", features = ["use_std"] }
log = { workspace = true }
paste = "1.0.14"
diff --git a/datafusion/functions-array/src/array_has.rs
b/datafusion/functions-array/src/array_has.rs
new file mode 100644
index 0000000000..17c0ad1619
--- /dev/null
+++ b/datafusion/functions-array/src/array_has.rs
@@ -0,0 +1,306 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`ScalarUDFImpl`] definitions for array functions.
+
+use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait};
+use arrow::datatypes::DataType;
+use arrow::row::{RowConverter, SortField};
+use datafusion_common::cast::as_generic_list_array;
+use datafusion_common::{exec_err, Result};
+use datafusion_expr::expr::ScalarFunction;
+use datafusion_expr::Expr;
+use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
+
+use itertools::Itertools;
+
+use crate::utils::check_datatypes;
+
+use std::any::Any;
+use std::sync::Arc;
+
+// Create static instances of ScalarUDFs for each function
+make_udf_function!(ArrayHas,
+ array_has,
+ first_array second_array, // arg name
+ "returns true, if the element appears in the first array, otherwise
false.", // doc
+ array_has_udf // internal function name
+);
+make_udf_function!(ArrayHasAll,
+ array_has_all,
+ first_array second_array, // arg name
+ "returns true if each element of the second array appears in the first
array; otherwise, it returns false.", // doc
+ array_has_all_udf // internal function name
+);
+make_udf_function!(ArrayHasAny,
+ array_has_any,
+ first_array second_array, // arg name
+ "returns true if at least one element of the second array appears in the
first array; otherwise, it returns false.", // doc
+ array_has_any_udf // internal function name
+);
+
+#[derive(Debug)]
+pub(super) struct ArrayHas {
+ signature: Signature,
+ aliases: Vec<String>,
+}
+
+impl ArrayHas {
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::array_and_element(Volatility::Immutable),
+ aliases: vec![
+ String::from("array_has"),
+ String::from("list_has"),
+ String::from("array_contains"),
+ String::from("list_contains"),
+ ],
+ }
+ }
+}
+
+impl ScalarUDFImpl for ArrayHas {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "array_has"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _: &[DataType]) ->
datafusion_common::Result<DataType> {
+ Ok(DataType::Boolean)
+ }
+
+ fn invoke(&self, args: &[ColumnarValue]) ->
datafusion_common::Result<ColumnarValue> {
+ let args = ColumnarValue::values_to_arrays(args)?;
+
+ if args.len() != 2 {
+ return exec_err!("array_has needs two arguments");
+ }
+
+ let array_type = args[0].data_type();
+
+ match array_type {
+ DataType::List(_) => general_array_has_dispatch::<i32>(
+ &args[0],
+ &args[1],
+ ComparisonType::Single,
+ )
+ .map(ColumnarValue::Array),
+ DataType::LargeList(_) => general_array_has_dispatch::<i64>(
+ &args[0],
+ &args[1],
+ ComparisonType::Single,
+ )
+ .map(ColumnarValue::Array),
+ _ => exec_err!("array_has does not support type
'{array_type:?}'."),
+ }
+ }
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
+}
+
+#[derive(Debug)]
+pub(super) struct ArrayHasAll {
+ signature: Signature,
+ aliases: Vec<String>,
+}
+
+impl ArrayHasAll {
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::any(2, Volatility::Immutable),
+ aliases: vec![String::from("array_has_all"),
String::from("list_has_all")],
+ }
+ }
+}
+
+impl ScalarUDFImpl for ArrayHasAll {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "array_has_all"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _: &[DataType]) ->
datafusion_common::Result<DataType> {
+ Ok(DataType::Boolean)
+ }
+
+ fn invoke(&self, args: &[ColumnarValue]) ->
datafusion_common::Result<ColumnarValue> {
+ let args = ColumnarValue::values_to_arrays(args)?;
+ if args.len() != 2 {
+ return exec_err!("array_has_all needs two arguments");
+ }
+
+ let array_type = args[0].data_type();
+
+ match array_type {
+ DataType::List(_) => {
+ general_array_has_dispatch::<i32>(&args[0], &args[1],
ComparisonType::All)
+ .map(ColumnarValue::Array)
+ }
+ DataType::LargeList(_) => {
+ general_array_has_dispatch::<i64>(&args[0], &args[1],
ComparisonType::All)
+ .map(ColumnarValue::Array)
+ }
+ _ => exec_err!("array_has_all does not support type
'{array_type:?}'."),
+ }
+ }
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
+}
+
+#[derive(Debug)]
+pub(super) struct ArrayHasAny {
+ signature: Signature,
+ aliases: Vec<String>,
+}
+
+impl ArrayHasAny {
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::any(2, Volatility::Immutable),
+ aliases: vec![String::from("array_has_any"),
String::from("list_has_any")],
+ }
+ }
+}
+
+impl ScalarUDFImpl for ArrayHasAny {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "array_has_any"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _: &[DataType]) ->
datafusion_common::Result<DataType> {
+ Ok(DataType::Boolean)
+ }
+
+ fn invoke(&self, args: &[ColumnarValue]) ->
datafusion_common::Result<ColumnarValue> {
+ let args = ColumnarValue::values_to_arrays(args)?;
+
+ if args.len() != 2 {
+ return exec_err!("array_has_any needs two arguments");
+ }
+
+ let array_type = args[0].data_type();
+
+ match array_type {
+ DataType::List(_) => {
+ general_array_has_dispatch::<i32>(&args[0], &args[1],
ComparisonType::Any)
+ .map(ColumnarValue::Array)
+ }
+ DataType::LargeList(_) => {
+ general_array_has_dispatch::<i64>(&args[0], &args[1],
ComparisonType::Any)
+ .map(ColumnarValue::Array)
+ }
+ _ => exec_err!("array_has_any does not support type
'{array_type:?}'."),
+ }
+ }
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
+}
+
+/// Represents the type of comparison for array_has.
+#[derive(Debug, PartialEq)]
+enum ComparisonType {
+ // array_has_all
+ All,
+ // array_has_any
+ Any,
+ // array_has
+ Single,
+}
+
+fn general_array_has_dispatch<O: OffsetSizeTrait>(
+ array: &ArrayRef,
+ sub_array: &ArrayRef,
+ comparison_type: ComparisonType,
+) -> Result<ArrayRef> {
+ let array = if comparison_type == ComparisonType::Single {
+ let arr = as_generic_list_array::<O>(array)?;
+ check_datatypes("array_has", &[arr.values(), sub_array])?;
+ arr
+ } else {
+ check_datatypes("array_has", &[array, sub_array])?;
+ as_generic_list_array::<O>(array)?
+ };
+
+ let mut boolean_builder = BooleanArray::builder(array.len());
+
+ let converter =
RowConverter::new(vec![SortField::new(array.value_type())])?;
+
+ let element = sub_array.clone();
+ let sub_array = if comparison_type != ComparisonType::Single {
+ as_generic_list_array::<O>(sub_array)?
+ } else {
+ array
+ };
+
+ for (row_idx, (arr, sub_arr)) in
array.iter().zip(sub_array.iter()).enumerate() {
+ if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
+ let arr_values = converter.convert_columns(&[arr])?;
+ let sub_arr_values = if comparison_type != ComparisonType::Single {
+ converter.convert_columns(&[sub_arr])?
+ } else {
+ converter.convert_columns(&[element.clone()])?
+ };
+
+ let mut res = match comparison_type {
+ ComparisonType::All => sub_arr_values
+ .iter()
+ .dedup()
+ .all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
+ ComparisonType::Any => sub_arr_values
+ .iter()
+ .dedup()
+ .any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
+ ComparisonType::Single => arr_values
+ .iter()
+ .dedup()
+ .any(|x| x == sub_arr_values.row(row_idx)),
+ };
+
+ if comparison_type == ComparisonType::Any {
+ res |= res;
+ }
+
+ boolean_builder.append_value(res);
+ }
+ }
+ Ok(Arc::new(boolean_builder.finish()))
+}
diff --git a/datafusion/functions-array/src/lib.rs
b/datafusion/functions-array/src/lib.rs
index e4cdf69aa9..710f49761f 100644
--- a/datafusion/functions-array/src/lib.rs
+++ b/datafusion/functions-array/src/lib.rs
@@ -28,8 +28,10 @@
#[macro_use]
pub mod macros;
+mod array_has;
mod kernels;
mod udf;
+mod utils;
use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
@@ -39,6 +41,9 @@ use std::sync::Arc;
/// Fluent-style API for creating `Expr`s
pub mod expr_fn {
+ pub use super::array_has::array_has;
+ pub use super::array_has::array_has_all;
+ pub use super::array_has::array_has_any;
pub use super::udf::array_dims;
pub use super::udf::array_ndims;
pub use super::udf::array_to_string;
@@ -56,6 +61,9 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) ->
Result<()> {
udf::array_dims_udf(),
udf::cardinality_udf(),
udf::array_ndims_udf(),
+ array_has::array_has_udf(),
+ array_has::array_has_all_udf(),
+ array_has::array_has_any_udf(),
];
functions.into_iter().try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
diff --git a/datafusion/functions-array/src/utils.rs
b/datafusion/functions-array/src/utils.rs
new file mode 100644
index 0000000000..d374a9f66b
--- /dev/null
+++ b/datafusion/functions-array/src/utils.rs
@@ -0,0 +1,34 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! array function utils
+
+use arrow::{array::ArrayRef, datatypes::DataType};
+use datafusion_common::{plan_err, Result};
+
+pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
+ let data_type = args[0].data_type();
+ if !args.iter().all(|arg| {
+ arg.data_type().equals_datatype(data_type)
+ || arg.data_type().equals_datatype(&DataType::Null)
+ }) {
+ let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
+ return plan_err!("{name} received incompatible types: '{types:?}'.");
+ }
+
+ Ok(())
+}
diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml
index 861715b351..f497f2ec86 100644
--- a/datafusion/optimizer/Cargo.toml
+++ b/datafusion/optimizer/Cargo.toml
@@ -33,8 +33,9 @@ name = "datafusion_optimizer"
path = "src/lib.rs"
[features]
+array_expressions = ["datafusion-functions-array"]
crypto_expressions = ["datafusion-physical-expr/crypto_expressions"]
-default = ["unicode_expressions", "crypto_expressions", "regex_expressions"]
+default = ["unicode_expressions", "crypto_expressions", "regex_expressions",
"array_expressions"]
regex_expressions = ["datafusion-physical-expr/regex_expressions"]
unicode_expressions = ["datafusion-physical-expr/unicode_expressions"]
@@ -44,6 +45,7 @@ async-trait = { workspace = true }
chrono = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
+datafusion-functions-array = { workspace = true, optional = true }
datafusion-physical-expr = { workspace = true }
hashbrown = { version = "0.14", features = ["raw"] }
itertools = { workspace = true }
diff --git a/datafusion/optimizer/src/analyzer/mod.rs
b/datafusion/optimizer/src/analyzer/mod.rs
index 08caa4be60..ad852b460f 100644
--- a/datafusion/optimizer/src/analyzer/mod.rs
+++ b/datafusion/optimizer/src/analyzer/mod.rs
@@ -38,6 +38,9 @@ use datafusion_expr::{Expr, LogicalPlan};
use log::debug;
use std::sync::Arc;
+#[cfg(feature = "array_expressions")]
+use datafusion_functions_array::expr_fn::array_has_all;
+
use self::rewrite_expr::OperatorToFunction;
/// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make
diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs
b/datafusion/optimizer/src/analyzer/rewrite_expr.rs
index 41ebcd8e50..3ea5596b67 100644
--- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs
+++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs
@@ -118,11 +118,62 @@ impl TreeNodeRewriter for OperatorToFunctionRewriter {
args: vec![left, right],
})));
}
+
+ // TODO: change OperatorToFunction to OperatoToArrayFunction and
configure it with array_expressions feature
+ // after other array functions are udf-based
+ #[cfg(feature = "array_expressions")]
+ if let Some(expr) = rewrite_array_has_all_operator_to_func(left,
op, right) {
+ return Ok(Transformed::yes(expr));
+ }
}
Ok(Transformed::no(expr))
}
}
+// Note This rewrite is only done if the built in DataFusion
`array_expressions` feature is enabled.
+// Even if users implement their own array functions, those functions are not
equal to the DataFusion
+// udf based array functions, so this rewrite is not corrrect
+#[cfg(feature = "array_expressions")]
+fn rewrite_array_has_all_operator_to_func(
+ left: &Expr,
+ op: Operator,
+ right: &Expr,
+) -> Option<Expr> {
+ use super::array_has_all;
+
+ if op != Operator::AtArrow && op != Operator::ArrowAt {
+ return None;
+ }
+
+ match (left, right) {
+ // array1 @> array2 -> array_has_all(array1, array2)
+ // array1 <@ array2 -> array_has_all(array2, array1)
+ (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+ args: _left_args,
+ }),
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+ args: _right_args,
+ }),
+ ) => {
+ let left = left.clone();
+ let right = right.clone();
+
+ let expr = if let Operator::ArrowAt = op {
+ array_has_all(right, left)
+ } else {
+ array_has_all(left, right)
+ };
+ Some(expr)
+ }
+ _ => None,
+ }
+}
+
/// Summary of the logic below:
///
/// 1) array || array -> array concat
diff --git a/datafusion/physical-expr/src/array_expressions.rs
b/datafusion/physical-expr/src/array_expressions.rs
index c10f5df540..8d2a283a05 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -2018,133 +2018,6 @@ pub fn array_length(args: &[ArrayRef]) ->
Result<ArrayRef> {
}
}
-/// Represents the type of comparison for array_has.
-#[derive(Debug, PartialEq)]
-enum ComparisonType {
- // array_has_all
- All,
- // array_has_any
- Any,
- // array_has
- Single,
-}
-
-fn general_array_has_dispatch<O: OffsetSizeTrait>(
- array: &ArrayRef,
- sub_array: &ArrayRef,
- comparison_type: ComparisonType,
-) -> Result<ArrayRef> {
- let array = if comparison_type == ComparisonType::Single {
- let arr = as_generic_list_array::<O>(array)?;
- check_datatypes("array_has", &[arr.values(), sub_array])?;
- arr
- } else {
- check_datatypes("array_has", &[array, sub_array])?;
- as_generic_list_array::<O>(array)?
- };
-
- let mut boolean_builder = BooleanArray::builder(array.len());
-
- let converter =
RowConverter::new(vec![SortField::new(array.value_type())])?;
-
- let element = sub_array.clone();
- let sub_array = if comparison_type != ComparisonType::Single {
- as_generic_list_array::<O>(sub_array)?
- } else {
- array
- };
-
- for (row_idx, (arr, sub_arr)) in
array.iter().zip(sub_array.iter()).enumerate() {
- if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
- let arr_values = converter.convert_columns(&[arr])?;
- let sub_arr_values = if comparison_type != ComparisonType::Single {
- converter.convert_columns(&[sub_arr])?
- } else {
- converter.convert_columns(&[element.clone()])?
- };
-
- let mut res = match comparison_type {
- ComparisonType::All => sub_arr_values
- .iter()
- .dedup()
- .all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
- ComparisonType::Any => sub_arr_values
- .iter()
- .dedup()
- .any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
- ComparisonType::Single => arr_values
- .iter()
- .dedup()
- .any(|x| x == sub_arr_values.row(row_idx)),
- };
-
- if comparison_type == ComparisonType::Any {
- res |= res;
- }
-
- boolean_builder.append_value(res);
- }
- }
- Ok(Arc::new(boolean_builder.finish()))
-}
-
-/// Array_has SQL function
-pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
- if args.len() != 2 {
- return exec_err!("array_has needs two arguments");
- }
-
- let array_type = args[0].data_type();
-
- match array_type {
- DataType::List(_) => {
- general_array_has_dispatch::<i32>(&args[0], &args[1],
ComparisonType::Single)
- }
- DataType::LargeList(_) => {
- general_array_has_dispatch::<i64>(&args[0], &args[1],
ComparisonType::Single)
- }
- _ => exec_err!("array_has does not support type '{array_type:?}'."),
- }
-}
-
-/// Array_has_any SQL function
-pub fn array_has_any(args: &[ArrayRef]) -> Result<ArrayRef> {
- if args.len() != 2 {
- return exec_err!("array_has_any needs two arguments");
- }
-
- let array_type = args[0].data_type();
-
- match array_type {
- DataType::List(_) => {
- general_array_has_dispatch::<i32>(&args[0], &args[1],
ComparisonType::Any)
- }
- DataType::LargeList(_) => {
- general_array_has_dispatch::<i64>(&args[0], &args[1],
ComparisonType::Any)
- }
- _ => exec_err!("array_has_any does not support type
'{array_type:?}'."),
- }
-}
-
-/// Array_has_all SQL function
-pub fn array_has_all(args: &[ArrayRef]) -> Result<ArrayRef> {
- if args.len() != 2 {
- return exec_err!("array_has_all needs two arguments");
- }
-
- let array_type = args[0].data_type();
-
- match array_type {
- DataType::List(_) => {
- general_array_has_dispatch::<i32>(&args[0], &args[1],
ComparisonType::All)
- }
- DataType::LargeList(_) => {
- general_array_has_dispatch::<i64>(&args[0], &args[1],
ComparisonType::All)
- }
- _ => exec_err!("array_has_all does not support type
'{array_type:?}'."),
- }
-}
-
/// Splits string at occurrences of delimiter and returns an array of parts
/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]'
pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) ->
Result<ArrayRef> {
diff --git a/datafusion/physical-expr/src/expressions/binary.rs
b/datafusion/physical-expr/src/expressions/binary.rs
index f1842458d5..bc107e169d 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -20,7 +20,6 @@ mod kernels;
use std::hash::{Hash, Hasher};
use std::{any::Any, sync::Arc};
-use crate::array_expressions::array_has_all;
use crate::expressions::datum::{apply, apply_cmp};
use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::physical_expr::down_cast_any_ref;
@@ -602,8 +601,9 @@ impl BinaryExpr {
BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
StringConcat => binary_string_array_op!(left, right,
concat_elements),
- AtArrow => array_has_all(&[left, right]),
- ArrowAt => array_has_all(&[right, left]),
+ AtArrow | ArrowAt => {
+ unreachable!("ArrowAt and AtArrow should be rewritten to
function")
+ }
}
}
}
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index db53ac986d..eebbb1dbea 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -314,15 +314,6 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayEmpty => Arc::new(|args| {
make_scalar_function_inner(array_expressions::array_empty)(args)
}),
- BuiltinScalarFunction::ArrayHasAll => Arc::new(|args| {
- make_scalar_function_inner(array_expressions::array_has_all)(args)
- }),
- BuiltinScalarFunction::ArrayHasAny => Arc::new(|args| {
- make_scalar_function_inner(array_expressions::array_has_any)(args)
- }),
- BuiltinScalarFunction::ArrayHas => Arc::new(|args| {
- make_scalar_function_inner(array_expressions::array_has)(args)
- }),
BuiltinScalarFunction::ArrayDistinct => Arc::new(|args| {
make_scalar_function_inner(array_expressions::array_distinct)(args)
}),
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index faffe57c07..c5b20986c3 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -651,9 +651,9 @@ enum ScalarFunction {
ArrayElement = 99;
ArraySlice = 100;
Cot = 103;
- ArrayHas = 104;
- ArrayHasAny = 105;
- ArrayHasAll = 106;
+ // 104 was ArrayHas
+ // 105 was ArrayHasAny
+ // 106 was ArrayHasAll
ArrayRemoveN = 107;
ArrayReplaceN = 108;
ArrayRemoveAll = 109;
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 3415574c15..b99e957406 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -22410,9 +22410,6 @@ impl serde::Serialize for ScalarFunction {
Self::ArrayElement => "ArrayElement",
Self::ArraySlice => "ArraySlice",
Self::Cot => "Cot",
- Self::ArrayHas => "ArrayHas",
- Self::ArrayHasAny => "ArrayHasAny",
- Self::ArrayHasAll => "ArrayHasAll",
Self::ArrayRemoveN => "ArrayRemoveN",
Self::ArrayReplaceN => "ArrayReplaceN",
Self::ArrayRemoveAll => "ArrayRemoveAll",
@@ -22538,9 +22535,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"ArrayElement",
"ArraySlice",
"Cot",
- "ArrayHas",
- "ArrayHasAny",
- "ArrayHasAll",
"ArrayRemoveN",
"ArrayReplaceN",
"ArrayRemoveAll",
@@ -22695,9 +22689,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"ArrayElement" => Ok(ScalarFunction::ArrayElement),
"ArraySlice" => Ok(ScalarFunction::ArraySlice),
"Cot" => Ok(ScalarFunction::Cot),
- "ArrayHas" => Ok(ScalarFunction::ArrayHas),
- "ArrayHasAny" => Ok(ScalarFunction::ArrayHasAny),
- "ArrayHasAll" => Ok(ScalarFunction::ArrayHasAll),
"ArrayRemoveN" => Ok(ScalarFunction::ArrayRemoveN),
"ArrayReplaceN" => Ok(ScalarFunction::ArrayReplaceN),
"ArrayRemoveAll" => Ok(ScalarFunction::ArrayRemoveAll),
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 4c4b17f1a8..62b3d39580 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2739,9 +2739,9 @@ pub enum ScalarFunction {
ArrayElement = 99,
ArraySlice = 100,
Cot = 103,
- ArrayHas = 104,
- ArrayHasAny = 105,
- ArrayHasAll = 106,
+ /// 104 was ArrayHas
+ /// 105 was ArrayHasAny
+ /// 106 was ArrayHasAll
ArrayRemoveN = 107,
ArrayReplaceN = 108,
ArrayRemoveAll = 109,
@@ -2872,9 +2872,6 @@ impl ScalarFunction {
ScalarFunction::ArrayElement => "ArrayElement",
ScalarFunction::ArraySlice => "ArraySlice",
ScalarFunction::Cot => "Cot",
- ScalarFunction::ArrayHas => "ArrayHas",
- ScalarFunction::ArrayHasAny => "ArrayHasAny",
- ScalarFunction::ArrayHasAll => "ArrayHasAll",
ScalarFunction::ArrayRemoveN => "ArrayRemoveN",
ScalarFunction::ArrayReplaceN => "ArrayReplaceN",
ScalarFunction::ArrayRemoveAll => "ArrayRemoveAll",
@@ -2994,9 +2991,6 @@ impl ScalarFunction {
"ArrayElement" => Some(Self::ArrayElement),
"ArraySlice" => Some(Self::ArraySlice),
"Cot" => Some(Self::Cot),
- "ArrayHas" => Some(Self::ArrayHas),
- "ArrayHasAny" => Some(Self::ArrayHasAny),
- "ArrayHasAll" => Some(Self::ArrayHasAll),
"ArrayRemoveN" => Some(Self::ArrayRemoveN),
"ArrayReplaceN" => Some(Self::ArrayReplaceN),
"ArrayRemoveAll" => Some(Self::ArrayRemoveAll),
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 6476afca43..c26b8acbf1 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -48,14 +48,13 @@ use datafusion_expr::expr::Unnest;
use datafusion_expr::window_frame::{check_window_frame,
regularize_window_order_by};
use datafusion_expr::{
acosh, array, array_append, array_concat, array_distinct, array_element,
array_empty,
- array_except, array_has, array_has_all, array_has_any, array_intersect,
array_length,
- array_pop_back, array_pop_front, array_position, array_positions,
array_prepend,
- array_remove, array_remove_all, array_remove_n, array_repeat,
array_replace,
- array_replace_all, array_replace_n, array_resize, array_slice, array_sort,
- array_union, arrow_typeof, ascii, asinh, atan, atan2, atanh, bit_length,
btrim, cbrt,
- ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos,
cosh, cot,
- current_date, current_time, date_bin, date_part, date_trunc, degrees,
digest,
- ends_with, exp,
+ array_except, array_intersect, array_length, array_pop_back,
array_pop_front,
+ array_position, array_positions, array_prepend, array_remove,
array_remove_all,
+ array_remove_n, array_repeat, array_replace, array_replace_all,
array_replace_n,
+ array_resize, array_slice, array_sort, array_union, arrow_typeof, ascii,
asinh, atan,
+ atan2, atanh, bit_length, btrim, cbrt, ceil, character_length, chr,
coalesce,
+ concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time,
date_bin,
+ date_part, date_trunc, degrees, digest, ends_with, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, find_in_set, flatten, floor, from_unixtime, gcd, initcap,
iszero, lcm,
left, levenshtein, ln, log, log10, log2,
@@ -483,9 +482,6 @@ impl From<&protobuf::ScalarFunction> for
BuiltinScalarFunction {
ScalarFunction::ArrayConcat => Self::ArrayConcat,
ScalarFunction::ArrayEmpty => Self::ArrayEmpty,
ScalarFunction::ArrayExcept => Self::ArrayExcept,
- ScalarFunction::ArrayHasAll => Self::ArrayHasAll,
- ScalarFunction::ArrayHasAny => Self::ArrayHasAny,
- ScalarFunction::ArrayHas => Self::ArrayHas,
ScalarFunction::ArrayDistinct => Self::ArrayDistinct,
ScalarFunction::ArrayElement => Self::ArrayElement,
ScalarFunction::Flatten => Self::Flatten,
@@ -1454,18 +1450,6 @@ pub fn parse_expr(
parse_expr(&args[0], registry, codec)?,
parse_expr(&args[1], registry, codec)?,
)),
- ScalarFunction::ArrayHasAll => Ok(array_has_all(
- parse_expr(&args[0], registry, codec)?,
- parse_expr(&args[1], registry, codec)?,
- )),
- ScalarFunction::ArrayHasAny => Ok(array_has_any(
- parse_expr(&args[0], registry, codec)?,
- parse_expr(&args[1], registry, codec)?,
- )),
- ScalarFunction::ArrayHas => Ok(array_has(
- parse_expr(&args[0], registry, codec)?,
- parse_expr(&args[1], registry, codec)?,
- )),
ScalarFunction::ArrayIntersect => Ok(array_intersect(
parse_expr(&args[0], registry, codec)?,
parse_expr(&args[1], registry, codec)?,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 0ee43ffd27..55c8542d97 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1461,9 +1461,6 @@ impl TryFrom<&BuiltinScalarFunction> for
protobuf::ScalarFunction {
BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat,
BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty,
BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept,
- BuiltinScalarFunction::ArrayHasAll => Self::ArrayHasAll,
- BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny,
- BuiltinScalarFunction::ArrayHas => Self::ArrayHas,
BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct,
BuiltinScalarFunction::ArrayElement => Self::ArrayElement,
BuiltinScalarFunction::Flatten => Self::Flatten,