In my schedule there are two ops. One is to calculate the result using gemm and
the other is to reshape it . The function is like this:
```
for (i.outer.outer, 0, 98) {
for (j.outer.outer, 0, 16) {
for (ii, 0, 8) {
for (jj, 0, 8) {
gemm_C[((((i.outer.outer*1024) + (j.outer.outer*64)) + (ii*8)) + jj)]
= gemm_C.wmma.accumulator[((((i.outer.outer*1024) + (j.outer.outer*64)) +
(ii*8)) + jj)]
}
}
}
}
for (n.oh.fused.ow.fused.outer.outer.outer, 0, 98) {
for (oc.outer.outer.outer, 0, 16) {
for (n.oh.fused.ow.fused.inner, 0, 8) {
for (oc.inner, 0, 8) {
output[((((n.oh.fused.ow.fused.outer.outer.outer*1024) +
(n.oh.fused.ow.fused.inner*128)) + (oc.outer.outer.outer*8)) + oc.inner)] =
gemm_C[((((n.oh.fused.ow.fused.outer.outer.outer*1024) +
(oc.outer.outer.outer*64)) + (n.oh.fused.ow.fused.inner*8)) + oc.inner)]
}
}
}
}
```
I want these two operations to be in the same kernel. The `gemm_C` result needs
to be stored in the shared memory. I first bind the output axis to block and
thread.
```
for (i, 0, 98) {
for (j, 0, 16) {
for (ii, 0, 8) {
for (jj, 0, 8) {
gemm_C[((((i*1024) + (j*64)) + (ii*8)) + jj)] =
gemm_C.wmma.accumulator[((((i*1024) + (j*64)) + (ii*8)) + jj)]
}
}
}
}
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 98
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 16
// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
for (n.oh.fused.ow.fused.inner, 0, 8) {
for (oc.inner, 0, 8) {
output[((((blockIdx.x*1024) + (n.oh.fused.ow.fused.inner*128)) +
(blockIdx.y*8)) + oc.inner)] = gemm_C[((((blockIdx.x*1024) + (blockIdx.y*64)) +
(n.oh.fused.ow.fused.inner*8)) + oc.inner)]
}
}
```
And then I try to set the scope for `gemm_C` by using
`s[gemm_C].set_scope('shared')` or `compute_at()`. Both methods will give the
result like:
```
for (i, 0, 98) {
for (j, 0, (16 - blockIdx.y)) {
for (ii, 0, 8) {
for (jj, 0, 8) {
if (likely(((j + blockIdx.y) < 16))) {
gemm_C[(((((i*(16 - blockIdx.y))*64) + (j*64)) + (ii*8)) + jj)] =
gemm_C.wmma.accumulator[(((((i*(16 - blockIdx.y))*64) + (j*64)) + (ii*8)) + jj)]
}
}
}
}
}
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 98
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 16
// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
for (n.oh.fused.ow.fused.inner, 0, 8) {
for (oc.inner, 0, 8) {
output[((((blockIdx.x*1024) + (n.oh.fused.ow.fused.inner*128)) +
(blockIdx.y*8)) + oc.inner)] = gemm_C[((((blockIdx.x*(16 - blockIdx.y))*64) +
(n.oh.fused.ow.fused.inner*8)) + oc.inner)]
}
}
```
The `j-axis` of `gemm_C` is inferred to be `(j, 0, (16-blockIdx.y)`. I can't
bind this axis to `block_y` because of this weird inference.
Am I doing the correct things to achieve my goal? What are the possible reasons
to cause `iter_var` to be inferred like this? How should I solve this problem?
---
[Visit
Topic](https://discuss.tvm.ai/t/gpu-thread-binding-and-iter-var-infer/6598/1)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.ai/email/unsubscribe/4bbe09530afc1a2a2e33ee7042cfb895cd2c74a06b83dec420037d0c8359c995).