This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 92593cd11f Add pattern-based dataflow block rewriting
92593cd11f is described below

commit 92593cd11f74b0e8e872926907d634978148dda5
Author: Masahiro Masuda <masahi...@gmail.com>
AuthorDate: Sat Apr 1 04:13:15 2023 +0900

    Add pattern-based dataflow block rewriting
---
 python/tvm/relax/dpl/pattern.py             |  34 ++++++-
 src/relax/ir/dataflow_matcher.cc            | 113 ++++++++++++++++++--
 tests/python/relax/test_dataflow_pattern.py | 153 +++++++++++++++++++++++++++-
 3 files changed, 285 insertions(+), 15 deletions(-)

diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index acabac2dcb..3026213ba2 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -1125,7 +1125,7 @@ def make_fused_bias_activation_pattern(op_name, 
with_bias=False, activation=None
     return out
 
 
-def rewrite(
+def rewrite_call(
     pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], 
Expr], func: Function
 ) -> Function:
     """
@@ -1158,4 +1158,34 @@ def rewrite(
     rewritten_func: Function
         The rewritten or the input function, depending on the pattern matching 
result.
     """
-    return ffi.rewrite(pattern, rewriter, func)
+    return ffi.rewrite_call(pattern, rewriter, func)
+
+
+def rewrite_bindings(
+    ctx, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: 
Function
+) -> Function:
+    """
+    Rewrite a function with the given pattern and the rewriter function.
+    Parameters
+    ----------
+    pattern: DFPattern
+        The pattern to match.
+    rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
+        The function to be called on a successful matching for rewriting. 
Given the matched
+        call node and the map of patterns and matched expressions, it should 
return a new call node
+        to replace the original one or the original matched call node as is.
+        For example, to replace x + x with 2 * x, we can write the rewriter as 
follows:
+        ```
+        x = wildcard()
+        pattern = is_op("relax.add")(x, x)
+        def rewriter(orig, matchings):
+            return R.multiply(matchings[x], R.const(2, "float32"))
+        ```
+    func: Function
+        The function to rewrite.
+    Returns
+    -------
+    rewritten_func: Function
+        The rewritten or the input function, depending on the pattern matching 
result.
+    """
+    return ffi.rewrite_bindings(ctx, rewriter, func)
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 88381d6e26..0055929a78 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -780,18 +780,29 @@ Optional<Map<DFPattern, Var>> MatchGraph(const 
PatternContext& ctx, const Datafl
 TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph);
 
 /*!
- * \brief Apply pattern matching to each call node and replace matching ones 
with the output of
- * a user-provided rewriter function.
+ * \brief Apply pattern matching to each call node and dataflow block, and 
replace matching ones
+ * with the output of a user-provided rewriter function.
  */
 class PatternRewriter : ExprMutator {
  public:
+  using ExprMutator::VisitBindingBlock_;
   using ExprMutator::VisitExpr_;
 
-  PatternRewriter(DFPattern pat, PackedFunc rewriter_func)
-      : pattern_(pat), rewriter_func_(rewriter_func) {}
+  PatternRewriter(DFPattern pat, PackedFunc rewriter_func,
+                  const std::unordered_set<const VarNode*>& params)
+      : pattern_(pat), rewriter_func_(rewriter_func), params_(params) {}
 
-  static Expr Run(DFPattern pat, PackedFunc rewriter_func, Function f) {
-    PatternRewriter rewriter(pat, rewriter_func);
+  PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func,
+                  const std::unordered_set<const VarNode*>& params)
+      : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {}
+
+  template <typename PatternType>
+  static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) {
+    std::unordered_set<const VarNode*> params;
+    for (const auto& p : f->params) {
+      params.insert(p.get());
+    }
+    PatternRewriter rewriter(pat, rewriter_func, params);
     return RemoveAllUnused(Downcast<Function>(rewriter.VisitExpr(f)));
   }
 
@@ -807,7 +818,9 @@ class PatternRewriter : ExprMutator {
 
   Expr VisitExpr_(const CallNode* call_node) final {
     auto call = ExprMutator::VisitExpr_(call_node);
-    if (auto matches_opt = ExtractMatchedExpr(pattern_, call, bindings_)) {
+    if (!pattern_) {
+      return call;
+    } else if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), call, 
bindings_)) {
       auto rewriten_expr = rewriter_func_(call, matches_opt.value());
       memo_[call_node] = rewriten_expr;
       return rewriten_expr;
@@ -815,17 +828,99 @@ class PatternRewriter : ExprMutator {
     return call;
   }
 
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final {
+    if (!ctx_) {
+      return ExprMutator::VisitBindingBlock_(block_node);
+    }
+    return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
+  }
+
  private:
-  DFPattern pattern_;
+  void EmitUsedVars(Expr val, const Array<Binding>& pending_bindings,
+                    std::unordered_set<const VarNode*>* emitted_vars) {
+    std::unordered_set<const VarNode*> unemitted_vars;
+    PostOrderVisit(val, [=, &unemitted_vars](Expr e) {
+      if (auto v = e.as<VarNode>(); v && !emitted_vars->count(v)) {
+        unemitted_vars.insert(v);
+      }
+    });
+
+    if (unemitted_vars.empty()) {
+      return;
+    }
+
+    size_t num_unemitted = unemitted_vars.size();
+    for (const auto& binding : pending_bindings) {
+      if (auto var_bind = binding.as<VarBindingNode>();
+          var_bind && unemitted_vars.count(var_bind->var.get())) {
+        EmitUsedVars(var_bind->value, pending_bindings, emitted_vars);
+        this->VisitBinding(binding);
+        emitted_vars->insert(var_bind->var.get());
+        if (--num_unemitted == 0) {
+          return;
+        }
+      }
+    }
+  }
+
+  // Repeat until all matchable subsets of bindings are rewritten.
+  BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) {
+    if (auto matches = MatchGraph(ctx_.value(), 
Downcast<DataflowBlock>(block))) {
+      builder_->BeginDataflowBlock();
+      Map<Var, Expr> replacements = rewriter_func_(matches.value());
+
+      std::unordered_set<const VarNode*> emitted_vars;
+
+      for (size_t i = 0; i < block->bindings.size(); ++i) {
+        const auto& binding = block->bindings[i];
+        if (auto var_bind = binding.as<VarBindingNode>()) {
+          if (replacements.count(var_bind->var)) {
+            auto new_val = replacements[var_bind->var];
+            Array<Binding> pending_bindings(block->bindings.begin() + i + 1, 
block->bindings.end());
+            // Make sure there is no unbound variable used in the new value 
before it is emitted
+            EmitUsedVars(new_val, pending_bindings, &emitted_vars);
+            this->ReEmitBinding(var_bind, builder_->Normalize(new_val));
+          } else if (!emitted_vars.count(var_bind->var.get())) {
+            this->VisitBinding(binding);
+            emitted_vars.insert(var_bind->var.get());
+          }
+        } else {
+          this->VisitBinding(binding);
+        }
+      }
+      return RewriteDataflowBlockFixedPoint(builder_->EndBlock());
+    }
+    return block;
+  }
+
+  /*! \brief The pattern for rewriting call nodes */
+  Optional<DFPattern> pattern_;
+  /*! \brief The pattern constraint contexts for rewriting dataflow blocks */
+  Optional<PatternContext> ctx_;
+  /*!
+   * \brief The user-provided rewriter function. Its signature and semantics 
are:
+   * - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the 
matched
+   *    call node and the map of patterns and matched expressions, it should 
return a new call node
+   *    to replace the original one or the original matched call node as is.
+   * - Map<DFPattern, Var> -> Map<Var, Expr> for dataflow block rewriting. 
Given the map of patterns
+   *   and corresponding variables (bound variables or parameters), it should 
return a map that
+   *   specifies new values for matched bound variables.
+   */
   PackedFunc rewriter_func_;
+  std::unordered_set<const VarNode*> params_;
   Map<Var, Expr> bindings_;
   std::unordered_map<const Object*, Expr> memo_;
 };
 
-TVM_REGISTER_GLOBAL("relax.dpl.rewrite")
+TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call")
     .set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) {
       return PatternRewriter::Run(pat, rewriter, f);
     });
 
+TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings")
+    .set_body_typed([](const PatternContext& ctx, PackedFunc rewriter, 
Function f) {
+      return PatternRewriter::Run(ctx, rewriter, f);
+    });
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index f18244096e..e4d7f7972c 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -918,7 +918,7 @@ def test_rewrite_simple():
     def rewriter(_, matchings):
         return R.multiply(matchings[x], R.const(2, "float32"))
 
-    rewritten = rewrite(pattern, rewriter, main)
+    rewritten = rewrite_call(pattern, rewriter, main)
     tvm.ir.assert_structural_equal(rewritten, expected1)
 
     add1 = is_op("relax.add")(x, x)
@@ -927,14 +927,14 @@ def test_rewrite_simple():
     def rewriter(_, matchings):
         return R.multiply(matchings[x], R.const(4, "float32"))
 
-    rewritten = rewrite(pattern, rewriter, main)
+    rewritten = rewrite_call(pattern, rewriter, main)
     tvm.ir.assert_structural_equal(rewritten, expected2)
 
     # No rewriting, return the original call node as is
     def rewriter(orig, _):
         return orig
 
-    rewritten = rewrite(pattern, rewriter, main)
+    rewritten = rewrite_call(pattern, rewriter, main)
     tvm.ir.assert_structural_equal(rewritten, main)
 
 
@@ -1002,7 +1002,7 @@ def test_rewrite_attention():
     def rewriter(_, matchings):
         return R.nn.attention(matchings[Q], matchings[K], matchings[V])
 
-    rewritten = rewrite(pattern, rewriter, main)
+    rewritten = rewrite_call(pattern, rewriter, main)
     tvm.ir.assert_structural_equal(rewritten, expected)
 
 
@@ -1075,5 +1075,150 @@ def test_attention_fake_qkv():
         assert ctx.match_dfb(dfb) is None
 
 
+def get_qkv_proj_rewriter(
+    inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, 
matmul3
+):
+    def qkv_proj_rewriter(matchings):
+        inp = matchings[inp_pat]
+        Q_weight = matchings[Q_weight_pat]
+        K_weight = matchings[K_weight_pat]
+        V_weight = matchings[V_weight_pat]
+        width = Q_weight.struct_info.shape[1]
+
+        concat = R.concat([Q_weight, K_weight, V_weight], axis=1)
+        matmul = R.matmul(inp, concat)
+        Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width])
+        K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2])
+        V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 
3])
+
+        return {matchings[matmul1]: Q, matchings[matmul2]: K, 
matchings[matmul3]: V}
+
+    return qkv_proj_rewriter
+
+
+def test_combine_matmul_twice():
+    @R.function
+    def qkv_x2(
+        x1: R.Tensor((2, 1024, 640), "float32"),
+        x2: R.Tensor((2, 1024, 640), "float32"),
+        w0: R.Tensor((640, 640), "float32"),
+        w1: R.Tensor((640, 640), "float32"),
+        w2: R.Tensor((640, 640), "float32"),
+        w3: R.Tensor((640, 640), "float32"),
+        w4: R.Tensor((640, 640), "float32"),
+        w5: R.Tensor((640, 640), "float32"),
+    ) -> R.Tensor:
+        with R.dataflow():
+            lv0 = R.matmul(x1, w0)
+            lv1 = R.matmul(x1, w1)
+            lv2 = R.matmul(x1, w2)
+            lv3 = R.matmul(x2, w3)
+            lv4 = R.matmul(x2, w4)
+            lv5 = R.matmul(x2, w5)
+            out = (lv0, lv1, lv2, lv3, lv4, lv5)
+            R.output(out)
+        return out
+
+    @R.function
+    def expected(
+        x1: R.Tensor((2, 1024, 640), "float32"),
+        x2: R.Tensor((2, 1024, 640), "float32"),
+        w0: R.Tensor((640, 640), "float32"),
+        w1: R.Tensor((640, 640), "float32"),
+        w2: R.Tensor((640, 640), "float32"),
+        w3: R.Tensor((640, 640), "float32"),
+        w4: R.Tensor((640, 640), "float32"),
+        w5: R.Tensor((640, 640), "float32"),
+    ) -> R.Tensor:
+        with R.dataflow():
+            lv = R.concat((w0, w1, w2), axis=1)
+            lv1 = R.matmul(x1, lv)
+            lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640])
+            lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280])
+            lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920])
+            lv2_1 = R.concat((w3, w4, w5), axis=1)
+            lv3 = R.matmul(x2, lv2_1, out_dtype="void")
+            lv3_1 = R.strided_slice(lv3, axes=[2], begin=[0], end=[640])
+            lv4 = R.strided_slice(lv3, axes=[2], begin=[640], end=[1280])
+            lv5 = R.strided_slice(lv3, axes=[2], begin=[1280], end=[1920])
+            out = lv0, lv1_1, lv2, lv3_1, lv4, lv5
+            R.output(out)
+        return out
+
+    with PatternContext() as ctx:
+        inp_pat = wildcard()
+        Q_weight_pat = wildcard()
+        K_weight_pat = wildcard()
+        V_weight_pat = wildcard()
+
+        matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
+        matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
+        matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
+
+        rewriter = get_qkv_proj_rewriter(
+            inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, 
matmul2, matmul3
+        )
+        rewritten = rewrite_bindings(ctx, rewriter, qkv_x2)
+        tvm.ir.assert_structural_equal(rewritten, expected)
+
+
+def test_combine_matmul_emit_order():
+    @R.function
+    def main(
+        x1: R.Tensor((2, 1024, 640), "float32"),
+        w0: R.Tensor((640, 640), "float32"),
+        w1: R.Tensor((640, 640), "float32"),
+        w2: R.Tensor((640, 640), "float32"),
+    ) -> R.Tensor:
+        with R.dataflow():
+            w0_t = R.permute_dims(w0, axes=None)
+            lv0 = R.matmul(x1, w0_t)
+            w1_t = R.permute_dims(w1, axes=None)
+            w1_t_t = R.permute_dims(w1_t, axes=None)
+            lv1 = R.matmul(x1, w1_t_t)
+            w2_t = R.permute_dims(w2, axes=None)
+            lv2 = R.matmul(x1, w2_t)
+            out = (lv0, lv1, lv2)
+            R.output(out)
+        return out
+
+    @R.function
+    def expected(
+        x1: R.Tensor((2, 1024, 640), dtype="float32"),
+        w0: R.Tensor((640, 640), dtype="float32"),
+        w1: R.Tensor((640, 640), dtype="float32"),
+        w2: R.Tensor((640, 640), dtype="float32"),
+    ) -> R.Tensor:
+        with R.dataflow():
+            w0_t = R.permute_dims(w0, axes=None)
+            w1_t = R.permute_dims(w1, axes=None)
+            w1_t_t = R.permute_dims(w1_t, axes=None)
+            w2_t = R.permute_dims(w2, axes=None)
+            lv = R.concat((w0_t, w1_t_t, w2_t), axis=1)
+            lv1 = R.matmul(x1, lv, out_dtype="void")
+            lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640])
+            lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280])
+            lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920])
+            out = lv0, lv1_1, lv2
+            R.output(out)
+        return out
+
+    with PatternContext() as ctx:
+        inp_pat = wildcard()
+        Q_weight_pat = wildcard()
+        K_weight_pat = wildcard()
+        V_weight_pat = wildcard()
+
+        matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
+        matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
+        matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
+
+        rewriter = get_qkv_proj_rewriter(
+            inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, 
matmul2, matmul3
+        )
+        rewritten = rewrite_bindings(ctx, rewriter, main)
+        tvm.ir.assert_structural_equal(rewritten, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to