This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 438f643  feat: Fix perf issue in `Map.get` (#341)
438f643 is described below

commit 438f6439148b059d424ce2cc2a348736923f6948
Author: Junru Shao <[email protected]>
AuthorDate: Fri Dec 12 13:08:38 2025 -0800

    feat: Fix perf issue in `Map.get` (#341)
    
    Should fix #326
---
 python/tvm_ffi/_ffi_api.py  |  6 +++++-
 python/tvm_ffi/container.py | 17 +++++++++--------
 src/ffi/container.cc        | 18 ++++++++++++++++--
 3 files changed, 30 insertions(+), 11 deletions(-)

diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 4b49716..f9850a5 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -23,7 +23,7 @@ from __future__ import annotations
 from typing import Any, Callable, TYPE_CHECKING
 if TYPE_CHECKING:
     from collections.abc import Mapping, Sequence
-    from tvm_ffi import Module
+    from tvm_ffi import Module, Object
     from tvm_ffi.access_path import AccessPath
 # isort: on
 # fmt: on
@@ -50,6 +50,8 @@ if TYPE_CHECKING:
     def MapCount(_0: Mapping[Any, Any], _1: Any, /) -> int: ...
     def MapForwardIterFunctor(_0: Mapping[Any, Any], /) -> Callable[..., Any]: 
...
     def MapGetItem(_0: Mapping[Any, Any], _1: Any, /) -> Any: ...
+    def MapGetItemOrMissing(_0: Mapping[Any, Any], _1: Any, /) -> Any: ...
+    def MapGetMissingObject() -> Object: ...
     def MapSize(_0: Mapping[Any, Any], /) -> int: ...
     def ModuleClearImports(_0: Module, /) -> None: ...
     def ModuleGetFunction(_0: Module, _1: str, _2: bool, /) -> Callable[..., 
Any] | None: ...
@@ -95,6 +97,8 @@ __all__ = [
     "MapCount",
     "MapForwardIterFunctor",
     "MapGetItem",
+    "MapGetItemOrMissing",
+    "MapGetMissingObject",
     "MapSize",
     "ModuleClearImports",
     "ModuleGetFunction",
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 06fb92e..dfa0d22 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -77,6 +77,8 @@ K = TypeVar("K")
 V = TypeVar("V")
 _DefaultT = TypeVar("_DefaultT")
 
+MISSING = _ffi_api.MapGetMissingObject()
+
 
 def getitem_helper(
     obj: Any,
@@ -254,12 +256,11 @@ class ItemsView(ItemsViewBase[K, V]):
         if not isinstance(item, tuple) or len(item) != 2:
             return False
         key, value = item
-        try:
-            existing_value = self._backend_map[key]
-        except KeyError:
+        actual_value = self._backend_map.get(key, MISSING)
+        if actual_value is MISSING:
             return False
-        else:
-            return existing_value == value
+        # TODO(@junrus): Is `__eq__` the right method to use here?
+        return actual_value == value
 
 
 @register_object("ffi.Map")
@@ -349,10 +350,10 @@ class Map(core.Object, Mapping[K, V]):
             The result value.
 
         """
-        try:
-            return self[key]
-        except KeyError:
+        ret = _ffi_api.MapGetItemOrMissing(self, key)
+        if MISSING.same_as(ret):
             return default
+        return ret
 
     def __repr__(self) -> str:
         """Return a string representation of the map."""
diff --git a/src/ffi/container.cc b/src/ffi/container.cc
index b777dc0..57eda37 100644
--- a/src/ffi/container.cc
+++ b/src/ffi/container.cc
@@ -55,6 +55,11 @@ class MapForwardIterFunctor {
   ffi::MapObj::iterator end_;
 };
 
+ObjectRef GetMissingObject() {
+  static ObjectRef missing_obj(make_object<Object>());
+  return missing_obj;
+}
+
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef()
@@ -81,8 +86,17 @@ TVM_FFI_STATIC_INIT_BLOCK() {
            [](const ffi::MapObj* n, const Any& k) -> int64_t {
              return static_cast<int64_t>(n->count(k));
            })
-      .def("ffi.MapForwardIterFunctor", [](const ffi::MapObj* n) -> 
ffi::Function {
-        return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), 
n->end()));
+      .def("ffi.MapForwardIterFunctor",
+           [](const ffi::MapObj* n) -> ffi::Function {
+             return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), 
n->end()));
+           })
+      .def("ffi.MapGetMissingObject", GetMissingObject)
+      .def("ffi.MapGetItemOrMissing", [](const ffi::MapObj* n, const Any& k) 
-> Any {
+        try {
+          return n->at(k);
+        } catch (const tvm::ffi::Error& e) {
+          return GetMissingObject();
+        }
       });
 }
 }  // namespace ffi

Reply via email to