This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 92593cd11f Add pattern-based dataflow block rewriting 92593cd11f is described below commit 92593cd11f74b0e8e872926907d634978148dda5 Author: Masahiro Masuda <masahi...@gmail.com> AuthorDate: Sat Apr 1 04:13:15 2023 +0900 Add pattern-based dataflow block rewriting --- python/tvm/relax/dpl/pattern.py | 34 ++++++- src/relax/ir/dataflow_matcher.cc | 113 ++++++++++++++++++-- tests/python/relax/test_dataflow_pattern.py | 153 +++++++++++++++++++++++++++- 3 files changed, 285 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index acabac2dcb..3026213ba2 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -1125,7 +1125,7 @@ def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None return out -def rewrite( +def rewrite_call( pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function ) -> Function: """ @@ -1158,4 +1158,34 @@ def rewrite( rewritten_func: Function The rewritten or the input function, depending on the pattern matching result. """ - return ffi.rewrite(pattern, rewriter, func) + return ffi.rewrite_call(pattern, rewriter, func) + + +def rewrite_bindings( + ctx, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function +) -> Function: + """ + Rewrite a function with the given pattern and the rewriter function. + Parameters + ---------- + pattern: DFPattern + The pattern to match. + rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr] + The function to be called on a successful matching for rewriting. Given the matched + call node and the map of patterns and matched expressions, it should return a new call node + to replace the original one or the original matched call node as is. + For example, to replace x + x with 2 * x, we can write the rewriter as follows: + ``` + x = wildcard() + pattern = is_op("relax.add")(x, x) + def rewriter(orig, matchings): + return R.multiply(matchings[x], R.const(2, "float32")) + ``` + func: Function + The function to rewrite. + Returns + ------- + rewritten_func: Function + The rewritten or the input function, depending on the pattern matching result. + """ + return ffi.rewrite_bindings(ctx, rewriter, func) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 88381d6e26..0055929a78 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -780,18 +780,29 @@ Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const Datafl TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph); /*! - * \brief Apply pattern matching to each call node and replace matching ones with the output of - * a user-provided rewriter function. + * \brief Apply pattern matching to each call node and dataflow block, and replace matching ones + * with the output of a user-provided rewriter function. */ class PatternRewriter : ExprMutator { public: + using ExprMutator::VisitBindingBlock_; using ExprMutator::VisitExpr_; - PatternRewriter(DFPattern pat, PackedFunc rewriter_func) - : pattern_(pat), rewriter_func_(rewriter_func) {} + PatternRewriter(DFPattern pat, PackedFunc rewriter_func, + const std::unordered_set<const VarNode*>& params) + : pattern_(pat), rewriter_func_(rewriter_func), params_(params) {} - static Expr Run(DFPattern pat, PackedFunc rewriter_func, Function f) { - PatternRewriter rewriter(pat, rewriter_func); + PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func, + const std::unordered_set<const VarNode*>& params) + : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {} + + template <typename PatternType> + static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) { + std::unordered_set<const VarNode*> params; + for (const auto& p : f->params) { + params.insert(p.get()); + } + PatternRewriter rewriter(pat, rewriter_func, params); return RemoveAllUnused(Downcast<Function>(rewriter.VisitExpr(f))); } @@ -807,7 +818,9 @@ class PatternRewriter : ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { auto call = ExprMutator::VisitExpr_(call_node); - if (auto matches_opt = ExtractMatchedExpr(pattern_, call, bindings_)) { + if (!pattern_) { + return call; + } else if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), call, bindings_)) { auto rewriten_expr = rewriter_func_(call, matches_opt.value()); memo_[call_node] = rewriten_expr; return rewriten_expr; @@ -815,17 +828,99 @@ class PatternRewriter : ExprMutator { return call; } + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final { + if (!ctx_) { + return ExprMutator::VisitBindingBlock_(block_node); + } + return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node)); + } + private: - DFPattern pattern_; + void EmitUsedVars(Expr val, const Array<Binding>& pending_bindings, + std::unordered_set<const VarNode*>* emitted_vars) { + std::unordered_set<const VarNode*> unemitted_vars; + PostOrderVisit(val, [=, &unemitted_vars](Expr e) { + if (auto v = e.as<VarNode>(); v && !emitted_vars->count(v)) { + unemitted_vars.insert(v); + } + }); + + if (unemitted_vars.empty()) { + return; + } + + size_t num_unemitted = unemitted_vars.size(); + for (const auto& binding : pending_bindings) { + if (auto var_bind = binding.as<VarBindingNode>(); + var_bind && unemitted_vars.count(var_bind->var.get())) { + EmitUsedVars(var_bind->value, pending_bindings, emitted_vars); + this->VisitBinding(binding); + emitted_vars->insert(var_bind->var.get()); + if (--num_unemitted == 0) { + return; + } + } + } + } + + // Repeat until all matchable subsets of bindings are rewritten. + BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { + if (auto matches = MatchGraph(ctx_.value(), Downcast<DataflowBlock>(block))) { + builder_->BeginDataflowBlock(); + Map<Var, Expr> replacements = rewriter_func_(matches.value()); + + std::unordered_set<const VarNode*> emitted_vars; + + for (size_t i = 0; i < block->bindings.size(); ++i) { + const auto& binding = block->bindings[i]; + if (auto var_bind = binding.as<VarBindingNode>()) { + if (replacements.count(var_bind->var)) { + auto new_val = replacements[var_bind->var]; + Array<Binding> pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); + // Make sure there is no unbound variable used in the new value before it is emitted + EmitUsedVars(new_val, pending_bindings, &emitted_vars); + this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); + } else if (!emitted_vars.count(var_bind->var.get())) { + this->VisitBinding(binding); + emitted_vars.insert(var_bind->var.get()); + } + } else { + this->VisitBinding(binding); + } + } + return RewriteDataflowBlockFixedPoint(builder_->EndBlock()); + } + return block; + } + + /*! \brief The pattern for rewriting call nodes */ + Optional<DFPattern> pattern_; + /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ + Optional<PatternContext> ctx_; + /*! + * \brief The user-provided rewriter function. Its signature and semantics are: + * - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the matched + * call node and the map of patterns and matched expressions, it should return a new call node + * to replace the original one or the original matched call node as is. + * - Map<DFPattern, Var> -> Map<Var, Expr> for dataflow block rewriting. Given the map of patterns + * and corresponding variables (bound variables or parameters), it should return a map that + * specifies new values for matched bound variables. + */ PackedFunc rewriter_func_; + std::unordered_set<const VarNode*> params_; Map<Var, Expr> bindings_; std::unordered_map<const Object*, Expr> memo_; }; -TVM_REGISTER_GLOBAL("relax.dpl.rewrite") +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call") .set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) { return PatternRewriter::Run(pat, rewriter, f); }); +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings") + .set_body_typed([](const PatternContext& ctx, PackedFunc rewriter, Function f) { + return PatternRewriter::Run(ctx, rewriter, f); + }); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index f18244096e..e4d7f7972c 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -918,7 +918,7 @@ def test_rewrite_simple(): def rewriter(_, matchings): return R.multiply(matchings[x], R.const(2, "float32")) - rewritten = rewrite(pattern, rewriter, main) + rewritten = rewrite_call(pattern, rewriter, main) tvm.ir.assert_structural_equal(rewritten, expected1) add1 = is_op("relax.add")(x, x) @@ -927,14 +927,14 @@ def test_rewrite_simple(): def rewriter(_, matchings): return R.multiply(matchings[x], R.const(4, "float32")) - rewritten = rewrite(pattern, rewriter, main) + rewritten = rewrite_call(pattern, rewriter, main) tvm.ir.assert_structural_equal(rewritten, expected2) # No rewriting, return the original call node as is def rewriter(orig, _): return orig - rewritten = rewrite(pattern, rewriter, main) + rewritten = rewrite_call(pattern, rewriter, main) tvm.ir.assert_structural_equal(rewritten, main) @@ -1002,7 +1002,7 @@ def test_rewrite_attention(): def rewriter(_, matchings): return R.nn.attention(matchings[Q], matchings[K], matchings[V]) - rewritten = rewrite(pattern, rewriter, main) + rewritten = rewrite_call(pattern, rewriter, main) tvm.ir.assert_structural_equal(rewritten, expected) @@ -1075,5 +1075,150 @@ def test_attention_fake_qkv(): assert ctx.match_dfb(dfb) is None +def get_qkv_proj_rewriter( + inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 +): + def qkv_proj_rewriter(matchings): + inp = matchings[inp_pat] + Q_weight = matchings[Q_weight_pat] + K_weight = matchings[K_weight_pat] + V_weight = matchings[V_weight_pat] + width = Q_weight.struct_info.shape[1] + + concat = R.concat([Q_weight, K_weight, V_weight], axis=1) + matmul = R.matmul(inp, concat) + Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width]) + K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2]) + V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 3]) + + return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V} + + return qkv_proj_rewriter + + +def test_combine_matmul_twice(): + @R.function + def qkv_x2( + x1: R.Tensor((2, 1024, 640), "float32"), + x2: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + w3: R.Tensor((640, 640), "float32"), + w4: R.Tensor((640, 640), "float32"), + w5: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x1, w0) + lv1 = R.matmul(x1, w1) + lv2 = R.matmul(x1, w2) + lv3 = R.matmul(x2, w3) + lv4 = R.matmul(x2, w4) + lv5 = R.matmul(x2, w5) + out = (lv0, lv1, lv2, lv3, lv4, lv5) + R.output(out) + return out + + @R.function + def expected( + x1: R.Tensor((2, 1024, 640), "float32"), + x2: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + w3: R.Tensor((640, 640), "float32"), + w4: R.Tensor((640, 640), "float32"), + w5: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv = R.concat((w0, w1, w2), axis=1) + lv1 = R.matmul(x1, lv) + lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640]) + lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280]) + lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920]) + lv2_1 = R.concat((w3, w4, w5), axis=1) + lv3 = R.matmul(x2, lv2_1, out_dtype="void") + lv3_1 = R.strided_slice(lv3, axes=[2], begin=[0], end=[640]) + lv4 = R.strided_slice(lv3, axes=[2], begin=[640], end=[1280]) + lv5 = R.strided_slice(lv3, axes=[2], begin=[1280], end=[1920]) + out = lv0, lv1_1, lv2, lv3_1, lv4, lv5 + R.output(out) + return out + + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + rewriter = get_qkv_proj_rewriter( + inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 + ) + rewritten = rewrite_bindings(ctx, rewriter, qkv_x2) + tvm.ir.assert_structural_equal(rewritten, expected) + + +def test_combine_matmul_emit_order(): + @R.function + def main( + x1: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + w0_t = R.permute_dims(w0, axes=None) + lv0 = R.matmul(x1, w0_t) + w1_t = R.permute_dims(w1, axes=None) + w1_t_t = R.permute_dims(w1_t, axes=None) + lv1 = R.matmul(x1, w1_t_t) + w2_t = R.permute_dims(w2, axes=None) + lv2 = R.matmul(x1, w2_t) + out = (lv0, lv1, lv2) + R.output(out) + return out + + @R.function + def expected( + x1: R.Tensor((2, 1024, 640), dtype="float32"), + w0: R.Tensor((640, 640), dtype="float32"), + w1: R.Tensor((640, 640), dtype="float32"), + w2: R.Tensor((640, 640), dtype="float32"), + ) -> R.Tensor: + with R.dataflow(): + w0_t = R.permute_dims(w0, axes=None) + w1_t = R.permute_dims(w1, axes=None) + w1_t_t = R.permute_dims(w1_t, axes=None) + w2_t = R.permute_dims(w2, axes=None) + lv = R.concat((w0_t, w1_t_t, w2_t), axis=1) + lv1 = R.matmul(x1, lv, out_dtype="void") + lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640]) + lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280]) + lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920]) + out = lv0, lv1_1, lv2 + R.output(out) + return out + + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + rewriter = get_qkv_proj_rewriter( + inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 + ) + rewritten = rewrite_bindings(ctx, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, expected) + + if __name__ == "__main__": tvm.testing.main()