gemini-code-assist[bot] commented on code in PR #18551:
URL: https://github.com/apache/tvm/pull/18551#discussion_r2593587859
##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -540,6 +542,53 @@ def _compile(spec, device, pipeline, debug):
raise ValueError(f"Unknown out_format: {out_format}")
+class ModuleDict(Module):
+ """Holds submodules in a dict."""
+
+ def __init__(self, modules: OrderedDict[str, Module]):
+ self.modules = OrderedDict(modules)
Review Comment:

The `__init__` method for `ModuleDict` is a bit restrictive. To align better
with PyTorch's `ModuleDict` and improve usability, I recommend making the
`modules` argument optional and accepting a more general mapping type.
This would allow creating an empty `ModuleDict` with `nn.ModuleDict()` and
also allow passing regular dictionaries during initialization, which is more
convenient.
```suggestion
def __init__(self, modules: Optional[Dict[str, Module]] = None):
if modules is None:
self.modules = OrderedDict()
else:
self.modules = OrderedDict(modules)
```
##########
tests/python/relax/test_frontend_nn_modules.py:
##########
@@ -715,5 +715,22 @@ def forward(self, x: nn.Tensor):
assert ["layers.0.0.weight", "layers.0.1.weight"] ==
sorted(list(named_params.keys()))
+def test_module_dict():
+ class Module(nn.Module):
+ def __init__(self):
+ self.layers = nn.ModuleDict(
+ {"linear0": nn.Linear(4, 4, bias=False), "linear1":
nn.Linear(4, 4, bias=False)}
+ )
+
+ def forward(self, x: nn.Tensor):
+ x = self.layers["linear0"](x)
+ x = self.layers["linear1"](x)
+ return x
+
+ mod = Module()
+ named_params = dict(mod.named_parameters())
+ assert ["layers.linear0.weight", "layers.linear1.weight"] ==
sorted(list(named_params.keys()))
Review Comment:

This test is a good start for verifying `named_parameters`. However, to
ensure `ModuleDict` is fully integrated and works correctly during compilation,
I suggest extending this test to cover the `forward` pass. You could export the
module to a TVM IRModule and use `assert_structural_equal` against an expected
Relax function, similar to how other modules are tested in this file. This
would provide a stronger guarantee of correctness.
##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -60,6 +60,8 @@
from .subroutine import SubroutineMixin
if TYPE_CHECKING:
+ from collections.abc import Iterable, Iterator, Mapping
Review Comment:

The `Mapping` type is imported but not used within this file. Please remove
it to avoid clutter.
```suggestion
from collections.abc import Iterable, Iterator
```
##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -23,7 +23,7 @@
- Effect, a non-user-facing class that encloses potential side effects, for
example, IO,
impure external function callings, inplace mutation, etc.
"""
-from collections import OrderedDict
+from collections import abc as container_abcs, OrderedDict
Review Comment:

The imported name `container_abcs` is not used anywhere in this file. To
keep the code clean, it's best to remove unused imports.
```suggestion
from collections import OrderedDict
```
##########
tests/python/relax/test_frontend_nn_modules.py:
##########
@@ -715,5 +715,22 @@ def forward(self, x: nn.Tensor):
assert ["layers.0.0.weight", "layers.0.1.weight"] ==
sorted(list(named_params.keys()))
+def test_module_dict():
+ class Module(nn.Module):
Review Comment:

The class name `Module` shadows the imported `nn.Module`. To improve clarity
and avoid potential confusion, it's better to use a more descriptive name for
your test module, such as `ModuleWithDict`.
```suggestion
class ModuleWithDict(nn.Module):
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]