LiSsHhUuAaIi opened a new issue, #18443:
URL: https://github.com/apache/tvm/issues/18443
### Description
When converting a PyTorch model containing `torch.sqrt` operation on integer
tensors (commonly used in attention scaling factors) to TVM Relax module via
`torch.export`, an InternalError occurs. PyTorch's `torch.sqrt` automatically
converts integer inputs to float, but TVM's `relax.sqrt` requires explicit
float dtype and fails with integer inputs.
### Expected behavior
The PyTorch model with `sqrt` on integer tensors should be successfully
converted to TVM Relax module, matching PyTorch's behavior of automatic type
conversion.
### Actual behavior
An InternalError occurs during `from_exported_program` conversion with the
message `Op(relax.sqrt) requires the input tensor to have float dtype. However,
the given input dtype is int64`, indicating that TVM's sqrt implementation
doesn't handle automatic type conversion like PyTorch.
### Environment
* **OS:** Ubuntu 20.04.6 LTS
* **TVM version:** 0.23.dev0
* **Python version:** 3.11.14
### Steps to reproduce
```python
import torch
import torch.nn as nn
import tvm
from tvm import relax
class SqrtModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
# Common pattern in attention: sqrt of integer dimension
size_tensor = torch.tensor(x.size(-1)) # Integer tensor
scaling_factor = 1.0 / torch.sqrt(size_tensor) # PyTorch
auto-converts
return x * scaling_factor
model = SqrtModel()
model.eval()
x = torch.randn(2, 64, 512)
# PyTorch execution works (automatic type conversion)
with torch.no_grad():
output = model(x)
# PyTorch export works
exported_program = torch.export.export(model, (x,))
# TVM conversion fails
from tvm.relax.frontend.torch import from_exported_program
mod = from_exported_program(exported_program) # InternalError here
```
### Error Log
```
Traceback (most recent call last):
File "test.py", line 30, in <module>
mod = from_exported_program(exported_program) # InternalError here
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...
tvm.error.InternalError: Op(relax.sqrt) requires the input tensor to have
float dtype. However, the given input dtype is int64
```
### Triage
* needs-triage
* bug
* frontend: pytorch
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]