This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new c7cf8f7  feat(ipc): Support writing dictionaries nested in structs and 
unions (#870)
c7cf8f7 is described below

commit c7cf8f77318ecc61531c5ee0785e81f7f26fe69a
Author: Helgi Kristvin Sigurbjarnarson <[email protected]>
AuthorDate: Fri Oct 29 06:11:40 2021 -0700

    feat(ipc): Support writing dictionaries nested in structs and unions (#870)
    
    * feat(ipc): Support for writing dictionaries nested in structs and unions
    
    Dictionaries are lost when serializing a RecordBatch for IPC, producing
    invalid arrow data. This PR changes encoded_batch to recursively find
    all dictionary fields within the schema (currently only in structs and
    unions) so nested dictionaries are properly serialized.
    
    * address lint and clippy
---
 arrow/src/array/cast.rs |   1 +
 arrow/src/array/mod.rs  |   2 +-
 arrow/src/ipc/writer.rs | 138 +++++++++++++++++++++++++++++++++++++++++++-----
 3 files changed, 127 insertions(+), 14 deletions(-)

diff --git a/arrow/src/array/cast.rs b/arrow/src/array/cast.rs
index dfc1560..e4284ef 100644
--- a/arrow/src/array/cast.rs
+++ b/arrow/src/array/cast.rs
@@ -92,3 +92,4 @@ array_downcast_fn!(as_largestring_array, LargeStringArray);
 array_downcast_fn!(as_boolean_array, BooleanArray);
 array_downcast_fn!(as_null_array, NullArray);
 array_downcast_fn!(as_struct_array, StructArray);
+array_downcast_fn!(as_union_array, UnionArray);
diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs
index 63b8b61..5d4e57a 100644
--- a/arrow/src/array/mod.rs
+++ b/arrow/src/array/mod.rs
@@ -441,7 +441,7 @@ pub use self::ord::{build_compare, DynComparator};
 pub use self::cast::{
     as_boolean_array, as_dictionary_array, as_generic_binary_array,
     as_generic_list_array, as_large_list_array, as_largestring_array, 
as_list_array,
-    as_null_array, as_primitive_array, as_string_array, as_struct_array,
+    as_null_array, as_primitive_array, as_string_array, as_struct_array, 
as_union_array,
 };
 
 // ------------------------------ C Data Interface ---------------------------
diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs
index 0376265..853fc0f 100644
--- a/arrow/src/ipc/writer.rs
+++ b/arrow/src/ipc/writer.rs
@@ -25,7 +25,7 @@ use std::io::{BufWriter, Write};
 
 use flatbuffers::FlatBufferBuilder;
 
-use crate::array::{ArrayData, ArrayRef};
+use crate::array::{as_struct_array, as_union_array, ArrayData, ArrayRef};
 use crate::buffer::{Buffer, MutableBuffer};
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
@@ -137,20 +137,45 @@ impl IpcDataGenerator {
         }
     }
 
-    pub fn encoded_batch(
+    fn encode_dictionaries(
         &self,
-        batch: &RecordBatch,
+        field: &Field,
+        column: &ArrayRef,
+        encoded_dictionaries: &mut Vec<EncodedData>,
         dictionary_tracker: &mut DictionaryTracker,
         write_options: &IpcWriteOptions,
-    ) -> Result<(Vec<EncodedData>, EncodedData)> {
-        // TODO: handle nested dictionaries
-        let schema = batch.schema();
-        let mut encoded_dictionaries = 
Vec::with_capacity(schema.fields().len());
-
-        for (i, field) in schema.fields().iter().enumerate() {
-            let column = batch.column(i);
-
-            if let DataType::Dictionary(_key_type, _value_type) = 
column.data_type() {
+    ) -> Result<()> {
+        // TODO: Handle other nested types (map, list, etc)
+        match column.data_type() {
+            DataType::Struct(fields) => {
+                let s = as_struct_array(column);
+                for (field, &column) in fields.iter().zip(s.columns().iter()) {
+                    self.encode_dictionaries(
+                        field,
+                        column,
+                        encoded_dictionaries,
+                        dictionary_tracker,
+                        write_options,
+                    )?;
+                }
+            }
+            DataType::Union(fields) => {
+                let union = as_union_array(column);
+                for (field, ref column) in fields
+                    .iter()
+                    .enumerate()
+                    .map(|(n, f)| (f, union.child(n as i8)))
+                {
+                    self.encode_dictionaries(
+                        field,
+                        column,
+                        encoded_dictionaries,
+                        dictionary_tracker,
+                        write_options,
+                    )?;
+                }
+            }
+            DataType::Dictionary(_key_type, _value_type) => {
                 let dict_id = field
                     .dict_id()
                     .expect("All Dictionary types have `dict_id`");
@@ -167,10 +192,33 @@ impl IpcDataGenerator {
                     ));
                 }
             }
+            _ => (),
         }
 
-        let encoded_message = self.record_batch_to_bytes(batch, write_options);
+        Ok(())
+    }
+
+    pub fn encoded_batch(
+        &self,
+        batch: &RecordBatch,
+        dictionary_tracker: &mut DictionaryTracker,
+        write_options: &IpcWriteOptions,
+    ) -> Result<(Vec<EncodedData>, EncodedData)> {
+        let schema = batch.schema();
+        let mut encoded_dictionaries = 
Vec::with_capacity(schema.fields().len());
 
+        for (i, field) in schema.fields().iter().enumerate() {
+            let column = batch.column(i);
+            self.encode_dictionaries(
+                field,
+                column,
+                &mut encoded_dictionaries,
+                dictionary_tracker,
+                write_options,
+            )?;
+        }
+
+        let encoded_message = self.record_batch_to_bytes(batch, write_options);
         Ok((encoded_dictionaries, encoded_message))
     }
 
@@ -1161,4 +1209,68 @@ mod tests {
         let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap();
         arrow_json
     }
+
+    #[test]
+    fn track_union_nested_dict() {
+        let inner: DictionaryArray<Int32Type> = vec!["a", "b", 
"a"].into_iter().collect();
+
+        let array = Arc::new(inner) as ArrayRef;
+
+        // Dict field with id 2
+        let dctfield =
+            Field::new_dict("dict", array.data_type().clone(), false, 2, 
false);
+
+        let types = Buffer::from_slice_ref(&[0_i8, 0, 0]);
+        let offsets = Buffer::from_slice_ref(&[0_i32, 1, 2]);
+
+        let union =
+            UnionArray::try_new(types, Some(offsets), vec![(dctfield, array)], 
None)
+                .unwrap();
+
+        let schema = Arc::new(Schema::new(vec![Field::new(
+            "union",
+            union.data_type().clone(),
+            false,
+        )]));
+
+        let batch = RecordBatch::try_new(schema, 
vec![Arc::new(union)]).unwrap();
+
+        let gen = IpcDataGenerator {};
+        let mut dict_tracker = DictionaryTracker::new(false);
+        gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
+            .unwrap();
+
+        // Dictionary with id 2 should have been written to the dict tracker
+        assert!(dict_tracker.written.contains_key(&2));
+    }
+
+    #[test]
+    fn track_struct_nested_dict() {
+        let inner: DictionaryArray<Int32Type> = vec!["a", "b", 
"a"].into_iter().collect();
+
+        let array = Arc::new(inner) as ArrayRef;
+
+        // Dict field with id 2
+        let dctfield =
+            Field::new_dict("dict", array.data_type().clone(), false, 2, 
false);
+
+        let s = StructArray::from(vec![(dctfield, array)]);
+        let struct_array = Arc::new(s) as ArrayRef;
+
+        let schema = Arc::new(Schema::new(vec![Field::new(
+            "struct",
+            struct_array.data_type().clone(),
+            false,
+        )]));
+
+        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
+
+        let gen = IpcDataGenerator {};
+        let mut dict_tracker = DictionaryTracker::new(false);
+        gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
+            .unwrap();
+
+        // Dictionary with id 2 should have been written to the dict tracker
+        assert!(dict_tracker.written.contains_key(&2));
+    }
 }

Reply via email to