jecsand838 commented on code in PR #8274:
URL: https://github.com/apache/arrow-rs/pull/8274#discussion_r2322745292


##########
arrow-avro/src/writer/encoder.rs:
##########
@@ -275,3 +384,434 @@ impl F64Encoder<'_> {
             .map_err(|e| ArrowError::IoError(format!("write f64: {e}"), e))
     }
 }
+
+struct Utf8GenericEncoder<'a, O: OffsetSizeTrait>(&'a GenericStringArray<O>);
+
+impl<'a, O: OffsetSizeTrait> Utf8GenericEncoder<'a, O> {
+    #[inline]
+    fn encode<W: Write + ?Sized>(&mut self, idx: usize, out: &mut W) -> 
Result<(), ArrowError> {
+        write_len_prefixed(out, self.0.value(idx).as_bytes())
+    }
+}
+
+type Utf8Encoder<'a> = Utf8GenericEncoder<'a, i32>;
+type Utf8LargeEncoder<'a> = Utf8GenericEncoder<'a, i64>;
+
+/// Unified field encoder:
+/// - Holds the inner `Encoder` (by value)
+/// - Tracks the column/site null buffer and whether any nulls exist
+/// - Carries per-site Avro `Nullability` and precomputed union branch (fast 
path)
+pub struct FieldEncoder<'a> {
+    encoder: Encoder<'a>,
+    nulls: Option<NullBuffer>,
+    has_nulls: bool,
+    nullability: Option<Nullability>,
+    /// Precomputed constant branch byte if the site is nullable but contains 
no nulls
+    pre: Option<u8>,
+}
+
+impl<'a> FieldEncoder<'a> {
+    fn make_encoder(
+        array: &'a dyn Array,
+        field: &Field,
+        plan: PlanRef<'_>,
+    ) -> Result<Self, ArrowError> {
+        let nulls = array.nulls().cloned();
+        let has_nulls = array.null_count() > 0;
+        let encoder = match plan {
+            FieldPlan::Struct { encoders } => {
+                let arr = array
+                    .as_any()
+                    .downcast_ref::<StructArray>()
+                    .ok_or_else(|| ArrowError::SchemaError("Expected 
StructArray".into()))?;
+                Encoder::Struct(Box::new(StructEncoder::try_new(arr, 
encoders)?))
+            }
+            FieldPlan::List {
+                items_nullability,
+                item_plan,
+            } => match array.data_type() {
+                DataType::List(_) => {
+                    let arr = array
+                        .as_any()
+                        .downcast_ref::<ListArray>()
+                        .ok_or_else(|| ArrowError::SchemaError("Expected 
ListArray".into()))?;
+                    Encoder::List(Box::new(ListEncoder32::try_new(
+                        arr,
+                        *items_nullability,
+                        item_plan.as_ref(),
+                    )?))
+                }
+                DataType::LargeList(_) => {
+                    let arr = array
+                        .as_any()
+                        .downcast_ref::<LargeListArray>()
+                        .ok_or_else(|| ArrowError::SchemaError("Expected 
LargeListArray".into()))?;
+                    Encoder::LargeList(Box::new(ListEncoder64::try_new(
+                        arr,
+                        *items_nullability,
+                        item_plan.as_ref(),
+                    )?))
+                }
+                other => {
+                    return Err(ArrowError::SchemaError(format!(
+                        "Avro array site requires Arrow List/LargeList, found: 
{other:?}"
+                    )))
+                }
+            },
+            FieldPlan::Scalar => match array.data_type() {
+                DataType::Boolean => 
Encoder::Boolean(BooleanEncoder(array.as_boolean())),
+                DataType::Utf8 => {
+                    
Encoder::Utf8(Utf8GenericEncoder::<i32>(array.as_string::<i32>()))
+                }
+                DataType::LargeUtf8 => {
+                    
Encoder::Utf8Large(Utf8GenericEncoder::<i64>(array.as_string::<i64>()))
+                }
+                DataType::Int32 => 
Encoder::Int(IntEncoder(array.as_primitive::<Int32Type>())),
+                DataType::Int64 => 
Encoder::Long(LongEncoder(array.as_primitive::<Int64Type>())),
+                DataType::Float32 => {
+                    
Encoder::Float32(F32Encoder(array.as_primitive::<Float32Type>()))
+                }
+                DataType::Float64 => {
+                    
Encoder::Float64(F64Encoder(array.as_primitive::<Float64Type>()))
+                }
+                DataType::Binary => 
Encoder::Binary(BinaryEncoder(array.as_binary::<i32>())),
+                DataType::LargeBinary => {
+                    
Encoder::LargeBinary(BinaryEncoder(array.as_binary::<i64>()))
+                }
+                DataType::Timestamp(TimeUnit::Microsecond, _) => 
Encoder::Timestamp(LongEncoder(
+                    array.as_primitive::<TimestampMicrosecondType>(),
+                )),
+                other => {
+                    return Err(ArrowError::NotYetImplemented(format!(
+                        "Avro scalar type not yet supported: {other:?}"
+                    )));
+                }
+            },
+            other => {
+                return Err(ArrowError::NotYetImplemented(
+                    "Avro writer: {other:?} not yet supported".into(),
+                ));
+            }
+        };
+        Ok(Self {
+            encoder,
+            nulls,
+            has_nulls,
+            nullability: None,
+            pre: None,
+        })
+    }
+
+    #[inline]
+    fn has_nulls(&self) -> bool {
+        self.has_nulls
+    }
+
+    #[inline]
+    fn is_null(&self, idx: usize) -> bool {
+        self.nulls
+            .as_ref()
+            .is_some_and(|null_buffer| null_buffer.is_null(idx))
+    }
+
+    #[inline]
+    fn with_effective_nullability(mut self, order: Option<Nullability>) -> 
Self {
+        self.nullability = order;
+        self.pre = self.precomputed_union_value_branch(order);
+        self
+    }
+
+    #[inline]
+    fn precomputed_union_value_branch(&self, order: Option<Nullability>) -> 
Option<u8> {
+        match (order, self.has_nulls()) {
+            (Some(Nullability::NullFirst), false) => Some(0x02), // value 
branch index 1
+            (Some(Nullability::NullSecond), false) => Some(0x00), // value 
branch index 0
+            _ => None,
+        }
+    }
+
+    #[inline]
+    fn encode_inner<W: Write + ?Sized>(
+        &mut self,
+        idx: usize,
+        out: &mut W,
+    ) -> Result<(), ArrowError> {
+        self.encoder.encode(idx, out)
+    }
+
+    #[inline]
+    fn encode<W: Write + ?Sized>(&mut self, idx: usize, out: &mut W) -> 
Result<(), ArrowError> {
+        if let Some(b) = self.pre {
+            return out
+                .write_all(&[b])
+                .map_err(|e| ArrowError::IoError(format!("write union value 
branch: {e}"), e))
+                .and_then(|_| self.encode_inner(idx, out));
+        }
+        if let Some(order) = self.nullability {
+            let is_null = self.is_null(idx);
+            write_optional_index(out, is_null, order)?;
+            if is_null {
+                return Ok(());
+            }
+        }
+        self.encode_inner(idx, out)
+    }
+}
+
+struct StructEncoder<'a> {
+    encoders: Vec<FieldEncoder<'a>>,
+}
+
+impl<'a> StructEncoder<'a> {
+    fn try_new(
+        array: &'a StructArray,
+        field_bindings: &[FieldBinding],
+    ) -> Result<Self, ArrowError> {
+        let fields = match array.data_type() {
+            DataType::Struct(struct_fields) => struct_fields,
+            _ => return Err(ArrowError::SchemaError("Expected Struct".into())),
+        };
+        let mut encoders = Vec::with_capacity(field_bindings.len());
+        for field_binding in field_bindings {
+            let idx = field_binding.arrow_index;
+            let column = array.columns().get(idx).ok_or_else(|| {
+                ArrowError::SchemaError(format!("Struct child index {idx} out 
of range"))
+            })?;
+            let field = fields
+                .get(idx)
+                .ok_or_else(|| {
+                    ArrowError::SchemaError(format!("Struct child index {idx} 
out of range"))
+                })?
+                .as_ref();
+            let encoder = prepare_value_site_encoder(
+                column.as_ref(),
+                field,
+                field_binding.nullability,
+                &field_binding.plan,
+            )?;
+            encoders.push(encoder);
+        }
+        Ok(Self { encoders })
+    }
+
+    #[inline]
+    fn encode<W: Write + ?Sized>(&mut self, idx: usize, out: &mut W) -> 
Result<(), ArrowError> {
+        for encoder in self.encoders.iter_mut() {
+            encoder.encode(idx, out)?;
+        }
+        Ok(())
+    }
+}
+
+#[inline]
+fn encode_blocked_range<W: Write + ?Sized, F>(
+    out: &mut W,
+    start: usize,
+    end: usize,
+    mut write_item: F,
+) -> Result<(), ArrowError>
+where
+    F: FnMut(usize, &mut W) -> Result<(), ArrowError>,
+{
+    let len = end.saturating_sub(start);
+    if len == 0 {
+        // Zero-length terminator per Avro spec
+        write_long(out, 0)?;
+        return Ok(());
+    }
+    // Emit a single positive block for performance, then the end marker.
+    write_long(out, len as i64)?;
+    for j in start..end {
+        write_item(j, out)?;
+    }
+    write_long(out, 0)?;
+    Ok(())
+}
+
+struct ListEncoder<'a, O: OffsetSizeTrait> {
+    list: &'a GenericListArray<O>,
+    values: FieldEncoder<'a>,
+    values_offset: usize,
+}
+
+type ListEncoder32<'a> = ListEncoder<'a, i32>;
+type ListEncoder64<'a> = ListEncoder<'a, i64>;
+
+impl<'a, O: OffsetSizeTrait> ListEncoder<'a, O> {
+    fn try_new(
+        list: &'a GenericListArray<O>,
+        items_nullability: Option<Nullability>,
+        item_plan: &FieldPlan,
+    ) -> Result<Self, ArrowError> {
+        let child_field = match list.data_type() {
+            DataType::List(field) => field.as_ref(),
+            DataType::LargeList(field) => field.as_ref(),
+            _ => {
+                return Err(ArrowError::SchemaError(
+                    "Expected List or LargeList for ListEncoder".into(),
+                ))
+            }
+        };
+        let values_enc = prepare_value_site_encoder(
+            list.values().as_ref(),
+            child_field,
+            items_nullability,
+            item_plan,
+        )?;
+        Ok(Self {
+            list,
+            values: values_enc,
+            values_offset: list.values().offset(),
+        })
+    }
+
+    #[inline]
+    fn encode_list_range<W: Write + ?Sized>(
+        &mut self,
+        out: &mut W,
+        start: usize,
+        end: usize,
+    ) -> Result<(), ArrowError> {
+        encode_blocked_range(out, start, end, |row, out| {
+            self.values
+                .encode(row.saturating_sub(self.values_offset), out)
+        })
+    }
+
+    #[inline]
+    fn encode<W: Write + ?Sized>(&mut self, idx: usize, out: &mut W) -> 
Result<(), ArrowError> {
+        let offsets = self.list.offsets();
+        let start = offsets[idx].to_usize().ok_or_else(|| {
+            ArrowError::InvalidArgumentError(format!("Error converting 
offset[{idx}] to usize"))
+        })?;
+        let end = offsets[idx + 1].to_usize().ok_or_else(|| {
+            ArrowError::InvalidArgumentError(format!(
+                "Error converting offset[{}] to usize",
+                idx + 1
+            ))
+        })?;
+        self.encode_list_range(out, start, end)
+    }
+}
+
+#[inline]
+fn prepare_value_site_encoder<'a>(
+    values_array: &'a dyn Array,
+    value_field: &Field,
+    site_nullability: Option<Nullability>,
+    plan: PlanRef<'_>,
+) -> Result<FieldEncoder<'a>, ArrowError> {
+    // Effective nullability is exactly the site's Avro-declared nullability.
+    Ok(FieldEncoder::make_encoder(values_array, value_field, plan)?
+        .with_effective_nullability(site_nullability))
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow_array::types::Int32Type;
+    use arrow_array::{
+        Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, 
Float64Array, Int32Array,
+        Int64Array, LargeBinaryArray, LargeListArray, LargeStringArray, 
ListArray, StringArray,
+        TimestampMicrosecondArray,
+    };
+    use arrow_schema::{DataType, Field, Fields};
+
+    fn zigzag_i64(v: i64) -> u64 {
+        ((v << 1) ^ (v >> 63)) as u64
+    }
+
+    fn varint(mut x: u64) -> Vec<u8> {
+        let mut out = Vec::new();
+        while (x & !0x7f) != 0 {
+            out.push(((x & 0x7f) as u8) | 0x80);
+            x >>= 7;
+        }
+        out.push((x & 0x7f) as u8);
+        out
+    }
+
+    fn avro_long_bytes(v: i64) -> Vec<u8> {
+        varint(zigzag_i64(v))
+    }
+
+    fn avro_len_prefixed_bytes(payload: &[u8]) -> Vec<u8> {
+        let mut out = avro_long_bytes(payload.len() as i64);
+        out.extend_from_slice(payload);
+        out
+    }
+
+    fn encode_all(array: &dyn Array, plan: &FieldPlan, site: 
Option<Nullability>) -> Vec<u8> {

Review Comment:
   I'll rename to `nullability`, that's a good recommendation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to