tinywisdom opened a new issue, #18338:
URL: https://github.com/apache/tvm/issues/18338
### Expected behavior
+ `unbind` producing a single-element tuple should be represented as a tuple
in Relax IR, not misinterpreted as a tensor.
+ The frontend should either:
+ Correctly lower to Tuple with one element (each tensor shaped `(3,)` in
this case), or
+ Gracefully reject with a clear Python exception, not an internal
assertion failure.
### Actual behavior
When importing a PyTorch `torch.exported` program into TVM Relax, if the
model applies `unbind(dim=0)` to a tensor with dimension=1 along that axis, the
frontend crashes with:
```
Check failed: (opt) is false: The struct info of Tuple must be
TupleStructInfo,
but expression lv3 has struct info R.Tensor((1, 3), dtype="float32")
```
This indicates a mismatch between expected `TupleStructInfo` vs. actual
`TensorStructInfo`.
```
[INFO] start importing exported program into TVM Relax...
[REPRODUCED] Caught exception while importing:
Check failed: (opt) is false: The struct info of Tuple must be
TupleStructInfo,
but expression lv3 has struct info R.Tensor((1, 3), dtype="float32")
tvm.error.InternalError: Check failed: (opt) is false: The struct info of
Tuple must be TupleStructInfo, but expression lv3 has struct info R.Tensor((1,
3), dtype="float32")
[...]/src/relax/ir/block_builder.cc:64: Warning: BlockBuilder destroyed with
remaining blocks!
```
### Environment
+ OS: (Ubuntu 22.04.4 LTS (x86_64))
+ TVM version: (release v0.21.0)
+ Python: (3.10.16)
+ LLVM: (17.0.6)
### Steps to reproduce
```python
import torch
import torch.nn as nn
# Minimal model: gather -> max -> unbind(0)
class MyModel(nn.Module):
def forward(self, x, y):
# x: (2, 3), y: (1, 3)
indices = torch.zeros((1,) + x.size()[1:], dtype=torch.long,
device=x.device)
x_gathered = torch.gather(x, 0, indices) # (1, 3)
compared = torch.max(x_gathered, y) # (1, 3)
outs = compared.unbind(0) # tuple of length 1, each
(3,)
return outs
def get_inputs():
torch.manual_seed(0)
x = torch.randn(2, 3)
y = torch.randn(1, 3)
return x, y
if __name__ == "__main__":
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program
m = MyModel().eval()
args = get_inputs()
ep = torch_export(m, args)
print("[INFO] start importing exported program into TVM Relax...")
try:
mod = from_exported_program(ep)
print("[UNEXPECTED] Import succeeded (no error).")
except Exception as e:
print("[REPRODUCED] Caught exception while importing:")
print(e)
raise
```
### Triage
* needs-triage
* bug
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]