This is an automated email from the ASF dual-hosted git repository.
kriskras99 pushed a commit to branch feat/enums
in repository https://gitbox.apache.org/repos/asf/avro-rs.git
The following commit(s) were added to refs/heads/feat/enums by this push:
new 1065d1e progress
1065d1e is described below
commit 1065d1ef92972cea326fe982413e2f4806995385
Author: default <[email protected]>
AuthorDate: Tue Mar 3 07:25:52 2026 +0000
progress
---
Cargo.lock | 16 ++++
avro/src/bigdecimal.rs | 6 +-
avro/src/reader/mod.rs | 48 ++++++++++-
avro/src/reader/single_object.rs | 25 ++++++
avro/src/serde/de.rs | 1 +
avro/src/serde/deser_schema/array.rs | 22 +++--
avro/src/serde/deser_schema/enums/plain.rs | 18 ++---
avro/src/serde/deser_schema/map.rs | 13 +--
avro/src/serde/deser_schema/mod.rs | 80 +++++++++++++------
avro/src/serde/deser_schema/record.rs | 14 ++--
avro/src/serde/deser_schema/tuple.rs | 25 +++---
avro/src/serde/deser_schema/union.rs | 30 ++++---
avro/src/serde/ser_schema/array.rs | 8 +-
avro/src/serde/ser_schema/map.rs | 6 +-
avro/src/serde/ser_schema/mod.rs | 32 ++++----
avro/src/serde/ser_schema/union.rs | 2 +-
avro/src/serde/with.rs | 7 +-
avro/tests/avro-rs-226.rs | 11 +--
avro_derive/Cargo.toml | 1 +
avro_derive/src/attributes/mod.rs | 70 ++++++++++++----
avro_derive/src/attributes/serde.rs | 6 +-
avro_derive/src/enums/bare_union.rs | 3 +-
avro_derive/src/lib.rs | 124 ++++++++++++++++-------------
avro_derive/tests/derive.rs | 55 ++++++++++---
avro_derive/tests/serde.rs | 43 +++++-----
25 files changed, 449 insertions(+), 217 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 4517f14..8504ea1 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -101,6 +101,7 @@ version = "0.22.0"
dependencies = [
"apache-avro",
"darling",
+ "hexdump",
"pretty_assertions",
"proc-macro2",
"proptest",
@@ -123,6 +124,12 @@ dependencies = [
"log",
]
+[[package]]
+name = "arrayvec"
+version = "0.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b"
+
[[package]]
name = "async-trait"
version = "0.1.89"
@@ -684,6 +691,15 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e712f64ec3850b98572bffac52e2c6f282b29fe6c5fa6d42334b30be438d95c1"
+[[package]]
+name = "hexdump"
+version = "0.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cf31ab66ed8145a1c7427bd8e9b42a6131bd74ccf444f69b9e620c2e73ded832"
+dependencies = [
+ "arrayvec",
+]
+
[[package]]
name = "id-arena"
version = "2.3.0"
diff --git a/avro/src/bigdecimal.rs b/avro/src/bigdecimal.rs
index c2f00b2..7df71dc 100644
--- a/avro/src/bigdecimal.rs
+++ b/avro/src/bigdecimal.rs
@@ -70,7 +70,7 @@ pub(crate) fn deserialize_big_decimal(mut bytes: &[u8]) ->
AvroResult<BigDecimal
#[cfg(test)]
mod tests {
use super::*;
- use crate::{Codec, Reader, Schema, Writer, error::Error, from_value,
types::Record};
+ use crate::{Codec, Reader, Schema, Writer, error::Error, types::Record};
use apache_avro_test_helper::TestResult;
use bigdecimal::{One, Zero};
use pretty_assertions::assert_eq;
@@ -222,9 +222,9 @@ mod tests {
// read record
let mut reader = Reader::new(&wrote_data[..])?;
- let value = reader.next().unwrap()?;
+ let value = reader.next_deser()?.unwrap();
- assert_eq!(test, from_value::<Test>(&value)?);
+ assert_eq!(test, value);
Ok(())
}
diff --git a/avro/src/reader/mod.rs b/avro/src/reader/mod.rs
index 904600d..4a7d6c0 100644
--- a/avro/src/reader/mod.rs
+++ b/avro/src/reader/mod.rs
@@ -24,12 +24,13 @@ use crate::{
AvroResult,
decode::{decode, decode_internal},
schema::{ResolvedSchema, Schema},
+ serde::deser_schema::{Config, SchemaAwareDeserializer},
types::Value,
};
use block::Block;
use bon::bon;
use serde::de::DeserializeOwned;
-use std::{collections::HashMap, io::Read};
+use std::{collections::HashMap, io::Read, marker::PhantomData};
/// Main interface for reading Avro formatted values.
///
@@ -53,6 +54,29 @@ pub struct Reader<'a, R> {
should_resolve_schema: bool,
}
+pub struct ReaderSerde<'a, R, T: DeserializeOwned> {
+ inner: Reader<'a, R>,
+ phantom: PhantomData<T>,
+}
+
+impl<R: Read, T: DeserializeOwned> Iterator for ReaderSerde<'_, R, T> {
+ type Item = AvroResult<T>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ // to prevent keep on reading after the first error occurs
+ if self.inner.errored {
+ return None;
+ };
+ match self.inner.next_deser::<T>() {
+ Ok(opt) => opt.map(Ok),
+ Err(e) => {
+ self.inner.errored = true;
+ Some(Err(e))
+ }
+ }
+ }
+}
+
#[bon]
impl<'a, R: Read> Reader<'a, R> {
/// Creates a `Reader` given something implementing the `io::Read` trait
to read from.
@@ -108,6 +132,13 @@ impl<'a, R: Read> Reader<'a, R> {
&self.block.user_metadata
}
+ pub fn into_serde_iter<T: DeserializeOwned>(self) -> ReaderSerde<'a, R, T>
{
+ ReaderSerde {
+ inner: self,
+ phantom: PhantomData,
+ }
+ }
+
#[inline]
fn read_next(&mut self) -> AvroResult<Option<Value>> {
let read_schema = if self.should_resolve_schema {
@@ -162,6 +193,21 @@ pub fn from_avro_datum<R: Read>(
}
}
+pub fn from_avro_datum_deser<R: Read, T: DeserializeOwned>(
+ writer_schema: &Schema,
+ reader: &mut R,
+) -> AvroResult<T> {
+ let names: HashMap<_, &Schema> = HashMap::new();
+ T::deserialize(SchemaAwareDeserializer::new(
+ reader,
+ writer_schema,
+ Config {
+ names: &names,
+ human_readable: false,
+ },
+ )?)
+}
+
/// Decode a `Value` from raw Avro data.
///
/// If the writer schema is incomplete, i.e. contains `Schema::Ref`s then it
will use the provided
diff --git a/avro/src/reader/single_object.rs b/avro/src/reader/single_object.rs
index c6151a3..6e0b744 100644
--- a/avro/src/reader/single_object.rs
+++ b/avro/src/reader/single_object.rs
@@ -19,6 +19,7 @@ use crate::decode::decode_internal;
use crate::error::Details;
use crate::headers::{HeaderBuilder, RabinFingerprintHeader};
use crate::schema::ResolvedOwnedSchema;
+use crate::serde::deser_schema::{Config, SchemaAwareDeserializer};
use crate::types::Value;
use crate::{AvroResult, AvroSchema, Schema, from_value};
use serde::de::DeserializeOwned;
@@ -68,6 +69,30 @@ impl GenericSingleObjectReader {
Err(io_error) => Err(Details::ReadHeader(io_error).into()),
}
}
+
+ pub fn read_deser<R: Read, T: DeserializeOwned>(&self, reader: &mut R) ->
AvroResult<T> {
+ let mut header = vec![0; self.expected_header.len()];
+ match reader.read_exact(&mut header) {
+ Ok(_) => {
+ if self.expected_header == header {
+ T::deserialize(SchemaAwareDeserializer::new(
+ reader,
+ self.write_schema.get_root_schema(),
+ Config {
+ names: self.write_schema.get_names(),
+ human_readable: false,
+ },
+ )?)
+ } else {
+ Err(
+
Details::SingleObjectHeaderMismatch(self.expected_header.clone(), header)
+ .into(),
+ )
+ }
+ }
+ Err(io_error) => Err(Details::ReadHeader(io_error).into()),
+ }
+ }
}
pub struct SpecificSingleObjectReader<T>
diff --git a/avro/src/serde/de.rs b/avro/src/serde/de.rs
index 18c42a2..bfad160 100644
--- a/avro/src/serde/de.rs
+++ b/avro/src/serde/de.rs
@@ -938,6 +938,7 @@ impl<'de> de::Deserializer<'de> for StringDeserializer {
///
/// This conversion can fail if the structure of the `Value` does not match the
/// structure expected by `D`.
+#[deprecated(since = "0.22.0", note = "Use the `deser` functions instead")]
pub fn from_value<'de, D: Deserialize<'de>>(value: &'de Value) -> Result<D,
Error> {
let de = Deserializer::new(value);
D::deserialize(de)
diff --git a/avro/src/serde/deser_schema/array.rs
b/avro/src/serde/deser_schema/array.rs
index 61415c2..08cb055 100644
--- a/avro/src/serde/deser_schema/array.rs
+++ b/avro/src/serde/deser_schema/array.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use std::borrow::Borrow;
use std::io::Read;
use serde::de::SeqAccess;
@@ -25,15 +26,15 @@ use crate::{
Error, Schema, schema::ArraySchema,
serde::deser_schema::SchemaAwareDeserializer, util::zag_i32,
};
-pub struct ArrayDeserializer<'s, 'r, R: Read> {
+pub struct ArrayDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
reader: &'r mut R,
schema: &'s ArraySchema,
- config: Config<'s>,
+ config: Config<'s, S>,
state: State,
}
-impl<'s, 'r, R: Read> ArrayDeserializer<'s, 'r, R> {
- pub fn new(reader: &'r mut R, schema: &'s ArraySchema, config: Config<'s>)
-> Self {
+impl<'s, 'r, R: Read, S: Borrow<Schema>> ArrayDeserializer<'s, 'r, R, S> {
+ pub fn new(reader: &'r mut R, schema: &'s ArraySchema, config: Config<'s,
S>) -> Self {
Self {
reader,
schema,
@@ -49,7 +50,7 @@ enum State {
Finished,
}
-impl<'de, 's, 'r, R: Read> SeqAccess<'de> for ArrayDeserializer<'s, 'r, R> {
+impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> SeqAccess<'de> for
ArrayDeserializer<'s, 'r, R, S> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>,
Self::Error>
@@ -65,11 +66,11 @@ impl<'de, 's, 'r, R: Read> SeqAccess<'de> for
ArrayDeserializer<'s, 'r, R> {
if remaining == 0 {
self.state = State::Finished
} else {
- self.state = State::ReadingValue(remaining.abs() as u32)
+ self.state = State::ReadingValue(remaining.unsigned_abs())
}
self.next_element_seed(seed)
}
- State::ReadingValue(remaining) => {
+ State::ReadingValue(mut remaining) => {
let v = match self.schema.items.as_ref() {
Schema::Union(union) => {
seed.deserialize(UnionDeserializer::new(self.reader,
union, self.config)?)?
@@ -81,7 +82,12 @@ impl<'de, 's, 'r, R: Read> SeqAccess<'de> for
ArrayDeserializer<'s, 'r, R> {
)?)?,
};
- self.state = State::ReadingValue(remaining);
+ remaining -= 1;
+ if remaining == 0 {
+ self.state = State::EndOfBlock;
+ } else {
+ self.state = State::ReadingValue(remaining);
+ }
Ok(Some(v))
}
diff --git a/avro/src/serde/deser_schema/enums/plain.rs
b/avro/src/serde/deser_schema/enums/plain.rs
index 5b4ca6e..65b5b81 100644
--- a/avro/src/serde/deser_schema/enums/plain.rs
+++ b/avro/src/serde/deser_schema/enums/plain.rs
@@ -304,15 +304,15 @@ impl<'de, 's, 'r, R: Read> Deserializer<'de> for
EnumIdentifierDeserializer<'s,
V: Visitor<'de>,
{
let index = zag_i32(self.reader)?;
- let symbol =
- self.schema
- .symbols
- .get(index as usize)
- .ok_or_else(|| Details::EnumSymbolIndex {
- index: index as usize,
- num_variants: self.schema.symbols.len(),
- })?;
- visitor.visit_str(&symbol)
+ let symbol = self
+ .schema
+ .symbols
+ .get(index as usize)
+ .ok_or(Details::EnumSymbolIndex {
+ index: index as usize,
+ num_variants: self.schema.symbols.len(),
+ })?;
+ visitor.visit_str(symbol)
}
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value,
Self::Error>
diff --git a/avro/src/serde/deser_schema/map.rs
b/avro/src/serde/deser_schema/map.rs
index 6f06cf3..53b7ed5 100644
--- a/avro/src/serde/deser_schema/map.rs
+++ b/avro/src/serde/deser_schema/map.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use std::borrow::Borrow;
use std::io::Read;
use serde::de::MapAccess;
@@ -25,15 +26,15 @@ use crate::{
Error, Schema, schema::MapSchema,
serde::deser_schema::SchemaAwareDeserializer, util::zag_i32,
};
-pub struct MapDeserializer<'s, 'r, R: Read> {
+pub struct MapDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
reader: &'r mut R,
schema: &'s MapSchema,
- config: Config<'s>,
+ config: Config<'s, S>,
state: State,
}
-impl<'s, 'r, R: Read> MapDeserializer<'s, 'r, R> {
- pub fn new(reader: &'r mut R, schema: &'s MapSchema, config: Config<'s>)
-> Self {
+impl<'s, 'r, R: Read, S: Borrow<Schema>> MapDeserializer<'s, 'r, R, S> {
+ pub fn new(reader: &'r mut R, schema: &'s MapSchema, config: Config<'s,
S>) -> Self {
Self {
reader,
schema,
@@ -50,7 +51,7 @@ enum State {
Finished,
}
-impl<'de, 's, 'r, R: Read> MapAccess<'de> for MapDeserializer<'s, 'r, R> {
+impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> MapAccess<'de> for
MapDeserializer<'s, 'r, R, S> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>,
Self::Error>
@@ -66,7 +67,7 @@ impl<'de, 's, 'r, R: Read> MapAccess<'de> for
MapDeserializer<'s, 'r, R> {
if remaining == 0 {
self.state = State::Finished
} else {
- self.state = State::ReadingKey(remaining.abs() as u32)
+ self.state = State::ReadingKey(remaining.unsigned_abs())
}
self.next_key_seed(seed)
}
diff --git a/avro/src/serde/deser_schema/mod.rs
b/avro/src/serde/deser_schema/mod.rs
index 26efad1..0e07518 100644
--- a/avro/src/serde/deser_schema/mod.rs
+++ b/avro/src/serde/deser_schema/mod.rs
@@ -22,11 +22,14 @@ mod record;
mod tuple;
mod union;
+use std::borrow::Borrow;
+use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::io::Read;
use serde::Deserializer;
+use crate::schema::Name;
use crate::serde::deser_schema::enums::PlainEnumAccess;
use crate::serde::deser_schema::tuple::{
ManyTupleDeserializer, OneTupleDeserializer, UnitTupleDeserializer,
@@ -36,30 +39,40 @@ use crate::{
Error, Schema,
decode::decode_len,
error::Details,
- schema::{DecimalSchema, InnerDecimalSchema, Names, SchemaKind, UuidSchema},
+ schema::{DecimalSchema, InnerDecimalSchema, SchemaKind, UuidSchema},
serde::deser_schema::{
array::ArrayDeserializer, map::MapDeserializer,
record::RecordDeserializer,
},
util::{zag_i32, zag_i64},
};
-#[derive(Debug, Clone, Copy)]
-pub struct Config<'s> {
+#[derive(Debug)]
+pub struct Config<'s, S: Borrow<Schema>> {
/// All names that can be referenced in the schema being used for
serialisation.
- pub names: &'s Names,
+ pub names: &'s HashMap<Name, S>,
/// Should `Serialize` implementations pick a human-readable format.
///
/// It is recommended to set this to `false` as it results in compacter
output.
pub human_readable: bool,
}
-pub struct SchemaAwareDeserializer<'s, 'r, R: Read> {
+impl<'s, S: Borrow<Schema>> Copy for Config<'s, S> {}
+impl<'s, S: Borrow<Schema>> Clone for Config<'s, S> {
+ fn clone(&self) -> Self {
+ Self {
+ names: self.names,
+ human_readable: self.human_readable,
+ }
+ }
+}
+
+pub struct SchemaAwareDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
reader: &'r mut R,
schema: &'s Schema,
- config: Config<'s>,
+ config: Config<'s, S>,
}
-impl<'s, 'r, R: Read> Debug for SchemaAwareDeserializer<'s, 'r, R> {
+impl<'s, 'r, R: Read, S: Borrow<Schema>> Debug for SchemaAwareDeserializer<'s,
'r, R, S> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchemaAwareDeserializer")
.field("schema", &self.schema)
@@ -67,13 +80,18 @@ impl<'s, 'r, R: Read> Debug for SchemaAwareDeserializer<'s,
'r, R> {
}
}
-impl<'s, 'r, R: Read> SchemaAwareDeserializer<'s, 'r, R> {
- pub fn new(reader: &'r mut R, schema: &'s Schema, config: Config<'s>) ->
Result<Self, Error> {
+impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S>
{
+ pub fn new(
+ reader: &'r mut R,
+ schema: &'s Schema,
+ config: Config<'s, S>,
+ ) -> Result<Self, Error> {
if let Schema::Ref { name } = schema {
let schema = config
.names
.get(name)
- .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?;
+ .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?
+ .borrow();
Self::new(reader, schema, config)
} else {
Ok(Self {
@@ -101,6 +119,7 @@ impl<'s, 'r, R: Read> SchemaAwareDeserializer<'s, 'r, R> {
.names
.get(name)
.ok_or_else(|| Details::SchemaResolutionError(name.clone()))?
+ .borrow()
} else {
schema
};
@@ -152,7 +171,9 @@ impl<'s, 'r, R: Read> SchemaAwareDeserializer<'s, 'r, R> {
const DESERIALIZE_ANY: &str = "This value is compared by pointer value";
const DESERIALIZE_ANY_FIELDS: &[&str] = &[];
-impl<'de, 's, 'r, R: Read> Deserializer<'de> for SchemaAwareDeserializer<'s,
'r, R> {
+impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de>
+ for SchemaAwareDeserializer<'s, 'r, R, S>
+{
type Error = Error;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
@@ -185,8 +206,16 @@ impl<'de, 's, 'r, R: Read> Deserializer<'de> for
SchemaAwareDeserializer<'s, 'r,
Schema::Union(union) => {
UnionDeserializer::new(self.reader, union,
self.config)?.deserialize_any(visitor)
}
- Schema::Record(_) => {
- self.deserialize_struct(DESERIALIZE_ANY,
DESERIALIZE_ANY_FIELDS, visitor)
+ Schema::Record(schema) => {
+ if schema.attributes.get("org.apache.avro.rust.tuple")
+ == Some(&serde_json::Value::Bool(true))
+ {
+ // This is needed because we can't tell the difference
between a tuple and struct.
+ // And a tuple needs to be deserialized as a sequence
+ self.deserialize_tuple(schema.fields.len(), visitor)
+ } else {
+ self.deserialize_struct(DESERIALIZE_ANY,
DESERIALIZE_ANY_FIELDS, visitor)
+ }
}
Schema::Enum(_) => {
self.deserialize_enum(DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS,
visitor)
@@ -246,9 +275,7 @@ impl<'de, 's, 'r, R: Read> Deserializer<'de> for
SchemaAwareDeserializer<'s, 'r,
where
V: serde::de::Visitor<'de>,
{
- let int = self.read_int("i32")?;
- let value = i32::try_from(int)
- .map_err(|_| self.error("i32", format!("Could not convert int
({int}) to an i32")))?;
+ let value = self.read_int("i32")?;
visitor.visit_i32(value)
}
@@ -256,9 +283,7 @@ impl<'de, 's, 'r, R: Read> Deserializer<'de> for
SchemaAwareDeserializer<'s, 'r,
where
V: serde::de::Visitor<'de>,
{
- let int = self.read_long("i64")?;
- let value = i64::try_from(int)
- .map_err(|_| self.error("i64", format!("Could not convert int
({int}) to an i64")))?;
+ let value = self.read_long("i64")?;
visitor.visit_i64(value)
}
@@ -413,7 +438,7 @@ impl<'de, 's, 'r, R: Read> Deserializer<'de> for
SchemaAwareDeserializer<'s, 'r,
&& let Some(null_index) =
union.index().get(&SchemaKind::Null).copied()
{
let index = zag_i32(self.reader)?;
- if index < 0 || index > 1 {
+ if !(0..=1).contains(&index) {
return Err(self.error("option", format!("Invalid union index
{index}")));
}
let index = index as usize;
@@ -447,7 +472,7 @@ impl<'de, 's, 'r, R: Read> Deserializer<'de> for
SchemaAwareDeserializer<'s, 'r,
V: serde::de::Visitor<'de>,
{
if let Schema::Record(record) = self.schema
- && record.fields.len() == 0
+ && record.fields.is_empty()
&& record.name.name() == name
{
visitor.visit_unit()
@@ -535,10 +560,15 @@ impl<'de, 's, 'r, R: Read> Deserializer<'de> for
SchemaAwareDeserializer<'s, 'r,
where
V: serde::de::Visitor<'de>,
{
- if let Schema::Map(map) = self.schema {
- visitor.visit_map(MapDeserializer::new(self.reader, map,
self.config))
- } else {
- Err(self.error("map", "Expected Schema::Map"))
+ match self.schema {
+ Schema::Map(map) => {
+ visitor.visit_map(MapDeserializer::new(self.reader, map,
self.config))
+ }
+ Schema::Record(record) => {
+ // Needed for flattened structs which are (de)serialized as
maps
+ visitor.visit_map(RecordDeserializer::new(self.reader, record,
self.config))
+ }
+ _ => Err(self.error("map", "Expected Schema::Map")),
}
}
diff --git a/avro/src/serde/deser_schema/record.rs
b/avro/src/serde/deser_schema/record.rs
index ad70999..055dc37 100644
--- a/avro/src/serde/deser_schema/record.rs
+++ b/avro/src/serde/deser_schema/record.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use std::borrow::Borrow;
use std::fmt::{Debug, Formatter};
use std::io::Read;
@@ -27,25 +28,24 @@ use crate::{
serde::deser_schema::SchemaAwareDeserializer,
};
-pub struct RecordDeserializer<'s, 'r, R: Read> {
+pub struct RecordDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
reader: &'r mut R,
schema: &'s RecordSchema,
- config: Config<'s>,
+ config: Config<'s, S>,
current_field: State,
}
-impl<'s, 'r, R: Read> Debug for RecordDeserializer<'s, 'r, R> {
+impl<'s, 'r, R: Read, S: Borrow<Schema>> Debug for RecordDeserializer<'s, 'r,
R, S> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecordDeserializer")
.field("schema", &self.schema)
- .field("config", &self.config)
.field("current_field", &self.current_field)
.finish()
}
}
-impl<'s, 'r, R: Read> RecordDeserializer<'s, 'r, R> {
- pub fn new(reader: &'r mut R, schema: &'s RecordSchema, config:
Config<'s>) -> Self {
+impl<'s, 'r, R: Read, S: Borrow<Schema>> RecordDeserializer<'s, 'r, R, S> {
+ pub fn new(reader: &'r mut R, schema: &'s RecordSchema, config: Config<'s,
S>) -> Self {
Self {
reader,
schema,
@@ -61,7 +61,7 @@ enum State {
Value(usize),
}
-impl<'de, 's, 'r, R: Read> MapAccess<'de> for RecordDeserializer<'s, 'r, R> {
+impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> MapAccess<'de> for
RecordDeserializer<'s, 'r, R, S> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>,
Self::Error>
diff --git a/avro/src/serde/deser_schema/tuple.rs
b/avro/src/serde/deser_schema/tuple.rs
index c95e38a..6b9284d 100644
--- a/avro/src/serde/deser_schema/tuple.rs
+++ b/avro/src/serde/deser_schema/tuple.rs
@@ -3,17 +3,18 @@ use crate::serde::deser_schema::union::UnionDeserializer;
use crate::serde::deser_schema::{Config, SchemaAwareDeserializer};
use crate::{Error, Schema};
use serde::de::{DeserializeSeed, SeqAccess};
+use std::borrow::Borrow;
use std::io::Read;
-pub struct ManyTupleDeserializer<'s, 'r, R: Read> {
+pub struct ManyTupleDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
reader: &'r mut R,
schema: &'s RecordSchema,
- config: Config<'s>,
+ config: Config<'s, S>,
current_field: usize,
}
-impl<'s, 'r, R: Read> ManyTupleDeserializer<'s, 'r, R> {
- pub fn new(reader: &'r mut R, schema: &'s RecordSchema, config:
Config<'s>) -> Self {
+impl<'s, 'r, R: Read, S: Borrow<Schema>> ManyTupleDeserializer<'s, 'r, R, S> {
+ pub fn new(reader: &'r mut R, schema: &'s RecordSchema, config: Config<'s,
S>) -> Self {
Self {
reader,
schema,
@@ -23,7 +24,9 @@ impl<'s, 'r, R: Read> ManyTupleDeserializer<'s, 'r, R> {
}
}
-impl<'de, 's, 'r, R: Read> SeqAccess<'de> for ManyTupleDeserializer<'s, 'r, R>
{
+impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> SeqAccess<'de>
+ for ManyTupleDeserializer<'s, 'r, R, S>
+{
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>,
Self::Error>
@@ -62,15 +65,15 @@ impl<'de> SeqAccess<'de> for UnitTupleDeserializer {
}
}
-pub struct OneTupleDeserializer<'s, 'r, R: Read> {
+pub struct OneTupleDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
reader: &'r mut R,
schema: &'s Schema,
- config: Config<'s>,
+ config: Config<'s, S>,
field_read: bool,
}
-impl<'s, 'r, R: Read> OneTupleDeserializer<'s, 'r, R> {
- pub fn new(reader: &'r mut R, schema: &'s Schema, config: Config<'s>) ->
Self {
+impl<'s, 'r, R: Read, S: Borrow<Schema>> OneTupleDeserializer<'s, 'r, R, S> {
+ pub fn new(reader: &'r mut R, schema: &'s Schema, config: Config<'s, S>)
-> Self {
Self {
reader,
schema,
@@ -80,7 +83,9 @@ impl<'s, 'r, R: Read> OneTupleDeserializer<'s, 'r, R> {
}
}
-impl<'de, 's, 'r, R: Read> SeqAccess<'de> for OneTupleDeserializer<'s, 'r, R> {
+impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> SeqAccess<'de>
+ for OneTupleDeserializer<'s, 'r, R, S>
+{
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>,
Self::Error>
diff --git a/avro/src/serde/deser_schema/union.rs
b/avro/src/serde/deser_schema/union.rs
index 480a474..658eb2e 100644
--- a/avro/src/serde/deser_schema/union.rs
+++ b/avro/src/serde/deser_schema/union.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use std::borrow::Borrow;
use std::fmt::{Debug, Formatter};
use std::io::Read;
@@ -26,14 +27,14 @@ use crate::schema::{SchemaKind, UuidSchema};
use crate::util::zag_i32;
use crate::{Error, Schema, schema::UnionSchema};
-pub struct UnionDeserializer<'s, 'r, R: Read> {
+pub struct UnionDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
reader: &'r mut R,
schema: &'s UnionSchema,
- config: Config<'s>,
+ config: Config<'s, S>,
variant: &'s Schema,
}
-impl<'s, 'r, R: Read> Debug for UnionDeserializer<'s, 'r, R> {
+impl<'s, 'r, R: Read, S: Borrow<Schema>> Debug for UnionDeserializer<'s, 'r,
R, S> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnionDeserializer")
.field("schema", self.schema)
@@ -42,11 +43,11 @@ impl<'s, 'r, R: Read> Debug for UnionDeserializer<'s, 'r,
R> {
}
}
-impl<'s, 'r, R: Read> UnionDeserializer<'s, 'r, R> {
+impl<'s, 'r, R: Read, S: Borrow<Schema>> UnionDeserializer<'s, 'r, R, S> {
pub fn new(
reader: &'r mut R,
schema: &'s UnionSchema,
- config: Config<'s>,
+ config: Config<'s, S>,
) -> Result<Self, Error> {
let index = zag_i32(reader)?;
let variant =
@@ -74,7 +75,9 @@ impl<'s, 'r, R: Read> UnionDeserializer<'s, 'r, R> {
}
}
-impl<'de, 's, 'r, R: Read> serde::Deserializer<'de> for UnionDeserializer<'s,
'r, R> {
+impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> serde::Deserializer<'de>
+ for UnionDeserializer<'s, 'r, R, S>
+{
type Error = Error;
fn deserialize_any<V>(mut self, visitor: V) -> Result<V::Value,
Self::Error>
@@ -105,8 +108,16 @@ impl<'de, 's, 'r, R: Read> serde::Deserializer<'de> for
UnionDeserializer<'s, 'r
Schema::String | Schema::Uuid(UuidSchema::String) =>
self.deserialize_string(visitor),
Schema::Array(_) => self.deserialize_seq(visitor),
Schema::Map(_) => self.deserialize_map(visitor),
- Schema::Record(_) => {
- self.deserialize_struct(DESERIALIZE_ANY,
DESERIALIZE_ANY_FIELDS, visitor)
+ Schema::Record(schema) => {
+ if schema.attributes.get("org.apache.avro.rust.tuple")
+ == Some(&serde_json::Value::Bool(true))
+ {
+ // This is needed because we can't tell the difference
between a tuple and struct.
+ // And a tuple needs to be deserialized as a sequence
+ self.deserialize_tuple(schema.fields.len(), visitor)
+ } else {
+ self.deserialize_struct(DESERIALIZE_ANY,
DESERIALIZE_ANY_FIELDS, visitor)
+ }
}
Schema::Enum(_) => {
self.deserialize_enum(DESERIALIZE_ANY, DESERIALIZE_ANY_FIELDS,
visitor)
@@ -116,7 +127,8 @@ impl<'de, 's, 'r, R: Read> serde::Deserializer<'de> for
UnionDeserializer<'s, 'r
.config
.names
.get(name)
- .ok_or_else(||
Details::SchemaResolutionError(name.clone()))?;
+ .ok_or_else(||
Details::SchemaResolutionError(name.clone()))?
+ .borrow();
self.variant = schema;
self.deserialize_any(visitor)
}
diff --git a/avro/src/serde/ser_schema/array.rs
b/avro/src/serde/ser_schema/array.rs
index 1700409..a7961af 100644
--- a/avro/src/serde/ser_schema/array.rs
+++ b/avro/src/serde/ser_schema/array.rs
@@ -76,13 +76,17 @@ struct DirectArraySerializer<'s, 'w, W: Write> {
impl<'s, 'w, W: Write> DirectArraySerializer<'s, 'w, W> {
pub fn new(
- mut writer: &'w mut W,
+ writer: &'w mut W,
array: &'s ArraySchema,
config: Config<'s>,
len: usize,
mut bytes_written: usize,
) -> Result<Self, Error> {
- bytes_written += encode_int(len as i32, &mut writer)?;
+ if len != 0 {
+ // .end() always writes the zero block, so we only want to write
+ // the size for arrays that have at least one element
+ bytes_written += encode_int(len as i32, &mut *writer)?;
+ }
Ok(Self {
writer,
array,
diff --git a/avro/src/serde/ser_schema/map.rs b/avro/src/serde/ser_schema/map.rs
index 185b8a6..6870d00 100644
--- a/avro/src/serde/ser_schema/map.rs
+++ b/avro/src/serde/ser_schema/map.rs
@@ -92,7 +92,11 @@ impl<'s, 'w, W: Write> DirectMapSerializer<'s, 'w, W> {
len: usize,
mut bytes_written: usize,
) -> Result<Self, Error> {
- bytes_written += encode_int(len as i32, &mut *writer)?;
+ if len != 0 {
+ // .end() always writes the zero block, so we only want to write
+ // the size for maps that have at least one entry
+ bytes_written += encode_int(len as i32, &mut *writer)?;
+ }
Ok(Self {
writer,
map,
diff --git a/avro/src/serde/ser_schema/mod.rs b/avro/src/serde/ser_schema/mod.rs
index d355cbf..9244910 100644
--- a/avro/src/serde/ser_schema/mod.rs
+++ b/avro/src/serde/ser_schema/mod.rs
@@ -326,7 +326,7 @@ impl<'s, 'w, W: Write> Serializer for
SchemaAwareSerializer<'s, 'w, W> {
fn serialize_f64(mut self, v: f64) -> Result<Self::Ok, Self::Error> {
println!("serialize_f64({v}): {self:?}");
match self.schema {
- Schema::Float => {
+ Schema::Double => {
let bytes = v.to_le_bytes();
self.write_bytes(&bytes)
}
@@ -701,17 +701,12 @@ impl<'s, 'w, W: Write> Serializer for
SchemaAwareSerializer<'s, 'w, W> {
println!("serialize_struct(name: {name}, len: {len}): {self:?}");
match self.schema {
// For unit variants with tag,content `len` will be 1 but we
expect 2.
- Schema::Record(record)
- if record.name.name() == name && record.fields.len() == len
- || record.fields.len() == 2 =>
- {
- Ok(RecordSerializer::new(
- self.writer,
- record,
- self.config,
- None,
- ))
- }
+ Schema::Record(record) if record.name.name() == name =>
Ok(RecordSerializer::new(
+ self.writer,
+ record,
+ self.config,
+ None,
+ )),
Schema::Union(union) => UnionAwareSerializer::new(self.writer,
union, self.config)
.serialize_struct(name, len),
_ => Err(self.error(
@@ -784,7 +779,7 @@ mod tests {
use crate::schema::{FixedSchema, Name};
use crate::{
Days, Duration, Error, Millis, Months, Reader, Schema, Writer,
decimal::Decimal,
- error::Details, from_value, schema::ResolvedSchema,
+ error::Details, schema::ResolvedSchema,
};
use apache_avro_test_helper::TestResult;
use bigdecimal::BigDecimal;
@@ -1247,7 +1242,10 @@ mod tests {
schema,
}) => {
assert_eq!(value_type, "None");
- assert_eq!(value, "Expected Schema::Union([null, _])");
+ assert_eq!(
+ value,
+ "Expected Schema::Union(variants.len() == 2 &&
variants.contains(Schema::Null))"
+ );
assert_eq!(schema, schema);
}
unexpected => panic!("Expected an error. Got: {unexpected:?}"),
@@ -1393,7 +1391,7 @@ mod tests {
schema,
}) => {
assert_eq!(value_type, "str");
- assert_eq!(value, "Expected Schema::String |
Schema::Uuid(String)");
+ assert_eq!(value, "Expected Schema::String in variants");
assert_eq!(schema, schema);
}
unexpected => panic!("Expected an error. Got: {unexpected:?}"),
@@ -1448,7 +1446,7 @@ mod tests {
schema,
}) => {
assert_eq!(value_type, "f64");
- assert_eq!(value, "Expected Schema::Double");
+ assert_eq!(value, "Expected Schema::Double in variants");
assert_eq!(schema, schema);
}
unexpected => panic!("Expected an error. Got: {unexpected:?}"),
@@ -2249,7 +2247,7 @@ mod tests {
let mut reader = Reader::builder(&encoded[..])
.reader_schema(&schema)
.build()?;
- let decoded = from_value::<Foo>(&reader.next().unwrap()?)?;
+ let decoded: Foo = reader.next_deser()?.unwrap();
assert_eq!(
decoded,
Foo {
diff --git a/avro/src/serde/ser_schema/union.rs
b/avro/src/serde/ser_schema/union.rs
index 09562c6..0a75273 100644
--- a/avro/src/serde/ser_schema/union.rs
+++ b/avro/src/serde/ser_schema/union.rs
@@ -252,7 +252,7 @@ impl<'s, 'w, W: Write> Serializer for
UnionAwareSerializer<'s, 'w, W> {
};
let mut bytes_written = encode_int(index as i32, &mut *self.writer)?;
let bytes = v.as_bytes();
- bytes_written += self.write_bytes_with_len(&bytes)?;
+ bytes_written += self.write_bytes_with_len(bytes)?;
Ok(bytes_written)
}
diff --git a/avro/src/serde/with.rs b/avro/src/serde/with.rs
index fc3b5bc..99a432e 100644
--- a/avro/src/serde/with.rs
+++ b/avro/src/serde/with.rs
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+#![expect(clippy::ref_option, reason = "Required by the Serde API")]
+
use std::cell::Cell;
thread_local! {
@@ -186,7 +188,6 @@ pub mod bytes_opt {
None
}
- #[expect(clippy::ref_option, reason = "Required by the Serde API")]
pub fn serialize<S, B>(bytes: &Option<B>, serializer: S) -> Result<S::Ok,
S::Error>
where
S: Serializer,
@@ -344,7 +345,6 @@ pub mod fixed_opt {
None
}
- #[expect(clippy::ref_option, reason = "Required by the Serde API")]
pub fn serialize<S, B>(bytes: &Option<B>, serializer: S) -> Result<S::Ok,
S::Error>
where
S: Serializer,
@@ -486,7 +486,6 @@ pub mod slice_opt {
None
}
- #[expect(clippy::ref_option, reason = "Required by the Serde API")]
pub fn serialize<S, B>(bytes: &Option<B>, serializer: S) -> Result<S::Ok,
S::Error>
where
S: Serializer,
@@ -629,7 +628,7 @@ pub mod bigdecimal_opt {
None
}
- pub fn serialize<S, B>(decimal: &Option<BigDecimal>, serializer: S) ->
Result<S::Ok, S::Error>
+ pub fn serialize<S>(decimal: &Option<BigDecimal>, serializer: S) ->
Result<S::Ok, S::Error>
where
S: Serializer,
{
diff --git a/avro/tests/avro-rs-226.rs b/avro/tests/avro-rs-226.rs
index fd1b6d2..63c912d 100644
--- a/avro/tests/avro-rs-226.rs
+++ b/avro/tests/avro-rs-226.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use apache_avro::{AvroSchema, Schema, Writer, from_value};
+use apache_avro::{AvroSchema, Schema, Writer};
use apache_avro_test_helper::TestResult;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::fmt::Debug;
@@ -29,13 +29,10 @@ where
writer.append_ser(record)?;
let bytes_written = writer.into_inner()?;
- let reader = apache_avro::Reader::new(&bytes_written[..])?;
- for value in reader {
- let value = value?;
- let deserialized = from_value::<T>(&value)?;
- assert_eq!(deserialized, record2);
+ let mut reader = apache_avro::Reader::new(&bytes_written[..])?;
+ while let Some(value) = reader.next_deser::<T>()? {
+ assert_eq!(value, record2);
}
-
Ok(())
}
diff --git a/avro_derive/Cargo.toml b/avro_derive/Cargo.toml
index 9ed7b2d..d7a5fd9 100644
--- a/avro_derive/Cargo.toml
+++ b/avro_derive/Cargo.toml
@@ -46,6 +46,7 @@ proptest = { default-features = false, version = "1.10.0",
features = ["std"] }
rustversion = "1.0.22"
serde = { workspace = true }
trybuild = "1.0.116"
+hexdump = "0.1.2"
[build-dependencies]
rustversion = "1.0.22"
diff --git a/avro_derive/src/attributes/mod.rs
b/avro_derive/src/attributes/mod.rs
index 15a2561..b6ebead 100644
--- a/avro_derive/src/attributes/mod.rs
+++ b/avro_derive/src/attributes/mod.rs
@@ -49,6 +49,7 @@ pub struct NamedTypeOptions {
pub transparent: bool,
pub default: TokenStream,
pub repr: Option<EnumRepr>,
+ pub with: Option<Path>,
}
impl NamedTypeOptions {
@@ -93,6 +94,12 @@ impl NamedTypeOptions {
r#"AvroSchema: rename rules for serializing and deserializing
must match (`rename_all_fields(serialize = "..", deserialize = "..")`)"#
));
}
+ if serde.from != serde.into && serde.try_from != serde.into {
+ errors.push(syn::Error::new(
+ span,
+ r#"AvroSchema: `#[serde({try_,}from = "..")]` must match
`#[serde(into = "..")]`"#,
+ ));
+ }
// Check for conflicts between Serde and Avro
if avro.name.is_some() && avro.name != serde.rename {
@@ -118,14 +125,36 @@ impl NamedTypeOptions {
|| serde.rename_all.deserialize != RenameRule::None
|| serde.untagged
|| serde.tag.is_some()
- || serde.content.is_some())
+ || serde.content.is_some()
+ || serde.into.is_some()
+ || serde.from.is_some()
+ || serde.try_from.is_some())
{
errors.push(syn::Error::new(
span,
"AvroSchema: `#[serde(transparent)]` is incompatible with all
other attributes",
));
}
+ if serde.into.is_some()
+ && (serde.rename.is_some()
+ || avro.name.is_some()
+ || avro.namespace.is_some()
+ || avro.doc.is_some()
+ || !avro.alias.is_empty()
+ || avro.rename_all != RenameRule::None
+ || serde.rename_all.serialize != RenameRule::None
+ || serde.rename_all.deserialize != RenameRule::None
+ || serde.untagged
+ || serde.tag.is_some()
+ || serde.content.is_some())
+ {
+ errors.push(syn::Error::new(
+ span,
+ r#"AvroSchema: `#[serde({try_,}from = ".."", into = "..")]`
are incompatible with all other attributes"#,
+ ));
+ }
+ #[expect(clippy::manual_map, reason = "Intent is clearer this way")]
let repr = if let Some(repr) = avro.repr {
match repr {
avro::EnumRepr::Enum => {
@@ -188,37 +217,47 @@ impl NamedTypeOptions {
}
}
}
+ } else if let Some(content) = serde.content
+ && let Some(tag) = serde.tag
+ {
+ Some(EnumRepr::RecordTagContent { tag, content })
+ } else if serde.untagged {
+ Some(EnumRepr::BareUnion)
+ } else if let Some(tag) = serde.tag {
+ Some(EnumRepr::RecordInternallyTagged { tag })
} else {
- if let Some(content) = serde.content
- && let Some(tag) = serde.tag
- {
- Some(EnumRepr::RecordTagContent { tag, content })
- } else if serde.untagged {
- Some(EnumRepr::BareUnion)
- } else if let Some(tag) = serde.tag {
- Some(EnumRepr::RecordInternallyTagged { tag })
- } else {
- None
- }
+ None
};
let default = match avro.default {
- None => quote! { None },
+ None => quote! { ::std::option::Option::None },
Some(default_value) => {
if let Err(err) =
serde_json::from_str::<serde_json::Value>(&default_value[..]) {
errors.push(syn::Error::new(
ident.span(),
format!("Invalid Avro `default` JSON: \n{err}"),
));
- quote! { None }
+ quote! { ::std::option::Option::None }
} else {
quote! {
-
Some(serde_json::from_str(#default_value).expect("Unreachable! This was checked
at compile time"))
+
::std::option::Option::Some(serde_json::from_str(#default_value).expect("Unreachable!
This was checked at compile time"))
}
}
}
};
+ let with = match serde.into.as_deref().map(Path::from_string) {
+ Some(Ok(path)) => Some(path),
+ Some(Err(err)) => {
+ errors.push(syn::Error::new(
+ span,
+ format!(r#"AvroSchema: Expected a path for `#[serde(into =
"..")]`: {err:?}"#),
+ ));
+ None
+ }
+ None => None,
+ };
+
if !errors.is_empty() {
return Err(errors);
}
@@ -241,6 +280,7 @@ impl NamedTypeOptions {
transparent: serde.transparent,
default,
repr,
+ with,
})
}
}
diff --git a/avro_derive/src/attributes/serde.rs
b/avro_derive/src/attributes/serde.rs
index 6d7f55c..b0e4ae1 100644
--- a/avro_derive/src/attributes/serde.rs
+++ b/avro_derive/src/attributes/serde.rs
@@ -124,13 +124,13 @@ pub struct ContainerAttributes {
pub transparent: bool,
/// Deserialize using the given type and then convert to this type with
`From`.
#[darling(default, rename = "from")]
- pub _from: Option<String>,
+ pub from: Option<String>,
/// Deserialize using the given type and then convert to this type with
`TryFrom`.
#[darling(default, rename = "try_from")]
- pub _try_from: Option<String>,
+ pub try_from: Option<String>,
/// Convert this type to the given type using `Into` and then serialize
using the given type.
#[darling(default, rename = "into")]
- pub _into: Option<String>,
+ pub into: Option<String>,
/// Use the Serde API at this path.
#[darling(default, rename = "crate")]
pub _crate: Option<String>,
diff --git a/avro_derive/src/enums/bare_union.rs
b/avro_derive/src/enums/bare_union.rs
index 26d9ee7..a9e5c93 100644
--- a/avro_derive/src/enums/bare_union.rs
+++ b/avro_derive/src/enums/bare_union.rs
@@ -36,7 +36,7 @@ pub fn get_data_enum_schema_def(
}
}
Fields::Unnamed(unnamed) => {
- if unnamed.unnamed.len() == 0 {
+ if unnamed.unnamed.is_empty() {
return Err(vec![syn::Error::new(
unnamed.span(),
"AvroSchema: Empty tuple variants are not supported
for bare unions",
@@ -64,6 +64,7 @@ pub fn get_data_enum_schema_def(
.fields(vec![
#(#fields,)*
])
+
.attributes([("org.apache.avro.rust.tuple".to_string(),
::serde_json::value::Value::Bool(true))].into())
.build()
)
};
diff --git a/avro_derive/src/lib.rs b/avro_derive/src/lib.rs
index 1159323..cebfe24 100644
--- a/avro_derive/src/lib.rs
+++ b/avro_derive/src/lib.rs
@@ -58,59 +58,71 @@ fn derive_avro_schema(input: DeriveInput) ->
Result<TokenStream, Vec<syn::Error>
// It would be nice to parse the attributes before the `match`, but we
first need to validate that `input` is not a union.
// Otherwise a user could get errors related to the attributes and after
fixing those get an error because the attributes were on a union.
let input_span = input.span();
- match input.data {
- syn::Data::Struct(data_struct) => {
- let named_type_options = NamedTypeOptions::new(&input.ident,
&input.attrs, input_span)?;
- if named_type_options.repr.is_some() {
- return Err(vec![syn::Error::new(
- input_span,
- r#"AvroSchema: `#[avro(repr = "..")]`, `#[serde(tag =
"..")]`, `#[serde(content = "..")]`, and `#[serde(untagged)]` are only
supported on enums"#,
- )]);
+ let named_type_options = NamedTypeOptions::new(&input.ident, &input.attrs,
input_span)?;
+ if let Some(path) = named_type_options.with {
+ Ok(create_trait_definition(
+ input.ident,
+ &input.generics,
+ quote! { #path::get_schema_in_ctxt(named_schemas,
enclosing_namespace) },
+ quote! { #path::get_record_fields_in_ctxt(named_schemas,
enclosing_namespace) },
+ quote! { #path::field_default() },
+ ))
+ } else {
+ match input.data {
+ syn::Data::Struct(data_struct) => {
+ if named_type_options.repr.is_some() {
+ return Err(vec![syn::Error::new(
+ input_span,
+ r#"AvroSchema: `#[avro(repr = "..")]`, `#[serde(tag =
"..")]`, `#[serde(content = "..")]`, and `#[serde(untagged)]` are only
supported on enums"#,
+ )]);
+ }
+ let (get_schema_impl, get_record_fields_impl) = if
named_type_options.transparent {
+ get_transparent_struct_schema_def(data_struct.fields,
input_span)?
+ } else {
+ let (schema_def, record_fields) = get_struct_schema_def(
+ &named_type_options,
+ data_struct,
+ input.ident.span(),
+ )?;
+ (
+ handle_named_schemas(named_type_options.name,
schema_def),
+ record_fields,
+ )
+ };
+ Ok(create_trait_definition(
+ input.ident,
+ &input.generics,
+ get_schema_impl,
+ get_record_fields_impl,
+ named_type_options.default,
+ ))
}
- let (get_schema_impl, get_record_fields_impl) = if
named_type_options.transparent {
- get_transparent_struct_schema_def(data_struct.fields,
input_span)?
- } else {
- let (schema_def, record_fields) =
- get_struct_schema_def(&named_type_options, data_struct,
input.ident.span())?;
- (
- handle_named_schemas(named_type_options.name, schema_def),
- record_fields,
- )
- };
- Ok(create_trait_definition(
- input.ident,
- &input.generics,
- get_schema_impl,
- get_record_fields_impl,
- named_type_options.default,
- ))
- }
- syn::Data::Enum(data_enum) => {
- let named_type_options = NamedTypeOptions::new(&input.ident,
&input.attrs, input_span)?;
- if named_type_options.transparent {
- return Err(vec![syn::Error::new(
- input_span,
- "AvroSchema: `#[serde(transparent)]` is only supported on
structs",
- )]);
- }
- let schema_def = enums::get_data_enum_schema_def(
- &named_type_options,
- data_enum,
- input.ident.span(),
- )?;
- let inner = handle_named_schemas(named_type_options.name,
schema_def);
- Ok(create_trait_definition(
- input.ident,
- &input.generics,
- inner,
- quote! { None },
- named_type_options.default,
- ))
+ syn::Data::Enum(data_enum) => {
+ if named_type_options.transparent {
+ return Err(vec![syn::Error::new(
+ input_span,
+ "AvroSchema: `#[serde(transparent)]` is only supported
on structs",
+ )]);
+ }
+ let schema_def = enums::get_data_enum_schema_def(
+ &named_type_options,
+ data_enum,
+ input.ident.span(),
+ )?;
+ let inner = handle_named_schemas(named_type_options.name,
schema_def);
+ Ok(create_trait_definition(
+ input.ident,
+ &input.generics,
+ inner,
+ quote! { None },
+ named_type_options.default,
+ ))
+ }
+ syn::Data::Union(_) => Err(vec![syn::Error::new(
+ input_span,
+ "AvroSchema: derive only works for structs and enums",
+ )]),
}
- syn::Data::Union(_) => Err(vec![syn::Error::new(
- input_span,
- "AvroSchema: derive only works for structs and enums",
- )]),
}
}
@@ -135,7 +147,7 @@ fn create_trait_definition(
}
fn field_default() -> ::std::option::Option<::serde_json::Value> {
- ::std::option::Option::#field_default_impl
+ #field_default_impl
}
}
}
@@ -615,9 +627,7 @@ mod tests {
let enclosing_namespace = name.namespace();
named_schemas.insert(name.clone());
::apache_avro::schema::Schema::Enum(apache_avro::schema::EnumSchema {
- name:
::apache_avro::schema::Name::new("Basic").expect(
- &format!("Unable to parse enum name
for schema {}", "Basic")[..]
- ),
+ name,
aliases: ::std::option::Option::None,
doc: None,
symbols: vec![
@@ -787,7 +797,7 @@ mod tests {
match syn::parse2::<DeriveInput>(test_enum) {
Ok(input) => {
let schema_res = derive_avro_schema(input);
- let expected_token_stream = r#"# [automatically_derived] impl
:: apache_avro :: AvroSchemaComponent for A { fn get_schema_in_ctxt
(named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro ::
schema :: Name > , enclosing_namespace : :: apache_avro :: schema ::
NamespaceRef) -> :: apache_avro :: schema :: Schema { let name = :: apache_avro
:: schema :: Name :: new_with_enclosing_namespace ("A" , enclosing_namespace) .
expect (concat ! ("Unable to parse schema [...]
+ let expected_token_stream = r#"# [automatically_derived] impl
:: apache_avro :: AvroSchemaComponent for A { fn get_schema_in_ctxt
(named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro ::
schema :: Name > , enclosing_namespace : :: apache_avro :: schema ::
NamespaceRef) -> :: apache_avro :: schema :: Schema { let name = :: apache_avro
:: schema :: Name :: new_with_enclosing_namespace ("A" , enclosing_namespace) .
expect (concat ! ("Unable to parse schema [...]
let schema_token_stream = schema_res.unwrap().to_string();
assert_eq!(schema_token_stream, expected_token_stream);
}
@@ -830,7 +840,7 @@ mod tests {
match syn::parse2::<DeriveInput>(test_enum) {
Ok(input) => {
let schema_res = derive_avro_schema(input);
- let expected_token_stream = r#"# [automatically_derived] impl
:: apache_avro :: AvroSchemaComponent for B { fn get_schema_in_ctxt
(named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro ::
schema :: Name > , enclosing_namespace : :: apache_avro :: schema ::
NamespaceRef) -> :: apache_avro :: schema :: Schema { let name = :: apache_avro
:: schema :: Name :: new_with_enclosing_namespace ("B" , enclosing_namespace) .
expect (concat ! ("Unable to parse schema [...]
+ let expected_token_stream = r#"# [automatically_derived] impl
:: apache_avro :: AvroSchemaComponent for B { fn get_schema_in_ctxt
(named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro ::
schema :: Name > , enclosing_namespace : :: apache_avro :: schema ::
NamespaceRef) -> :: apache_avro :: schema :: Schema { let name = :: apache_avro
:: schema :: Name :: new_with_enclosing_namespace ("B" , enclosing_namespace) .
expect (concat ! ("Unable to parse schema [...]
let schema_token_stream = schema_res.unwrap().to_string();
assert_eq!(schema_token_stream, expected_token_stream);
}
diff --git a/avro_derive/tests/derive.rs b/avro_derive/tests/derive.rs
index 1ef52e3..268e175 100644
--- a/avro_derive/tests/derive.rs
+++ b/avro_derive/tests/derive.rs
@@ -16,7 +16,7 @@
// under the License.
use apache_avro::{
- AvroSchema, AvroSchemaComponent, Reader, Schema, Writer, from_value,
+ AvroSchema, AvroSchemaComponent, Reader, Schema, Writer,
schema::{Alias, EnumSchema, FixedSchema, Name, RecordSchema},
};
use proptest::prelude::*;
@@ -64,20 +64,16 @@ where
T: DeserializeOwned + AvroSchema,
{
assert!(!encoded.is_empty());
+ hexdump::hexdump(&encoded);
let schema = T::get_schema();
let mut reader = Reader::builder(&encoded[..])
.reader_schema(&schema)
.build()
.unwrap();
- if let Some(res) = reader.next() {
- match res {
- Ok(value) => {
- return from_value::<T>(&value).unwrap();
- }
- Err(e) => panic!("{e:?}"),
- }
- }
- unreachable!()
+ reader
+ .next_deser::<T>()
+ .unwrap()
+ .expect("Did not deserialize a value from the reader")
}
#[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq, Eq)]
@@ -494,6 +490,44 @@ fn test_generic_container_1(a: String, b: Vec<i32>, c:
HashMap<String, i32>) {
serde_assert(test_generic);
}}
+#[test]
+fn test_generic_container_1_actual() {
+ let schema = r#"
+ {
+ "type":"record",
+ "name":"TestGeneric",
+ "fields":[
+ {
+ "name":"a",
+ "type":"string"
+ },
+ {
+ "name":"b",
+ "type": {
+ "type":"array",
+ "items":"int"
+ }
+ },
+ {
+ "name":"c",
+ "type": {
+ "type":"map",
+ "values":"int"
+ }
+ }
+ ]
+ }
+ "#;
+ let schema = Schema::parse_str(schema).unwrap();
+ assert_eq!(schema, TestGeneric::<i32>::get_schema());
+ let test_generic = TestGeneric::<i32> {
+ a: "".to_string(),
+ b: vec![0],
+ c: HashMap::new(),
+ };
+ serde_assert(test_generic);
+}
+
proptest! {
#[test]
fn test_generic_container_2(a: bool, b: i8, c: i16, d: i32, e: u8, f: u16, g:
i64, h: f32, i: f64, j: String) {
@@ -800,6 +834,7 @@ fn test_cons_generic() {
#[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq, Eq)]
struct TestSimpleArray {
+ #[serde(with = "apache_avro::serde::array")]
a: [i32; 4],
}
diff --git a/avro_derive/tests/serde.rs b/avro_derive/tests/serde.rs
index cf11fa4..445e11c 100644
--- a/avro_derive/tests/serde.rs
+++ b/avro_derive/tests/serde.rs
@@ -27,18 +27,18 @@ where
assert_eq!(obj, serde(obj.clone()).unwrap());
}
-/// Takes in a type that implements the right combination of traits and runs
it through a Serde
-/// round-trip and asserts that the error matches the expected string.
-fn serde_assert_err<T>(obj: T, expected: &str)
-where
- T: std::fmt::Debug + Serialize + DeserializeOwned + AvroSchema + Clone +
PartialEq,
-{
- let error = serde(obj).unwrap_err().to_string();
- assert!(
- error.contains(expected),
- "Error `{error}` does not contain `{expected}`"
- );
-}
+// /// Takes in a type that implements the right combination of traits and
runs it through a Serde
+// /// round-trip and asserts that the error matches the expected string.
+// fn serde_assert_err<T>(obj: T, expected: &str)
+// where
+// T: std::fmt::Debug + Serialize + DeserializeOwned + AvroSchema + Clone
+ PartialEq,
+// {
+// let error = serde(obj).unwrap_err().to_string();
+// assert!(
+// error.contains(expected),
+// "Error `{error}` does not contain `{expected}`"
+// );
+// }
fn serde<T>(obj: T) -> Result<T, Error>
where
@@ -231,7 +231,7 @@ mod container_attributes {
let schema = r#"
{
"type":"record",
- "name":"Foo",
+ "name":"FooFromInto",
"fields": [
{
"name":"a",
@@ -291,7 +291,7 @@ mod container_attributes {
let schema = r#"
{
"type":"record",
- "name":"Foo",
+ "name":"FooFromInto",
"fields": [
{
"name":"a",
@@ -300,6 +300,10 @@ mod container_attributes {
{
"name":"b",
"type":"int"
+ },
+ {
+ "name":"c",
+ "type":"boolean"
}
]
}
@@ -308,13 +312,10 @@ mod container_attributes {
let schema = Schema::parse_str(schema).unwrap();
assert_eq!(schema, Foo::get_schema());
- serde_assert_err(
- Foo {
- a: "spam".to_string(),
- b: 321,
- },
- "Invalid field name c",
- );
+ serde_assert(Foo {
+ a: "spam".to_string(),
+ b: 321,
+ });
}
#[test]