This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new f19d1ed  Add dictionary support for C data interface (#1407)
f19d1ed is described below

commit f19d1ed71c2318407cf6beaab226355e0d005daa
Author: Chao Sun <[email protected]>
AuthorDate: Tue Mar 8 23:27:43 2022 -0800

    Add dictionary support for C data interface (#1407)
    
    * initial commit
    
    * add integration tests for python
    
    * address comments
---
 .../tests/test_sql.py                              |  19 ++--
 arrow/src/array/ffi.rs                             |  22 ++++
 arrow/src/datatypes/ffi.rs                         | 104 ++++++++++---------
 arrow/src/ffi.rs                                   | 111 +++++++++++++++++++--
 4 files changed, 192 insertions(+), 64 deletions(-)

diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py 
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index bacd118..058a32e 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -79,6 +79,7 @@ _supported_pyarrow_types = [
             pa.field("c", pa.string()),
         ]
     ),
+    pa.dictionary(pa.int8(), pa.string()),
 ]
 
 _unsupported_pyarrow_types = [
@@ -122,14 +123,6 @@ def test_type_roundtrip_raises(pyarrow_type):
     with pytest.raises(pa.ArrowException):
         rust.round_trip_type(pyarrow_type)
 
-
-def test_dictionary_type_roundtrip():
-    # the dictionary type conversion is incomplete
-    pyarrow_type = pa.dictionary(pa.int32(), pa.string())
-    ty = rust.round_trip_type(pyarrow_type)
-    assert ty == pa.int32()
-
-
 @pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str)
 def test_field_roundtrip(pyarrow_type):
     pyarrow_field = pa.field("test", pyarrow_type, nullable=True)
@@ -263,3 +256,13 @@ def test_decimal_python():
     assert a == b
     del a
     del b
+
+def test_dictionary_python():
+    """
+    Python -> Rust -> Python
+    """
+    a = pa.array(["a", None, "b", None, "a"], type=pa.dictionary(pa.int8(), 
pa.string()))
+    b = rust.round_trip_array(a)
+    assert a == b
+    del a
+    del b
diff --git a/arrow/src/array/ffi.rs b/arrow/src/array/ffi.rs
index 847649c..976c6b8 100644
--- a/arrow/src/array/ffi.rs
+++ b/arrow/src/array/ffi.rs
@@ -45,6 +45,7 @@ impl TryFrom<ArrayData> for ffi::ArrowArray {
 
 #[cfg(test)]
 mod tests {
+    use crate::array::{DictionaryArray, Int32Array, StringArray};
     use crate::error::Result;
     use crate::{
         array::{
@@ -127,4 +128,25 @@ mod tests {
         let data = array.data();
         test_round_trip(data)
     }
+
+    #[test]
+    fn test_dictionary() -> Result<()> {
+        let values = StringArray::from(vec![Some("foo"), Some("bar"), None]);
+        let keys = Int32Array::from(vec![
+            Some(0),
+            Some(1),
+            None,
+            Some(1),
+            Some(1),
+            None,
+            Some(1),
+            Some(2),
+            Some(1),
+            None,
+        ]);
+        let array = DictionaryArray::try_new(&keys, &values)?;
+
+        let data = array.data();
+        test_round_trip(data)
+    }
 }
diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs
index fbff4a0..10645fb 100644
--- a/arrow/src/datatypes/ffi.rs
+++ b/arrow/src/datatypes/ffi.rs
@@ -28,7 +28,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
 
     /// See [CDataInterface 
docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
     fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self> {
-        let dtype = match c_schema.format() {
+        let mut dtype = match c_schema.format() {
             "n" => DataType::Null,
             "b" => DataType::Boolean,
             "c" => DataType::Int8,
@@ -134,6 +134,12 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
                 }
             }
         };
+
+        if let Some(dict_schema) = c_schema.dictionary() {
+            let value_type = Self::try_from(dict_schema)?;
+            dtype = DataType::Dictionary(Box::new(dtype), 
Box::new(value_type));
+        }
+
         Ok(dtype)
     }
 }
@@ -169,49 +175,7 @@ impl TryFrom<&DataType> for FFI_ArrowSchema {
 
     /// See [CDataInterface 
docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
     fn try_from(dtype: &DataType) -> Result<Self> {
-        let format = match dtype {
-            DataType::Null => "n".to_string(),
-            DataType::Boolean => "b".to_string(),
-            DataType::Int8 => "c".to_string(),
-            DataType::UInt8 => "C".to_string(),
-            DataType::Int16 => "s".to_string(),
-            DataType::UInt16 => "S".to_string(),
-            DataType::Int32 => "i".to_string(),
-            DataType::UInt32 => "I".to_string(),
-            DataType::Int64 => "l".to_string(),
-            DataType::UInt64 => "L".to_string(),
-            DataType::Float16 => "e".to_string(),
-            DataType::Float32 => "f".to_string(),
-            DataType::Float64 => "g".to_string(),
-            DataType::Binary => "z".to_string(),
-            DataType::LargeBinary => "Z".to_string(),
-            DataType::Utf8 => "u".to_string(),
-            DataType::LargeUtf8 => "U".to_string(),
-            DataType::Decimal(precision, scale) => format!("d:{},{}", 
precision, scale),
-            DataType::Date32 => "tdD".to_string(),
-            DataType::Date64 => "tdm".to_string(),
-            DataType::Time32(TimeUnit::Second) => "tts".to_string(),
-            DataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(),
-            DataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(),
-            DataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(),
-            DataType::Timestamp(TimeUnit::Second, None) => "tss:".to_string(),
-            DataType::Timestamp(TimeUnit::Millisecond, None) => 
"tsm:".to_string(),
-            DataType::Timestamp(TimeUnit::Microsecond, None) => 
"tsu:".to_string(),
-            DataType::Timestamp(TimeUnit::Nanosecond, None) => 
"tsn:".to_string(),
-            DataType::Timestamp(TimeUnit::Second, Some(tz)) => 
format!("tss:{}", tz),
-            DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => 
format!("tsm:{}", tz),
-            DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => 
format!("tsu:{}", tz),
-            DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => 
format!("tsn:{}", tz),
-            DataType::List(_) => "+l".to_string(),
-            DataType::LargeList(_) => "+L".to_string(),
-            DataType::Struct(_) => "+s".to_string(),
-            other => {
-                return Err(ArrowError::CDataInterface(format!(
-                    "The datatype \"{:?}\" is still not supported in Rust 
implementation",
-                    other
-                )))
-            }
-        };
+        let format = get_format_string(dtype)?;
         // allocate and hold the children
         let children = match dtype {
             DataType::List(child) | DataType::LargeList(child) => {
@@ -223,7 +187,57 @@ impl TryFrom<&DataType> for FFI_ArrowSchema {
                 .collect::<Result<Vec<_>>>()?,
             _ => vec![],
         };
-        FFI_ArrowSchema::try_new(&format, children)
+        let dictionary = if let DataType::Dictionary(_, value_data_type) = 
dtype {
+            Some(Self::try_from(value_data_type.as_ref())?)
+        } else {
+            None
+        };
+        FFI_ArrowSchema::try_new(&format, children, dictionary)
+    }
+}
+
+fn get_format_string(dtype: &DataType) -> Result<String> {
+    match dtype {
+        DataType::Null => Ok("n".to_string()),
+        DataType::Boolean => Ok("b".to_string()),
+        DataType::Int8 => Ok("c".to_string()),
+        DataType::UInt8 => Ok("C".to_string()),
+        DataType::Int16 => Ok("s".to_string()),
+        DataType::UInt16 => Ok("S".to_string()),
+        DataType::Int32 => Ok("i".to_string()),
+        DataType::UInt32 => Ok("I".to_string()),
+        DataType::Int64 => Ok("l".to_string()),
+        DataType::UInt64 => Ok("L".to_string()),
+        DataType::Float16 => Ok("e".to_string()),
+        DataType::Float32 => Ok("f".to_string()),
+        DataType::Float64 => Ok("g".to_string()),
+        DataType::Binary => Ok("z".to_string()),
+        DataType::LargeBinary => Ok("Z".to_string()),
+        DataType::Utf8 => Ok("u".to_string()),
+        DataType::LargeUtf8 => Ok("U".to_string()),
+        DataType::Decimal(precision, scale) => Ok(format!("d:{},{}", 
precision, scale)),
+        DataType::Date32 => Ok("tdD".to_string()),
+        DataType::Date64 => Ok("tdm".to_string()),
+        DataType::Time32(TimeUnit::Second) => Ok("tts".to_string()),
+        DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".to_string()),
+        DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".to_string()),
+        DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".to_string()),
+        DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".to_string()),
+        DataType::Timestamp(TimeUnit::Millisecond, None) => 
Ok("tsm:".to_string()),
+        DataType::Timestamp(TimeUnit::Microsecond, None) => 
Ok("tsu:".to_string()),
+        DataType::Timestamp(TimeUnit::Nanosecond, None) => 
Ok("tsn:".to_string()),
+        DataType::Timestamp(TimeUnit::Second, Some(tz)) => 
Ok(format!("tss:{}", tz)),
+        DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => 
Ok(format!("tsm:{}", tz)),
+        DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => 
Ok(format!("tsu:{}", tz)),
+        DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => 
Ok(format!("tsn:{}", tz)),
+        DataType::List(_) => Ok("+l".to_string()),
+        DataType::LargeList(_) => Ok("+L".to_string()),
+        DataType::Struct(_) => Ok("+s".to_string()),
+        DataType::Dictionary(key_data_type, _) => 
get_format_string(key_data_type),
+        other => Err(ArrowError::CDataInterface(format!(
+            "The datatype \"{:?}\" is still not supported in Rust 
implementation",
+            other
+        ))),
     }
 }
 
diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs
index b7a22de..461995b 100644
--- a/arrow/src/ffi.rs
+++ b/arrow/src/ffi.rs
@@ -122,6 +122,7 @@ pub struct FFI_ArrowSchema {
 
 struct SchemaPrivateData {
     children: Box<[*mut FFI_ArrowSchema]>,
+    dictionary: *mut FFI_ArrowSchema,
 }
 
 // callback used to drop [FFI_ArrowSchema] when it is exported.
@@ -141,6 +142,10 @@ unsafe extern "C" fn release_schema(schema: *mut 
FFI_ArrowSchema) {
         for child in private_data.children.iter() {
             drop(Box::from_raw(*child))
         }
+        if !private_data.dictionary.is_null() {
+            drop(Box::from_raw(private_data.dictionary));
+        }
+
         drop(private_data);
     }
 
@@ -150,7 +155,11 @@ unsafe extern "C" fn release_schema(schema: *mut 
FFI_ArrowSchema) {
 impl FFI_ArrowSchema {
     /// create a new [`FFI_ArrowSchema`]. This fails if the fields'
     /// [`DataType`] is not supported.
-    pub fn try_new(format: &str, children: Vec<FFI_ArrowSchema>) -> 
Result<Self> {
+    pub fn try_new(
+        format: &str,
+        children: Vec<FFI_ArrowSchema>,
+        dictionary: Option<FFI_ArrowSchema>,
+    ) -> Result<Self> {
         let mut this = Self::empty();
 
         let children_ptr = children
@@ -163,13 +172,20 @@ impl FFI_ArrowSchema {
         this.release = Some(release_schema);
         this.n_children = children_ptr.len() as i64;
 
+        let dictionary_ptr = dictionary
+            .map(|d| Box::into_raw(Box::new(d)))
+            .unwrap_or(std::ptr::null_mut());
+
         let mut private_data = Box::new(SchemaPrivateData {
             children: children_ptr,
+            dictionary: dictionary_ptr,
         });
 
         // intentionally set from private_data (see 
https://github.com/apache/arrow-rs/issues/580)
         this.children = private_data.children.as_mut_ptr();
 
+        this.dictionary = dictionary_ptr;
+
         this.private_data = Box::into_raw(private_data) as *mut c_void;
 
         Ok(this)
@@ -233,6 +249,10 @@ impl FFI_ArrowSchema {
     pub fn nullable(&self) -> bool {
         (self.flags / 2) & 1 == 1
     }
+
+    pub fn dictionary(&self) -> Option<&Self> {
+        unsafe { self.dictionary.as_ref() }
+    }
 }
 
 impl Drop for FFI_ArrowSchema {
@@ -356,6 +376,9 @@ unsafe extern "C" fn release_array(array: *mut 
FFI_ArrowArray) {
     for child in private.children.iter() {
         let _ = Box::from_raw(*child);
     }
+    if !private.dictionary.is_null() {
+        let _ = Box::from_raw(private.dictionary);
+    }
 
     array.release = None;
 }
@@ -365,6 +388,7 @@ struct ArrayPrivateData {
     buffers: Vec<Option<Buffer>>,
     buffers_ptr: Box<[*const c_void]>,
     children: Box<[*mut FFI_ArrowArray]>,
+    dictionary: *mut FFI_ArrowArray,
 }
 
 impl FFI_ArrowArray {
@@ -389,8 +413,16 @@ impl FFI_ArrowArray {
             })
             .collect::<Box<[_]>>();
 
-        let children = data
-            .child_data()
+        let empty = vec![];
+        let (child_data, dictionary) = match data.data_type() {
+            DataType::Dictionary(_, _) => (
+                empty.as_slice(),
+                
Box::into_raw(Box::new(FFI_ArrowArray::new(&data.child_data()[0]))),
+            ),
+            _ => (data.child_data(), std::ptr::null_mut()),
+        };
+
+        let children = child_data
             .iter()
             .map(|child| Box::into_raw(Box::new(FFI_ArrowArray::new(child))))
             .collect::<Box<_>>();
@@ -402,6 +434,7 @@ impl FFI_ArrowArray {
             buffers,
             buffers_ptr,
             children,
+            dictionary,
         });
 
         Self {
@@ -412,7 +445,7 @@ impl FFI_ArrowArray {
             n_children,
             buffers: private_data.buffers_ptr.as_mut_ptr(),
             children: private_data.children.as_mut_ptr(),
-            dictionary: std::ptr::null_mut(),
+            dictionary,
             release: Some(release_array),
             private_data: Box::into_raw(private_data) as *mut c_void,
         }
@@ -508,7 +541,7 @@ pub trait ArrowArrayRef {
         let buffers = self.buffers()?;
         let null_bit_buffer = self.null_bit_buffer();
 
-        let child_data = (0..self.array().n_children as usize)
+        let mut child_data: Vec<ArrayData> = (0..self.array().n_children as 
usize)
             .map(|i| {
                 let child = self.child(i);
                 child.to_data()
@@ -516,6 +549,13 @@ pub trait ArrowArrayRef {
             .map(|d| d.unwrap())
             .collect();
 
+        if let Some(d) = self.dictionary() {
+            // For dictionary type there should only be a single child, so we 
don't need to worry if
+            // there are other children added above.
+            assert!(child_data.is_empty());
+            child_data.push(d.to_data()?);
+        }
+
         // Should FFI be checking validity?
         Ok(unsafe {
             ArrayData::new_unchecked(
@@ -555,10 +595,15 @@ pub trait ArrowArrayRef {
     // for variable-sized buffers, such as the second buffer of a stringArray, 
we need
     // to fetch offset buffer's len to build the second buffer.
     fn buffer_len(&self, i: usize) -> Result<usize> {
-        // Inner type is not important for buffer length.
-        let data_type = &self.data_type()?;
+        // Special handling for dictionary type as we only care about the key 
type in the case.
+        let t = self.data_type()?;
+        let data_type = match &t {
+            DataType::Dictionary(key_data_type, _) => key_data_type.as_ref(),
+            dt => dt,
+        };
 
-        Ok(match (data_type, i) {
+        // Inner type is not important for buffer length.
+        Ok(match (&data_type, i) {
             (DataType::Utf8, 1)
             | (DataType::LargeUtf8, 1)
             | (DataType::Binary, 1)
@@ -622,6 +667,21 @@ pub trait ArrowArrayRef {
     fn array(&self) -> &FFI_ArrowArray;
     fn schema(&self) -> &FFI_ArrowSchema;
     fn data_type(&self) -> Result<DataType>;
+    fn dictionary(&self) -> Option<ArrowArrayChild> {
+        unsafe {
+            assert!(!(self.array().dictionary.is_null() ^ 
self.schema().dictionary.is_null()),
+                    "Dictionary should both be set or not set in 
FFI_ArrowArray and FFI_ArrowSchema");
+            if !self.array().dictionary.is_null() {
+                Some(ArrowArrayChild::from_raw(
+                    &*self.array().dictionary,
+                    &*self.schema().dictionary,
+                    self.owner().clone(),
+                ))
+            } else {
+                None
+            }
+        }
+    }
 }
 
 #[allow(rustdoc::private_intra_doc_links)]
@@ -763,12 +823,12 @@ mod tests {
     use super::*;
     use crate::array::{
         make_array, Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, 
DecimalArray,
-        GenericBinaryArray, GenericListArray, GenericStringArray, Int32Array,
-        OffsetSizeTrait, StringOffsetSizeTrait, Time32MillisecondArray,
+        DictionaryArray, GenericBinaryArray, GenericListArray, 
GenericStringArray,
+        Int32Array, OffsetSizeTrait, StringOffsetSizeTrait, 
Time32MillisecondArray,
         TimestampMillisecondArray,
     };
     use crate::compute::kernels;
-    use crate::datatypes::Field;
+    use crate::datatypes::{Field, Int8Type};
     use std::convert::TryFrom;
 
     #[test]
@@ -1075,4 +1135,33 @@ mod tests {
         // (drop/release)
         Ok(())
     }
+
+    #[test]
+    fn test_dictionary() -> Result<()> {
+        // create an array natively
+        let values = vec!["a", "aaa", "aaa"];
+        let dict_array: DictionaryArray<Int8Type> = 
values.into_iter().collect();
+
+        // export it
+        let array = ArrowArray::try_from(dict_array.data().clone())?;
+
+        // (simulate consumer) import it
+        let data = ArrayData::try_from(array)?;
+        let array = make_array(data);
+
+        // perform some operation
+        let array = kernels::concat::concat(&[array.as_ref(), 
array.as_ref()]).unwrap();
+        let actual = array
+            .as_any()
+            .downcast_ref::<DictionaryArray<Int8Type>>()
+            .unwrap();
+
+        // verify
+        let new_values = vec!["a", "aaa", "aaa", "a", "aaa", "aaa"];
+        let expected: DictionaryArray<Int8Type> = 
new_values.into_iter().collect();
+        assert_eq!(actual, &expected);
+
+        // (drop/release)
+        Ok(())
+    }
 }

Reply via email to