gemini-code-assist[bot] commented on code in PR #18551:
URL: https://github.com/apache/tvm/pull/18551#discussion_r2593587859


##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -540,6 +542,53 @@ def _compile(spec, device, pipeline, debug):
         raise ValueError(f"Unknown out_format: {out_format}")
 
 
+class ModuleDict(Module):
+    """Holds submodules in a dict."""
+
+    def __init__(self, modules: OrderedDict[str, Module]):
+        self.modules = OrderedDict(modules)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The `__init__` method for `ModuleDict` is a bit restrictive. To align better 
with PyTorch's `ModuleDict` and improve usability, I recommend making the 
`modules` argument optional and accepting a more general mapping type.
   
   This would allow creating an empty `ModuleDict` with `nn.ModuleDict()` and 
also allow passing regular dictionaries during initialization, which is more 
convenient.
   
   ```suggestion
       def __init__(self, modules: Optional[Dict[str, Module]] = None):
           if modules is None:
               self.modules = OrderedDict()
           else:
               self.modules = OrderedDict(modules)
   ```



##########
tests/python/relax/test_frontend_nn_modules.py:
##########
@@ -715,5 +715,22 @@ def forward(self, x: nn.Tensor):
     assert ["layers.0.0.weight", "layers.0.1.weight"] == 
sorted(list(named_params.keys()))
 
 
+def test_module_dict():
+    class Module(nn.Module):
+        def __init__(self):
+            self.layers = nn.ModuleDict(
+                {"linear0": nn.Linear(4, 4, bias=False), "linear1": 
nn.Linear(4, 4, bias=False)}
+            )
+
+        def forward(self, x: nn.Tensor):
+            x = self.layers["linear0"](x)
+            x = self.layers["linear1"](x)
+            return x
+
+    mod = Module()
+    named_params = dict(mod.named_parameters())
+    assert ["layers.linear0.weight", "layers.linear1.weight"] == 
sorted(list(named_params.keys()))

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This test is a good start for verifying `named_parameters`. However, to 
ensure `ModuleDict` is fully integrated and works correctly during compilation, 
I suggest extending this test to cover the `forward` pass. You could export the 
module to a TVM IRModule and use `assert_structural_equal` against an expected 
Relax function, similar to how other modules are tested in this file. This 
would provide a stronger guarantee of correctness.



##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -60,6 +60,8 @@
 from .subroutine import SubroutineMixin
 
 if TYPE_CHECKING:
+    from collections.abc import Iterable, Iterator, Mapping

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `Mapping` type is imported but not used within this file. Please remove 
it to avoid clutter.
   
   ```suggestion
       from collections.abc import Iterable, Iterator
   ```



##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -23,7 +23,7 @@
 - Effect, a non-user-facing class that encloses potential side effects, for 
example, IO,
   impure external function callings, inplace mutation, etc.
 """
-from collections import OrderedDict
+from collections import abc as container_abcs, OrderedDict

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The imported name `container_abcs` is not used anywhere in this file. To 
keep the code clean, it's best to remove unused imports.
   
   ```suggestion
   from collections import OrderedDict
   ```



##########
tests/python/relax/test_frontend_nn_modules.py:
##########
@@ -715,5 +715,22 @@ def forward(self, x: nn.Tensor):
     assert ["layers.0.0.weight", "layers.0.1.weight"] == 
sorted(list(named_params.keys()))
 
 
+def test_module_dict():
+    class Module(nn.Module):

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The class name `Module` shadows the imported `nn.Module`. To improve clarity 
and avoid potential confusion, it's better to use a more descriptive name for 
your test module, such as `ModuleWithDict`.
   
   ```suggestion
       class ModuleWithDict(nn.Module):
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to