This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new d4ad4b74d Clean the code in `field.rs` and add more tests (#2345)
d4ad4b74d is described below

commit d4ad4b74db387fc9d686dbe8bb274785fa745cc9
Author: Remzi Yang <[email protected]>
AuthorDate: Wed Aug 10 22:56:57 2022 +0800

    Clean the code in `field.rs` and add more tests (#2345)
    
    * clean up the field
    
    Signed-off-by: remzi <[email protected]>
    
    * test to check same field
    
    Signed-off-by: remzi <[email protected]>
    
    * fix nit
    
    Signed-off-by: remzi <[email protected]>
    
    * fix fmt
    
    Signed-off-by: remzi <[email protected]>
---
 arrow/src/datatypes/field.rs | 178 ++++++++++++++++++++++++++-----------------
 1 file changed, 108 insertions(+), 70 deletions(-)

diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs
index abb80d64a..f50ebadd5 100644
--- a/arrow/src/datatypes/field.rs
+++ b/arrow/src/datatypes/field.rs
@@ -209,23 +209,17 @@ impl Field {
     }
 
     fn _fields<'a>(&'a self, dt: &'a DataType) -> Vec<&Field> {
-        let mut collected_fields = vec![];
-
         match dt {
             DataType::Struct(fields) | DataType::Union(fields, _, _) => {
-                collected_fields.extend(fields.iter().flat_map(|f| f.fields()))
+                fields.iter().flat_map(|f| f.fields()).collect()
             }
             DataType::List(field)
             | DataType::LargeList(field)
             | DataType::FixedSizeList(field, _)
-            | DataType::Map(field, _) => 
collected_fields.extend(field.fields()),
-            DataType::Dictionary(_, value_field) => {
-                collected_fields.append(&mut 
self._fields(value_field.as_ref()))
-            }
-            _ => (),
+            | DataType::Map(field, _) => field.fields(),
+            DataType::Dictionary(_, value_field) => 
self._fields(value_field.as_ref()),
+            _ => vec![],
         }
-
-        collected_fields
     }
 
     /// Returns a vector containing all (potentially nested) `Field` instances 
selected by the
@@ -506,12 +500,10 @@ impl Field {
     pub fn to_json(&self) -> Value {
         let children: Vec<Value> = match self.data_type() {
             DataType::Struct(fields) => fields.iter().map(|f| 
f.to_json()).collect(),
-            DataType::List(field) => vec![field.to_json()],
-            DataType::LargeList(field) => vec![field.to_json()],
-            DataType::FixedSizeList(field, _) => vec![field.to_json()],
-            DataType::Map(field, _) => {
-                vec![field.to_json()]
-            }
+            DataType::List(field)
+            | DataType::LargeList(field)
+            | DataType::FixedSizeList(field, _)
+            | DataType::Map(field, _) => vec![field.to_json()],
             _ => vec![],
         };
         match self.data_type() {
@@ -550,6 +542,17 @@ impl Field {
     /// assert!(field.is_nullable());
     /// ```
     pub fn try_merge(&mut self, from: &Field) -> Result<()> {
+        if from.dict_id != self.dict_id {
+            return Err(ArrowError::SchemaError(
+                "Fail to merge schema Field due to conflicting 
dict_id".to_string(),
+            ));
+        }
+        if from.dict_is_ordered != self.dict_is_ordered {
+            return Err(ArrowError::SchemaError(
+                "Fail to merge schema Field due to conflicting dict_is_ordered"
+                    .to_string(),
+            ));
+        }
         // merge metadata
         match (self.metadata(), from.metadata()) {
             (Some(self_metadata), Some(from_metadata)) => {
@@ -572,31 +575,16 @@ impl Field {
             }
             _ => {}
         }
-        if from.dict_id != self.dict_id {
-            return Err(ArrowError::SchemaError(
-                "Fail to merge schema Field due to conflicting 
dict_id".to_string(),
-            ));
-        }
-        if from.dict_is_ordered != self.dict_is_ordered {
-            return Err(ArrowError::SchemaError(
-                "Fail to merge schema Field due to conflicting dict_is_ordered"
-                    .to_string(),
-            ));
-        }
         match &mut self.data_type {
             DataType::Struct(nested_fields) => match &from.data_type {
                 DataType::Struct(from_nested_fields) => {
                     for from_field in from_nested_fields {
-                        let mut is_new_field = true;
-                        for self_field in nested_fields.iter_mut() {
-                            if self_field.name != from_field.name {
-                                continue;
-                            }
-                            is_new_field = false;
-                            self_field.try_merge(from_field)?;
-                        }
-                        if is_new_field {
-                            nested_fields.push(from_field.clone());
+                        match nested_fields
+                            .iter_mut()
+                            .find(|self_field| self_field.name == 
from_field.name)
+                        {
+                            Some(self_field) => 
self_field.try_merge(from_field)?,
+                            None => nested_fields.push(from_field.clone()),
                         }
                     }
                 }
@@ -685,9 +673,7 @@ impl Field {
                 }
             }
         }
-        if from.nullable {
-            self.nullable = from.nullable;
-        }
+        self.nullable |= from.nullable;
 
         Ok(())
     }
@@ -698,41 +684,25 @@ impl Field {
     /// * self.metadata is a superset of other.metadata
     /// * all other fields are equal
     pub fn contains(&self, other: &Field) -> bool {
-        if self.name != other.name
-            || self.data_type != other.data_type
-            || self.dict_id != other.dict_id
-            || self.dict_is_ordered != other.dict_is_ordered
-        {
-            return false;
-        }
-
-        if self.nullable != other.nullable && !self.nullable {
-            return false;
-        }
-
+        self.name == other.name
+        && self.data_type == other.data_type
+        && self.dict_id == other.dict_id
+        && self.dict_is_ordered == other.dict_is_ordered
+        // self need to be nullable or both of them are not nullable
+        && (self.nullable || !other.nullable)
         // make sure self.metadata is a superset of other.metadata
-        match (&self.metadata, &other.metadata) {
-            (None, Some(_)) => {
-                return false;
-            }
+        && match (&self.metadata, &other.metadata) {
+            (_, None) => true,
+            (None, Some(_)) => false,
             (Some(self_meta), Some(other_meta)) => {
-                for (k, v) in other_meta.iter() {
+                other_meta.iter().all(|(k, v)| {
                     match self_meta.get(k) {
-                        Some(s) => {
-                            if s != v {
-                                return false;
-                            }
-                        }
-                        None => {
-                            return false;
-                        }
+                        Some(s) => s == v,
+                        None => false
                     }
-                }
+                })
             }
-            _ => {}
         }
-
-        true
     }
 }
 
@@ -745,7 +715,7 @@ impl std::fmt::Display for Field {
 
 #[cfg(test)]
 mod test {
-    use super::{DataType, Field};
+    use super::*;
     use std::collections::hash_map::DefaultHasher;
     use std::hash::{Hash, Hasher};
 
@@ -840,4 +810,72 @@ mod test {
         assert_ne!(dict1, dict2);
         assert_ne!(get_field_hash(&dict1), get_field_hash(&dict2));
     }
+
+    #[test]
+    fn test_contains_reflexivity() {
+        let mut field = Field::new("field1", DataType::Float16, false);
+        field.set_metadata(Some(BTreeMap::from([
+            (String::from("k0"), String::from("v0")),
+            (String::from("k1"), String::from("v1")),
+        ])));
+        assert!(field.contains(&field))
+    }
+
+    #[test]
+    fn test_contains_transitivity() {
+        let child_field = Field::new("child1", DataType::Float16, false);
+
+        let mut field1 = Field::new("field1", 
DataType::Struct(vec![child_field]), false);
+        field1.set_metadata(Some(BTreeMap::from([(
+            String::from("k1"),
+            String::from("v1"),
+        )])));
+
+        let mut field2 = Field::new("field1", DataType::Struct(vec![]), true);
+        field2.set_metadata(Some(BTreeMap::from([(
+            String::from("k2"),
+            String::from("v2"),
+        )])));
+        field2.try_merge(&field1).unwrap();
+
+        let mut field3 = Field::new("field1", DataType::Struct(vec![]), false);
+        field3.set_metadata(Some(BTreeMap::from([(
+            String::from("k3"),
+            String::from("v3"),
+        )])));
+        field3.try_merge(&field2).unwrap();
+
+        assert!(field2.contains(&field1));
+        assert!(field3.contains(&field2));
+        assert!(field3.contains(&field1));
+
+        assert!(!field1.contains(&field2));
+        assert!(!field1.contains(&field3));
+        assert!(!field2.contains(&field3));
+    }
+
+    #[test]
+    fn test_contains_nullable() {
+        let field1 = Field::new("field1", DataType::Boolean, true);
+        let field2 = Field::new("field1", DataType::Boolean, false);
+        assert!(field1.contains(&field2));
+        assert!(!field2.contains(&field1));
+    }
+
+    #[test]
+    fn test_contains_must_have_same_fields() {
+        let child_field1 = Field::new("child1", DataType::Float16, false);
+        let child_field2 = Field::new("child2", DataType::Float16, false);
+
+        let field1 =
+            Field::new("field1", DataType::Struct(vec![child_field1.clone()]), 
true);
+        let field2 = Field::new(
+            "field1",
+            DataType::Struct(vec![child_field1, child_field2]),
+            true,
+        );
+
+        assert!(!field1.contains(&field2));
+        assert!(!field2.contains(&field1));
+    }
 }

Reply via email to