This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new d0d0e2f feat: add Rust binding for `Array<T>` (#348)
d0d0e2f is described below
commit d0d0e2f935cda443bd85e097a3cfb18de2a96f4d
Author: Haejoon Kim <[email protected]>
AuthorDate: Fri Jan 30 22:11:57 2026 +0900
feat: add Rust binding for `Array<T>` (#348)
This PR introduces a Rust implementation of the `Array` container.
### Key Features
* **Memory Safety**: Correctly handles reference counting for
`ObjectRef` elements. It ensures `inc_ref` is called during retrieval
(`get`) and `dec_ref` is called when elements are removed or the array
is cleared.
* **Dynamic Mutation**: Implements `push`, `pop`, `insert`, `remove`,
and `clear`. It handles internal growth and reallocation while
maintaining compatibility with the underlying C++ memory layout.
* **FFI Compatibility**: Explicitly manages the `data` pointer within
`ArrayObj` to allow C++ TVM functions to traverse the array using
standard pointer arithmetic.
* **Type System Integration**:
* Implements `AnyCompatible`, allowing `Array<T>` to be erased into
`Any` and `AnyView` and recovered via `TryFrom`.
* Implements `FromIterator` and `Extend`, enabling seamless integration
with Rust's iterator ecosystem.
## Tests
Verified with a comprehensive test suite in
`tvm-ffi/tests/test_array.rs` covering:
- [x] Basic creation and iteration.
- [x] Out-of-bounds safety.
- [x] Dynamic growth and reallocation (push/insert).
- [x] Memory integrity after element shifting (remove/insert).
- [x] Roundtrip conversions through `Any` and `AnyView`.
- [x] Parametric support for both `Tensor` and `Shape` types.
---
rust/tvm-ffi/src/any.rs | 1 +
rust/tvm-ffi/src/collections/array.rs | 341 ++++++++++++++++++++++++++++++++++
rust/tvm-ffi/src/collections/mod.rs | 3 +-
rust/tvm-ffi/src/lib.rs | 1 +
rust/tvm-ffi/tests/test_array.rs | 132 +++++++++++++
5 files changed, 477 insertions(+), 1 deletion(-)
diff --git a/rust/tvm-ffi/src/any.rs b/rust/tvm-ffi/src/any.rs
index d5a4c85..ecf8b9e 100644
--- a/rust/tvm-ffi/src/any.rs
+++ b/rust/tvm-ffi/src/any.rs
@@ -177,6 +177,7 @@ impl Any {
#[inline]
pub unsafe fn into_raw_ffi_any(this: Self) -> TVMFFIAny {
+ let this = std::mem::ManuallyDrop::new(this);
this.data
}
diff --git a/rust/tvm-ffi/src/collections/array.rs
b/rust/tvm-ffi/src/collections/array.rs
new file mode 100644
index 0000000..6f259ba
--- /dev/null
+++ b/rust/tvm-ffi/src/collections/array.rs
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+use std::fmt::Debug;
+use std::marker::PhantomData;
+use std::ops::Deref;
+
+use crate::any::TryFromTemp;
+use crate::derive::Object;
+use crate::object::{Object, ObjectArc};
+use crate::{Any, AnyCompatible, AnyView, ObjectCoreWithExtraItems,
ObjectRefCore};
+use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex;
+use tvm_ffi_sys::{TVMFFIAny, TVMFFIObject};
+
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "ffi.Array"]
+#[type_index(TypeIndex::kTVMFFIArray)]
+pub struct ArrayObj {
+ pub object: Object,
+ /// Pointer to the start of the element buffer (AddressOf(0)).
+ pub data: *mut core::ffi::c_void,
+ pub size: i64,
+ pub capacity: i64,
+ /// Optional custom deleter for the data pointer.
+ pub data_deleter: Option<unsafe extern "C" fn(*mut core::ffi::c_void)>,
+}
+
+unsafe impl ObjectCoreWithExtraItems for ArrayObj {
+ type ExtraItem = TVMFFIAny;
+ fn extra_items_count(this: &Self) -> usize {
+ this.size as usize
+ }
+}
+
+#[repr(C)]
+#[derive(Clone)]
+pub struct Array<T: AnyCompatible + Clone> {
+ data: ObjectArc<ArrayObj>,
+ _marker: PhantomData<T>,
+}
+
+impl<T: AnyCompatible + Clone> Debug for Array<T> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let full_name = std::any::type_name::<T>();
+ let short_name = full_name.split("::").last().unwrap_or(full_name);
+ write!(f, "Array<{}>[{}]", short_name, self.len())
+ }
+}
+
+impl<T: AnyCompatible + Clone> Default for Array<T> {
+ fn default() -> Self {
+ Self::new(vec![])
+ }
+}
+
+unsafe impl<T: AnyCompatible + Clone> ObjectRefCore for Array<T> {
+ type ContainerType = ArrayObj;
+
+ fn data(this: &Self) -> &ObjectArc<Self::ContainerType> {
+ &this.data
+ }
+
+ fn into_data(this: Self) -> ObjectArc<Self::ContainerType> {
+ this.data
+ }
+
+ fn from_data(data: ObjectArc<Self::ContainerType>) -> Self {
+ Self {
+ data,
+ _marker: PhantomData,
+ }
+ }
+}
+
+impl<T: AnyCompatible + Clone> Array<T> {
+ /// Creates a new Array from a vector of items.
+ pub fn new(items: Vec<T>) -> Self {
+ let capacity = items.len();
+ Self::new_with_capacity(items, capacity)
+ }
+
+ /// Internal helper to allocate an ArrayObj with specific headroom.
+ fn new_with_capacity(items: Vec<T>, capacity: usize) -> Self {
+ let size = items.len();
+
+ // Allocate with capacity
+ let arc = ObjectArc::<ArrayObj>::new_with_extra_items(ArrayObj {
+ object: Object::new(),
+ data: core::ptr::null_mut(),
+ size: size as i64,
+ capacity: capacity as i64,
+ data_deleter: None,
+ });
+
+ unsafe {
+ let raw_ptr = ObjectArc::as_raw(&arc) as *mut ArrayObj;
+ let container = &mut *raw_ptr;
+
+ let base_ptr = ArrayObj::extra_items_mut(container).as_ptr() as
*mut TVMFFIAny;
+ container.data = base_ptr as *mut _;
+
+ for (i, item) in items.into_iter().enumerate() {
+ let any: Any = Any::from(item);
+ let raw = Any::into_raw_ffi_any(any);
+ core::ptr::write(base_ptr.add(i), raw);
+ }
+ }
+ Self::from_data(arc)
+ }
+
+ pub fn len(&self) -> usize {
+ self.data.size as usize
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.len() == 0
+ }
+
+ /// Retrieves an item at the given index.
+ pub fn get(&self, index: usize) -> Result<T, crate::Error> {
+ if index >= self.len() {
+ crate::bail!(crate::error::INDEX_ERROR, "Array get index out of
bound");
+ }
+ unsafe {
+ let container = self.data.deref();
+ let base_ptr = container.data as *const TVMFFIAny;
+ let raw_any_ref = &*base_ptr.add(index);
+
+ match T::try_cast_from_any_view(raw_any_ref) {
+ Ok(val) => Ok(val),
+ Err(_) => crate::bail!(
+ crate::error::TYPE_ERROR,
+ "Failed to cast element at {} to {}",
+ index,
+ T::type_str()
+ ),
+ }
+ }
+ }
+
+ pub fn iter(&'_ self) -> ArrayIterator<'_, T> {
+ ArrayIterator {
+ array: self,
+ index: 0,
+ len: self.len(),
+ }
+ }
+
+ #[inline]
+ fn as_container(&self) -> &ArrayObj {
+ unsafe {
+ let ptr = ObjectArc::as_raw(&self.data) as *const ArrayObj;
+ &*ptr
+ }
+ }
+}
+
+// --- Index Implementation ---
+
+impl<T: AnyCompatible + Clone> std::ops::Index<usize> for Array<T> {
+ type Output = AnyView<'static>;
+
+ fn index(&self, index: usize) -> &Self::Output {
+ let container = self.as_container();
+ let len = container.size as usize;
+ if index >= len {
+ panic!(
+ "Index out of bounds: the len is {} but the index is {}",
+ len, index
+ );
+ }
+ unsafe {
+ let ptr = (container.data as *const AnyView<'static>).add(index);
+ &*ptr
+ }
+ }
+}
+
+// --- Iterator Implementations ---
+
+pub struct ArrayIterator<'a, T: AnyCompatible + Clone> {
+ array: &'a Array<T>,
+ index: usize,
+ len: usize,
+}
+
+impl<'a, T: AnyCompatible + Clone> Iterator for ArrayIterator<'a, T> {
+ type Item = T;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.index < self.len {
+ let item = self.array.get(self.index).ok();
+ self.index += 1;
+ item
+ } else {
+ None
+ }
+ }
+}
+
+impl<'a, T: AnyCompatible + Clone> IntoIterator for &'a Array<T> {
+ type Item = T;
+ type IntoIter = ArrayIterator<'a, T>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ self.iter()
+ }
+}
+
+impl<T: AnyCompatible + Clone> FromIterator<T> for Array<T> {
+ fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
+ let items: Vec<T> = iter.into_iter().collect();
+ Self::new(items)
+ }
+}
+
+// --- Any Type System Conversions ---
+
+unsafe impl<T> AnyCompatible for Array<T>
+where
+ T: AnyCompatible + Clone + 'static,
+{
+ fn type_str() -> String {
+ format!("Array<{}>", T::type_str())
+ }
+
+ unsafe fn check_any_strict(data: &TVMFFIAny) -> bool {
+ if data.type_index != TypeIndex::kTVMFFIArray as i32 {
+ return false;
+ }
+
+ if std::any::TypeId::of::<T>() == std::any::TypeId::of::<Any>() {
+ return true;
+ }
+
+ let container = &*(data.data_union.v_obj as *const ArrayObj);
+ let base_ptr = container.data as *const TVMFFIAny;
+ for i in 0..container.size {
+ let elem_any = &*base_ptr.add(i as usize);
+ if !T::check_any_strict(elem_any) {
+ return false;
+ }
+ }
+ true
+ }
+
+ unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) {
+ data.type_index = TypeIndex::kTVMFFIArray as i32;
+ data.data_union.v_obj = ObjectArc::as_raw(Self::data(src)) as *mut
TVMFFIObject;
+ data.small_str_len = 0;
+ }
+
+ unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) {
+ data.type_index = TypeIndex::kTVMFFIArray as i32;
+ data.data_union.v_obj = ObjectArc::into_raw(Self::into_data(src)) as
*mut TVMFFIObject;
+ data.small_str_len = 0;
+ }
+
+ unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self {
+ let ptr = data.data_union.v_obj as *const ArrayObj;
+ crate::object::unsafe_::inc_ref(ptr as *mut TVMFFIObject);
+ Self::from_data(ObjectArc::from_raw(ptr))
+ }
+
+ unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self {
+ let ptr = data.data_union.v_obj as *const ArrayObj;
+ let obj = Self::from_data(ObjectArc::from_raw(ptr));
+
+ data.type_index = TypeIndex::kTVMFFINone as i32;
+ data.data_union.v_int64 = 0;
+
+ obj
+ }
+
+ unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result<Self, ()> {
+ if data.type_index != TypeIndex::kTVMFFIArray as i32 {
+ return Err(());
+ }
+
+ // Fast path: if types match exactly, we can just copy the reference.
+ if Self::check_any_strict(data) {
+ return Ok(Self::copy_from_any_view_after_check(data));
+ }
+
+ // Slow path: try to convert element by element.
+ let container = &*(data.data_union.v_obj as *const ArrayObj);
+ let base_ptr = container.data as *const TVMFFIAny;
+ let mut items = Vec::with_capacity(container.size as usize);
+
+ for i in 0..container.size {
+ let any_v = &*base_ptr.add(i as usize);
+ if let Ok(item) = T::try_cast_from_any_view(any_v) {
+ items.push(item);
+ } else {
+ return Err(());
+ }
+ }
+
+ Ok(Array::new(items))
+ }
+}
+
+impl<T> TryFrom<Any> for Array<T>
+where
+ T: AnyCompatible + Clone + 'static,
+{
+ type Error = crate::error::Error;
+
+ fn try_from(value: Any) -> Result<Self, Self::Error> {
+ let temp: TryFromTemp<Self> = TryFromTemp::try_from(value)?;
+ Ok(TryFromTemp::into_value(temp))
+ }
+}
+
+impl<'a, T> TryFrom<AnyView<'a>> for Array<T>
+where
+ T: AnyCompatible + Clone + 'static,
+{
+ type Error = crate::error::Error;
+
+ fn try_from(value: AnyView<'a>) -> Result<Self, Self::Error> {
+ let temp: TryFromTemp<Self> = TryFromTemp::try_from(value)?;
+ Ok(TryFromTemp::into_value(temp))
+ }
+}
diff --git a/rust/tvm-ffi/src/collections/mod.rs
b/rust/tvm-ffi/src/collections/mod.rs
index 85635a7..ad17dcc 100644
--- a/rust/tvm-ffi/src/collections/mod.rs
+++ b/rust/tvm-ffi/src/collections/mod.rs
@@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-pub mod shape;
/// Collection types
+pub mod array;
+pub mod shape;
pub mod tensor;
diff --git a/rust/tvm-ffi/src/lib.rs b/rust/tvm-ffi/src/lib.rs
index 94e87b0..fad8260 100644
--- a/rust/tvm-ffi/src/lib.rs
+++ b/rust/tvm-ffi/src/lib.rs
@@ -32,6 +32,7 @@ pub mod type_traits;
pub use tvm_ffi_sys;
pub use crate::any::{Any, AnyView};
+pub use crate::collections::array::Array;
pub use crate::collections::shape::Shape;
pub use crate::collections::tensor::{CPUNDAlloc, NDAllocator, Tensor};
pub use crate::device::{current_stream, with_stream};
diff --git a/rust/tvm-ffi/tests/test_array.rs b/rust/tvm-ffi/tests/test_array.rs
new file mode 100644
index 0000000..fe87c5f
--- /dev/null
+++ b/rust/tvm-ffi/tests/test_array.rs
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+use tvm_ffi::*;
+
+/// Helper to create a Tensor with a specific float value and shape
+fn create_tensor(val: f32, shape: &[i64]) -> Tensor {
+ let dtype = DLDataType::new(DLDataTypeCode::kDLFloat, 32, 1);
+ let device = DLDevice::new(DLDeviceType::kDLCPU, 0);
+ let tensor = Tensor::from_nd_alloc(CPUNDAlloc {}, shape, dtype, device);
+ if let Ok(slice) = tensor.data_as_slice_mut::<f32>() {
+ slice[0] = val;
+ }
+ tensor
+}
+
+/// Helper to extract the first float value from a Tensor
+fn get_val(tensor: &Tensor) -> f32 {
+ tensor
+ .data_as_slice::<f32>()
+ .expect("Type mismatch or null")[0]
+}
+
+#[test]
+fn test_array_core_and_iteration() {
+ let t1 = create_tensor(10.0, &[1, 2]);
+ let t2 = create_tensor(20.0, &[3, 4, 5]);
+
+ let array = Array::new(vec![t1.clone(), t2.clone()]);
+
+ // Core Accessors
+ assert_eq!(array.len(), 2);
+ assert!(!array.is_empty());
+
+ // Value Integrity
+ assert_eq!(get_val(&Tensor::try_from(array[0]).unwrap()), 10.0);
+ assert_eq!(Tensor::try_from(array[0]).unwrap().ndim(), 2);
+ assert_eq!(Tensor::try_from(array[1]).unwrap().ndim(), 3);
+
+ // Iteration
+ let vals: Vec<f32> = array.iter().map(|t| get_val(&t)).collect();
+ assert_eq!(vals, vec![10.0, 20.0]);
+}
+
+#[test]
+fn test_array_any_conversions() {
+ let array = Array::new(vec![
+ create_tensor(1.0, &[1]),
+ create_tensor(2.0, &[1]),
+ create_tensor(3.0, &[1]),
+ ]);
+
+ // Test Any/AnyView Roundtrip (Verifies AnyCompatible and Trait Bounds)
+ let any = Any::from(array);
+ assert_eq!(any.type_index(), TypeIndex::kTVMFFIArray as i32);
+
+ let back: Array<Tensor> = Array::try_from(any).expect("Any -> Array
failed");
+ assert_eq!(back.len(), 3);
+ assert_eq!(get_val(&back.get(2).unwrap()), 3.0);
+
+ let view = AnyView::from(&back);
+ let back_from_view: Array<Tensor> = Array::try_from(view).expect("AnyView
-> Array failed");
+ assert_eq!(back_from_view.len(), 3);
+}
+
+#[test]
+fn test_array_recursive_type_checking() {
+ // 1. Create an Array of Shapes
+ let shape_array = Array::new(vec![Shape::from(vec![1, 2]),
Shape::from(vec![3])]);
+
+ // 2. Wrap it in Any
+ let any_val = Any::from(shape_array);
+
+ // 3. Try to convert Any (containing Shapes) into Array<Tensor>
+ // This should FAIL because T::check_any_strict (Tensor) will fail on
Shape elements
+ let tensor_cast = Array::<Tensor>::try_from(any_val.clone());
+ assert!(
+ tensor_cast.is_err(),
+ "Should not be able to cast Array<Shape> to Array<Tensor>"
+ );
+
+ // 4. Verify valid cast works
+ let shape_cast = Array::<Shape>::try_from(any_val);
+ assert!(
+ shape_cast.is_ok(),
+ "Should be able to cast back to correct type"
+ );
+}
+
+#[test]
+fn test_array_parametric_heterogeneity() {
+ // Verify Array works with different ObjectRefCore types
+ let shape_array = Array::new(vec![Shape::from(vec![1, 2, 3]),
Shape::from(vec![10])]);
+ assert_eq!(shape_array.get(0).unwrap().as_slice(), &[1, 2, 3]);
+ assert_eq!(shape_array.get(1).unwrap().as_slice(), &[10]);
+
+ let function_array = Array::new(vec![
+ Function::get_global("ffi.String").unwrap(),
+ Function::get_global("ffi.Bytes").unwrap(),
+ ]);
+ assert_eq!(
+ into_typed_fn!(
+ function_array.get(0).unwrap(),
+ Fn(String) -> Result<String>
+ )("hello".into())
+ .unwrap(),
+ "hello"
+ );
+ assert_eq!(
+ into_typed_fn!(
+ function_array.get(1).unwrap(),
+ Fn(Bytes) -> Result<Bytes>
+ )([1, 2, 3].into())
+ .unwrap(),
+ &[1, 2, 3]
+ );
+}