Re: [PR] [BugTIR] fix thread_sync occurs in letstmt [tvm]
vinx13 merged PR #16454: URL: https://github.com/apache/tvm/pull/16454 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
Re: [PR] [BugTIR] fix thread_sync occurs in letstmt [tvm]
JackWeiw commented on code in PR #16454: URL: https://github.com/apache/tvm/pull/16454#discussion_r1468355794 ## src/tir/transforms/storage_access.cc: ## @@ -94,6 +94,21 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { allow_append_ = false; } +void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { + allow_append_ = true; + curr_stmt_.access.clear(); Review Comment: thank you for your review, the unecessary clear has been removed -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
Re: [PR] [BugTIR] fix thread_sync occurs in letstmt [tvm]
vinx13 commented on code in PR #16454: URL: https://github.com/apache/tvm/pull/16454#discussion_r1468269410 ## src/tir/transforms/storage_access.cc: ## @@ -94,6 +94,21 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { allow_append_ = false; } +void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { + allow_append_ = true; + curr_stmt_.access.clear(); Review Comment: ```suggestion ``` This should work. You can remove this line since the assertion `ICHECK_EQ(curr_stmt_.access.size(), 0U);` holds -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
Re: [PR] [BugTIR] fix thread_sync occurs in letstmt [tvm]
JackWeiw commented on code in PR #16454: URL: https://github.com/apache/tvm/pull/16454#discussion_r1467227960 ## src/tir/transforms/storage_access.cc: ## @@ -94,6 +98,20 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { allow_append_ = false; } +void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { + allow_append_ = true; + curr_stmt_.access.clear(); Review Comment: can you check my new solution? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
Re: [PR] [BugTIR] fix thread_sync occurs in letstmt [tvm]
JackWeiw commented on code in PR #16454: URL: https://github.com/apache/tvm/pull/16454#discussion_r1467213250 ## tests/python/tir-transform/test_tir_transform_thread_sync.py: ## @@ -160,8 +160,49 @@ def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): tvm.ir.assert_structural_equal(mod["main"], expected) +@tvm.testing.requires_cuda +def test_sync_let_stmt(): +@T.prim_func(private=True) +def func(A: T.Buffer((16 * 512), "float32")): +blockIdx_x = T.launch_thread("blockIdx.x", 16) +A_shared = T.allocate([512], "float32", "shared") +in_thread_A_temp = T.allocate([1], "float32", "local") +cross_thread_A_temp = T.allocate([1], "float32", "local") +threadIdx_x = T.launch_thread("threadIdx.x", 128) +A_shared_1 = T.Buffer((512,), data=A_shared, scope="shared") +for ax0 in range(512): +A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] +in_thread_A_temp_1 = T.Buffer((1,), data=in_thread_A_temp, scope="local") +in_thread_A_temp_1[0] = T.float32(0) +with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp: Review Comment: in case here, before take change, `StorageAccessVisitor` will call `VisitExpr_(const BufferLoadNode* op)` directly (not called by father call like `VisitStmt_(const BufferStoreNode* op)` traverse child), so `ICHECK(allow_append_) << op << " " << scope.to_string();` will return error. if we take `VisitStmt_(const LetStmtNode* op)` and traverse child of `LetStmtNode` (case here it will traverse `BufferLoadNode` A_shared which will increase curr_stmt_.access by 1, and then visit `BufferStoreNode` `in_thread_A_temp` , `ICHECK_EQ(curr_stmt_.access.size(), 0U);` will return error) Do you have any insights on how to solve this problem? @vinx13 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
Re: [PR] [BugTIR] fix thread_sync occurs in letstmt [tvm]
vinx13 commented on code in PR #16454: URL: https://github.com/apache/tvm/pull/16454#discussion_r1466957964 ## tests/python/tir-transform/test_tir_transform_thread_sync.py: ## @@ -160,8 +160,49 @@ def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): tvm.ir.assert_structural_equal(mod["main"], expected) +@tvm.testing.requires_cuda +def test_sync_let_stmt(): +@T.prim_func(private=True) +def func(A: T.Buffer((16 * 512), "float32")): +blockIdx_x = T.launch_thread("blockIdx.x", 16) +A_shared = T.allocate([512], "float32", "shared") +in_thread_A_temp = T.allocate([1], "float32", "local") +cross_thread_A_temp = T.allocate([1], "float32", "local") +threadIdx_x = T.launch_thread("threadIdx.x", 128) +A_shared_1 = T.Buffer((512,), data=A_shared, scope="shared") +for ax0 in range(512): +A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] +in_thread_A_temp_1 = T.Buffer((1,), data=in_thread_A_temp, scope="local") +in_thread_A_temp_1[0] = T.float32(0) +with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp: +in_thread_A_temp_1[0] = A_temp +with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) as A_temp: +in_thread_A_temp_1[0] = A_temp +with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) as A_temp: +in_thread_A_temp_1[0] = A_temp +with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) as A_temp: +in_thread_A_temp_1[0] = A_temp +cross_thread_A_temp_1 = T.Buffer((1,), data=cross_thread_A_temp, scope="local") +with T.attr( +T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), +"reduce_scope", +T.reinterpret("handle", T.uint64(0)), +): +T.tvm_thread_allreduce( +T.uint32(1), +in_thread_A_temp_1[0], +T.bool(True), +cross_thread_A_temp_1[0], +threadIdx_x, +) + +mod = run_passes(func) +assert "T.tvm_storage_sync" in str(mod) Review Comment: it's better to use `tvm.ir.assert_structural_equal` if possible since it's not obvious where `T.tvm_storage_sync` should be inserted ## src/tir/transforms/storage_access.cc: ## @@ -94,6 +98,20 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { allow_append_ = false; } +void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { + allow_append_ = true; + curr_stmt_.access.clear(); Review Comment: shouldn't call clear here, it can only be used after finishing handling a statement ## src/tir/transforms/storage_access.cc: ## @@ -51,6 +51,10 @@ void StorageAccessVisitor::VisitExpr_(const BufferLoadNode* op) { } // traverse child StmtExprVisitor::VisitExpr_(op); + // push to the scope + scope_.back().push_back(curr_stmt_); + // clear access entry. + curr_stmt_.access.clear(); Review Comment: `BufferLoad` is not a statement, it is an expression, we shouldn't push and clear here -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
Re: [PR] [BugTIR] fix thread_sync occurs in letstmt [tvm]
JackWeiw commented on PR #16454: URL: https://github.com/apache/tvm/pull/16454#issuecomment-1909980005 cc @vinx13 @spectrometerHBH please spend a little time do a review if you see this -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
Re: [PR] [BugTIR] fix thread_sync occurs in letstmt [tvm]
Hzfengsy commented on PR #16454: URL: https://github.com/apache/tvm/pull/16454#issuecomment-1905259889 cc @vinx13 @spectrometerHBH -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
Re: [PR] [BugTIR] fix thread_sync occurs in letstmt [tvm]
JackWeiw closed pull request #16447: [BugTIR] fix thread_sync occurs in letstmt URL: https://github.com/apache/tvm/pull/16447 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org