This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new c1af3b33 [FEAT] Introduce mutable Dict (#463)
c1af3b33 is described below
commit c1af3b337645bed13f573560910dee2743d7d3b1
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Feb 19 11:50:02 2026 -0500
[FEAT] Introduce mutable Dict (#463)
This PR introduces Dict which serves as a mutable
variant of the map data structure.
Overall Dict shares the same base class that backs map(immutable),
enabling possible future possibility of freeze a dict to immutable map.
---
include/tvm/ffi/c_api.h | 2 +
include/tvm/ffi/container/dict.h | 374 ++++++++++++++++++++++++++
include/tvm/ffi/container/map.h | 94 +------
include/tvm/ffi/container/map_base.h | 298 ++++++++++++++++++--
include/tvm/ffi/object.h | 2 +
python/tvm_ffi/__init__.py | 3 +-
python/tvm_ffi/_ffi_api.py | 18 ++
python/tvm_ffi/_optional_torch_c_dlpack.py | 1 +
python/tvm_ffi/container.py | 148 +++++++++-
python/tvm_ffi/cpp/dtype.py | 1 +
python/tvm_ffi/cython/base.pxi | 1 +
python/tvm_ffi/cython/type_info.pxi | 1 +
python/tvm_ffi/registry.py | 2 +-
python/tvm_ffi/testing/_ffi_api.py | 4 +
src/ffi/container.cc | 34 ++-
src/ffi/extra/deep_copy.cc | 13 +
src/ffi/extra/json_writer.cc | 1 +
src/ffi/extra/repr_print.cc | 15 +-
src/ffi/extra/serialization.cc | 37 ++-
src/ffi/extra/structural_equal.cc | 9 +-
src/ffi/extra/structural_hash.cc | 13 +-
src/ffi/object.cc | 1 +
src/ffi/testing/testing.cc | 7 +
tests/cpp/extra/test_serialization.cc | 72 +++++
tests/cpp/extra/test_structural_equal_hash.cc | 53 ++++
tests/cpp/test_dict.cc | 229 ++++++++++++++++
tests/python/test_container.py | 219 +++++++++++++++
tests/python/test_copy.py | 47 ++++
tests/python/test_cubin_launcher.py | 1 +
tests/python/test_metadata.py | 2 +
tests/python/test_optional_torch_c_dlpack.py | 1 +
tests/python/test_repr.py | 42 ++-
tests/python/test_serialization.py | 60 ++++-
tests/python/test_structural.py | 21 ++
tests/python/test_tensor.py | 1 +
35 files changed, 1691 insertions(+), 136 deletions(-)
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index a48641be..c195ad30 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -175,6 +175,8 @@ typedef enum {
kTVMFFIOpaquePyObject = 74,
/*! \brief List object. */
kTVMFFIList = 75,
+ /*! \brief Dict object. */
+ kTVMFFIDict = 76,
//----------------------------------------------------------------
// more complex objects
//----------------------------------------------------------------
diff --git a/include/tvm/ffi/container/dict.h b/include/tvm/ffi/container/dict.h
new file mode 100644
index 00000000..d28f39de
--- /dev/null
+++ b/include/tvm/ffi/container/dict.h
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ffi/container/dict.h
+ * \brief Mutable dictionary container type.
+ *
+ * All handles sharing the same DictObj see mutations immediately.
+ */
+#ifndef TVM_FFI_CONTAINER_DICT_H_
+#define TVM_FFI_CONTAINER_DICT_H_
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/container/container_details.h>
+#include <tvm/ffi/container/map_base.h>
+#include <tvm/ffi/memory.h>
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/optional.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace ffi {
+
+/*! \brief Dict object — mutable map with shared reference semantics. */
+class DictObj : public MapBaseObj {
+ public:
+ /// \cond Doxygen_Suppress
+ static constexpr const int32_t _type_index = TypeIndex::kTVMFFIDict;
+ static const constexpr bool _type_final = true;
+ TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIDict, DictObj,
Object);
+ /// \endcond
+
+ protected:
+ template <typename, typename, typename>
+ friend class Dict;
+};
+
+static_assert(sizeof(DictObj) == sizeof(MapBaseObj), "DictObj must match
MapBaseObj layout");
+
+/*!
+ * \brief Mutable dictionary container with shared reference semantics.
+ *
+ * Mutations happen directly on the underlying shared DictObj.
+ * All handles sharing the same DictObj see mutations immediately.
+ *
+ * \tparam K The key type.
+ * \tparam V The value type.
+ */
+template <typename K, typename V,
+ typename = typename std::enable_if_t<details::storage_enabled_v<K> &&
+ details::storage_enabled_v<V>>>
+class Dict : public ObjectRef {
+ public:
+ /*! \brief The key type of the dict */
+ using key_type = K;
+ /*! \brief The mapped type of the dict */
+ using mapped_type = V;
+ /*! \brief The iterator type of the dict */
+ class iterator;
+ /*!
+ * \brief Construct a Dict with UnsafeInit
+ */
+ explicit Dict(UnsafeInit tag) : ObjectRef(tag) {}
+ /*!
+ * \brief default constructor
+ */
+ Dict() { data_ = DictObj::Empty<DictObj>(); }
+ /*!
+ * \brief move constructor
+ * \param other source
+ */
+ Dict(Dict<K, V>&& other) // NOLINT(google-explicit-constructor)
+ : ObjectRef(std::move(other.data_)) {}
+ /*!
+ * \brief copy constructor
+ * \param other source
+ */
+ Dict(const Dict<K, V>& other) // NOLINT(google-explicit-constructor)
+ : ObjectRef(other.data_) {}
+
+ /*!
+ * \brief Move constructor
+ * \param other The other dict
+ * \tparam KU The key type of the other dict
+ * \tparam VU The mapped type of the other dict
+ */
+ template <typename KU, typename VU,
+ typename = std::enable_if_t<details::type_contains_v<K, KU> &&
+ details::type_contains_v<V, VU>>>
+ Dict(Dict<KU, VU>&& other) // NOLINT(google-explicit-constructor)
+ : ObjectRef(std::move(other.data_)) {}
+
+ /*!
+ * \brief Copy constructor
+ * \param other The other dict
+ * \tparam KU The key type of the other dict
+ * \tparam VU The mapped type of the other dict
+ */
+ template <typename KU, typename VU,
+ typename = std::enable_if_t<details::type_contains_v<K, KU> &&
+ details::type_contains_v<V, VU>>>
+ // NOLINTNEXTLINE(google-explicit-constructor)
+ Dict(const Dict<KU, VU>& other) : ObjectRef(other.data_) {}
+
+ /*!
+ * \brief Move assignment
+ * \param other The other dict
+ */
+ Dict<K, V>& operator=(Dict<K, V>&& other) {
+ data_ = std::move(other.data_);
+ return *this;
+ }
+
+ /*!
+ * \brief Copy assignment
+ * \param other The other dict
+ */
+ Dict<K, V>& operator=(const Dict<K, V>& other) {
+ data_ = other.data_;
+ return *this;
+ }
+
+ /*!
+ * \brief Move assignment
+ * \param other The other dict
+ * \tparam KU The key type of the other dict
+ * \tparam VU The mapped type of the other dict
+ */
+ template <typename KU, typename VU,
+ typename = std::enable_if_t<details::type_contains_v<K, KU> &&
+ details::type_contains_v<V, VU>>>
+ Dict<K, V>& operator=(Dict<KU, VU>&& other) {
+ data_ = std::move(other.data_);
+ return *this;
+ }
+
+ /*!
+ * \brief Copy assignment
+ * \param other The other dict
+ * \tparam KU The key type of the other dict
+ * \tparam VU The mapped type of the other dict
+ */
+ template <typename KU, typename VU,
+ typename = std::enable_if_t<details::type_contains_v<K, KU> &&
+ details::type_contains_v<V, VU>>>
+ Dict<K, V>& operator=(const Dict<KU, VU>& other) {
+ data_ = other.data_;
+ return *this;
+ }
+ /*!
+ * \brief constructor from pointer
+ * \param n the container pointer
+ */
+ explicit Dict(ObjectPtr<Object> n) : ObjectRef(n) {}
+ /*!
+ * \brief constructor from iterator
+ * \param begin begin of iterator
+ * \param end end of iterator
+ * \tparam IterType The type of iterator
+ */
+ template <typename IterType>
+ Dict(IterType begin, IterType end) {
+ data_ = DictObj::CreateFromRange<DictObj>(begin, end);
+ }
+ /*!
+ * \brief constructor from initializer list
+ * \param init The initalizer list
+ */
+ Dict(std::initializer_list<std::pair<K, V>> init) {
+ data_ = DictObj::CreateFromRange<DictObj>(init.begin(), init.end());
+ }
+ /*!
+ * \brief constructor from unordered_map
+ * \param init The unordered_map
+ */
+ template <typename Hash, typename Equal>
+ Dict(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
+ data_ = DictObj::CreateFromRange<DictObj>(init.begin(), init.end());
+ }
+ /*!
+ * \brief Read element from dict.
+ * \param key The key
+ * \return the corresponding element.
+ */
+ V at(const K& key) const {
+ return
details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(GetDictObj()->at(key));
+ }
+ /*!
+ * \brief Read element from dict.
+ * \param key The key
+ * \return the corresponding element.
+ */
+ V operator[](const K& key) const { return this->at(key); }
+ /*! \return The size of the dict */
+ size_t size() const {
+ DictObj* n = GetDictObj();
+ return n == nullptr ? 0 : n->size();
+ }
+ /*! \return The number of elements of the key */
+ size_t count(const K& key) const {
+ DictObj* n = GetDictObj();
+ return n == nullptr ? 0 : n->count(key);
+ }
+ /*! \return whether dict is empty */
+ bool empty() const { return size() == 0; }
+ /*! \brief Release reference to all the elements */
+ void clear() {
+ DictObj* n = GetDictObj();
+ if (n != nullptr) {
+ n->clear();
+ }
+ }
+ /*!
+ * \brief Set a key-value pair in the Dict (mutates in-place).
+ * \param key The index key.
+ * \param value The value to be set.
+ */
+ void Set(const K& key, const V& value) {
+ EnsureDictObj();
+ ObjectPtr<Object> new_container =
+ MapBaseObj::InsertMaybeReHash<DictObj>(DictObj::KVType(key, value),
data_);
+ if (new_container != nullptr) {
+
static_cast<MapBaseObj*>(data_.get())->InplaceSwitchTo(std::move(new_container));
+ }
+ }
+ /*! \return begin iterator */
+ iterator begin() const { return iterator(GetDictObj()->begin()); }
+ /*! \return end iterator */
+ iterator end() const { return iterator(GetDictObj()->end()); }
+ /*! \return find the key and returns the associated iterator */
+ iterator find(const K& key) const { return
iterator(GetDictObj()->find(key)); }
+ /*! \return The value associated with the key, std::nullopt if not found */
+ std::optional<V> Get(const K& key) const {
+ DictObj::iterator iter = GetDictObj()->find(key);
+ if (iter == GetDictObj()->end()) {
+ return std::nullopt;
+ }
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(iter->second);
+ }
+
+ /*!
+ * \brief Erase the entry associated with the key (mutates in-place)
+ * \param key The key
+ */
+ void erase(const K& key) {
+ DictObj* n = GetDictObj();
+ if (n != nullptr) {
+ n->erase(key);
+ }
+ }
+
+ /*! \brief specify container node */
+ using ContainerType = DictObj;
+
+ /// \cond Doxygen_Suppress
+ /*! \brief Iterator of the hash map */
+ class iterator {
+ public:
+ using iterator_category = std::bidirectional_iterator_tag;
+ using difference_type = int64_t;
+ using value_type = const std::pair<K, V>;
+ using pointer = value_type*;
+ using reference = value_type;
+
+ iterator() : itr() {}
+
+ /*! \brief Compare iterators */
+ bool operator==(const iterator& other) const { return itr == other.itr; }
+ /*! \brief Compare iterators */
+ bool operator!=(const iterator& other) const { return itr != other.itr; }
+ /*! \brief De-reference iterators is not allowed */
+ pointer operator->() const = delete;
+ /*! \brief De-reference iterators */
+ reference operator*() const {
+ auto& kv = *itr;
+ return
std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck<K>(kv.first),
+
details::AnyUnsafe::CopyFromAnyViewAfterCheck<V>(kv.second));
+ }
+ /*! \brief Prefix self increment, e.g. ++iter */
+ iterator& operator++() {
+ ++itr;
+ return *this;
+ }
+ /*! \brief Suffix self increment */
+ iterator operator++(int) {
+ iterator copy = *this;
+ ++(*this);
+ return copy;
+ }
+
+ /*! \brief Prefix self decrement, e.g. --iter */
+ iterator& operator--() {
+ --itr;
+ return *this;
+ }
+ /*! \brief Suffix self decrement */
+ iterator operator--(int) {
+ iterator copy = *this;
+ --(*this);
+ return copy;
+ }
+
+ private:
+ iterator(const DictObj::iterator& itr) // NOLINT(*)
+ : itr(itr) {}
+
+ template <typename, typename, typename>
+ friend class Dict;
+
+ DictObj::iterator itr;
+ };
+ /// \endcond
+
+ private:
+ /*! \brief Return data_ as type of pointer of DictObj */
+ DictObj* GetDictObj() const { return static_cast<DictObj*>(data_.get()); }
+
+ /*! \brief Ensure we have a valid DictObj */
+ void EnsureDictObj() {
+ if (data_ == nullptr) {
+ data_ = DictObj::Empty<DictObj>();
+ }
+ }
+
+ template <typename, typename, typename>
+ friend class Dict;
+};
+
+// Traits for Dict
+template <typename K, typename V>
+inline constexpr bool use_default_type_traits_v<Dict<K, V>> = false;
+
+template <typename K, typename V>
+struct TypeTraits<Dict<K, V>> : public MapTypeTraitsBase<TypeTraits<Dict<K,
V>>, Dict<K, V>, K, V> {
+ static constexpr int32_t kPrimaryTypeIndex = TypeIndex::kTVMFFIDict;
+ static constexpr int32_t kOtherTypeIndex = TypeIndex::kTVMFFIMap;
+ static constexpr const char* kTypeName = "Dict";
+
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ std::ostringstream oss;
+ oss << R"({"type":")" << StaticTypeKey::kTVMFFIDict << R"(","args":[)";
+ oss << details::TypeSchema<K>::v() << ",";
+ oss << details::TypeSchema<V>::v();
+ oss << "]}";
+ return oss.str();
+ }
+};
+
+namespace details {
+template <typename K, typename V, typename KU, typename VU>
+inline constexpr bool type_contains_v<Dict<K, V>, Dict<KU, VU>> =
+ type_contains_v<K, KU> && type_contains_v<V, VU>;
+} // namespace details
+
+} // namespace ffi
+} // namespace tvm
+#endif // TVM_FFI_CONTAINER_DICT_H_
diff --git a/include/tvm/ffi/container/map.h b/include/tvm/ffi/container/map.h
index cfd9f720..20153a06 100644
--- a/include/tvm/ffi/container/map.h
+++ b/include/tvm/ffi/container/map.h
@@ -19,7 +19,7 @@
/*!
* \file tvm/ffi/container/map.h
- * \brief Runtime Map container types.
+ * \brief Immutable Map container type.
*/
#ifndef TVM_FFI_CONTAINER_MAP_H_
#define TVM_FFI_CONTAINER_MAP_H_
@@ -229,7 +229,11 @@ class Map : public ObjectRef {
*/
void Set(const K& key, const V& value) {
CopyOnWrite();
- MapObj::InsertMaybeReHash<MapObj>(MapObj::KVType(key, value), &data_);
+ ObjectPtr<Object> new_data =
+ MapObj::InsertMaybeReHash<MapObj>(MapObj::KVType(key, value), data_);
+ if (new_data != nullptr) {
+ data_ = std::move(new_data);
+ }
}
/*! \return begin iterator */
iterator begin() const { return iterator(GetMapObj()->begin()); }
@@ -359,89 +363,11 @@ template <typename K, typename V>
inline constexpr bool use_default_type_traits_v<Map<K, V>> = false;
template <typename K, typename V>
-struct TypeTraits<Map<K, V>> : public ObjectRefTypeTraitsBase<Map<K, V>> {
- static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap;
- using ObjectRefTypeTraitsBase<Map<K, V>>::CopyFromAnyViewAfterCheck;
-
- TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
- if (src->type_index != TypeIndex::kTVMFFIMap) {
- return TypeTraitsBase::GetMismatchTypeInfo(src);
- }
- if constexpr (!std::is_same_v<K, Any> || !std::is_same_v<V, Any>) {
- const MapObj* n = reinterpret_cast<const MapObj*>(src->v_obj);
- for (const auto& kv : *n) {
- if constexpr (!std::is_same_v<K, Any>) {
- if (!details::AnyUnsafe::CheckAnyStrict<K>(kv.first) &&
- !kv.first.try_cast<K>().has_value()) {
- return "Map[some key is " +
details::AnyUnsafe::GetMismatchTypeInfo<K>(kv.first) +
- ", V]";
- }
- }
- if constexpr (!std::is_same_v<V, Any>) {
- if (!details::AnyUnsafe::CheckAnyStrict<V>(kv.second) &&
- !kv.second.try_cast<V>().has_value()) {
- return "Map[K, some value is " +
details::AnyUnsafe::GetMismatchTypeInfo<V>(kv.second) +
- "]";
- }
- }
- }
- }
- TVM_FFI_THROW(InternalError) << "Cannot reach here";
- TVM_FFI_UNREACHABLE();
- }
-
- TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
- if (src->type_index != TypeIndex::kTVMFFIMap) return false;
- if constexpr (std::is_same_v<K, Any> && std::is_same_v<V, Any>) {
- return true;
- } else {
- const MapObj* n = reinterpret_cast<const MapObj*>(src->v_obj);
- for (const auto& kv : *n) {
- if constexpr (!std::is_same_v<K, Any>) {
- if (!details::AnyUnsafe::CheckAnyStrict<K>(kv.first)) return false;
- }
- if constexpr (!std::is_same_v<V, Any>) {
- if (!details::AnyUnsafe::CheckAnyStrict<V>(kv.second)) return false;
- }
- }
- return true;
- }
- }
-
- TVM_FFI_INLINE static std::optional<Map<K, V>> TryCastFromAnyView(const
TVMFFIAny* src) {
- if (src->type_index != TypeIndex::kTVMFFIMap) return std::nullopt;
- if constexpr (!std::is_same_v<K, Any> || !std::is_same_v<V, Any>) {
- const MapObj* n = reinterpret_cast<const MapObj*>(src->v_obj);
- bool storage_check = [&]() {
- for (const auto& kv : *n) {
- if constexpr (!std::is_same_v<K, Any>) {
- if (!details::AnyUnsafe::CheckAnyStrict<K>(kv.first)) return false;
- }
- if constexpr (!std::is_same_v<V, Any>) {
- if (!details::AnyUnsafe::CheckAnyStrict<V>(kv.second)) return
false;
- }
- }
- return true;
- }();
- // fast path, if storage check passes, we can return the array directly.
- if (storage_check) return CopyFromAnyViewAfterCheck(src);
- // slow path, we need to create a new map and convert to the target type.
- Map<K, V> ret;
- for (const auto& kv : *n) {
- auto k = kv.first.try_cast<K>();
- auto v = kv.second.try_cast<V>();
- if (!k.has_value() || !v.has_value()) return std::nullopt;
- ret.Set(*std::move(k), *std::move(v));
- }
- return ret;
- } else {
- return CopyFromAnyViewAfterCheck(src);
- }
- }
+struct TypeTraits<Map<K, V>> : public MapTypeTraitsBase<TypeTraits<Map<K, V>>,
Map<K, V>, K, V> {
+ static constexpr int32_t kPrimaryTypeIndex = TypeIndex::kTVMFFIMap;
+ static constexpr int32_t kOtherTypeIndex = TypeIndex::kTVMFFIDict;
+ static constexpr const char* kTypeName = "Map";
- TVM_FFI_INLINE static std::string TypeStr() {
- return "Map<" + details::Type2Str<K>::v() + ", " +
details::Type2Str<V>::v() + ">";
- }
TVM_FFI_INLINE static std::string TypeSchema() {
std::ostringstream oss;
oss << R"({"type":")" << StaticTypeKey::kTVMFFIMap << R"(","args":[)";
diff --git a/include/tvm/ffi/container/map_base.h
b/include/tvm/ffi/container/map_base.h
index 1009bf15..2bd41d01 100644
--- a/include/tvm/ffi/container/map_base.h
+++ b/include/tvm/ffi/container/map_base.h
@@ -195,6 +195,15 @@ class MapBaseObj : public Object {
#if TVM_FFI_DEBUG_WITH_ABI_CHANGE
uint64_t state_marker;
#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE
+ /*!
+ * \brief Inplace switch the storage contents to the other map
+ *
+ * The other map will be consumed after this operation
+ * The current content will be reset after this operation
+ *
+ * \param other The other map
+ */
+ inline void InplaceSwitchTo(ObjectPtr<Object>&& other);
/*!
* \brief Create an empty container
* \return The object created
@@ -214,11 +223,12 @@ class MapBaseObj : public Object {
/*!
* \brief InsertMaybeReHash an entry into the given hash map
* \param kv The entry to be inserted
- * \param map The pointer to the map, can be changed if re-hashing happens
+ * \param map The reference to the map container
* \tparam MapObjType The type of map object
+ * \return A new container if re-hashing happens, nullptr otherwise
*/
template <typename MapObjType>
- static inline void InsertMaybeReHash(KVType&& kv, ObjectPtr<Object>* map);
+ static inline ObjectPtr<Object> InsertMaybeReHash(KVType&& kv, const
ObjectPtr<Object>& map);
/*!
* \brief Create an empty container with elements copying from another
SmallMapBaseObj
* \param from The source container
@@ -257,6 +267,9 @@ class MapBaseObj : public Object {
friend class SmallMapBaseObj;
friend class DenseMapBaseObj;
+ template <typename, typename, typename>
+ friend class Dict;
+
template <typename, typename>
friend struct TypeTraits;
};
@@ -682,6 +695,35 @@ class DenseMapBaseObj : public MapBaseObj {
clear();
ReleaseMemory();
}
+ /*!
+ * \brief Inplace switch the storage contents to the other map
+ *
+ * The other map will be consumed after this operation
+ * The current content will be reset after this operation
+ *
+ * \param other The other map
+ */
+ void InplaceSwitchTo(ObjectPtr<Object>&& other) {
+ this->Reset();
+ MapBaseObj* other_map = static_cast<MapBaseObj*>(other.get());
+ // to simplify implementation, inplace switch from another dense map
+ // since all current switch usecases are pushing elements to map
+ // so we don't need to handle small -> dense switch
+ TVM_FFI_ICHECK(!other_map->IsSmallMap());
+ DenseMapBaseObj* other_dense_map =
static_cast<DenseMapBaseObj*>(other_map);
+ DenseMapBaseObj* this_dense_map = this;
+ this_dense_map->size_ = other_dense_map->size_;
+ this_dense_map->slots_ = other_dense_map->slots_;
+ this_dense_map->data_ = other_dense_map->data_;
+ this_dense_map->data_deleter_ = other_dense_map->data_deleter_;
+ this_dense_map->fib_shift_ = other_dense_map->fib_shift_;
+ this_dense_map->iter_list_head_ = other_dense_map->iter_list_head_;
+ this_dense_map->iter_list_tail_ = other_dense_map->iter_list_tail_;
+ other_dense_map->data_ = nullptr;
+ other_dense_map->data_deleter_ = nullptr;
+ other_dense_map->size_ = 0;
+ other_dense_map->slots_ = 0;
+ }
/*!
* \brief Release the memory acquired by the container without deleting its
entries stored inside
*/
@@ -777,15 +819,15 @@ class DenseMapBaseObj : public MapBaseObj {
* \param map The pointer to the map, can be changed if re-hashing happens
*/
template <typename MapObjType>
- static void InsertMaybeReHash(KVType&& kv, ObjectPtr<Object>* map) {
- DenseMapBaseObj* map_node = static_cast<DenseMapBaseObj*>(map->get());
+ static ObjectPtr<Object> InsertMaybeReHash(KVType&& kv, const
ObjectPtr<Object>& map) {
+ DenseMapBaseObj* map_node = static_cast<DenseMapBaseObj*>(map.get());
ListNode iter;
// Try to insert. If succeed, we simply return
if (map_node->TryInsert(kv.first, &iter)) {
iter.Val() = std::move(kv.second);
// update the iter list relation
map_node->IterListPushBack(iter);
- return;
+ return ObjectPtr<Object>(nullptr);
}
TVM_FFI_ICHECK(!map_node->IsSmallMap());
// Otherwise, start rehash
@@ -796,7 +838,8 @@ class DenseMapBaseObj : public MapBaseObj {
ListNode node(index, map_node);
// now try move src_data into the new map, note that src may still not
// be fully consumed into the call, but destructor will be called.
- InsertMaybeReHash<MapObjType>(std::move(node.Data()), &p);
+ ObjectPtr<Object> rehashed =
InsertMaybeReHash<MapObjType>(std::move(node.Data()), p);
+ if (rehashed != nullptr) p = std::move(rehashed);
// Important, needs to explicit call destructor in case move did remove
// node's internal item
index = node.Item().next;
@@ -807,9 +850,12 @@ class DenseMapBaseObj : public MapBaseObj {
// Remove this call will cause memory leak very likely.
node.DestructData();
}
- InsertMaybeReHash<MapObjType>(std::move(kv), &p);
+ {
+ ObjectPtr<Object> rehashed =
InsertMaybeReHash<MapObjType>(std::move(kv), p);
+ if (rehashed != nullptr) p = std::move(rehashed);
+ }
map_node->ReleaseMemory();
- *map = p;
+ return p;
}
/*!
* \brief Check whether the hash table is full
@@ -1064,7 +1110,17 @@ class SmallMapBaseObj : public MapBaseObj {
*/
uint64_t NumSlots() const { return slots_ & ~kSmallTagMask; }
- ~SmallMapBaseObj() { this->Reset(); }
+ ~SmallMapBaseObj() {
+ // in destructor, need to check if the map was inplace switched to a dense
map.
+ MapBaseObj* base_map = static_cast<MapBaseObj*>(this);
+ if (base_map->IsSmallMap()) {
+ this->Reset();
+ } else {
+ // this map was inplace switched to a dense map.
+ DenseMapBaseObj* this_dense_map =
static_cast<DenseMapBaseObj*>(base_map);
+ this_dense_map->Reset();
+ }
+ }
/*!
* \brief clear all entries
*/
@@ -1152,6 +1208,69 @@ class SmallMapBaseObj : public MapBaseObj {
data_deleter_ = nullptr;
}
}
+ /*!
+ * \brief Inplace deleter from data
+ * \param data The data
+ */
+ static void InplaceSmallMapDeleterFromData(void* data) {
+ details::ObjectUnsafe::ObjectPtrFromOwned<SmallMapBaseObj>(
+ reinterpret_cast<Object*>(reinterpret_cast<char*>(data) -
sizeof(SmallMapBaseObj)))
+ .reset();
+ }
+ /*!
+ * \brief Inplace switch the storage contents to the other map
+ *
+ * The other map will be consumed after this operation
+ * The current content will be reset after this operation
+ *
+ * \param other The other map
+ */
+ void InplaceSwitchTo(ObjectPtr<Object>&& other) {
+ // invariant this map have not been inplace switched to a dense map
+ TVM_FFI_ICHECK(this->IsSmallMap());
+ this->Reset();
+ MapBaseObj* other_map = static_cast<MapBaseObj*>(other.get());
+ if (other_map->IsSmallMap()) {
+ SmallMapBaseObj* other_small_map =
static_cast<SmallMapBaseObj*>(other_map);
+ SmallMapBaseObj* this_small_map = this;
+ this_small_map->size_ = other_small_map->size_;
+ this_small_map->slots_ = other_small_map->slots_;
+ this_small_map->data_ = other_small_map->data_;
+ if (other_small_map->data_deleter_ != nullptr) {
+ this_small_map->data_deleter_ = other_small_map->data_deleter_;
+ other_small_map->data_deleter_ = nullptr;
+ } else {
+ // we switch to the inplace deleter from the data
+ this_small_map->data_deleter_ = InplaceSmallMapDeleterFromData;
+ // move out other from the ptr so the deletion can only be triggered
+ // via InplaceSmallMapDeleterFromData
+
details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(other));
+ }
+ other_small_map->data_ = nullptr;
+ other_small_map->size_ = 0;
+ other_small_map->slots_ = 0;
+ } else {
+ // Reinterpret this SmallMapBaseObj's memory as DenseMapBaseObj to write
fields at the
+ // correct offsets. This is raw memory manipulation: SmallMapBaseObj's
allocation is
+ // guaranteed large enough (see static_assert in Empty()), and all
member access
+ // compiles to fixed-offset stores with no virtual dispatch involved.
+ // The destructor will also cross check and apply the correct deletion.
+ // As a result, we can inplace switch the container storage to dense map
+ DenseMapBaseObj* other_dense_map =
static_cast<DenseMapBaseObj*>(other_map);
+ DenseMapBaseObj* this_dense_map =
reinterpret_cast<DenseMapBaseObj*>(this);
+ this_dense_map->size_ = other_dense_map->size_;
+ this_dense_map->slots_ = other_dense_map->slots_;
+ this_dense_map->data_ = other_dense_map->data_;
+ this_dense_map->data_deleter_ = other_dense_map->data_deleter_;
+ this_dense_map->fib_shift_ = other_dense_map->fib_shift_;
+ this_dense_map->iter_list_head_ = other_dense_map->iter_list_head_;
+ this_dense_map->iter_list_tail_ = other_dense_map->iter_list_tail_;
+ other_dense_map->data_ = nullptr;
+ other_dense_map->data_deleter_ = nullptr;
+ other_dense_map->size_ = 0;
+ other_dense_map->slots_ = 0;
+ }
+ }
/*!
* \brief Remove a position in SmallMapBaseObj
* \param index The position to be removed
@@ -1184,6 +1303,12 @@ class SmallMapBaseObj : public MapBaseObj {
*/
template <typename MapObjType>
static ObjectPtr<Object> Empty(uint64_t n = kInitSize) {
+ // We always allocate a SmallMapObj to be large enough so it can inplace
switch to DenseMapObj
+ static_assert(alignof(SmallMapBaseObj) % alignof(KVType) == 0);
+ static_assert(sizeof(SmallMapBaseObj) + kInitSize * sizeof(KVType) >=
sizeof(DenseMapBaseObj));
+ n = std::max(n, static_cast<uint64_t>(kInitSize));
+ // allocate a SmallMapBaseObj with enough space to inplace store n
elements and switch to
+ // DenseMapBaseObj
ObjectPtr<SmallMapBaseObj> p =
ffi::make_inplace_array_object<SmallMapBaseObj, KVType>(n);
// data_ is after the SmallMapBaseObj header
p->data_ = reinterpret_cast<char*>(p.get()) + sizeof(SmallMapBaseObj);
@@ -1232,26 +1357,27 @@ class SmallMapBaseObj : public MapBaseObj {
* \param map The pointer to the map, can be changed if re-hashing happens
*/
template <typename MapObjType>
- static void InsertMaybeReHash(KVType&& kv, ObjectPtr<Object>* map) {
- SmallMapBaseObj* map_node = static_cast<SmallMapBaseObj*>(map->get());
+ static ObjectPtr<Object> InsertMaybeReHash(KVType&& kv, const
ObjectPtr<Object>& map) {
+ SmallMapBaseObj* map_node = static_cast<SmallMapBaseObj*>(map.get());
iterator itr = map_node->find(kv.first);
if (itr.index < map_node->size_) {
itr->second = kv.second;
- return;
+ return ObjectPtr<Object>(nullptr);
}
if (map_node->size_ < map_node->NumSlots()) {
KVType* ptr = static_cast<KVType*>(map_node->data_) + map_node->size_;
new (ptr) KVType(std::move(kv));
++map_node->size_;
- return;
+ return ObjectPtr<Object>(nullptr);
}
uint64_t next_size = std::max(map_node->NumSlots() * 2, kInitSize);
next_size = std::min(next_size, kMaxSize);
TVM_FFI_ICHECK_GT(next_size, map_node->NumSlots());
ObjectPtr<Object> new_map =
CreateFromRange<MapObjType>(next_size, map_node->begin(),
map_node->end());
- InsertMaybeReHash<MapObjType>(std::move(kv), &new_map);
- *map = std::move(new_map);
+ ObjectPtr<Object> rehashed = InsertMaybeReHash<MapObjType>(std::move(kv),
new_map);
+ if (rehashed != nullptr) new_map = std::move(rehashed);
+ return new_map;
}
/*!
* \brief Increment the pointer
@@ -1356,6 +1482,10 @@ inline void MapBaseObj::erase(const
MapBaseObj::iterator& position) {
inline void MapBaseObj::clear() {
TVM_FFI_DISPATCH_MAP(this, p, { p->clear(); });
}
+
+inline void MapBaseObj::InplaceSwitchTo(ObjectPtr<Object>&& other) {
+ TVM_FFI_DISPATCH_MAP(this, p, { p->InplaceSwitchTo(std::move(other)); });
+}
/// \endcond
#undef TVM_FFI_DISPATCH_MAP
@@ -1390,7 +1520,9 @@ inline ObjectPtr<Object>
MapBaseObj::CreateFromRange(IterType first, IterType la
ObjectPtr<Object> obj = SmallMapBaseObj::Empty<MapObjType>(cap);
for (; first != last; ++first) {
KVType kv(*first);
- SmallMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv), &obj);
+ ObjectPtr<Object> rehashed =
+ SmallMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv), obj);
+ if (rehashed != nullptr) obj = std::move(rehashed);
}
return obj;
} else {
@@ -1400,35 +1532,40 @@ inline ObjectPtr<Object>
MapBaseObj::CreateFromRange(IterType first, IterType la
ObjectPtr<Object> obj = DenseMapBaseObj::Empty<MapObjType>(fib_shift,
n_slots);
for (; first != last; ++first) {
KVType kv(*first);
- DenseMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv), &obj);
+ ObjectPtr<Object> rehashed =
+ DenseMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv), obj);
+ if (rehashed != nullptr) obj = std::move(rehashed);
}
return obj;
}
}
template <typename MapObjType>
-inline void MapBaseObj::InsertMaybeReHash(KVType&& kv, ObjectPtr<Object>* map)
{
- MapBaseObj* base = static_cast<MapBaseObj*>(map->get());
+inline ObjectPtr<Object> MapBaseObj::InsertMaybeReHash(KVType&& kv, const
ObjectPtr<Object>& map) {
+ MapBaseObj* base = static_cast<MapBaseObj*>(map.get());
#if TVM_FFI_DEBUG_WITH_ABI_CHANGE
base->state_marker++;
#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE
if (base->IsSmallMap()) {
SmallMapBaseObj* sm = static_cast<SmallMapBaseObj*>(base);
if (sm->NumSlots() < SmallMapBaseObj::kMaxSize) {
- SmallMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv), map);
+ return SmallMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv),
map);
} else if (sm->NumSlots() == SmallMapBaseObj::kMaxSize) {
if (base->size_ < sm->NumSlots()) {
- SmallMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv), map);
+ return SmallMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv),
map);
} else {
ObjectPtr<Object> new_map =
MapBaseObj::CreateFromRange<MapObjType>(base->begin(),
base->end());
- DenseMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv),
&new_map);
- *map = std::move(new_map);
+ ObjectPtr<Object> rehashed =
+ DenseMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv),
new_map);
+ if (rehashed != nullptr) new_map = std::move(rehashed);
+ return new_map;
}
}
} else {
- DenseMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv), map);
+ return DenseMapBaseObj::InsertMaybeReHash<MapObjType>(std::move(kv), map);
}
+ return ObjectPtr<Object>(nullptr);
}
/// \cond Doxygen_Suppress
@@ -1439,6 +1576,119 @@ inline void MapBaseObj::InsertMaybeReHash(KVType&& kv,
ObjectPtr<Object>* map) {
template <>
inline ObjectPtr<MapBaseObj> make_object<>() = delete;
/// \endcond
+/*!
+ * \brief CRTP base for map type-traits (Map, Dict).
+ *
+ * \tparam Derived Must expose:
+ * - `static constexpr int32_t kPrimaryTypeIndex` — the canonical FFI type
index
+ * - `static constexpr int32_t kOtherTypeIndex` — an alternative accepted
type index
+ * - `static constexpr const char* kTypeName` — human-readable name for
diagnostics
+ */
+template <typename Derived, typename MapRef, typename K, typename V>
+struct MapTypeTraitsBase : public ObjectRefTypeTraitsBase<MapRef> {
+ using Base = ObjectRefTypeTraitsBase<MapRef>;
+ using Base::CopyFromAnyViewAfterCheck;
+
+ TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+ if (src->type_index != Derived::kPrimaryTypeIndex) return false;
+ if constexpr (std::is_same_v<K, Any> && std::is_same_v<V, Any>) {
+ return true;
+ } else {
+ const MapBaseObj* n = reinterpret_cast<const MapBaseObj*>(src->v_obj);
+ for (const auto& kv : *n) {
+ if constexpr (!std::is_same_v<K, Any>) {
+ if (!details::AnyUnsafe::CheckAnyStrict<K>(kv.first)) return false;
+ }
+ if constexpr (!std::is_same_v<V, Any>) {
+ if (!details::AnyUnsafe::CheckAnyStrict<V>(kv.second)) return false;
+ }
+ }
+ return true;
+ }
+ }
+
+ TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
+ if (src->type_index != Derived::kPrimaryTypeIndex &&
+ src->type_index != Derived::kOtherTypeIndex) {
+ return TypeTraitsBase::GetMismatchTypeInfo(src);
+ }
+ if constexpr (!std::is_same_v<K, Any> || !std::is_same_v<V, Any>) {
+ const MapBaseObj* n = reinterpret_cast<const MapBaseObj*>(src->v_obj);
+ for (const auto& kv : *n) {
+ if constexpr (!std::is_same_v<K, Any>) {
+ if (!details::AnyUnsafe::CheckAnyStrict<K>(kv.first) &&
+ !kv.first.try_cast<K>().has_value()) {
+ return std::string(Derived::kTypeName) + "[some key is " +
+ details::AnyUnsafe::GetMismatchTypeInfo<K>(kv.first) + ",
V]";
+ }
+ }
+ if constexpr (!std::is_same_v<V, Any>) {
+ if (!details::AnyUnsafe::CheckAnyStrict<V>(kv.second) &&
+ !kv.second.try_cast<V>().has_value()) {
+ return std::string(Derived::kTypeName) + "[K, some value is " +
+ details::AnyUnsafe::GetMismatchTypeInfo<V>(kv.second) + "]";
+ }
+ }
+ }
+ }
+ TVM_FFI_THROW(InternalError) << "Cannot reach here";
+ TVM_FFI_UNREACHABLE();
+ }
+
+ TVM_FFI_INLINE static std::optional<MapRef> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (src->type_index != Derived::kPrimaryTypeIndex &&
+ src->type_index != Derived::kOtherTypeIndex) {
+ return std::nullopt;
+ }
+ const MapBaseObj* n = reinterpret_cast<const MapBaseObj*>(src->v_obj);
+ if constexpr (!std::is_same_v<K, Any> || !std::is_same_v<V, Any>) {
+ bool storage_check = [&]() {
+ for (const auto& kv : *n) {
+ if constexpr (!std::is_same_v<K, Any>) {
+ if (!details::AnyUnsafe::CheckAnyStrict<K>(kv.first)) return false;
+ }
+ if constexpr (!std::is_same_v<V, Any>) {
+ if (!details::AnyUnsafe::CheckAnyStrict<V>(kv.second)) return
false;
+ }
+ }
+ return true;
+ }();
+ // fast path: if storage check passes and type is primary, return
directly.
+ if (storage_check && src->type_index == Derived::kPrimaryTypeIndex) {
+ return CopyFromAnyViewAfterCheck(src);
+ }
+ // slow path: create a new map and convert each key-value pair.
+ MapRef ret;
+ for (const auto& kv : *n) {
+ auto k = kv.first.try_cast<K>();
+ auto v = kv.second.try_cast<V>();
+ if (!k.has_value() || !v.has_value()) return std::nullopt;
+ ret.Set(*std::move(k), *std::move(v));
+ }
+ return ret;
+ } else {
+ if (src->type_index == Derived::kPrimaryTypeIndex) {
+ return CopyFromAnyViewAfterCheck(src);
+ }
+ // cross-type conversion for Any,Any: create new MapRef, copy all
entries.
+ MapRef ret;
+ for (const auto& kv : *n) {
+ ret.Set(kv.first, kv.second);
+ }
+ return ret;
+ }
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() {
+ return std::string(Derived::kTypeName) + "<" + details::Type2Str<K>::v() +
", " +
+ details::Type2Str<V>::v() + ">";
+ }
+
+ private:
+ MapTypeTraitsBase() = default;
+ friend Derived;
+};
+
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_CONTAINER_MAP_BASE_H_
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index c97ab714..c95f2bdb 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -116,6 +116,8 @@ struct StaticTypeKey {
static constexpr const char* kTVMFFIMap = "ffi.Map";
/*! \brief The type key for Module */
static constexpr const char* kTVMFFIModule = "ffi.Module";
+ /*! \brief The type key for Dict */
+ static constexpr const char* kTVMFFIDict = "ffi.Dict";
/*! \brief The type key for OpaquePyObject */
static constexpr const char* kTVMFFIOpaquePyObject = "ffi.OpaquePyObject";
};
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 3eedbc55..603fe8ae 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -51,7 +51,7 @@ from ._convert import convert
from .error import register_error
from ._tensor import Device, device, DLDeviceType
from ._tensor import from_dlpack, Tensor, Shape
-from .container import Array, List, Map
+from .container import Array, Dict, List, Map
from .module import Module, system_lib, load_module
from .stream import StreamContext, get_raw_stream, use_raw_stream,
use_torch_stream
from .structural import (
@@ -103,6 +103,7 @@ __all__ = [
"Array",
"DLDeviceType",
"Device",
+ "Dict",
"Function",
"List",
"Map",
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index d20e5fb1..b60ab808 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -40,6 +40,15 @@ if TYPE_CHECKING:
def ArraySize(_0: Sequence[Any], /) -> int: ...
def ArrayContains(_0: Sequence[Any], _1: Any, /) -> bool: ...
def Bytes(_0: bytes, /) -> bytes: ...
+ def Dict(*args: Any) -> Any: ...
+ def DictClear(_0: Any, /) -> None: ...
+ def DictCount(_0: Any, _1: Any, /) -> int: ...
+ def DictErase(_0: Any, _1: Any, /) -> None: ...
+ def DictForwardIterFunctor(_0: Any, /) -> Callable[..., Any]: ...
+ def DictGetItem(_0: Any, _1: Any, /) -> Any: ...
+ def DictGetItemOrMissing(_0: Any, _1: Any, /) -> Any: ...
+ def DictSetItem(_0: Any, _1: Any, _2: Any, /) -> None: ...
+ def DictSize(_0: Any, /) -> int: ...
def FromJSONGraph(_0: Any, /) -> Any: ...
def FromJSONGraphString(_0: str, /) -> Any: ...
def GetInvalidObject() -> Any: ...
@@ -103,6 +112,15 @@ __all__ = [
"ArrayGetItem",
"ArraySize",
"Bytes",
+ "Dict",
+ "DictClear",
+ "DictCount",
+ "DictErase",
+ "DictForwardIterFunctor",
+ "DictGetItem",
+ "DictGetItemOrMissing",
+ "DictSetItem",
+ "DictSize",
"FromJSONGraph",
"FromJSONGraphString",
"FunctionListGlobalNamesFunctor",
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index ca73f164..594abaee 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -92,6 +92,7 @@ def _check_and_update_dlpack_c_exchange_api(tensor_cls:
object) -> bool:
def load_torch_c_dlpack_extension() -> Any: # noqa: PLR0912, PLR0915
try:
import torch # noqa: PLC0415
+ import torch.version # noqa: PLC0415
if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
# skip loading the extension if the __dlpack_c_exchange_api__
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 009c4813..d53366ae 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -42,6 +42,7 @@ if sys.version_info >= (3, 9):
Iterable,
Iterator,
Mapping,
+ MutableMapping,
MutableSequence,
Sequence,
)
@@ -61,6 +62,7 @@ else: # Python 3.8
Iterable,
Iterator,
Mapping,
+ MutableMapping,
MutableSequence,
Sequence,
)
@@ -71,7 +73,7 @@ else: # Python 3.8
ValuesView as ValuesViewBase,
)
-__all__ = ["Array", "List", "Map"]
+__all__ = ["Array", "Dict", "List", "Map"]
T = TypeVar("T")
@@ -365,15 +367,20 @@ class List(core.Object, MutableSequence[T]):
class KeysView(KeysViewBase[K]):
"""Helper class to return keys view."""
- def __init__(self, backend_map: Map[K, V]) -> None:
+ def __init__(
+ self,
+ backend_map: Map[K, V] | Dict[K, V],
+ iter_functor_getter: Callable[..., Callable[[int], Any]] | None = None,
+ ) -> None:
self._backend_map = backend_map
+ self._iter_functor_getter = iter_functor_getter or
_ffi_api.MapForwardIterFunctor
def __len__(self) -> int:
return len(self._backend_map)
def __iter__(self) -> Iterator[K]:
size = len(self._backend_map)
- functor: Callable[[int], Any] =
_ffi_api.MapForwardIterFunctor(self._backend_map)
+ functor: Callable[[int], Any] =
self._iter_functor_getter(self._backend_map)
for _ in range(size):
key = cast(K, functor(0))
yield key
@@ -387,15 +394,20 @@ class KeysView(KeysViewBase[K]):
class ValuesView(ValuesViewBase[V]):
"""Helper class to return values view."""
- def __init__(self, backend_map: Map[K, V]) -> None:
+ def __init__(
+ self,
+ backend_map: Map[K, V] | Dict[K, V],
+ iter_functor_getter: Callable[..., Callable[[int], Any]] | None = None,
+ ) -> None:
self._backend_map = backend_map
+ self._iter_functor_getter = iter_functor_getter or
_ffi_api.MapForwardIterFunctor
def __len__(self) -> int:
return len(self._backend_map)
def __iter__(self) -> Iterator[V]:
size = len(self._backend_map)
- functor: Callable[[int], Any] =
_ffi_api.MapForwardIterFunctor(self._backend_map)
+ functor: Callable[[int], Any] =
self._iter_functor_getter(self._backend_map)
for _ in range(size):
value = cast(V, functor(1))
yield value
@@ -406,15 +418,20 @@ class ValuesView(ValuesViewBase[V]):
class ItemsView(ItemsViewBase[K, V]):
"""Helper class to return items view."""
- def __init__(self, backend_map: Map[K, V]) -> None:
+ def __init__(
+ self,
+ backend_map: Map[K, V] | Dict[K, V],
+ iter_functor_getter: Callable[..., Callable[[int], Any]] | None = None,
+ ) -> None:
self._backend_map = backend_map
+ self._iter_functor_getter = iter_functor_getter or
_ffi_api.MapForwardIterFunctor
def __len__(self) -> int:
return len(self._backend_map)
def __iter__(self) -> Iterator[tuple[K, V]]:
size = len(self._backend_map)
- functor: Callable[[int], Any] =
_ffi_api.MapForwardIterFunctor(self._backend_map)
+ functor: Callable[[int], Any] =
self._iter_functor_getter(self._backend_map)
for _ in range(size):
key = cast(K, functor(0))
value = cast(V, functor(1))
@@ -544,3 +561,120 @@ class Map(core.Object, Mapping[K, V]):
if self.__chandle__() == 0:
return type(self).__name__ + "(chandle=None)"
return str(core.__object_repr__(self)) # ty:
ignore[unresolved-attribute]
+
+
+@register_object("ffi.Dict")
+class Dict(core.Object, MutableMapping[K, V]):
+ """Mutable dictionary container with shared reference semantics.
+
+ Unlike :class:`Map`, ``Dict`` does NOT implement copy-on-write.
+ Mutations happen directly on the underlying shared object.
+ All Python references sharing the same ``Dict`` see mutations immediately.
+
+ Parameters
+ ----------
+ input_dict
+ The dictionary of values to be stored.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ import tvm_ffi
+
+ d = tvm_ffi.Dict({"a": 1, "b": 2})
+ d["c"] = 3
+ assert len(d) == 3
+
+ """
+
+ def __init__(self, input_dict: Mapping[K, V] | None = None) -> None:
+ """Construct a Dict from a Python mapping."""
+ list_kvs: list[Any] = []
+ if input_dict is not None:
+ for k, v in input_dict.items():
+ list_kvs.append(k)
+ list_kvs.append(v)
+ self.__init_handle_by_constructor__(_ffi_api.Dict, *list_kvs)
+
+ def __getitem__(self, k: K) -> V:
+ """Return the value for key `k` or raise KeyError."""
+ return cast(V, _ffi_api.DictGetItem(self, k))
+
+ def __setitem__(self, k: K, v: V) -> None:
+ """Set the value for key `k`."""
+ _ffi_api.DictSetItem(self, k, v)
+
+ def __delitem__(self, k: K) -> None:
+ """Delete the entry for key `k`."""
+ if _ffi_api.DictCount(self, k) == 0:
+ raise KeyError(k)
+ _ffi_api.DictErase(self, k)
+
+ def __contains__(self, k: object) -> bool:
+ """Return True if the dict contains key `k`."""
+ return _ffi_api.DictCount(self, k) != 0
+
+ def __len__(self) -> int:
+ """Return the number of items in the dict."""
+ return _ffi_api.DictSize(self)
+
+ def __bool__(self) -> bool:
+ """Return True if the dict is non-empty."""
+ return len(self) > 0
+
+ def __iter__(self) -> Iterator[K]:
+ """Iterate over the dict's keys."""
+ return iter(self.keys())
+
+ def keys(self) -> KeysView[K]:
+ """Return a dynamic view of the dict's keys."""
+ return KeysView(self, _ffi_api.DictForwardIterFunctor)
+
+ def values(self) -> ValuesView[V]:
+ """Return a dynamic view of the dict's values."""
+ return ValuesView(self, _ffi_api.DictForwardIterFunctor)
+
+ def items(self) -> ItemsView[K, V]:
+ """Get the items from the dict."""
+ return ItemsView(self, _ffi_api.DictForwardIterFunctor)
+
+ @overload
+ def get(self, key: K) -> V | None: ...
+
+ @overload
+ def get(self, key: K, default: V | _DefaultT) -> V | _DefaultT: ...
+
+ def get(self, key: K, default: V | _DefaultT | None = None) -> V |
_DefaultT | None:
+ """Get an element with a default value."""
+ ret = _ffi_api.DictGetItemOrMissing(self, key)
+ if MISSING.same_as(ret):
+ return default
+ return ret
+
+ def pop(self, key: K, *args: V | _DefaultT) -> V | _DefaultT:
+ """Remove and return value for key, or default if not present."""
+ if len(args) > 1:
+ raise TypeError(f"pop expected at most 2 arguments, got {1 +
len(args)}")
+ ret = _ffi_api.DictGetItemOrMissing(self, key)
+ if MISSING.same_as(ret):
+ if args:
+ return args[0]
+ raise KeyError(key)
+ _ffi_api.DictErase(self, key)
+ return cast(V, ret)
+
+ def clear(self) -> None:
+ """Remove all elements from the dict."""
+ _ffi_api.DictClear(self)
+
+ def update(self, other: Mapping[K, V]) -> None: # type: ignore[override]
+ """Update the dict from a mapping."""
+ for k, v in other.items():
+ self[k] = v
+
+ def __repr__(self) -> str:
+ """Return a string representation of the dict."""
+ if self.__chandle__() == 0:
+ return type(self).__name__ + "(chandle=None)"
+ return str(core.__object_repr__(self)) # ty:
ignore[unresolved-attribute]
diff --git a/python/tvm_ffi/cpp/dtype.py b/python/tvm_ffi/cpp/dtype.py
index 4a782b0d..01981109 100644
--- a/python/tvm_ffi/cpp/dtype.py
+++ b/python/tvm_ffi/cpp/dtype.py
@@ -63,6 +63,7 @@ ROCM_DTYPE_MAP = {
def _determine_backend_once() -> Literal["cpu", "cuda", "rocm"]:
try:
import torch # noqa: PLC0415
+ import torch.version # noqa: PLC0415
if torch.cuda.is_available():
if torch.version.cuda is not None:
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 4512936c..e80a4802 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -152,6 +152,7 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIModule = 73
kTVMFFIOpaquePyObject = 74
kTVMFFIList = 75
+ kTVMFFIDict = 76
ctypedef void* TVMFFIObjectHandle
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index d0e98a47..e2ecf9e2 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -75,6 +75,7 @@ _TYPE_SCHEMA_ORIGIN_CONVERTER = {
"ffi.Array": "list",
"ffi.List": "list",
"ffi.Map": "dict",
+ "ffi.Dict": "dict",
"ffi.OpaquePyObject": "Any",
"ffi.Object": "Object",
"ffi.Tensor": "Tensor",
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 049ace0f..e39cc44d 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -353,7 +353,7 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo)
-> type:
setattr(type_cls, "__init__", getattr(type_cls, "__ffi_init__"))
elif not issubclass(type_cls, core.PyNativeObject):
setattr(type_cls, "__init__", __init__invalid)
- is_container = type_info.type_key in ("ffi.Array", "ffi.Map")
+ is_container = type_info.type_key in ("ffi.Array", "ffi.Map", "ffi.List",
"ffi.Dict")
_setup_copy_methods(type_cls, has_shallow_copy, is_container=is_container)
return type_cls
diff --git a/python/tvm_ffi/testing/_ffi_api.py
b/python/tvm_ffi/testing/_ffi_api.py
index 8bd11c6b..a83254f2 100644
--- a/python/tvm_ffi/testing/_ffi_api.py
+++ b/python/tvm_ffi/testing/_ffi_api.py
@@ -55,6 +55,8 @@ if TYPE_CHECKING:
def schema_id_bool(_0: bool, /) -> bool: ...
def schema_id_bytes(_0: bytes, /) -> bytes: ...
def schema_id_device(_0: Device, /) -> Device: ...
+ def schema_id_dict_str_int(_0: Mapping[str, int], /) -> Mapping[str, int]:
...
+ def schema_id_dict_str_str(_0: Mapping[str, str], /) -> Mapping[str, str]:
...
def schema_id_dltensor(_0: Tensor, /) -> Tensor: ...
def schema_id_dtype(_0: dtype, /) -> dtype: ...
def schema_id_float(_0: float, /) -> float: ...
@@ -107,6 +109,8 @@ __all__ = [
"schema_id_bool",
"schema_id_bytes",
"schema_id_device",
+ "schema_id_dict_str_int",
+ "schema_id_dict_str_str",
"schema_id_dltensor",
"schema_id_dtype",
"schema_id_float",
diff --git a/src/ffi/container.cc b/src/ffi/container.cc
index 570fd755..2c5cfe13 100644
--- a/src/ffi/container.cc
+++ b/src/ffi/container.cc
@@ -21,6 +21,7 @@
* \file src/ffi/container.cc
*/
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
@@ -163,7 +164,38 @@ TVM_FFI_STATIC_INIT_BLOCK() {
return GetMissingObject();
}
})
- .def("ffi.GetInvalidObject", []() -> ObjectRef { return
GetMissingObject(); });
+ .def("ffi.GetInvalidObject", []() -> ObjectRef { return
GetMissingObject(); })
+ .def_packed("ffi.Dict",
+ [](ffi::PackedArgs args, Any* ret) {
+ TVM_FFI_ICHECK_EQ(args.size() % 2, 0);
+ Dict<Any, Any> data;
+ for (int i = 0; i < args.size(); i += 2) {
+ data.Set(args[i], args[i + 1]);
+ }
+ *ret = data;
+ })
+ .def("ffi.DictSize",
+ [](const ffi::DictObj* n) -> int64_t { return
static_cast<int64_t>(n->size()); })
+ .def("ffi.DictGetItem", [](const ffi::DictObj* n, const Any& k) -> Any {
return n->at(k); })
+ .def("ffi.DictSetItem",
+ [](ffi::Dict<Any, Any> d, const Any& k, const Any& v) -> void {
d.Set(k, v); })
+ .def("ffi.DictCount",
+ [](const ffi::DictObj* n, const Any& k) -> int64_t {
+ return static_cast<int64_t>(n->count(k));
+ })
+ .def("ffi.DictErase", [](ffi::Dict<Any, Any> d, const Any& k) -> void {
d.erase(k); })
+ .def("ffi.DictClear", [](ffi::Dict<Any, Any> d) -> void { d.clear(); })
+ .def("ffi.DictForwardIterFunctor",
+ [](const ffi::DictObj* n) -> ffi::Function {
+ return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(),
n->end()));
+ })
+ .def("ffi.DictGetItemOrMissing", [](const ffi::DictObj* n, const Any& k)
-> Any {
+ try {
+ return n->at(k);
+ } catch (const tvm::ffi::Error& e) {
+ return GetMissingObject();
+ }
+ });
}
} // namespace ffi
} // namespace tvm
diff --git a/src/ffi/extra/deep_copy.cc b/src/ffi/extra/deep_copy.cc
index 0b390948..5555c7bc 100644
--- a/src/ffi/extra/deep_copy.cc
+++ b/src/ffi/extra/deep_copy.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
@@ -115,6 +116,18 @@ class ObjectDeepCopier {
copy_map_[obj] = new_map;
return new_map;
}
+ if (ti == TypeIndex::kTVMFFIDict) {
+ // Dict is mutable, so cyclic self-references are possible.
+ // Register the empty copy in copy_map_ before resolving children
+ // so that back-references resolve to the same new Dict.
+ const DictObj* orig = value.as<DictObj>();
+ Dict<Any, Any> new_dict;
+ copy_map_[obj] = new_dict;
+ for (const auto& [k, v] : *orig) {
+ new_dict.Set(Resolve(k), Resolve(v));
+ }
+ return new_dict;
+ }
// General object: shallow-copy, register, and queue for field resolution.
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(ti);
TVM_FFI_ICHECK((*column_)[ti] != nullptr)
diff --git a/src/ffi/extra/json_writer.cc b/src/ffi/extra/json_writer.cc
index 8abb17b9..3061e312 100644
--- a/src/ffi/extra/json_writer.cc
+++ b/src/ffi/extra/json_writer.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/error.h>
diff --git a/src/ffi/extra/repr_print.cc b/src/ffi/extra/repr_print.cc
index f7aecd5f..cb357894 100644
--- a/src/ffi/extra/repr_print.cc
+++ b/src/ffi/extra/repr_print.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
@@ -225,8 +226,9 @@ class ReprPrinter {
// Generic reflection-based repr
result = GenericRepr(obj);
}
- // For Array/List: append address if env var is set
+ // For containers: append address if env var is set
if (show_addr_ && (ti == TypeIndex::kTVMFFIArray || ti ==
TypeIndex::kTVMFFIList ||
+ ti == TypeIndex::kTVMFFIMap || ti ==
TypeIndex::kTVMFFIDict ||
ti == TypeIndex::kTVMFFITensor)) {
result += "@" + AddressStr(obj);
}
@@ -335,7 +337,7 @@ String ReprList(const ListObj* obj, const Function&
fn_repr) {
return String(os.str());
}
-String ReprMap(const MapObj* obj, const Function& fn_repr) {
+String ReprMapBase(const MapBaseObj* obj, const Function& fn_repr) {
std::ostringstream os;
os << "{";
bool first = true;
@@ -350,6 +352,14 @@ String ReprMap(const MapObj* obj, const Function& fn_repr)
{
return String(os.str());
}
+String ReprDict(const DictObj* obj, const Function& fn_repr) {
+ return ReprMapBase(static_cast<const MapBaseObj*>(obj), fn_repr);
+}
+
+String ReprMap(const MapObj* obj, const Function& fn_repr) {
+ return ReprMapBase(static_cast<const MapBaseObj*>(obj), fn_repr);
+}
+
/*!
* \brief Register a built-in __ffi_repr__ function for a given type index.
*/
@@ -381,6 +391,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
RegisterBuiltinRepr(TypeIndex::kTVMFFIArray, ReprArray);
RegisterBuiltinRepr(TypeIndex::kTVMFFIList, ReprList);
RegisterBuiltinRepr(TypeIndex::kTVMFFIMap, ReprMap);
+ RegisterBuiltinRepr(TypeIndex::kTVMFFIDict, ReprDict);
// Register global function
refl::GlobalDef().def("ffi.ReprPrint",
[](const Any& value) -> String { return
ReprPrint(value); });
diff --git a/src/ffi/extra/serialization.cc b/src/ffi/extra/serialization.cc
index f639d87c..2351f7ca 100644
--- a/src/ffi/extra/serialization.cc
+++ b/src/ffi/extra/serialization.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
@@ -133,7 +134,19 @@ class ObjectGraphSerializer {
case TypeIndex::kTVMFFIMap: {
Map<Any, Any> map =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<Map<Any, Any>>(value);
node.Set("type", ffi::StaticTypeKey::kTVMFFIMap);
- node.Set("data", CreateMapData(map));
+ node.Set("data", CreateMapBaseData(static_cast<const
MapBaseObj*>(map.get())));
+ break;
+ }
+ case TypeIndex::kTVMFFIDict: {
+ Dict<Any, Any> dict =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<Dict<Any, Any>>(value);
+ node.Set("type", ffi::StaticTypeKey::kTVMFFIDict);
+ const void* dict_ptr = static_cast<const void*>(dict.get());
+ if (!active_lists_.insert(dict_ptr).second) {
+ TVM_FFI_THROW(ValueError)
+ << "Cycle detected during serialization: a Dict contains itself";
+ }
+ node.Set("data", CreateMapBaseData(static_cast<const
MapBaseObj*>(dict.get())));
+ active_lists_.erase(dict_ptr);
break;
}
case TypeIndex::kTVMFFIShape: {
@@ -169,12 +182,12 @@ class ObjectGraphSerializer {
return data;
}
- json::Array CreateMapData(const Map<Any, Any>& value) {
+ json::Array CreateMapBaseData(const MapBaseObj* value) {
json::Array data;
- data.reserve(static_cast<int64_t>(value.size()) * 2);
- for (const auto& [key, value] : value) {
+ data.reserve(static_cast<int64_t>(value->size()) * 2);
+ for (const auto& [key, val] : *value) {
data.push_back(GetOrCreateNodeIndex(key));
- data.push_back(GetOrCreateNodeIndex(value));
+ data.push_back(GetOrCreateNodeIndex(val));
}
return data;
}
@@ -307,7 +320,10 @@ class ObjectGraphDeserializer {
return Base64Decode(node["data"].cast<String>());
}
case TypeIndex::kTVMFFIMap: {
- return DecodeMapData(node["data"].cast<json::Array>());
+ return DecodeMapLikeData<Map<Any,
Any>>(node["data"].cast<json::Array>());
+ }
+ case TypeIndex::kTVMFFIDict: {
+ return DecodeMapLikeData<Dict<Any,
Any>>(node["data"].cast<json::Array>());
}
case TypeIndex::kTVMFFIArray: {
return
DecodeSequenceData<Array<Any>>(node["data"].cast<json::Array>());
@@ -335,15 +351,16 @@ class ObjectGraphDeserializer {
return sequence;
}
- Map<Any, Any> DecodeMapData(const json::Array& data) {
- Map<Any, Any> map;
+ template <typename MapType>
+ MapType DecodeMapLikeData(const json::Array& data) {
+ MapType result;
const int64_t n = static_cast<int64_t>(data.size());
for (int64_t i = 0; i < n; i += 2) {
int64_t key_index = data[i].cast<int64_t>();
int64_t value_index = data[i + 1].cast<int64_t>();
- map.Set(GetOrDecodeNode(key_index), GetOrDecodeNode(value_index));
+ result.Set(GetOrDecodeNode(key_index), GetOrDecodeNode(value_index));
}
- return map;
+ return result;
}
Any DecodeObjectData(int32_t type_index, const json::Value& data) {
diff --git a/src/ffi/extra/structural_equal.cc
b/src/ffi/extra/structural_equal.cc
index 3e8cb54a..6897ad55 100644
--- a/src/ffi/extra/structural_equal.cc
+++ b/src/ffi/extra/structural_equal.cc
@@ -22,6 +22,7 @@
* \brief Structural equal implementation.
*/
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
@@ -113,6 +114,10 @@ class StructEqualHandler {
return CompareMap(AnyUnsafe::MoveFromAnyAfterCheck<Map<Any,
Any>>(std::move(lhs)),
AnyUnsafe::MoveFromAnyAfterCheck<Map<Any,
Any>>(std::move(rhs)));
}
+ case TypeIndex::kTVMFFIDict: {
+ return CompareMap(AnyUnsafe::MoveFromAnyAfterCheck<Dict<Any,
Any>>(std::move(lhs)),
+ AnyUnsafe::MoveFromAnyAfterCheck<Dict<Any,
Any>>(std::move(rhs)));
+ }
case TypeIndex::kTVMFFIShape: {
return
CompareShape(AnyUnsafe::MoveFromAnyAfterCheck<Shape>(std::move(lhs)),
AnyUnsafe::MoveFromAnyAfterCheck<Shape>(std::move(rhs)));
@@ -261,8 +266,8 @@ class StructEqualHandler {
}
}
- // NOLINTNEXTLINE(performance-unnecessary-value-param)
- bool CompareMap(Map<Any, Any> lhs, Map<Any, Any> rhs) {
+ template <typename MapType>
+ bool CompareMap(const MapType& lhs, const MapType& rhs) {
if (lhs.size() != rhs.size()) {
// size mismatch, and there is no path tracing
// return false since we don't need informative error message
diff --git a/src/ffi/extra/structural_hash.cc b/src/ffi/extra/structural_hash.cc
index aed6abeb..6b125cae 100644
--- a/src/ffi/extra/structural_hash.cc
+++ b/src/ffi/extra/structural_hash.cc
@@ -22,6 +22,7 @@
* \brief Structural equal implementation.
*/
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
@@ -85,6 +86,10 @@ class StructuralHashHandler {
case TypeIndex::kTVMFFIMap: {
return HashMap(AnyUnsafe::MoveFromAnyAfterCheck<Map<Any,
Any>>(std::move(src)));
}
+ case TypeIndex::kTVMFFIDict: {
+ Dict<Any, Any> dict = AnyUnsafe::MoveFromAnyAfterCheck<Dict<Any,
Any>>(std::move(src));
+ return HashMapBase(static_cast<const MapBaseObj*>(dict.get()));
+ }
case TypeIndex::kTVMFFIShape: {
return
HashShape(AnyUnsafe::MoveFromAnyAfterCheck<Shape>(std::move(src)));
}
@@ -243,10 +248,14 @@ class StructuralHashHandler {
// NOLINTNEXTLINE(performance-unnecessary-value-param)
uint64_t HashMap(Map<Any, Any> map) {
+ return HashMapBase(static_cast<const MapBaseObj*>(map.get()));
+ }
+
+ uint64_t HashMapBase(const MapBaseObj* map) {
// Compute a deterministic hash value for the map.
- uint64_t hash_value = details::StableHashCombine(map->GetTypeKeyHash(),
map.size());
+ uint64_t hash_value = details::StableHashCombine(map->GetTypeKeyHash(),
map->size());
std::vector<std::pair<uint64_t, Any>> items;
- for (auto [key, value] : map) {
+ for (const auto& [key, value] : *map) {
// if we cannot find order independent hash, we skip the key
if (auto hash_key = FindOrderIndependentHash(key)) {
items.emplace_back(*hash_key, value);
diff --git a/src/ffi/object.cc b/src/ffi/object.cc
index 7e5e00e1..cb43adf8 100644
--- a/src/ffi/object.cc
+++ b/src/ffi/object.cc
@@ -366,6 +366,7 @@ class TypeTable {
ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIArray,
TypeIndex::kTVMFFIArray);
ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIList,
TypeIndex::kTVMFFIList);
ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIMap,
TypeIndex::kTVMFFIMap);
+ ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIDict,
TypeIndex::kTVMFFIDict);
ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIModule,
TypeIndex::kTVMFFIModule);
ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIOpaquePyObject,
TypeIndex::kTVMFFIOpaquePyObject);
diff --git a/src/ffi/testing/testing.cc b/src/ffi/testing/testing.cc
index b9b70bc5..d1a0a161 100644
--- a/src/ffi/testing/testing.cc
+++ b/src/ffi/testing/testing.cc
@@ -22,6 +22,7 @@
#include <dlpack/dlpack.h>
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/tensor.h>
@@ -388,6 +389,10 @@ List<int64_t> schema_id_list_int(List<int64_t> lst) {
return lst; }
List<String> schema_id_list_str(List<String> lst) { return lst; }
List<ObjectRef> schema_id_list_obj(List<ObjectRef> lst) { return lst; }
+// Dict types
+Dict<String, int64_t> schema_id_dict_str_int(Dict<String, int64_t> d) { return
d; }
+Dict<String, String> schema_id_dict_str_str(Dict<String, String> d) { return
d; }
+
// Complex nested types
Map<String, Array<int64_t>> schema_arr_map_opt(const Array<Optional<int64_t>>&
arr,
Map<String, Array<int64_t>> mp,
@@ -557,6 +562,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("testing.schema_id_map_str_str",
schema_test_impl::schema_id_map_str_str)
.def("testing.schema_id_map_str_obj",
schema_test_impl::schema_id_map_str_obj)
.def("testing.schema_id_map", schema_test_impl::schema_id_map)
+ .def("testing.schema_id_dict_str_int",
schema_test_impl::schema_id_dict_str_int)
+ .def("testing.schema_id_dict_str_str",
schema_test_impl::schema_id_dict_str_str)
.def("testing.schema_id_variant_int_str",
schema_test_impl::schema_id_variant_int_str)
.def_packed("testing.schema_packed", [](PackedArgs args, Any* ret) {})
.def("testing.schema_arr_map_opt", schema_test_impl::schema_arr_map_opt)
diff --git a/tests/cpp/extra/test_serialization.cc
b/tests/cpp/extra/test_serialization.cc
index ce6e42cc..664c055a 100644
--- a/tests/cpp/extra/test_serialization.cc
+++ b/tests/cpp/extra/test_serialization.cc
@@ -18,6 +18,7 @@
*/
#include <gtest/gtest.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
@@ -275,6 +276,77 @@ TEST(Serialization, Maps) {
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated),
duplicated_map));
}
+TEST(Serialization, Dicts) {
+ // Test empty dict
+ Dict<String, Any> empty_dict;
+ json::Object expected_empty = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "ffi.Dict"}, {"data",
json::Array{}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_dict), expected_empty));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_dict));
+
+ // Test single element dict
+ Dict<String, Any> single_dict{{"key", 42}};
+ json::Object expected_single = json::Object{
+ {"root_index", 2},
+ {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data",
String("key")}},
+ json::Object{{"type", "int"}, {"data", 42}},
+ json::Object{{"type", "ffi.Dict"}, {"data",
json::Array{0, 1}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_dict), expected_single));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_dict));
+
+ // Test duplicated element dict
+ Dict<String, Any> duplicated_dict{{"b", 42}, {"a", 42}};
+ json::Object expected_duplicated = json::Object{
+ {"root_index", 3},
+ {"nodes", json::Array{
+ json::Object{{"type", "ffi.String"}, {"data", "b"}},
+ json::Object{{"type", "int"}, {"data", 42}},
+ json::Object{{"type", "ffi.String"}, {"data", "a"}},
+ json::Object{{"type", "ffi.Dict"}, {"data", json::Array{0,
1, 2, 1}}},
+ }}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(duplicated_dict),
expected_duplicated));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated),
duplicated_dict));
+}
+
+TEST(Serialization, DictWithIntKeys) {
+ Dict<Any, Any> dict;
+ dict.Set(static_cast<int64_t>(1), String("one"));
+ dict.Set(static_cast<int64_t>(2), String("two"));
+
+ json::Value serialized = ToJSONGraph(dict);
+ Any deserialized = FromJSONGraph(serialized);
+ Dict<Any, Any> result = deserialized.cast<Dict<Any, Any>>();
+ EXPECT_EQ(result.size(), 2);
+ EXPECT_EQ(std::string(result[1].cast<String>()), "one");
+ EXPECT_EQ(std::string(result[2].cast<String>()), "two");
+}
+
+TEST(Serialization, DictWithArrayValues) {
+ Array<Any> arr;
+ arr.push_back(10);
+ arr.push_back(20);
+ Dict<String, Any> dict{{"nums", arr}};
+
+ json::Value serialized = ToJSONGraph(dict);
+ Any deserialized = FromJSONGraph(serialized);
+ Dict<String, Any> result = deserialized.cast<Dict<String, Any>>();
+ Array<Any> result_arr = result["nums"].cast<Array<Any>>();
+ EXPECT_EQ(result_arr.size(), 2);
+ EXPECT_EQ(result_arr[0].cast<int64_t>(), 10);
+ EXPECT_EQ(result_arr[1].cast<int64_t>(), 20);
+}
+
+TEST(Serialization, DictOfObjects) {
+ TVar x("x");
+ Dict<String, Any> dict{{"var", x}};
+
+ json::Value serialized = ToJSONGraph(dict);
+ Any deserialized = FromJSONGraph(serialized);
+ Dict<String, Any> result = deserialized.cast<Dict<String, Any>>();
+ EXPECT_EQ(std::string(result["var"].cast<TVar>()->name), "x");
+}
+
TEST(Serialization, Shapes) {
Shape empty_shape;
diff --git a/tests/cpp/extra/test_structural_equal_hash.cc
b/tests/cpp/extra/test_structural_equal_hash.cc
index a768319f..ad081e30 100644
--- a/tests/cpp/extra/test_structural_equal_hash.cc
+++ b/tests/cpp/extra/test_structural_equal_hash.cc
@@ -19,6 +19,7 @@
*/
#include <gtest/gtest.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/extra/structural_equal.h>
@@ -118,6 +119,58 @@ TEST(StructuralEqualHash, NestedMapArray) {
refl::AccessPath::Root()->MapItem("b"));
}
+TEST(StructuralEqualHash, Dict) {
+ // same dict but different insertion order
+ Dict<String, int> a = {{"a", 1}, {"b", 2}, {"c", 3}};
+ Dict<String, int> b = {{"b", 2}, {"c", 3}, {"a", 1}};
+ EXPECT_TRUE(StructuralEqual()(a, b));
+ EXPECT_EQ(StructuralHash()(a), StructuralHash()(b));
+
+ Dict<String, int> c = {{"a", 1}, {"b", 2}, {"c", 4}};
+ EXPECT_FALSE(StructuralEqual()(a, c));
+ EXPECT_NE(StructuralHash()(a), StructuralHash()(c));
+
+ auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c);
+ auto expected_diff_a_c =
refl::AccessPathPair(refl::AccessPath::Root()->MapItem("c"),
+
refl::AccessPath::Root()->MapItem("c"));
+ EXPECT_TRUE(diff_a_c.has_value());
+ EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c));
+}
+
+TEST(StructuralEqualHash, NestedDictArray) {
+ Dict<String, Array<Any>> a = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}};
+ Dict<String, Array<Any>> b = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}};
+ EXPECT_TRUE(StructuralEqual()(a, b));
+ EXPECT_EQ(StructuralHash()(a), StructuralHash()(b));
+
+ Dict<String, Array<Any>> c = {{"a", {1, 2, 3}}, {"b", {4, "world", 6}}};
+ EXPECT_FALSE(StructuralEqual()(a, c));
+ EXPECT_NE(StructuralHash()(a), StructuralHash()(c));
+
+ auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c);
+ auto expected_diff_a_c =
+
refl::AccessPathPair(refl::AccessPath::Root()->MapItem("b")->ArrayItem(1),
+
refl::AccessPath::Root()->MapItem("b")->ArrayItem(1));
+ EXPECT_TRUE(diff_a_c.has_value());
+ EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c));
+
+ Dict<String, Array<Any>> d = {{"a", {1, 2, 3}}};
+ auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d);
+ auto expected_diff_a_d =
refl::AccessPathPair(refl::AccessPath::Root()->MapItem("b"),
+
refl::AccessPath::Root()->MapItemMissing("b"));
+ EXPECT_TRUE(diff_a_d.has_value());
+ EXPECT_TRUE(StructuralEqual()(diff_a_d, expected_diff_a_d));
+}
+
+TEST(StructuralEqualHash, DictVsMapDifferentType) {
+ Map<String, int> m = {{"a", 1}, {"b", 2}};
+ Dict<String, int> d = {{"a", 1}, {"b", 2}};
+ // Different type_index => not equal
+ EXPECT_FALSE(StructuralEqual()(m, d));
+ // Different type_key_hash => different hash (very likely)
+ EXPECT_NE(StructuralHash()(m), StructuralHash()(d));
+}
+
TEST(StructuralEqualHash, FreeVar) {
TVar a = TVar("a");
TVar b = TVar("b");
diff --git a/tests/cpp/test_dict.cc b/tests/cpp/test_dict.cc
new file mode 100644
index 00000000..6e6acaaf
--- /dev/null
+++ b/tests/cpp/test_dict.cc
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/container/dict.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/function.h>
+
+namespace {
+
+using namespace tvm::ffi;
+
+TEST(Dict, Basic) {
+ Dict<String, int> d;
+ d.Set("a", 1);
+ d.Set("b", 2);
+
+ EXPECT_EQ(d.size(), 2);
+ EXPECT_EQ(d.at("a"), 1);
+ EXPECT_EQ(d["b"], 2);
+ EXPECT_EQ(d.count("a"), 1);
+ EXPECT_EQ(d.count("c"), 0);
+ EXPECT_FALSE(d.empty());
+}
+
+TEST(Dict, FindAndGet) {
+ Dict<String, int> d;
+ d.Set("x", 42);
+
+ auto it = d.find("x");
+ EXPECT_TRUE(it != d.end());
+ EXPECT_EQ((*it).second, 42);
+
+ auto it2 = d.find("y");
+ EXPECT_TRUE(it2 == d.end());
+
+ auto opt = d.Get("x");
+ ASSERT_TRUE(opt.has_value());
+ EXPECT_EQ(opt.value(), 42); // NOLINT(bugprone-unchecked-optional-access)
+
+ auto opt2 = d.Get("y");
+ EXPECT_FALSE(opt2.has_value());
+}
+
+TEST(Dict, SharedMutation) {
+ // Two handles point to the same DictObj
+ Dict<String, int> d1;
+ d1.Set("a", 1);
+ Dict<String, int> d2 = d1; // shallow copy of ObjectRef
+
+ // Mutate through d1
+ d1.Set("b", 2);
+
+ // d2 should see the change (no COW)
+ EXPECT_EQ(d2.size(), 2);
+ EXPECT_EQ(d2["b"], 2);
+
+ // Same underlying object
+ EXPECT_EQ(d1.get(), d2.get());
+}
+
+TEST(Dict, InplaceSwitchTo) {
+ // Insert >4 elements to trigger transition from small to dense layout.
+ // Verify the ObjectPtr address stays the same.
+ Dict<String, int> d1;
+ d1.Set("a", 1);
+ Dict<String, int> d2 = d1; // alias
+
+ const void* original_ptr = d1.get();
+
+ // Insert enough elements to trigger rehash
+ d1.Set("b", 2);
+ d1.Set("c", 3);
+ d1.Set("d", 4);
+ d1.Set("e", 5);
+ d1.Set("f", 6);
+
+ // ObjectPtr must be stable (InplaceSwitchTo)
+ EXPECT_EQ(static_cast<const void*>(d1.get()), original_ptr);
+ // Alias must point to same object and see all elements
+ EXPECT_EQ(static_cast<const void*>(d2.get()), original_ptr);
+ EXPECT_EQ(d2.size(), 6);
+ EXPECT_EQ(d2["f"], 6);
+}
+
+TEST(Dict, ManyElements) {
+ Dict<int, int> d;
+ for (int i = 0; i < 100; ++i) {
+ d.Set(i, i * 10);
+ }
+ EXPECT_EQ(d.size(), 100);
+ for (int i = 0; i < 100; ++i) {
+ EXPECT_EQ(d[i], i * 10);
+ }
+}
+
+TEST(Dict, Erase) {
+ Dict<String, int> d;
+ d.Set("a", 1);
+ d.Set("b", 2);
+ d.Set("c", 3);
+
+ d.erase("b");
+ EXPECT_EQ(d.size(), 2);
+ EXPECT_EQ(d.count("b"), 0);
+ EXPECT_EQ(d["a"], 1);
+ EXPECT_EQ(d["c"], 3);
+}
+
+TEST(Dict, Clear) {
+ Dict<String, int> d;
+ d.Set("a", 1);
+ d.Set("b", 2);
+ d.clear();
+ EXPECT_EQ(d.size(), 0);
+ EXPECT_TRUE(d.empty());
+}
+
+TEST(Dict, Iteration) {
+ Dict<String, int> d;
+ d.Set("x", 10);
+ d.Set("y", 20);
+
+ int count = 0;
+ for (auto [k, v] : d) {
+ ++count;
+ if (k == "x") {
+ EXPECT_EQ(v, 10);
+ }
+ if (k == "y") {
+ EXPECT_EQ(v, 20);
+ }
+ }
+ EXPECT_EQ(count, 2);
+}
+
+TEST(Dict, PODKeys) {
+ Dict<int, String> d;
+ d.Set(1, "one");
+ d.Set(2, "two");
+ EXPECT_EQ(d[1], "one");
+ EXPECT_EQ(d[2], "two");
+}
+
+TEST(Dict, AnyConversion) {
+ Dict<Any, Any> d;
+ d.Set(String("key"), 42);
+
+ Any any_d = d;
+ auto d2 = any_d.cast<Dict<Any, Any>>();
+ EXPECT_EQ(d2.size(), 1);
+}
+
+TEST(Dict, InitializerList) {
+ Dict<String, int> d{{"a", 1}, {"b", 2}};
+ EXPECT_EQ(d.size(), 2);
+ EXPECT_EQ(d["a"], 1);
+ EXPECT_EQ(d["b"], 2);
+}
+
+TEST(Dict, UpdateExistingKey) {
+ Dict<String, int> d;
+ d.Set("a", 1);
+ d.Set("a", 2);
+ EXPECT_EQ(d.size(), 1);
+ EXPECT_EQ(d["a"], 2);
+}
+
+TEST(Dict, DefaultConstruction) {
+ Dict<String, int> d;
+ EXPECT_EQ(d.size(), 0);
+ EXPECT_TRUE(d.empty());
+ // Set on default-constructed should work
+ d.Set("a", 1);
+ EXPECT_EQ(d.size(), 1);
+}
+
+TEST(Dict, CrossConvMapToDict) {
+ Map<String, int> m{{"a", 1}, {"b", 2}};
+ Any any_m = m;
+ // Cast Map to Dict via Any — triggers cross-conversion
+ auto d = any_m.cast<Dict<String, int>>();
+ EXPECT_EQ(d.size(), 2);
+ EXPECT_EQ(d["a"], 1);
+ EXPECT_EQ(d["b"], 2);
+}
+
+TEST(Dict, CrossConvDictToMap) {
+ Dict<String, int> d{{"x", 10}, {"y", 20}};
+ Any any_d = d;
+ // Cast Dict to Map via Any — triggers cross-conversion
+ auto m = any_d.cast<Map<String, int>>();
+ EXPECT_EQ(m.size(), 2);
+ EXPECT_EQ(m["x"], 10);
+ EXPECT_EQ(m["y"], 20);
+}
+
+TEST(Dict, CrossConvEmptyMapToDict) {
+ Map<String, int> m;
+ Any any_m = m;
+ auto d = any_m.cast<Dict<String, int>>();
+ EXPECT_EQ(d.size(), 0);
+ EXPECT_TRUE(d.empty());
+}
+
+TEST(Dict, CrossConvEmptyDictToMap) {
+ Dict<String, int> d;
+ Any any_d = d;
+ auto m = any_d.cast<Map<String, int>>();
+ EXPECT_EQ(m.size(), 0);
+ EXPECT_TRUE(m.empty());
+}
+
+} // namespace
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index 64468f41..bd7e1d16 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -514,3 +514,222 @@ def test_missing_object() -> None:
assert m.get("a") == 1
assert m.get("b") is None
assert m.get("b", 42) == 42
+
+
+# ---------------------------------------------------------------------------
+# Dict tests
+# ---------------------------------------------------------------------------
+def test_dict_basic() -> None:
+ d = tvm_ffi.Dict({"a": 1, "b": 2})
+ assert len(d) == 2
+ assert "a" in d
+ assert d["a"] == 1
+ assert d["b"] == 2
+ assert "c" not in d
+
+
+def test_dict_setitem_delitem() -> None:
+ d = tvm_ffi.Dict({"a": 1})
+ d["b"] = 2
+ assert len(d) == 2
+ assert d["b"] == 2
+ del d["a"]
+ assert len(d) == 1
+ assert "a" not in d
+ with pytest.raises(KeyError):
+ del d["nonexistent"]
+
+
+def test_dict_shared_mutation() -> None:
+ d1 = tvm_ffi.Dict({"a": 1})
+ d2 = d1 # alias, same underlying object
+ d1["b"] = 2
+ assert d2["b"] == 2
+ assert len(d2) == 2
+
+
+def test_dict_keys_values_items() -> None:
+ d = tvm_ffi.Dict({"x": 10, "y": 20})
+ keys = set(d.keys())
+ assert keys == {"x", "y"}
+ values = set(d.values())
+ assert values == {10, 20}
+ items = set(d.items())
+ assert items == {("x", 10), ("y", 20)}
+
+
+def test_dict_get_with_default() -> None:
+ d = tvm_ffi.Dict({"a": 1})
+ assert d.get("a") == 1
+ assert d.get("b") is None
+ assert d.get("b", 42) == 42
+
+
+def test_dict_pop() -> None:
+ d = tvm_ffi.Dict({"a": 1, "b": 2})
+ val = d.pop("a")
+ assert val == 1
+ assert len(d) == 1
+ # pop with default
+ val2 = d.pop("nonexistent", 99)
+ assert val2 == 99
+ # pop missing without default
+ with pytest.raises(KeyError):
+ d.pop("nonexistent")
+
+
+def test_dict_clear() -> None:
+ d = tvm_ffi.Dict({"a": 1, "b": 2})
+ d.clear()
+ assert len(d) == 0
+
+
+def test_dict_update() -> None:
+ d = tvm_ffi.Dict({"a": 1})
+ d.update({"b": 2, "c": 3})
+ assert len(d) == 3
+ assert d["b"] == 2
+
+
+def test_dict_iteration() -> None:
+ d = tvm_ffi.Dict({"a": 1, "b": 2})
+ keys = list(d)
+ assert set(keys) == {"a", "b"}
+
+
+def test_dict_iteration_many_entries() -> None:
+ """Regression: ensure iterator visits ALL entries, not just the first."""
+ entries = {f"key_{i}": i for i in range(10)}
+ d = tvm_ffi.Dict(entries)
+ # keys
+ keys_list = list(d.keys())
+ assert len(keys_list) == 10
+ assert set(keys_list) == set(entries.keys())
+ # values
+ values_list = list(d.values())
+ assert len(values_list) == 10
+ assert set(values_list) == set(entries.values())
+ # items
+ items_list = list(d.items())
+ assert len(items_list) == 10
+ assert set(items_list) == set(entries.items())
+ # direct iteration (keys)
+ iter_keys = list(d)
+ assert len(iter_keys) == 10
+ assert set(iter_keys) == set(entries.keys())
+
+
+def test_dict_iteration_after_mutation() -> None:
+ """Regression: iteration after insert/delete visits correct elements."""
+ d = tvm_ffi.Dict({"a": 1, "b": 2, "c": 3})
+ d["d"] = 4
+ del d["b"]
+ # should have a, c, d
+ keys = list(d.keys())
+ assert len(keys) == 3
+ assert set(keys) == {"a", "c", "d"}
+ values = list(d.values())
+ assert len(values) == 3
+ assert set(values) == {1, 3, 4}
+ items = list(d.items())
+ assert len(items) == 3
+ assert set(items) == {("a", 1), ("c", 3), ("d", 4)}
+
+
+def test_dict_repr() -> None:
+ d = tvm_ffi.Dict({"a": 1})
+ r = repr(d)
+ assert isinstance(r, str)
+ # repr should look dict-like
+ assert ":" in r
+
+
+def test_dict_bool() -> None:
+ assert not tvm_ffi.Dict()
+ assert tvm_ffi.Dict({"a": 1})
+
+
+def test_dict_empty_init() -> None:
+ d = tvm_ffi.Dict()
+ assert len(d) == 0
+ d["a"] = 1
+ assert d["a"] == 1
+
+
+def test_dict_pickle_roundtrip() -> None:
+ d = tvm_ffi.Dict({"a": 1, "b": 2})
+ data = pickle.dumps(d)
+ d2 = pickle.loads(data)
+ assert isinstance(d2, tvm_ffi.Dict)
+ assert len(d2) == 2
+ assert d2["a"] == 1
+ assert d2["b"] == 2
+
+
+# ---------------------------------------------------------------------------
+# Map <-> Dict cross-conversion tests
+# ---------------------------------------------------------------------------
+
+
+def test_map_cross_conv_dict_to_map_str_int() -> None:
+ """Dict<String, int> passed to a function expecting Map<String, int>."""
+ d = tvm_ffi.Dict({"a": 1, "b": 2})
+ result = testing.schema_id_map_str_int(d)
+ assert isinstance(result, tvm_ffi.Map)
+ assert dict(result.items()) == {"a": 1, "b": 2}
+
+
+def test_map_cross_conv_map_to_dict_str_int() -> None:
+ """Map<String, int> passed to a function expecting Dict<String, int>."""
+ m = tvm_ffi.Map({"a": 1, "b": 2})
+ result = testing.schema_id_dict_str_int(m)
+ assert isinstance(result, tvm_ffi.Dict)
+ assert result["a"] == 1
+ assert result["b"] == 2
+
+
+def test_map_cross_conv_dict_to_map_str_str() -> None:
+ """Dict<String, String> passed to a function expecting Map<String,
String>."""
+ d = tvm_ffi.Dict({"x": "hello", "y": "world"})
+ result = testing.schema_id_map_str_str(d)
+ assert isinstance(result, tvm_ffi.Map)
+ assert dict(result.items()) == {"x": "hello", "y": "world"}
+
+
+def test_map_cross_conv_map_to_dict_str_str() -> None:
+ """Map<String, String> passed to a function expecting Dict<String,
String>."""
+ m = tvm_ffi.Map({"x": "hello", "y": "world"})
+ result = testing.schema_id_dict_str_str(m)
+ assert isinstance(result, tvm_ffi.Dict)
+ assert result["x"] == "hello"
+ assert result["y"] == "world"
+
+
+def test_map_cross_conv_empty_dict_to_map() -> None:
+ """Empty Dict passed to a function expecting Map<String, int>."""
+ d = tvm_ffi.Dict({})
+ result = testing.schema_id_map_str_int(d)
+ assert isinstance(result, tvm_ffi.Map)
+ assert len(result) == 0
+
+
+def test_map_cross_conv_empty_map_to_dict() -> None:
+ """Empty Map passed to a function expecting Dict<String, int>."""
+ m = tvm_ffi.Map({})
+ result = testing.schema_id_dict_str_int(m)
+ assert isinstance(result, tvm_ffi.Dict)
+ assert len(result) == 0
+
+
+def test_map_cross_conv_incompatible_dict_to_map() -> None:
+ """Dict with incompatible value types should fail when cast to Map<String,
int>."""
+ d = tvm_ffi.Dict({"a": "not_int", "b": "still_not_int"})
+ with pytest.raises(TypeError):
+ testing.schema_id_map_str_int(d) # type: ignore[arg-type]
+
+
+def test_map_cross_conv_incompatible_map_to_dict() -> None:
+ """Map with incompatible value types should fail when cast to Dict<String,
int>."""
+ m = tvm_ffi.Map({"a": "not_int", "b": "still_not_int"})
+ with pytest.raises(TypeError):
+ testing.schema_id_dict_str_int(m) # type: ignore[arg-type]
diff --git a/tests/python/test_copy.py b/tests/python/test_copy.py
index 0a8b28fa..3ec9a55c 100644
--- a/tests/python/test_copy.py
+++ b/tests/python/test_copy.py
@@ -238,6 +238,16 @@ class TestDeepCopy:
assert not m["key"].same_as(m_deep["key"])
assert m_deep["key"].a == 3
+ def test_dict_root(self) -> None:
+ """Deepcopy with a bare Dict as root should create a new dict."""
+ inner = tvm_ffi.testing.TestIntPair(3, 4) # ty:
ignore[too-many-positional-arguments]
+ d = tvm_ffi.Dict({"key": inner})
+ d_deep = copy.deepcopy(d)
+ assert not d.same_as(d_deep)
+ # inner object is deep-copied
+ assert not d["key"].same_as(d_deep["key"])
+ assert d_deep["key"].a == 3
+
def test_auto_deepcopy_for_cxx_class(self) -> None:
# _TestCxxClassBase is copy-constructible, so deepcopy is auto-enabled
# Note: _TestCxxClassBase.__init__ adds 1 to v_i64 and 2 to v_i32
@@ -466,6 +476,36 @@ class TestDeepCopyBranches:
assert not m.same_as(m_deep)
assert len(m_deep) == 0
+ # --- Resolve(): dict with various key/value types ---
+
+ def test_dict_primitive_keys_and_values(self) -> None:
+ d = tvm_ffi.Dict({"a": 1, "b": 2, "c": 3})
+ d_deep = copy.deepcopy(d)
+ assert not d.same_as(d_deep)
+ assert d_deep["a"] == 1
+ assert d_deep["b"] == 2
+ assert d_deep["c"] == 3
+
+ def test_dict_with_container_values(self) -> None:
+ inner_arr = tvm_ffi.convert([1, 2])
+ d = tvm_ffi.Dict({"arr": inner_arr})
+ d_deep = copy.deepcopy(d)
+ assert not d["arr"].same_as(d_deep["arr"])
+ assert list(d_deep["arr"]) == [1, 2]
+
+ def test_dict_with_none_values(self) -> None:
+ d = tvm_ffi.Dict({"a": None, "b": 1})
+ d_deep = copy.deepcopy(d)
+ assert not d.same_as(d_deep)
+ assert d_deep["a"] is None
+ assert d_deep["b"] == 1
+
+ def test_dict_empty(self) -> None:
+ d = tvm_ffi.Dict({})
+ d_deep = copy.deepcopy(d)
+ assert not d.same_as(d_deep)
+ assert len(d_deep) == 0
+
# --- Resolve(): copy_map_ hit (shared references across containers) ---
def test_shared_array_identity_in_outer_array(self) -> None:
@@ -484,6 +524,13 @@ class TestDeepCopyBranches:
assert outer_deep[0].same_as(outer_deep[1])
assert not outer[0].same_as(outer_deep[0])
+ def test_shared_dict_identity_in_outer_array(self) -> None:
+ shared = tvm_ffi.Dict({"x": 1})
+ outer = tvm_ffi.convert([shared, shared])
+ outer_deep = copy.deepcopy(outer)
+ assert outer_deep[0].same_as(outer_deep[1])
+ assert not outer[0].same_as(outer_deep[0])
+
def test_shared_object_across_array_and_map(self) -> None:
"""Same object referenced from both v_array and v_map."""
pair = tvm_ffi.testing.TestIntPair(7, 8) # ty:
ignore[too-many-positional-arguments]
diff --git a/tests/python/test_cubin_launcher.py
b/tests/python/test_cubin_launcher.py
index 5409777b..c953f882 100644
--- a/tests/python/test_cubin_launcher.py
+++ b/tests/python/test_cubin_launcher.py
@@ -27,6 +27,7 @@ import pytest
try:
import torch
+ import torch.version
except ImportError:
torch = None # ty: ignore[invalid-assignment]
diff --git a/tests/python/test_metadata.py b/tests/python/test_metadata.py
index 4d91462d..ec0d0de7 100644
--- a/tests/python/test_metadata.py
+++ b/tests/python/test_metadata.py
@@ -60,6 +60,8 @@ def _replace_list_dict(ty: str) -> str:
("testing.schema_id_map_str_str", "Callable[[dict[str, str]],
dict[str, str]]"),
("testing.schema_id_map_str_obj", "Callable[[dict[str, Object]],
dict[str, Object]]"),
("testing.schema_id_map", "Callable[[dict[Any, Any]], dict[Any,
Any]]"),
+ ("testing.schema_id_dict_str_int", "Callable[[dict[str, int]],
dict[str, int]]"),
+ ("testing.schema_id_dict_str_str", "Callable[[dict[str, str]],
dict[str, str]]"),
("testing.schema_id_variant_int_str", "Callable[[int | str], int |
str]"),
("testing.schema_packed", "Callable[..., Any]"),
(
diff --git a/tests/python/test_optional_torch_c_dlpack.py
b/tests/python/test_optional_torch_c_dlpack.py
index c409da28..6640d6fb 100644
--- a/tests/python/test_optional_torch_c_dlpack.py
+++ b/tests/python/test_optional_torch_c_dlpack.py
@@ -24,6 +24,7 @@ import pytest
try:
import torch
+ import torch.version
except ImportError:
torch = None # ty: ignore[invalid-assignment]
diff --git a/tests/python/test_repr.py b/tests/python/test_repr.py
index bc962117..312c517b 100644
--- a/tests/python/test_repr.py
+++ b/tests/python/test_repr.py
@@ -124,6 +124,39 @@ def test_repr_map_empty() -> None:
assert ReprPrint(tvm_ffi.Map({})) == "{}"
+# ---------- Dict ----------
+
+
+def test_repr_dict() -> None:
+ """Test repr of FFI Dict."""
+ assert ReprPrint(tvm_ffi.Dict({"key": "value"})) == '{"key": "value"}'
+
+
+def test_repr_dict_empty() -> None:
+ """Test repr of empty Dict."""
+ assert ReprPrint(tvm_ffi.Dict({})) == "{}"
+
+
+def test_repr_dict_int_keys() -> None:
+ """Test repr of Dict with integer keys."""
+ d = tvm_ffi.Dict({1: 2, 3: 4})
+ result = ReprPrint(d)
+ # Dict iteration order is hash-dependent; match either ordering.
+ _check(result, r"(?:\{1: 2, 3: 4\}|\{3: 4, 1: 2\})")
+
+
+def test_repr_dict_with_array_values() -> None:
+ """Test repr of Dict with Array values."""
+ assert ReprPrint(tvm_ffi.Dict({1: tvm_ffi.Array([10, 20])})) == "{1: (10,
20)}"
+
+
+def test_repr_dict_with_object_values() -> None:
+ """Test repr of Dict with object values."""
+ pair = tvm_ffi.testing.create_object("testing.TestIntPair", a=1, b=2)
+ d = tvm_ffi.Dict({"obj": pair})
+ assert ReprPrint(d) == '{"obj": testing.TestIntPair(a=1, b=2)}'
+
+
# ---------- Tensor ----------
@@ -527,6 +560,13 @@ def test_repr_with_addr_list(monkeypatch:
pytest.MonkeyPatch) -> None:
_check(ReprPrint(lst), rf"\[10, 20\]@{A}")
+def test_repr_with_addr_dict(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test that Dict shows address suffix when TVM_FFI_REPR_WITH_ADDR is
set."""
+ monkeypatch.setenv("TVM_FFI_REPR_WITH_ADDR", "1")
+ d = tvm_ffi.Dict({"a": 1})
+ _check(ReprPrint(d), rf'\{{"a": 1\}}@{A}')
+
+
def test_repr_with_addr_dag(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test DAG with addresses: both occurrences show full form with same
address."""
monkeypatch.setenv("TVM_FFI_REPR_WITH_ADDR", "1")
@@ -556,7 +596,7 @@ def test_repr_with_addr_cycle(monkeypatch:
pytest.MonkeyPatch) -> None:
_check(
result,
rf"testing\.TestObjectDerived@(?P<obj>{A})\("
- rf'v_i64=1, v_f64=0, v_str="", v_map=\{{\}}, '
+ rf'v_i64=1, v_f64=0, v_str="", v_map=\{{\}}@{A}, '
rf"v_array=\(\.\.\.@(?P=obj),\)@{A}"
rf"\)",
)
diff --git a/tests/python/test_serialization.py
b/tests/python/test_serialization.py
index d2357fa4..83a9cf96 100644
--- a/tests/python/test_serialization.py
+++ b/tests/python/test_serialization.py
@@ -55,7 +55,7 @@ def _assert_any_equal(a: Any, b: Any) -> None:
assert len(a) == len(b)
for x, y in zip(a, b):
_assert_any_equal(x, y)
- elif isinstance(b, tvm_ffi.Map):
+ elif isinstance(b, (tvm_ffi.Map, tvm_ffi.Dict)):
assert len(a) == len(b)
for k in b:
_assert_any_equal(a[k], b[k])
@@ -329,6 +329,54 @@ class TestMap:
assert list(result["nums"]) == [10, 20]
+class TestDict:
+ """Roundtrip tests for ffi.Dict containers."""
+
+ def test_empty(self) -> None:
+ """Empty dict roundtrips correctly."""
+ d = tvm_ffi.Dict({})
+ _assert_roundtrip_eq(d)
+
+ def test_single_entry(self) -> None:
+ """Single-entry dict roundtrips correctly."""
+ d = tvm_ffi.Dict({"key": 42})
+ result = _roundtrip(d)
+ assert len(result) == 1
+ assert result["key"] == 42
+
+ def test_multiple_entries(self) -> None:
+ """Multi-entry dict roundtrips correctly."""
+ d = tvm_ffi.Dict({"a": 1, "b": 2, "c": 3})
+ result = _roundtrip(d)
+ assert len(result) == 3
+ assert result["a"] == 1
+ assert result["b"] == 2
+ assert result["c"] == 3
+
+ def test_mixed_value_types(self) -> None:
+ """Dict with mixed value types roundtrips correctly."""
+ d = tvm_ffi.Dict({"int": 42, "str": "hello", "bool": True, "none":
None})
+ result = _roundtrip(d)
+ assert result["int"] == 42
+ assert result["str"] == "hello"
+ assert result["bool"] is True
+ assert result["none"] is None
+
+ def test_nested_dict(self) -> None:
+ """Nested dicts roundtrip correctly."""
+ inner = tvm_ffi.Dict({"x": 1})
+ outer = tvm_ffi.Dict({"inner": inner})
+ result = _roundtrip(outer)
+ assert result["inner"]["x"] == 1
+
+ def test_dict_with_array_value(self) -> None:
+ """Dict with array values roundtrips correctly."""
+ arr = tvm_ffi.convert([10, 20])
+ d = tvm_ffi.Dict({"nums": arr})
+ result = _roundtrip(d)
+ assert list(result["nums"]) == [10, 20]
+
+
class TestShape:
"""Roundtrip tests for ffi.Shape containers."""
@@ -448,6 +496,16 @@ class TestJSONStructure:
# key-value pairs flattened: [key_idx, val_idx]
assert len(root["data"]) == 2
+ def test_dict_json_structure(self) -> None:
+ """Dict data contains flattened key-value node index pairs."""
+ s = to_json_graph_str(tvm_ffi.Dict({"a": 1}))
+ parsed = json.loads(s)
+ root = parsed["nodes"][parsed["root_index"]]
+ assert root["type"] == "ffi.Dict"
+ assert isinstance(root["data"], list)
+ # key-value pairs flattened: [key_idx, val_idx]
+ assert len(root["data"]) == 2
+
def test_node_dedup(self) -> None:
"""Duplicate values should share the same node index."""
s = to_json_graph_str(tvm_ffi.convert([42, 42, 42]))
diff --git a/tests/python/test_structural.py b/tests/python/test_structural.py
index df800f7c..952f7369 100644
--- a/tests/python/test_structural.py
+++ b/tests/python/test_structural.py
@@ -59,6 +59,27 @@ def test_structural_key_in_map() -> None:
assert m[k3] == 3
+def test_structural_equal_dict() -> None:
+ d1 = tvm_ffi.Dict({"a": 1, "b": 2, "c": 3})
+ d2 = tvm_ffi.Dict({"c": 3, "b": 2, "a": 1})
+ d3 = tvm_ffi.Dict({"a": 1, "b": 2, "c": 4})
+
+ assert tvm_ffi.structural_equal(d1, d2)
+ assert tvm_ffi.structural_hash(d1) == tvm_ffi.structural_hash(d2)
+ assert not tvm_ffi.structural_equal(d1, d3)
+ assert tvm_ffi.structural_hash(d1) != tvm_ffi.structural_hash(d3)
+ assert tvm_ffi.get_first_structural_mismatch(d1, d2) is None
+ assert tvm_ffi.get_first_structural_mismatch(d1, d3) is not None
+
+
+def test_structural_dict_vs_map_different_type() -> None:
+ m = tvm_ffi.Map({"a": 1, "b": 2})
+ d = tvm_ffi.Dict({"a": 1, "b": 2})
+ # Different type_index => not structurally equal
+ assert not tvm_ffi.structural_equal(m, d)
+ assert tvm_ffi.structural_hash(m) != tvm_ffi.structural_hash(d)
+
+
def test_structural_key_in_python_dict() -> None:
k1 = tvm_ffi.StructuralKey({"name": ["a", "b"], "ver": [1]})
k2 = tvm_ffi.StructuralKey({"ver": [1], "name": ["a", "b"]})
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index df8023d1..cd5ac090 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -24,6 +24,7 @@ import pytest
try:
import torch
+ import torch.version
except ImportError:
torch = None # ty: ignore[invalid-assignment]