This is an automated email from the ASF dual-hosted git repository. kriskras99 pushed a commit to branch feat/enums in repository https://gitbox.apache.org/repos/asf/avro-rs.git
commit 4c0cd5605781e8240e30ab300bf4fefff176e0e0 Author: Kriskras99 <[email protected]> AuthorDate: Fri Mar 13 11:04:29 2026 +0100 partial resolving deserializer --- avro/src/bigdecimal.rs | 1 - avro/src/documentation/mod.rs | 1 + avro/src/documentation/serde_datamodel_to_avro.rs | 117 +++- avro/src/error.rs | 14 + avro/src/schema/mod.rs | 37 +- avro/src/schema/union.rs | 67 +-- avro/src/schema_compatibility.rs | 118 ++-- avro/src/serde/deser_resolving/any.rs | 218 +++++++ avro/src/serde/deser_resolving/mod.rs | 665 +++++++++++++++++++++ avro/src/serde/deser_resolving/record.rs | 1 + avro/src/serde/deser_resolving/record/default.rs | 340 +++++++++++ avro/src/serde/deser_schema/mod.rs | 86 +-- avro/src/serde/mod.rs | 1 + avro/src/serde/ser_schema/mod.rs | 122 +--- avro/src/serde/with.rs | 55 +- avro_derive/src/attributes/mod.rs | 188 +++--- avro_derive/src/fields.rs | 5 + avro_derive/tests/derive.rs | 5 +- avro_derive/tests/enum.rs | 49 ++ avro_derive/tests/serde.rs | 2 +- .../ui/avro_rs_xxx_bare_union_and_untagged.stderr | 9 - 21 files changed, 1746 insertions(+), 355 deletions(-) diff --git a/avro/src/bigdecimal.rs b/avro/src/bigdecimal.rs index 7df71dc..58efb0d 100644 --- a/avro/src/bigdecimal.rs +++ b/avro/src/bigdecimal.rs @@ -185,7 +185,6 @@ mod tests { // TODO: Needs new schema aware deserializer #[test] - #[ignore] fn avro_rs_338_deserialize_serde_way() -> TestResult { #[derive(Clone, PartialEq, Eq, Debug, Default, serde::Deserialize, serde::Serialize)] #[serde(rename = "test")] diff --git a/avro/src/documentation/mod.rs b/avro/src/documentation/mod.rs index 7771012..5598566 100644 --- a/avro/src/documentation/mod.rs +++ b/avro/src/documentation/mod.rs @@ -22,3 +22,4 @@ pub mod dynamic; pub mod primer; +pub mod serde_datamodel_to_avro; diff --git a/avro/src/documentation/serde_datamodel_to_avro.rs b/avro/src/documentation/serde_datamodel_to_avro.rs index 5e1ad7f..1fe4664 100644 --- a/avro/src/documentation/serde_datamodel_to_avro.rs +++ b/avro/src/documentation/serde_datamodel_to_avro.rs @@ -16,48 +16,103 @@ // under the License. //! # Mapping the Serde datamodel to the Avro datamodel -//! +//! //! When manually mapping Rust types to a Avro schema, or the reverse, it is important to understand //! how the datamodels are mapped. This should only be done in very specific circumstances, most users //! should use the [`AvroSchema`] derive macro. -//! +//! //! Only the schemas as generated by the [`AvroSchema`] derive macro and the mapping as defined here are -//! supported. Any other behaviour is not supported. -//! -//! ## Primitive types -//! - `bool` -> [`Schema::Boolean`] -//! - `i8`, `i16`, `i32`, `u8`, `u16` -> [`Schema::Int`] -//! - `i64`, `u32` -> [`Schema::Long`] -//! - `u64` -> [`Schema::Fixed`]`(name: "u64", size: 8)` -//! - This is not a `Schema::Long` as that is a signed number of maximum 64 bits. -//! - `i128` -> [`Schema::Fixed`]`(name: "i128", size: 16)` -//! - `u128` -> [`Schema::Fixed`]`(name: "u128", size: 16)` -//! - `f32` -> [`Schema::Float`] -//! - `f64` -> [`Schema::Double`] -//! - `char` -> [`Schema::String`] -//! - Only one character allowed, deserializer will return an error for strings with more than one character. -//! - `tuple` -> [`Schema::Record`] with a field for every tuple index -//! - `[T; N]` -> [`Schema::Record`] with a field for every array index -//! - To (de)serialize as a [`Schema::Array`] use [`apache_avro::serde::array`] -//! - To (de)serialize as a [`Schema::Fixed`] use [`apache_avro::serde::fixed`] -//! - To (de)serialize as a [`Schema::Bytes`] use [`apache_avro::serde::bytes`] -//! -//! ## Standard library types -//! - `string` -> [`Schema::String`] -//! - `` -//! -//! -//! [`apache_avro::serde::array`]: crate::serde::array -//! [`apache_avro::serde::bytes`]: crate::serde::bytes -//! [`apache_avro::serde::fixed`]: crate::serde::fixed +//! supported. Any other behavior is not supported. +//! +//! - **14 primitive types** +//! - `bool` => [`Schema::Boolean`] +//! - `i8`, `i16`, `i32`, `u8`, `u16` => [`Schema::Int`] +//! - `i64`, `u32` => [`Schema::Long`] +//! - `u64` => [`Schema::Fixed`]`(name: "u64", size: 8)` +//! - This is not a `Schema::Long` as that is a signed number of maximum 64 bits. +//! - `i128` => [`Schema::Fixed`]`(name: "i128", size: 16)` +//! - `u128` => [`Schema::Fixed`]`(name: "u128", size: 16)` +//! - `f32` => [`Schema::Float`] +//! - `f64` => [`Schema::Double`] +//! - `char` => [`Schema::String`] +//! - Only one character allowed, deserializer will return an error for strings with more than one character. +//! - **string** => [`Schema::String`] +//! - **byte array** => [`Schema::Bytes`] or [`Schema::Fixed`] +//! - **option** => [`Schema::Union`] of [`Schema::Null`] and the schema of the inner type +//! - **unit** => [`Schema::Null`] +//! - **unit struct** => [`Schema::Record`] with the name of the struct and zero fields +//! - **unit variant** => See [Enums](##Enums) +//! - **newtype struct** => [`Schema::Record`] with the name of the struct and one field +//! - **newtype variant** => See [Enums](##Enums) +//! - **seq** => [`Schema::Array`] where `types` has the schema of the inner type +//! - **tuple** +//! - => The schema of the only element for tuples with one element +//! - => [`Schema::Record`] with as many fields as there are elements +//! - **Note:** Serde (de)serializes arrays (`[T; N]`) as tuples. To (de)serialize an array as a +//! [`Schema::Array`] use [`apache_avro::serde::array`]. +//! - **tuple_struct** => [`Schema::Record`] with the name of the struct and as many fields as there are elements +//! - **Note:** Tuple structs with 0 or 1 element will also be (de)serialized as a [`Schema::Record`]. This +//! is different from normal tuples`. +//! - **tuple_variant** => See [Enums](##Enums) +//! - **map** => [`Schema::Map`] where `items` has the schema of the inner type +//! - **Note:** the key type of the map will be (de)serialized as a [`Schema::String`] +//! - **struct** => [`Schema::Record`] +//! - **struct_variant** => See [Enums](##Enums) +//! +//! ## Enums +//! Avro doesn't have ADTs (Algebraic Data Type). It only has [`Schema::Enum`] and [`Schema::Union`]. +//! Serde also supports different ways of (de)serializing enums making it more complex. The mapping +//! for enums is explained per Serde enum representation. +//! +//! ### Externally tagged +//! This is the default enum representation for Serde. It can be mapped in three ways to the Avro datamodel. +//! For all three options it is important that the enum index matches the Avro index. +//! - As a [`Schema::Enum`], this is only possible for enums with only unit variants. +//! - As a [`Schema::Union`] with a [`Schema::Record`] for every variant: +//! - **unit_variant** => [`Schema::Record`] named as the variant and with no fields. +//! - **newtype_variant** => [`Schema::Record`] named as the variant and with one field. +//! - **tuple_variant** => [`Schema::Record`] named as the variant and with as many fields as there are element. +//! - **struct_variant** => [`Schema::Record`] named as the variant and with a field for every field of the struct variant. +//! - As a [`Schema::Union`] without the wrapper [`Schema::Record`], all schemas must be unique: +//! - **unit_variant** => [`Schema::Null`]. +//! - **newtype_variant** => The schema of the inner type. +//! - **tuple_variant** => [`Schema::Record`] named as the variant and with as many fields as there are element. +//! - **struct_variant** => [`Schema::Record`] named as the variant and with a field for every field of the struct variant. +//! +//! ### Internally tagged +//! This enum representation is used by Serde if the attribute `#[serde(tag = "...")]` is used. +//! It maps to a [`Schema::Record`]. There must be at least one field that is named as the value of the +//! `tag` attribute. If a field is not used by all variants it must have a `default` set. +//! +//! ### Adjacently tagged +//! This enum representation is used by Serde if the attributes `#[serde(tag = "...", content = "...")]` are used. +//! It maps to a [`Schema::Record`] with two fields. One field must be named as the value of the `tag` +//! attribute and use the [`Schema::Enum`] schema. The other field must be named as the value of the +//! `content` tag and use the [`Schema::Union`] schema. +//! +//! ### Untagged +//! This enum representation is ued by Serde if the attribute `#[serde(untagged)]` is used. It maps +//! to a [`Schema::Union`] with the following schemas: +//! - **unit_variant** => [`Schema::Null`]. +//! - **newtype_variant** => The schema of the inner type. +//! - **tuple_variant** => [`Schema::Record`] named as the variant and with as many fields as there are element. +//! - **struct_variant** => [`Schema::Record`] named as the variant and with a field for every field of the struct variant. +//! //! [`AvroSchema`]: crate::AvroSchema //! [`Schema::Array`]: crate::schema::Schema::Array //! [`Schema::Boolean`]: crate::schema::Schema::Boolean //! [`Schema::Bytes`]: crate::schema::Schema::Bytes //! [`Schema::Double`]: crate::schema::Schema::Double +//! [`Schema::Enum`]: crate::schema::Schema::Enum //! [`Schema::Fixed`]: crate::schema::Schema::Fixed //! [`Schema::Float`]: crate::schema::Schema::Float //! [`Schema::Int`]: crate::schema::Schema::Int //! [`Schema::Long`]: crate::schema::Schema::Long +//! [`Schema::Map`]: crate::schema::Schema::Map +//! [`Schema::Null`]: crate::schema::Schema::Null //! [`Schema::Record`]: crate::schema::Schema::Record -//! [`Schema::String`]: crate::schema::Schema::String \ No newline at end of file +//! [`Schema::String`]: crate::schema::Schema::String +//! [`Schema::Union`]: crate::schema::Schema::Union +//! [`apache_avro::serde::array`]: crate::serde::array +//! [`apache_avro::serde::bytes`]: crate::serde::bytes +//! [`apache_avro::serde::fixed`]: crate::serde::fixed diff --git a/avro/src/error.rs b/avro/src/error.rs index 4482280..7b9840c 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -585,6 +585,16 @@ pub enum Details { schema: Schema, }, + #[error( + "Failed to deserialize value of type {value_type} using schema {writer_schema:?} amd resolving to schema {reader_schema:?}: {value}" + )] + DeserializeValueWithResolvingSchema { + value_type: &'static str, + value: String, + writer_schema: Schema, + reader_schema: Schema, + }, + #[error("Only expected `deserialize_identifier` to be called but `{0}` was called")] DeserializeIdentifier(&'static str), @@ -711,6 +721,10 @@ pub enum CompatibilityError { "Incompatible schemata! Unknown type for '{0}'. Make sure that the type is a valid one" )] Inconclusive(String), + + /// Error while resolving [`Schema::Ref`] + #[error("Unresolved schema reference: {0}")] + SchemaResolutionError(Name), } impl serde::ser::Error for Details { diff --git a/avro/src/schema/mod.rs b/avro/src/schema/mod.rs index 06f56be..f78efc1 100644 --- a/avro/src/schema/mod.rs +++ b/avro/src/schema/mod.rs @@ -58,7 +58,7 @@ use std::{ hash::Hash, io::Read, }; -use strum::{Display, EnumDiscriminants}; +use strum::{Display, EnumDiscriminants, IntoDiscriminant}; /// Represents documentation for complex Avro schemas. pub type Documentation = Option<String>; @@ -854,6 +854,41 @@ impl Schema { } } } + + /// Get the [`SchemaKind`] of a [`Schema`] converting logical types to their base type. + pub(crate) fn to_base_schema_kind(&self) -> SchemaKind { + let kind = self.discriminant(); + match kind { + SchemaKind::Date | SchemaKind::TimeMillis => SchemaKind::Int, + SchemaKind::TimeMicros + | SchemaKind::TimestampMillis + | SchemaKind::TimestampMicros + | SchemaKind::TimestampNanos + | SchemaKind::LocalTimestampMillis + | SchemaKind::LocalTimestampMicros + | SchemaKind::LocalTimestampNanos => SchemaKind::Long, + SchemaKind::Uuid => match self { + Schema::Uuid(UuidSchema::Bytes) => SchemaKind::Bytes, + Schema::Uuid(UuidSchema::String) => SchemaKind::String, + Schema::Uuid(UuidSchema::Fixed(_)) => SchemaKind::Fixed, + _ => unreachable!(), + }, + SchemaKind::Decimal => match self { + Schema::Decimal(DecimalSchema { + inner: InnerDecimalSchema::Bytes, + .. + }) => SchemaKind::Bytes, + Schema::Decimal(DecimalSchema { + inner: InnerDecimalSchema::Fixed(_), + .. + }) => SchemaKind::Fixed, + _ => unreachable!(), + }, + SchemaKind::BigDecimal => SchemaKind::Bytes, + SchemaKind::Duration => SchemaKind::Fixed, + _ => kind, + } + } } impl Serialize for Schema { diff --git a/avro/src/schema/union.rs b/avro/src/schema/union.rs index 5a5cc9f..6704021 100644 --- a/avro/src/schema/union.rs +++ b/avro/src/schema/union.rs @@ -19,6 +19,7 @@ use crate::error::Details; use crate::schema::{ DecimalSchema, InnerDecimalSchema, Name, NamespaceRef, Schema, SchemaKind, UuidSchema, }; +use crate::schema_compatibility::Checker; use crate::types; use crate::{AvroResult, Error}; use std::borrow::Borrow; @@ -83,6 +84,30 @@ impl UnionSchema { &self.variant_index } + /// Find the first variant that is (partially) compatible with the provided schema. + /// + /// Partial in this context means that actual compatibility can only be established by reading + /// the data. + pub(crate) fn find_compatible_variant<S: Borrow<Schema>>( + &self, + writer_schema: &Schema, + union_schemata: &HashMap<Name, S>, + writer_schemata: &HashMap<Name, S>, + ) -> Option<usize> { + // We reuse the same checker so that its cache of checked schema can be reused. + let mut checker = Checker::new(writer_schemata, union_schemata); + + // We can't do a fast path, because the specification specifically says: + // "The first schema in the reader’s union that matches the writer’s schema is recursively resolved against it." + // So we have to check every schema in turn + for (index, schema) in self.schemas.iter().enumerate() { + if checker.full_match_schemas(writer_schema, schema).is_ok() { + return Some(index); + } + } + None + } + /// Optionally returns a reference to the schema matched by this value, as well as its position /// within this union. /// @@ -125,7 +150,7 @@ impl UnionSchema { .copied() .map(|i| (i, &self.schemas[i])) .filter(|(_i, s)| { - let s_kind = schema_to_base_schemakind(s); + let s_kind = s.to_base_schema_kind(); s_kind == kind || s_kind == SchemaKind::Ref }) .find(|(_i, schema)| { @@ -285,7 +310,7 @@ impl UnionSchemaBuilder { self.schemas.push(schema); } } else { - let discriminant = schema_to_base_schemakind(&schema); + let discriminant = schema.to_base_schema_kind(); if discriminant == SchemaKind::Union { return Err(Details::GetNestedUnion.into()); } @@ -311,7 +336,7 @@ impl UnionSchemaBuilder { self.schemas.push(schema); } } else { - let discriminant = schema_to_base_schemakind(&schema); + let discriminant = schema.to_base_schema_kind(); if discriminant == SchemaKind::Union { return Err(Details::GetNestedUnion.into()); } @@ -334,7 +359,7 @@ impl UnionSchemaBuilder { false } } else { - let discriminant = schema_to_base_schemakind(schema); + let discriminant = schema.to_base_schema_kind(); if let Some(index) = self.variant_index.get(&discriminant).copied() { &self.schemas[index] == schema } else { @@ -356,40 +381,6 @@ impl UnionSchemaBuilder { } } -/// Get the [`SchemaKind`] of a [`Schema`] converting logical types to their base type. -fn schema_to_base_schemakind(schema: &Schema) -> SchemaKind { - let kind = schema.discriminant(); - match kind { - SchemaKind::Date | SchemaKind::TimeMillis => SchemaKind::Int, - SchemaKind::TimeMicros - | SchemaKind::TimestampMillis - | SchemaKind::TimestampMicros - | SchemaKind::TimestampNanos - | SchemaKind::LocalTimestampMillis - | SchemaKind::LocalTimestampMicros - | SchemaKind::LocalTimestampNanos => SchemaKind::Long, - SchemaKind::Uuid => match schema { - Schema::Uuid(UuidSchema::Bytes) => SchemaKind::Bytes, - Schema::Uuid(UuidSchema::String) => SchemaKind::String, - Schema::Uuid(UuidSchema::Fixed(_)) => SchemaKind::Fixed, - _ => unreachable!(), - }, - SchemaKind::Decimal => match schema { - Schema::Decimal(DecimalSchema { - inner: InnerDecimalSchema::Bytes, - .. - }) => SchemaKind::Bytes, - Schema::Decimal(DecimalSchema { - inner: InnerDecimalSchema::Fixed(_), - .. - }) => SchemaKind::Fixed, - _ => unreachable!(), - }, - SchemaKind::Duration => SchemaKind::Fixed, - _ => kind, - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/avro/src/schema_compatibility.rs b/avro/src/schema_compatibility.rs index 2578551..86480d5 100644 --- a/avro/src/schema_compatibility.rs +++ b/avro/src/schema_compatibility.rs @@ -57,6 +57,7 @@ //! # Ok::<(), Error>(()) //! ``` //! +use crate::schema::Name; use crate::{ error::CompatibilityError, schema::{ @@ -64,13 +65,8 @@ use crate::{ Schema, UuidSchema, }, }; -use std::{ - collections::{HashMap, hash_map::DefaultHasher}, - hash::Hasher, - iter::once, - ops::BitAndAssign, - ptr, -}; +use std::borrow::Borrow; +use std::{collections::HashMap, iter::once, ops::BitAndAssign}; /// Check if two schemas can be resolved. /// @@ -85,7 +81,22 @@ impl SchemaCompatibility { writers_schema: &Schema, readers_schema: &Schema, ) -> Result<Compatibility, CompatibilityError> { - let mut c = Checker::new(); + let empty_schemata: HashMap<_, Schema> = HashMap::new(); + Self::can_read_with_schemata( + writers_schema, + readers_schema, + &empty_schemata, + &empty_schemata, + ) + } + + pub fn can_read_with_schemata<S: Borrow<Schema>>( + writers_schema: &Schema, + readers_schema: &Schema, + writer_schemata: &HashMap<Name, S>, + reader_schemata: &HashMap<Name, S>, + ) -> Result<Compatibility, CompatibilityError> { + let mut c = Checker::new(&writer_schemata, &reader_schemata); c.can_read(writers_schema, readers_schema) } @@ -94,9 +105,29 @@ impl SchemaCompatibility { schema_a: &Schema, schema_b: &Schema, ) -> Result<Compatibility, CompatibilityError> { - let mut c = SchemaCompatibility::can_read(schema_a, schema_b)?; - c &= SchemaCompatibility::can_read(schema_b, schema_a)?; - Ok(c) + let empty_schemata: HashMap<_, Schema> = HashMap::new(); + Self::mutal_read_with_schemata(schema_a, schema_b, &empty_schemata, &empty_schemata) + } + + pub fn mutal_read_with_schemata<S: Borrow<Schema>>( + schema_a: &Schema, + schema_b: &Schema, + schema_a_schemata: &HashMap<Name, S>, + schema_b_schemata: &HashMap<Name, S>, + ) -> Result<Compatibility, CompatibilityError> { + let mut compatibility = SchemaCompatibility::can_read_with_schemata( + schema_a, + schema_b, + schema_a_schemata, + schema_b_schemata, + )?; + compatibility &= SchemaCompatibility::can_read_with_schemata( + schema_b, + schema_a, + schema_b_schemata, + schema_a_schemata, + )?; + Ok(compatibility) } } @@ -127,15 +158,22 @@ impl BitAndAssign for Compatibility { } } -struct Checker { - recursion: HashMap<(u64, u64), Compatibility>, +pub(crate) struct Checker<'s, S: Borrow<Schema>> { + recursion: HashMap<(usize, usize), Compatibility>, + writer_schemata: &'s HashMap<Name, S>, + reader_schemata: &'s HashMap<Name, S>, } -impl Checker { +impl<'s, S: Borrow<Schema>> Checker<'s, S> { /// Create a new checker, with recursion set to an empty set. - pub(crate) fn new() -> Self { + pub(crate) fn new( + writer_schemata: &'s HashMap<Name, S>, + reader_schemata: &'s HashMap<Name, S>, + ) -> Self { Self { recursion: HashMap::new(), + writer_schemata, + reader_schemata, } } @@ -148,26 +186,26 @@ impl Checker { // Hash both reader and writer based on their pointer value. This is a fast way to see if // we get the exact same schemas multiple times (because of recursive types) let key = ( - Self::pointer_hash(writers_schema), - Self::pointer_hash(readers_schema), + Self::ref_to_addr(writers_schema), + Self::ref_to_addr(readers_schema), ); - // If we already saw this pairing, return the previous value - if let Some(c) = self.recursion.get(&key).copied() { - Ok(c) + // `HashMap::entry` cannot be used as that does a mutable borrow of the map and `inner_full_match_schemas` + // does a mutable borrow of `self`. + if let Some(compatibility) = self.recursion.get(&key).copied() { + // If we already saw this pairing, return the previous value + Ok(compatibility) } else { - let c = self.inner_full_match_schemas(writers_schema, readers_schema)?; + let compatibility = self.inner_full_match_schemas(writers_schema, readers_schema)?; // Insert the new value - self.recursion.insert(key, c); - Ok(c) + self.recursion.insert(key, compatibility); + Ok(compatibility) } } - /// Hash a schema based only on its pointer value. - fn pointer_hash(schema: &Schema) -> u64 { - let mut hasher = DefaultHasher::new(); - ptr::hash(schema, &mut hasher); - hasher.finish() + /// Get the address of the reference. + fn ref_to_addr(schema: &Schema) -> usize { + (schema as *const Schema).addr() } /// The actual implementation of "`full_match_schemas()`" but without the recursion protection. @@ -192,14 +230,18 @@ impl Checker { // Logical types are downgraded to their actual type match (writers_schema, readers_schema) { - (Schema::Ref { name: w_name }, Schema::Ref { name: r_name }) => { - if r_name == w_name { - Ok(Compatibility::Full) + (Schema::Ref { name: w_name }, _) => { + if let Some(schema) = self.writer_schemata.get(w_name) { + self.inner_full_match_schemas(schema.borrow(), readers_schema) } else { - Err(CompatibilityError::NameMismatch { - writer_name: w_name.fullname(None), - reader_name: r_name.fullname(None), - }) + Err(CompatibilityError::SchemaResolutionError(w_name.clone())) + } + } + (_, Schema::Ref { name: r_name }) => { + if let Some(schema) = self.reader_schemata.get(r_name) { + self.inner_full_match_schemas(writers_schema, schema.borrow()) + } else { + Err(CompatibilityError::SchemaResolutionError(r_name.clone())) } } (Schema::Union(writer), Schema::Union(reader)) => { @@ -1656,7 +1698,11 @@ mod tests { ]; let schemas = Schema::parse_list(schema_strs).unwrap(); - SchemaCompatibility::can_read(&schemas[1], &schemas[1])?; + let names: HashMap<_, _> = schemas + .iter() + .filter_map(|s| s.name().map(|n| (n.clone(), s))) + .collect(); + SchemaCompatibility::can_read_with_schemata(&schemas[1], &schemas[1], &names, &names)?; Ok(()) } diff --git a/avro/src/serde/deser_resolving/any.rs b/avro/src/serde/deser_resolving/any.rs new file mode 100644 index 0000000..67760a1 --- /dev/null +++ b/avro/src/serde/deser_resolving/any.rs @@ -0,0 +1,218 @@ +use std::fmt::Formatter; +use serde::de::{EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor}; +use serde::{Deserialize, Deserializer}; + +pub struct AnyVisitor; + +impl<'de> Visitor<'de> for AnyVisitor { + type Value = (); + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + write!(formatter, "anything") + } + + fn visit_bool<E>(self, _: bool) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_i8<E>(self, _: i8) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_i16<E>(self, _: i16) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_i32<E>(self, _: i32) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_i64<E>(self, _: i64) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_i128<E>(self, _: i128) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_u8<E>(self, _: u8) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_u16<E>(self, _: u16) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_u32<E>(self, _: u32) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_u64<E>(self, _: u64) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_u128<E>(self, _: u128) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_f32<E>(self, _: f32) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_f64<E>(self, _: f64) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_char<E>(self, _: char) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_str<E>(self, _: &str) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_borrowed_str<E>(self, _: &'de str) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_string<E>(self, _: String) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_bytes<E>(self, _: &[u8]) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_borrowed_bytes<E>(self, _: &'de [u8]) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_byte_buf<E>(self, _: Vec<u8>) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_none<E>(self) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error> + where + D: Deserializer<'de> + { + deserializer.deserialize_any(AnyVisitor) + } + + fn visit_unit<E>(self) -> Result<Self::Value, E> + where + E: Error + { + Ok(()) + } + + fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error> + where + D: Deserializer<'de> + { + deserializer.deserialize_any(AnyVisitor) + } + + fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error> + where + A: SeqAccess<'de> + { + while let Some(_) = seq.next_element::<AnyDeserialize>()? {} + Ok(()) + } + + fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error> + where + A: MapAccess<'de> + { + while let Some(_) = map.next_entry::<AnyDeserialize, AnyDeserialize>()? {} + Ok(()) + } + + fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error> + where + A: EnumAccess<'de> + { + let (_, variant) = data.variant::<AnyDeserialize>()?; + variant.unit_variant()?; + Ok(()) + } +} + +pub struct AnyDeserialize; + +impl<'de> Deserialize<'de> for AnyDeserialize { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: Deserializer<'de> + { + deserializer.deserialize_any(AnyVisitor).map(|()| Self) + } +} + diff --git a/avro/src/serde/deser_resolving/mod.rs b/avro/src/serde/deser_resolving/mod.rs new file mode 100644 index 0000000..78fbd6b --- /dev/null +++ b/avro/src/serde/deser_resolving/mod.rs @@ -0,0 +1,665 @@ +mod any; +mod record; + +use std::borrow::Borrow; +use std::collections::HashMap; +use std::io::Read; +use serde::de::Visitor; +use serde::Deserialize; +use crate::{Error, Schema}; +use crate::error::Details; +use crate::schema::{DecimalSchema, InnerDecimalSchema, Name, SchemaKind, UnionSchema, UuidSchema}; +use crate::serde::deser_resolving::any::AnyVisitor; +use crate::serde::deser_schema::{SchemaAwareDeserializer, DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS}; +use crate::util::zag_i32; + +#[derive(Debug)] +pub struct Config<'s, S: Borrow<Schema>> { + /// All names that can be referenced in the writer schema. + pub writer_names: &'s HashMap<Name, S>, + /// All names that can be referenced in the reader schema. + pub reader_names: &'s HashMap<Name, S>, + /// Should `Deserialize` implementations pick a human-readable format. + /// + /// This should match the setting used for serialisation. + pub human_readable: bool, +} + +// This needs to be implemented manually as the derive puts a bound on `S` +// which is not necessary as a reference is always Copy. +impl<'s, S: Borrow<Schema>> Copy for Config<'s, S> {} +impl<'s, S: Borrow<Schema>> Clone for Config<'s, S> { + fn clone(&self) -> Self { + *self + } +} + +impl<'s, S: Borrow<Schema>> From<Config<'s, S>> for super::deser_schema::Config<'s, S> { + fn from(value: Config<'s, S>) -> Self { + Self { + names: value.writer_names, + human_readable: value.human_readable, + } + } +} + +pub struct ResolvingDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> { + reader: &'r mut R, + writer_schema: &'s Schema, + reader_schema: &'s Schema, + config: Config<'s, S>, +} + +impl<'s, 'r, R: Read, S: Borrow<Schema>> ResolvingDeserializer<'s, 'r, R, S> { + pub fn new( + reader: &'r mut R, + writer_schema: &'s Schema, + reader_schema: &'s Schema, + config: Config<'s, S>, + ) -> Result<Self, Error> { + if let Schema::Ref { name } = writer_schema { + let writer_schema = config.writer_names.get(name).ok_or_else(|| Details::SchemaResolutionError(name.clone()))?.borrow(); + Self::new(reader, writer_schema, reader_schema, config) + } else if let Schema::Ref { name } = reader_schema { + let reader_schema = config.reader_names.get(name).ok_or_else(|| Details::SchemaResolutionError(name.clone()))?.borrow(); + Self::new(reader, writer_schema, reader_schema, config) + } else { + Ok(Self { + reader, + writer_schema, + reader_schema, + config, + }) + } + } + + fn error(&self, ty: &'static str, error: impl Into<String>) -> Error { + Error::new(Details::DeserializeValueWithResolvingSchema { + value_type: ty, + value: error.into(), + writer_schema: self.writer_schema.clone(), + reader_schema: self.reader_schema.clone(), + }) + } + + /// Create a new deserializer with the existing reader and config. + /// + /// This will resolve the schema if it is a reference. + fn with_different_schema(mut self, writer_schema: &'s Schema, reader_schema: &'s Schema) -> Result<Self, Error> { + if let Schema::Ref { name } = writer_schema { + let writer_schema = self.config.writer_names.get(name).ok_or_else(|| Details::SchemaResolutionError(name.clone()))?.borrow(); + self.with_different_schema(writer_schema, reader_schema) + } else if let Schema::Ref { name } = reader_schema { + let reader_schema = self.config.reader_names.get(name).ok_or_else(|| Details::SchemaResolutionError(name.clone()))?.borrow(); + self.with_different_schema(writer_schema, reader_schema) + } else { + self.writer_schema = writer_schema; + self.reader_schema = reader_schema; + Ok(self) + } + } + + /// Read the union and create a new deserializer with the existing reader and config. + /// + /// This will resolve the read schema if it is a reference. + fn with_reader_union(self, reader_schema: &'s UnionSchema) -> Result<Self, Error> { + if let Schema::Union(writer_schema) = self.writer_schema { + let index = zag_i32(self.reader)?; + let writer_schema = + writer_schema + .variants() + .get(index as usize) + .ok_or_else(|| Details::GetUnionVariant { + index: index as i64, + num_variants: writer_schema.variants().len(), + })?; + let Some(index) = reader_schema.find_compatible_variant(writer_schema, self.config.reader_names, self.config.writer_names) else { + panic!("writer variant is not in reader variant") + }; + let reader_schema = &reader_schema.variants()[index]; + self.with_different_schema(writer_schema, reader_schema) + } else if let Some(index) = reader_schema.find_compatible_variant(self.writer_schema, self.config.reader_names, self.config.writer_names) { + let reader_schema = &reader_schema.variants()[index]; + let writer_schema = self.writer_schema; + self.with_different_schema(writer_schema, reader_schema) + } else { + panic!("No match found") + } + } + + /// Read the union and create a new deserializer with the existing reader and config. + /// + /// This will resolve the read schema if it is a reference. + fn with_writer_union(self, writer_schema: &'s UnionSchema) -> Result<Self, Error> { + if let Schema::Union(_) = self.reader_schema { + unreachable!("This should only be called if only the writer schema is a union") + } + let index = zag_i32(self.reader)?; + let writer_schema = + writer_schema + .variants() + .get(index as usize) + .ok_or_else(|| Details::GetUnionVariant { + index: index as i64, + num_variants: writer_schema.variants().len(), + })?; + let reader_schema = self.reader_schema; + self.with_different_schema(writer_schema, reader_schema) + } + + fn read_int(&mut self, original_ty: &'static str) -> Result<i32, Error> { + match self.reader_schema { + Schema::Int | Schema::Date | Schema::TimeMillis => SchemaAwareDeserializer::new(self.reader, self.writer_schema, self.config.into())?.read_int(original_ty), + _ => Err(self.error(original_ty, "Expected Schema::Int | Schema::Date | Schema::TimeMillis for reader")), + } + } + + fn read_long(&mut self, original_ty: &'static str) -> Result<i64, Error> { + match self.reader_schema { + Schema::Long | Schema::TimeMicros | Schema::TimestampMillis | Schema::TimestampMicros + | Schema::TimestampNanos | Schema::LocalTimestampMillis | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => match self.writer_schema { + Schema::Long | Schema::TimeMicros | Schema::TimestampMillis | Schema::TimestampMicros + | Schema::TimestampNanos | Schema::LocalTimestampMillis | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => SchemaAwareDeserializer::new(self.reader, self.writer_schema, self.config.into())?.read_long(original_ty), + Schema::Int | Schema::Date | Schema::TimeMillis => SchemaAwareDeserializer::new(self.reader, self.writer_schema, self.config.into())?.read_int(original_ty).map(i64::from), + _ => Err(self.error(original_ty, "Expected Schema::Int | Schema::Date | Schema::Long | Schema::Time{Millis,Micros} | Schema::{,Local}Timestamp{Millis,Micros,Nanos} for writer")), + }, + _ => Err(self.error(original_ty, "Expected Schema::Long | Schema::TimeMicros | Schema::{,Local}Timestamp{Millis,Micros,Nanos} for reader")), + } + } +} + +impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> serde::Deserializer<'de> for ResolvingDeserializer<'s, 'r, R, S> { + type Error = Error; + + fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Null => self.deserialize_unit(visitor), + Schema::Boolean => self.deserialize_bool(visitor), + Schema::Int | Schema::Date | Schema::TimeMillis => self.deserialize_i32(visitor), + Schema::Long + | Schema::TimeMicros + | Schema::TimestampMillis + | Schema::TimestampMicros + | Schema::TimestampNanos + | Schema::LocalTimestampMillis + | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => self.deserialize_i64(visitor), + Schema::Float => self.deserialize_f32(visitor), + Schema::Double => self.deserialize_f64(visitor), + Schema::Bytes + | Schema::Fixed(_) + | Schema::Decimal(_) + | Schema::BigDecimal + | Schema::Uuid(UuidSchema::Fixed(_)) + | Schema::Duration(_) => self.deserialize_byte_buf(visitor), + Schema::String | Schema::Uuid(UuidSchema::String) => self.deserialize_string(visitor), + Schema::Array(_) => self.deserialize_seq(visitor), + Schema::Map(_) => self.deserialize_map(visitor), + Schema::Union(union) => self.with_reader_union(union)?.deserialize_any(visitor), + Schema::Record(schema) => { + if schema.attributes.get("org.apache.avro.rust.tuple") + == Some(&serde_json::Value::Bool(true)) + { + // This is needed because we can't tell the difference between a tuple and struct. + // And a tuple needs to be deserialized as a sequence + self.deserialize_tuple(schema.fields.len(), visitor) + } else { + self.deserialize_struct(DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS, visitor) + } + } + Schema::Enum(_) => { + self.deserialize_enum(DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS, visitor) + } + Schema::Ref { .. } => unreachable!("References are resolved on deserializer creation"), + Schema::Uuid(UuidSchema::Bytes) => panic!("Unsupported"), + } + } + + fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_bool(visitor), + Schema::Boolean => match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_bool(visitor), + Schema::Boolean => SchemaAwareDeserializer::new(self.reader, self.reader_schema, self.config.into())?.deserialize_bool(visitor), + _ => Err(self.error("bool", "Expected Schema::Boolean for writer")) + }, + _ => Err(self.error("bool", "Expected Schema::Boolean for reader")) + } + } + + fn deserialize_i8<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + if let Schema::Union(union) = self.reader_schema { + self.with_reader_union(union)?.deserialize_i8(visitor) + } else { + let int = self.read_int("i8")?; + let value = i8::try_from(int) + .map_err(|_| self.error("i8", format!("Could not convert int ({int}) to an i8")))?; + visitor.visit_i8(value) + } + } + + fn deserialize_i16<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + if let Schema::Union(union) = self.reader_schema { + self.with_reader_union(union)?.deserialize_i16(visitor) + } else { + let int = self.read_int("i16")?; + let value = i16::try_from(int).map_err(|_| { + self.error("i16", format!("Could not convert int ({int}) to an i16")) + })?; + visitor.visit_i16(value) + } + } + + fn deserialize_i32<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + if let Schema::Union(union) = self.reader_schema { + self.with_reader_union(union)?.deserialize_i32(visitor) + } else { + let value = self.read_int("i32")?; + visitor.visit_i32(value) + } + } + + fn deserialize_i64<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + if let Schema::Union(union) = self.reader_schema { + self.with_reader_union(union)?.deserialize_i64(visitor) + } else { + let value = self.read_long("i64")?; + visitor.visit_i64(value) + } + } + + fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_i128(visitor), + Schema::Fixed(fixed) if fixed.size == 16 && fixed.name.name() == "i128" => match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_i128(visitor), + Schema::Fixed(fixed) if fixed.size == 16 && fixed.name.name() == "i128" => SchemaAwareDeserializer::new(self.reader, self.reader_schema, self.config.into())?.deserialize_i128(visitor), + _ => Err(self.error("i128", r#"Expected Schema::Fixed(name: "i128", size: 16) for writer"#)) + } + _ => Err(self.error("i128", r#"Expected Schema::Fixed(name: "i128", size: 16) for reader"#)), + } + } + + fn deserialize_u8<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + if let Schema::Union(union) = self.reader_schema { + self.with_reader_union(union)?.deserialize_u8(visitor) + } else { + let int = self.read_int("u8")?; + let value = u8::try_from(int) + .map_err(|_| self.error("u8", format!("Could not convert int ({int}) to an u8")))?; + visitor.visit_u8(value) + } + } + + fn deserialize_u16<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + if let Schema::Union(union) = self.reader_schema { + self.with_reader_union(union)?.deserialize_u16(visitor) + } else { + let int = self.read_int("u16")?; + let value = u16::try_from(int).map_err(|_| { + self.error("u16", format!("Could not convert int ({int}) to an u16")) + })?; + visitor.visit_u16(value) + } + } + + fn deserialize_u32<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + let int = self.read_long("u32")?; + let value = u32::try_from(int).map_err(|_| { + self.error("u32", format!("Could not convert int ({int}) to an u32")) + })?; + visitor.visit_u32(value) + } + + fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_u64(visitor), + Schema::Fixed(fixed) if fixed.size == 8 && fixed.name.name() == "u64" => match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_u64(visitor), + Schema::Fixed(fixed) if fixed.size == 8 && fixed.name.name() == "u64" => SchemaAwareDeserializer::new(self.reader, self.reader_schema, self.config.into())?.deserialize_u64(visitor), + _ => Err(self.error("u64", r#"Expected Schema::Fixed(name: "u64", size: 8) for writer"#)) + } + _ => Err(self.error("u64", r#"Expected Schema::Fixed(name: "u64", size: 8) for reader"#)), + } + } + + fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_u128(visitor), + Schema::Fixed(fixed) if fixed.size == 8 && fixed.name.name() == "u128" => match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_u128(visitor), + Schema::Fixed(fixed) if fixed.size == 8 && fixed.name.name() == "u128" => SchemaAwareDeserializer::new(self.reader, self.reader_schema, self.config.into())?.deserialize_u128(visitor), + _ => Err(self.error("u128", r#"Expected Schema::Fixed(name: "u128", size: 16) for writer"#)) + } + _ => Err(self.error("u128", r#"Expected Schema::Fixed(name: "u128", size: 16) for reader"#)), + } + } + + fn deserialize_f32<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_f32(visitor), + Schema::Float => match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_f32(visitor), + Schema::Float => SchemaAwareDeserializer::new(self.reader, self.reader_schema, self.config.into())?.deserialize_f32(visitor), + Schema::Int | Schema::Date | Schema::TimeMillis => { + let value = i32::deserialize(SchemaAwareDeserializer::new(self.reader, self.writer_schema, self.config.into())?)?; + visitor.visit_f32(value as f32) + } + Schema::Long | Schema::TimeMicros | Schema::TimestampMillis | Schema::TimestampMicros + | Schema::TimestampNanos | Schema::LocalTimestampMillis | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => { + let value = i64::deserialize(SchemaAwareDeserializer::new(self.reader, self.writer_schema, self.config.into())?)?; + visitor.visit_f32(value as f32) + } + _ => Err(self.error("f32", "Expected Schema::Float | Schema::Int | Schema::Date | Schema::Long | Schema::Time{Millis,Micros} | Schema::{,Local}Timestamp{Millis,Micros,Nanos} for writer")) + } + _ => Err(self.error("f32", "Expected Schema::Float for reader")), + } + } + + fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_f64(visitor), + Schema::Float => match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_f64(visitor), + Schema::Double => SchemaAwareDeserializer::new(self.reader, self.reader_schema, self.config.into())?.deserialize_f64(visitor), + Schema::Float => { + let float = f32::deserialize(SchemaAwareDeserializer::new(self.reader, self.reader_schema, self.config.into())?)?; + visitor.visit_f64(float as f64) + } + Schema::Int | Schema::Date | Schema::TimeMillis => { + let value = i32::deserialize(SchemaAwareDeserializer::new(self.reader, self.writer_schema, self.config.into())?)?; + visitor.visit_f64(value as f64) + } + Schema::Long | Schema::TimeMicros | Schema::TimestampMillis | Schema::TimestampMicros + | Schema::TimestampNanos | Schema::LocalTimestampMillis | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => { + let value = i64::deserialize(SchemaAwareDeserializer::new(self.reader, self.writer_schema, self.config.into())?)?; + visitor.visit_f64(value as f64) + } + _ => Err(self.error("f64", "Expected Schema::Float | Schema::Double Schema::Int | Schema::Date | Schema::Long | Schema::Time{Millis,Micros} | Schema::{,Local}Timestamp{Millis,Micros,Nanos} for writer")) + } + _ => Err(self.error("f64", "Expected Schema::Double for reader")), + } + } + + fn deserialize_char<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_char(visitor), + Schema::String | Schema::Uuid(UuidSchema::String) => match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_char(visitor), + Schema::String | Schema::Bytes | Schema::Uuid(UuidSchema::String | UuidSchema::Bytes) | Schema::Decimal(DecimalSchema { inner: InnerDecimalSchema::Bytes, ..}) | Schema::BigDecimal => { + let string = SchemaAwareDeserializer::new(self.reader, self.writer_schema, self.config.into())?.read_string()?; + let mut chars = string.chars(); + if let Some(character) = chars.next() { + if chars.next().is_some() { + Err(self.error("char", "String contains more than one character")) + } else { + visitor.visit_char(character) + } + } else { + Err(self.error("char", "String is empty")) + } + } + _ => Err(self.error("string", "Expected Schema::String | Schema::Bytes | Schema::Uuid(String | Bytes) | Schema::Decimal(Bytes) | Schema::BigDecimal for writer")), + } + _ => Err(self.error("string", "Expected Schema::String | Schema::Uuid(String) for reader")), + } + } + + fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + self.deserialize_string(visitor) + } + + fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_string(visitor), + Schema::String | Schema::Uuid(UuidSchema::String) => match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_string(visitor), + Schema::String | Schema::Bytes | Schema::Uuid(UuidSchema::String | UuidSchema::Bytes) | Schema::Decimal(DecimalSchema { inner: InnerDecimalSchema::Bytes, ..}) | Schema::BigDecimal => { + let string = SchemaAwareDeserializer::new(self.reader, self.writer_schema, self.config.into())?.read_string()?; + visitor.visit_string(string) + } + _ => Err(self.error("string", "Expected Schema::String | Schema::Bytes | Schema::Uuid(String | Bytes) | Schema::Decimal(Bytes) | Schema::BigDecimal for writer")), + } + _ => Err(self.error("string", "Expected Schema::String | Schema::Uuid(String) for reader")), + } + } + + fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + self.deserialize_byte_buf(visitor) + } + + fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_bytes(visitor), + Schema::Bytes | Schema::Uuid(UuidSchema::Bytes) | Schema::Decimal(DecimalSchema { inner: InnerDecimalSchema::Bytes, ..}) | Schema::BigDecimal=> match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_bytes(visitor), + Schema::String | Schema::Bytes | Schema::Uuid(UuidSchema::String | UuidSchema::Bytes) | Schema::Decimal(DecimalSchema { inner: InnerDecimalSchema::Bytes, ..}) | Schema::BigDecimal => { + let bytes = SchemaAwareDeserializer::new(self.reader, self.writer_schema, self.config.into())?.read_bytes_with_len()?; + visitor.visit_byte_buf(bytes) + } + _ => Err(self.error("bytes", "Expected Schema::String | Schema::Bytes | Schema::Uuid(String | Bytes) | Schema::Decimal(Bytes) | Schema::BigDecimal for writer")), + } + _ => Err(self.error("bytes", "Expected Schema::Bytes | Schema::Uuid(Bytes) | Schema::Decimal(Bytes) | Schema::BigDecimal for reader")), + } + } + + fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + // The reader schema must be a Union([Null, _]) + if let Schema::Union(union) = self.reader_schema + && union.variants().len() == 2 + && let Some(null_index) = union.index().get(&SchemaKind::Null).copied() + { + let some_index = (null_index + 1) & 1; + // Map the writer schema to the reader Some or None + match self.writer_schema { + Schema::Union(union) => { + let index = zag_i32(self.reader)? as usize; + let writer_schema = &union.variants()[index]; + if writer_schema == &Schema::Null { + visitor.visit_none() + } else { + let reader_schema = &union.variants()[some_index]; + visitor.visit_some(self.with_different_schema(writer_schema, reader_schema)?) + } + }, + Schema::Null => visitor.visit_none(), + _ => { + let reader_schema = &union.variants()[some_index]; + let writer_schema = self.writer_schema; + visitor.visit_some(self.with_different_schema(writer_schema, reader_schema)?) + } + } + } else { + Err(self.error("option", "Expected Schema::Union([Null, _])")) + } + } + + fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_unit(visitor), + Schema::Null => match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_unit(visitor), + Schema::Null => visitor.visit_unit(), + _ => Err(self.error("unit", "Expected Schema::Null for writer")), + } + _ => Err(self.error("unit", "Expected Schema::Null for reader")), + } + } + + fn deserialize_unit_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_unit_struct(name, visitor), + Schema::Record(record) if record.fields.is_empty() && record.name.name() == name => match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_unit_struct(name, visitor), + Schema::Record(record) if record.name.name() == name => { + if !record.fields.is_empty() { + // Ignore all fields + SchemaAwareDeserializer::new(self.reader, self.writer_schema, self.config.into())?.deserialize_struct(name, DESERIALIZE_ANY_FIELDS, AnyVisitor)?; + } + visitor.visit_unit() + } + _ => Err(self.error("unit", "Expected Schema::Null for writer")), + } + _ => Err(self.error("unit", "Expected Schema::Null for reader")), + } + } + + fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.reader_schema { + Schema::Union(union) => self.with_reader_union(union)?.deserialize_newtype_struct(name, visitor), + Schema::Record(w_record) if w_record.fields.len() == 1 && w_record.name.name() == name => match self.writer_schema { + Schema::Union(union) => self.with_writer_union(union)?.deserialize_newtype_struct(name, visitor), + Schema::Record(r_record) if r_record.name.name() == name && r_record.fields.is_empty() => { + let field = &w_record.fields[0]; + if let Some(default) = &field.default { + + todo!() + } else { + Err(self.error("newtype struct", "Writer is missing field and no default is available")) + } + } + Schema::Record(r_record) if r_record.name.name() == name && r_record.fields.len() == 1 => { + visitor.visit_newtype_struct(self.with_different_schema(&w_record.fields[0].schema, &r_record.fields[0].schema)?) + } + Schema::Record(r_record) if r_record.name.name() == name => { + todo!("Skip all fields that do not match") + } + _ => Err(self.error("unit", "Expected Schema::Null for writer")), + } + _ => Err(self.error("unit", "Expected Schema::Null for reader")), + } + } + + fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_tuple_struct<V>(self, name: &'static str, len: usize, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_struct<V>(self, name: &'static str, fields: &'static [&'static str], visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_enum<V>(self, name: &'static str, variants: &'static [&'static str], visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn is_human_readable(&self) -> bool { + self.config.human_readable + } +} diff --git a/avro/src/serde/deser_resolving/record.rs b/avro/src/serde/deser_resolving/record.rs new file mode 100644 index 0000000..409dfe4 --- /dev/null +++ b/avro/src/serde/deser_resolving/record.rs @@ -0,0 +1 @@ +mod default; \ No newline at end of file diff --git a/avro/src/serde/deser_resolving/record/default.rs b/avro/src/serde/deser_resolving/record/default.rs new file mode 100644 index 0000000..f5c6718 --- /dev/null +++ b/avro/src/serde/deser_resolving/record/default.rs @@ -0,0 +1,340 @@ +use std::borrow::Borrow; +use serde::de::Visitor; +use serde::Deserializer; +use serde_json::Value; +use crate::{Error, Schema}; +use crate::error::Details; +use crate::schema::{UnionSchema, UuidSchema}; +use crate::serde::deser_schema::{Config, DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS}; +use crate::util::zag_i32; + +pub struct DefaultDeserializer<'s, S: Borrow<Schema>> { + default: &'s Value, + schema: &'s Schema, + config: &'s Config<'s, S>, +} + +impl<'s, S: Borrow<Schema>> DefaultDeserializer<'s, S> { + pub fn new(default: &'s Value, schema: &'s Schema, config: &'s Config<'s, S>) -> Result<Self, Error> { + if let Schema::Ref { name } = schema { + let schema = config + .names + .get(name) + .ok_or_else(|| Details::SchemaResolutionError(name.clone()))? + .borrow(); + Self::new(default, schema, config) + } else { + Ok(Self { + default, + schema, + config, + }) + } + } + + fn error(&self, ty: &'static str, error: impl Into<String>) -> Error { + Error::new(Details::DeserializeValueWithSchema { + value_type: ty, + value: error.into(), + schema: self.schema.clone(), + }) + } + + /// Create a new deserializer with the existing reader and config. + /// + /// This will resolve the schema if it is a reference. + fn with_different_schema(mut self, schema: &'s Schema) -> Result<Self, Error> { + let schema = if let Schema::Ref { name } = schema { + self.config + .names + .get(name) + .ok_or_else(|| Details::SchemaResolutionError(name.clone()))? + .borrow() + } else { + schema + }; + self.schema = schema; + Ok(self) + } + + /// Read the union and create a new deserializer with the existing reader and config. + /// + /// This will resolve the read schema if it is a reference. + fn with_union(self, schema: &'s UnionSchema) -> Result<Self, Error> { + let index = zag_i32(self.reader)?; + let variant = + schema + .variants() + .get(index as usize) + .ok_or_else(|| Details::GetUnionVariant { + index: index as i64, + num_variants: schema.variants().len(), + })?; + self.with_different_schema(variant) + } +} + +impl<'de, 's, S: Borrow<Schema>> Deserializer<'de> for DefaultDeserializer<'s, S> { + type Error = Error; + + fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + match self.schema { + Schema::Null => self.deserialize_unit(visitor), + Schema::Boolean => self.deserialize_bool(visitor), + Schema::Int | Schema::Date | Schema::TimeMillis => self.deserialize_i32(visitor), + Schema::Long + | Schema::TimeMicros + | Schema::TimestampMillis + | Schema::TimestampMicros + | Schema::TimestampNanos + | Schema::LocalTimestampMillis + | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => self.deserialize_i64(visitor), + Schema::Float => self.deserialize_f32(visitor), + Schema::Double => self.deserialize_f64(visitor), + Schema::Bytes + | Schema::Fixed(_) + | Schema::Decimal(_) + | Schema::BigDecimal + | Schema::Uuid(UuidSchema::Fixed(_)) + | Schema::Duration(_) => self.deserialize_byte_buf(visitor), + Schema::String | Schema::Uuid(UuidSchema::String) => self.deserialize_string(visitor), + Schema::Array(_) => self.deserialize_seq(visitor), + Schema::Map(_) => self.deserialize_map(visitor), + Schema::Union(union) => self.with_union(union)?.deserialize_any(visitor), + Schema::Record(schema) => { + if schema.attributes.get("org.apache.avro.rust.tuple") + == Some(&serde_json::Value::Bool(true)) + { + // This is needed because we can't tell the difference between a tuple and struct. + // And a tuple needs to be deserialized as a sequence + self.deserialize_tuple(schema.fields.len(), visitor) + } else { + self.deserialize_struct(DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS, visitor) + } + } + Schema::Enum(_) => { + self.deserialize_enum(DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS, visitor) + } + Schema::Ref { .. } => unreachable!("References are resolved on deserializer creation"), + Schema::Uuid(UuidSchema::Bytes) => panic!("Unsupported"), + } + } + + fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_unit_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_tuple_struct<V>(self, name: &'static str, len: usize, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_struct<V>(self, name: &'static str, fields: &'static [&'static str], visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_enum<V>(self, name: &'static str, variants: &'static [&'static str], visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de> + { + todo!() + } + + fn is_human_readable(&self) -> bool { + todo!() + } +} diff --git a/avro/src/serde/deser_schema/mod.rs b/avro/src/serde/deser_schema/mod.rs index cf9d0d7..5ce3947 100644 --- a/avro/src/serde/deser_schema/mod.rs +++ b/avro/src/serde/deser_schema/mod.rs @@ -47,11 +47,11 @@ use crate::{ #[derive(Debug)] pub struct Config<'s, S: Borrow<Schema>> { - /// All names that can be referenced in the schema being used for serialisation. + /// All names that can be referenced in the schema being used for deserialization. pub names: &'s HashMap<Name, S>, - /// Should `Serialize` implementations pick a human-readable format. + /// Should `Deserialize` implementations pick a human-readable format. /// - /// It is recommended to set this to `false` as it results in compacter output. + /// This should match the setting used for serialization. pub human_readable: bool, } @@ -141,7 +141,7 @@ impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S> { self.with_different_schema(variant) } - fn read_int(&mut self, original_ty: &'static str) -> Result<i32, Error> { + pub(crate) fn read_int(&mut self, original_ty: &'static str) -> Result<i32, Error> { match self.schema { Schema::Int | Schema::Date | Schema::TimeMillis => zag_i32(self.reader), _ => Err(self.error( @@ -151,7 +151,7 @@ impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S> { } } - fn read_long(&mut self, original_ty: &'static str) -> Result<i64, Error> { + pub(crate) fn read_long(&mut self, original_ty: &'static str) -> Result<i64, Error> { match self.schema { Schema::Long | Schema::TimeMicros | Schema::TimestampMillis | Schema::TimestampMicros | Schema::TimestampNanos | Schema::LocalTimestampMillis | Schema::LocalTimestampMicros @@ -163,12 +163,12 @@ impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S> { } } - fn read_string(&mut self) -> Result<String, Error> { + pub(crate) fn read_string(&mut self) -> Result<String, Error> { let bytes = self.read_bytes_with_len()?; Ok(String::from_utf8(bytes).map_err(Details::ConvertToUtf8)?) } - fn read_bytes_with_len(&mut self) -> Result<Vec<u8>, Error> { + pub(crate) fn read_bytes_with_len(&mut self) -> Result<Vec<u8>, Error> { let length = decode_len(self.reader)?; self.read_bytes(length) } @@ -182,8 +182,8 @@ impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S> { } } -static DESERIALIZE_ANY: &str = "This value is compared by pointer value"; -static DESERIALIZE_ANY_FIELDS: &[&str] = &[]; +pub(super) static DESERIALIZE_ANY: &str = "This value is compared by pointer value"; +pub(super) static DESERIALIZE_ANY_FIELDS: &[&str] = &[]; impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> for SchemaAwareDeserializer<'s, 'r, R, S> @@ -317,6 +317,23 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> } } + fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + match self.schema { + Schema::Fixed(fixed) if fixed.size == 16 && fixed.name.name() == "i128" => { + let mut buf = [0; 16]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_i128(i128::from_le_bytes(buf)) + } + Schema::Union(union) => self.with_union(union)?.deserialize_i128(visitor), + _ => Err(self.error("i128", r#"Expected Schema::Fixed(name: "i128", size: 16)"#)), + } + } + fn deserialize_u8<V>(mut self, visitor: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de>, @@ -378,6 +395,23 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> } } + fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: serde::de::Visitor<'de>, + { + match self.schema { + Schema::Fixed(fixed) if fixed.size == 16 && fixed.name.name() == "u128" => { + let mut buf = [0; 16]; + self.reader + .read_exact(&mut buf) + .map_err(Details::ReadBytes)?; + visitor.visit_u128(u128::from_le_bytes(buf)) + } + Schema::Union(union) => self.with_union(union)?.deserialize_u128(visitor), + _ => Err(self.error("u128", r#"Expected Schema::Fixed(name: "u128", size: 16)"#)), + } + } + fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de>, @@ -673,40 +707,6 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de> } } - fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - match self.schema { - Schema::Fixed(fixed) if fixed.size == 16 && fixed.name.name() == "i128" => { - let mut buf = [0; 16]; - self.reader - .read_exact(&mut buf) - .map_err(Details::ReadBytes)?; - visitor.visit_i128(i128::from_le_bytes(buf)) - } - Schema::Union(union) => self.with_union(union)?.deserialize_i128(visitor), - _ => Err(self.error("i128", r#"Expected Schema::Fixed(name: "i128", size: 16)"#)), - } - } - - fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: serde::de::Visitor<'de>, - { - match self.schema { - Schema::Fixed(fixed) if fixed.size == 16 && fixed.name.name() == "u128" => { - let mut buf = [0; 16]; - self.reader - .read_exact(&mut buf) - .map_err(Details::ReadBytes)?; - visitor.visit_u128(u128::from_le_bytes(buf)) - } - Schema::Union(union) => self.with_union(union)?.deserialize_u128(visitor), - _ => Err(self.error("u128", r#"Expected Schema::Fixed(name: "u128", size: 16)"#)), - } - } - fn is_human_readable(&self) -> bool { self.config.human_readable } diff --git a/avro/src/serde/mod.rs b/avro/src/serde/mod.rs index 11ffd84..5ad5638 100644 --- a/avro/src/serde/mod.rs +++ b/avro/src/serde/mod.rs @@ -114,6 +114,7 @@ mod ser; pub(crate) mod ser_schema; mod util; mod with; +// mod deser_resolving; #[expect( deprecated, diff --git a/avro/src/serde/ser_schema/mod.rs b/avro/src/serde/ser_schema/mod.rs index 0387097..bff7ccd 100644 --- a/avro/src/serde/ser_schema/mod.rs +++ b/avro/src/serde/ser_schema/mod.rs @@ -602,25 +602,8 @@ impl<'s, 'w, W: Write> Serializer for SchemaAwareSerializer<'s, 'w, W> { && record.name.name() == variant && record.fields.len() == len { - // Union of records let bytes_written = encode_int(variant_index as i32, &mut *self.writer)?; - Ok(ManyTupleSerializer::new( - self.writer, - record, - self.config, - Some(bytes_written), - )) - } else if let Some((index, Schema::Record(record))) = union - .variants() - .iter() - .enumerate() - .find(|(_i, s)| s.name().is_some_and(|n| n.name() == variant)) - && record.fields.len() == len - { - // Bare union - let bytes_written = encode_int(index as i32, &mut *self.writer)?; - Ok(ManyTupleSerializer::new( self.writer, record, @@ -669,6 +652,9 @@ impl<'s, 'w, W: Write> Serializer for SchemaAwareSerializer<'s, 'w, W> { len: usize, ) -> Result<Self::SerializeStruct, Self::Error> { match self.schema { + // Serde is inconsistent with the `name` and `len` provided. When using internally tagged + // enums the name can be the name of the inner type of a newtype variant. The length can + // also change based on `serialize_if`. Schema::Record(record) => Ok(RecordSerializer::new( self.writer, record, @@ -694,25 +680,8 @@ impl<'s, 'w, W: Write> Serializer for SchemaAwareSerializer<'s, 'w, W> { && record.name.name() == variant && record.fields.len() == len { - // Union of records let bytes_written = encode_int(variant_index as i32, &mut *self.writer)?; - Ok(RecordSerializer::new( - self.writer, - record, - self.config, - Some(bytes_written), - )) - } else if let Some((index, Schema::Record(record))) = union - .variants() - .iter() - .enumerate() - .find(|(_i, s)| s.name().is_some_and(|n| n.name() == variant)) - && record.fields.len() == len - { - // Bare union - let bytes_written = encode_int(index as i32, &mut *self.writer)?; - Ok(RecordSerializer::new( self.writer, record, @@ -1993,7 +1962,6 @@ mod tests { // TODO: Figure out what to do with Option<Enum> mapping to Union([Null, ..]) #[test] - #[ignore] fn avro_rs_337_serialize_union_record_variant() -> TestResult { let schema = Schema::parse_str( r#"{ @@ -2072,90 +2040,6 @@ mod tests { Ok(()) } - // TODO: Figure out what to do with Option<Enum> mapping to Union([Null, ..]) - #[test] - #[ignore] - fn avro_rs_337_serialize_option_union_record_variant() -> TestResult { - let schema = Schema::parse_str( - r#"{ - "type": "record", - "name": "TestRecord", - "fields": [{ - "name": "innerUnion", "type": [ - "null", - {"type": "record", "name": "innerRecordFoo", "fields": [ - {"name": "foo", "type": "string"} - ]}, - {"type": "record", "name": "innerRecordBar", "fields": [ - {"name": "bar", "type": "string"} - ]}, - {"name": "intField", "type": "int"}, - {"name": "stringField", "type": "string"} - ], - }] - }"#, - )?; - - #[derive(Serialize)] - #[serde(rename_all = "camelCase")] - struct TestRecord { - inner_union: Option<InnerUnion>, - } - - #[derive(Serialize)] - #[serde(untagged)] - enum InnerUnion { - InnerVariantFoo(InnerRecordFoo), - InnerVariantBar(InnerRecordBar), - IntField(i32), - StringField(String), - } - - #[derive(Serialize)] - #[serde(rename = "innerRecordFoo")] - struct InnerRecordFoo { - foo: String, - } - - #[derive(Serialize)] - #[serde(rename = "innerRecordBar")] - struct InnerRecordBar { - bar: String, - } - - let mut buffer: Vec<u8> = Vec::new(); - let rs = ResolvedSchema::try_from(&schema)?; - let config = Config { - names: rs.get_names(), - target_block_size: None, - human_readable: false, - }; - - let null_record = TestRecord { inner_union: None }; - null_record.serialize(SchemaAwareSerializer::new(&mut buffer, &schema, config)?)?; - let foo_record = TestRecord { - inner_union: Some(InnerUnion::InnerVariantFoo(InnerRecordFoo { - foo: String::from("foo"), - })), - }; - foo_record.serialize(SchemaAwareSerializer::new(&mut buffer, &schema, config)?)?; - let bar_record = TestRecord { - inner_union: Some(InnerUnion::InnerVariantBar(InnerRecordBar { - bar: String::from("bar"), - })), - }; - bar_record.serialize(SchemaAwareSerializer::new(&mut buffer, &schema, config)?)?; - let int_record = TestRecord { - inner_union: Some(InnerUnion::IntField(1)), - }; - int_record.serialize(SchemaAwareSerializer::new(&mut buffer, &schema, config)?)?; - let string_record = TestRecord { - inner_union: Some(InnerUnion::StringField(String::from("string"))), - }; - string_record.serialize(SchemaAwareSerializer::new(&mut buffer, &schema, config)?)?; - Ok(()) - } - #[test] fn avro_rs_351_different_field_order_serde_vs_schema() -> TestResult { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] diff --git a/avro/src/serde/with.rs b/avro/src/serde/with.rs index 8d06967..98f2887 100644 --- a/avro/src/serde/with.rs +++ b/avro/src/serde/with.rs @@ -117,6 +117,11 @@ pub mod bytes { None } + /// Returns `None` + pub fn field_default() -> Option<serde_json::Value> { + None + } + pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, @@ -184,6 +189,11 @@ pub mod bytes_opt { None } + /// Returns `Some(serde_json::Value::Null)` + pub fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::Null) + } + pub fn serialize<S, B>(bytes: &Option<B>, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, @@ -262,6 +272,11 @@ pub mod fixed { None } + /// Returns `None` + pub fn field_default() -> Option<serde_json::Value> { + None + } + pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, @@ -336,6 +351,11 @@ pub mod fixed_opt { None } + /// Returns `Some(serde_json::Value::Null)` + pub fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::Null) + } + pub fn serialize<S, B>(bytes: &Option<B>, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, @@ -405,6 +425,11 @@ pub mod slice { None } + /// Returns `None` + pub fn field_default() -> Option<serde_json::Value> { + None + } + pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, @@ -475,6 +500,11 @@ pub mod slice_opt { None } + /// Returns `Some(serde_json::Value::Null)` + pub fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::Null) + } + pub fn serialize<S, B>(bytes: &Option<B>, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, @@ -543,6 +573,11 @@ pub mod bigdecimal { None } + /// Returns `None` + pub fn field_default() -> Option<serde_json::Value> { + None + } + pub fn serialize<S>(decimal: &BigDecimal, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, @@ -558,9 +593,10 @@ pub mod bigdecimal { { let _bytes_guard = super::BytesTypeGuard::set(BytesType::Bytes); let _guard = super::BorrowedGuard::set(true); - let bytes: &'de [u8] = serde_bytes::deserialize(deserializer)?; + // We don't use &'de [u8] here as the deserializer doesn't support that + let bytes: Vec<u8> = serde_bytes::deserialize(deserializer)?; - deserialize_big_decimal(bytes).map_err(D::Error::custom) + deserialize_big_decimal(&bytes).map_err(D::Error::custom) } } @@ -615,6 +651,11 @@ pub mod bigdecimal_opt { None } + /// Returns `Some(serde_json::Value::Null)` + pub fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::Null) + } + pub fn serialize<S>(decimal: &Option<BigDecimal>, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, @@ -690,6 +731,11 @@ pub mod array { None } + /// Returns `None` + pub fn field_default() -> Option<serde_json::Value> { + None + } + pub fn serialize<const N: usize, S, T>(value: &[T; N], serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, @@ -765,6 +811,11 @@ pub mod array_opt { None } + /// Returns `Some(serde_json::Value::Null)` + pub fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::Null) + } + pub fn serialize<const N: usize, S, T>( value: &Option<[T; N]>, serializer: S, diff --git a/avro_derive/src/attributes/mod.rs b/avro_derive/src/attributes/mod.rs index 396e6fa..979f03d 100644 --- a/avro_derive/src/attributes/mod.rs +++ b/avro_derive/src/attributes/mod.rs @@ -39,6 +39,99 @@ pub enum EnumRepr { RecordInternallyTagged { tag: String }, } +impl EnumRepr { + fn from_avro_and_serde( + avro: Option<avro::EnumRepr>, + tag: Option<String>, + content: Option<String>, + untagged: bool, + span: Span, + ) -> Result<Option<Self>, Vec<syn::Error>> { + let mut errors = Vec::new(); + + #[expect(clippy::manual_map, reason = "Intent is clearer this way")] + let repr = if let Some(repr) = avro { + match repr { + avro::EnumRepr::Enum => { + if tag.is_some() || content.is_some() || untagged { + errors.push(syn::Error::new( + span, + r#"AvroSchema: `#[avro(repr = "enum")]` is incompatible with `#[serde(tag = "..")]`, `#[serde(content = "..")]`, and `#[serde(untagged)]`"#, + )); + None + } else { + Some(EnumRepr::Enum) + } + } + avro::EnumRepr::BareUnion => { + if tag.is_some() || content.is_some() { + errors.push(syn::Error::new( + span, + r#"AvroSchema: `#[avro(repr = "bare_union")]` is incompatible with `#[serde(tag = "..")]` and `#[serde(content = "..")]`"#, + )); + None + } else { + Some(EnumRepr::BareUnion) + } + } + avro::EnumRepr::UnionOfRecords => { + if tag.is_some() || content.is_some() || untagged { + errors.push(syn::Error::new( + span, + r#"AvroSchema: `#[avro(repr = "union_of_records")]` is incompatible with `#[serde(tag = "..")]`, `#[serde(content = "..")]`, and `#[serde(untagged)]`"#, + )); + None + } else { + Some(EnumRepr::UnionOfRecords) + } + } + avro::EnumRepr::RecordTagContent => { + if let Some(tag) = tag + && let Some(content) = content + { + Some(EnumRepr::RecordTagContent { tag, content }) + } else { + errors.push(syn::Error::new( + span, + r#"AvroSchema: `#[avro(repr = "record_tag_content")]` requires `#[serde(tag = "..", content = "..")]`"#, + )); + None + } + } + avro::EnumRepr::RecordInternallyTagged => { + if let Some(tag) = tag + && content.is_none() + { + Some(EnumRepr::RecordInternallyTagged { tag }) + } else { + errors.push(syn::Error::new( + span, + r#"AvroSchema: `#[avro(repr = "record_internally_tagged")]` requires `#[serde(tag = "..")]` and is incompatible with `#[serde(content = "..")]`"#, + )); + None + } + } + } + } else if let Some(content) = content + && let Some(tag) = tag + { + Some(EnumRepr::RecordTagContent { tag, content }) + } else if untagged { + Some(EnumRepr::BareUnion) + } else if let Some(tag) = tag { + Some(EnumRepr::RecordInternallyTagged { tag }) + } else { + None + }; + + if errors.is_empty() { + Ok(repr) + } else { + Err(errors) + } + } +} + pub struct NamedTypeOptions { pub name: String, pub doc: Option<String>, @@ -153,79 +246,18 @@ impl NamedTypeOptions { )); } - #[expect(clippy::manual_map, reason = "Intent is clearer this way")] - let repr = if let Some(repr) = avro.repr { - match repr { - avro::EnumRepr::Enum => { - if serde.tag.is_some() || serde.content.is_some() || serde.untagged { - errors.push(syn::Error::new( - span, - r#"AvroSchema: `#[avro(repr = "enum")]` is incompatible with `#[serde(tag = "..")]`, `#[serde(content = "..")]`, and `#[serde(untagged)]`"#, - )); - None - } else { - Some(EnumRepr::Enum) - } - } - avro::EnumRepr::BareUnion => { - if serde.untagged { - Some(EnumRepr::BareUnion) - } else { - errors.push(syn::Error::new( - span, - r#"AvroSchema: `#[avro(repr = "bare_union")]` requires `#[serde(untagged)]`"#, - )); - None - } - } - avro::EnumRepr::UnionOfRecords => { - if serde.tag.is_some() || serde.content.is_some() || serde.untagged { - errors.push(syn::Error::new( - span, - r#"AvroSchema: `#[avro(repr = "union_of_records")]` is incompatible with `#[serde(tag = "..")]`, `#[serde(content = "..")]`, and `#[serde(untagged)]`"#, - )); - None - } else { - Some(EnumRepr::UnionOfRecords) - } - } - avro::EnumRepr::RecordTagContent => { - if let Some(tag) = serde.tag - && let Some(content) = serde.content - { - Some(EnumRepr::RecordTagContent { tag, content }) - } else { - errors.push(syn::Error::new( - span, - r#"AvroSchema: `#[avro(repr = "record_tag_content")]` requires `#[serde(tag = "..", content = "..")]`"#, - )); - None - } - } - avro::EnumRepr::RecordInternallyTagged => { - if let Some(tag) = serde.tag - && serde.content.is_none() - { - Some(EnumRepr::RecordInternallyTagged { tag }) - } else { - errors.push(syn::Error::new( - span, - r#"AvroSchema: `#[avro(repr = "discriminator_value")]` requires `#[serde(tag = "..")]` and is incompatible with `#[serde(content = "..")]`"#, - )); - None - } - } + let repr = match EnumRepr::from_avro_and_serde( + avro.repr, + serde.tag, + serde.content, + serde.untagged, + span, + ) { + Ok(repr) => repr, + Err(err) => { + errors.extend(err); + None } - } else if let Some(content) = serde.content - && let Some(tag) = serde.tag - { - Some(EnumRepr::RecordTagContent { tag, content }) - } else if serde.untagged { - Some(EnumRepr::BareUnion) - } else if let Some(tag) = serde.tag { - Some(EnumRepr::RecordInternallyTagged { tag }) - } else { - None }; let default = match avro.default { @@ -387,6 +419,8 @@ pub enum FieldDefault { Disabled, /// Use this JSON value. Value(String), + /// Use `module::get_schema_in_ctxt` where the module is defined by Serde's `with` attribute. + Serde(Path), } impl FromMeta for FieldDefault { @@ -484,6 +518,14 @@ impl FieldOptions { } }; + let default = if let With::Serde(path) = &with + && avro.default == FieldDefault::Trait + { + FieldDefault::Serde(path.clone()) + } else { + avro.default + }; + if !errors.is_empty() { return Err(errors); } @@ -492,7 +534,7 @@ impl FieldOptions { Ok(Self { doc, - default: avro.default, + default, alias: serde.alias, rename: serde.rename, skip: serde.skip || (serde.skip_serializing && serde.skip_deserializing), diff --git a/avro_derive/src/fields.rs b/avro_derive/src/fields.rs index a384818..c1daa55 100644 --- a/avro_derive/src/fields.rs +++ b/avro_derive/src/fields.rs @@ -95,6 +95,11 @@ pub fn to_default( ::std::option::Option::Some(::serde_json::from_str(#default_value).expect("Unreachable! Checked at compile time")) })) } + FieldDefault::Serde(path) => { + Ok(TypedTokenStream::<Option<serde_json::Value>>::new(quote! { + #path::field_default() + })) + } } } diff --git a/avro_derive/tests/derive.rs b/avro_derive/tests/derive.rs index 1dfac6c..b26b81c 100644 --- a/avro_derive/tests/derive.rs +++ b/avro_derive/tests/derive.rs @@ -928,7 +928,6 @@ struct Testu8 { // TODO: Needs new deserializer proptest! { #[test] -#[ignore] fn test_bytes_handled(a: Vec<u8>, b: [u8; 2]) { let test = Testu8 { a, @@ -1913,6 +1912,10 @@ fn avro_rs_397_with() { ) -> Schema { Schema::Bytes } + + pub fn field_default() -> Option<serde_json::Value> { + None + } } #[allow(dead_code)] diff --git a/avro_derive/tests/enum.rs b/avro_derive/tests/enum.rs index e4e272c..83bf508 100644 --- a/avro_derive/tests/enum.rs +++ b/avro_derive/tests/enum.rs @@ -601,3 +601,52 @@ fn avro_rs_xxx_enum_repr_union_of_records_struct() { other: true, }); } + +#[test] +fn avro_rs_xxx_enum_repr_bare_union_without_untagged_plain() { + #[derive(AvroSchema, Debug, Serialize, Deserialize, Clone, PartialEq)] + #[avro(repr = "bare_union")] + enum Foo { + A, + } + + let schema = Schema::parse_str(r#"["null"]"#).unwrap(); + + assert_eq!(Foo::get_schema(), schema); + serde_assert(Foo::A); +} + +#[test] +fn avro_rs_xxx_enum_repr_bare_union_without_untagged_tuple() { + #[derive(AvroSchema, Debug, Serialize, Deserialize, Clone, PartialEq)] + #[avro(repr = "bare_union")] + enum Foo { + B(String), + #[serde(rename = "D")] + C( + String, + #[serde(rename = "is_it_true", alias = "is_it_false")] bool, + ), + } + + let schema = Schema::parse_str( + r#"[ + "string", + { + "type": "record", + "name": "D", + "default": "null", + "fields": [ + { "name": "field_0", "type": "string" }, + { "name": "is_it_true", "aliases": ["is_it_false"], "type": "boolean" } + ], + "org.apache.avro.rust.tuple": true + } + ]"#, + ) + .unwrap(); + + assert_eq!(Foo::get_schema(), schema); + serde_assert(Foo::B("Something".to_string())); + serde_assert(Foo::C("Something".to_string(), true)); +} diff --git a/avro_derive/tests/serde.rs b/avro_derive/tests/serde.rs index a650dc6..7c4c2f4 100644 --- a/avro_derive/tests/serde.rs +++ b/avro_derive/tests/serde.rs @@ -586,7 +586,7 @@ mod field_attributes { #[avro(with = apache_avro::serde::fixed::get_schema_in_ctxt::<6>)] #[serde(with = "apache_avro::serde::fixed")] fixed_field: [u8; 6], - #[avro(with = apache_avro::serde::fixed_opt::get_schema_in_ctxt::<7>, default = false)] + #[avro(with = apache_avro::serde::fixed_opt::get_schema_in_ctxt::<7>)] #[serde(with = "apache_avro::serde::fixed_opt")] fixed_field_opt: Option<[u8; 7]>, diff --git a/avro_derive/tests/ui/avro_rs_xxx_bare_union_and_untagged.stderr b/avro_derive/tests/ui/avro_rs_xxx_bare_union_and_untagged.stderr index dbe041a..eb30e2c 100644 --- a/avro_derive/tests/ui/avro_rs_xxx_bare_union_and_untagged.stderr +++ b/avro_derive/tests/ui/avro_rs_xxx_bare_union_and_untagged.stderr @@ -1,12 +1,3 @@ -error: AvroSchema: `#[avro(repr = "bare_union")]` requires `#[serde(untagged)]` - --> tests/ui/avro_rs_xxx_bare_union_and_untagged.rs:21:1 - | -21 | / #[avro(repr = "bare_union")] -22 | | enum A { -23 | | A -24 | | } - | |_^ - error: More than one variant maps to Schema::Null, this is not supported for bare unions --> tests/ui/avro_rs_xxx_bare_union_and_untagged.rs:29:6 |
