tinywisdom opened a new issue, #18407:
URL: https://github.com/apache/tvm/issues/18407
### Expected behavior
A tiny Transformer-like block exported via torch.export crashes TVM when
importing with tvm.relax.frontend.torch.from_exported_program(ep).
Before the crash, PyTorch emits warnings that torch.export inserted a
get_attr node without a backing submodule/parameter/buffer. TVM then segfaults
in tvm::relax::Tuple::Tuple(...)/FFI path while translating the exported
program.
### Actual behavior
torch.export succeeds but prints the above get_attr lifting warnings.
Immediately after, tvm.relax.frontend.torch.from_exported_program(ep)
triggers an FFI segfault.
(In my run it shows an FFI backtrace ending in tvm::relax::Tuple::Tuple(...)
/ TVM FFI traceback.)
```
!!!!!!! TVM FFI encountered a Segfault !!!!!!!
... tvm::relax::Tuple::Tuple(...) ...
Segmentation fault (core dumped)
```
### Environment
+ OS: (Ubuntu 22.04.4 LTS (x86_64))
+ TVM version: (release v0.21.0)
+ Python: (3.10.16)
+ LLVM: (17.0.6)
+ Pytorch: (2.7.1)
### Steps to reproduce
```python
# mini_repro_export_tvm_segfault.py
import math
import torch
import torch.nn as nn
def get_attn_pad_mask(seq_q, seq_k):
B, Lq = seq_q.size()
B2, Lk = seq_k.size()
assert B == B2
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # (B,1,Lk)
return pad_attn_mask.expand(B, Lq, Lk) # (B,Lq,Lk)
class TinyMHA(nn.Module):
def __init__(self, d_model=64, d_k=16, n_heads=4, dropout=0.1):
super().__init__()
self.h, self.dk = n_heads, d_k
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_V = nn.Linear(d_model, d_k * n_heads, bias=False)
self.proj = nn.Linear(d_k * n_heads, d_model, bias=False)
self.ln = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, attn_mask): # x: (B,L,dm), attn_mask: (B,L,L)
B, L, _ = x.shape
q = self.W_Q(x).view(B, L, self.h, self.dk).transpose(1, 2) #
(B,H,L,dk)
k = self.W_K(x).view(B, L, self.h, self.dk).transpose(1, 2)
v = self.W_V(x).view(B, L, self.h, self.dk).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.dk)
# (B,H,L,L)
# In-place masked_fill_ with broadcasted mask coming from
eq(0)+expand
scores.masked_fill_(attn_mask.unsqueeze(1), -1e9)
attn = torch.softmax(scores, dim=-1)
ctx = torch.matmul(attn, v).transpose(1, 2).reshape(B, L, self.h *
self.dk)
out = self.drop(self.proj(ctx))
return self.ln(out + x)
class MiniModel(nn.Module):
def __init__(self, vocab=10000, d_model=64):
super().__init__()
self.emb = nn.Embedding(vocab, d_model)
self.mha = TinyMHA(d_model=d_model, d_k=16, n_heads=4, dropout=0.1)
self.proj = nn.Linear(d_model, vocab, bias=False)
def forward(self, enc_inputs, dec_inputs_unused=None):
x = self.emb(enc_inputs) # (B,L,dm)
mask = get_attn_pad_mask(enc_inputs, enc_inputs) # (B,L,L)
y = self.mha(x, mask) # (B,L,dm)
logits = self.proj(y) # (B,L,V)
return logits.reshape(-1, logits.size(-1)) # (B*L, V)
def my_model_function(): return MiniModel()
def GetInput():
enc = torch.randint(0, 10000, (2, 5))
enc[0, 0] = 0 # ensure eq(0) path is taken
dec = torch.randint(0, 10000, (2, 5))
return (enc, dec)
import numpy as np
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program
def trigger_known_bugs(model=None):
if model is None:
model = my_model_function()
torch.manual_seed(42); np.random.seed(42)
model.eval()
args = GetInput()
ep = torch_export(model, args) # Emits get_attr warnings (see
below)
mod = from_exported_program(ep) # <-- TVM segfaults here in my
env
print(mod)
if __name__ == "__main__":
import os
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "6,7")
trigger_known_bugs()
```
### Triage
* needs-triage
* bug
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]