jecsand838 commented on code in PR #8349:
URL: https://github.com/apache/arrow-rs/pull/8349#discussion_r2369679690


##########
arrow-avro/src/reader/record.rs:
##########
@@ -1518,19 +1104,340 @@ impl Decoder {
                     .map_err(|e| ArrowError::ParseError(e.to_string()))?;
                 Arc::new(vals)
             }
-            Self::Union(fields, type_ids, offsets, encodings, _, None) => {
-                flush_union!(fields, type_ids, offsets, encodings)
-            }
-            Self::Union(fields, type_ids, offsets, encodings, _, 
Some(union_resolution)) => {
-                match &mut union_resolution.kind {
-                    UnionResolvedKind::Both { .. } | 
UnionResolvedKind::FromSingle { .. } => {
-                        flush_union!(fields, type_ids, offsets, encodings)
-                    }
-                    UnionResolvedKind::ToSingle { target } => 
target.flush(nulls)?,
+            Self::Union(u) => u.flush(nulls)?,
+        })
+    }
+}
+
+#[derive(Debug)]
+struct DispatchLut {
+    to_reader: Box<[i16]>,
+    promotion: Box<[Promotion]>,
+}
+
+impl DispatchLut {
+    fn from_writer_to_reader(promotion_map: &[Option<(usize, Promotion)>]) -> 
Self {
+        let mut to_reader = Vec::with_capacity(promotion_map.len());
+        let mut promotion = Vec::with_capacity(promotion_map.len());
+        for map in promotion_map {
+            match *map {
+                Some((idx, promo)) => {
+                    debug_assert!(idx <= i16::MAX as usize);
+                    to_reader.push(idx as i16);
+                    promotion.push(promo);
+                }
+                None => {
+                    to_reader.push(-1);
+                    promotion.push(Promotion::Direct);
                 }
             }
+        }
+        Self {
+            to_reader: to_reader.into_boxed_slice(),
+            promotion: promotion.into_boxed_slice(),
+        }
+    }
+
+    // Resolve a writer branch index to (reader_idx, promotion)
+    #[inline]
+    fn resolve(&self, writer_idx: usize) -> Option<(usize, Promotion)> {
+        if writer_idx >= self.to_reader.len() {
+            return None;
+        }
+        let reader_index = self.to_reader[writer_idx];
+        if reader_index < 0 {
+            None
+        } else {
+            Some((reader_index as usize, self.promotion[writer_idx]))
+        }
+    }
+}
+
+#[derive(Debug)]
+struct UnionDecoder {
+    fields: UnionFields,
+    type_ids: Vec<i8>,
+    offsets: Vec<i32>,
+    branches: Vec<Decoder>,
+    counts: Vec<i32>,
+    type_id_by_reader_idx: Arc<[i8]>,
+    null_branch: Option<usize>,
+    default_emit_idx: usize,
+    null_emit_idx: usize,
+    plan: UnionReadPlan,
+}
+
+impl Default for UnionDecoder {
+    fn default() -> Self {
+        Self {
+            fields: UnionFields::empty(),
+            type_ids: Vec::new(),
+            offsets: Vec::new(),
+            branches: Vec::new(),
+            counts: Vec::new(),
+            type_id_by_reader_idx: Arc::from([]),
+            null_branch: None,
+            default_emit_idx: 0,
+            null_emit_idx: 0,
+            plan: UnionReadPlan::Passthrough,
+        }
+    }
+}
+
+#[derive(Debug)]
+enum UnionReadPlan {
+    ReaderUnion {
+        lookup_table: DispatchLut,
+    },
+    FromSingle {
+        reader_idx: usize,
+        promotion: Promotion,
+    },
+    ToSingle {
+        target: Box<Decoder>,
+        lookup_table: DispatchLut,
+    },
+    Passthrough,
+}
+
+impl UnionDecoder {
+    fn try_new(
+        fields: UnionFields,
+        branches: Vec<Decoder>,
+        resolved: Option<ResolvedUnion>,
+    ) -> Result<Self, ArrowError> {
+        let reader_type_codes: Arc<[i8]> =
+            Arc::from(fields.iter().map(|(tid, _)| tid).collect::<Vec<i8>>());
+        let null_branch = branches.iter().position(|b| matches!(b, 
Decoder::Null(_)));
+        let default_emit_idx = 0;
+        let null_emit_idx = null_branch.unwrap_or(default_emit_idx);
+        let plan = Self::plan_from_resolved(resolved)?;
+        let branch_len = branches.len().max(reader_type_codes.len());
+        Ok(Self {
+            fields,
+            type_ids: Vec::with_capacity(DEFAULT_CAPACITY),
+            offsets: Vec::with_capacity(DEFAULT_CAPACITY),
+            branches,
+            counts: vec![0; branch_len],
+            type_id_by_reader_idx: reader_type_codes,
+            null_branch,
+            default_emit_idx,
+            null_emit_idx,
+            plan,
         })
     }
+
+    fn try_new_from_writer_union(
+        info: ResolvedUnion,
+        target: Box<Decoder>,
+    ) -> Result<Self, ArrowError> {
+        // This constructor is only for writer-union to single-type resolution
+        debug_assert!(info.writer_is_union && !info.reader_is_union);
+        let lookup_table = 
DispatchLut::from_writer_to_reader(&info.writer_to_reader);
+        Ok(Self {
+            plan: UnionReadPlan::ToSingle {
+                target,
+                lookup_table,
+            },
+            ..Self::default()
+        })
+    }
+
+    fn plan_from_resolved(resolved: Option<ResolvedUnion>) -> 
Result<UnionReadPlan, ArrowError> {
+        match resolved {
+            None => Ok(UnionReadPlan::Passthrough),
+            Some(info) => match (info.writer_is_union, info.reader_is_union) {
+                (true, true) => {
+                    let lookup_table = 
DispatchLut::from_writer_to_reader(&info.writer_to_reader);
+                    Ok(UnionReadPlan::ReaderUnion { lookup_table })
+                }
+                (false, true) => {
+                    let (reader_idx, promotion) =
+                        info.writer_to_reader.first().and_then(|x| 
*x).ok_or_else(|| {
+                            ArrowError::SchemaError(
+                                "Writer type does not match any reader union 
branch".to_string(),
+                            )
+                        })?;
+                    Ok(UnionReadPlan::FromSingle {
+                        reader_idx,
+                        promotion,
+                    })
+                }
+                (true, false) => Err(ArrowError::InvalidArgumentError(
+                    "UnionDecoder::try_new cannot build writer-union to 
single; use UnionDecoderBuilder with a target"
+                        .to_string(),
+                )),
+                (false, false) => Ok(UnionReadPlan::Passthrough),
+            },
+        }
+    }
+
+    #[inline]
+    fn read_tag(buf: &mut AvroCursor<'_>) -> Result<usize, ArrowError> {
+        let tag = buf.get_long()?;
+        if tag < 0 {
+            return Err(ArrowError::ParseError(format!(
+                "Negative union branch index {tag}"
+            )));
+        }
+        Ok(tag as usize)
+    }
+
+    #[inline]
+    fn emit_to(&mut self, reader_idx: usize) -> Result<&mut Decoder, 
ArrowError> {
+        if reader_idx >= self.branches.len() {
+            return Err(ArrowError::ParseError(format!(
+                "Union branch index {reader_idx} out of range ({} branches)",
+                self.branches.len()
+            )));
+        }
+        self.type_ids.push(self.type_id_by_reader_idx[reader_idx]);
+        self.offsets.push(self.counts[reader_idx]);
+        self.counts[reader_idx] += 1;
+        Ok(&mut self.branches[reader_idx])
+    }
+
+    #[inline]
+    fn on_decoder<F>(&mut self, fallback_idx: usize, action: F) -> Result<(), 
ArrowError>
+    where
+        F: FnOnce(&mut Decoder) -> Result<(), ArrowError>,
+    {
+        if let UnionReadPlan::ToSingle { target, .. } = &mut self.plan {
+            return action(target);
+        }
+        let reader_idx = match &self.plan {
+            UnionReadPlan::FromSingle { reader_idx, .. } => *reader_idx,
+            _ => fallback_idx,
+        };
+        self.emit_to(reader_idx).and_then(action)
+    }
+
+    fn append_null(&mut self) -> Result<(), ArrowError> {
+        self.on_decoder(self.null_emit_idx, |decoder| decoder.append_null())
+    }
+
+    fn append_default(&mut self, lit: &AvroLiteral) -> Result<(), ArrowError> {
+        self.on_decoder(self.default_emit_idx, |decoder| 
decoder.append_default(lit))
+    }
+
+    fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> {
+        let (reader_idx, promotion) = match &mut self.plan {
+            UnionReadPlan::ToSingle {
+                target,
+                lookup_table,
+            } => {
+                let idx = Self::read_tag(buf)?;
+                return match lookup_table.resolve(idx) {
+                    Some((_, promotion)) => target.decode_with_promotion(buf, 
promotion),
+                    None => Err(ArrowError::ParseError(format!(
+                        "Writer union branch {idx} does not resolve to reader 
type"
+                    ))),
+                };
+            }
+            UnionReadPlan::Passthrough => (Self::read_tag(buf)?, 
Promotion::Direct),
+            UnionReadPlan::ReaderUnion { lookup_table } => {
+                let idx = Self::read_tag(buf)?;
+                lookup_table.resolve(idx).ok_or_else(|| {
+                    ArrowError::ParseError(format!(
+                        "Union branch index {idx} not resolvable by reader 
schema"
+                    ))
+                })?
+            }
+            UnionReadPlan::FromSingle {
+                reader_idx,
+                promotion,
+            } => (*reader_idx, *promotion),
+            UnionReadPlan::ToSingle { .. } => {
+                return Err(ArrowError::ParseError(
+                    "Invalid union read plan state".to_string(),
+                ));
+            }
+        };
+        let decoder = self.emit_to(reader_idx)?;
+        decoder.decode_with_promotion(buf, promotion)
+    }
+
+    fn flush(&mut self, nulls: Option<NullBuffer>) -> Result<ArrayRef, 
ArrowError> {
+        match &mut self.plan {
+            UnionReadPlan::ToSingle { target, .. } => target.flush(nulls),
+            _ => {
+                debug_assert!(
+                    nulls.is_none(),
+                    "UnionArray does not accept a validity bitmap; \
+                     nulls should have been materialized as a Null child 
during decode"
+                );
+                let children = self
+                    .branches
+                    .iter_mut()
+                    .map(|d| d.flush(None))
+                    .collect::<Result<Vec<_>, _>>()?;
+                let type_ids_buf: ScalarBuffer<i8> =
+                    flush_values(&mut self.type_ids).into_iter().collect();
+                let offsets_buf: ScalarBuffer<i32> =
+                    flush_values(&mut self.offsets).into_iter().collect();
+                let arr = UnionArray::try_new(
+                    self.fields.clone(),
+                    type_ids_buf,
+                    Some(offsets_buf),
+                    children,
+                )
+                .map_err(|e| ArrowError::ParseError(e.to_string()))?;
+                Ok(Arc::new(arr))
+            }
+        }
+    }
+}
+
+#[derive(Debug, Default)]
+struct UnionDecoderBuilder {
+    fields: Option<UnionFields>,
+    branches: Option<Vec<Decoder>>,
+    resolved: Option<ResolvedUnion>,
+    target: Option<Box<Decoder>>,
+}
+
+impl UnionDecoderBuilder {
+    fn new() -> Self {
+        Self::default()
+    }
+
+    fn with_fields(mut self, fields: UnionFields) -> Self {
+        self.fields = Some(fields);
+        self
+    }
+
+    fn with_branches(mut self, branches: Vec<Decoder>) -> Self {
+        self.branches = Some(branches);
+        self
+    }
+
+    fn with_resolved_union(mut self, resolved_union: ResolvedUnion) -> Self {
+        self.resolved = Some(resolved_union);
+        self
+    }
+
+    fn with_target(mut self, target: Box<Decoder>) -> Self {
+        self.target = Some(target);
+        self
+    }
+
+    fn build(self) -> Result<UnionDecoder, ArrowError> {
+        match (self.resolved, self.fields, self.branches, self.target) {
+            (resolved, Some(fields), Some(branches), _) => {

Review Comment:
   This actually popped up in my head after I went to bed last night lol. 100% 
good catch here. I'll tighten this up.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to