tinywisdom opened a new issue, #18338:
URL: https://github.com/apache/tvm/issues/18338

   ### Expected behavior
   
   + `unbind` producing a single-element tuple should be represented as a tuple 
in Relax IR, not misinterpreted as a tensor.
   
   + The frontend should either:
   
     + Correctly lower to Tuple with one element (each tensor shaped `(3,)` in 
this case), or
   
     + Gracefully reject with a clear Python exception, not an internal 
assertion failure.
   
   ### Actual behavior
   
   When importing a PyTorch `torch.exported` program into TVM Relax, if the 
model applies `unbind(dim=0)` to a tensor with dimension=1 along that axis, the 
frontend crashes with:
   ```
   Check failed: (opt) is false: The struct info of Tuple must be 
TupleStructInfo,
   but expression lv3 has struct info R.Tensor((1, 3), dtype="float32")
   ```
   This indicates a mismatch between expected `TupleStructInfo` vs. actual 
`TensorStructInfo`.
   
   ```
   [INFO] start importing exported program into TVM Relax...
   [REPRODUCED] Caught exception while importing:
   Check failed: (opt) is false: The struct info of Tuple must be 
TupleStructInfo,
   but expression lv3 has struct info R.Tensor((1, 3), dtype="float32")
   
   tvm.error.InternalError: Check failed: (opt) is false: The struct info of 
Tuple must be TupleStructInfo, but expression lv3 has struct info R.Tensor((1, 
3), dtype="float32")
   [...]/src/relax/ir/block_builder.cc:64: Warning: BlockBuilder destroyed with 
remaining blocks!
   ```
   
   ### Environment
   
   + OS: (Ubuntu 22.04.4 LTS (x86_64))
   + TVM version: (release v0.21.0)
   + Python: (3.10.16)
   + LLVM: (17.0.6)
   
   ### Steps to reproduce
   
   ```python
   import torch
   import torch.nn as nn
   
   # Minimal model: gather -> max -> unbind(0)
   class MyModel(nn.Module):
       def forward(self, x, y):
           # x: (2, 3), y: (1, 3)
           indices = torch.zeros((1,) + x.size()[1:], dtype=torch.long, 
device=x.device)
           x_gathered = torch.gather(x, 0, indices)   # (1, 3)
           compared = torch.max(x_gathered, y)        # (1, 3)
           outs = compared.unbind(0)                  # tuple of length 1, each 
(3,)
           return outs
   
   def get_inputs():
       torch.manual_seed(0)
       x = torch.randn(2, 3)
       y = torch.randn(1, 3)
       return x, y
   
   if __name__ == "__main__":
       from torch.export import export as torch_export
       from tvm.relax.frontend.torch import from_exported_program
   
       m = MyModel().eval()
       args = get_inputs()
   
       ep = torch_export(m, args)
   
       print("[INFO] start importing exported program into TVM Relax...")
       try:
           mod = from_exported_program(ep)
           print("[UNEXPECTED] Import succeeded (no error).")
       except Exception as e:
           print("[REPRODUCED] Caught exception while importing:")
           print(e)
           raise
   ```
   
   ### Triage
   
   * needs-triage
   * bug
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to