This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new a42c475419 Fix union cast incorrectness for duplicate field names 
(#9666)
a42c475419 is described below

commit a42c475419fea382dc66642d01ebf1abf92a9f6d
Author: Matthew Kim <[email protected]>
AuthorDate: Fri Apr 10 14:32:20 2026 -0400

    Fix union cast incorrectness for duplicate field names (#9666)
    
    # Which issue does this PR close?
    
    - Closes #9664
    
    # Rationale for this change
    
    This PR fixes `union_extract_by_type` to select children by type id
    instead of field name
    
    `union_extract` resolves children by field name, greedily returning the
    first match. When `union_extract_by_type` resolved the correct child via
    type and then called `union_extract` with just the field name, that
    name-based lookup could pick the wrong child if another field shared the
    same name
    
    This commit exposes a public method `union_extract_by_id` that selects
    the child by type id directly. This way, it avoids the ambiguity
    mentioned above
---
 arrow-cast/src/cast/union.rs      |  66 ++++++++++++++++--
 arrow-select/src/union_extract.rs | 140 +++++++++++++++++++++++++++++++++++++-
 2 files changed, 200 insertions(+), 6 deletions(-)

diff --git a/arrow-cast/src/cast/union.rs b/arrow-cast/src/cast/union.rs
index 7681e04356..89929e3c88 100644
--- a/arrow-cast/src/cast/union.rs
+++ b/arrow-cast/src/cast/union.rs
@@ -21,7 +21,7 @@ use crate::cast::can_cast_types;
 use crate::cast_with_options;
 use arrow_array::{Array, ArrayRef, UnionArray};
 use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields};
-use arrow_select::union_extract::union_extract;
+use arrow_select::union_extract::union_extract_by_id;
 
 use super::CastOptions;
 
@@ -64,7 +64,7 @@ fn same_type_family(a: &DataType, b: &DataType) -> bool {
 pub(crate) fn resolve_child_array<'a>(
     fields: &'a UnionFields,
     target_type: &DataType,
-) -> Option<&'a FieldRef> {
+) -> Option<(i8, &'a FieldRef)> {
     fields
         .iter()
         .find(|(_, f)| f.data_type() == target_type)
@@ -84,7 +84,6 @@ pub(crate) fn resolve_child_array<'a>(
                 .iter()
                 .find(|(_, f)| can_cast_types(f.data_type(), target_type))
         })
-        .map(|(_, f)| f)
 }
 
 /// Extracts the best-matching child array from a [`UnionArray`] for a given 
target type,
@@ -137,7 +136,7 @@ pub fn union_extract_by_type(
         _ => unreachable!("union_extract_by_type called on non-union array"),
     };
 
-    let Some(field) = resolve_child_array(fields, target_type) else {
+    let Some((type_id, _)) = resolve_child_array(fields, target_type) else {
         return Err(ArrowError::CastError(format!(
             "cannot cast Union with fields {} to {}",
             fields
@@ -149,7 +148,7 @@ pub fn union_extract_by_type(
         )));
     };
 
-    let extracted = union_extract(union_array, field.name())?;
+    let extracted = union_extract_by_id(union_array, type_id)?;
 
     if extracted.data_type() == target_type {
         return Ok(extracted);
@@ -355,6 +354,63 @@ mod tests {
         assert!(!arr.value(2));
     }
 
+    // duplicate field names: ensure we resolve by type_id, not field name.
+    // Union has two children both named "val" — Int32 (type_id 0) and Utf8 
(type_id 1).
+    // Casting to Utf8 should select the Utf8 child (type_id 1), not the Int32 
child (type_id 0).
+    #[test]
+    fn test_duplicate_field_names() {
+        let fields = UnionFields::try_new(
+            [0, 1],
+            [
+                Field::new("val", DataType::Int32, true),
+                Field::new("val", DataType::Utf8, true),
+            ],
+        )
+        .unwrap();
+
+        let target = DataType::Utf8;
+
+        let sparse = UnionArray::try_new(
+            fields.clone(),
+            vec![0_i8, 1, 0, 1].into(),
+            None,
+            vec![
+                Arc::new(Int32Array::from(vec![Some(42), None, Some(99), 
None])) as ArrayRef,
+                Arc::new(StringArray::from(vec![
+                    None,
+                    Some("hello"),
+                    None,
+                    Some("world"),
+                ])),
+            ],
+        )
+        .unwrap();
+
+        let result = cast::cast(&sparse, &target).unwrap();
+        let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
+        assert!(arr.is_null(0));
+        assert_eq!(arr.value(1), "hello");
+        assert!(arr.is_null(2));
+        assert_eq!(arr.value(3), "world");
+
+        let dense = UnionArray::try_new(
+            fields,
+            vec![0_i8, 1, 1].into(),
+            Some(vec![0_i32, 0, 1].into()),
+            vec![
+                Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
+                Arc::new(StringArray::from(vec![Some("hello"), 
Some("world")])),
+            ],
+        )
+        .unwrap();
+
+        let result = cast::cast(&dense, &target).unwrap();
+        let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
+        assert!(arr.is_null(0));
+        assert_eq!(arr.value(1), "hello");
+        assert_eq!(arr.value(2), "world");
+    }
+
     // no matching child array, all three passes fail.
     // Union(Int32, Utf8) targeting Struct({x: Int32}). neither Int32 nor Utf8
     // can be cast to a Struct, so both can_cast_types and cast return errors.
diff --git a/arrow-select/src/union_extract.rs 
b/arrow-select/src/union_extract.rs
index 3accecc359..893b13554b 100644
--- a/arrow-select/src/union_extract.rs
+++ b/arrow-select/src/union_extract.rs
@@ -89,6 +89,40 @@ pub fn union_extract(union_array: &UnionArray, target: &str) 
-> Result<ArrayRef,
             ArrowError::InvalidArgumentError(format!("field {target} not found 
on union"))
         })?;
 
+    union_extract_impl(union_array, fields, target_type_id)
+}
+
+/// Like [`union_extract`], but selects the child by `type_id` rather than by
+/// field name.
+///
+/// This avoids ambiguity when the union contains duplicate field names.
+///
+/// # Errors
+///
+/// Returns error if `target_type_id` does not correspond to a field in the 
union.
+pub fn union_extract_by_id(
+    union_array: &UnionArray,
+    target_type_id: i8,
+) -> Result<ArrayRef, ArrowError> {
+    let fields = match union_array.data_type() {
+        DataType::Union(fields, _) => fields,
+        _ => unreachable!(),
+    };
+
+    if fields.iter().all(|(id, _)| id != target_type_id) {
+        return Err(ArrowError::InvalidArgumentError(format!(
+            "type_id {target_type_id} not found on union"
+        )));
+    }
+
+    union_extract_impl(union_array, fields, target_type_id)
+}
+
+fn union_extract_impl(
+    union_array: &UnionArray,
+    fields: &UnionFields,
+    target_type_id: i8,
+) -> Result<ArrayRef, ArrowError> {
     match union_array.offsets() {
         Some(_) => extract_dense(union_array, fields, target_type_id),
         None => extract_sparse(union_array, fields, target_type_id),
@@ -399,7 +433,9 @@ fn is_sequential_generic<const N: usize>(offsets: &[i32]) 
-> bool {
 
 #[cfg(test)]
 mod tests {
-    use super::{BoolValue, eq_scalar_inner, is_sequential_generic, 
union_extract};
+    use super::{
+        BoolValue, eq_scalar_inner, is_sequential_generic, union_extract, 
union_extract_by_id,
+    };
     use arrow_array::{Array, Int32Array, NullArray, StringArray, UnionArray, 
new_null_array};
     use arrow_buffer::{BooleanBuffer, ScalarBuffer};
     use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
@@ -1236,4 +1272,106 @@ mod tests {
             ArrowError::InvalidArgumentError("field a not found on 
union".into()).to_string()
         );
     }
+
+    #[test]
+    fn extract_by_id_sparse_duplicate_names() {
+        // Two fields with the same name "val" but different type_ids and types
+        let fields = UnionFields::try_new(
+            [0, 1],
+            [
+                Field::new("val", DataType::Int32, true),
+                Field::new("val", DataType::Utf8, true),
+            ],
+        )
+        .unwrap();
+
+        let union = UnionArray::try_new(
+            fields,
+            vec![0_i8, 1, 0, 1].into(),
+            None,
+            vec![
+                Arc::new(Int32Array::from(vec![Some(42), None, Some(99), 
None])) as _,
+                Arc::new(StringArray::from(vec![
+                    None,
+                    Some("hello"),
+                    None,
+                    Some("world"),
+                ])),
+            ],
+        )
+        .unwrap();
+
+        // union_extract by name always returns type_id 0 (first match)
+        let by_name = union_extract(&union, "val").unwrap();
+        assert_eq!(
+            *by_name,
+            Int32Array::from(vec![Some(42), None, Some(99), None])
+        );
+
+        // union_extract_by_id can select type_id 1 (the Utf8 child)
+        let by_id = union_extract_by_id(&union, 1).unwrap();
+        assert_eq!(
+            *by_id,
+            StringArray::from(vec![None, Some("hello"), None, Some("world")])
+        );
+    }
+
+    #[test]
+    fn extract_by_id_dense_duplicate_names() {
+        let fields = UnionFields::try_new(
+            [0, 1],
+            [
+                Field::new("val", DataType::Int32, true),
+                Field::new("val", DataType::Utf8, true),
+            ],
+        )
+        .unwrap();
+
+        let union = UnionArray::try_new(
+            fields,
+            vec![0_i8, 1, 0].into(),
+            Some(vec![0_i32, 0, 1].into()),
+            vec![
+                Arc::new(Int32Array::from(vec![Some(42), Some(99)])) as _,
+                Arc::new(StringArray::from(vec![Some("hello")])),
+            ],
+        )
+        .unwrap();
+
+        // by type_id 0 → Int32 child
+        let by_id_0 = union_extract_by_id(&union, 0).unwrap();
+        assert_eq!(*by_id_0, Int32Array::from(vec![Some(42), None, Some(99)]));
+
+        // by type_id 1 → Utf8 child
+        let by_id_1 = union_extract_by_id(&union, 1).unwrap();
+        assert_eq!(*by_id_1, StringArray::from(vec![None, Some("hello"), 
None]));
+    }
+
+    #[test]
+    fn extract_by_id_not_found() {
+        let fields = UnionFields::try_new(
+            [0, 1],
+            [
+                Field::new("a", DataType::Int32, true),
+                Field::new("b", DataType::Utf8, true),
+            ],
+        )
+        .unwrap();
+
+        let union = UnionArray::try_new(
+            fields,
+            vec![0_i8, 1].into(),
+            None,
+            vec![
+                Arc::new(Int32Array::from(vec![Some(1), None])) as _,
+                Arc::new(StringArray::from(vec![None, Some("x")])),
+            ],
+        )
+        .unwrap();
+
+        assert_eq!(
+            union_extract_by_id(&union, 5).unwrap_err().to_string(),
+            ArrowError::InvalidArgumentError("type_id 5 not found on 
union".into()).to_string()
+        );
+    }
 }

Reply via email to