majiang31312 edited a comment on issue #5686: URL: https://github.com/apache/incubator-tvm/issues/5686#issuecomment-638953820
The fix seems quite simple, but I'm not sure whether it's complete. Please take a look at the Discussion section. Thanks! @tqchen @wpan11nv Problem: when num_thread = 1 (that's the case for vulkan as CreateTarget in target.cc set thread_warp_size to 1), ' ko, ki = s[B].split(B.op.reduce_axis[0], factor=num_thread) s[B].bind(ki, te.thread_axis("threadIdx.x")) ' will triger "TVMError: Check failed: v:" in MakeAllreduce. when factor=1, simplify optimization replace the IterVar with a constant node, but MakeAllreduce want a var node. Reproduce: ``` import tvm from tvm import te n, m = 32,32 num_thread = 1 A = te.placeholder((n, m), name='A' ,dtype = 'int8') k = te.reduce_axis((0, m), "k") B = te.compute((n, ), lambda i: te.sum(A[i, k], axis=[k]), name="B") s = te.create_schedule(B.op) ko, ki = s[B].split(B.op.reduce_axis[0], factor=num_thread) s[B].bind(ki, te.thread_axis("threadIdx.z")) #target = tvm.target.create("vulkan") target = tvm.target.create("cuda") s = tvm.lower(s, [A, B]) s = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(s) s = tvm.tir.transform.Simplify()(s) print(s) s = tvm.tir.transform.LowerThreadAllreduce()(s) ``` Fix: ``` --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -154,9 +154,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::unordered_set<const VarNode*> reduce_set; for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) { const VarNode* v = call->args[i].as<VarNode>(); - CHECK(v); - reduce_set.insert(v); + // The simply optimization replace a iteration variable with a constant + // when extent of the iteration is 1. As threaded IterVar always started from 0, + // we can just ignore this variable in this case. + if (v) { + reduce_set.insert(v); + } else { + CHECK(call->args[i].as<IntImmNode>() && call->args[i].as<IntImmNode>()->value == 0) + << "arg" << i << "should be a VarNode or IntImmNode"; + } } + size_t nmatch = 0; std::vector<ThreadEntry> vred, vpar; for (const AttrStmtNode* attr : thread_extents_) { @@ -170,6 +178,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const auto* ptr = attr->value.as<IntImmNode>(); CHECK(ptr) << "Need constant extent for reduce set " << iv; e.extent = static_cast<int>(ptr->value); + // ignore variables equal to 0 + if (e.extent == 1) { + continue; + } + if (reduce_set.count(iv->var.get())) { vred.push_back(e); ++nmatch; ``` Discussion: At this moment threaded IterVar always started from 0, so we can safely ignore the const var node. Maybe we could keep a record somewhere after we replace a VarNode with a IntImmNode? I thinks that would help to deal with such kind of cases more clearly. By the way, the 'analyzer_.Simplify' in BufIndex can not work as expected. It looks like that the analyzer have not been initilized properly. I can provide test cases if someone want to take a look. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org