tinywisdom opened a new issue, #18364:
URL: https://github.com/apache/tvm/issues/18364

   ### Summary
   Importing a torch.exported STFT + RNN toy model fails in the TVM Relax Torch 
frontend with:
   ```
   AssertionError: Unsupported function types ['rnn_tanh.input', 
'real.default', 'unfold.default', 'imag.default', 'fft_fft.default']
   ```
   This highlights several missing op coverages common in audio pipelines: 1D 
framing (unfold), FFT (fft_fft), complex tensor accessors (real/imag), and 
fused RNN (rnn_tanh.input).
   
   ### Environment
   
   - OS: (Ubuntu 22.04.4 LTS (x86_64))
   - TVM version: (release v0.21.0)
   - Python: (3.10.16)
   - LLVM: (17.0.6)
   - Pytorch: (2.7.1)
   
   ### Steps to reproduce
   
   ```python
   import torch
   import torch.nn as nn
   import torch.nn.functional as F
   
   def get_input(batch=1, length=4096, device="cpu", dtype=torch.float32):
       return torch.randn(batch, length, device=device, dtype=dtype)
   
   class MiniSTFTRNN(nn.Module):
       def __init__(self, win_len=320, n_fft=512, hop=160):
           super().__init__()
           self.register_buffer("window", torch.hann_window(win_len))
           self.win_len = win_len
           self.n_fft = n_fft
           self.hop = hop
           # use per-frame FFT spectrum as features
           self.rnn = nn.RNN(input_size=n_fft, hidden_size=8, num_layers=1,
                             batch_first=True, nonlinearity="tanh")
           self.fc = nn.Linear(8, 4)
   
       def forward(self, x):
           # x: (B, L)
           pad_tail = (0, self.n_fft - (x.shape[-1] % self.hop)) if 
(x.shape[-1] % self.hop) != 0 else (0, 0)
           x = F.pad(x, pad_tail, mode="constant")                # align for 
framing
           frames = x.unfold(-1, self.win_len, self.hop)          # (B, T, 
win_len) -> aten::unfold
           frames = frames * self.window                           # windowing
           spec = torch.fft.fft(frames, n=self.n_fft)              # 
aten::fft_fft
           real = spec.real                                        # 
aten::real.default
           imag = spec.imag                                        # 
aten::imag.default
           mag = torch.sqrt(real * real + imag * imag)             # (B, T, 
n_fft)
           out, _ = self.rnn(mag)                                  # 
aten::rnn_tanh.input
           return self.fc(out[:, -1, :])                           # (B, 4)
   
   def main():
       import numpy as np
       from torch.export import export as torch_export
       from tvm.relax.frontend.torch import from_exported_program
   
       torch.manual_seed(0); np.random.seed(0)
       model = MiniSTFTRNN().eval()
       inp = get_input(batch=1, length=4096)
       with torch.inference_mode():
           _ = model(inp)  # sanity
   
       ep = torch_export(model, (inp,))
       mod = from_exported_program(ep)  # <- raises assertion
   
   if __name__ == "__main__":
       main()
   ```
   ### Output
   ```
   Traceback (most recent call last):
     ...
     File ".../base_fx_graph_translator.py", line 116, in 
_check_unsupported_func_type
       assert not missing_func_types, f"Unsupported function types 
{missing_func_types}"
   AssertionError: Unsupported function types ['rnn_tanh.input', 
'real.default', 'unfold.default', 'imag.default', 'fft_fft.default']
   ```
   
   
   ### Triage
   
   * needs-triage
   * bug
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to