This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 12ff1eac2 fix: Correctly handle take on dense union of a single
selected type (#6209)
12ff1eac2 is described below
commit 12ff1eac23058691bb157d7ef9981e6a0e7bcd0d
Author: gstvg <[email protected]>
AuthorDate: Thu Aug 8 07:58:12 2024 -0300
fix: Correctly handle take on dense union of a single selected type (#6209)
* fix: use filter instead of filter_primitive
* fix: remove pub(crate) from filter_primitive
* fix: run cargo fmt
* fix: clippy
---
arrow-select/src/filter.rs | 5 +----
arrow-select/src/take.rs | 27 +++++++++++++++++++++------
2 files changed, 22 insertions(+), 10 deletions(-)
diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs
index 65ccbe1e0..8e06b07f5 100644
--- a/arrow-select/src/filter.rs
+++ b/arrow-select/src/filter.rs
@@ -552,10 +552,7 @@ fn filter_native<T: ArrowNativeType>(values: &[T],
predicate: &FilterPredicate)
}
/// `filter` implementation for primitive arrays
-pub(crate) fn filter_primitive<T>(
- array: &PrimitiveArray<T>,
- predicate: &FilterPredicate,
-) -> PrimitiveArray<T>
+fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate)
-> PrimitiveArray<T>
where
T: ArrowPrimitiveType,
{
diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index b66133ac7..ed7179fd3 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -31,8 +31,6 @@ use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
use num::{One, Zero};
-use crate::filter::{filter_primitive, FilterBuilder};
-
/// Take elements by index from [Array], creating a new [Array] from those
indexes.
///
/// ```text
@@ -251,13 +249,12 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
let children = fields.iter()
.map(|(field_type_id, _)| {
let mask = BooleanArray::from_unary(&type_ids,
|value_type_id| value_type_id == field_type_id);
- let predicate = FilterBuilder::new(&mask).build();
- let indices = filter_primitive(&offsets, &predicate);
+ let indices = crate::filter::filter(&offsets, &mask)?;
let values = values.child(field_type_id);
- take_impl(values, &indices)
+ take_impl(values, indices.as_primitive::<Int32Type>())
})
.collect::<Result<_, _>>()?;
@@ -885,7 +882,7 @@ mod tests {
use super::*;
use arrow_array::builder::*;
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
- use arrow_schema::{Field, Fields, TimeUnit};
+ use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
fn test_take_decimal_arrays(
data: Vec<Option<i128>>,
@@ -2308,4 +2305,22 @@ mod tests {
take(&union, &indices, None).unwrap().to_data()
);
}
+
+ #[test]
+ fn test_take_union_dense_all_match_issue_6206() {
+ let fields = UnionFields::new(vec![0], vec![Field::new("a",
DataType::Int64, false)]);
+ let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
+
+ let array = UnionArray::try_new(
+ fields,
+ ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
+ Some(ScalarBuffer::from_iter(0_i32..5)),
+ vec![ints],
+ )
+ .unwrap();
+
+ let indicies = Int64Array::from(vec![0, 2, 4]);
+ let array = take(&array, &indicies, None).unwrap();
+ assert_eq!(array.len(), 3);
+ }
}