This is an automated email from the ASF dual-hosted git repository. kriskras99 pushed a commit to branch feat/enums in repository https://gitbox.apache.org/repos/asf/avro-rs.git
commit 939e1712c62d79a9b5b42c45076b35b4a669da53 Author: default <[email protected]> AuthorDate: Wed Mar 4 13:44:27 2026 +0000 all tests work again --- avro/src/duration.rs | 1 + avro/src/error.rs | 2 +- avro/src/lib.rs | 4 + avro/src/reader/datum.rs | 41 ++- avro/src/reader/single_object.rs | 6 +- avro/src/schema/mod.rs | 1 + avro/src/serde/de.rs | 1 + avro/src/serde/deser_schema/array.rs | 16 +- avro/src/serde/deser_schema/enums/mod.rs | 2 + avro/src/serde/deser_schema/enums/plain.rs | 278 +-------------- avro/src/serde/deser_schema/enums/union.rs | 142 ++++++++ avro/src/serde/deser_schema/identifier.rs | 255 ++++++++++++++ avro/src/serde/deser_schema/map.rs | 16 +- avro/src/serde/deser_schema/mod.rs | 346 ++++++++++-------- avro/src/serde/deser_schema/record.rs | 275 +-------------- avro/src/serde/deser_schema/tuple.rs | 31 +- avro/src/serde/deser_schema/union.rs | 450 ------------------------ avro/src/serde/mod.rs | 4 + avro/src/serde/ser_schema/mod.rs | 2 +- avro/src/serde/with.rs | 7 +- avro/tests/avro-rs-285-bytes_deserialization.rs | 4 +- avro/tests/schema.rs | 8 +- avro/tests/union_schema.rs | 4 +- avro_derive/src/attributes/serde.rs | 4 + avro_derive/tests/serde.rs | 6 +- 25 files changed, 708 insertions(+), 1198 deletions(-) diff --git a/avro/src/duration.rs b/avro/src/duration.rs index eecfca1..564ee72 100644 --- a/avro/src/duration.rs +++ b/avro/src/duration.rs @@ -209,6 +209,7 @@ mod tests { use crate::types::Value; use apache_avro_test_helper::TestResult; + #[expect(deprecated, reason = "This tests the deprecated function")] #[test] fn avro_rs_382_duration_from_value() -> TestResult { let val = Value::Duration(Duration::new(Months::new(7), Days::new(4), Millis::new(45))); diff --git a/avro/src/error.rs b/avro/src/error.rs index 23db16b..4482280 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -586,7 +586,7 @@ pub enum Details { }, #[error("Only expected `deserialize_identifier` to be called but `{0}` was called")] - DeserializeKey(String), + DeserializeIdentifier(&'static str), #[error("Failed to write buffer bytes during flush: {0}")] WriteBytes(#[source] std::io::Error), diff --git a/avro/src/lib.rs b/avro/src/lib.rs index 314d596..e91464b 100644 --- a/avro/src/lib.rs +++ b/avro/src/lib.rs @@ -101,6 +101,10 @@ pub use reader::{ single_object::{GenericSingleObjectReader, SpecificSingleObjectReader}, }; pub use schema::Schema; +#[expect( + deprecated, + reason = "Still need to export it until we remove it completely" +)] pub use serde::{AvroSchema, AvroSchemaComponent, from_value, to_value}; pub use uuid::Uuid; #[expect( diff --git a/avro/src/reader/datum.rs b/avro/src/reader/datum.rs index cbf6fd7..f72b272 100644 --- a/avro/src/reader/datum.rs +++ b/avro/src/reader/datum.rs @@ -37,6 +37,7 @@ pub struct GenericDatumReader<'s> { writer: &'s Schema, resolved: ResolvedSchema<'s>, reader: Option<(&'s Schema, ResolvedSchema<'s>)>, + human_readable: bool, } #[bon] @@ -57,6 +58,7 @@ impl<'s> GenericDatumReader<'s> { reader_schema: Option<&'s Schema>, /// Already resolved schemata that will be used to resolve references in the reader's schema. resolved_reader_schemata: Option<ResolvedSchema<'s>>, + #[builder(default)] human_readable: bool, ) -> AvroResult<Self> { let resolved_writer_schemata = if let Some(resolved) = resolved_writer_schemata { resolved @@ -78,6 +80,7 @@ impl<'s> GenericDatumReader<'s> { writer: writer_schema, resolved: resolved_writer_schemata, reader, + human_readable, }) } } @@ -131,6 +134,17 @@ impl<'s> GenericDatumReader<'s> { Ok(value) } } + + pub fn read_deser<R: Read, T: DeserializeOwned>(&self, reader: &mut R) -> AvroResult<T> { + T::deserialize(SchemaAwareDeserializer::new( + reader, + self.writer, + Config { + names: self.resolved.get_names(), + human_readable: self.human_readable, + }, + )?) + } } pub struct SpecificDatumReader<T: DeserializeOwned> { @@ -253,7 +267,7 @@ mod tests { use serde::Deserialize; use crate::{ - Schema, from_value, + Schema, reader::datum::GenericDatumReader, types::{Record, Value}, }; @@ -298,7 +312,7 @@ mod tests { const TEST_RECORD_SCHEMA_3240: &str = r#" { "type": "record", - "name": "test", + "name": "TestRecord3240", "fields": [ { "name": "a", @@ -340,24 +354,15 @@ mod tests { let schema = Schema::parse_str(TEST_RECORD_SCHEMA_3240)?; let mut encoded: &'static [u8] = &[54, 6, 102, 111, 111]; - let expected_record: TestRecord3240 = TestRecord3240 { - a: 27i64, - b: String::from("foo"), - a_nullable_array: None, - a_nullable_string: None, - }; - - let avro_datum = GenericDatumReader::builder(&schema) + let error = GenericDatumReader::builder(&schema) .build()? - .read_value(&mut encoded)?; - let parsed_record: TestRecord3240 = match &avro_datum { - Value::Record(_) => from_value::<TestRecord3240>(&avro_datum)?, - unexpected => { - panic!("could not map avro data to struct, found unexpected: {unexpected:?}") - } - }; + .read_deser::<_, TestRecord3240>(&mut encoded) + .unwrap_err(); - assert_eq!(parsed_record, expected_record); + assert_eq!( + error.to_string(), + "Failed to read bytes for decoding variable length integer: failed to fill whole buffer" + ); Ok(()) } diff --git a/avro/src/reader/single_object.rs b/avro/src/reader/single_object.rs index 6e0b744..2f1ef2b 100644 --- a/avro/src/reader/single_object.rs +++ b/avro/src/reader/single_object.rs @@ -21,7 +21,7 @@ use crate::headers::{HeaderBuilder, RabinFingerprintHeader}; use crate::schema::ResolvedOwnedSchema; use crate::serde::deser_schema::{Config, SchemaAwareDeserializer}; use crate::types::Value; -use crate::{AvroResult, AvroSchema, Schema, from_value}; +use crate::{AvroResult, AvroSchema, Schema}; use serde::de::DeserializeOwned; use std::io::Read; use std::marker::PhantomData; @@ -129,7 +129,7 @@ where T: AvroSchema + DeserializeOwned, { pub fn read<R: Read>(&self, reader: &mut R) -> AvroResult<T> { - from_value::<T>(&self.inner.read_value(reader)?) + self.inner.read_deser(reader) } } @@ -156,7 +156,7 @@ mod tests { let schema = r#" { "type":"record", - "name":"TestSingleObjectWrtierSerialize", + "name":"TestSingleObjectReader", "fields":[ { "name":"a", diff --git a/avro/src/schema/mod.rs b/avro/src/schema/mod.rs index 05908f3..7e17e60 100644 --- a/avro/src/schema/mod.rs +++ b/avro/src/schema/mod.rs @@ -3476,6 +3476,7 @@ mod tests { Ok(()) } + #[expect(deprecated, reason = "Schema resolution is a WIP")] #[test] fn test_avro_3814_schema_resolution_failure() -> TestResult { // Define a reader schema: a nested record with an optional field. diff --git a/avro/src/serde/de.rs b/avro/src/serde/de.rs index b05e065..aefe7b9 100644 --- a/avro/src/serde/de.rs +++ b/avro/src/serde/de.rs @@ -944,6 +944,7 @@ pub fn from_value<'de, D: Deserialize<'de>>(value: &'de Value) -> Result<D, Erro D::deserialize(de) } +#[expect(deprecated, reason = "This tests the deprecated function")] #[cfg(test)] mod tests { use num_bigint::BigInt; diff --git a/avro/src/serde/deser_schema/array.rs b/avro/src/serde/deser_schema/array.rs index 08cb055..615f0c4 100644 --- a/avro/src/serde/deser_schema/array.rs +++ b/avro/src/serde/deser_schema/array.rs @@ -21,7 +21,6 @@ use std::io::Read; use serde::de::SeqAccess; use super::Config; -use crate::serde::deser_schema::union::UnionDeserializer; use crate::{ Error, Schema, schema::ArraySchema, serde::deser_schema::SchemaAwareDeserializer, util::zag_i32, }; @@ -71,16 +70,11 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> SeqAccess<'de> for ArrayDeserializ self.next_element_seed(seed) } State::ReadingValue(mut remaining) => { - let v = match self.schema.items.as_ref() { - Schema::Union(union) => { - seed.deserialize(UnionDeserializer::new(self.reader, union, self.config)?)? - } - schema => seed.deserialize(SchemaAwareDeserializer::new( - self.reader, - schema, - self.config, - )?)?, - }; + let v = seed.deserialize(SchemaAwareDeserializer::new( + self.reader, + &self.schema.items, + self.config, + )?)?; remaining -= 1; if remaining == 0 { diff --git a/avro/src/serde/deser_schema/enums/mod.rs b/avro/src/serde/deser_schema/enums/mod.rs index 65a8e66..7803cad 100644 --- a/avro/src/serde/deser_schema/enums/mod.rs +++ b/avro/src/serde/deser_schema/enums/mod.rs @@ -1,3 +1,5 @@ mod plain; +mod union; pub use plain::PlainEnumAccess; +pub use union::UnionEnumAccess; diff --git a/avro/src/serde/deser_schema/enums/plain.rs b/avro/src/serde/deser_schema/enums/plain.rs index 65b5b81..f419de3 100644 --- a/avro/src/serde/deser_schema/enums/plain.rs +++ b/avro/src/serde/deser_schema/enums/plain.rs @@ -1,8 +1,8 @@ +use crate::Error; use crate::error::Details; use crate::schema::EnumSchema; +use crate::serde::deser_schema::identifier::IdentifierDeserializer; use crate::util::zag_i32; -use crate::{Error, Schema}; -use serde::Deserializer; use serde::de::{DeserializeSeed, EnumAccess, Unexpected, VariantAccess, Visitor}; use std::io::Read; @@ -25,11 +25,21 @@ impl<'de, 's, 'r, R: Read> EnumAccess<'de> for PlainEnumAccess<'s, 'r, R> { where V: DeserializeSeed<'de>, { - let deserializer = EnumIdentifierDeserializer { - reader: self.reader, - schema: self.schema, - }; - Ok((seed.deserialize(deserializer)?, self)) + let orig_index = zag_i32(self.reader)?; + let index = + usize::try_from(orig_index).map_err(|e| Details::ConvertI32ToUsize(e, orig_index))?; + let symbol = self + .schema + .symbols + .get(index) + .ok_or(Details::EnumSymbolIndex { + index: index as usize, + num_variants: self.schema.symbols.len(), + })?; + Ok(( + seed.deserialize(IdentifierDeserializer::string(symbol))?, + self, + )) } } @@ -68,257 +78,3 @@ impl<'de, 's, 'r, R: Read> VariantAccess<'de> for PlainEnumAccess<'s, 'r, R> { Err(serde::de::Error::invalid_type(unexp, &"newtype variant")) } } - -struct EnumIdentifierDeserializer<'s, 'r, R: Read> { - schema: &'s EnumSchema, - reader: &'r mut R, -} - -impl<'s, 'r, R: Read> EnumIdentifierDeserializer<'s, 'r, R> { - fn error(&self, error: impl Into<String>) -> Error { - Error::new(Details::DeserializeValueWithSchema { - value_type: "enum", - value: error.into(), - schema: Schema::Enum(self.schema.clone()), - }) - } -} - -impl<'de, 's, 'r, R: Read> Deserializer<'de> for EnumIdentifierDeserializer<'s, 'r, R> { - type Error = Error; - - fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_any, expected deserialize_identifier")) - } - - fn deserialize_bool<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_bool, expected deserialize_identifier")) - } - - fn deserialize_i8<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_i8, expected deserialize_identifier")) - } - - fn deserialize_i16<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_i16, expected deserialize_identifier")) - } - - fn deserialize_i32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_i32, expected deserialize_identifier")) - } - - fn deserialize_i64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_i64, expected deserialize_identifier")) - } - - fn deserialize_u8<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_u8, expected deserialize_identifier")) - } - - fn deserialize_u16<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_u16, expected deserialize_identifier")) - } - - fn deserialize_u32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_u32, expected deserialize_identifier")) - } - - fn deserialize_u64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_u64, expected deserialize_identifier")) - } - - fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_f32, expected deserialize_identifier")) - } - - fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_f64, expected deserialize_identifier")) - } - - fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_char, expected deserialize_identifier")) - } - - fn deserialize_str<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_str, expected deserialize_identifier")) - } - - fn deserialize_string<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_string, expected deserialize_identifier")) - } - - fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_bytes, expected deserialize_identifier")) - } - - fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_byte_buf, expected deserialize_identifier")) - } - - fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_option, expected deserialize_identifier")) - } - - fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_unit, expected deserialize_identifier")) - } - - fn deserialize_unit_struct<V>( - self, - _name: &'static str, - _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_unit_struct, expected deserialize_identifier")) - } - - fn deserialize_newtype_struct<V>( - self, - _name: &'static str, - _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_newtype_struct, expected deserialize_identifier")) - } - - fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_seq, expected deserialize_identifier")) - } - - fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_tuple, expected deserialize_identifier")) - } - - fn deserialize_tuple_struct<V>( - self, - _name: &'static str, - _len: usize, - _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_tuple_struct, expected deserialize_identifier")) - } - - fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_map, expected deserialize_identifier")) - } - - fn deserialize_struct<V>( - self, - _name: &'static str, - _fields: &'static [&'static str], - _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_struct, expected deserialize_identifier")) - } - - fn deserialize_enum<V>( - self, - _name: &'static str, - _variants: &'static [&'static str], - _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(self.error("Unexpected deserialize_enum, expected deserialize_identifier")) - } - - fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - let index = zag_i32(self.reader)?; - let symbol = self - .schema - .symbols - .get(index as usize) - .ok_or(Details::EnumSymbolIndex { - index: index as usize, - num_variants: self.schema.symbols.len(), - })?; - visitor.visit_str(symbol) - } - - fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - self.deserialize_any(visitor) - } -} diff --git a/avro/src/serde/deser_schema/enums/union.rs b/avro/src/serde/deser_schema/enums/union.rs new file mode 100644 index 0000000..be09d2a --- /dev/null +++ b/avro/src/serde/deser_schema/enums/union.rs @@ -0,0 +1,142 @@ +use std::{borrow::Borrow, io::Read}; + +use serde::{ + Deserializer, + de::{EnumAccess, Unexpected, VariantAccess}, +}; + +use crate::{ + Error, Schema, + error::Details, + schema::UnionSchema, + serde::deser_schema::{ + Config, DESERIALIZE_ANY, SchemaAwareDeserializer, identifier::IdentifierDeserializer, + }, + util::zag_i32, +}; + +pub struct UnionEnumAccess<'s, 'r, R: Read, S: Borrow<Schema>> { + schema: &'s UnionSchema, + reader: &'r mut R, + config: Config<'s, S>, +} + +impl<'s, 'r, R: Read, S: Borrow<Schema>> UnionEnumAccess<'s, 'r, R, S> { + pub fn new(schema: &'s UnionSchema, reader: &'r mut R, config: Config<'s, S>) -> Self { + Self { + schema, + reader, + config, + } + } +} + +impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> EnumAccess<'de> for UnionEnumAccess<'s, 'r, R, S> { + type Error = Error; + + type Variant = UnionVariantAccess<'s, 'r, R, S>; + + fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: serde::de::DeserializeSeed<'de>, + { + let orig_index = zag_i32(self.reader)?; + let index = + usize::try_from(orig_index).map_err(|e| Details::ConvertI32ToUsize(e, orig_index))?; + + let schema = self + .schema + .variants() + .get(index) + .ok_or(Details::GetUnionVariant { + index: orig_index as i64, + num_variants: self.schema.variants().len(), + })?; + + Ok(( + seed.deserialize(IdentifierDeserializer::index(index as u32))?, + UnionVariantAccess::new(schema, self.reader, self.config), + )) + } +} + +pub struct UnionVariantAccess<'s, 'r, R: Read, S: Borrow<Schema>> { + schema: &'s Schema, + reader: &'r mut R, + config: Config<'s, S>, +} + +impl<'s, 'r, R: Read, S: Borrow<Schema>> UnionVariantAccess<'s, 'r, R, S> { + fn new(schema: &'s Schema, reader: &'r mut R, config: Config<'s, S>) -> Self { + Self { + schema, + reader, + config, + } + } +} + +impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> VariantAccess<'de> + for UnionVariantAccess<'s, 'r, R, S> +{ + type Error = Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + if let Schema::Null = self.schema { + Ok(()) + } else if let Schema::Record(record) = self.schema + && record.fields.is_empty() + { + Ok(()) + } else { + let unexp = Unexpected::UnitVariant; + Err(serde::de::Error::invalid_type(unexp, &"other variant")) + } + } + + fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error> + where + T: serde::de::DeserializeSeed<'de>, + { + if let Schema::Record(record) = self.schema + && record.fields.len() == 1 + && record.fields[0].name == "field_0" + { + // Most likely a Union of Records + seed.deserialize(SchemaAwareDeserializer::new( + self.reader, + &record.fields[0].schema, + self.config, + )?) + } else { + seed.deserialize(SchemaAwareDeserializer::new( + self.reader, + self.schema, + self.config, + )?) + } + } + + fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + SchemaAwareDeserializer::new(self.reader, self.schema, self.config)? + .deserialize_tuple(len, visitor) + } + + fn struct_variant<V>( + self, + fields: &'static [&'static str], + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + SchemaAwareDeserializer::new(self.reader, self.schema, self.config)?.deserialize_struct( + DESERIALIZE_ANY, + fields, + visitor, + ) + } +} diff --git a/avro/src/serde/deser_schema/identifier.rs b/avro/src/serde/deser_schema/identifier.rs new file mode 100644 index 0000000..1082a18 --- /dev/null +++ b/avro/src/serde/deser_schema/identifier.rs @@ -0,0 +1,255 @@ +use crate::{Error, error::Details}; +use serde::{Deserializer, de::Visitor}; + +/// Deserializer that only accepts `deserialize_identifier` calls. +pub enum IdentifierDeserializer<'s> { + String(&'s str), + Index(u32), +} + +impl<'s> IdentifierDeserializer<'s> { + pub fn string(name: &'s str) -> Self { + Self::String(name) + } + + pub fn index(index: u32) -> Self { + Self::Index(index) + } + + fn error(&self, error: &'static str) -> Error { + Error::new(Details::DeserializeIdentifier(error)) + } +} + +impl<'de, 's> Deserializer<'de> for IdentifierDeserializer<'s> { + type Error = Error; + + fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_any")) + } + + fn deserialize_bool<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_bool")) + } + + fn deserialize_i8<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_i8")) + } + + fn deserialize_i16<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_i16")) + } + + fn deserialize_i32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_i32")) + } + + fn deserialize_i64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_i64")) + } + + fn deserialize_u8<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_u8")) + } + + fn deserialize_u16<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_u16")) + } + + fn deserialize_u32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_u32")) + } + + fn deserialize_u64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_u64")) + } + + fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_f32")) + } + + fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_f64")) + } + + fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_char")) + } + + fn deserialize_str<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_str")) + } + + fn deserialize_string<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_string")) + } + + fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_bytes")) + } + + fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_byte_buf")) + } + + fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_option")) + } + + fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_unit")) + } + + fn deserialize_unit_struct<V>( + self, + _name: &'static str, + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_unit_struct")) + } + + fn deserialize_newtype_struct<V>( + self, + _name: &'static str, + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_newtype_struct")) + } + + fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_seq")) + } + + fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_tuple")) + } + + fn deserialize_tuple_struct<V>( + self, + _name: &'static str, + _len: usize, + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_tuple_struct")) + } + + fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_map")) + } + + fn deserialize_struct<V>( + self, + _name: &'static str, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_struct")) + } + + fn deserialize_enum<V>( + self, + _name: &'static str, + _variants: &'static [&'static str], + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_enum")) + } + + fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + match self { + IdentifierDeserializer::String(string) => visitor.visit_str(string), + IdentifierDeserializer::Index(index) => visitor.visit_u32(index), + } + } + + fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(self.error("deserialize_ignored_any")) + } +} diff --git a/avro/src/serde/deser_schema/map.rs b/avro/src/serde/deser_schema/map.rs index 53b7ed5..fe12a7f 100644 --- a/avro/src/serde/deser_schema/map.rs +++ b/avro/src/serde/deser_schema/map.rs @@ -21,7 +21,6 @@ use std::io::Read; use serde::de::MapAccess; use super::Config; -use crate::serde::deser_schema::union::UnionDeserializer; use crate::{ Error, Schema, schema::MapSchema, serde::deser_schema::SchemaAwareDeserializer, util::zag_i32, }; @@ -96,16 +95,11 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> MapAccess<'de> for MapDeserializer let State::ReadingValue(mut remaining) = self.state else { panic!("`next_key_seed` and `next_value_seed` where called in the wrong error") }; - let v = match self.schema.types.as_ref() { - Schema::Union(union) => { - seed.deserialize(UnionDeserializer::new(self.reader, union, self.config)?)? - } - schema => seed.deserialize(SchemaAwareDeserializer::new( - self.reader, - schema, - self.config, - )?)?, - }; + let v = seed.deserialize(SchemaAwareDeserializer::new( + self.reader, + &self.schema.types, + self.config, + )?)?; remaining -= 1; if remaining == 0 { diff --git a/avro/src/serde/deser_schema/mod.rs b/avro/src/serde/deser_schema/mod.rs index 0e07518..fc6c560 100644 --- a/avro/src/serde/deser_schema/mod.rs +++ b/avro/src/serde/deser_schema/mod.rs @@ -17,10 +17,10 @@ mod array; mod enums; +mod identifier; mod map; mod record; mod tuple; -mod union; use std::borrow::Borrow; use std::collections::HashMap; @@ -29,12 +29,11 @@ use std::io::Read; use serde::Deserializer; -use crate::schema::Name; -use crate::serde::deser_schema::enums::PlainEnumAccess; +use crate::schema::{Name, UnionSchema}; +use crate::serde::deser_schema::enums::{PlainEnumAccess, UnionEnumAccess}; use crate::serde::deser_schema::tuple::{ ManyTupleDeserializer, OneTupleDeserializer, UnitTupleDeserializer, }; -use crate::serde::deser_schema::union::UnionDeserializer; use crate::{ Error, Schema, decode::decode_len, @@ -56,13 +55,12 @@ pub struct Config<'s, S: Borrow<Schema>> { pub human_readable: bool, } +// This needs to be implemented manually as the derive puts a bound on `S` +// which is not necessary as a reference is always Copy. impl<'s, S: Borrow<Schema>> Copy for Config<'s, S> {} impl<'s, S: Borrow<Schema>> Clone for Config<'s, S> { fn clone(&self) -> Self { - Self { - names: self.names, - human_readable: self.human_readable, - } + *self } } @@ -127,6 +125,22 @@ impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S> { Ok(self) } + /// Read the union and create a new deserializer with the existing reader and config. + /// + /// This will resolve the read schema if it is a reference. + fn with_union(self, schema: &'s UnionSchema) -> Result<Self, Error> { + let index = zag_i32(self.reader)?; + let variant = + schema + .variants() + .get(index as usize) + .ok_or_else(|| Details::GetUnionVariant { + index: index as i64, + num_variants: schema.variants().len(), + })?; + self.with_different_schema(variant) + } + fn read_int(&mut self, original_ty: &'static str) -> Result<i32, Error> { match self.schema { Schema::Int | Schema::Date | Schema::TimeMillis => zag_i32(self.reader), @@ -203,9 +217,7 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> Schema::String | Schema::Uuid(UuidSchema::String) => self.deserialize_string(visitor), Schema::Array(_) => self.deserialize_seq(visitor), Schema::Map(_) => self.deserialize_map(visitor), - Schema::Union(union) => { - UnionDeserializer::new(self.reader, union, self.config)?.deserialize_any(visitor) - } + Schema::Union(union) => self.with_union(union)?.deserialize_any(visitor), Schema::Record(schema) => { if schema.attributes.get("org.apache.avro.rust.tuple") == Some(&serde_json::Value::Bool(true)) @@ -247,6 +259,7 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> }; visitor.visit_bool(boolean) } + Schema::Union(union) => self.with_union(union)?.deserialize_bool(visitor), _ => Err(self.error("bool", "Expected a Schema::Boolean")), } } @@ -255,83 +268,113 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - let int = self.read_int("i8")?; - let value = i8::try_from(int) - .map_err(|_| self.error("i8", format!("Could not convert int ({int}) to an i8")))?; - visitor.visit_i8(value) + if let Schema::Union(union) = self.schema { + self.with_union(union)?.deserialize_i8(visitor) + } else { + let int = self.read_int("i8")?; + let value = i8::try_from(int) + .map_err(|_| self.error("i8", format!("Could not convert int ({int}) to an i8")))?; + visitor.visit_i8(value) + } } fn deserialize_i16<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de>, { - let int = self.read_int("i16")?; - let value = i16::try_from(int) - .map_err(|_| self.error("i16", format!("Could not convert int ({int}) to an i16")))?; - visitor.visit_i16(value) + if let Schema::Union(union) = self.schema { + self.with_union(union)?.deserialize_i16(visitor) + } else { + let int = self.read_int("i16")?; + let value = i16::try_from(int).map_err(|_| { + self.error("i16", format!("Could not convert int ({int}) to an i16")) + })?; + visitor.visit_i16(value) + } } fn deserialize_i32<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de>, { - let value = self.read_int("i32")?; - visitor.visit_i32(value) + if let Schema::Union(union) = self.schema { + self.with_union(union)?.deserialize_i32(visitor) + } else { + let value = self.read_int("i32")?; + visitor.visit_i32(value) + } } fn deserialize_i64<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de>, { - let value = self.read_long("i64")?; - visitor.visit_i64(value) + if let Schema::Union(union) = self.schema { + self.with_union(union)?.deserialize_i64(visitor) + } else { + let value = self.read_long("i64")?; + visitor.visit_i64(value) + } } fn deserialize_u8<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de>, { - let int = self.read_int("u8")?; - let value = u8::try_from(int) - .map_err(|_| self.error("u8", format!("Could not convert int ({int}) to an u8")))?; - visitor.visit_u8(value) + if let Schema::Union(union) = self.schema { + self.with_union(union)?.deserialize_u8(visitor) + } else { + let int = self.read_int("u8")?; + let value = u8::try_from(int) + .map_err(|_| self.error("u8", format!("Could not convert int ({int}) to an u8")))?; + visitor.visit_u8(value) + } } fn deserialize_u16<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de>, { - let int = self.read_int("u16")?; - let value = u16::try_from(int) - .map_err(|_| self.error("u16", format!("Could not convert int ({int}) to an u16")))?; - visitor.visit_u16(value) + if let Schema::Union(union) = self.schema { + self.with_union(union)?.deserialize_u16(visitor) + } else { + let int = self.read_int("u16")?; + let value = u16::try_from(int).map_err(|_| { + self.error("u16", format!("Could not convert int ({int}) to an u16")) + })?; + visitor.visit_u16(value) + } } fn deserialize_u32<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de>, { - let int = self.read_long("u32")?; - let value = u32::try_from(int) - .map_err(|_| self.error("u32", format!("Could not convert int ({int}) to an u32")))?; - visitor.visit_u32(value) + if let Schema::Union(union) = self.schema { + self.with_union(union)?.deserialize_u32(visitor) + } else { + let int = self.read_long("u32")?; + let value = u32::try_from(int).map_err(|_| { + self.error("u32", format!("Could not convert int ({int}) to an u32")) + })?; + visitor.visit_u32(value) + } } fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de>, { - if let Schema::Fixed(fixed) = self.schema - && fixed.size == 8 - && fixed.name.name() == "u64" - { - let mut buf = [0; 8]; - self.reader - .read_exact(&mut buf) - .map_err(Details::ReadBytes)?; - visitor.visit_u64(u64::from_le_bytes(buf)) - } else { - Err(self.error("u64", r#"Expected Schema::Fixed(name: "u64", size: 8)"#)) + match self.schema { + Schema::Fixed(fixed) if fixed.size == 8 && fixed.name.name() == "u64" => { + let mut buf = [0; 8]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_u64(u64::from_le_bytes(buf)) + } + Schema::Union(union) => self.with_union(union)?.deserialize_u64(visitor), + _ => Err(self.error("u64", r#"Expected Schema::Fixed(name: "u64", size: 8)"#)), } } @@ -339,14 +382,16 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - if let Schema::Float = self.schema { - let mut buf = [0; 4]; - self.reader - .read_exact(&mut buf) - .map_err(Details::ReadBytes)?; - visitor.visit_f32(f32::from_le_bytes(buf)) - } else { - Err(self.error("f32", r#"Expected Schema::Float)"#)) + match self.schema { + Schema::Float => { + let mut buf = [0; 4]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_f32(f32::from_le_bytes(buf)) + } + Schema::Union(union) => self.with_union(union)?.deserialize_f32(visitor), + _ => Err(self.error("f32", r#"Expected Schema::Float)"#)), } } @@ -354,14 +399,16 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - if let Schema::Double = self.schema { - let mut buf = [0; 8]; - self.reader - .read_exact(&mut buf) - .map_err(Details::ReadBytes)?; - visitor.visit_f64(f64::from_le_bytes(buf)) - } else { - Err(self.error("f64", r#"Expected Schema::Double)"#)) + match self.schema { + Schema::Double => { + let mut buf = [0; 8]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_f64(f64::from_le_bytes(buf)) + } + Schema::Union(union) => self.with_union(union)?.deserialize_f64(visitor), + _ => Err(self.error("f64", r#"Expected Schema::Double)"#)), } } @@ -369,20 +416,22 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - if let Schema::String = self.schema { - let string = self.read_string()?; - let mut chars = string.chars(); - if let Some(character) = chars.next() { - if chars.next().is_some() { - Err(self.error("char", "String contains more than one character")) + match self.schema { + Schema::String => { + let string = self.read_string()?; + let mut chars = string.chars(); + if let Some(character) = chars.next() { + if chars.next().is_some() { + Err(self.error("char", "String contains more than one character")) + } else { + visitor.visit_char(character) + } } else { - visitor.visit_char(character) + Err(self.error("char", "String is empty")) } - } else { - Err(self.error("char", "String is empty")) } - } else { - Err(self.error("char", "Expected Schema::String")) + Schema::Union(union) => self.with_union(union)?.deserialize_char(visitor), + _ => Err(self.error("char", "Expected Schema::String")), } } @@ -397,11 +446,13 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - if let Schema::String = self.schema { - let string = self.read_string()?; - visitor.visit_string(string) - } else { - Err(self.error("string", "Expected Schema::String")) + match self.schema { + Schema::String => { + let string = self.read_string()?; + visitor.visit_string(string) + } + Schema::Union(union) => self.with_union(union)?.deserialize_string(visitor), + _ => Err(self.error("string", "Expected Schema::String")), } } @@ -425,6 +476,9 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> let bytes = self.read_bytes(fixed.size)?; visitor.visit_byte_buf(bytes) } + Schema::Union(union) => { + self.with_union(union)?.deserialize_byte_buf(visitor) + } _ => Err(self.error("bytes", "Expected Schema::Bytes | Schema::Fixed | Schema::BigDecimal | Schema::Decimal | Schema::Uuid(Fixed) | Schema::Duration")) } } @@ -456,10 +510,10 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - if let Schema::Null = self.schema { - visitor.visit_unit() - } else { - Err(self.error("unit", "Expected Schema::Null")) + match self.schema { + Schema::Null => visitor.visit_unit(), + Schema::Union(union) => self.with_union(union)?.deserialize_unit(visitor), + _ => Err(self.error("unit", "Expected Schema::Null")), } } @@ -471,16 +525,15 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - if let Schema::Record(record) = self.schema - && record.fields.is_empty() - && record.name.name() == name - { - visitor.visit_unit() - } else { - Err(self.error( + match self.schema { + Schema::Record(record) if record.fields.is_empty() && record.name.name() == name => { + visitor.visit_unit() + } + Schema::Union(union) => self.with_union(union)?.deserialize_unit(visitor), + _ => Err(self.error( "unit struct", format!("Expected Schema::Record(name: {name}, fields.len() == 0)"), - )) + )), } } @@ -492,16 +545,17 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - if let Schema::Record(record) = self.schema - && record.fields.len() == 1 - && record.name.name() == name - { - visitor.visit_newtype_struct(self.with_different_schema(&record.fields[0].schema)?) - } else { - Err(self.error( - "unit struct", + match self.schema { + Schema::Record(record) if record.fields.len() == 1 && record.name.name() == name => { + visitor.visit_newtype_struct(self.with_different_schema(&record.fields[0].schema)?) + } + Schema::Union(union) => self + .with_union(union)? + .deserialize_newtype_struct(name, visitor), + _ => Err(self.error( + "newtype struct", format!("Expected Schema::Record(name: {name}, fields.len() == 1)"), - )) + )), } } @@ -509,10 +563,12 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - if let Schema::Array(array) = self.schema { - visitor.visit_seq(ArrayDeserializer::new(self.reader, array, self.config)) - } else { - Err(self.error("array", "Expected Schema::Array")) + match self.schema { + Schema::Array(array) => { + visitor.visit_seq(ArrayDeserializer::new(self.reader, array, self.config)) + } + Schema::Union(union) => self.with_union(union)?.deserialize_seq(visitor), + _ => Err(self.error("array", "Expected Schema::Array")), } } @@ -529,6 +585,7 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> Schema::Record(record) if record.fields.len() == len => { visitor.visit_seq(ManyTupleDeserializer::new(self.reader, record, self.config)) } + Schema::Union(union) => self.with_union(union)?.deserialize_tuple(len, visitor), _ if len == 0 => Err(self.error("tuple", "Expected Schema::Null for unit tuple")), _ => Err(self.error("tuple", format!("Expected Schema::Record for {len}-tuple"))), } @@ -543,16 +600,17 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - if let Schema::Record(record) = self.schema - && record.name.name() == name - && record.fields.len() == len - { - visitor.visit_map(RecordDeserializer::new(self.reader, record, self.config)) - } else { - Err(self.error( + match self.schema { + Schema::Record(record) if record.name.name() == name && record.fields.len() == len => { + visitor.visit_map(RecordDeserializer::new(self.reader, record, self.config)) + } + Schema::Union(union) => self + .with_union(union)? + .deserialize_tuple_struct(name, len, visitor), + _ => Err(self.error( "tuple struct", format!("Expected Schema::Record(fields.len() == {len})"), - )) + )), } } @@ -568,6 +626,7 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> // Needed for flattened structs which are (de)serialized as maps visitor.visit_map(RecordDeserializer::new(self.reader, record, self.config)) } + Schema::Union(union) => self.with_union(union)?.deserialize_map(visitor), _ => Err(self.error("map", "Expected Schema::Map")), } } @@ -575,19 +634,22 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> fn deserialize_struct<V>( self, name: &'static str, - _fields: &'static [&'static str], + fields: &'static [&'static str], visitor: V, ) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de>, { - println!("deserialize_struct(name: {name}, fields = {_fields:?}): {self:?}"); - if let Schema::Record(record) = self.schema - && record.name.name() == name - { - visitor.visit_map(RecordDeserializer::new(self.reader, record, self.config)) - } else { - Err(self.error("struct", format!("Expected Schema::Record(name: {name})"))) + match self.schema { + Schema::Record(record) + if record.name.name() == name || name.as_ptr() == DESERIALIZE_ANY.as_ptr() => + { + visitor.visit_map(RecordDeserializer::new(self.reader, record, self.config)) + } + Schema::Union(union) => self + .with_union(union)? + .deserialize_struct(name, fields, visitor), + _ => Err(self.error("struct", format!("Expected Schema::Record(name: {name})"))), } } @@ -600,9 +662,11 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - println!("deserialize_enum(name: {name}, variants: {variants:?}): {self:?}"); match self.schema { Schema::Enum(schema) => visitor.visit_enum(PlainEnumAccess::new(self.reader, schema)), + Schema::Union(schema) => { + visitor.visit_enum(UnionEnumAccess::new(schema, self.reader, self.config)) + } _ => panic!( "deserializing enum, name: {name}, variants: {variants:#?}, {:?}", self.schema @@ -614,17 +678,16 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - if let Schema::Fixed(fixed) = self.schema - && fixed.size == 16 - && fixed.name.name() == "i128" - { - let mut buf = [0; 16]; - self.reader - .read_exact(&mut buf) - .map_err(Details::ReadBytes)?; - visitor.visit_i128(i128::from_le_bytes(buf)) - } else { - Err(self.error("i128", r#"Expected Schema::Fixed(name: "i128", size: 16)"#)) + match self.schema { + Schema::Fixed(fixed) if fixed.size == 16 && fixed.name.name() == "i128" => { + let mut buf = [0; 16]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_i128(i128::from_le_bytes(buf)) + } + Schema::Union(union) => self.with_union(union)?.deserialize_i128(visitor), + _ => Err(self.error("i128", r#"Expected Schema::Fixed(name: "i128", size: 16)"#)), } } @@ -632,17 +695,16 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> where V: serde::de::Visitor<'de>, { - if let Schema::Fixed(fixed) = self.schema - && fixed.size == 16 - && fixed.name.name() == "u128" - { - let mut buf = [0; 16]; - self.reader - .read_exact(&mut buf) - .map_err(Details::ReadBytes)?; - visitor.visit_u128(u128::from_le_bytes(buf)) - } else { - Err(self.error("u128", r#"Expected Schema::Fixed(name: "u128", size: 16)"#)) + match self.schema { + Schema::Fixed(fixed) if fixed.size == 16 && fixed.name.name() == "u128" => { + let mut buf = [0; 16]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_u128(u128::from_le_bytes(buf)) + } + Schema::Union(union) => self.with_union(union)?.deserialize_u128(visitor), + _ => Err(self.error("u128", r#"Expected Schema::Fixed(name: "u128", size: 16)"#)), } } diff --git a/avro/src/serde/deser_schema/record.rs b/avro/src/serde/deser_schema/record.rs index 055dc37..86e6000 100644 --- a/avro/src/serde/deser_schema/record.rs +++ b/avro/src/serde/deser_schema/record.rs @@ -19,14 +19,11 @@ use std::borrow::Borrow; use std::fmt::{Debug, Formatter}; use std::io::Read; -use serde::{Deserializer, de::MapAccess}; +use serde::de::MapAccess; use super::Config; -use crate::serde::deser_schema::union::UnionDeserializer; -use crate::{ - Error, Schema, error::Details, schema::RecordSchema, - serde::deser_schema::SchemaAwareDeserializer, -}; +use crate::serde::deser_schema::identifier::IdentifierDeserializer; +use crate::{Error, Schema, schema::RecordSchema, serde::deser_schema::SchemaAwareDeserializer}; pub struct RecordDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> { reader: &'r mut R, @@ -76,9 +73,9 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> MapAccess<'de> for RecordDeseriali // Finished reading this record Ok(None) } else { - let v = seed.deserialize(FieldName { - name: &self.schema.fields[index].name, - })?; + let v = seed.deserialize(IdentifierDeserializer::string( + &self.schema.fields[index].name, + ))?; self.current_field = State::Value(index); Ok(Some(v)) } @@ -92,16 +89,11 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> MapAccess<'de> for RecordDeseriali let State::Value(index) = self.current_field else { panic!("`next_key_seed` and `next_value_seed` where called in the wrong error") }; - let v = match &self.schema.fields[index].schema { - Schema::Union(union) => { - seed.deserialize(UnionDeserializer::new(self.reader, union, self.config)?)? - } - schema => seed.deserialize(SchemaAwareDeserializer::new( - self.reader, - schema, - self.config, - )?)?, - }; + let v = seed.deserialize(SchemaAwareDeserializer::new( + self.reader, + &self.schema.fields[index].schema, + self.config, + )?)?; self.current_field = State::Key(index + 1); Ok(v) } @@ -112,248 +104,3 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> MapAccess<'de> for RecordDeseriali } } } - -/// "Deserializer" for the field name -struct FieldName<'s> { - name: &'s str, -} - -impl<'de, 's> Deserializer<'de> for FieldName<'s> { - type Error = Error; - - fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_any".into()).into()) - } - - fn deserialize_bool<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_bool".into()).into()) - } - - fn deserialize_i8<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_i8".into()).into()) - } - - fn deserialize_i16<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_i16".into()).into()) - } - - fn deserialize_i32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_i32".into()).into()) - } - - fn deserialize_i64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_i64".into()).into()) - } - - fn deserialize_u8<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_u8".into()).into()) - } - - fn deserialize_u16<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_u16".into()).into()) - } - - fn deserialize_u32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_u32".into()).into()) - } - - fn deserialize_u64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_u64".into()).into()) - } - - fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_f32".into()).into()) - } - - fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_f64".into()).into()) - } - - fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_char".into()).into()) - } - - fn deserialize_str<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_str".into()).into()) - } - - fn deserialize_string<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_string".into()).into()) - } - - fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_bytes".into()).into()) - } - - fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_byte_buf".into()).into()) - } - - fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_option".into()).into()) - } - - fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_unit".into()).into()) - } - - fn deserialize_unit_struct<V>( - self, - name: &'static str, - _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey(format!("deserialize_unit_struct(name: {name})")).into()) - } - - fn deserialize_newtype_struct<V>( - self, - name: &'static str, - _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey(format!("deserialize_newtype_struct(name: {name})")).into()) - } - - fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_seq".into()).into()) - } - - fn deserialize_tuple<V>(self, len: usize, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey(format!("deserialize_tuple(len: {len})")).into()) - } - - fn deserialize_tuple_struct<V>( - self, - name: &'static str, - len: usize, - _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey(format!( - "deserialize_tuple_struct(name: {name}, len: {len})" - )) - .into()) - } - - fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_map".into()).into()) - } - - fn deserialize_struct<V>( - self, - name: &'static str, - fields: &'static [&'static str], - _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey(format!( - "deserialize_struct(name: {name}, fields: {fields:?})" - )) - .into()) - } - - fn deserialize_enum<V>( - self, - name: &'static str, - variants: &'static [&'static str], - _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey(format!( - "deserialize_enum(name: {name}, variants: {variants:?})" - )) - .into()) - } - - fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - println!("deserializing_identifier: {}", self.name); - visitor.visit_str(self.name) - } - - fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - Err(Details::DeserializeKey("deserialize_ignored_any".into()).into()) - } -} diff --git a/avro/src/serde/deser_schema/tuple.rs b/avro/src/serde/deser_schema/tuple.rs index 6b9284d..bdf07de 100644 --- a/avro/src/serde/deser_schema/tuple.rs +++ b/avro/src/serde/deser_schema/tuple.rs @@ -1,5 +1,4 @@ use crate::schema::RecordSchema; -use crate::serde::deser_schema::union::UnionDeserializer; use crate::serde::deser_schema::{Config, SchemaAwareDeserializer}; use crate::{Error, Schema}; use serde::de::{DeserializeSeed, SeqAccess}; @@ -35,16 +34,11 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> SeqAccess<'de> { if self.current_field < self.schema.fields.len() { let schema = &self.schema.fields[self.current_field].schema; - let v = match schema { - Schema::Union(union) => { - seed.deserialize(UnionDeserializer::new(self.reader, union, self.config)?)? - } - schema => seed.deserialize(SchemaAwareDeserializer::new( - self.reader, - schema, - self.config, - )?)?, - }; + let v = seed.deserialize(SchemaAwareDeserializer::new( + self.reader, + schema, + self.config, + )?)?; self.current_field += 1; Ok(Some(v)) } else { @@ -95,16 +89,11 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> SeqAccess<'de> if self.field_read { Ok(None) } else { - let v = match self.schema { - Schema::Union(union) => { - seed.deserialize(UnionDeserializer::new(self.reader, union, self.config)?)? - } - schema => seed.deserialize(SchemaAwareDeserializer::new( - self.reader, - schema, - self.config, - )?)?, - }; + let v = seed.deserialize(SchemaAwareDeserializer::new( + self.reader, + self.schema, + self.config, + )?)?; self.field_read = true; Ok(Some(v)) } diff --git a/avro/src/serde/deser_schema/union.rs b/avro/src/serde/deser_schema/union.rs deleted file mode 100644 index 658eb2e..0000000 --- a/avro/src/serde/deser_schema/union.rs +++ /dev/null @@ -1,450 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::borrow::Borrow; -use std::fmt::{Debug, Formatter}; -use std::io::Read; - -use serde::de::Visitor; - -use super::{Config, DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS, SchemaAwareDeserializer}; -use crate::error::Details; -use crate::schema::{SchemaKind, UuidSchema}; -use crate::util::zag_i32; -use crate::{Error, Schema, schema::UnionSchema}; - -pub struct UnionDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> { - reader: &'r mut R, - schema: &'s UnionSchema, - config: Config<'s, S>, - variant: &'s Schema, -} - -impl<'s, 'r, R: Read, S: Borrow<Schema>> Debug for UnionDeserializer<'s, 'r, R, S> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("UnionDeserializer") - .field("schema", self.schema) - .field("variant", &self.variant) - .finish() - } -} - -impl<'s, 'r, R: Read, S: Borrow<Schema>> UnionDeserializer<'s, 'r, R, S> { - pub fn new( - reader: &'r mut R, - schema: &'s UnionSchema, - config: Config<'s, S>, - ) -> Result<Self, Error> { - let index = zag_i32(reader)?; - let variant = - schema - .variants() - .get(index as usize) - .ok_or_else(|| Details::GetUnionVariant { - index: index as i64, - num_variants: schema.variants().len(), - })?; - Ok(Self { - reader, - schema, - config, - variant, - }) - } - - fn error(&self, ty: &'static str, error: impl Into<String>) -> Error { - Error::new(Details::DeserializeValueWithSchema { - value_type: ty, - value: error.into(), - schema: Schema::Union(self.schema.clone()), - }) - } -} - -impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> serde::Deserializer<'de> - for UnionDeserializer<'s, 'r, R, S> -{ - type Error = Error; - - fn deserialize_any<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_any: {self:?}"); - match self.variant { - Schema::Null => visitor.visit_unit(), - Schema::Boolean => self.deserialize_bool(visitor), - Schema::Int | Schema::Date | Schema::TimeMillis => self.deserialize_i32(visitor), - Schema::Long - | Schema::TimeMicros - | Schema::TimestampMillis - | Schema::TimestampMicros - | Schema::TimestampNanos - | Schema::LocalTimestampMillis - | Schema::LocalTimestampMicros - | Schema::LocalTimestampNanos => self.deserialize_i64(visitor), - Schema::Float => self.deserialize_f32(visitor), - Schema::Double => self.deserialize_f64(visitor), - Schema::Bytes - | Schema::Fixed(_) - | Schema::Decimal(_) - | Schema::BigDecimal - | Schema::Uuid(UuidSchema::Fixed(_)) - | Schema::Duration(_) => self.deserialize_byte_buf(visitor), - Schema::String | Schema::Uuid(UuidSchema::String) => self.deserialize_string(visitor), - Schema::Array(_) => self.deserialize_seq(visitor), - Schema::Map(_) => self.deserialize_map(visitor), - Schema::Record(schema) => { - if schema.attributes.get("org.apache.avro.rust.tuple") - == Some(&serde_json::Value::Bool(true)) - { - // This is needed because we can't tell the difference between a tuple and struct. - // And a tuple needs to be deserialized as a sequence - self.deserialize_tuple(schema.fields.len(), visitor) - } else { - self.deserialize_struct(DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS, visitor) - } - } - Schema::Enum(_) => { - self.deserialize_enum(DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS, visitor) - } - Schema::Ref { name } => { - let schema = self - .config - .names - .get(name) - .ok_or_else(|| Details::SchemaResolutionError(name.clone()))? - .borrow(); - self.variant = schema; - self.deserialize_any(visitor) - } - Schema::Union(_) => Err(self.error("any", "Nested unions are not supported")), - Schema::Uuid(UuidSchema::Bytes) => panic!("Unsupported"), - } - } - - fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_bool: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_bool(visitor) - } - - fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_i8: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_i8(visitor) - } - - fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_i16: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_i16(visitor) - } - - fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_i32: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_i32(visitor) - } - - fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_i64: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_i64(visitor) - } - - fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_i128: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_i128(visitor) - } - - fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_u8: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_u8(visitor) - } - - fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_u16: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_u16(visitor) - } - - fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_u32: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_u32(visitor) - } - - fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_u64: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_u64(visitor) - } - - fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_u128: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_u128(visitor) - } - - fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_f32: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_f32(visitor) - } - - fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_f64: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_f64(visitor) - } - - fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_char: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_char(visitor) - } - - fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_str: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_str(visitor) - } - - fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_string: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_string(visitor) - } - - fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_bytes: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_bytes(visitor) - } - - fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_byte_buf: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_byte_buf(visitor) - } - - fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_option: {self:?}"); - if self.schema.variants().len() == 2 && self.schema.index().get(&SchemaKind::Null).is_some() - { - match self.variant { - Schema::Null => visitor.visit_none(), - schema => visitor.visit_some(SchemaAwareDeserializer::new( - self.reader, - schema, - self.config, - )?), - } - } else { - Err(self.error( - "option", - "Expected Schema::Union(variants.contains(Schema::Null), self.variants.len() == 2)", - )) - } - } - - fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_unit: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_unit(visitor) - } - - fn deserialize_unit_struct<V>( - self, - name: &'static str, - visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_unit_struct: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_unit_struct(name, visitor) - } - - fn deserialize_newtype_struct<V>( - self, - name: &'static str, - visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_newtype_struct: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_newtype_struct(name, visitor) - } - - fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_seq: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_seq(visitor) - } - - fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_tuple: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_tuple(len, visitor) - } - - fn deserialize_tuple_struct<V>( - self, - name: &'static str, - len: usize, - visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_tuple_struct: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_tuple_struct(name, len, visitor) - } - - fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_map: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_map(visitor) - } - - fn deserialize_struct<V>( - self, - name: &'static str, - fields: &'static [&'static str], - visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_struct: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_struct(name, fields, visitor) - } - - fn deserialize_enum<V>( - self, - name: &'static str, - variants: &'static [&'static str], - visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_enum: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_enum(name, variants, visitor) - } - - fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - println!("deserialize_identifier: {self:?}"); - SchemaAwareDeserializer::new(self.reader, self.variant, self.config)? - .deserialize_identifier(visitor) - } - - fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - // TODO: We can do something more efficient here - println!("deserialize_ignored_any: {self:?}"); - self.deserialize_any(visitor) - } - - fn is_human_readable(&self) -> bool { - self.config.human_readable - } -} diff --git a/avro/src/serde/mod.rs b/avro/src/serde/mod.rs index 96cf620..11ffd84 100644 --- a/avro/src/serde/mod.rs +++ b/avro/src/serde/mod.rs @@ -115,6 +115,10 @@ pub(crate) mod ser_schema; mod util; mod with; +#[expect( + deprecated, + reason = "Still need to export it until we remove it completely" +)] pub use de::from_value; pub use derive::{AvroSchema, AvroSchemaComponent}; pub use ser::to_value; diff --git a/avro/src/serde/ser_schema/mod.rs b/avro/src/serde/ser_schema/mod.rs index 9244910..8a1e511 100644 --- a/avro/src/serde/ser_schema/mod.rs +++ b/avro/src/serde/ser_schema/mod.rs @@ -455,7 +455,7 @@ impl<'s, 'w, W: Write> Serializer for SchemaAwareSerializer<'s, 'w, W> { variant: &'static str, ) -> Result<Self::Ok, Self::Error> { println!( - "serialize_struct_variant(name: {name}, index: {variant_index}, variant: {variant}): {self:?}" + "serialize_unit_variant(name: {name}, index: {variant_index}, variant: {variant}): {self:?}" ); match self.schema { Schema::Enum(enum_schema) => { diff --git a/avro/src/serde/with.rs b/avro/src/serde/with.rs index 99a432e..ead124d 100644 --- a/avro/src/serde/with.rs +++ b/avro/src/serde/with.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -#![expect(clippy::ref_option, reason = "Required by the Serde API")] - use std::cell::Cell; thread_local! { @@ -814,7 +812,7 @@ pub mod array_opt { #[cfg(test)] mod tests { - use crate::{Schema, from_value, to_value, types::Value}; + use crate::{Schema, to_value, types::Value}; use serde::{Deserialize, Serialize}; #[test] @@ -884,6 +882,7 @@ mod tests { assert!(value.validate(&schema)); } + #[expect(deprecated, reason = "This tests the deprecated function")] #[test] fn avro_3631_deserialize_value_to_struct_with_byte_types() { #[derive(Debug, Deserialize, PartialEq)] @@ -1007,7 +1006,7 @@ mod tests { Value::Union(1, Box::new(Value::Null)), ), ]); - assert_eq!(expected, from_value(&value).unwrap()); + assert_eq!(expected, crate::from_value(&value).unwrap()); } #[test] diff --git a/avro/tests/avro-rs-285-bytes_deserialization.rs b/avro/tests/avro-rs-285-bytes_deserialization.rs index cf02567..da8abcd 100644 --- a/avro/tests/avro-rs-285-bytes_deserialization.rs +++ b/avro/tests/avro-rs-285-bytes_deserialization.rs @@ -72,13 +72,15 @@ fn avro_rs_285_bytes_deserialization_round_trip() -> TestResult { // deserialize Avro binary data back into ExampleByteArray structs let reader = apache_avro::Reader::new(&avro_data[..])?; let deserialized_records: Vec<ExampleByteArray> = reader - .map(|value| apache_avro::from_value::<ExampleByteArray>(&value.unwrap()).unwrap()) + .into_serde_iter() + .map(|value| value.unwrap()) .collect(); assert_eq!(records, deserialized_records); Ok(()) } +#[expect(deprecated, reason = "Schema resolution is WIP")] #[test] fn avro_rs_285_bytes_deserialization_filtered_round_trip() -> TestResult { let raw_schema = r#" diff --git a/avro/tests/schema.rs b/avro/tests/schema.rs index 89e328a..1273a3d 100644 --- a/avro/tests/schema.rs +++ b/avro/tests/schema.rs @@ -19,7 +19,6 @@ use apache_avro::writer::datum::GenericDatumWriter; use apache_avro::{ Codec, Error, Reader, Schema, Writer, error::Details, - from_value, reader::datum::GenericDatumReader, schema::{EnumSchema, FixedSchema, Name, RecordField, RecordSchema}, to_value, @@ -837,7 +836,7 @@ fn avro_old_issue_47() -> TestResult { let schema_str = r#" { "type": "record", - "name": "my_record", + "name": "MyRecord", "fields": [ {"name": "a", "type": "long"}, {"name": "b", "type": "string"} @@ -863,10 +862,9 @@ fn avro_old_issue_47() -> TestResult { .build()? .write_value_to_vec(ser_value)?; - let de_value = GenericDatumReader::builder(&schema) + let deserialized_record: MyRecord = GenericDatumReader::builder(&schema) .build()? - .read_value(&mut &*serialized_bytes)?; - let deserialized_record = from_value::<MyRecord>(&de_value)?; + .read_deser(&mut &*serialized_bytes)?; assert_eq!(record, deserialized_record); Ok(()) diff --git a/avro/tests/union_schema.rs b/avro/tests/union_schema.rs index 892b69b..d3c68c9 100644 --- a/avro/tests/union_schema.rs +++ b/avro/tests/union_schema.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use apache_avro::{AvroResult, Codec, Reader, Schema, Writer, from_value}; +use apache_avro::{AvroResult, Codec, Reader, Schema, Writer}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; static SCHEMA_A_STR: &str = r#"{ @@ -80,7 +80,7 @@ where .reader_schema(schema) .schemata(schemata.iter().collect()) .build()?; - from_value::<T>(&reader.next().expect("")?) + Ok(reader.next_deser::<T>()?.expect("Expected a value")) } #[test] diff --git a/avro_derive/src/attributes/serde.rs b/avro_derive/src/attributes/serde.rs index b0e4ae1..be3fe65 100644 --- a/avro_derive/src/attributes/serde.rs +++ b/avro_derive/src/attributes/serde.rs @@ -159,6 +159,10 @@ pub struct VariantAttributes { #[darling(default)] pub rename_all: RenameAll, /// Do not serialize or deserialize this variant. + /// + /// Skip (and skip_{de,}serializing) should not be used. + /// Serde will remove the variant from the list of variants when deserializing but + /// will not update the index when serializing. #[darling(default, rename = "skip")] pub _skip: bool, /// Do not serialize this variant. diff --git a/avro_derive/tests/serde.rs b/avro_derive/tests/serde.rs index 445e11c..7c4c2f4 100644 --- a/avro_derive/tests/serde.rs +++ b/avro_derive/tests/serde.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use apache_avro::{AvroSchema, Error, Reader, Schema, Writer, from_value}; +use apache_avro::{AvroSchema, Error, Reader, Schema, Writer}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; /// Takes in a type that implements the right combination of traits and runs it through a Serde @@ -66,8 +66,8 @@ where let mut reader = Reader::builder(&encoded[..]) .reader_schema(&schema) .build()?; - if let Some(res) = reader.next() { - return res.and_then(|v| from_value::<T>(&v)); + if let Some(res) = reader.next_deser::<T>()? { + return Ok(res); } panic!("Nothing was encoded!") }
