gemini-code-assist[bot] commented on code in PR #18408:
URL: https://github.com/apache/tvm/pull/18408#discussion_r2477518058


##########
position_id_fix.py:
##########
@@ -0,0 +1,78 @@
+# sol-script-fixed.py
+import torch
+import torch.nn as nn
+from transformers import AutoModel
+from torch.export import export as torch_export
+from tvm.relax.frontend.torch import from_exported_program
+
+class StateDictWrapper(dict):
+    """Wrap exported state_dict and inject extra keys (non-persistent 
buffers)."""
+    def __init__(self, base_dict, extra):
+        super().__init__(base_dict)
+        self.extra = extra
+
+    def __getitem__(self, key):
+        if key in self.extra:
+            return self.extra[key]
+        return super().__getitem__(key)
+
+    def get(self, key, default=None):
+        if key in self.extra:
+            return self.extra[key]
+        return super().get(key, default)
+
+class M(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased")
+        self.cls = nn.Linear(self.bert.config.hidden_size, 2)
+
+    def forward(self, x, mask=None):
+        out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
+        return self.cls(out)
+
+def main():
+    torch.manual_seed(0)
+    m = M().eval()
+
+    x = torch.randint(0, 30522, (2, 16))
+    mask = torch.ones_like(x)
+
+    ep = torch_export(m, (x, mask))
+    print("\n torch.export completed successfully\n")
+
+    # --- Build extra buffers dict ---
+    extra = {}
+    for buf_name in m.bert.embeddings._non_persistent_buffers_set:
+        tensor = m.bert.embeddings._buffers.get(buf_name)
+        if tensor is not None:
+            extra[f"bert.embeddings.{buf_name}"] = tensor
+            print(f"Injecting buffer: bert.embeddings.{buf_name} -> shape 
{tensor.shape}")
+
+    # Wrap exported state_dict
+    sd_wrapped = StateDictWrapper(ep.state_dict, extra)
+
+    # EP wrapper to override state_dict access
+    class EPWrapper:
+        def __init__(self, ep, sd_wrapped):
+            self.__dict__["_ep"] = ep
+            self.__dict__["_sd"] = sd_wrapped
+
+        def __getattr__(self, name):
+            if name == "state_dict":
+                return self._sd
+            return getattr(self._ep, name)
+
+    ep_wrapped = EPWrapper(ep, sd_wrapped)
+
+    # Import to TVM
+    try:
+        mod = from_exported_program(ep_wrapped)
+        print("\n TVM import succeeded — all non-persistent buffers 
injected!\n")
+    except Exception as e:
+        print("\n TVM import failed with exception:")
+        import traceback
+        traceback.print_exc()

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   It is a standard Python convention (PEP 8) to place all imports at the top 
of the file. This improves readability and makes it easier to see the script's 
dependencies at a glance. Please move `import traceback` to the top of the 
script with the other imports.
   
   ```suggestion
           traceback.print_exc()
   ```



##########
position_id_fix.py:
##########
@@ -0,0 +1,78 @@
+# sol-script-fixed.py
+import torch
+import torch.nn as nn
+from transformers import AutoModel
+from torch.export import export as torch_export
+from tvm.relax.frontend.torch import from_exported_program
+
+class StateDictWrapper(dict):
+    """Wrap exported state_dict and inject extra keys (non-persistent 
buffers)."""
+    def __init__(self, base_dict, extra):
+        super().__init__(base_dict)
+        self.extra = extra
+
+    def __getitem__(self, key):
+        if key in self.extra:
+            return self.extra[key]
+        return super().__getitem__(key)
+
+    def get(self, key, default=None):
+        if key in self.extra:
+            return self.extra[key]
+        return super().get(key, default)
+
+class M(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased")
+        self.cls = nn.Linear(self.bert.config.hidden_size, 2)
+
+    def forward(self, x, mask=None):
+        out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
+        return self.cls(out)
+
+def main():
+    torch.manual_seed(0)
+    m = M().eval()
+
+    x = torch.randint(0, 30522, (2, 16))

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The value `30522` is a magic number, which appears to be the vocabulary size 
for 'bert-base-multilingual-uncased'. It's better to fetch this value from the 
model's configuration to improve readability and maintainability. This makes 
the code more robust if the model changes.
   
   ```suggestion
       x = torch.randint(0, m.bert.config.vocab_size, (2, 16))
   ```



##########
position_id_fix.py:
##########
@@ -0,0 +1,78 @@
+# sol-script-fixed.py
+import torch
+import torch.nn as nn
+from transformers import AutoModel
+from torch.export import export as torch_export
+from tvm.relax.frontend.torch import from_exported_program
+
+class StateDictWrapper(dict):
+    """Wrap exported state_dict and inject extra keys (non-persistent 
buffers)."""
+    def __init__(self, base_dict, extra):
+        super().__init__(base_dict)
+        self.extra = extra
+
+    def __getitem__(self, key):
+        if key in self.extra:
+            return self.extra[key]
+        return super().__getitem__(key)
+
+    def get(self, key, default=None):
+        if key in self.extra:
+            return self.extra[key]
+        return super().get(key, default)
+
+class M(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased")
+        self.cls = nn.Linear(self.bert.config.hidden_size, 2)
+
+    def forward(self, x, mask=None):
+        out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
+        return self.cls(out)
+
+def main():
+    torch.manual_seed(0)
+    m = M().eval()
+
+    x = torch.randint(0, 30522, (2, 16))
+    mask = torch.ones_like(x)
+
+    ep = torch_export(m, (x, mask))
+    print("\n torch.export completed successfully\n")
+
+    # --- Build extra buffers dict ---
+    extra = {}
+    for buf_name in m.bert.embeddings._non_persistent_buffers_set:
+        tensor = m.bert.embeddings._buffers.get(buf_name)
+        if tensor is not None:
+            extra[f"bert.embeddings.{buf_name}"] = tensor
+            print(f"Injecting buffer: bert.embeddings.{buf_name} -> shape 
{tensor.shape}")
+
+    # Wrap exported state_dict
+    sd_wrapped = StateDictWrapper(ep.state_dict, extra)
+
+    # EP wrapper to override state_dict access
+    class EPWrapper:
+        def __init__(self, ep, sd_wrapped):
+            self.__dict__["_ep"] = ep
+            self.__dict__["_sd"] = sd_wrapped
+
+        def __getattr__(self, name):
+            if name == "state_dict":
+                return self._sd
+            return getattr(self._ep, name)
+

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This proxy implementation for `ExportedProgram` is minimal. A more robust 
and idiomatic way to create this wrapper is by using a property for 
`state_dict`. This avoids overriding `__getattr__` in a way that could be 
brittle and makes the intent clearer. The suggested implementation is cleaner 
and less prone to subtle bugs if the `from_exported_program` API has more 
complex interactions with the object.
   
   ```suggestion
       class EPWrapper:
           def __init__(self, ep, sd_wrapped):
               self._ep = ep
               self._sd = sd_wrapped
   
           @property
           def state_dict(self):
               return self._sd
   
           def __getattr__(self, name):
               return getattr(self._ep, name)
   ```



##########
position_id_fix.py:
##########
@@ -0,0 +1,78 @@
+# sol-script-fixed.py
+import torch
+import torch.nn as nn
+from transformers import AutoModel
+from torch.export import export as torch_export
+from tvm.relax.frontend.torch import from_exported_program
+
+class StateDictWrapper(dict):
+    """Wrap exported state_dict and inject extra keys (non-persistent 
buffers)."""
+    def __init__(self, base_dict, extra):
+        super().__init__(base_dict)
+        self.extra = extra
+
+    def __getitem__(self, key):
+        if key in self.extra:
+            return self.extra[key]
+        return super().__getitem__(key)
+
+    def get(self, key, default=None):
+        if key in self.extra:
+            return self.extra[key]
+        return super().get(key, default)
+
+class M(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased")
+        self.cls = nn.Linear(self.bert.config.hidden_size, 2)
+
+    def forward(self, x, mask=None):
+        out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
+        return self.cls(out)
+
+def main():
+    torch.manual_seed(0)
+    m = M().eval()
+
+    x = torch.randint(0, 30522, (2, 16))
+    mask = torch.ones_like(x)
+
+    ep = torch_export(m, (x, mask))
+    print("\n torch.export completed successfully\n")
+
+    # --- Build extra buffers dict ---
+    extra = {}
+    for buf_name in m.bert.embeddings._non_persistent_buffers_set:
+        tensor = m.bert.embeddings._buffers.get(buf_name)
+        if tensor is not None:
+            extra[f"bert.embeddings.{buf_name}"] = tensor
+            print(f"Injecting buffer: bert.embeddings.{buf_name} -> shape 
{tensor.shape}")
+
+    # Wrap exported state_dict
+    sd_wrapped = StateDictWrapper(ep.state_dict, extra)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The custom `StateDictWrapper` class can be replaced by 
`collections.ChainMap` for a more concise and idiomatic implementation. 
`ChainMap` is designed for linking multiple dictionaries.
   
   After this change, you can remove the `StateDictWrapper` class definition 
(lines 8-22) and add `import collections` to the top of the file.
   
   ```suggestion
       sd_wrapped = collections.ChainMap(extra, ep.state_dict)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to