This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new d74d9baff6 Adds Map & Enum support, round-trip & benchmark tests
(#8353)
d74d9baff6 is described below
commit d74d9baff62ad5a61d50f6b13577274e0356aa90
Author: nathaniel-d-ef <[email protected]>
AuthorDate: Wed Sep 17 17:23:44 2025 +0200
Adds Map & Enum support, round-trip & benchmark tests (#8353)
# Which issue does this PR close?
- Part of https://github.com/apache/arrow-rs/issues/4886
- Related to https://github.com/apache/arrow-rs/pull/8274 and
https://github.com/apache/arrow-rs/pull/8298
# Rationale for this change
This PR adds Map and Enum encoders to the arrow-avro crate writer, along
with new benchmark tests for remaining types and round-trip tests.
# What changes are included in this PR?
New encoders:
**Map**
**Enum**
Corresponding changes in support of these encoders in FieldEncoder and
FieldPlan
## Additional round trip tests in `mod.rs`
New tests follow existing file read pattern
- simple_fixed
- duration_uuid
- nonnullable.impala.avro
- decimals
- enum
## Additional benchmark tests for data types
- Utf8
- List<Utf8>
- Struct
- FixedSizeBinary16
- UUID
- IntervalMonthDayNanoDuration
- Decimal32(bytes)
- Decimal64(bytes)
- Decimal128(bytes)
- Decimal128(fixed16)
- Decimal256(bytes)
- Map
- Enum
# Are these changes tested?
Yes, additional complex type unit tests have been added for Map and
Enum. The rest of the PR beyond the new types are tests themselves. All
tests, new and existing, pass.
# Are there any user-facing changes?
n/a, arrow-avro crate is not yet public
---------
Co-authored-by: Connor Sanders <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
arrow-avro/benches/avro_writer.rs | 456 +++++++++++++++++++++++++++++++++++++-
arrow-avro/src/writer/encoder.rs | 309 ++++++++++++++++++++++++++
arrow-avro/src/writer/mod.rs | 218 ++++++++++++++++++
3 files changed, 976 insertions(+), 7 deletions(-)
diff --git a/arrow-avro/benches/avro_writer.rs
b/arrow-avro/benches/avro_writer.rs
index 924cbbdc84..aeb9edbac8 100644
--- a/arrow-avro/benches/avro_writer.rs
+++ b/arrow-avro/benches/avro_writer.rs
@@ -15,19 +15,22 @@
// specific language governing permissions and limitations
// under the License.
-//! Benchmarks for `arrow‑avro` **Writer** (Avro Object Container Files)
-//!
+//! Benchmarks for `arrow-avro` Writer (Avro Object Container File)
extern crate arrow_avro;
extern crate criterion;
extern crate once_cell;
use arrow_array::{
- types::{Int32Type, Int64Type, TimestampMicrosecondType},
- ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array,
PrimitiveArray, RecordBatch,
+ builder::{ListBuilder, StringBuilder},
+ types::{Int32Type, Int64Type, IntervalMonthDayNanoType,
TimestampMicrosecondType},
+ ArrayRef, BinaryArray, BooleanArray, Decimal128Array, Decimal256Array,
Decimal32Array,
+ Decimal64Array, FixedSizeBinaryArray, Float32Array, Float64Array,
ListArray, PrimitiveArray,
+ RecordBatch, StringArray, StructArray,
};
use arrow_avro::writer::AvroWriter;
-use arrow_schema::{DataType, Field, Schema, TimeUnit};
+use arrow_buffer::i256;
+use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId,
Criterion, Throughput};
use once_cell::sync::Lazy;
use rand::{
@@ -35,6 +38,7 @@ use rand::{
rngs::StdRng,
Rng, SeedableRng,
};
+use std::collections::HashMap;
use std::io::Cursor;
use std::sync::Arc;
use std::time::Duration;
@@ -63,7 +67,9 @@ where
#[inline]
fn make_bool_array_with_tag(n: usize, tag: u64) -> BooleanArray {
let mut rng = rng_for(tag, n);
+ // Can't use SampleUniform for bool; use the RNG's boolean helper
let values = (0..n).map(|_| rng.random_bool(0.5));
+ // This repo exposes `from_iter`, not `from_iter_values` for BooleanArray
BooleanArray::from_iter(values.map(Some))
}
@@ -81,6 +87,21 @@ fn make_i64_array_with_tag(n: usize, tag: u64) ->
PrimitiveArray<Int64Type> {
PrimitiveArray::<Int64Type>::from_iter_values(values)
}
+#[inline]
+fn rand_ascii_string(rng: &mut StdRng, min_len: usize, max_len: usize) ->
String {
+ let len = rng.random_range(min_len..=max_len);
+ (0..len)
+ .map(|_| (rng.random_range(b'a'..=b'z') as char))
+ .collect()
+}
+
+#[inline]
+fn make_utf8_array_with_tag(n: usize, tag: u64) -> StringArray {
+ let mut rng = rng_for(tag, n);
+ let data: Vec<String> = (0..n).map(|_| rand_ascii_string(&mut rng, 3,
16)).collect();
+ StringArray::from_iter_values(data)
+}
+
#[inline]
fn make_f32_array_with_tag(n: usize, tag: u64) -> Float32Array {
let mut rng = rng_for(tag, n);
@@ -98,14 +119,52 @@ fn make_f64_array_with_tag(n: usize, tag: u64) ->
Float64Array {
#[inline]
fn make_binary_array_with_tag(n: usize, tag: u64) -> BinaryArray {
let mut rng = rng_for(tag, n);
- let mut payloads: Vec<[u8; 16]> = vec![[0; 16]; n];
- for p in payloads.iter_mut() {
+ let mut payloads: Vec<Vec<u8>> = Vec::with_capacity(n);
+ for _ in 0..n {
+ let len = rng.random_range(1..=16);
+ let mut p = vec![0u8; len];
rng.fill(&mut p[..]);
+ payloads.push(p);
}
let views: Vec<&[u8]> = payloads.iter().map(|p| &p[..]).collect();
+ // This repo exposes a simple `from_vec` for BinaryArray
BinaryArray::from_vec(views)
}
+#[inline]
+fn make_fixed16_array_with_tag(n: usize, tag: u64) -> FixedSizeBinaryArray {
+ let mut rng = rng_for(tag, n);
+ let payloads = (0..n)
+ .map(|_| {
+ let mut b = [0u8; 16];
+ rng.fill(&mut b);
+ b
+ })
+ .collect::<Vec<[u8; 16]>>();
+ // Fixed-size constructor available in this repo
+ FixedSizeBinaryArray::try_from_iter(payloads.into_iter()).expect("build
FixedSizeBinaryArray")
+}
+
+/// Make an Arrow `Interval(IntervalUnit::MonthDayNano)` array with
**non-negative**
+/// (months, days, nanos) values, and nanos as **multiples of 1_000_000**
(whole ms),
+/// per Avro `duration` constraints used by the writer.
+#[inline]
+fn make_interval_mdn_array_with_tag(
+ n: usize,
+ tag: u64,
+) -> PrimitiveArray<IntervalMonthDayNanoType> {
+ let mut rng = rng_for(tag, n);
+ let values = (0..n).map(|_| {
+ let months: i32 = rng.random_range(0..=120);
+ let days: i32 = rng.random_range(0..=31);
+ // pick millis within a day (safe within u32::MAX and realistic)
+ let millis: u32 = rng.random_range(0..=86_400_000);
+ let nanos: i64 = (millis as i64) * 1_000_000;
+ IntervalMonthDayNanoType::make_value(months, days, nanos)
+ });
+ PrimitiveArray::<IntervalMonthDayNanoType>::from_iter_values(values)
+}
+
#[inline]
fn make_ts_micros_array_with_tag(n: usize, tag: u64) ->
PrimitiveArray<TimestampMicrosecondType> {
let mut rng = rng_for(tag, n);
@@ -115,6 +174,77 @@ fn make_ts_micros_array_with_tag(n: usize, tag: u64) ->
PrimitiveArray<Timestamp
PrimitiveArray::<TimestampMicrosecondType>::from_iter_values(values)
}
+// === Decimal helpers & generators ===
+
+#[inline]
+fn pow10_i32(p: u8) -> i32 {
+ (0..p).fold(1i32, |acc, _| acc.saturating_mul(10))
+}
+
+#[inline]
+fn pow10_i64(p: u8) -> i64 {
+ (0..p).fold(1i64, |acc, _| acc.saturating_mul(10))
+}
+
+#[inline]
+fn pow10_i128(p: u8) -> i128 {
+ (0..p).fold(1i128, |acc, _| acc.saturating_mul(10))
+}
+
+#[inline]
+fn make_decimal32_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8)
-> Decimal32Array {
+ let mut rng = rng_for(tag, n);
+ let max = pow10_i32(precision).saturating_sub(1);
+ let values = (0..n).map(|_| rng.random_range(-max..=max));
+ Decimal32Array::from_iter_values(values)
+ .with_precision_and_scale(precision, scale)
+ .expect("set precision/scale on Decimal32Array")
+}
+
+#[inline]
+fn make_decimal64_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8)
-> Decimal64Array {
+ let mut rng = rng_for(tag, n);
+ let max = pow10_i64(precision).saturating_sub(1);
+ let values = (0..n).map(|_| rng.random_range(-max..=max));
+ Decimal64Array::from_iter_values(values)
+ .with_precision_and_scale(precision, scale)
+ .expect("set precision/scale on Decimal64Array")
+}
+
+#[inline]
+fn make_decimal128_array_with_tag(n: usize, tag: u64, precision: u8, scale:
i8) -> Decimal128Array {
+ let mut rng = rng_for(tag, n);
+ let max = pow10_i128(precision).saturating_sub(1);
+ let values = (0..n).map(|_| rng.random_range(-max..=max));
+ Decimal128Array::from_iter_values(values)
+ .with_precision_and_scale(precision, scale)
+ .expect("set precision/scale on Decimal128Array")
+}
+
+#[inline]
+fn make_decimal256_array_with_tag(n: usize, tag: u64, precision: u8, scale:
i8) -> Decimal256Array {
+ // Generate within i128 range and widen to i256 to keep generation cheap
and portable
+ let mut rng = rng_for(tag, n);
+ let max128 = pow10_i128(30).saturating_sub(1);
+ let values = (0..n).map(|_| {
+ let v: i128 = rng.random_range(-max128..=max128);
+ i256::from_i128(v)
+ });
+ Decimal256Array::from_iter_values(values)
+ .with_precision_and_scale(precision, scale)
+ .expect("set precision/scale on Decimal256Array")
+}
+
+#[inline]
+fn make_fixed16_array(n: usize) -> FixedSizeBinaryArray {
+ make_fixed16_array_with_tag(n, 0xF15E_D016)
+}
+
+#[inline]
+fn make_interval_mdn_array(n: usize) ->
PrimitiveArray<IntervalMonthDayNanoType> {
+ make_interval_mdn_array_with_tag(n, 0xD0_1E_AD)
+}
+
#[inline]
fn make_bool_array(n: usize) -> BooleanArray {
make_bool_array_with_tag(n, 0xB001)
@@ -143,6 +273,57 @@ fn make_binary_array(n: usize) -> BinaryArray {
fn make_ts_micros_array(n: usize) -> PrimitiveArray<TimestampMicrosecondType> {
make_ts_micros_array_with_tag(n, 0x7157_0001)
}
+#[inline]
+fn make_utf8_array(n: usize) -> StringArray {
+ make_utf8_array_with_tag(n, 0x5712_07F8)
+}
+#[inline]
+fn make_list_utf8_array(n: usize) -> ListArray {
+ make_list_utf8_array_with_tag(n, 0x0A11_57ED)
+}
+#[inline]
+fn make_struct_array(n: usize) -> StructArray {
+ make_struct_array_with_tag(n, 0x57_AB_C7)
+}
+
+#[inline]
+fn make_list_utf8_array_with_tag(n: usize, tag: u64) -> ListArray {
+ let mut rng = rng_for(tag, n);
+ let mut builder = ListBuilder::new(StringBuilder::new());
+ for _ in 0..n {
+ let items = rng.random_range(0..=5);
+ for _ in 0..items {
+ let s = rand_ascii_string(&mut rng, 1, 12);
+ builder.values().append_value(s.as_str());
+ }
+ builder.append(true);
+ }
+ builder.finish()
+}
+
+#[inline]
+fn make_struct_array_with_tag(n: usize, tag: u64) -> StructArray {
+ let s_tag = tag ^ 0x5u64;
+ let i_tag = tag ^ 0x6u64;
+ let f_tag = tag ^ 0x7u64;
+ let s_col: ArrayRef = Arc::new(make_utf8_array_with_tag(n, s_tag));
+ let i_col: ArrayRef = Arc::new(make_i32_array_with_tag(n, i_tag));
+ let f_col: ArrayRef = Arc::new(make_f64_array_with_tag(n, f_tag));
+ StructArray::from(vec![
+ (
+ Arc::new(Field::new("s1", DataType::Utf8, false)),
+ s_col.clone(),
+ ),
+ (
+ Arc::new(Field::new("s2", DataType::Int32, false)),
+ i_col.clone(),
+ ),
+ (
+ Arc::new(Field::new("s3", DataType::Float64, false)),
+ f_col.clone(),
+ ),
+ ])
+}
#[inline]
fn schema_single(name: &str, dt: DataType) -> Arc<Schema> {
@@ -159,6 +340,36 @@ fn schema_mixed() -> Arc<Schema> {
]))
}
+#[inline]
+fn schema_fixed16() -> Arc<Schema> {
+ schema_single("field1", DataType::FixedSizeBinary(16))
+}
+
+#[inline]
+fn schema_uuid16() -> Arc<Schema> {
+ let mut md = HashMap::new();
+ md.insert("logicalType".to_string(), "uuid".to_string());
+ let field = Field::new("uuid", DataType::FixedSizeBinary(16),
false).with_metadata(md);
+ Arc::new(Schema::new(vec![field]))
+}
+
+#[inline]
+fn schema_interval_mdn() -> Arc<Schema> {
+ schema_single("duration", DataType::Interval(IntervalUnit::MonthDayNano))
+}
+
+#[inline]
+fn schema_decimal_with_size(name: &str, dt: DataType, size_meta:
Option<usize>) -> Arc<Schema> {
+ let field = if let Some(size) = size_meta {
+ let mut md = HashMap::new();
+ md.insert("size".to_string(), size.to_string());
+ Field::new(name, dt, false).with_metadata(md)
+ } else {
+ Field::new(name, dt, false)
+ };
+ Arc::new(Schema::new(vec![field]))
+}
+
static BOOLEAN_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
let schema = schema_single("field1", DataType::Boolean);
SIZES
@@ -225,6 +436,40 @@ static BINARY_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
.collect()
});
+static FIXED16_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ let schema = schema_fixed16();
+ SIZES
+ .iter()
+ .map(|&n| {
+ let col: ArrayRef = Arc::new(make_fixed16_array(n));
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
+static UUID16_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ let schema = schema_uuid16();
+ SIZES
+ .iter()
+ .map(|&n| {
+ // Same values as Fixed16; writer path differs because of field
metadata
+ let col: ArrayRef = Arc::new(make_fixed16_array_with_tag(n,
0x7575_6964_7575_6964));
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
+static INTERVAL_MDN_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ let schema = schema_interval_mdn();
+ SIZES
+ .iter()
+ .map(|&n| {
+ let col: ArrayRef = Arc::new(make_interval_mdn_array(n));
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
static TIMESTAMP_US_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
let schema = schema_single("field1",
DataType::Timestamp(TimeUnit::Microsecond, None));
SIZES
@@ -250,6 +495,190 @@ static MIXED_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
.collect()
});
+static UTF8_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ let schema = schema_single("field1", DataType::Utf8);
+ SIZES
+ .iter()
+ .map(|&n| {
+ let col: ArrayRef = Arc::new(make_utf8_array(n));
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
+static LIST_UTF8_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ // IMPORTANT: ListBuilder creates a child field named "item" that is
nullable by default.
+ // Make the schema's list item nullable to match the array we construct.
+ let item_field = Arc::new(Field::new("item", DataType::Utf8, true));
+ let schema = schema_single("field1", DataType::List(item_field));
+ SIZES
+ .iter()
+ .map(|&n| {
+ let col: ArrayRef = Arc::new(make_list_utf8_array(n));
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
+static STRUCT_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ let struct_dt = DataType::Struct(
+ vec![
+ Field::new("s1", DataType::Utf8, false),
+ Field::new("s2", DataType::Int32, false),
+ Field::new("s3", DataType::Float64, false),
+ ]
+ .into(),
+ );
+ let schema = schema_single("field1", struct_dt);
+ SIZES
+ .iter()
+ .map(|&n| {
+ let col: ArrayRef = Arc::new(make_struct_array(n));
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
+static DECIMAL32_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ // Choose a representative precision/scale within Decimal32 limits
+ let precision: u8 = 7;
+ let scale: i8 = 2;
+ let schema = schema_single("amount", DataType::Decimal32(precision,
scale));
+ SIZES
+ .iter()
+ .map(|&n| {
+ let arr = make_decimal32_array_with_tag(n, 0xDEC_0032, precision,
scale);
+ let col: ArrayRef = Arc::new(arr);
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
+static DECIMAL64_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ let precision: u8 = 13;
+ let scale: i8 = 3;
+ let schema = schema_single("amount", DataType::Decimal64(precision,
scale));
+ SIZES
+ .iter()
+ .map(|&n| {
+ let arr = make_decimal64_array_with_tag(n, 0xDEC_0064, precision,
scale);
+ let col: ArrayRef = Arc::new(arr);
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
+static DECIMAL128_BYTES_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ let precision: u8 = 25;
+ let scale: i8 = 6;
+ let schema = schema_single("amount", DataType::Decimal128(precision,
scale));
+ SIZES
+ .iter()
+ .map(|&n| {
+ let arr = make_decimal128_array_with_tag(n, 0xDEC_0128, precision,
scale);
+ let col: ArrayRef = Arc::new(arr);
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
+static DECIMAL128_FIXED16_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ // Same logical type as above but force Avro fixed(16) via metadata
"size": "16"
+ let precision: u8 = 25;
+ let scale: i8 = 6;
+ let schema =
+ schema_decimal_with_size("amount", DataType::Decimal128(precision,
scale), Some(16));
+ SIZES
+ .iter()
+ .map(|&n| {
+ let arr = make_decimal128_array_with_tag(n, 0xDEC_F128, precision,
scale);
+ let col: ArrayRef = Arc::new(arr);
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
+static DECIMAL256_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ // Use a higher precision typical of 256-bit decimals
+ let precision: u8 = 50;
+ let scale: i8 = 10;
+ let schema = schema_single("amount", DataType::Decimal256(precision,
scale));
+ SIZES
+ .iter()
+ .map(|&n| {
+ let arr = make_decimal256_array_with_tag(n, 0xDEC_0256, precision,
scale);
+ let col: ArrayRef = Arc::new(arr);
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
+static MAP_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ use arrow_array::builder::{MapBuilder, StringBuilder};
+
+ let key_field = Arc::new(Field::new("keys", DataType::Utf8, false));
+ let value_field = Arc::new(Field::new("values", DataType::Utf8, true));
+ let entry_struct = Field::new(
+ "entries",
+ DataType::Struct(vec![key_field.as_ref().clone(),
value_field.as_ref().clone()].into()),
+ false,
+ );
+ let map_dt = DataType::Map(Arc::new(entry_struct), false);
+ let schema = schema_single("field1", map_dt);
+
+ SIZES
+ .iter()
+ .map(|&n| {
+ // Build a MapArray with n rows
+ let mut builder = MapBuilder::new(None, StringBuilder::new(),
StringBuilder::new());
+ let mut rng = rng_for(0x00D0_0D1A, n);
+ for _ in 0..n {
+ let entries = rng.random_range(0..=5);
+ for _ in 0..entries {
+ let k = rand_ascii_string(&mut rng, 3, 10);
+ let v = rand_ascii_string(&mut rng, 0, 12);
+ // keys non-nullable, values nullable allowed but we
provide non-null here
+ builder.keys().append_value(k);
+ builder.values().append_value(v);
+ }
+ builder.append(true).expect("Error building MapArray");
+ }
+ let col: ArrayRef = Arc::new(builder.finish());
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
+static ENUM_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ // To represent an Avro enum, the Arrow writer expects a Dictionary<Int32,
Utf8>
+ // field with metadata specifying the enum symbols.
+ let enum_symbols = r#"["RED", "GREEN", "BLUE"]"#;
+ let mut metadata = HashMap::new();
+ metadata.insert("avro.enum.symbols".to_string(), enum_symbols.to_string());
+
+ let dict_type = DataType::Dictionary(Box::new(DataType::Int32),
Box::new(DataType::Utf8));
+ let field = Field::new("color_enum", dict_type,
false).with_metadata(metadata);
+ let schema = Arc::new(Schema::new(vec![field]));
+
+ let dict_values: ArrayRef = Arc::new(StringArray::from(vec!["RED",
"GREEN", "BLUE"]));
+
+ SIZES
+ .iter()
+ .map(|&n| {
+ use arrow_array::DictionaryArray;
+ let mut rng = rng_for(0x3A7A, n);
+ let keys_vec: Vec<i32> = (0..n).map(|_|
rng.random_range(0..=2)).collect();
+ let keys = PrimitiveArray::<Int32Type>::from(keys_vec);
+
+ let dict_array =
+ DictionaryArray::<Int32Type>::try_new(keys,
dict_values.clone()).unwrap();
+ let col: ArrayRef = Arc::new(dict_array);
+
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
fn ocf_size_for_batch(batch: &RecordBatch) -> usize {
let schema_owned: Schema = (*batch.schema()).clone();
let cursor = Cursor::new(Vec::<u8>::with_capacity(1024));
@@ -314,6 +743,19 @@ fn criterion_benches(c: &mut Criterion) {
bench_writer_scenario(c, "write-Binary(Bytes)", &BINARY_DATA);
bench_writer_scenario(c, "write-TimestampMicros", &TIMESTAMP_US_DATA);
bench_writer_scenario(c, "write-Mixed", &MIXED_DATA);
+ bench_writer_scenario(c, "write-Utf8", &UTF8_DATA);
+ bench_writer_scenario(c, "write-List<Utf8>", &LIST_UTF8_DATA);
+ bench_writer_scenario(c, "write-Struct", &STRUCT_DATA);
+ bench_writer_scenario(c, "write-FixedSizeBinary16", &FIXED16_DATA);
+ bench_writer_scenario(c, "write-UUID(logicalType)", &UUID16_DATA);
+ bench_writer_scenario(c, "write-IntervalMonthDayNanoDuration",
&INTERVAL_MDN_DATA);
+ bench_writer_scenario(c, "write-Decimal32(bytes)", &DECIMAL32_DATA);
+ bench_writer_scenario(c, "write-Decimal64(bytes)", &DECIMAL64_DATA);
+ bench_writer_scenario(c, "write-Decimal128(bytes)",
&DECIMAL128_BYTES_DATA);
+ bench_writer_scenario(c, "write-Decimal128(fixed16)",
&DECIMAL128_FIXED16_DATA);
+ bench_writer_scenario(c, "write-Decimal256(bytes)", &DECIMAL256_DATA);
+ bench_writer_scenario(c, "write-Map", &MAP_DATA);
+ bench_writer_scenario(c, "write-Enum", &ENUM_DATA);
}
criterion_group! {
diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs
index d80a3e739a..fd61924961 100644
--- a/arrow-avro/src/writer/encoder.rs
+++ b/arrow-avro/src/writer/encoder.rs
@@ -363,6 +363,60 @@ impl<'a> FieldEncoder<'a> {
.ok_or_else(|| ArrowError::SchemaError("Expected
FixedSizeBinaryArray".into()))?;
Encoder::Uuid(UuidEncoder(arr))
}
+ FieldPlan::Map { values_nullability,
+ value_plan } => {
+ let arr = array
+ .as_any()
+ .downcast_ref::<MapArray>()
+ .ok_or_else(|| ArrowError::SchemaError("Expected
MapArray".into()))?;
+ Encoder::Map(Box::new(MapEncoder::try_new(arr,
*values_nullability, value_plan.as_ref())?))
+ }
+ FieldPlan::Enum { symbols} => match array.data_type() {
+ DataType::Dictionary(key_dt, value_dt) => {
+ if **key_dt != DataType::Int32 || **value_dt !=
DataType::Utf8 {
+ return Err(ArrowError::SchemaError(
+ "Avro enum requires Dictionary<Int32,
Utf8>".into(),
+ ));
+ }
+ let dict = array
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int32Type>>()
+ .ok_or_else(|| {
+ ArrowError::SchemaError("Expected
DictionaryArray<Int32>".into())
+ })?;
+
+ let values = dict
+ .values()
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .ok_or_else(|| {
+ ArrowError::SchemaError("Dictionary values must be
Utf8".into())
+ })?;
+ if values.len() != symbols.len() {
+ return Err(ArrowError::SchemaError(format!(
+ "Enum symbol length {} != dictionary size {}",
+ symbols.len(),
+ values.len()
+ )));
+ }
+ for i in 0..values.len() {
+ if values.value(i) != symbols[i].as_str() {
+ return Err(ArrowError::SchemaError(format!(
+ "Enum symbol mismatch at {i}: schema='{}'
dict='{}'",
+ symbols[i],
+ values.value(i)
+ )));
+ }
+ }
+ let keys = dict.keys();
+ Encoder::Enum(EnumEncoder { keys })
+ }
+ other => {
+ return Err(ArrowError::SchemaError(format!(
+ "Avro enum site requires DataType::Dictionary, found:
{other:?}"
+ )))
+ }
+ }
other => {
return Err(ArrowError::NotYetImplemented(format!(
"Avro writer: {other:?} not yet supported",
@@ -443,6 +497,14 @@ enum FieldPlan {
Decimal { size: Option<usize> },
/// Avro UUID logical type (fixed)
Uuid,
+ /// Avro map with value‑site nullability and nested plan
+ Map {
+ values_nullability: Option<Nullability>,
+ value_plan: Box<FieldPlan>,
+ },
+ /// Avro enum; maps to Arrow Dictionary<Int32, Utf8> with dictionary values
+ /// exactly equal and ordered as the Avro enum `symbols`.
+ Enum { symbols: Arc<[String]> },
}
#[derive(Debug, Clone)]
@@ -631,6 +693,54 @@ impl FieldPlan {
"Avro array maps to Arrow List/LargeList, found: {other:?}"
))),
},
+ Codec::Map(values_dt) => {
+ let entries_field = match arrow_field.data_type() {
+ DataType::Map(entries, _sorted) => entries.as_ref(),
+ other => {
+ return Err(ArrowError::SchemaError(format!(
+ "Avro map maps to Arrow DataType::Map, found:
{other:?}"
+ )))
+ }
+ };
+ let entries_struct_fields = match entries_field.data_type() {
+ DataType::Struct(fs) => fs,
+ other => {
+ return Err(ArrowError::SchemaError(format!(
+ "Arrow Map entries must be Struct, found:
{other:?}"
+ )))
+ }
+ };
+ let value_idx =
+
find_map_value_field_index(entries_struct_fields).ok_or_else(|| {
+ ArrowError::SchemaError("Map entries struct missing
value field".into())
+ })?;
+ let value_field = entries_struct_fields[value_idx].as_ref();
+ let value_plan = FieldPlan::build(values_dt.as_ref(),
value_field)?;
+ Ok(FieldPlan::Map {
+ values_nullability: values_dt.nullability(),
+ value_plan: Box::new(value_plan),
+ })
+ }
+ Codec::Enum(symbols) => match arrow_field.data_type() {
+ DataType::Dictionary(key_dt, value_dt) => {
+ if **key_dt != DataType::Int32 {
+ return Err(ArrowError::SchemaError(
+ "Avro enum requires Dictionary<Int32,
Utf8>".into(),
+ ));
+ }
+ if **value_dt != DataType::Utf8 {
+ return Err(ArrowError::SchemaError(
+ "Avro enum requires Dictionary<Int32,
Utf8>".into(),
+ ));
+ }
+ Ok(FieldPlan::Enum {
+ symbols: symbols.clone(),
+ })
+ }
+ other => Err(ArrowError::SchemaError(format!(
+ "Avro enum maps to Arrow Dictionary<Int32, Utf8>, found:
{other:?}"
+ ))),
+ },
// decimal site (bytes or fixed(N)) with precision/scale validation
Codec::Decimal(precision, scale_opt, fixed_size_opt) => {
let (ap, as_) = match arrow_field.data_type() {
@@ -700,6 +810,9 @@ enum Encoder<'a> {
Decimal64(Decimal64Encoder<'a>),
Decimal128(Decimal128Encoder<'a>),
Decimal256(Decimal256Encoder<'a>),
+ /// Avro `enum` encoder: writes the key (int) as the enum index.
+ Enum(EnumEncoder<'a>),
+ Map(Box<MapEncoder<'a>>),
}
impl<'a> Encoder<'a> {
@@ -730,6 +843,8 @@ impl<'a> Encoder<'a> {
Encoder::Decimal64(e) => (e).encode(out, idx),
Encoder::Decimal128(e) => (e).encode(out, idx),
Encoder::Decimal256(e) => (e).encode(out, idx),
+ Encoder::Map(e) => (e).encode(out, idx),
+ Encoder::Enum(e) => (e).encode(out, idx),
}
}
}
@@ -795,6 +910,139 @@ impl<'a, O: OffsetSizeTrait> Utf8GenericEncoder<'a, O> {
type Utf8Encoder<'a> = Utf8GenericEncoder<'a, i32>;
type Utf8LargeEncoder<'a> = Utf8GenericEncoder<'a, i64>;
+
+/// Internal key array kind used by Map encoder.
+enum KeyKind<'a> {
+ Utf8(&'a GenericStringArray<i32>),
+ LargeUtf8(&'a GenericStringArray<i64>),
+}
+struct MapEncoder<'a> {
+ map: &'a MapArray,
+ keys: KeyKind<'a>,
+ values: FieldEncoder<'a>,
+ keys_offset: usize,
+ values_offset: usize,
+}
+
+impl<'a> MapEncoder<'a> {
+ fn try_new(
+ map: &'a MapArray,
+ values_nullability: Option<Nullability>,
+ value_plan: &FieldPlan,
+ ) -> Result<Self, ArrowError> {
+ let keys_arr = map.keys();
+ let keys_kind = match keys_arr.data_type() {
+ DataType::Utf8 => KeyKind::Utf8(keys_arr.as_string::<i32>()),
+ DataType::LargeUtf8 =>
KeyKind::LargeUtf8(keys_arr.as_string::<i64>()),
+ other => {
+ return Err(ArrowError::SchemaError(format!(
+ "Avro map requires string keys; Arrow key type must be
Utf8/LargeUtf8, found: {other:?}"
+ )))
+ }
+ };
+
+ let entries_struct_fields = match map.data_type() {
+ DataType::Map(entries, _) => match entries.data_type() {
+ DataType::Struct(fs) => fs,
+ other => {
+ return Err(ArrowError::SchemaError(format!(
+ "Arrow Map entries must be Struct, found: {other:?}"
+ )))
+ }
+ },
+ _ => {
+ return Err(ArrowError::SchemaError(
+ "Expected MapArray with DataType::Map".into(),
+ ))
+ }
+ };
+
+ let v_idx =
find_map_value_field_index(entries_struct_fields).ok_or_else(|| {
+ ArrowError::SchemaError("Map entries struct missing value
field".into())
+ })?;
+ let value_field = entries_struct_fields[v_idx].as_ref();
+
+ let values_enc = prepare_value_site_encoder(
+ map.values().as_ref(),
+ value_field,
+ values_nullability,
+ value_plan,
+ )?;
+
+ Ok(Self {
+ map,
+ keys: keys_kind,
+ values: values_enc,
+ keys_offset: keys_arr.offset(),
+ values_offset: map.values().offset(),
+ })
+ }
+
+ fn encode_map_entries<W, O>(
+ out: &mut W,
+ keys: &GenericStringArray<O>,
+ keys_offset: usize,
+ start: usize,
+ end: usize,
+ mut write_item: impl FnMut(&mut W, usize) -> Result<(), ArrowError>,
+ ) -> Result<(), ArrowError>
+ where
+ W: Write + ?Sized,
+ O: OffsetSizeTrait,
+ {
+ encode_blocked_range(out, start, end, |out, j| {
+ let j_key = j.saturating_sub(keys_offset);
+ write_len_prefixed(out, keys.value(j_key).as_bytes())?;
+ write_item(out, j)
+ })
+ }
+
+ fn encode<W: Write + ?Sized>(&mut self, out: &mut W, idx: usize) ->
Result<(), ArrowError> {
+ let offsets = self.map.offsets();
+ let start = offsets[idx] as usize;
+ let end = offsets[idx + 1] as usize;
+
+ let mut write_item = |out: &mut W, j: usize| {
+ let j_val = j.saturating_sub(self.values_offset);
+ self.values.encode(out, j_val)
+ };
+
+ match self.keys {
+ KeyKind::Utf8(arr) => MapEncoder::<'a>::encode_map_entries(
+ out,
+ arr,
+ self.keys_offset,
+ start,
+ end,
+ write_item,
+ ),
+ KeyKind::LargeUtf8(arr) => MapEncoder::<'a>::encode_map_entries(
+ out,
+ arr,
+ self.keys_offset,
+ start,
+ end,
+ write_item,
+ ),
+ }
+ }
+}
+
+/// Avro `enum` encoder for Arrow `DictionaryArray<Int32, Utf8>`.
+///
+/// Per Avro spec, an enum is encoded as an **int** equal to the
+/// zero-based position of the symbol in the schema’s `symbols` list.
+/// We validate at construction that the dictionary values equal the symbols,
+/// so we can directly write the key value here.
+struct EnumEncoder<'a> {
+ keys: &'a PrimitiveArray<Int32Type>,
+}
+impl EnumEncoder<'_> {
+ fn encode<W: Write + ?Sized>(&mut self, out: &mut W, row: usize) ->
Result<(), ArrowError> {
+ write_int(out, self.keys.value(row))
+ }
+}
+
struct StructEncoder<'a> {
encoders: Vec<FieldEncoder<'a>>,
}
@@ -1314,6 +1562,25 @@ mod tests {
assert_bytes_eq(&got, &expected);
}
+ #[test]
+ fn enum_encoder_dictionary() {
+ // symbols: ["A","B","C"], keys [2,0,1]
+ let dict_values = StringArray::from(vec!["A", "B", "C"]);
+ let keys = Int32Array::from(vec![2, 0, 1]);
+ let dict =
+ DictionaryArray::<Int32Type>::try_new(keys, Arc::new(dict_values)
as ArrayRef).unwrap();
+ let symbols = Arc::<[String]>::from(
+ vec!["A".to_string(), "B".to_string(),
"C".to_string()].into_boxed_slice(),
+ );
+ let plan = FieldPlan::Enum { symbols };
+ let got = encode_all(&dict, &plan, None);
+ let mut expected = Vec::new();
+ expected.extend(avro_long_bytes(2));
+ expected.extend(avro_long_bytes(0));
+ expected.extend(avro_long_bytes(1));
+ assert_bytes_eq(&got, &expected);
+ }
+
#[test]
fn decimal_bytes_and_fixed() {
// Use Decimal128 with small positives and negatives
@@ -1498,6 +1765,48 @@ mod tests {
}
}
+ #[test]
+ fn map_encoder_string_keys_int_values() {
+ // Build MapArray with two rows
+ // Row0: {"k1":1, "k2":2}
+ // Row1: {}
+ let keys = StringArray::from(vec!["k1", "k2"]);
+ let values = Int32Array::from(vec![1, 2]);
+ let entries_fields = Fields::from(vec![
+ Field::new("key", DataType::Utf8, false),
+ Field::new("value", DataType::Int32, true),
+ ]);
+ let entries = StructArray::new(
+ entries_fields,
+ vec![Arc::new(keys) as ArrayRef, Arc::new(values) as ArrayRef],
+ None,
+ );
+ let offsets = arrow_buffer::OffsetBuffer::new(vec![0i32, 2, 2].into());
+ let map = MapArray::new(
+ Field::new("entries", entries.data_type().clone(), false).into(),
+ offsets,
+ entries,
+ None,
+ false,
+ );
+ let plan = FieldPlan::Map {
+ values_nullability: None,
+ value_plan: Box::new(FieldPlan::Scalar),
+ };
+ let got = encode_all(&map, &plan, None);
+ let mut expected = Vec::new();
+ // Row0: block 2 then pairs
+ expected.extend(avro_long_bytes(2));
+ expected.extend(avro_len_prefixed_bytes(b"k1"));
+ expected.extend(avro_long_bytes(1));
+ expected.extend(avro_len_prefixed_bytes(b"k2"));
+ expected.extend(avro_long_bytes(2));
+ expected.extend(avro_long_bytes(0));
+ // Row1: empty
+ expected.extend(avro_long_bytes(0));
+ assert_bytes_eq(&got, &expected);
+ }
+
#[test]
fn list64_encoder_int32() {
// LargeList [[1,2,3], []]
diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs
index a5b2691bb8..f5e84eeb50 100644
--- a/arrow-avro/src/writer/mod.rs
+++ b/arrow-avro/src/writer/mod.rs
@@ -415,4 +415,222 @@ mod tests {
);
Ok(())
}
+
+ #[test]
+ fn test_round_trip_simple_fixed_ocf() -> Result<(), ArrowError> {
+ let path = arrow_test_data("avro/simple_fixed.avro");
+ let rdr_file = File::open(&path).expect("open avro/simple_fixed.avro");
+ let mut reader = ReaderBuilder::new()
+ .build(BufReader::new(rdr_file))
+ .expect("build avro reader");
+ let schema = reader.schema();
+ let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
+ let original =
+ arrow::compute::concat_batches(&schema,
&input_batches).expect("concat input");
+ let tmp = NamedTempFile::new().expect("create temp file");
+ let out_file = File::create(tmp.path()).expect("create temp avro");
+ let mut writer = AvroWriter::new(out_file,
original.schema().as_ref().clone())?;
+ writer.write(&original)?;
+ writer.finish()?;
+ drop(writer);
+ let rt_file = File::open(tmp.path()).expect("open round_trip avro");
+ let mut rt_reader = ReaderBuilder::new()
+ .build(BufReader::new(rt_file))
+ .expect("build round_trip reader");
+ let rt_schema = rt_reader.schema();
+ let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
+ let round_trip =
+ arrow::compute::concat_batches(&rt_schema,
&rt_batches).expect("concat round_trip");
+ assert_eq!(round_trip, original);
+ Ok(())
+ }
+
+ #[cfg(not(feature = "canonical_extension_types"))]
+ #[test]
+ fn test_round_trip_duration_and_uuid_ocf() -> Result<(), ArrowError> {
+ let in_file =
+ File::open("test/data/duration_uuid.avro").expect("open
test/data/duration_uuid.avro");
+ let mut reader = ReaderBuilder::new()
+ .build(BufReader::new(in_file))
+ .expect("build reader for duration_uuid.avro");
+ let in_schema = reader.schema();
+ let has_mdn = in_schema.fields().iter().any(|f| {
+ matches!(
+ f.data_type(),
+ DataType::Interval(IntervalUnit::MonthDayNano)
+ )
+ });
+ assert!(
+ has_mdn,
+ "expected at least one Interval(MonthDayNano) field in
duration_uuid.avro"
+ );
+ let has_uuid_fixed = in_schema
+ .fields()
+ .iter()
+ .any(|f| matches!(f.data_type(), DataType::FixedSizeBinary(16)));
+ assert!(
+ has_uuid_fixed,
+ "expected at least one FixedSizeBinary(16) (uuid) field in
duration_uuid.avro"
+ );
+ let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
+ let input =
+ arrow::compute::concat_batches(&in_schema,
&input_batches).expect("concat input");
+ let tmp = NamedTempFile::new().expect("create temp file");
+ {
+ let out_file = File::create(tmp.path()).expect("create temp avro");
+ let mut writer = AvroWriter::new(out_file,
in_schema.as_ref().clone())?;
+ writer.write(&input)?;
+ writer.finish()?;
+ }
+ let rt_file = File::open(tmp.path()).expect("open round_trip avro");
+ let mut rt_reader = ReaderBuilder::new()
+ .build(BufReader::new(rt_file))
+ .expect("build round_trip reader");
+ let rt_schema = rt_reader.schema();
+ let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
+ let round_trip =
+ arrow::compute::concat_batches(&rt_schema,
&rt_batches).expect("concat round_trip");
+ assert_eq!(round_trip, input);
+ Ok(())
+ }
+
+ // This test reads the same 'nonnullable.impala.avro' used by the reader
tests,
+ // writes it back out with the writer (hitting Map encoding paths), then
reads it
+ // again and asserts exact Arrow equivalence.
+ #[test]
+ fn test_nonnullable_impala_roundtrip_writer() -> Result<(), ArrowError> {
+ // Load source Avro with Map fields
+ let path = arrow_test_data("avro/nonnullable.impala.avro");
+ let rdr_file = File::open(&path).expect("open
avro/nonnullable.impala.avro");
+ let mut reader = ReaderBuilder::new()
+ .build(BufReader::new(rdr_file))
+ .expect("build reader for nonnullable.impala.avro");
+ // Collect all input batches and concatenate to a single RecordBatch
+ let in_schema = reader.schema();
+ // Sanity: ensure the file actually contains at least one Map field
+ let has_map = in_schema
+ .fields()
+ .iter()
+ .any(|f| matches!(f.data_type(), DataType::Map(_, _)));
+ assert!(
+ has_map,
+ "expected at least one Map field in avro/nonnullable.impala.avro"
+ );
+
+ let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
+ let original =
+ arrow::compute::concat_batches(&in_schema,
&input_batches).expect("concat input");
+ // Write out using the OCF writer into an in-memory Vec<u8>
+ let buffer = Vec::<u8>::new();
+ let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?;
+ writer.write(&original)?;
+ writer.finish()?;
+ let out_bytes = writer.into_inner();
+ // Read the produced bytes back with the Reader
+ let mut rt_reader = ReaderBuilder::new()
+ .build(Cursor::new(out_bytes))
+ .expect("build reader for round-tripped in-memory OCF");
+ let rt_schema = rt_reader.schema();
+ let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
+ let roundtrip =
+ arrow::compute::concat_batches(&rt_schema,
&rt_batches).expect("concat roundtrip");
+ // Exact value fidelity (schema + data)
+ assert_eq!(
+ roundtrip, original,
+ "Round-trip Avro map data mismatch for nonnullable.impala.avro"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_roundtrip_decimals_via_writer() -> Result<(), ArrowError> {
+ // (file, resolve via ARROW_TEST_DATA?)
+ let files: [(&str, bool); 8] = [
+ ("avro/fixed_length_decimal.avro", true), // fixed-backed ->
Decimal128(25,2)
+ ("avro/fixed_length_decimal_legacy.avro", true), // legacy
fixed[8] -> Decimal64(13,2)
+ ("avro/int32_decimal.avro", true), // bytes-backed ->
Decimal32(4,2)
+ ("avro/int64_decimal.avro", true), // bytes-backed ->
Decimal64(10,2)
+ ("test/data/int256_decimal.avro", false), // bytes-backed ->
Decimal256(76,2)
+ ("test/data/fixed256_decimal.avro", false), // fixed[32]-backed ->
Decimal256(76,10)
+ ("test/data/fixed_length_decimal_legacy_32.avro", false), //
legacy fixed[4] -> Decimal32(9,2)
+ ("test/data/int128_decimal.avro", false), // bytes-backed ->
Decimal128(38,2)
+ ];
+ for (rel, in_test_data_dir) in files {
+ // Resolve path the same way as reader::test_decimal
+ let path: String = if in_test_data_dir {
+ arrow_test_data(rel)
+ } else {
+ PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+ .join(rel)
+ .to_string_lossy()
+ .into_owned()
+ };
+ // Read original file into a single RecordBatch for comparison
+ let f_in = File::open(&path).expect("open input avro");
+ let mut rdr = ReaderBuilder::new().build(BufReader::new(f_in))?;
+ let in_schema = rdr.schema();
+ let in_batches = rdr.collect::<Result<Vec<_>, _>>()?;
+ let original =
+ arrow::compute::concat_batches(&in_schema,
&in_batches).expect("concat input");
+ // Write it out with the OCF writer (no special compression)
+ let tmp = NamedTempFile::new().expect("create temp file");
+ let out_path = tmp.into_temp_path();
+ let out_file = File::create(&out_path).expect("create temp avro");
+ let mut writer = AvroWriter::new(out_file,
original.schema().as_ref().clone())?;
+ writer.write(&original)?;
+ writer.finish()?;
+ // Read back the file we just wrote and compare equality (schema +
data)
+ let f_rt = File::open(&out_path).expect("open roundtrip avro");
+ let mut rt_rdr = ReaderBuilder::new().build(BufReader::new(f_rt))?;
+ let rt_schema = rt_rdr.schema();
+ let rt_batches = rt_rdr.collect::<Result<Vec<_>, _>>()?;
+ let roundtrip =
+ arrow::compute::concat_batches(&rt_schema,
&rt_batches).expect("concat rt");
+ assert_eq!(roundtrip, original, "decimal round-trip mismatch for
{rel}");
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn test_enum_roundtrip_uses_reader_fixture() -> Result<(), ArrowError> {
+ // Read the known-good enum file (same as reader::test_simple)
+ let path = arrow_test_data("avro/simple_enum.avro");
+ let rdr_file = File::open(&path).expect("open avro/simple_enum.avro");
+ let mut reader = ReaderBuilder::new()
+ .build(BufReader::new(rdr_file))
+ .expect("build reader for simple_enum.avro");
+ // Concatenate all batches to one RecordBatch for a clean equality
check
+ let in_schema = reader.schema();
+ let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
+ let original =
+ arrow::compute::concat_batches(&in_schema,
&input_batches).expect("concat input");
+ // Sanity: expect at least one Dictionary(Int32, Utf8) column (enum)
+ let has_enum_dict = in_schema.fields().iter().any(|f| {
+ matches!(
+ f.data_type(),
+ DataType::Dictionary(k, v) if **k == DataType::Int32 && **v ==
DataType::Utf8
+ )
+ });
+ assert!(
+ has_enum_dict,
+ "Expected at least one enum-mapped Dictionary<Int32, Utf8> field"
+ );
+ // Write with OCF writer into memory using the reader-provided Arrow
schema.
+ // The writer will embed the Avro JSON from `avro.schema` metadata if
present.
+ let buffer: Vec<u8> = Vec::new();
+ let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?;
+ writer.write(&original)?;
+ writer.finish()?;
+ let bytes = writer.into_inner();
+ // Read back and compare for exact equality (schema + data)
+ let mut rt_reader = ReaderBuilder::new()
+ .build(Cursor::new(bytes))
+ .expect("reader for round-trip");
+ let rt_schema = rt_reader.schema();
+ let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
+ let roundtrip =
+ arrow::compute::concat_batches(&rt_schema,
&rt_batches).expect("concat roundtrip");
+ assert_eq!(roundtrip, original, "Avro enum round-trip mismatch");
+ Ok(())
+ }
}