LiSsHhUuAaIi opened a new issue, #18439:
URL: https://github.com/apache/tvm/issues/18439

   ### Expected behavior
   
   When converting a PyTorch model that uses 
`torch.utils.checkpoint.checkpoint` to TVM Relax module via `torch.export`, a 
KeyError occurs during the conversion process. The TVM frontend fails to find 
the 'dtype' key in the node's kwargs when processing certain operations.
   
   The PyTorch model with gradient checkpointing should be successfully 
converted to TVM Relax module without KeyError, as it runs correctly in native 
PyTorch and exports successfully with `torch.export`.
   
   ### Actual behavior
   
   A KeyError occurs during `from_exported_program` conversion with the message 
`KeyError: 'dtype'`, indicating that the TVM frontend expects a 'dtype' 
parameter that is not present in the exported graph node.
   
   ### Environment
   
   * **OS:** Ubuntu 20.04.6 LTS
   * **TVM version:** 0.23.dev0
   *  **Python version:** 3.11.14
   
   ### Steps to reproduce
   
   ```python
   import torch
   import torch.nn as nn
   import tvm
   from tvm import relax
   
   class TestModel(nn.Module):
       def __init__(self):
           super().__init__()
           self.linear = nn.Linear(100, 10)
   
       def forward(self, x):
           x = torch.utils.checkpoint.checkpoint(self.linear, x, 
use_reentrant=False)
           return x
   
   # PyTorch execution works
   model = TestModel()
   model.eval()
   
   x = torch.randn(32, 100)
   
   # PyTorch execution works
   with torch.no_grad():
       output = model(x)
   
   # PyTorch export works
   exported_program = torch.export.export(model, (x,))
   
   # TVM conversion fails
   from tvm.relax.frontend.torch import from_exported_program
   mod = from_exported_program(exported_program)  # KeyError here
   ```
   
   ### Error logs
   ```
   Traceback (most recent call last):
     File "/mnt/e/DL_Compiler_Test/tvm_code/test.py", line 30, in <module>
       mod = from_exported_program(exported_program)  # KeyError here
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File 
"/mnt/e/DL_Compiler_Test/tvm_build/apache-tvm/python/tvm/relax/frontend/torch/exported_program_translator.py",
 line 1334, in from_exported_program
       return ExportedProgramImporter().from_exported_program(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File 
"/mnt/e/DL_Compiler_Test/tvm_build/apache-tvm/python/tvm/relax/frontend/torch/exported_program_translator.py",
 line 1221, in from_exported_program
       self.env[node] = self.convert_map[func_name](node)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File 
"/mnt/e/DL_Compiler_Test/tvm_build/apache-tvm/python/tvm/relax/frontend/torch/base_fx_graph_translator.py",
 line 1939, in _empty
       dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
                                           ~~~~~~~~~~~^^^^^^^^^
   KeyError: 'dtype'
   ```
   
   ### Triage
   * needs-triage
   * bug
   * frontend:pytorch


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to