This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new a42c475419 Fix union cast incorrectness for duplicate field names
(#9666)
a42c475419 is described below
commit a42c475419fea382dc66642d01ebf1abf92a9f6d
Author: Matthew Kim <[email protected]>
AuthorDate: Fri Apr 10 14:32:20 2026 -0400
Fix union cast incorrectness for duplicate field names (#9666)
# Which issue does this PR close?
- Closes #9664
# Rationale for this change
This PR fixes `union_extract_by_type` to select children by type id
instead of field name
`union_extract` resolves children by field name, greedily returning the
first match. When `union_extract_by_type` resolved the correct child via
type and then called `union_extract` with just the field name, that
name-based lookup could pick the wrong child if another field shared the
same name
This commit exposes a public method `union_extract_by_id` that selects
the child by type id directly. This way, it avoids the ambiguity
mentioned above
---
arrow-cast/src/cast/union.rs | 66 ++++++++++++++++--
arrow-select/src/union_extract.rs | 140 +++++++++++++++++++++++++++++++++++++-
2 files changed, 200 insertions(+), 6 deletions(-)
diff --git a/arrow-cast/src/cast/union.rs b/arrow-cast/src/cast/union.rs
index 7681e04356..89929e3c88 100644
--- a/arrow-cast/src/cast/union.rs
+++ b/arrow-cast/src/cast/union.rs
@@ -21,7 +21,7 @@ use crate::cast::can_cast_types;
use crate::cast_with_options;
use arrow_array::{Array, ArrayRef, UnionArray};
use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields};
-use arrow_select::union_extract::union_extract;
+use arrow_select::union_extract::union_extract_by_id;
use super::CastOptions;
@@ -64,7 +64,7 @@ fn same_type_family(a: &DataType, b: &DataType) -> bool {
pub(crate) fn resolve_child_array<'a>(
fields: &'a UnionFields,
target_type: &DataType,
-) -> Option<&'a FieldRef> {
+) -> Option<(i8, &'a FieldRef)> {
fields
.iter()
.find(|(_, f)| f.data_type() == target_type)
@@ -84,7 +84,6 @@ pub(crate) fn resolve_child_array<'a>(
.iter()
.find(|(_, f)| can_cast_types(f.data_type(), target_type))
})
- .map(|(_, f)| f)
}
/// Extracts the best-matching child array from a [`UnionArray`] for a given
target type,
@@ -137,7 +136,7 @@ pub fn union_extract_by_type(
_ => unreachable!("union_extract_by_type called on non-union array"),
};
- let Some(field) = resolve_child_array(fields, target_type) else {
+ let Some((type_id, _)) = resolve_child_array(fields, target_type) else {
return Err(ArrowError::CastError(format!(
"cannot cast Union with fields {} to {}",
fields
@@ -149,7 +148,7 @@ pub fn union_extract_by_type(
)));
};
- let extracted = union_extract(union_array, field.name())?;
+ let extracted = union_extract_by_id(union_array, type_id)?;
if extracted.data_type() == target_type {
return Ok(extracted);
@@ -355,6 +354,63 @@ mod tests {
assert!(!arr.value(2));
}
+ // duplicate field names: ensure we resolve by type_id, not field name.
+ // Union has two children both named "val" — Int32 (type_id 0) and Utf8
(type_id 1).
+ // Casting to Utf8 should select the Utf8 child (type_id 1), not the Int32
child (type_id 0).
+ #[test]
+ fn test_duplicate_field_names() {
+ let fields = UnionFields::try_new(
+ [0, 1],
+ [
+ Field::new("val", DataType::Int32, true),
+ Field::new("val", DataType::Utf8, true),
+ ],
+ )
+ .unwrap();
+
+ let target = DataType::Utf8;
+
+ let sparse = UnionArray::try_new(
+ fields.clone(),
+ vec![0_i8, 1, 0, 1].into(),
+ None,
+ vec![
+ Arc::new(Int32Array::from(vec![Some(42), None, Some(99),
None])) as ArrayRef,
+ Arc::new(StringArray::from(vec![
+ None,
+ Some("hello"),
+ None,
+ Some("world"),
+ ])),
+ ],
+ )
+ .unwrap();
+
+ let result = cast::cast(&sparse, &target).unwrap();
+ let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
+ assert!(arr.is_null(0));
+ assert_eq!(arr.value(1), "hello");
+ assert!(arr.is_null(2));
+ assert_eq!(arr.value(3), "world");
+
+ let dense = UnionArray::try_new(
+ fields,
+ vec![0_i8, 1, 1].into(),
+ Some(vec![0_i32, 0, 1].into()),
+ vec![
+ Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
+ Arc::new(StringArray::from(vec![Some("hello"),
Some("world")])),
+ ],
+ )
+ .unwrap();
+
+ let result = cast::cast(&dense, &target).unwrap();
+ let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
+ assert!(arr.is_null(0));
+ assert_eq!(arr.value(1), "hello");
+ assert_eq!(arr.value(2), "world");
+ }
+
// no matching child array, all three passes fail.
// Union(Int32, Utf8) targeting Struct({x: Int32}). neither Int32 nor Utf8
// can be cast to a Struct, so both can_cast_types and cast return errors.
diff --git a/arrow-select/src/union_extract.rs
b/arrow-select/src/union_extract.rs
index 3accecc359..893b13554b 100644
--- a/arrow-select/src/union_extract.rs
+++ b/arrow-select/src/union_extract.rs
@@ -89,6 +89,40 @@ pub fn union_extract(union_array: &UnionArray, target: &str)
-> Result<ArrayRef,
ArrowError::InvalidArgumentError(format!("field {target} not found
on union"))
})?;
+ union_extract_impl(union_array, fields, target_type_id)
+}
+
+/// Like [`union_extract`], but selects the child by `type_id` rather than by
+/// field name.
+///
+/// This avoids ambiguity when the union contains duplicate field names.
+///
+/// # Errors
+///
+/// Returns error if `target_type_id` does not correspond to a field in the
union.
+pub fn union_extract_by_id(
+ union_array: &UnionArray,
+ target_type_id: i8,
+) -> Result<ArrayRef, ArrowError> {
+ let fields = match union_array.data_type() {
+ DataType::Union(fields, _) => fields,
+ _ => unreachable!(),
+ };
+
+ if fields.iter().all(|(id, _)| id != target_type_id) {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "type_id {target_type_id} not found on union"
+ )));
+ }
+
+ union_extract_impl(union_array, fields, target_type_id)
+}
+
+fn union_extract_impl(
+ union_array: &UnionArray,
+ fields: &UnionFields,
+ target_type_id: i8,
+) -> Result<ArrayRef, ArrowError> {
match union_array.offsets() {
Some(_) => extract_dense(union_array, fields, target_type_id),
None => extract_sparse(union_array, fields, target_type_id),
@@ -399,7 +433,9 @@ fn is_sequential_generic<const N: usize>(offsets: &[i32])
-> bool {
#[cfg(test)]
mod tests {
- use super::{BoolValue, eq_scalar_inner, is_sequential_generic,
union_extract};
+ use super::{
+ BoolValue, eq_scalar_inner, is_sequential_generic, union_extract,
union_extract_by_id,
+ };
use arrow_array::{Array, Int32Array, NullArray, StringArray, UnionArray,
new_null_array};
use arrow_buffer::{BooleanBuffer, ScalarBuffer};
use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
@@ -1236,4 +1272,106 @@ mod tests {
ArrowError::InvalidArgumentError("field a not found on
union".into()).to_string()
);
}
+
+ #[test]
+ fn extract_by_id_sparse_duplicate_names() {
+ // Two fields with the same name "val" but different type_ids and types
+ let fields = UnionFields::try_new(
+ [0, 1],
+ [
+ Field::new("val", DataType::Int32, true),
+ Field::new("val", DataType::Utf8, true),
+ ],
+ )
+ .unwrap();
+
+ let union = UnionArray::try_new(
+ fields,
+ vec![0_i8, 1, 0, 1].into(),
+ None,
+ vec![
+ Arc::new(Int32Array::from(vec![Some(42), None, Some(99),
None])) as _,
+ Arc::new(StringArray::from(vec![
+ None,
+ Some("hello"),
+ None,
+ Some("world"),
+ ])),
+ ],
+ )
+ .unwrap();
+
+ // union_extract by name always returns type_id 0 (first match)
+ let by_name = union_extract(&union, "val").unwrap();
+ assert_eq!(
+ *by_name,
+ Int32Array::from(vec![Some(42), None, Some(99), None])
+ );
+
+ // union_extract_by_id can select type_id 1 (the Utf8 child)
+ let by_id = union_extract_by_id(&union, 1).unwrap();
+ assert_eq!(
+ *by_id,
+ StringArray::from(vec![None, Some("hello"), None, Some("world")])
+ );
+ }
+
+ #[test]
+ fn extract_by_id_dense_duplicate_names() {
+ let fields = UnionFields::try_new(
+ [0, 1],
+ [
+ Field::new("val", DataType::Int32, true),
+ Field::new("val", DataType::Utf8, true),
+ ],
+ )
+ .unwrap();
+
+ let union = UnionArray::try_new(
+ fields,
+ vec![0_i8, 1, 0].into(),
+ Some(vec![0_i32, 0, 1].into()),
+ vec![
+ Arc::new(Int32Array::from(vec![Some(42), Some(99)])) as _,
+ Arc::new(StringArray::from(vec![Some("hello")])),
+ ],
+ )
+ .unwrap();
+
+ // by type_id 0 → Int32 child
+ let by_id_0 = union_extract_by_id(&union, 0).unwrap();
+ assert_eq!(*by_id_0, Int32Array::from(vec![Some(42), None, Some(99)]));
+
+ // by type_id 1 → Utf8 child
+ let by_id_1 = union_extract_by_id(&union, 1).unwrap();
+ assert_eq!(*by_id_1, StringArray::from(vec![None, Some("hello"),
None]));
+ }
+
+ #[test]
+ fn extract_by_id_not_found() {
+ let fields = UnionFields::try_new(
+ [0, 1],
+ [
+ Field::new("a", DataType::Int32, true),
+ Field::new("b", DataType::Utf8, true),
+ ],
+ )
+ .unwrap();
+
+ let union = UnionArray::try_new(
+ fields,
+ vec![0_i8, 1].into(),
+ None,
+ vec![
+ Arc::new(Int32Array::from(vec![Some(1), None])) as _,
+ Arc::new(StringArray::from(vec![None, Some("x")])),
+ ],
+ )
+ .unwrap();
+
+ assert_eq!(
+ union_extract_by_id(&union, 5).unwrap_err().to_string(),
+ ArrowError::InvalidArgumentError("type_id 5 not found on
union".into()).to_string()
+ );
+ }
}