This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2b4a1e2fef [Relax][Frontend] Introduce ModuleDict (#18551)
2b4a1e2fef is described below
commit 2b4a1e2fefb226127b950528689a8b7947ad43bd
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Dec 7 13:24:12 2025 +0900
[Relax][Frontend] Introduce ModuleDict (#18551)
As per title.
Just like [ModuleDict in
PyTorch](https://docs.pytorch.org/docs/stable/generated/torch.nn.ModuleDict.html).
---
python/tvm/relax/frontend/nn/__init__.py | 2 +-
python/tvm/relax/frontend/nn/core.py | 61 +++++++++++++++++++++++++
python/tvm/relax/frontend/nn/visitor.py | 40 ++++++++++++++--
tests/python/relax/test_frontend_nn_modules.py | 17 +++++++
tests/python/relax/test_frontend_nn_mutator.py | 63 +++++++++++++++++++++++++-
5 files changed, 178 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/__init__.py
b/python/tvm/relax/frontend/nn/__init__.py
index f490af7062..d903634883 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -17,7 +17,7 @@
"""A PyTorch-like API to build IRModules."""
# pylint: disable=redefined-builtin
from . import op, spec
-from .core import Effect, Module, ModuleList, Object, Parameter, Tensor
+from .core import Effect, Module, ModuleDict, ModuleList, Object, Parameter,
Tensor
from .exporter import add_extern
from .extern import ExternModule, ObjectModule, SourceModule
from .modules import (
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index 8529dda006..b15ba685b7 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -540,6 +540,56 @@ class Module(SubroutineMixin):
raise ValueError(f"Unknown out_format: {out_format}")
+class ModuleDict(Module):
+ """Holds submodules in a dict."""
+
+ def __init__(self, modules: Optional[OrderedDict[str, Module]] = None):
+ if modules is None:
+ self.modules = OrderedDict()
+ else:
+ self.modules = OrderedDict(modules)
+
+ def __iter__(self):
+ return iter(self.modules.values())
+
+ def __getitem__(self, key: str) -> Module:
+ return self.modules[key]
+
+ def __setitem__(self, key: str, module: Module) -> None:
+ self.modules[key] = module
+
+ def __len__(self) -> int:
+ return len(self.modules)
+
+ def keys(self) -> Iterator[str]:
+ return self.modules.keys()
+
+ def values(self) -> Iterator[Module]:
+ return self.modules.values()
+
+ def items(self) -> Iterator[Tuple[str, Module]]:
+ return self.modules.items()
+
+ def get(self, key: str, default: Optional[Module] = None) ->
Optional[Module]:
+ return self.modules.get(key, default)
+
+ def update(self, modules: Dict[str, Module]) -> None:
+ self.modules.update(modules)
+
+ def clear(self) -> None:
+ self.modules.clear()
+
+ def pop(self, key: str) -> Module:
+ return self.modules.pop(key)
+
+ def __contains__(self, key: str) -> bool:
+ return key in self.modules
+
+ def to(self, dtype: Optional[str] = None) -> None: # pylint:
disable=invalid-name
+ for module in self.modules.values():
+ module.to(dtype=dtype)
+
+
class ModuleList(Module):
"""Holds submodules in a list."""
@@ -611,6 +661,10 @@ def _attribute_finder(root: Module, prefix: str,
condition_yield: Callable[[Any]
for i, subitem in enumerate(root):
yield from _attribute_finder(subitem, prefix + f"{i}.",
condition_yield)
return
+ elif isinstance(root, ModuleDict):
+ for name, subitem in root.items():
+ yield from _attribute_finder(subitem, prefix + f"{name}.",
condition_yield)
+ return
for name, item in root.__dict__.items():
if condition_yield(item):
yield prefix + name, item
@@ -620,6 +674,13 @@ def _attribute_finder(root: Module, prefix: str,
condition_yield: Callable[[Any]
prefix + name + ".",
condition_yield,
)
+ elif isinstance(item, ModuleDict):
+ for sub_name, sub_item in item.items():
+ yield from _attribute_finder(
+ sub_item,
+ prefix + name + f".{sub_name}.",
+ condition_yield,
+ )
elif isinstance(item, Module):
yield from _attribute_finder(
item,
diff --git a/python/tvm/relax/frontend/nn/visitor.py
b/python/tvm/relax/frontend/nn/visitor.py
index 82f3010066..d2467a2bf8 100644
--- a/python/tvm/relax/frontend/nn/visitor.py
+++ b/python/tvm/relax/frontend/nn/visitor.py
@@ -79,6 +79,24 @@ class Mutator:
"""
return self.visit(name, node)
+ def visit_moduledict(self, name: str, node: nn.ModuleDict) -> Any:
+ """The base visiting method for mutation of nn.ModuleDict nodes.
+
+ Parameters
+ ----------
+ name : str
+ The name of the current node in parent's attribute.
+
+ node : nn.ModuleDict
+ The current node of nn.ModuleDict to mutate.
+
+ Returns
+ ------
+ ret_node: Any
+ The new node to replace current node.
+ """
+ return self.visit(name, node)
+
def visit_modulelist(self, name: str, node: nn.ModuleList) -> Any:
"""The base visiting method for mutation of nn.ModuleList nodes.
@@ -88,7 +106,7 @@ class Mutator:
The name of the current node in parent's attribute.
node : nn.ModuleList
- The current node of nn.MoModuleListdule to mutate.
+ The current node of nn.ModuleList to mutate.
Returns
------
@@ -124,7 +142,9 @@ class Mutator:
if isinstance(node, nn.ModuleList):
for i in range(len(node)):
- if isinstance(node[i], nn.ModuleList):
+ if isinstance(node[i], nn.ModuleDict):
+ node[i] = self.visit_moduledict(f"{name}.{i}", node[i])
+ elif isinstance(node[i], nn.ModuleList):
node[i] = self.visit_modulelist(f"{name}.{i}", node[i])
elif isinstance(node[i], nn.Module):
node[i] = self.visit_module(f"{name}.{i}", node[i])
@@ -132,9 +152,23 @@ class Mutator:
node[i] = self.visit_effect(f"{name}.{i}", node[i])
elif isinstance(node[i], nn.Parameter):
node[i] = self.visit_param(f"{name}.{i}", node[i])
+ elif isinstance(node, nn.ModuleDict):
+ for k, v in node.items():
+ if isinstance(v, nn.ModuleDict):
+ node[k] = self.visit_moduledict(_get_child_name(name, k),
v)
+ elif isinstance(v, nn.ModuleList):
+ node[k] = self.visit_modulelist(_get_child_name(name, k),
v)
+ elif isinstance(v, nn.Module):
+ node[k] = self.visit_module(_get_child_name(name, k), v)
+ elif isinstance(v, nn.Effect):
+ node[k] = self.visit_effect(_get_child_name(name, k), v)
+ elif isinstance(v, nn.Parameter):
+ node[k] = self.visit_param(_get_child_name(name, k), v)
else:
for key, value in node.__dict__.items():
- if isinstance(value, nn.ModuleList):
+ if isinstance(value, nn.ModuleDict):
+ setattr(node, key,
self.visit_moduledict(_get_child_name(name, key), value))
+ elif isinstance(value, nn.ModuleList):
setattr(node, key,
self.visit_modulelist(_get_child_name(name, key), value))
elif isinstance(value, nn.Module):
setattr(node, key, self.visit_module(_get_child_name(name,
key), value))
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index 23250f28aa..e9a4a6f624 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -715,5 +715,22 @@ def test_module_list():
assert ["layers.0.0.weight", "layers.0.1.weight"] ==
sorted(list(named_params.keys()))
+def test_module_dict():
+ class Module(nn.Module):
+ def __init__(self):
+ self.layers = nn.ModuleDict(
+ {"linear0": nn.Linear(4, 4, bias=False), "linear1":
nn.Linear(4, 4, bias=False)}
+ )
+
+ def forward(self, x: nn.Tensor):
+ x = self.layers["linear0"](x)
+ x = self.layers["linear1"](x)
+ return x
+
+ mod = Module()
+ named_params = dict(mod.named_parameters())
+ assert ["layers.linear0.weight", "layers.linear1.weight"] ==
sorted(list(named_params.keys()))
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_nn_mutator.py
b/tests/python/relax/test_frontend_nn_mutator.py
index ffb6586159..253e24a4ed 100644
--- a/tests/python/relax/test_frontend_nn_mutator.py
+++ b/tests/python/relax/test_frontend_nn_mutator.py
@@ -65,6 +65,37 @@ def test_mutator_naming_basic():
mutator.visit("mod3", mod3)
+def test_mutator_naming_moduledict():
+ class Module(nn.Module):
+ def __init__(self, dtype) -> None:
+ super().__init__()
+ self.param = nn.Parameter((32, 128), dtype)
+
+ class Mutator(nn.Mutator):
+ def visit_param(self, name: str, node: nn.Parameter) -> Any:
+ if node.dtype == "float64":
+ assert name == "mod_dict.k0.0.param"
+ return node
+ elif node.dtype == "float32":
+ assert name == "mod_dict.k0.1.param"
+ return node
+ elif node.dtype == "float16":
+ assert name == "mod_dict.k1.0.param"
+ return node
+ elif node.dtype == "float8":
+ assert name == "mod_dict.k1.1.param"
+ return node
+
+ mod_dict = nn.ModuleDict(
+ {
+ "k0": nn.ModuleList([Module("float64"), Module("float32")]),
+ "k1": nn.ModuleList([Module("float16"), Module("float8")]),
+ }
+ )
+ mutator = Mutator()
+ mutator.visit("mod_dict", mod_dict)
+
+
def test_mutator_naming_modulelist():
class Module(nn.Module):
def __init__(self, dtype) -> None:
@@ -124,6 +155,37 @@ def test_mutator_module():
assert isinstance(module.mod, SubModule2)
+def test_mutator_moduledict():
+ class Module1(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ class Module2(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ class Module3(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ class Mutator(nn.Mutator):
+ def visit_module(self, name: str, node: nn.Module) -> Any:
+ if isinstance(node, Module3):
+ return Module1()
+ else:
+ return node
+
+ mutator = Mutator()
+ module_dict = nn.ModuleDict({"k0": Module1(), "k1": Module2(), "k2":
Module3()})
+ assert isinstance(module_dict["k0"], Module1)
+ assert isinstance(module_dict["k1"], Module2)
+ assert isinstance(module_dict["k2"], Module3)
+ module_dict = mutator.visit("", module_dict)
+ assert isinstance(module_dict["k0"], Module1)
+ assert isinstance(module_dict["k1"], Module2)
+ assert isinstance(module_dict["k2"], Module1)
+
+
def test_mutator_modulelist():
class Module1(nn.Module):
def __init__(self) -> None:
@@ -150,7 +212,6 @@ def test_mutator_modulelist():
assert isinstance(module_list[1], Module2)
assert isinstance(module_list[2], Module3)
module_list = mutator.visit("", module_list)
- print(module_list[2])
assert isinstance(module_list[0], Module1)
assert isinstance(module_list[1], Module2)
assert isinstance(module_list[2], Module1)