jecsand838 commented on code in PR #8274:
URL: https://github.com/apache/arrow-rs/pull/8274#discussion_r2325740170
##########
arrow-avro/src/writer/encoder.rs:
##########
@@ -83,49 +81,358 @@ fn write_bool<W: Write + ?Sized>(writer: &mut W, v: bool)
-> Result<(), ArrowErr
/// Branch index is 0-based per Avro unions:
/// - Null-first (default): null => 0, value => 1
/// - Null-second (Impala): value => 0, null => 1
-#[inline]
-fn write_optional_branch<W: Write + ?Sized>(
+fn write_optional_index<W: Write + ?Sized>(
writer: &mut W,
is_null: bool,
- impala_mode: bool,
+ null_order: Nullability,
) -> Result<(), ArrowError> {
- let branch = if impala_mode == is_null { 1 } else { 0 };
- write_int(writer, branch)
+ let byte = union_value_branch_byte(null_order, is_null);
+ writer
+ .write_all(&[byte])
+ .map_err(|e| ArrowError::IoError(format!("write union branch: {e}"),
e))
}
-/// Encode a `RecordBatch` in Avro binary format using **default options**.
-pub fn encode_record_batch<W: Write>(batch: &RecordBatch, out: &mut W) ->
Result<(), ArrowError> {
- encode_record_batch_with_options(batch, out, &EncoderOptions::default())
+#[derive(Debug, Clone)]
+enum NullState {
+ NonNullable,
+ NullableNoNulls {
+ byte: u8,
+ },
+ Nullable {
+ nulls: NullBuffer,
+ null_order: Nullability,
+ },
}
-/// Encode a `RecordBatch` with explicit `EncoderOptions`.
-pub fn encode_record_batch_with_options<W: Write>(
- batch: &RecordBatch,
- out: &mut W,
- opts: &EncoderOptions,
-) -> Result<(), ArrowError> {
- let mut encoders = batch
- .schema()
- .fields()
- .iter()
- .zip(batch.columns())
- .map(|(field, array)| Ok((field.is_nullable(),
make_encoder(array.as_ref())?)))
- .collect::<Result<Vec<_>, ArrowError>>()?;
- (0..batch.num_rows()).try_for_each(|row| {
- encoders.iter_mut().try_for_each(|(is_nullable, enc)| {
- if *is_nullable {
- let is_null = enc.is_null(row);
- write_optional_branch(out, is_null, opts.impala_mode)?;
- if is_null {
- return Ok(());
+/// Arrow to Avro FieldEncoder:
+/// - Holds the inner `Encoder` (by value)
+/// - Carries the per-site nullability **state** as a single enum that
enforces invariants
+pub struct FieldEncoder<'a> {
+ encoder: Encoder<'a>,
+ null_state: NullState,
+}
+
+impl<'a> FieldEncoder<'a> {
+ fn make_encoder(
+ array: &'a dyn Array,
+ field: &Field,
+ plan: &FieldPlan,
+ nullability: Option<Nullability>,
+ ) -> Result<Self, ArrowError> {
+ let has_nulls = array.null_count() > 0;
+ let encoder = match plan {
+ FieldPlan::Struct { encoders } => {
+ let arr = array
+ .as_any()
+ .downcast_ref::<StructArray>()
+ .ok_or_else(|| ArrowError::SchemaError("Expected
StructArray".into()))?;
+ Encoder::Struct(Box::new(StructEncoder::try_new(arr,
encoders)?))
+ }
+ FieldPlan::List {
+ items_nullability,
+ item_plan,
+ } => match array.data_type() {
+ DataType::List(_) => {
+ let arr = array
+ .as_any()
+ .downcast_ref::<ListArray>()
+ .ok_or_else(|| ArrowError::SchemaError("Expected
ListArray".into()))?;
+ Encoder::List(Box::new(ListEncoder32::try_new(
+ arr,
+ *items_nullability,
+ item_plan.as_ref(),
+ )?))
+ }
+ DataType::LargeList(_) => {
+ let arr = array
+ .as_any()
+ .downcast_ref::<LargeListArray>()
+ .ok_or_else(|| ArrowError::SchemaError("Expected
LargeListArray".into()))?;
+ Encoder::LargeList(Box::new(ListEncoder64::try_new(
+ arr,
+ *items_nullability,
+ item_plan.as_ref(),
+ )?))
+ }
+ other => {
+ return Err(ArrowError::SchemaError(format!(
+ "Avro array site requires Arrow List/LargeList, found:
{other:?}"
+ )))
+ }
+ },
+ FieldPlan::Scalar => match array.data_type() {
+ DataType::Boolean =>
Encoder::Boolean(BooleanEncoder(array.as_boolean())),
+ DataType::Utf8 => {
+
Encoder::Utf8(Utf8GenericEncoder::<i32>(array.as_string::<i32>()))
+ }
+ DataType::LargeUtf8 => {
+
Encoder::Utf8Large(Utf8GenericEncoder::<i64>(array.as_string::<i64>()))
+ }
+ DataType::Int32 =>
Encoder::Int(IntEncoder(array.as_primitive::<Int32Type>())),
+ DataType::Int64 =>
Encoder::Long(LongEncoder(array.as_primitive::<Int64Type>())),
+ DataType::Float32 => {
+
Encoder::Float32(F32Encoder(array.as_primitive::<Float32Type>()))
+ }
+ DataType::Float64 => {
+
Encoder::Float64(F64Encoder(array.as_primitive::<Float64Type>()))
+ }
+ DataType::Binary =>
Encoder::Binary(BinaryEncoder(array.as_binary::<i32>())),
+ DataType::LargeBinary => {
+
Encoder::LargeBinary(BinaryEncoder(array.as_binary::<i64>()))
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) =>
Encoder::Timestamp(LongEncoder(
+ array.as_primitive::<TimestampMicrosecondType>(),
+ )),
+ other => {
+ return Err(ArrowError::NotYetImplemented(format!(
+ "Avro scalar type not yet supported: {other:?}"
+ )));
+ }
+ },
+ other => {
+ return Err(ArrowError::NotYetImplemented(
+ "Avro writer: {other:?} not yet supported".into(),
+ ));
+ }
+ };
+ // Compute the effective null state from writer-declared nullability
and data nulls.
+ let null_state = match (nullability, has_nulls) {
+ (None, false) => NullState::NonNullable,
+ (None, true) => {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Avro site '{}' is non-nullable, but array contains nulls",
+ field.name()
+ )));
+ }
+ (Some(order), false) => {
+ // Optimization: drop any bitmap; emit a constant "value"
branch byte.
+ let byte = union_value_branch_byte(order, false);
+ NullState::NullableNoNulls { byte }
+ }
+ (Some(null_order), true) => {
+ let null_buffer = array.nulls().cloned().ok_or_else(|| {
+ ArrowError::InvalidArgumentError(format!(
+ "Array for Avro site '{}' reports nulls but has no
null buffer",
+ field.name()
+ ))
+ })?;
+ NullState::Nullable {
+ nulls: null_buffer,
+ null_order,
}
}
- enc.encode(row, out)
+ };
+ Ok(Self {
+ encoder,
+ null_state,
})
- })
+ }
+
+ fn encode<W: Write + ?Sized>(&mut self, idx: usize, out: &mut W) ->
Result<(), ArrowError> {
+ match &self.null_state {
+ NullState::NonNullable => self.encoder.encode(idx, out),
+ NullState::NullableNoNulls { byte } => {
+ // Constant non-null branch byte, then value.
+ out.write_all(&[*byte]).map_err(|e| {
+ ArrowError::IoError(format!("write union value branch:
{e}"), e)
+ })?;
+ self.encoder.encode(idx, out)
+ }
+ NullState::Nullable { nulls, null_order } => {
+ let is_null = nulls.is_null(idx);
+ write_optional_index(out, is_null, *null_order)?;
+ if is_null {
+ Ok(())
+ } else {
+ self.encoder.encode(idx, out)
+ }
+ }
+ }
+ }
+}
+
+fn union_value_branch_byte(null_order: Nullability, is_null: bool) -> u8 {
+ let nulls_first = null_order == Nullability::default();
+ if nulls_first == is_null {
+ 0x00
+ } else {
+ 0x02
+ }
+}
+
+/// Per‑site encoder plan for a field. This mirrors the Avro structure, so
nested
+/// optional branch order can be honored exactly as declared by the schema.
+#[derive(Debug, Clone)]
+enum FieldPlan {
+ /// Non-nested scalar/logical type
+ Scalar,
+ /// Record/Struct with Avro‑ordered children
+ Struct { encoders: Vec<FieldBinding> },
+ /// Array with item‑site nullability and nested plan
+ List {
+ items_nullability: Option<Nullability>,
+ item_plan: Box<FieldPlan>,
+ },
+}
+
+#[derive(Debug, Clone)]
+struct FieldBinding {
+ /// Index of the Arrow field/column associated with this Avro field site
+ arrow_index: usize,
+ /// Nullability/order for this site (None if required)
+ nullability: Option<Nullability>,
+ /// Nested plan for this site
+ plan: FieldPlan,
+}
+
+/// Builder for `RecordEncoder` write plan
+#[derive(Debug)]
+pub struct RecordEncoderBuilder<'a> {
+ avro_root: &'a AvroField,
+ arrow_schema: &'a ArrowSchema,
+}
+
+impl<'a> RecordEncoderBuilder<'a> {
+ /// Create a new builder from the Avro root and Arrow schema.
+ pub fn new(avro_root: &'a AvroField, arrow_schema: &'a ArrowSchema) ->
Self {
+ Self {
+ avro_root,
+ arrow_schema,
+ }
+ }
+
+ /// Build the `RecordEncoder` by walking the Avro **record** root in Avro
order,
+ /// resolving each field to an Arrow index by name.
+ pub fn build(self) -> Result<RecordEncoder, ArrowError> {
+ let avro_root_dt = self.avro_root.data_type();
+ let root_fields = match avro_root_dt.codec() {
+ Codec::Struct(fields) => fields,
+ _ => {
+ return Err(ArrowError::SchemaError(
+ "Top-level Avro schema must be a record/struct".into(),
+ ))
+ }
+ };
+ let mut columns = Vec::with_capacity(root_fields.len());
+ for root_field in root_fields.iter() {
Review Comment:
You're right, this is a good catch.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]