grf53 opened a new pull request, #18556:
URL: https://github.com/apache/tvm/pull/18556
## Description
This PR fixes a conversion bug that occurs when performing operations on
`bfloat16` tensors.
In conclusion, when applying the `BF16ComputeLegalize` compile pass and
visiting a `BufferStoreNode`, if the stored value's dtype is different from the
buffer's, `DTypeConversion()` should be used instead of a simple `cast` to
apply the appropriate conversion logic.
## Test
I added a test for this situation based on the existing tests.
With the fix, `B[i] = A[i]` turns into `B[i] = bf16tof32(A[i])` properly, so
the test passes.
I'm not really sure whether the structure or name of this added test is
appropriate.
So let me gladly modify it if there is any comment on this.
## Process
### Problem observed
This bug was identified when applying `nn.Linear()` to a `bfloat16` tensor
resulted in excessively large numbers.
While it appears to exist in other operations as well, it's particularly
noticeable when the inner dimension of `MatMul` is a multiple of `8`(`16` for
CUDA and ROCm).
#### Example of problematic code
```python
from ml_dtypes import bfloat16
import numpy as np
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op
from tvm.target import Target
n = 10
INNER_DIM = 8 * n # if INNER_DIM is a multiple of 8
class TestModule(nn.Module):
def __init__(self):
self.weight = nn.Parameter((32, INNER_DIM), dtype=dtype)
def run(self, x: Tensor):
t = op.matmul(self.weight, x, out_dtype=dtype)
return t
def get_default_spec(self):
mod_spec = {
"run": {
"x": nn.spec.Tensor([INNER_DIM, 100], dtype),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
}
return nn.spec.ModuleSpec.from_raw(mod_spec, self)
def compile_module(...):
...
def main():
target = "metal" # or "cuda", "vulkan", ...
model = TestModule()
ex, _ = compile_module(model, target)
device = tvm.device(target, 0)
vm = create_vm(ex, device=device)
frun = vm["run"]
params = []
param = tvm.runtime.empty(
(32, INNER_DIM),
dtype="bfloat16",
device=device,
)
param.copyfrom(np.ones((32, INNER_DIM), dtype=bfloat16))
params.append(param)
inputs = np.ones((INNER_DIM, 100), dtype=bfloat16)
arr = frun(inputs, params)
print(f"{arr=}") # arr has weird values!
```
In cases where the inner dimension is not a multiple of `8`(or `16`), the
issue was avoided by applying `T.if_then_else()` through `PadEinsum`.
`PadEinsum` itself wasn't a troublemaker, and rather helped identify the issue.
### Problem Identified
I could see the problems were avoided by wrapping an expression with
`T.if_then_else()` or `T.cast()` before applying `BF16ComputeLegalize` compile
pass.
#### Statement with problem
```python
weight_reindex_shared[v0, v1, v2] = weight[v1, v2]
```
#### Statements without problem
```python
# 1) wrapped with T.if_then_else()
weight_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < 511, weight[v1,
v2], T.bfloat16(0.0))
# 2) wrapped with T.Cast()
weight_reindex_pad_shared[v0, v1, v2] = T.Cast("float32", weight[v1, v2])
# ...
```
In the `BF16ComputeLegalize` compile pass, if a specific `Expr`(here,
`weight[...]`) is processed through `PromoteToTarget()`(eventually,
`DTypeConversion()`), the syntax changes to the syntax below(TO-BE), which
applies the conversion logic. While the problematic statement simply applies
`T.Cast()`(AS-IS).
#### AS-IS
```python
T.Cast("float32", weight[...])
```
#### TO-BE
```python
T.reinterpret("float32", T.shift_left(T.Cast("uint32",
T.reinterpret("uint16", weight[...])), T.uint32(16)))
```
### Fixing the problem
This situation is caused by L332 in the code below. Changing this part to
apply `DTypeConversion()` instead of `cast()` will resolve the issue. (In the
cases that the `Expr` is wrapped with `T.if_then_else()` or something else, the
`Expr` is processed properly in other visit functions through L312 or L313. So
the problems were avoided.)
#### L332
```diff
- value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value);
+ value = DTypeConversion(value,
new_buf->dtype.with_lanes(value.dtype().lanes()));
```
https://github.com/apache/tvm/blob/26b107fa12672c3b958da222fc87755a69d64c42/src/tir/transforms/unsupported_dtype_legalize.cc#L311-L338
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]