This is an automated email from the ASF dual-hosted git repository. kriskras99 pushed a commit to branch feat/schema_aware_deserializer in repository https://gitbox.apache.org/repos/asf/avro-rs.git
commit c8192df18a22fe2f110e7b26d2f47dd05a219151 Author: default <[email protected]> AuthorDate: Thu Feb 26 17:37:13 2026 +0000 wip: Serde deserializer direct from writer --- avro/src/decode.rs | 2 +- avro/src/error.rs | 10 + avro/src/schema/union.rs | 4 + avro/src/serde/deser_schema/array.rs | 86 +++++ avro/src/serde/deser_schema/map.rs | 113 +++++++ avro/src/serde/deser_schema/mod.rs | 606 ++++++++++++++++++++++++++++++++++ avro/src/serde/deser_schema/record.rs | 338 +++++++++++++++++++ avro/src/serde/mod.rs | 1 + 8 files changed, 1159 insertions(+), 1 deletion(-) diff --git a/avro/src/decode.rs b/avro/src/decode.rs index dfa4bd3..d102a9d 100644 --- a/avro/src/decode.rs +++ b/avro/src/decode.rs @@ -42,7 +42,7 @@ pub(crate) fn decode_long<R: Read>(reader: &mut R) -> AvroResult<Value> { } #[inline] -fn decode_int<R: Read>(reader: &mut R) -> AvroResult<Value> { +pub(crate) fn decode_int<R: Read>(reader: &mut R) -> AvroResult<Value> { zag_i32(reader).map(Value::Int) } diff --git a/avro/src/error.rs b/avro/src/error.rs index 50a09af..8f930a5 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -569,6 +569,16 @@ pub enum Details { #[error("Failed to deserialize Avro value into value: {0}")] DeserializeValue(String), + #[error("Failed to deserialize value of type {value_type} using schema {schema:?}: {value}")] + DeserializeValueWithSchema { + value_type: &'static str, + value: String, + schema: Schema, + }, + + #[error("Only expected `deserialize_identifier` to be called but `{0}` was called")] + DeserializeKey(String), + #[error("Failed to write buffer bytes during flush: {0}")] WriteBytes(#[source] std::io::Error), diff --git a/avro/src/schema/union.rs b/avro/src/schema/union.rs index bca79aa..d8ea91d 100644 --- a/avro/src/schema/union.rs +++ b/avro/src/schema/union.rs @@ -82,6 +82,10 @@ impl UnionSchema { self.schemas.iter().any(|x| matches!(x, Schema::Null)) } + pub fn index(&self) -> &BTreeMap<SchemaKind, usize> { + &self.variant_index + } + /// Optionally returns a reference to the schema matched by this value, as well as its position /// within this union. /// diff --git a/avro/src/serde/deser_schema/array.rs b/avro/src/serde/deser_schema/array.rs new file mode 100644 index 0000000..13c9005 --- /dev/null +++ b/avro/src/serde/deser_schema/array.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; + +use serde::de::{MapAccess, SeqAccess}; + +use crate::{ + Error, Schema, schema::ArraySchema, serde::deser_schema::SchemaAwareDeserializer, util::zag_i32, +}; + +use super::Config; + +pub struct ArrayDeserializer<'s, 'r, R: Read> { + reader: &'r mut R, + schema: &'s ArraySchema, + config: Config<'s>, + state: State, +} + +impl<'s, 'r, R: Read> ArrayDeserializer<'s, 'r, R> { + pub fn new(reader: &'r mut R, schema: &'s ArraySchema, config: Config<'s>) -> Self { + Self { + reader, + schema, + config, + state: State::EndOfBlock, + } + } +} + +enum State { + EndOfBlock, + ReadingValue(u32), + Finished, +} + +impl<'de, 's, 'r, R: Read> SeqAccess<'de> for ArrayDeserializer<'s, 'r, R> { + type Error = Error; + + fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error> + where + T: serde::de::DeserializeSeed<'de>, + { + match self.state { + State::EndOfBlock => { + let remaining = zag_i32(self.reader)?; + if remaining < 0 { + let _bytes = zag_i32(self.reader)?; + } + if remaining == 0 { + self.state = State::Finished + } else { + self.state = State::ReadingValue(remaining.abs() as u32) + } + self.next_element_seed(seed) + } + State::ReadingValue(remaining) => { + let v = seed.deserialize(SchemaAwareDeserializer::new( + self.reader, + &self.schema.items, + self.config, + )?)?; + + self.state = State::ReadingValue(remaining); + + Ok(Some(v)) + } + State::Finished => Ok(None), + } + } +} diff --git a/avro/src/serde/deser_schema/map.rs b/avro/src/serde/deser_schema/map.rs new file mode 100644 index 0000000..61e07c0 --- /dev/null +++ b/avro/src/serde/deser_schema/map.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; + +use serde::de::MapAccess; + +use crate::{ + Error, Schema, schema::MapSchema, serde::deser_schema::SchemaAwareDeserializer, util::zag_i32, +}; + +use super::Config; + +pub struct MapDeserializer<'s, 'r, R: Read> { + reader: &'r mut R, + schema: &'s MapSchema, + config: Config<'s>, + state: State, +} + +impl<'s, 'r, R: Read> MapDeserializer<'s, 'r, R> { + pub fn new(reader: &'r mut R, schema: &'s MapSchema, config: Config<'s>) -> Self { + Self { + reader, + schema, + config, + state: State::EndOfBlock, + } + } +} + +enum State { + EndOfBlock, + ReadingKey(u32), + ReadingValue(u32), + Finished, +} + +impl<'de, 's, 'r, R: Read> MapAccess<'de> for MapDeserializer<'s, 'r, R> { + type Error = Error; + + fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error> + where + K: serde::de::DeserializeSeed<'de>, + { + match self.state { + State::EndOfBlock => { + let remaining = zag_i32(self.reader)?; + if remaining < 0 { + let _bytes = zag_i32(self.reader)?; + } + if remaining == 0 { + self.state = State::Finished + } else { + self.state = State::ReadingKey(remaining.abs() as u32) + } + self.next_key_seed(seed) + } + State::ReadingKey(remaining) => { + let v = seed.deserialize(SchemaAwareDeserializer::new( + self.reader, + &Schema::String, + self.config, + )?)?; + + self.state = State::ReadingValue(remaining); + + Ok(Some(v)) + } + State::Finished => Ok(None), + State::ReadingValue(_) => { + panic!("`next_key_seed` and `next_value_seed` where called in the wrong error") + } + } + } + + fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error> + where + V: serde::de::DeserializeSeed<'de>, + { + let State::ReadingValue(mut remaining) = self.state else { + panic!("`next_key_seed` and `next_value_seed` where called in the wrong error") + }; + let v = seed.deserialize(SchemaAwareDeserializer::new( + self.reader, + &self.schema.types, + self.config, + )?)?; + + remaining -= 1; + if remaining == 0 { + self.state = State::EndOfBlock; + } else { + self.state = State::ReadingKey(remaining); + } + + Ok(v) + } +} diff --git a/avro/src/serde/deser_schema/mod.rs b/avro/src/serde/deser_schema/mod.rs new file mode 100644 index 0000000..f301ba8 --- /dev/null +++ b/avro/src/serde/deser_schema/mod.rs @@ -0,0 +1,606 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod array; +mod map; +mod record; + +use std::io::Read; + +use serde::Deserializer; + +use crate::{ + Error, Schema, + decode::decode_len, + error::Details, + schema::{DecimalSchema, InnerDecimalSchema, NamesRef, SchemaKind, UuidSchema}, + serde::deser_schema::{ + array::ArrayDeserializer, map::MapDeserializer, record::RecordDeserializer, + }, + util::{zag_i32, zag_i64}, +}; + +#[derive(Clone, Copy)] +pub struct Config<'s> { + /// All names that can be referenced in the schema being used for serialisation. + pub names: &'s NamesRef<'s>, + /// Should `Serialize` implementations pick a human-readable format. + /// + /// It is recommended to set this to `false` as it results in compacter output. + pub human_readable: bool, +} + +pub struct SchemaAwareDeserializer<'s, 'r, R: Read> { + reader: &'r mut R, + schema: &'s Schema, + config: Config<'s>, +} + +impl<'s, 'r, R: Read> SchemaAwareDeserializer<'s, 'r, R> { + pub fn new(reader: &'r mut R, schema: &'s Schema, config: Config<'s>) -> Result<Self, Error> { + if let Schema::Ref { name } = schema { + let schema = config + .names + .get(name) + .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?; + Self::new(reader, schema, config) + } else { + Ok(Self { + reader, + schema, + config, + }) + } + } + + fn error(&self, ty: &'static str, error: impl Into<String>) -> Error { + Error::new(Details::DeserializeValueWithSchema { + value_type: ty, + value: error.into(), + schema: self.schema.clone(), + }) + } + + /// Create a new deserializer with the existing reader and config. + /// + /// This will resolve the schema if it is a reference. + fn with_different_schema(mut self, schema: &'s Schema) -> Result<Self, Error> { + let schema = if let Schema::Ref { name } = schema { + self.config + .names + .get(name) + .copied() + .ok_or_else(|| Details::SchemaResolutionError(name.clone()))? + } else { + schema + }; + self.schema = schema; + Ok(self) + } + + fn read_int(&mut self, original_ty: &'static str) -> Result<i32, Error> { + match self.schema { + Schema::Int | Schema::Date | Schema::TimeMillis => zag_i32(self.reader), + _ => Err(self.error( + original_ty, + "Expected Schema::Int | Schema::Date | Schema::TimeMillis", + )), + } + } + + fn read_long(&mut self, original_ty: &'static str) -> Result<i64, Error> { + match self.schema { + Schema::Long | Schema::TimeMicros | Schema::TimestampMillis | Schema::TimestampMicros + | Schema::TimestampNanos | Schema::LocalTimestampMillis | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => zag_i64(self.reader), + _ => Err(self.error( + original_ty, + "Expected Schema::Long | Schema::TimeMicros | Schema::{,Local}Timestamp{Millis,Micros,Nanos}", + )), + } + } + + fn read_string(&mut self) -> Result<String, Error> { + let bytes = self.read_bytes_with_len()?; + Ok(String::from_utf8(bytes).map_err(Details::ConvertToUtf8)?) + } + + fn read_bytes_with_len(&mut self) -> Result<Vec<u8>, Error> { + let length = decode_len(self.reader)?; + self.read_bytes(length) + } + + fn read_bytes(&mut self, length: usize) -> Result<Vec<u8>, Error> { + let mut buf = vec![0; length]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + Ok(buf) + } +} + +const DESERIALIZE_ANY: &str = "This value is compared by pointer value"; +const DESERIALIZE_ANY_FIELDS: &[&str] = &[]; + +impl<'de, 's, 'r, R: Read> Deserializer<'de> for SchemaAwareDeserializer<'s, 'r, R> { + type Error = Error; + + fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + match self.schema { + Schema::Null => self.deserialize_unit(visitor), + Schema::Boolean => self.deserialize_bool(visitor), + Schema::Int | Schema::Date | Schema::TimeMillis => self.deserialize_i32(visitor), + Schema::Long + | Schema::TimeMicros + | Schema::TimestampMillis + | Schema::TimestampMicros + | Schema::TimestampNanos + | Schema::LocalTimestampMillis + | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => self.deserialize_i64(visitor), + Schema::Float => self.deserialize_f32(visitor), + Schema::Double => self.deserialize_f64(visitor), + Schema::Bytes + | Schema::Fixed(_) + | Schema::Decimal(_) + | Schema::BigDecimal + | Schema::Uuid(UuidSchema::Fixed(_)) + | Schema::Duration(_) => self.deserialize_byte_buf(visitor), + Schema::String | Schema::Uuid(UuidSchema::String) => self.deserialize_string(visitor), + Schema::Array(_) => self.deserialize_seq(visitor), + Schema::Map(_) => self.deserialize_map(visitor), + Schema::Union(_) => { + self.deserialize_enum(DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS, visitor) + } + Schema::Record(_) => { + self.deserialize_struct(DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS, visitor) + } + Schema::Enum(_) => { + self.deserialize_enum(DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS, visitor) + } + Schema::Ref { .. } => unreachable!("References are resolved on deserializer creation"), + Schema::Uuid(UuidSchema::Bytes) => panic!("Unsupported"), + } + } + + fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + let mut buf = [0]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + // TODO: The TryFrom implementation wasn't working?? + let boolean = match buf[0] { + 0 => false, + 1 => true, + _ => return Err(self.error("bool", format!("{} is not a valid boolean", buf[0]))), + }; + visitor.visit_bool(boolean) + } + + fn deserialize_i8<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + let int = self.read_int("i8")?; + let value = i8::try_from(int) + .map_err(|_| self.error("i8", format!("Could not convert int ({int}) to an i8")))?; + visitor.visit_i8(value) + } + + fn deserialize_i16<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + let int = self.read_int("i16")?; + let value = i16::try_from(int) + .map_err(|_| self.error("i16", format!("Could not convert int ({int}) to an i16")))?; + visitor.visit_i16(value) + } + + fn deserialize_i32<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + let int = self.read_int("i32")?; + let value = i32::try_from(int) + .map_err(|_| self.error("i32", format!("Could not convert int ({int}) to an i32")))?; + visitor.visit_i32(value) + } + + fn deserialize_i64<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + let int = self.read_long("i64")?; + let value = i64::try_from(int) + .map_err(|_| self.error("i64", format!("Could not convert int ({int}) to an i64")))?; + visitor.visit_i64(value) + } + + fn deserialize_u8<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + let int = self.read_int("u8")?; + let value = u8::try_from(int) + .map_err(|_| self.error("u8", format!("Could not convert int ({int}) to an u8")))?; + visitor.visit_u8(value) + } + + fn deserialize_u16<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + let int = self.read_int("u16")?; + let value = u16::try_from(int) + .map_err(|_| self.error("u16", format!("Could not convert int ({int}) to an u16")))?; + visitor.visit_u16(value) + } + + fn deserialize_u32<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + let int = self.read_long("u32")?; + let value = u32::try_from(int) + .map_err(|_| self.error("u32", format!("Could not convert int ({int}) to an u32")))?; + visitor.visit_u32(value) + } + + fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Fixed(fixed) = self.schema + && fixed.size == 8 + && fixed.name.name == "u64" + { + let mut buf = [0; 8]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_u64(u64::from_le_bytes(buf)) + } else { + Err(self.error("u64", r#"Expected Schema::Fixed(name: "u64", size: 8)"#)) + } + } + + fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Float = self.schema { + let mut buf = [0; 4]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_f32(f32::from_le_bytes(buf)) + } else { + Err(self.error("f32", r#"Expected Schema::Float)"#)) + } + } + + fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Double = self.schema { + let mut buf = [0; 8]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_f64(f64::from_le_bytes(buf)) + } else { + Err(self.error("f64", r#"Expected Schema::Double)"#)) + } + } + + fn deserialize_char<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::String = self.schema { + let string = self.read_string()?; + let mut chars = string.chars(); + if let Some(character) = chars.next() { + if chars.next().is_some() { + Err(self.error("char", "String contains more than one character")) + } else { + visitor.visit_char(character) + } + } else { + Err(self.error("char", "String is empty")) + } + } else { + Err(self.error("char", "Expected Schema::String")) + } + } + + fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + self.deserialize_string(visitor) + } + + fn deserialize_string<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::String = self.schema { + let string = self.read_string()?; + visitor.visit_string(string) + } else { + Err(self.error("string", "Expected Schema::String")) + } + } + + fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + self.deserialize_byte_buf(visitor) + } + + fn deserialize_byte_buf<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + match self.schema { + Schema::Bytes | Schema::BigDecimal | Schema::Decimal(DecimalSchema { inner: InnerDecimalSchema::Bytes, .. }) => { + let bytes = self.read_bytes_with_len()?; + visitor.visit_byte_buf(bytes) + } + Schema::Fixed(fixed) | Schema::Decimal(DecimalSchema { inner: InnerDecimalSchema::Fixed(fixed), ..}) | Schema::Uuid(UuidSchema::Fixed(fixed)) | Schema::Duration(fixed) => { + let bytes = self.read_bytes(fixed.size)?; + visitor.visit_byte_buf(bytes) + } + _ => Err(self.error("bytes", "Expected Schema::Bytes | Schema::Fixed | Schema::BigDecimal | Schema::Decimal | Schema::Uuid(Fixed) | Schema::Duration")) + } + } + + fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Union(union) = self.schema + && union.variants().len() == 2 + && let Some(null_index) = union.index().get(&SchemaKind::Null).copied() + { + let index = zag_i32(self.reader)?; + if index < 0 || index > 1 { + return Err(self.error("option", format!("Invalid union index {index}"))); + } + let index = index as usize; + if index == null_index { + visitor.visit_none() + } else { + visitor.visit_some(self.with_different_schema(&union.variants()[index])?) + } + } else { + Err(self.error("option", "Expected Schema::Union([Null, _])")) + } + } + + fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Null = self.schema { + visitor.visit_unit() + } else { + Err(self.error("unit", "Expected Schema::Null")) + } + } + + fn deserialize_unit_struct<V>( + self, + name: &'static str, + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Record(record) = self.schema + && record.fields.len() == 0 + && record.name.name == name + { + visitor.visit_unit() + } else { + Err(self.error( + "unit struct", + format!("Expected Schema::Record(name: {name}, fields.len() == 0)"), + )) + } + } + + fn deserialize_newtype_struct<V>( + self, + name: &'static str, + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Record(record) = self.schema + && record.fields.len() == 1 + && record.name.name == name + { + visitor.visit_newtype_struct(self.with_different_schema(&record.fields[0].schema)?) + } else { + Err(self.error( + "unit struct", + format!("Expected Schema::Record(name: {name}, fields.len() == 1)"), + )) + } + } + + fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Array(array) = self.schema { + visitor.visit_seq(ArrayDeserializer::new(self.reader, array, self.config)) + } else { + Err(self.error("array", "Expected Schema::Array")) + } + } + + fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Record(record) = self.schema + && record.fields.len() == len + { + visitor.visit_map(RecordDeserializer::new(self.reader, record, self.config)) + } else { + Err(self.error( + "tuple", + format!("Expected Schema::Record(fields.len() == {len})"), + )) + } + } + + fn deserialize_tuple_struct<V>( + self, + name: &'static str, + len: usize, + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Record(record) = self.schema + && record.name.name == name + && record.fields.len() == len + { + visitor.visit_map(RecordDeserializer::new(self.reader, record, self.config)) + } else { + Err(self.error( + "tuple struct", + format!("Expected Schema::Record(fields.len() == {len})"), + )) + } + } + + fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Map(map) = self.schema { + visitor.visit_map(MapDeserializer::new(self.reader, map, self.config)) + } else { + Err(self.error("map", "Expected Schema::Map")) + } + } + + fn deserialize_struct<V>( + self, + name: &'static str, + _fields: &'static [&'static str], + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Record(record) = self.schema + && record.name.name == name + { + visitor.visit_map(RecordDeserializer::new(self.reader, record, self.config)) + } else { + Err(self.error("struct", format!("Expected Schema::Record(name: {name})"))) + } + } + + fn deserialize_enum<V>( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + todo!() + } + + fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Fixed(fixed) = self.schema + && fixed.size == 16 + && fixed.name.name == "i128" + { + let mut buf = [0; 16]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_i128(i128::from_le_bytes(buf)) + } else { + Err(self.error("i128", r#"Expected Schema::Fixed(name: "i128", size: 16)"#)) + } + } + + fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Schema::Fixed(fixed) = self.schema + && fixed.size == 16 + && fixed.name.name == "u128" + { + let mut buf = [0; 16]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_u128(u128::from_le_bytes(buf)) + } else { + Err(self.error("u128", r#"Expected Schema::Fixed(name: "u128", size: 16)"#)) + } + } + + fn is_human_readable(&self) -> bool { + self.config.human_readable + } + + fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + if let Some(name) = self.schema.name() { + visitor.visit_str(&name.name) + } else { + Err(self.error("identifier", "Expected a named schema")) + } + } + + fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + // TODO: We can probably do something more efficient, but that might need the Seek trait bound + self.deserialize_any(visitor) + } +} diff --git a/avro/src/serde/deser_schema/record.rs b/avro/src/serde/deser_schema/record.rs new file mode 100644 index 0000000..0510f3a --- /dev/null +++ b/avro/src/serde/deser_schema/record.rs @@ -0,0 +1,338 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; + +use serde::{Deserializer, de::MapAccess}; + +use crate::{ + Error, error::Details, schema::RecordSchema, serde::deser_schema::SchemaAwareDeserializer, +}; + +use super::Config; + +pub struct RecordDeserializer<'s, 'r, R: Read> { + reader: &'r mut R, + schema: &'s RecordSchema, + config: Config<'s>, + current_field: State, +} + +impl<'s, 'r, R: Read> RecordDeserializer<'s, 'r, R> { + pub fn new(reader: &'r mut R, schema: &'s RecordSchema, config: Config<'s>) -> Self { + Self { + reader, + schema, + config, + current_field: State::Key(0), + } + } +} + +enum State { + Key(usize), + Value(usize), +} + +impl<'de, 's, 'r, R: Read> MapAccess<'de> for RecordDeserializer<'s, 'r, R> { + type Error = Error; + + fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error> + where + K: serde::de::DeserializeSeed<'de>, + { + let State::Key(index) = self.current_field else { + panic!("`next_key_seed` and `next_value_seed` where called in the wrong error") + }; + if index >= self.schema.fields.len() { + // Finished reading this record + Ok(None) + } else { + let v = seed.deserialize(FieldName { + name: &self.schema.fields[index].name, + })?; + self.current_field = State::Value(index); + Ok(Some(v)) + } + } + + fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error> + where + V: serde::de::DeserializeSeed<'de>, + { + let State::Value(index) = self.current_field else { + panic!("`next_key_seed` and `next_value_seed` where called in the wrong error") + }; + let v = seed.deserialize(SchemaAwareDeserializer::new( + self.reader, + &self.schema.fields[index].schema, + self.config, + )?)?; + self.current_field = State::Key(index + 1); + Ok(v) + } + + fn size_hint(&self) -> Option<usize> { + match self.current_field { + State::Key(index) | State::Value(index) => Some(self.schema.fields.len() - index), + } + } +} + +/// "Deserializer" for the field name +struct FieldName<'s> { + name: &'s str, +} + +impl<'de, 's> Deserializer<'de> for FieldName<'s> { + type Error = Error; + + fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_any".into()).into()) + } + + fn deserialize_bool<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_bool".into()).into()) + } + + fn deserialize_i8<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_i8".into()).into()) + } + + fn deserialize_i16<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_i16".into()).into()) + } + + fn deserialize_i32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_i32".into()).into()) + } + + fn deserialize_i64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_i64".into()).into()) + } + + fn deserialize_u8<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_u8".into()).into()) + } + + fn deserialize_u16<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_u16".into()).into()) + } + + fn deserialize_u32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_u32".into()).into()) + } + + fn deserialize_u64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_u64".into()).into()) + } + + fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_f32".into()).into()) + } + + fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_f64".into()).into()) + } + + fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_char".into()).into()) + } + + fn deserialize_str<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_str".into()).into()) + } + + fn deserialize_string<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_string".into()).into()) + } + + fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_bytes".into()).into()) + } + + fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_byte_buf".into()).into()) + } + + fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_option".into()).into()) + } + + fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_unit".into()).into()) + } + + fn deserialize_unit_struct<V>( + self, + name: &'static str, + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey(format!("deserialize_unit_struct(name: {name})")).into()) + } + + fn deserialize_newtype_struct<V>( + self, + name: &'static str, + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey(format!("deserialize_newtype_struct(name: {name})")).into()) + } + + fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_seq".into()).into()) + } + + fn deserialize_tuple<V>(self, len: usize, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey(format!("deserialize_tuple(len: {len})")).into()) + } + + fn deserialize_tuple_struct<V>( + self, + name: &'static str, + len: usize, + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey(format!( + "deserialize_tuple_struct(name: {name}, len: {len})" + )) + .into()) + } + + fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_map".into()).into()) + } + + fn deserialize_struct<V>( + self, + name: &'static str, + fields: &'static [&'static str], + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey(format!( + "deserialize_struct(name: {name}, fields: {fields:?})" + )) + .into()) + } + + fn deserialize_enum<V>( + self, + name: &'static str, + variants: &'static [&'static str], + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey(format!( + "deserialize_enum(name: {name}, variants: {variants:?})" + )) + .into()) + } + + fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + visitor.visit_str(self.name) + } + + fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + Err(Details::DeserializeKey("deserialize_ignored_any".into()).into()) + } +} diff --git a/avro/src/serde/mod.rs b/avro/src/serde/mod.rs index b3bfd2a..0a17bc7 100644 --- a/avro/src/serde/mod.rs +++ b/avro/src/serde/mod.rs @@ -109,6 +109,7 @@ mod de; mod derive; +pub mod deser_schema; mod ser; pub(crate) mod ser_schema; mod util;
