tinywisdom opened a new issue, #18363:
URL: https://github.com/apache/tvm/issues/18363
### Summary
Importing a torch.exported program into TVM Relax triggers a segmentation
fault inside FFI during construction of a Relax Tuple. The minimal model
performs a 4D advanced indexing write using two integer index tensors on the
last two dims (L[..., idx, idx] = ...) and returns a Python tuple of tensors
(x[..., :1], L). The exported graph is free of RNG ops (no randn), so the crash
appears related to the combination of aten.index_put_ lowering and tuple output
construction.
### Actual behavior
```
[1] torch.export ...
=== Exported ops ===
... (as above)
[2] tvm.relax.frontend.torch.from_exported_program ...
!!!!!!! TVM FFI encountered a Segfault !!!!!!!
...
tvm::relax::Tuple::Tuple(tvm::ffi::Array<tvm::RelaxExpr, void>, tvm::Span)
[clone .cold]
...
Segmentation fault (core dumped)
```
### Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
- Pytorch: (2.8.0)
### Steps to reproduce
```python
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # avoid GPU warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program
class M4D(nn.Module):
def forward(self, x):
B, K, N = 2, 3, 5
L = x.new_zeros(B, K, N, N) # tensor construct only; no
randomness
idx = torch.arange(N, device=x.device)
# key trigger: gather diagonal, apply smooth monotonic transform,
scatter back
diag = L[..., idx, idx] # shape: [B, K, N]
diag = F.elu(diag) + 1.0 + 1e-8 # avoid all-zero; any smooth
transform works
L[..., idx, idx] = diag # advanced indexing write (two int
index tensors)
# key trigger: return a Python-level tuple (two tensors)
return x[..., :1], L
if __name__ == "__main__":
torch.manual_seed(0)
m = M4D().eval()
ex_in = torch.zeros(2, 3, 5) # any input; ensures no randn exported
print("[1] torch.export ...")
ep = torch_export(m, (ex_in,))
# sanity: list exported ops
try:
print("=== Exported ops ===")
for n in ep.graph.nodes:
print(getattr(n, "op", None), getattr(n, "target", None))
except Exception:
pass
print("[2] tvm.relax.frontend.torch.from_exported_program ...")
mod = from_exported_program(ep) # <-- segfaults inside FFI Tuple
construction
print("[OK] Converted without segfault (if you see this, env may
differ)")
```
### Triage
* needs-triage
* bug
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]