tinywisdom opened a new issue, #18340:
URL: https://github.com/apache/tvm/issues/18340

   Converting a PyTorch `torch.exported` program into TVM Relax with 
`from_exported_program` fails when the model contains `nn.LSTM`. The frontend 
reports:
   ```
   AssertionError: Unsupported function types ['lstm.input']
   ```
   This indicates that `lstm.input` (the op emitted by `torch.export` for 
`nn.LSTM`) is currently not supported in the TVM Relax PyTorch frontend.
   
   
   ### Expected behavior
   
   + The Relax Torch frontend should lower `nn.LSTM` (emitted as `lstm.input` 
in `torch.export`) to a supported Relax representation:
   
     + Either a high-level RNN/LSTM composite (if available), or
   
     + A lower-level decomposition into primitive ops 
(matmul/elementwise/activations) wrapped as a Relax subgraph / `call_tir` where 
appropriate.
   
   + If certain LSTM configurations are not yet supported (e.g., bidirectional, 
multi-layer, projections), the importer should:
   
     + Accept supported subsets, and
   
     + Emit a clear Python exception for unsupported variants with guidance.
   
   ### Actual behavior
   
   ```
   PyTorch eager OK, y.shape = (2, 4, 16)
   ExportedProgram created.
   Traceback (most recent call last):
     ...
     File ".../base_fx_graph_translator.py", line 116, in 
_check_unsupported_func_type
       assert not missing_func_types, f"Unsupported function types 
{missing_func_types}"
   AssertionError: Unsupported function types ['lstm.input']
   ```
   
   ### Environment
   
   + OS: (Ubuntu 22.04.4 LTS (x86_64))
   + TVM version: (release v0.21.0)
   + Python: (3.10.16)
   + LLVM: (17.0.6)
   
   ### Steps to reproduce
   
   ```python
   import torch
   import torch.nn as nn
   from torch.export import export as torch_export
   from tvm.relax.frontend.torch import from_exported_program
   
   class M(nn.Module):
       def __init__(self, input_size=8, hidden_size=16, num_layers=1, 
bidirectional=False):
           super().__init__()
           self.lstm = nn.LSTM(
               input_size=input_size,
               hidden_size=hidden_size,
               num_layers=num_layers,
               batch_first=True,
               bidirectional=bidirectional,
           )
   
       def forward(self, x):
           # Only return the output sequence; drop (h_n, c_n)
           y, _ = self.lstm(x)
           return y
   
   def main():
       torch.manual_seed(0)
       m = M().eval()
   
       # Minimal input: B=2, T=4, C=8
       x = torch.randn(2, 4, 8, dtype=torch.float32)
   
       # 1) Sanity check in eager
       with torch.inference_mode():
           y = m(x)
       print("PyTorch eager OK, y.shape =", tuple(y.shape))
   
       # 2) Export to ExportedProgram
       ep = torch_export(m, (x,))
       print("ExportedProgram created.")
   
       # 3) Import into TVM Relax — triggers unsupported function type
       _ = from_exported_program(ep)
   
   if __name__ == "__main__":
       main()
   ```
   
   ### Triage
   
   
   * needs-triage
   * bug
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to