This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/refactor-s2 by this push:
new f4eed48b58 [FFI] Cython update map items to use lazy iterator
f4eed48b58 is described below
commit f4eed48b5863c888db98fad85c5b18c37f661e3d
Author: tqchen <[email protected]>
AuthorDate: Tue Apr 29 20:21:20 2025 -0400
[FFI] Cython update map items to use lazy iterator
---
ffi/src/ffi/container.cc | 37 ++++++++++++++----
python/tvm/ffi/container.py | 80 +++++++++++++++++++++++++++++++-------
tests/python/ffi/test_container.py | 23 +++++------
3 files changed, 106 insertions(+), 34 deletions(-)
diff --git a/ffi/src/ffi/container.cc b/ffi/src/ffi/container.cc
index e5ffaadd9f..885e8395ed 100644
--- a/ffi/src/ffi/container.cc
+++ b/ffi/src/ffi/container.cc
@@ -60,14 +60,35 @@ TVM_FFI_REGISTER_GLOBAL("ffi.MapGetItem")
TVM_FFI_REGISTER_GLOBAL("ffi.MapCount")
.set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return
n->count(k); });
-TVM_FFI_REGISTER_GLOBAL("ffi.MapItems").set_body_typed([](const ffi::MapObj*
n) -> Array<Any> {
- Array<Any> rkvs;
- for (const auto& kv : *n) {
- rkvs.push_back(kv.first);
- rkvs.push_back(kv.second);
- }
- return rkvs;
-});
+TVM_FFI_REGISTER_GLOBAL("ffi.MapForwardIterFunctor")
+ .set_body_typed([](const ffi::MapObj* n) -> ffi::Function {
+ class MapForwardIterFunctor {
+ public:
+ MapForwardIterFunctor(ffi::MapObj::iterator iter,
ffi::MapObj::iterator end)
+ : iter_(iter), end_(end) {}
+ // 0 get current key
+ // 1 get current value
+ // 2 move to next: return true if success, false if end
+ Any operator()(int command) const {
+ if (command == 0) {
+ return (*iter_).first;
+ } else if (command == 1) {
+ return (*iter_).second;
+ } else {
+ ++iter_;
+ if (iter_ == end_) {
+ return false;
+ }
+ return true;
+ }
+ }
+
+ private:
+ mutable ffi::MapObj::iterator iter_;
+ ffi::MapObj::iterator end_;
+ };
+ return ffi::Function::FromUnpacked(MapForwardIterFunctor(n->begin(),
n->end()));
+ });
} // namespace ffi
} // namespace tvm
diff --git a/python/tvm/ffi/container.py b/python/tvm/ffi/container.py
index 829cf8cf23..9dcc737078 100644
--- a/python/tvm/ffi/container.py
+++ b/python/tvm/ffi/container.py
@@ -76,9 +76,70 @@ class Array(core.Object):
return _ffi_api.ArraySize(self)
+class MapKeys:
+ """Helper class to return keys view"""
+
+ def __init__(self, map):
+ self._map = map
+
+ def __len__(self):
+ return len(self._map)
+
+ def __iter__(self):
+ if self.__len__() == 0:
+ return
+ functor = _ffi_api.MapForwardIterFunctor(self._map)
+ while True:
+ k = functor(0)
+ yield k
+ if not functor(2):
+ break
+
+
+class MapValues:
+ """Helper class to return values view"""
+
+ def __init__(self, map):
+ self._map = map
+
+ def __len__(self):
+ return len(self._map)
+
+ def __iter__(self):
+ if self.__len__() == 0:
+ return
+ functor = _ffi_api.MapForwardIterFunctor(self._map)
+ while True:
+ v = functor(1)
+ yield v
+ if not functor(2):
+ break
+
+
+class MapItems:
+ """Helper class to return items view"""
+
+ def __init__(self, map):
+ self._map = map
+
+ def __len__(self):
+ return len(self._map)
+
+ def __iter__(self):
+ if self.__len__() == 0:
+ return
+ functor = _ffi_api.MapForwardIterFunctor(self._map)
+ while True:
+ k = functor(0)
+ v = functor(1)
+ yield (k, v)
+ if not functor(2):
+ break
+
+
@register_object("object.Map")
class Map(core.Object):
- """Map container"""
+ """Map container."""
def __init__(self, input_dict: Dict[Any, Any]):
list_kvs = []
@@ -93,26 +154,15 @@ class Map(core.Object):
def __contains__(self, k):
return _ffi_api.MapCount(self, k) != 0
- def __iter__(self):
- akvs = _ffi_api.MapItems(self)
- for i in range(len(self)):
- yield akvs[i * 2]
-
- def __dir__(self):
- return sorted(dir(self.__class__) + ["type_key"])
-
def keys(self):
- return iter(self)
+ return MapKeys(self)
def values(self):
- akvs = _ffi_api.MapItems(self)
- for i in range(len(self)):
- yield akvs[i * 2 + 1]
+ return MapValues(self)
def items(self):
"""Get the items from the map"""
- akvs = _ffi_api.MapItems(self)
- return [(akvs[i], akvs[i + 1]) for i in range(0, len(akvs), 2)]
+ return MapItems(self)
def __len__(self):
return _ffi_api.MapSize(self)
diff --git a/tests/python/ffi/test_container.py
b/tests/python/ffi/test_container.py
index 3a2166dd20..44cb3f321b 100644
--- a/tests/python/ffi/test_container.py
+++ b/tests/python/ffi/test_container.py
@@ -46,17 +46,18 @@ def test_int_map():
assert 3 in dd
assert 4 in dd
assert 5 not in amap
- assert {x for x in amap} == {3, 4}
- assert set(amap.keys()) == {3, 4}
- assert set(amap.values()) == {2, 3}
+ assert tuple(amap.items()) == ((3, 2), (4, 3))
+ assert tuple(amap.keys()) == (3, 4)
+ assert tuple(amap.values()) == (2, 3)
def test_str_map():
- amap = tvm_ffi.convert({"a": 2, "b": 3})
- assert "a" in amap
- assert len(amap) == 2
- dd = dict(amap.items())
- assert amap["a"] == 2
- assert amap.get("b") == 3
- assert "a" in dd
- assert "b" in dd
+ data = []
+ for i in reversed(range(10)):
+ data.append((f"a{i}", i))
+ amap = tvm_ffi.convert({k: v for k, v in data})
+ assert tuple(amap.items()) == tuple(data)
+ for k, v in data:
+ assert k in amap
+ assert amap[k] == v
+ assert amap.get(k) == v