This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a25b9772d1 Support array concatenation for arrays with different
dimensions (#6872)
a25b9772d1 is described below
commit a25b9772d1ff1d57b657c71e545f795523b28e17
Author: Jay Zhan <[email protected]>
AuthorDate: Sun Jul 9 21:48:10 2023 +0800
Support array concatenation for arrays with different dimensions (#6872)
* add diff dims support for array concat
Signed-off-by: jayzhan211 <[email protected]>
* address comment
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
.../core/tests/sqllogictests/test_files/array.slt | 24 ++++++
datafusion/physical-expr/src/array_expressions.rs | 88 ++++++++++++++++------
2 files changed, 87 insertions(+), 25 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt
b/datafusion/core/tests/sqllogictests/test_files/array.slt
index 7eebb23d9c..2d7a609989 100644
--- a/datafusion/core/tests/sqllogictests/test_files/array.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/array.slt
@@ -365,6 +365,30 @@ select array_concat(make_array(), make_array(2, 3));
----
[2, 3]
+# array_concat with different dimensions #1 (2D + 1D)
+query ?
+select array_concat(make_array([1,2], [3,4]), make_array(5, 6))
+----
+[[1, 2], [3, 4], [5, 6]]
+
+# array_concat with different dimensions #2 (1D + 2D)
+query ?
+select array_concat(make_array(5, 6), make_array([1,2], [3,4]))
+----
+[[5, 6], [1, 2], [3, 4]]
+
+# array_concat with different dimensions #3 (2D + 1D + 1D)
+query ?
+select array_concat(make_array([1,2], [3,4]), make_array(5, 6),
make_array(7,8))
+----
+[[1, 2], [3, 4], [5, 6], [7, 8]]
+
+# array_concat with different dimensions #4 (1D + 2D + 3D)
+query ?
+select array_concat(make_array(10, 20), make_array([30, 40]), make_array([[50,
60]]))
+----
+[[[10, 20]], [[30, 40]], [[50, 60]]]
+
## array_position
# array_position scalar function #1
diff --git a/datafusion/physical-expr/src/array_expressions.rs
b/datafusion/physical-expr/src/array_expressions.rs
index cd174918db..cbf5896b85 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -373,20 +373,83 @@ pub fn array_prepend(args: &[ArrayRef]) ->
Result<ArrayRef> {
Ok(res)
}
+fn compute_array_ndims(arg: u8, arr: ArrayRef) -> Result<u8> {
+ match arr.data_type() {
+ DataType::List(..) => {
+ let list_array = downcast_arg!(arr, ListArray);
+ compute_array_ndims(arg + 1, list_array.value(0))
+ }
+ DataType::Null
+ | DataType::Utf8
+ | DataType::LargeUtf8
+ | DataType::Boolean
+ | DataType::Float32
+ | DataType::Float64
+ | DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64 => Ok(arg),
+ data_type => Err(DataFusionError::NotImplemented(format!(
+ "Array is not implemented for type '{data_type:?}'."
+ ))),
+ }
+}
+
+fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
+ // Find the maximum number of dimensions
+ let max_ndim: u8 = *args
+ .iter()
+ .map(|arr| compute_array_ndims(0, arr.clone()))
+ .collect::<Result<Vec<u8>>>()?
+ .iter()
+ .max()
+ .unwrap();
+
+ // Align the dimensions of the arrays
+ let aligned_args: Result<Vec<ArrayRef>> = args
+ .into_iter()
+ .map(|array| {
+ let ndim = compute_array_ndims(0, array.clone())?;
+ if ndim < max_ndim {
+ let mut aligned_array = array.clone();
+ for _ in 0..(max_ndim - ndim) {
+ let data_type = aligned_array.as_ref().data_type().clone();
+ aligned_array = array_array(&[aligned_array], data_type)?;
+ }
+ Ok(aligned_array)
+ } else {
+ Ok(array.clone())
+ }
+ })
+ .collect();
+
+ aligned_args
+}
+
/// Array_concat/Array_cat SQL function
pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::List(field) => match field.data_type() {
DataType::Null => array_concat(&args[1..]),
_ => {
+ let args = align_array_dimensions(args.to_vec())?;
+
let list_arrays = downcast_vec!(args, ListArray)
.collect::<Result<Vec<&ListArray>>>()?;
+
let len: usize = list_arrays.iter().map(|a|
a.values().len()).sum();
+
let capacity =
Capacities::Array(list_arrays.iter().map(|a|
a.len()).sum());
let array_data: Vec<_> =
list_arrays.iter().map(|a|
a.to_data()).collect::<Vec<_>>();
+
let array_data = array_data.iter().collect();
+
let mut mutable =
MutableArrayData::with_capacities(array_data, false,
capacity);
@@ -1217,31 +1280,6 @@ pub fn array_dims(args: &[ArrayRef]) -> Result<ArrayRef>
{
/// Array_ndims SQL function
pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
- fn compute_array_ndims(arg: u8, arr: ArrayRef) -> Result<u8> {
- match arr.data_type() {
- DataType::List(..) => {
- let list_array = downcast_arg!(arr, ListArray);
- compute_array_ndims(arg + 1, list_array.value(0))
- }
- DataType::Null
- | DataType::Utf8
- | DataType::LargeUtf8
- | DataType::Boolean
- | DataType::Float32
- | DataType::Float64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64 => Ok(arg),
- data_type => Err(DataFusionError::NotImplemented(format!(
- "Array is not implemented for type '{data_type:?}'."
- ))),
- }
- }
let arg: u8 = 0;
Ok(Arc::new(UInt8Array::from(vec![compute_array_ndims(
arg,