majiang31312 edited a comment on issue #5686:
URL: https://github.com/apache/incubator-tvm/issues/5686#issuecomment-638953820


   The fix seems quite simple, but I'm not sure whether it's complete. 
   Please take a look at the Discussion section. Thanks! @tqchen @wpan11nv 
   
   Problem:
   when num_thread = 1 (that's the case for vulkan as CreateTarget in target.cc 
set thread_warp_size to 1),
   '
   ko, ki = s[B].split(B.op.reduce_axis[0], factor=num_thread)
   s[B].bind(ki, te.thread_axis("threadIdx.x"))
   '
   will triger "TVMError: Check failed: v:" in MakeAllreduce.
   when factor=1, simplify optimization replace the IterVar with a constant 
node, but MakeAllreduce want a var node.
   
   
   Reproduce:
   ```
   import tvm
   from tvm import te
   
   n, m = 32,32
   num_thread = 1
   A = te.placeholder((n, m), name='A' ,dtype = 'int8')
   k = te.reduce_axis((0, m), "k")
   B = te.compute((n, ), lambda i: te.sum(A[i, k], axis=[k]), name="B")
   
   s = te.create_schedule(B.op)
   ko, ki = s[B].split(B.op.reduce_axis[0], factor=num_thread)
   s[B].bind(ki, te.thread_axis("threadIdx.z"))
   
   #target = tvm.target.create("vulkan")
   target = tvm.target.create("cuda")
   s = tvm.lower(s, [A, B])
   s = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(s)
   s = tvm.tir.transform.Simplify()(s)
   print(s)
   s = tvm.tir.transform.LowerThreadAllreduce()(s)
   ```
   
   Fix:
   ```
   --- a/src/tir/transforms/lower_thread_allreduce.cc
   +++ b/src/tir/transforms/lower_thread_allreduce.cc
   @@ -154,9 +154,17 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
        std::unordered_set<const VarNode*> reduce_set;
        for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
          const VarNode* v = call->args[i].as<VarNode>();
   -      CHECK(v);
   -      reduce_set.insert(v);
   +      // The simply optimization replace a iteration variable with a 
constant
   +      // when extent of the iteration is 1. As threaded IterVar always 
started from 0, 
   +      // we can just ignore this variable in this case.
   +      if (v) {
   +        reduce_set.insert(v);
   +      } else {
   +        CHECK(call->args[i].as<IntImmNode>() && 
call->args[i].as<IntImmNode>()->value == 0) 
   +          << "arg" << i << "should be a VarNode or IntImmNode";
   +      }
        }
   +      
        size_t nmatch = 0;
        std::vector<ThreadEntry> vred, vpar;
        for (const AttrStmtNode* attr : thread_extents_) {
   @@ -170,6 +178,11 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
            const auto* ptr = attr->value.as<IntImmNode>();
            CHECK(ptr) << "Need constant extent for reduce set " << iv;
            e.extent = static_cast<int>(ptr->value);
   +        // ignore variables equal to 0
   +        if (e.extent == 1) {
   +          continue;
   +        }
   +
            if (reduce_set.count(iv->var.get())) {
              vred.push_back(e);
              ++nmatch;
   ```
   
   Discussion:
     At this moment threaded IterVar always started from 0, so we can safely 
ignore the const var node.
     Maybe we could keep a record somewhere after we replace a VarNode with a 
IntImmNode? I thinks that would help to deal with such kind of cases more 
clearly.
     By the way, the 'analyzer_.Simplify' in BufIndex can not work as expected. 
It looks like that the analyzer have not been initilized properly. I can 
provide test cases if someone want to take a look.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to