This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cb6efae413 [Unity] Support pattern-based rewriting (#14312)
cb6efae413 is described below

commit cb6efae413a1068efbb1b3509c3198c96cade138
Author: masahi <masahi...@gmail.com>
AuthorDate: Fri Mar 17 04:20:51 2023 +0900

    [Unity] Support pattern-based rewriting (#14312)
    
    * stub
    
    * wip
    
    * works
    
    * restore binding
    
    * attention test work
    
    * use RemoveAllUnused
    
    * simplified callback api
    
    * pass original call node to callback
    
    * clean test
    
    * add doc
    
    * add test for the case where the original call is returned
    
    * callback -> rewriter and other doc improvement
---
 python/tvm/relax/dpl/pattern.py             |  40 +++++++++-
 src/relax/ir/dataflow_matcher.cc            |  48 +++++++++++
 tests/python/relax/test_dataflow_pattern.py | 118 ++++++++++++++++++++++++++++
 3 files changed, 204 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 1ca41b378d..248e957726 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -20,7 +20,7 @@
 # pylint: disable=pointless-statement
 
 import typing
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union, Callable
 
 import tvm
 import tvm._ffi as tvm_ffi
@@ -31,7 +31,7 @@ from tvm.relay.op import get
 from ...ir import make_node
 from ...ir.base import Node
 from ...runtime import Object
-from ..expr import Expr, Var
+from ..expr import Expr, Var, Function
 from . import _ffi as ffi
 
 
@@ -1115,3 +1115,39 @@ def make_fused_bias_activation_pattern(op_name, 
with_bias=False, activation=None
         return is_op(activation)(out)
 
     return out
+
+
+def rewrite(
+    pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], 
Expr], func: Function
+) -> Function:
+    """
+    Rewrite a function with the given pattern and the rewriter function.
+
+    Parameters
+    ----------
+    pattern: DFPattern
+        The pattern to match.
+
+    rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
+        The function to be called on a successful matching for rewriting. 
Given the matched
+        call node and the map of patterns and matched expressions, it should 
return a new call node
+        to replace the original one or the original matched call node as is.
+
+        For example, to replace x + x with 2 * x, we can write the rewriter as 
follows:
+        ```
+        x = wildcard()
+        pattern = is_op("relax.add")(x, x)
+
+        def rewriter(orig, matchings):
+            return R.multiply(matchings[x], R.const(2, "float32"))
+        ```
+
+    func: Function
+        The function to rewrite.
+
+    Returns
+    -------
+    rewritten_func: Function
+        The rewritten or the input function, depending on the pattern matching 
result.
+    """
+    return ffi.rewrite(pattern, rewriter, func)
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index da8c6ce2da..c6d705b5b4 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -766,5 +766,53 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, 
const DataflowBlock& d
 
 TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph);
 
+/*!
+ * \brief Apply pattern matching to each call node and replace matching ones 
with the output of
+ * a user-provided rewriter function.
+ */
+class PatternRewriter : ExprMutator {
+ public:
+  using ExprMutator::VisitExpr_;
+
+  PatternRewriter(DFPattern pat, PackedFunc rewriter_func)
+      : pattern_(pat), rewriter_func_(rewriter_func) {}
+
+  static Expr Run(DFPattern pat, PackedFunc rewriter_func, Function f) {
+    PatternRewriter rewriter(pat, rewriter_func);
+    return RemoveAllUnused(Downcast<Function>(rewriter.VisitExpr(f)));
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) final {
+    bindings_.Set(binding->var, binding->value);
+    ExprMutator::VisitBinding_(binding);
+    if (auto it = memo_.find(binding->value.get()); it != memo_.end()) {
+      // We need to update the binding to pass to ExtractMatchedExpr, so that 
the rewritten
+      // expression can be subject to further pattern matchings.
+      bindings_.Set(binding->var, it->second);
+    }
+  }
+
+  Expr VisitExpr_(const CallNode* call_node) final {
+    auto call = ExprMutator::VisitExpr_(call_node);
+    if (auto matches_opt = ExtractMatchedExpr(pattern_, call, bindings_)) {
+      auto rewriten_expr = rewriter_func_(call, matches_opt.value());
+      memo_[call_node] = rewriten_expr;
+      return rewriten_expr;
+    }
+    return call;
+  }
+
+ private:
+  DFPattern pattern_;
+  PackedFunc rewriter_func_;
+  Map<Var, Expr> bindings_;
+  std::unordered_map<const Object*, Expr> memo_;
+};
+
+TVM_REGISTER_GLOBAL("relax.dpl.rewrite")
+    .set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) {
+      return PatternRewriter::Run(pat, rewriter, f);
+    });
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index b57dca19f2..a40faf3bcb 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -888,5 +888,123 @@ def test_incremental_solving_counter():
             assert not ctx1.match_dfb(simple_chain.body.blocks[0])
 
 
+def test_rewrite_simple():
+    @R.function
+    def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), 
"float32"):
+        with R.dataflow():
+            x2 = R.add(x, x)
+            x4 = R.add(x2, x2)
+            R.output(x4)
+        return x4
+
+    @R.function
+    def expected1(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 
16), dtype="float32"):
+        with R.dataflow():
+            lv: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(2, 
"float32"))
+            x4: R.Tensor((16, 16), dtype="float32") = R.multiply(lv, 
R.const(2, "float32"))
+            R.output(x4)
+        return x4
+
+    @R.function
+    def expected2(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 
16), dtype="float32"):
+        with R.dataflow():
+            x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(4, 
"float32"))
+            R.output(x4)
+        return x4
+
+    x = wildcard()
+    pattern = is_op("relax.add")(x, x)
+
+    def rewriter(_, matchings):
+        return R.multiply(matchings[x], R.const(2, "float32"))
+
+    rewritten = rewrite(pattern, rewriter, main)
+    tvm.ir.assert_structural_equal(rewritten, expected1)
+
+    add1 = is_op("relax.add")(x, x)
+    pattern = is_op("relax.add")(add1, add1)
+
+    def rewriter(_, matchings):
+        return R.multiply(matchings[x], R.const(4, "float32"))
+
+    rewritten = rewrite(pattern, rewriter, main)
+    tvm.ir.assert_structural_equal(rewritten, expected2)
+
+    # No rewriting, return the original call node as is
+    def rewriter(orig, _):
+        return orig
+
+    rewritten = rewrite(pattern, rewriter, main)
+    tvm.ir.assert_structural_equal(rewritten, main)
+
+
+def test_rewrite_attention():
+    @R.function
+    def main(
+        Q: R.Tensor((2, 4096, 8, 40), "float32"),
+        K: R.Tensor((2, 4096, 8, 40), "float32"),
+        V: R.Tensor((2, 4096, 8, 40), "float32"),
+    ) -> R.Tensor((2, 4096, 8, 40), "float32"):
+        with R.dataflow():
+            lv58 = R.permute_dims(Q, axes=[0, 2, 1, 3])
+            lv59 = R.reshape(lv58, R.shape([16, 4096, 40]))
+
+            lv61 = R.permute_dims(K, axes=[0, 2, 1, 3])
+            lv62 = R.reshape(lv61, R.shape([16, 4096, 40]))
+
+            lv64 = R.permute_dims(V, axes=[0, 2, 1, 3])
+            lv65 = R.reshape(lv64, R.shape([16, 4096, 40]))
+
+            lv62_transposed = R.permute_dims(lv62, axes=[0, 2, 1])
+            lv3_1 = R.matmul(lv59, lv62_transposed)
+            lv68 = R.multiply(lv3_1, R.const(0.15811388194561005, "float32"))
+            lv69 = R.nn.softmax(lv68, axis=-1)
+            lv_3 = R.matmul(lv69, lv65)
+
+            lv71 = R.reshape(lv_3, R.shape([2, 8, 4096, 40]))
+            lv72 = R.permute_dims(lv71, axes=[0, 2, 1, 3])
+            R.output(lv72)
+
+        return lv72
+
+    @R.function
+    def expected(
+        Q: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+        K: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+        V: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+    ) -> R.Tensor((2, 4096, 8, 40), dtype="float32"):
+        with R.dataflow():
+            lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") = 
R.nn.attention(Q, V, K)
+            R.output(lv72)
+        return lv72
+
+    def BSNH_to_BSH(tensor):
+        return is_op("relax.reshape")(is_op("relax.permute_dims")(tensor), 
wildcard())
+
+    def BSH_to_BSNH(tensor):
+        return is_op("relax.permute_dims")(is_op("relax.reshape")(tensor, 
wildcard()))
+
+    Q = wildcard()
+    K = wildcard()
+    V = wildcard()
+
+    Q_3D = BSNH_to_BSH(Q)
+    V_3D = BSNH_to_BSH(V)
+    K_3D = BSNH_to_BSH(K)
+
+    matmul1 = is_op("relax.matmul")(Q_3D, is_op("relax.permute_dims")(V_3D))
+    multiply = is_op("relax.multiply")(matmul1, is_const())
+    softmax = is_op("relax.nn.softmax")(multiply)
+    matmul2 = is_op("relax.matmul")(softmax, K_3D)
+
+    pattern = BSH_to_BSNH(matmul2)
+
+    def rewriter(_, matchings):
+        return R.nn.attention(matchings[Q], matchings[K], matchings[V])
+
+    rewritten = rewrite(pattern, rewriter, main)
+    tvm.ir.assert_structural_equal(rewritten, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to