tinywisdom opened a new issue, #18356:
URL: https://github.com/apache/tvm/issues/18356
### Expected behavior
Translating a tiny torch.exported model that contains nn.GRU fails in TVM
Relax Torch frontend with:
```
AssertionError: Unsupported function types ['gru.input']
```
This looks like missing coverage for the fused RNN op (aten::gru overload
gru.input) in from_exported_program. Even if unsupported today, it would be
helpful to either (a) add lowering for GRU, or (b) fail with a more actionable
message / doc pointer describing the current RNN support scope and workarounds.
### Actual behavior
```
=== Versions ===
Python : 3.10.16 | packaged by conda-forge | (main, Apr 8, 2025, 20:53:32)
[GCC 13.3.0]
Torch : 2.8.0+cu128
TVM : 0.21.0
================
torch.export: OK
TVM from_exported_program: FAILED as expected
Error: Unsupported function types ['gru.input']
Traceback (most recent call last):
...
File ".../tvm/relax/frontend/torch/base_fx_graph_translator.py", line 116,
in _check_unsupported_func_type
assert not missing_func_types, f"Unsupported function types
{missing_func_types}"
AssertionError: Unsupported function types ['gru.input']
```
### Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
### Steps to reproduce
```python
# minimal_gru_from_exported_program_repro.py
import torch
import torch.nn as nn
import sys
def print_env():
print("=== Versions ===")
print("Python :", sys.version.replace("\n", " "))
print("Torch :", torch.__version__)
try:
import torchaudio, torchvision
print("torchaudio:", getattr(torchaudio, '__version__', 'unknown'))
print("torchvision:", getattr(torchvision, '__version__', 'unknown'))
except Exception:
pass
try:
import tvm
print("TVM :", tvm.__version__)
except Exception as e:
print("TVM : import error ->", e)
print("================")
class M(nn.Module):
"""
Tiny model that triggers a fused GRU op (aten::gru/gru.input)
in the ExportedProgram graph.
"""
def __init__(self):
super().__init__()
self.proj = nn.Linear(12, 10) # map to GRU input_size=10
self.gru = nn.GRU(10, 20, num_layers=2)
def forward(self, x):
# x: (B=1, 12)
x = self.proj(x) # (1, 10)
x = x.unsqueeze(0) # (T=1, B=1, 10) → seq len 1
y, h = self.gru(x) # y: (1, 1, 20)
return y
def main():
torch.manual_seed(0)
print_env()
m = M().eval()
inp = torch.randn(1, 12)
# Eager sanity
with torch.inference_mode():
_ = m(inp)
# torch.export succeeds
from torch.export import export as torch_export
ep = torch_export(m, (inp,))
print("torch.export: OK")
# TVM: translate ExportedProgram → Relax
from tvm.relax.frontend.torch import from_exported_program
try:
mod = from_exported_program(ep)
print("TVM from_exported_program: OK (unexpected)")
except AssertionError as e:
print("TVM from_exported_program: FAILED as expected")
print("Error:", e) # shows "Unsupported function types
['gru.input']"
raise
except Exception as e:
print("TVM from_exported_program: FAILED (different exception)")
raise
if __name__ == "__main__":
main()
```
### Triage
* needs-triage
* bug
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]