This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new f19d1ed Add dictionary support for C data interface (#1407)
f19d1ed is described below
commit f19d1ed71c2318407cf6beaab226355e0d005daa
Author: Chao Sun <[email protected]>
AuthorDate: Tue Mar 8 23:27:43 2022 -0800
Add dictionary support for C data interface (#1407)
* initial commit
* add integration tests for python
* address comments
---
.../tests/test_sql.py | 19 ++--
arrow/src/array/ffi.rs | 22 ++++
arrow/src/datatypes/ffi.rs | 104 ++++++++++---------
arrow/src/ffi.rs | 111 +++++++++++++++++++--
4 files changed, 192 insertions(+), 64 deletions(-)
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index bacd118..058a32e 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -79,6 +79,7 @@ _supported_pyarrow_types = [
pa.field("c", pa.string()),
]
),
+ pa.dictionary(pa.int8(), pa.string()),
]
_unsupported_pyarrow_types = [
@@ -122,14 +123,6 @@ def test_type_roundtrip_raises(pyarrow_type):
with pytest.raises(pa.ArrowException):
rust.round_trip_type(pyarrow_type)
-
-def test_dictionary_type_roundtrip():
- # the dictionary type conversion is incomplete
- pyarrow_type = pa.dictionary(pa.int32(), pa.string())
- ty = rust.round_trip_type(pyarrow_type)
- assert ty == pa.int32()
-
-
@pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str)
def test_field_roundtrip(pyarrow_type):
pyarrow_field = pa.field("test", pyarrow_type, nullable=True)
@@ -263,3 +256,13 @@ def test_decimal_python():
assert a == b
del a
del b
+
+def test_dictionary_python():
+ """
+ Python -> Rust -> Python
+ """
+ a = pa.array(["a", None, "b", None, "a"], type=pa.dictionary(pa.int8(),
pa.string()))
+ b = rust.round_trip_array(a)
+ assert a == b
+ del a
+ del b
diff --git a/arrow/src/array/ffi.rs b/arrow/src/array/ffi.rs
index 847649c..976c6b8 100644
--- a/arrow/src/array/ffi.rs
+++ b/arrow/src/array/ffi.rs
@@ -45,6 +45,7 @@ impl TryFrom<ArrayData> for ffi::ArrowArray {
#[cfg(test)]
mod tests {
+ use crate::array::{DictionaryArray, Int32Array, StringArray};
use crate::error::Result;
use crate::{
array::{
@@ -127,4 +128,25 @@ mod tests {
let data = array.data();
test_round_trip(data)
}
+
+ #[test]
+ fn test_dictionary() -> Result<()> {
+ let values = StringArray::from(vec![Some("foo"), Some("bar"), None]);
+ let keys = Int32Array::from(vec![
+ Some(0),
+ Some(1),
+ None,
+ Some(1),
+ Some(1),
+ None,
+ Some(1),
+ Some(2),
+ Some(1),
+ None,
+ ]);
+ let array = DictionaryArray::try_new(&keys, &values)?;
+
+ let data = array.data();
+ test_round_trip(data)
+ }
}
diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs
index fbff4a0..10645fb 100644
--- a/arrow/src/datatypes/ffi.rs
+++ b/arrow/src/datatypes/ffi.rs
@@ -28,7 +28,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
/// See [CDataInterface
docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self> {
- let dtype = match c_schema.format() {
+ let mut dtype = match c_schema.format() {
"n" => DataType::Null,
"b" => DataType::Boolean,
"c" => DataType::Int8,
@@ -134,6 +134,12 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
}
}
};
+
+ if let Some(dict_schema) = c_schema.dictionary() {
+ let value_type = Self::try_from(dict_schema)?;
+ dtype = DataType::Dictionary(Box::new(dtype),
Box::new(value_type));
+ }
+
Ok(dtype)
}
}
@@ -169,49 +175,7 @@ impl TryFrom<&DataType> for FFI_ArrowSchema {
/// See [CDataInterface
docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
fn try_from(dtype: &DataType) -> Result<Self> {
- let format = match dtype {
- DataType::Null => "n".to_string(),
- DataType::Boolean => "b".to_string(),
- DataType::Int8 => "c".to_string(),
- DataType::UInt8 => "C".to_string(),
- DataType::Int16 => "s".to_string(),
- DataType::UInt16 => "S".to_string(),
- DataType::Int32 => "i".to_string(),
- DataType::UInt32 => "I".to_string(),
- DataType::Int64 => "l".to_string(),
- DataType::UInt64 => "L".to_string(),
- DataType::Float16 => "e".to_string(),
- DataType::Float32 => "f".to_string(),
- DataType::Float64 => "g".to_string(),
- DataType::Binary => "z".to_string(),
- DataType::LargeBinary => "Z".to_string(),
- DataType::Utf8 => "u".to_string(),
- DataType::LargeUtf8 => "U".to_string(),
- DataType::Decimal(precision, scale) => format!("d:{},{}",
precision, scale),
- DataType::Date32 => "tdD".to_string(),
- DataType::Date64 => "tdm".to_string(),
- DataType::Time32(TimeUnit::Second) => "tts".to_string(),
- DataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(),
- DataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(),
- DataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(),
- DataType::Timestamp(TimeUnit::Second, None) => "tss:".to_string(),
- DataType::Timestamp(TimeUnit::Millisecond, None) =>
"tsm:".to_string(),
- DataType::Timestamp(TimeUnit::Microsecond, None) =>
"tsu:".to_string(),
- DataType::Timestamp(TimeUnit::Nanosecond, None) =>
"tsn:".to_string(),
- DataType::Timestamp(TimeUnit::Second, Some(tz)) =>
format!("tss:{}", tz),
- DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) =>
format!("tsm:{}", tz),
- DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) =>
format!("tsu:{}", tz),
- DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) =>
format!("tsn:{}", tz),
- DataType::List(_) => "+l".to_string(),
- DataType::LargeList(_) => "+L".to_string(),
- DataType::Struct(_) => "+s".to_string(),
- other => {
- return Err(ArrowError::CDataInterface(format!(
- "The datatype \"{:?}\" is still not supported in Rust
implementation",
- other
- )))
- }
- };
+ let format = get_format_string(dtype)?;
// allocate and hold the children
let children = match dtype {
DataType::List(child) | DataType::LargeList(child) => {
@@ -223,7 +187,57 @@ impl TryFrom<&DataType> for FFI_ArrowSchema {
.collect::<Result<Vec<_>>>()?,
_ => vec![],
};
- FFI_ArrowSchema::try_new(&format, children)
+ let dictionary = if let DataType::Dictionary(_, value_data_type) =
dtype {
+ Some(Self::try_from(value_data_type.as_ref())?)
+ } else {
+ None
+ };
+ FFI_ArrowSchema::try_new(&format, children, dictionary)
+ }
+}
+
+fn get_format_string(dtype: &DataType) -> Result<String> {
+ match dtype {
+ DataType::Null => Ok("n".to_string()),
+ DataType::Boolean => Ok("b".to_string()),
+ DataType::Int8 => Ok("c".to_string()),
+ DataType::UInt8 => Ok("C".to_string()),
+ DataType::Int16 => Ok("s".to_string()),
+ DataType::UInt16 => Ok("S".to_string()),
+ DataType::Int32 => Ok("i".to_string()),
+ DataType::UInt32 => Ok("I".to_string()),
+ DataType::Int64 => Ok("l".to_string()),
+ DataType::UInt64 => Ok("L".to_string()),
+ DataType::Float16 => Ok("e".to_string()),
+ DataType::Float32 => Ok("f".to_string()),
+ DataType::Float64 => Ok("g".to_string()),
+ DataType::Binary => Ok("z".to_string()),
+ DataType::LargeBinary => Ok("Z".to_string()),
+ DataType::Utf8 => Ok("u".to_string()),
+ DataType::LargeUtf8 => Ok("U".to_string()),
+ DataType::Decimal(precision, scale) => Ok(format!("d:{},{}",
precision, scale)),
+ DataType::Date32 => Ok("tdD".to_string()),
+ DataType::Date64 => Ok("tdm".to_string()),
+ DataType::Time32(TimeUnit::Second) => Ok("tts".to_string()),
+ DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".to_string()),
+ DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".to_string()),
+ DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".to_string()),
+ DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".to_string()),
+ DataType::Timestamp(TimeUnit::Millisecond, None) =>
Ok("tsm:".to_string()),
+ DataType::Timestamp(TimeUnit::Microsecond, None) =>
Ok("tsu:".to_string()),
+ DataType::Timestamp(TimeUnit::Nanosecond, None) =>
Ok("tsn:".to_string()),
+ DataType::Timestamp(TimeUnit::Second, Some(tz)) =>
Ok(format!("tss:{}", tz)),
+ DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) =>
Ok(format!("tsm:{}", tz)),
+ DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) =>
Ok(format!("tsu:{}", tz)),
+ DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) =>
Ok(format!("tsn:{}", tz)),
+ DataType::List(_) => Ok("+l".to_string()),
+ DataType::LargeList(_) => Ok("+L".to_string()),
+ DataType::Struct(_) => Ok("+s".to_string()),
+ DataType::Dictionary(key_data_type, _) =>
get_format_string(key_data_type),
+ other => Err(ArrowError::CDataInterface(format!(
+ "The datatype \"{:?}\" is still not supported in Rust
implementation",
+ other
+ ))),
}
}
diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs
index b7a22de..461995b 100644
--- a/arrow/src/ffi.rs
+++ b/arrow/src/ffi.rs
@@ -122,6 +122,7 @@ pub struct FFI_ArrowSchema {
struct SchemaPrivateData {
children: Box<[*mut FFI_ArrowSchema]>,
+ dictionary: *mut FFI_ArrowSchema,
}
// callback used to drop [FFI_ArrowSchema] when it is exported.
@@ -141,6 +142,10 @@ unsafe extern "C" fn release_schema(schema: *mut
FFI_ArrowSchema) {
for child in private_data.children.iter() {
drop(Box::from_raw(*child))
}
+ if !private_data.dictionary.is_null() {
+ drop(Box::from_raw(private_data.dictionary));
+ }
+
drop(private_data);
}
@@ -150,7 +155,11 @@ unsafe extern "C" fn release_schema(schema: *mut
FFI_ArrowSchema) {
impl FFI_ArrowSchema {
/// create a new [`FFI_ArrowSchema`]. This fails if the fields'
/// [`DataType`] is not supported.
- pub fn try_new(format: &str, children: Vec<FFI_ArrowSchema>) ->
Result<Self> {
+ pub fn try_new(
+ format: &str,
+ children: Vec<FFI_ArrowSchema>,
+ dictionary: Option<FFI_ArrowSchema>,
+ ) -> Result<Self> {
let mut this = Self::empty();
let children_ptr = children
@@ -163,13 +172,20 @@ impl FFI_ArrowSchema {
this.release = Some(release_schema);
this.n_children = children_ptr.len() as i64;
+ let dictionary_ptr = dictionary
+ .map(|d| Box::into_raw(Box::new(d)))
+ .unwrap_or(std::ptr::null_mut());
+
let mut private_data = Box::new(SchemaPrivateData {
children: children_ptr,
+ dictionary: dictionary_ptr,
});
// intentionally set from private_data (see
https://github.com/apache/arrow-rs/issues/580)
this.children = private_data.children.as_mut_ptr();
+ this.dictionary = dictionary_ptr;
+
this.private_data = Box::into_raw(private_data) as *mut c_void;
Ok(this)
@@ -233,6 +249,10 @@ impl FFI_ArrowSchema {
pub fn nullable(&self) -> bool {
(self.flags / 2) & 1 == 1
}
+
+ pub fn dictionary(&self) -> Option<&Self> {
+ unsafe { self.dictionary.as_ref() }
+ }
}
impl Drop for FFI_ArrowSchema {
@@ -356,6 +376,9 @@ unsafe extern "C" fn release_array(array: *mut
FFI_ArrowArray) {
for child in private.children.iter() {
let _ = Box::from_raw(*child);
}
+ if !private.dictionary.is_null() {
+ let _ = Box::from_raw(private.dictionary);
+ }
array.release = None;
}
@@ -365,6 +388,7 @@ struct ArrayPrivateData {
buffers: Vec<Option<Buffer>>,
buffers_ptr: Box<[*const c_void]>,
children: Box<[*mut FFI_ArrowArray]>,
+ dictionary: *mut FFI_ArrowArray,
}
impl FFI_ArrowArray {
@@ -389,8 +413,16 @@ impl FFI_ArrowArray {
})
.collect::<Box<[_]>>();
- let children = data
- .child_data()
+ let empty = vec![];
+ let (child_data, dictionary) = match data.data_type() {
+ DataType::Dictionary(_, _) => (
+ empty.as_slice(),
+
Box::into_raw(Box::new(FFI_ArrowArray::new(&data.child_data()[0]))),
+ ),
+ _ => (data.child_data(), std::ptr::null_mut()),
+ };
+
+ let children = child_data
.iter()
.map(|child| Box::into_raw(Box::new(FFI_ArrowArray::new(child))))
.collect::<Box<_>>();
@@ -402,6 +434,7 @@ impl FFI_ArrowArray {
buffers,
buffers_ptr,
children,
+ dictionary,
});
Self {
@@ -412,7 +445,7 @@ impl FFI_ArrowArray {
n_children,
buffers: private_data.buffers_ptr.as_mut_ptr(),
children: private_data.children.as_mut_ptr(),
- dictionary: std::ptr::null_mut(),
+ dictionary,
release: Some(release_array),
private_data: Box::into_raw(private_data) as *mut c_void,
}
@@ -508,7 +541,7 @@ pub trait ArrowArrayRef {
let buffers = self.buffers()?;
let null_bit_buffer = self.null_bit_buffer();
- let child_data = (0..self.array().n_children as usize)
+ let mut child_data: Vec<ArrayData> = (0..self.array().n_children as
usize)
.map(|i| {
let child = self.child(i);
child.to_data()
@@ -516,6 +549,13 @@ pub trait ArrowArrayRef {
.map(|d| d.unwrap())
.collect();
+ if let Some(d) = self.dictionary() {
+ // For dictionary type there should only be a single child, so we
don't need to worry if
+ // there are other children added above.
+ assert!(child_data.is_empty());
+ child_data.push(d.to_data()?);
+ }
+
// Should FFI be checking validity?
Ok(unsafe {
ArrayData::new_unchecked(
@@ -555,10 +595,15 @@ pub trait ArrowArrayRef {
// for variable-sized buffers, such as the second buffer of a stringArray,
we need
// to fetch offset buffer's len to build the second buffer.
fn buffer_len(&self, i: usize) -> Result<usize> {
- // Inner type is not important for buffer length.
- let data_type = &self.data_type()?;
+ // Special handling for dictionary type as we only care about the key
type in the case.
+ let t = self.data_type()?;
+ let data_type = match &t {
+ DataType::Dictionary(key_data_type, _) => key_data_type.as_ref(),
+ dt => dt,
+ };
- Ok(match (data_type, i) {
+ // Inner type is not important for buffer length.
+ Ok(match (&data_type, i) {
(DataType::Utf8, 1)
| (DataType::LargeUtf8, 1)
| (DataType::Binary, 1)
@@ -622,6 +667,21 @@ pub trait ArrowArrayRef {
fn array(&self) -> &FFI_ArrowArray;
fn schema(&self) -> &FFI_ArrowSchema;
fn data_type(&self) -> Result<DataType>;
+ fn dictionary(&self) -> Option<ArrowArrayChild> {
+ unsafe {
+ assert!(!(self.array().dictionary.is_null() ^
self.schema().dictionary.is_null()),
+ "Dictionary should both be set or not set in
FFI_ArrowArray and FFI_ArrowSchema");
+ if !self.array().dictionary.is_null() {
+ Some(ArrowArrayChild::from_raw(
+ &*self.array().dictionary,
+ &*self.schema().dictionary,
+ self.owner().clone(),
+ ))
+ } else {
+ None
+ }
+ }
+ }
}
#[allow(rustdoc::private_intra_doc_links)]
@@ -763,12 +823,12 @@ mod tests {
use super::*;
use crate::array::{
make_array, Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray,
DecimalArray,
- GenericBinaryArray, GenericListArray, GenericStringArray, Int32Array,
- OffsetSizeTrait, StringOffsetSizeTrait, Time32MillisecondArray,
+ DictionaryArray, GenericBinaryArray, GenericListArray,
GenericStringArray,
+ Int32Array, OffsetSizeTrait, StringOffsetSizeTrait,
Time32MillisecondArray,
TimestampMillisecondArray,
};
use crate::compute::kernels;
- use crate::datatypes::Field;
+ use crate::datatypes::{Field, Int8Type};
use std::convert::TryFrom;
#[test]
@@ -1075,4 +1135,33 @@ mod tests {
// (drop/release)
Ok(())
}
+
+ #[test]
+ fn test_dictionary() -> Result<()> {
+ // create an array natively
+ let values = vec!["a", "aaa", "aaa"];
+ let dict_array: DictionaryArray<Int8Type> =
values.into_iter().collect();
+
+ // export it
+ let array = ArrowArray::try_from(dict_array.data().clone())?;
+
+ // (simulate consumer) import it
+ let data = ArrayData::try_from(array)?;
+ let array = make_array(data);
+
+ // perform some operation
+ let array = kernels::concat::concat(&[array.as_ref(),
array.as_ref()]).unwrap();
+ let actual = array
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int8Type>>()
+ .unwrap();
+
+ // verify
+ let new_values = vec!["a", "aaa", "aaa", "a", "aaa", "aaa"];
+ let expected: DictionaryArray<Int8Type> =
new_values.into_iter().collect();
+ assert_eq!(actual, &expected);
+
+ // (drop/release)
+ Ok(())
+ }
}