This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new d4ad4b74d Clean the code in `field.rs` and add more tests (#2345)
d4ad4b74d is described below
commit d4ad4b74db387fc9d686dbe8bb274785fa745cc9
Author: Remzi Yang <[email protected]>
AuthorDate: Wed Aug 10 22:56:57 2022 +0800
Clean the code in `field.rs` and add more tests (#2345)
* clean up the field
Signed-off-by: remzi <[email protected]>
* test to check same field
Signed-off-by: remzi <[email protected]>
* fix nit
Signed-off-by: remzi <[email protected]>
* fix fmt
Signed-off-by: remzi <[email protected]>
---
arrow/src/datatypes/field.rs | 178 ++++++++++++++++++++++++++-----------------
1 file changed, 108 insertions(+), 70 deletions(-)
diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs
index abb80d64a..f50ebadd5 100644
--- a/arrow/src/datatypes/field.rs
+++ b/arrow/src/datatypes/field.rs
@@ -209,23 +209,17 @@ impl Field {
}
fn _fields<'a>(&'a self, dt: &'a DataType) -> Vec<&Field> {
- let mut collected_fields = vec![];
-
match dt {
DataType::Struct(fields) | DataType::Union(fields, _, _) => {
- collected_fields.extend(fields.iter().flat_map(|f| f.fields()))
+ fields.iter().flat_map(|f| f.fields()).collect()
}
DataType::List(field)
| DataType::LargeList(field)
| DataType::FixedSizeList(field, _)
- | DataType::Map(field, _) =>
collected_fields.extend(field.fields()),
- DataType::Dictionary(_, value_field) => {
- collected_fields.append(&mut
self._fields(value_field.as_ref()))
- }
- _ => (),
+ | DataType::Map(field, _) => field.fields(),
+ DataType::Dictionary(_, value_field) =>
self._fields(value_field.as_ref()),
+ _ => vec![],
}
-
- collected_fields
}
/// Returns a vector containing all (potentially nested) `Field` instances
selected by the
@@ -506,12 +500,10 @@ impl Field {
pub fn to_json(&self) -> Value {
let children: Vec<Value> = match self.data_type() {
DataType::Struct(fields) => fields.iter().map(|f|
f.to_json()).collect(),
- DataType::List(field) => vec![field.to_json()],
- DataType::LargeList(field) => vec![field.to_json()],
- DataType::FixedSizeList(field, _) => vec![field.to_json()],
- DataType::Map(field, _) => {
- vec![field.to_json()]
- }
+ DataType::List(field)
+ | DataType::LargeList(field)
+ | DataType::FixedSizeList(field, _)
+ | DataType::Map(field, _) => vec![field.to_json()],
_ => vec![],
};
match self.data_type() {
@@ -550,6 +542,17 @@ impl Field {
/// assert!(field.is_nullable());
/// ```
pub fn try_merge(&mut self, from: &Field) -> Result<()> {
+ if from.dict_id != self.dict_id {
+ return Err(ArrowError::SchemaError(
+ "Fail to merge schema Field due to conflicting
dict_id".to_string(),
+ ));
+ }
+ if from.dict_is_ordered != self.dict_is_ordered {
+ return Err(ArrowError::SchemaError(
+ "Fail to merge schema Field due to conflicting dict_is_ordered"
+ .to_string(),
+ ));
+ }
// merge metadata
match (self.metadata(), from.metadata()) {
(Some(self_metadata), Some(from_metadata)) => {
@@ -572,31 +575,16 @@ impl Field {
}
_ => {}
}
- if from.dict_id != self.dict_id {
- return Err(ArrowError::SchemaError(
- "Fail to merge schema Field due to conflicting
dict_id".to_string(),
- ));
- }
- if from.dict_is_ordered != self.dict_is_ordered {
- return Err(ArrowError::SchemaError(
- "Fail to merge schema Field due to conflicting dict_is_ordered"
- .to_string(),
- ));
- }
match &mut self.data_type {
DataType::Struct(nested_fields) => match &from.data_type {
DataType::Struct(from_nested_fields) => {
for from_field in from_nested_fields {
- let mut is_new_field = true;
- for self_field in nested_fields.iter_mut() {
- if self_field.name != from_field.name {
- continue;
- }
- is_new_field = false;
- self_field.try_merge(from_field)?;
- }
- if is_new_field {
- nested_fields.push(from_field.clone());
+ match nested_fields
+ .iter_mut()
+ .find(|self_field| self_field.name ==
from_field.name)
+ {
+ Some(self_field) =>
self_field.try_merge(from_field)?,
+ None => nested_fields.push(from_field.clone()),
}
}
}
@@ -685,9 +673,7 @@ impl Field {
}
}
}
- if from.nullable {
- self.nullable = from.nullable;
- }
+ self.nullable |= from.nullable;
Ok(())
}
@@ -698,41 +684,25 @@ impl Field {
/// * self.metadata is a superset of other.metadata
/// * all other fields are equal
pub fn contains(&self, other: &Field) -> bool {
- if self.name != other.name
- || self.data_type != other.data_type
- || self.dict_id != other.dict_id
- || self.dict_is_ordered != other.dict_is_ordered
- {
- return false;
- }
-
- if self.nullable != other.nullable && !self.nullable {
- return false;
- }
-
+ self.name == other.name
+ && self.data_type == other.data_type
+ && self.dict_id == other.dict_id
+ && self.dict_is_ordered == other.dict_is_ordered
+ // self need to be nullable or both of them are not nullable
+ && (self.nullable || !other.nullable)
// make sure self.metadata is a superset of other.metadata
- match (&self.metadata, &other.metadata) {
- (None, Some(_)) => {
- return false;
- }
+ && match (&self.metadata, &other.metadata) {
+ (_, None) => true,
+ (None, Some(_)) => false,
(Some(self_meta), Some(other_meta)) => {
- for (k, v) in other_meta.iter() {
+ other_meta.iter().all(|(k, v)| {
match self_meta.get(k) {
- Some(s) => {
- if s != v {
- return false;
- }
- }
- None => {
- return false;
- }
+ Some(s) => s == v,
+ None => false
}
- }
+ })
}
- _ => {}
}
-
- true
}
}
@@ -745,7 +715,7 @@ impl std::fmt::Display for Field {
#[cfg(test)]
mod test {
- use super::{DataType, Field};
+ use super::*;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
@@ -840,4 +810,72 @@ mod test {
assert_ne!(dict1, dict2);
assert_ne!(get_field_hash(&dict1), get_field_hash(&dict2));
}
+
+ #[test]
+ fn test_contains_reflexivity() {
+ let mut field = Field::new("field1", DataType::Float16, false);
+ field.set_metadata(Some(BTreeMap::from([
+ (String::from("k0"), String::from("v0")),
+ (String::from("k1"), String::from("v1")),
+ ])));
+ assert!(field.contains(&field))
+ }
+
+ #[test]
+ fn test_contains_transitivity() {
+ let child_field = Field::new("child1", DataType::Float16, false);
+
+ let mut field1 = Field::new("field1",
DataType::Struct(vec![child_field]), false);
+ field1.set_metadata(Some(BTreeMap::from([(
+ String::from("k1"),
+ String::from("v1"),
+ )])));
+
+ let mut field2 = Field::new("field1", DataType::Struct(vec![]), true);
+ field2.set_metadata(Some(BTreeMap::from([(
+ String::from("k2"),
+ String::from("v2"),
+ )])));
+ field2.try_merge(&field1).unwrap();
+
+ let mut field3 = Field::new("field1", DataType::Struct(vec![]), false);
+ field3.set_metadata(Some(BTreeMap::from([(
+ String::from("k3"),
+ String::from("v3"),
+ )])));
+ field3.try_merge(&field2).unwrap();
+
+ assert!(field2.contains(&field1));
+ assert!(field3.contains(&field2));
+ assert!(field3.contains(&field1));
+
+ assert!(!field1.contains(&field2));
+ assert!(!field1.contains(&field3));
+ assert!(!field2.contains(&field3));
+ }
+
+ #[test]
+ fn test_contains_nullable() {
+ let field1 = Field::new("field1", DataType::Boolean, true);
+ let field2 = Field::new("field1", DataType::Boolean, false);
+ assert!(field1.contains(&field2));
+ assert!(!field2.contains(&field1));
+ }
+
+ #[test]
+ fn test_contains_must_have_same_fields() {
+ let child_field1 = Field::new("child1", DataType::Float16, false);
+ let child_field2 = Field::new("child2", DataType::Float16, false);
+
+ let field1 =
+ Field::new("field1", DataType::Struct(vec![child_field1.clone()]),
true);
+ let field2 = Field::new(
+ "field1",
+ DataType::Struct(vec![child_field1, child_field2]),
+ true,
+ );
+
+ assert!(!field1.contains(&field2));
+ assert!(!field2.contains(&field1));
+ }
}