This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new cb6efae413 [Unity] Support pattern-based rewriting (#14312) cb6efae413 is described below commit cb6efae413a1068efbb1b3509c3198c96cade138 Author: masahi <masahi...@gmail.com> AuthorDate: Fri Mar 17 04:20:51 2023 +0900 [Unity] Support pattern-based rewriting (#14312) * stub * wip * works * restore binding * attention test work * use RemoveAllUnused * simplified callback api * pass original call node to callback * clean test * add doc * add test for the case where the original call is returned * callback -> rewriter and other doc improvement --- python/tvm/relax/dpl/pattern.py | 40 +++++++++- src/relax/ir/dataflow_matcher.cc | 48 +++++++++++ tests/python/relax/test_dataflow_pattern.py | 118 ++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 1ca41b378d..248e957726 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -20,7 +20,7 @@ # pylint: disable=pointless-statement import typing -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Callable import tvm import tvm._ffi as tvm_ffi @@ -31,7 +31,7 @@ from tvm.relay.op import get from ...ir import make_node from ...ir.base import Node from ...runtime import Object -from ..expr import Expr, Var +from ..expr import Expr, Var, Function from . import _ffi as ffi @@ -1115,3 +1115,39 @@ def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None return is_op(activation)(out) return out + + +def rewrite( + pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function +) -> Function: + """ + Rewrite a function with the given pattern and the rewriter function. + + Parameters + ---------- + pattern: DFPattern + The pattern to match. + + rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr] + The function to be called on a successful matching for rewriting. Given the matched + call node and the map of patterns and matched expressions, it should return a new call node + to replace the original one or the original matched call node as is. + + For example, to replace x + x with 2 * x, we can write the rewriter as follows: + ``` + x = wildcard() + pattern = is_op("relax.add")(x, x) + + def rewriter(orig, matchings): + return R.multiply(matchings[x], R.const(2, "float32")) + ``` + + func: Function + The function to rewrite. + + Returns + ------- + rewritten_func: Function + The rewritten or the input function, depending on the pattern matching result. + """ + return ffi.rewrite(pattern, rewriter, func) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index da8c6ce2da..c6d705b5b4 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -766,5 +766,53 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, const DataflowBlock& d TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph); +/*! + * \brief Apply pattern matching to each call node and replace matching ones with the output of + * a user-provided rewriter function. + */ +class PatternRewriter : ExprMutator { + public: + using ExprMutator::VisitExpr_; + + PatternRewriter(DFPattern pat, PackedFunc rewriter_func) + : pattern_(pat), rewriter_func_(rewriter_func) {} + + static Expr Run(DFPattern pat, PackedFunc rewriter_func, Function f) { + PatternRewriter rewriter(pat, rewriter_func); + return RemoveAllUnused(Downcast<Function>(rewriter.VisitExpr(f))); + } + + void VisitBinding_(const VarBindingNode* binding) final { + bindings_.Set(binding->var, binding->value); + ExprMutator::VisitBinding_(binding); + if (auto it = memo_.find(binding->value.get()); it != memo_.end()) { + // We need to update the binding to pass to ExtractMatchedExpr, so that the rewritten + // expression can be subject to further pattern matchings. + bindings_.Set(binding->var, it->second); + } + } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = ExprMutator::VisitExpr_(call_node); + if (auto matches_opt = ExtractMatchedExpr(pattern_, call, bindings_)) { + auto rewriten_expr = rewriter_func_(call, matches_opt.value()); + memo_[call_node] = rewriten_expr; + return rewriten_expr; + } + return call; + } + + private: + DFPattern pattern_; + PackedFunc rewriter_func_; + Map<Var, Expr> bindings_; + std::unordered_map<const Object*, Expr> memo_; +}; + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite") + .set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) { + return PatternRewriter::Run(pat, rewriter, f); + }); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index b57dca19f2..a40faf3bcb 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -888,5 +888,123 @@ def test_incremental_solving_counter(): assert not ctx1.match_dfb(simple_chain.body.blocks[0]) +def test_rewrite_simple(): + @R.function + def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"): + with R.dataflow(): + x2 = R.add(x, x) + x4 = R.add(x2, x2) + R.output(x4) + return x4 + + @R.function + def expected1(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(2, "float32")) + x4: R.Tensor((16, 16), dtype="float32") = R.multiply(lv, R.const(2, "float32")) + R.output(x4) + return x4 + + @R.function + def expected2(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"): + with R.dataflow(): + x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(4, "float32")) + R.output(x4) + return x4 + + x = wildcard() + pattern = is_op("relax.add")(x, x) + + def rewriter(_, matchings): + return R.multiply(matchings[x], R.const(2, "float32")) + + rewritten = rewrite(pattern, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, expected1) + + add1 = is_op("relax.add")(x, x) + pattern = is_op("relax.add")(add1, add1) + + def rewriter(_, matchings): + return R.multiply(matchings[x], R.const(4, "float32")) + + rewritten = rewrite(pattern, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, expected2) + + # No rewriting, return the original call node as is + def rewriter(orig, _): + return orig + + rewritten = rewrite(pattern, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, main) + + +def test_rewrite_attention(): + @R.function + def main( + Q: R.Tensor((2, 4096, 8, 40), "float32"), + K: R.Tensor((2, 4096, 8, 40), "float32"), + V: R.Tensor((2, 4096, 8, 40), "float32"), + ) -> R.Tensor((2, 4096, 8, 40), "float32"): + with R.dataflow(): + lv58 = R.permute_dims(Q, axes=[0, 2, 1, 3]) + lv59 = R.reshape(lv58, R.shape([16, 4096, 40])) + + lv61 = R.permute_dims(K, axes=[0, 2, 1, 3]) + lv62 = R.reshape(lv61, R.shape([16, 4096, 40])) + + lv64 = R.permute_dims(V, axes=[0, 2, 1, 3]) + lv65 = R.reshape(lv64, R.shape([16, 4096, 40])) + + lv62_transposed = R.permute_dims(lv62, axes=[0, 2, 1]) + lv3_1 = R.matmul(lv59, lv62_transposed) + lv68 = R.multiply(lv3_1, R.const(0.15811388194561005, "float32")) + lv69 = R.nn.softmax(lv68, axis=-1) + lv_3 = R.matmul(lv69, lv65) + + lv71 = R.reshape(lv_3, R.shape([2, 8, 4096, 40])) + lv72 = R.permute_dims(lv71, axes=[0, 2, 1, 3]) + R.output(lv72) + + return lv72 + + @R.function + def expected( + Q: R.Tensor((2, 4096, 8, 40), dtype="float32"), + K: R.Tensor((2, 4096, 8, 40), dtype="float32"), + V: R.Tensor((2, 4096, 8, 40), dtype="float32"), + ) -> R.Tensor((2, 4096, 8, 40), dtype="float32"): + with R.dataflow(): + lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.nn.attention(Q, V, K) + R.output(lv72) + return lv72 + + def BSNH_to_BSH(tensor): + return is_op("relax.reshape")(is_op("relax.permute_dims")(tensor), wildcard()) + + def BSH_to_BSNH(tensor): + return is_op("relax.permute_dims")(is_op("relax.reshape")(tensor, wildcard())) + + Q = wildcard() + K = wildcard() + V = wildcard() + + Q_3D = BSNH_to_BSH(Q) + V_3D = BSNH_to_BSH(V) + K_3D = BSNH_to_BSH(K) + + matmul1 = is_op("relax.matmul")(Q_3D, is_op("relax.permute_dims")(V_3D)) + multiply = is_op("relax.multiply")(matmul1, is_const()) + softmax = is_op("relax.nn.softmax")(multiply) + matmul2 = is_op("relax.matmul")(softmax, K_3D) + + pattern = BSH_to_BSNH(matmul2) + + def rewriter(_, matchings): + return R.nn.attention(matchings[Q], matchings[K], matchings[V]) + + rewritten = rewrite(pattern, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, expected) + + if __name__ == "__main__": tvm.testing.main()