Lunderberg commented on code in PR #16591:
URL: https://github.com/apache/tvm/pull/16591#discussion_r1500762292


##########
src/relax/transform/combine_parallel_matmul.cc:
##########
@@ -140,40 +149,68 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, 
Var>, Map<Var, Expr>)> Ge
     for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) {
       if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices, 
rhs_shapes)) continue;
 
-      auto inp = matchings[patterns.input];
+      auto lhs = matchings[patterns.input];
+
+      const auto& patterns_to_replace = [&patterns, &branch_info]() {
+        if (branch_info.activation) return patterns.activation;
+        if (branch_info.bias_dim) return patterns.bias_add;
+        return patterns.matmul;
+      }();
 
-      Array<Var> rhs, bias;
-      for (auto ind : indices) {
-        rhs.push_back(matchings[patterns.rhs[ind]]);
-        if (branch_info.bias_dim) {
-          ICHECK(matchings.count(patterns.bias[ind]));
-          bias.push_back(matchings[patterns.bias[ind]]);
+      std::vector<SplitInfo> splits;
+      for (auto index : indices) {
+        Var rhs = matchings[patterns.rhs[index]];
+        Optional<Var> bias = NullOpt;
+        if (branch_info.bias_dim.has_value()) {
+          bias = matchings[patterns.bias[index]];
         }
+        PrimExpr split_size = GetTensorSInfo(rhs)->GetShape().value()[rhs_dim 
- 1];
+        DFPattern pattern_to_replace = patterns_to_replace[index];
+        splits.push_back(SplitInfo{rhs, bias, split_size, pattern_to_replace});
+      }
+      // At most one dynamic output shape can be part of the combined
+      // matmul, and it must be the last item in the split.  Use
+      // `std::stable_sort` instead of `std::sort` to maintain a
+      // consistent order for all static shapes, and to consistently
+      // select the same dynamic weight to participate.
+      auto is_dynamic_split = [](const SplitInfo& split) -> bool {
+        return !split.split_size->IsInstance<IntImmNode>();
+      };
+      std::stable_sort(splits.begin(), splits.end(),
+                       [&is_dynamic_split](const auto& a, const auto& b) {
+                         return is_dynamic_split(a) < is_dynamic_split(b);

Review Comment:
   Thank you.  I went back and forth on whether this was reasonably clever, or 
too clever, and I think I like it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to