This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 12ff1eac2 fix: Correctly handle take on dense union of a single 
selected type (#6209)
12ff1eac2 is described below

commit 12ff1eac23058691bb157d7ef9981e6a0e7bcd0d
Author: gstvg <[email protected]>
AuthorDate: Thu Aug 8 07:58:12 2024 -0300

    fix: Correctly handle take on dense union of a single selected type (#6209)
    
    * fix: use filter instead of filter_primitive
    
    * fix: remove pub(crate) from filter_primitive
    
    * fix: run cargo fmt
    
    * fix: clippy
---
 arrow-select/src/filter.rs |  5 +----
 arrow-select/src/take.rs   | 27 +++++++++++++++++++++------
 2 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs
index 65ccbe1e0..8e06b07f5 100644
--- a/arrow-select/src/filter.rs
+++ b/arrow-select/src/filter.rs
@@ -552,10 +552,7 @@ fn filter_native<T: ArrowNativeType>(values: &[T], 
predicate: &FilterPredicate)
 }
 
 /// `filter` implementation for primitive arrays
-pub(crate) fn filter_primitive<T>(
-    array: &PrimitiveArray<T>,
-    predicate: &FilterPredicate,
-) -> PrimitiveArray<T>
+fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) 
-> PrimitiveArray<T>
 where
     T: ArrowPrimitiveType,
 {
diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index b66133ac7..ed7179fd3 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -31,8 +31,6 @@ use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
 
 use num::{One, Zero};
 
-use crate::filter::{filter_primitive, FilterBuilder};
-
 /// Take elements by index from [Array], creating a new [Array] from those 
indexes.
 ///
 /// ```text
@@ -251,13 +249,12 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
             let children = fields.iter()
                 .map(|(field_type_id, _)| {
                     let mask = BooleanArray::from_unary(&type_ids, 
|value_type_id| value_type_id == field_type_id);
-                    let predicate = FilterBuilder::new(&mask).build();
 
-                    let indices = filter_primitive(&offsets, &predicate);
+                    let indices = crate::filter::filter(&offsets, &mask)?;
 
                     let values = values.child(field_type_id);
 
-                    take_impl(values, &indices)
+                    take_impl(values, indices.as_primitive::<Int32Type>())
                 })
                 .collect::<Result<_, _>>()?;
 
@@ -885,7 +882,7 @@ mod tests {
     use super::*;
     use arrow_array::builder::*;
     use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
-    use arrow_schema::{Field, Fields, TimeUnit};
+    use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
 
     fn test_take_decimal_arrays(
         data: Vec<Option<i128>>,
@@ -2308,4 +2305,22 @@ mod tests {
             take(&union, &indices, None).unwrap().to_data()
         );
     }
+
+    #[test]
+    fn test_take_union_dense_all_match_issue_6206() {
+        let fields = UnionFields::new(vec![0], vec![Field::new("a", 
DataType::Int64, false)]);
+        let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
+
+        let array = UnionArray::try_new(
+            fields,
+            ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
+            Some(ScalarBuffer::from_iter(0_i32..5)),
+            vec![ints],
+        )
+        .unwrap();
+
+        let indicies = Int64Array::from(vec![0, 2, 4]);
+        let array = take(&array, &indicies, None).unwrap();
+        assert_eq!(array.len(), 3);
+    }
 }

Reply via email to