This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a25b9772d1 Support array concatenation for arrays with different 
dimensions (#6872)
a25b9772d1 is described below

commit a25b9772d1ff1d57b657c71e545f795523b28e17
Author: Jay Zhan <[email protected]>
AuthorDate: Sun Jul 9 21:48:10 2023 +0800

    Support array concatenation for arrays with different dimensions (#6872)
    
    * add diff dims support for array concat
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * address comment
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 .../core/tests/sqllogictests/test_files/array.slt  | 24 ++++++
 datafusion/physical-expr/src/array_expressions.rs  | 88 ++++++++++++++++------
 2 files changed, 87 insertions(+), 25 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt 
b/datafusion/core/tests/sqllogictests/test_files/array.slt
index 7eebb23d9c..2d7a609989 100644
--- a/datafusion/core/tests/sqllogictests/test_files/array.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/array.slt
@@ -365,6 +365,30 @@ select array_concat(make_array(), make_array(2, 3));
 ----
 [2, 3]
 
+# array_concat with different dimensions #1 (2D + 1D)
+query ?
+select array_concat(make_array([1,2], [3,4]), make_array(5, 6))
+----
+[[1, 2], [3, 4], [5, 6]]
+
+# array_concat with different dimensions #2 (1D + 2D)
+query ?
+select array_concat(make_array(5, 6), make_array([1,2], [3,4]))
+----
+[[5, 6], [1, 2], [3, 4]]
+
+# array_concat with different dimensions #3 (2D + 1D + 1D)
+query ?
+select array_concat(make_array([1,2], [3,4]), make_array(5, 6), 
make_array(7,8))
+----
+[[1, 2], [3, 4], [5, 6], [7, 8]]
+
+# array_concat with different dimensions #4 (1D + 2D + 3D)
+query ?
+select array_concat(make_array(10, 20), make_array([30, 40]), make_array([[50, 
60]]))
+----
+[[[10, 20]], [[30, 40]], [[50, 60]]]
+
 ## array_position
 
 # array_position scalar function #1
diff --git a/datafusion/physical-expr/src/array_expressions.rs 
b/datafusion/physical-expr/src/array_expressions.rs
index cd174918db..cbf5896b85 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -373,20 +373,83 @@ pub fn array_prepend(args: &[ArrayRef]) -> 
Result<ArrayRef> {
     Ok(res)
 }
 
+fn compute_array_ndims(arg: u8, arr: ArrayRef) -> Result<u8> {
+    match arr.data_type() {
+        DataType::List(..) => {
+            let list_array = downcast_arg!(arr, ListArray);
+            compute_array_ndims(arg + 1, list_array.value(0))
+        }
+        DataType::Null
+        | DataType::Utf8
+        | DataType::LargeUtf8
+        | DataType::Boolean
+        | DataType::Float32
+        | DataType::Float64
+        | DataType::Int8
+        | DataType::Int16
+        | DataType::Int32
+        | DataType::Int64
+        | DataType::UInt8
+        | DataType::UInt16
+        | DataType::UInt32
+        | DataType::UInt64 => Ok(arg),
+        data_type => Err(DataFusionError::NotImplemented(format!(
+            "Array is not implemented for type '{data_type:?}'."
+        ))),
+    }
+}
+
+fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
+    // Find the maximum number of dimensions
+    let max_ndim: u8 = *args
+        .iter()
+        .map(|arr| compute_array_ndims(0, arr.clone()))
+        .collect::<Result<Vec<u8>>>()?
+        .iter()
+        .max()
+        .unwrap();
+
+    // Align the dimensions of the arrays
+    let aligned_args: Result<Vec<ArrayRef>> = args
+        .into_iter()
+        .map(|array| {
+            let ndim = compute_array_ndims(0, array.clone())?;
+            if ndim < max_ndim {
+                let mut aligned_array = array.clone();
+                for _ in 0..(max_ndim - ndim) {
+                    let data_type = aligned_array.as_ref().data_type().clone();
+                    aligned_array = array_array(&[aligned_array], data_type)?;
+                }
+                Ok(aligned_array)
+            } else {
+                Ok(array.clone())
+            }
+        })
+        .collect();
+
+    aligned_args
+}
+
 /// Array_concat/Array_cat SQL function
 pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
     match args[0].data_type() {
         DataType::List(field) => match field.data_type() {
             DataType::Null => array_concat(&args[1..]),
             _ => {
+                let args = align_array_dimensions(args.to_vec())?;
+
                 let list_arrays = downcast_vec!(args, ListArray)
                     .collect::<Result<Vec<&ListArray>>>()?;
+
                 let len: usize = list_arrays.iter().map(|a| 
a.values().len()).sum();
+
                 let capacity =
                     Capacities::Array(list_arrays.iter().map(|a| 
a.len()).sum());
                 let array_data: Vec<_> =
                     list_arrays.iter().map(|a| 
a.to_data()).collect::<Vec<_>>();
+
                 let array_data = array_data.iter().collect();
+
                 let mut mutable =
                     MutableArrayData::with_capacities(array_data, false, 
capacity);
 
@@ -1217,31 +1280,6 @@ pub fn array_dims(args: &[ArrayRef]) -> Result<ArrayRef> 
{
 
 /// Array_ndims SQL function
 pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
-    fn compute_array_ndims(arg: u8, arr: ArrayRef) -> Result<u8> {
-        match arr.data_type() {
-            DataType::List(..) => {
-                let list_array = downcast_arg!(arr, ListArray);
-                compute_array_ndims(arg + 1, list_array.value(0))
-            }
-            DataType::Null
-            | DataType::Utf8
-            | DataType::LargeUtf8
-            | DataType::Boolean
-            | DataType::Float32
-            | DataType::Float64
-            | DataType::Int8
-            | DataType::Int16
-            | DataType::Int32
-            | DataType::Int64
-            | DataType::UInt8
-            | DataType::UInt16
-            | DataType::UInt32
-            | DataType::UInt64 => Ok(arg),
-            data_type => Err(DataFusionError::NotImplemented(format!(
-                "Array is not implemented for type '{data_type:?}'."
-            ))),
-        }
-    }
     let arg: u8 = 0;
     Ok(Arc::new(UInt8Array::from(vec![compute_array_ndims(
         arg,

Reply via email to