Recently I tried to implement GE-SpMM in TVM, using hybrid script:
```python
def mergespmm(num_rows, num_cols, nnz, indice_type, feat_type, feat_len):
indptr = tvm.te.placeholder((num_rows+1,), indice_type, 'indptr')
indices = tvm.te.placeholder((nnz,), indice_type, name='indices')
ufeat = tvm.te.placeholder((num_cols, feat_len), feat_type, name='ufeat')
CF = 1 if feat_len < 64 else 2
row_factor = 4 if feat_len < 64 else 8
@tvm.te.hybrid.script
def _mergespmm(indptr, indices, ufeat):
out = output_tensor((indptr.shape[0]-1, ufeat.shape[1]), 'float32')
sm_k = allocate((32*row_factor,), 'int32', 'shared')
result = allocate((CF,), 'float32', 'local')
row_start = allocate((1,), 'int32', 'local')
row_end = allocate((1,), 'int32', 'local')
for row_outer in bind('blockIdx.x', (indptr.shape[0]+row_factor-2) //
row_factor):
for feat_outer in bind('blockIdx.y', feat_len // 32 // CF):
for row_inner in bind('threadIdx.y', row_factor):
for elem_inner in bind('threadIdx.x', 32):
if row_outer * row_factor + row_inner <
indptr.shape[0]-1:
row_start[0] = indptr[row_outer * row_factor +
row_inner]
row_end[0] = indptr[row_outer * row_factor +
row_inner + 1]
for elem_outer in range((row_end[0] - row_start[0]
+ 31) // 32):
if row_start[0] + elem_outer * 32 + elem_inner
< row_end[0]:
sm_k[row_inner * 32 + elem_inner] =
indices[row_start[0] + elem_outer * 32 + elem_inner]
for kk in range(32):
if row_start[0] + elem_outer * 32 + kk <
row_end[0]:
for cf in unroll(CF):
result[cf] += ufeat[sm_k[row_inner
* 32 + kk], feat_outer * CF * 32 + cf * 32 + elem_inner]
for cf in unroll(CF):
out[row_outer*row_factor+row_inner, feat_outer
* CF * 32 + cf * 32 + elem_inner] = result[cf]
return out
out = _mergespmm(indptr, indices, ufeat)
sched = tvm.te.create_schedule(out.op)
f = tvm.build(sched, [indptr, indices, ufeat, out], target='cuda')
print(f.imported_modules[0].get_source())
return f
```
This will fail at src/tir/transforms/thread_storage_sync.cc:100, saying cannot
insert syncs inside condition. This is reasonable, because usually it can
produce deadlock. However, in this kernel, GE-SpMM, warps in a block do not
share shared memory, so `__syncthreads()` is not needed. Is it possible to let
programmers control sync operations? Or do we need another pass to check
whether sync is needed?
Besides, I don't know whether it is related, but I had this issue before.
https://discuss.tvm.apache.org/t/tvm-access-beyond-array-boundary/6998
@Huyuwei
---
[Visit
Topic](https://discuss.tvm.apache.org/t/tvm-cuda-generating-unnecessary-sync-operations/7975/1)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/2a38f8e9ea5e494bcbe82fc69f0337d9cb5a741a2d92cd0766d7c380f885d9db).