tinywisdom opened a new issue, #18364:
URL: https://github.com/apache/tvm/issues/18364
### Summary
Importing a torch.exported STFT + RNN toy model fails in the TVM Relax Torch
frontend with:
```
AssertionError: Unsupported function types ['rnn_tanh.input',
'real.default', 'unfold.default', 'imag.default', 'fft_fft.default']
```
This highlights several missing op coverages common in audio pipelines: 1D
framing (unfold), FFT (fft_fft), complex tensor accessors (real/imag), and
fused RNN (rnn_tanh.input).
### Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
- Pytorch: (2.7.1)
### Steps to reproduce
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_input(batch=1, length=4096, device="cpu", dtype=torch.float32):
return torch.randn(batch, length, device=device, dtype=dtype)
class MiniSTFTRNN(nn.Module):
def __init__(self, win_len=320, n_fft=512, hop=160):
super().__init__()
self.register_buffer("window", torch.hann_window(win_len))
self.win_len = win_len
self.n_fft = n_fft
self.hop = hop
# use per-frame FFT spectrum as features
self.rnn = nn.RNN(input_size=n_fft, hidden_size=8, num_layers=1,
batch_first=True, nonlinearity="tanh")
self.fc = nn.Linear(8, 4)
def forward(self, x):
# x: (B, L)
pad_tail = (0, self.n_fft - (x.shape[-1] % self.hop)) if
(x.shape[-1] % self.hop) != 0 else (0, 0)
x = F.pad(x, pad_tail, mode="constant") # align for
framing
frames = x.unfold(-1, self.win_len, self.hop) # (B, T,
win_len) -> aten::unfold
frames = frames * self.window # windowing
spec = torch.fft.fft(frames, n=self.n_fft) #
aten::fft_fft
real = spec.real #
aten::real.default
imag = spec.imag #
aten::imag.default
mag = torch.sqrt(real * real + imag * imag) # (B, T,
n_fft)
out, _ = self.rnn(mag) #
aten::rnn_tanh.input
return self.fc(out[:, -1, :]) # (B, 4)
def main():
import numpy as np
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program
torch.manual_seed(0); np.random.seed(0)
model = MiniSTFTRNN().eval()
inp = get_input(batch=1, length=4096)
with torch.inference_mode():
_ = model(inp) # sanity
ep = torch_export(model, (inp,))
mod = from_exported_program(ep) # <- raises assertion
if __name__ == "__main__":
main()
```
### Output
```
Traceback (most recent call last):
...
File ".../base_fx_graph_translator.py", line 116, in
_check_unsupported_func_type
assert not missing_func_types, f"Unsupported function types
{missing_func_types}"
AssertionError: Unsupported function types ['rnn_tanh.input',
'real.default', 'unfold.default', 'imag.default', 'fft_fft.default']
```
### Triage
* needs-triage
* bug
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]