cblmemo opened a new pull request, #14413:
URL: https://github.com/apache/tvm/pull/14413
This PR provides a workaround for a bug when the schedule primitive
`reverse_compute_inline` creates a predicate. In the past, the block structure
is like this:
```
For v0 in ... {
For v1 in ... {
BlockRealize {
Predicate; // <-- `reverse_compute_inline` will generate a predicate
here
LoopBody;
}
}
}
```
In which the predicate is related to the loop vars, like `v0` and `v1`.
However, the block generated by memhammer is like this:
```
BlockRealize {
Predicate; // <-- `reverse_compute_inline` will generate a predicate here
For v0 in ... {
For v1 in ... {
LoopBody;
}
}
}
```
This will cause a use-before-def for all loop vars and will generate an
error tir like this:
```python
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(a: T.Buffer((128,), "float32"), b: T.Buffer((128,), "float32")):
T.func_attr({"global_symbol": "main"})
# with T.block("root"):
b_reindex_shared_dyn = T.alloc_buffer((32, 4), scope="shared.dyn")
for i in T.thread_binding(32, thread="threadIdx.x"):
for j in range(4):
with T.block("B"):
v0, v1 = T.axis.remap("SS", [i, j])
T.reads(a[v0 * 4 + v1])
T.writes(b_reindex_shared_dyn[v0, v1])
b_reindex_shared_dyn[v0, v1] = a[v0 * 4 + v1] *
T.float32(2)
ax0 = T.int32()
ax1 = T.int32()
with T.block("b_reindex_shared.dyn"):
v0 = T.axis.spatial(4, j)
T.where(0 <= ax0 and ax0 < 32 and 0 <= j + ax1 and j +
ax1 < 4) # ax0, ax1 used here
T.reads(b_reindex_shared_dyn[0:32, v0])
T.writes(b[v0:v0 + 125])
T.block_attr({"auto_copy": 1})
for ax0, ax1 in T.grid(32, 1): # ax0, ax1 defined here
b[ax0 * 4 + (v0 + ax1)] = b_reindex_shared_dyn[ax0,
v0 + ax1]
```
For now, this PR provides a workaround for some special predicates: const
true. By checking whether it could be proved, we could eliminate this error if
predicates are const true, which is a common case. We still need to discuss how
to fix this bug.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]