vinx13 commented on code in PR #16454: URL: https://github.com/apache/tvm/pull/16454#discussion_r1466957964
########## tests/python/tir-transform/test_tir_transform_thread_sync.py: ########## @@ -160,8 +160,49 @@ def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): tvm.ir.assert_structural_equal(mod["main"], expected) +@tvm.testing.requires_cuda +def test_sync_let_stmt(): + @T.prim_func(private=True) + def func(A: T.Buffer((16 * 512), "float32")): + blockIdx_x = T.launch_thread("blockIdx.x", 16) + A_shared = T.allocate([512], "float32", "shared") + in_thread_A_temp = T.allocate([1], "float32", "local") + cross_thread_A_temp = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 128) + A_shared_1 = T.Buffer((512,), data=A_shared, scope="shared") + for ax0 in range(512): + A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] + in_thread_A_temp_1 = T.Buffer((1,), data=in_thread_A_temp, scope="local") + in_thread_A_temp_1[0] = T.float32(0) + with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp: + in_thread_A_temp_1[0] = A_temp + with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) as A_temp: + in_thread_A_temp_1[0] = A_temp + with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) as A_temp: + in_thread_A_temp_1[0] = A_temp + with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) as A_temp: + in_thread_A_temp_1[0] = A_temp + cross_thread_A_temp_1 = T.Buffer((1,), data=cross_thread_A_temp, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + in_thread_A_temp_1[0], + T.bool(True), + cross_thread_A_temp_1[0], + threadIdx_x, + ) + + mod = run_passes(func) + assert "T.tvm_storage_sync" in str(mod) Review Comment: it's better to use `tvm.ir.assert_structural_equal` if possible since it's not obvious where `T.tvm_storage_sync` should be inserted ########## src/tir/transforms/storage_access.cc: ########## @@ -94,6 +98,20 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { allow_append_ = false; } +void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { + allow_append_ = true; + curr_stmt_.access.clear(); Review Comment: shouldn't call clear here, it can only be used after finishing handling a statement ########## src/tir/transforms/storage_access.cc: ########## @@ -51,6 +51,10 @@ void StorageAccessVisitor::VisitExpr_(const BufferLoadNode* op) { } // traverse child StmtExprVisitor::VisitExpr_(op); + // push to the scope + scope_.back().push_back(curr_stmt_); + // clear access entry. + curr_stmt_.access.clear(); Review Comment: `BufferLoad` is not a statement, it is an expression, we shouldn't push and clear here -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org