This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 5513095b5a [Unity][BYOC] Improve expressiveness of the pattern check function in FuseOpsByPattern (#14310) 5513095b5a is described below commit 5513095b5a28481ea4aea4979177b16bdc78dac6 Author: Lite Ye <yelite...@gmail.com> AuthorDate: Fri Mar 17 19:07:19 2023 -0400 [Unity][BYOC] Improve expressiveness of the pattern check function in FuseOpsByPattern (#14310) * Change the input of FuseOpsByPattern and add check for result dependency in cutlass conv2d residual block * Rename FuseOpsPattern to FusionPattern and PatternCheckFunctionInput to PatternCheckContext --- include/tvm/relax/transform.h | 103 ++++++++++++++++-- python/tvm/contrib/cutlass/build.py | 8 +- python/tvm/relax/backend/contrib/cutlass.py | 80 +++++++------- python/tvm/relax/backend/pattern_registry.py | 75 ++----------- python/tvm/relax/backend/patterns.py | 49 +++++---- python/tvm/relax/transform/transform.py | 118 ++++++++++++++++----- src/relax/backend/pattern_registry.cc | 39 ++----- src/relax/backend/pattern_registry.h | 59 +---------- src/relax/transform/fuse_ops.cc | 87 ++++++++++++--- tests/python/relax/test_codegen_cutlass.py | 39 ++++++- .../relax/test_transform_fuse_ops_by_pattern.py | 30 ++++-- 11 files changed, 414 insertions(+), 273 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 3ff863dd09..e0fe226e83 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -215,17 +215,108 @@ TVM_DLL Pass AnnotateTIROpPattern(); */ TVM_DLL Pass FuseOps(int fuse_opt_level = -1); +/*! + * \brief The pattern object used as the input of FuseOpsByPattern. For bindings to be + * fused, it needs to be matched with `pattern` and the `check` function needs to return + * true. + */ +class FusionPatternNode : public Object { + public: + /*! + * \brief The name of pattern. It becomes the value of the kComposite attribute + * of a fused function after successful matching + */ + String name; + + /*! + * \brief The dataflow pattern that will be used to match expression in the DataflowBlock. + * All the call nodes covered by the pattern will be extracted into the fused function. + */ + DFPattern pattern; + + /*! + * \brief The map which is used to extract important expressions from the pattern match + * result. All DFPattern in this map should be part of the `pattern`. + */ + Map<String, DFPattern> annotation_patterns; + + /*! + * \brief The function to determine whether the match result is accepted. This can be + * NullOpt if check function is not necessary for this pattern. + * + * It should have signature + * bool(const PatternCheckContext& context) + */ + Optional<PackedFunc> check; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("pattern", &pattern); + v->Visit("annotation_patterns", &annotation_patterns); + v->Visit("check", &check); + } + + static constexpr const char* _type_key = "relax.transform.FusionPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(FusionPatternNode, Object); +}; + +class FusionPattern : public ObjectRef { + public: + FusionPattern(String name, DFPattern pattern, Map<String, DFPattern> annotation_patterns, + Optional<PackedFunc> check); + + FusionPattern(String name, DFPattern pattern) : FusionPattern(name, pattern, {}, NullOpt) {} + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FusionPattern, ObjectRef, FusionPatternNode); +}; + +/*! + * \brief The input of FusionPattern::check. + */ +class PatternCheckContextNode : public Object { + public: + /*! + * \brief A map which contains all expressions matched by the sub patterns in + * FusionPattern::annotation_patterns. + */ + Map<String, Expr> annotated_expr; + + /*! + * \brief A map mapping variable definitions to a set of uses. + */ + Map<Var, Array<Var>> var_usages; + + /*! + * \brief Map from value to its bound variable. + */ + Map<Expr, Var> value_to_bound_var; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("annotated_expr", &annotated_expr); + v->Visit("var_usages", &var_usages); + v->Visit("value_to_bound_var", &value_to_bound_var); + } + + static constexpr const char* _type_key = "relax.transform.PatternCheckContext"; + TVM_DECLARE_FINAL_OBJECT_INFO(PatternCheckContextNode, Object); +}; + +class PatternCheckContext : public ObjectRef { + public: + PatternCheckContext(Map<String, Expr> annotated_expr, Map<Var, Array<Var>> var_usages, + Map<Expr, Var> value_to_bound_var); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternCheckContext, ObjectRef, + PatternCheckContextNode); +}; + /*! * \brief Apply pattern matching to each function in the given module, and group matched * expressions into a new function. The end result is similar to FuseOps, but fusion is driven * completely by the provided patterns. * - * \param pattern_names The name of each pattern. It becomes the value of the kComposite attribute - * of a fused function after successful matching. * \param patterns The patterns to detect. The order of the patterns determines the order * of priority in which they are matched. Higher-priority patterns should come earlier in the list. - * \param checks The callback functions with type (Map<DFPattern, Expr>, Expr) -> bool. It takes a - * match result and returns a boolean value to indicate whether the match result is accepted. * \param bind_constants Whether or not to keep bound constants of the grouped function. * \param annotate_codegen If true, wrap each created composite function with another function, * whose body consists only of a call to the composite function, and annotate the outer function @@ -235,9 +326,7 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1); * an external backend without using the MergeCompositeFunctions pass. * \return The Pass. */ -TVM_DLL Pass FuseOpsByPattern(const tvm::Array<runtime::String>& pattern_names, - const tvm::Array<DFPattern>& patterns, - const tvm::Array<PackedFunc>& checks, bool bind_constants = true, +TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants = true, bool annotate_codegen = false); /*! diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 7e92e6a887..47bdcaa790 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -568,11 +568,11 @@ def _extract_arg_idx(pattern_name, f): func_args = list(f.params) arg_idx = {} - for arg_name, arg_pattern in pattern_entry.arg_patterns.items(): - arg_expr = matched_expr[arg_pattern] + for name, annotation_pattern in pattern_entry.annotation_patterns.items(): + arg_expr = matched_expr[annotation_pattern] if arg_expr not in func_args: - raise ValueError(f"Cannot find arg {arg_name} in the fused function parameters") - arg_idx[arg_name] = func_args.index(arg_expr) + continue + arg_idx[name] = func_args.index(arg_expr) return arg_idx diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index e1b9226d68..4d539928cf 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -17,12 +17,12 @@ """Pattern table for CUTLASS backend""" -from typing import Mapping, Optional, Tuple +from typing import Mapping, Optional, Sequence, Tuple import tvm from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul -from tvm.relax import Call, Expr, ShapeExpr, transform -from tvm.relax.dpl import CallPattern, DFPattern +from tvm.relax import ShapeExpr, Var, transform +from tvm.relax.transform import PatternCheckContext from ..pattern_registry import get_patterns_with_prefix, register_patterns from ..patterns import ( @@ -52,33 +52,27 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype): ) -def _find_call(op_name: str, match_result: Mapping[DFPattern, Expr]) -> Optional[Expr]: - result = None +def _has_dependency(from_var: Var, to_var: Var, var_usages: Mapping[Var, Sequence[Var]]): + if from_var == to_var: + return True - for pattern, expr in match_result.items(): - if ( - isinstance(expr, Call) - and isinstance(pattern, CallPattern) - and isinstance(expr.op, tvm.ir.Op) - and expr.op.name == op_name - ): - if result is not None: - raise ValueError(f"Found multiple matched call node for {op_name}") - result = expr + checked = set() + vars_to_check = [to_var] + while vars_to_check: + current_var = vars_to_check.pop() + for user in var_usages.get(current_var, []): + if user == from_var: + return True + if user not in checked: + checked.add(user) + vars_to_check.append(user) - return result + return False -def _check_conv2d( - match_result: Mapping[DFPattern, Expr], - _: Expr, -): +def _check_conv2d(context: PatternCheckContext) -> bool: """Check if the given conv2d workload can be offloaded to CUTLASS.""" - - conv2d_call = _find_call("relax.nn.conv2d", match_result) - if conv2d_call is None: - return False - + conv2d_call = context.annotated_expr["root"] data_layout = conv2d_call.attrs.data_layout kernel_layout = conv2d_call.attrs.kernel_layout data, weight, *_ = conv2d_call.args @@ -89,6 +83,15 @@ def _check_conv2d( ): return False + if "residual" in context.annotated_expr: + residual = context.annotated_expr["residual"] + if not isinstance(residual, Var): + residual = context.value_to_bound_var[residual] + conv2d_var = context.value_to_bound_var[conv2d_call] + if _has_dependency(from_var=residual, to_var=conv2d_var, var_usages=context.var_usages): + # If residual depends on the result of conv2d, this cannot be handled by cutlass. + return False + # pylint: disable=invalid-name IC = data.struct_info.shape.values[3] OC = weight.struct_info.shape.values[0] @@ -96,17 +99,10 @@ def _check_conv2d( return not IC == OC == conv2d_call.attrs.groups -def _check_matmul( - match_result: Mapping[DFPattern, Expr], - _: Expr, -) -> bool: +def _check_matmul(context: PatternCheckContext) -> bool: """Check if the given matmul workload can be offloaded to CUTLASS.""" - - matmul_call: Call = _find_call("relax.matmul", match_result) - if matmul_call is None: - return False - - lhs, rhs, *_ = matmul_call.args + lhs = context.annotated_expr["lhs"] + rhs = context.annotated_expr["rhs"] lhs_dtype = lhs.struct_info.dtype rhs_dtype = rhs.struct_info.dtype @@ -244,7 +240,7 @@ register_patterns( ) -def partition_for_cutlass(mod): +def partition_for_cutlass(mod, annotate_codegen=True): """ Partition the input module into CUTLASS-supported subgraphs. @@ -253,6 +249,11 @@ def partition_for_cutlass(mod): mod: tvm.IRModule The IRModule to be partitioned. + annotate_codegen: bool + Whether to wrap each created composite function with another function, whose + body consists only of a call to the composite function. See the doc of FuseOpsByPattern + for more detail. + Returns ------- mod: tvm.IRModule @@ -260,6 +261,7 @@ def partition_for_cutlass(mod): compiled by the CUTLASS backend. """ - cutlass_pattern_entries = get_patterns_with_prefix("cutlass") - patterns = [(e.name, e.pattern, e.check) for e in cutlass_pattern_entries] - return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) + patterns = get_patterns_with_prefix("cutlass") + return transform.FuseOpsByPattern( + patterns, bind_constants=False, annotate_codegen=annotate_codegen + )(mod) diff --git a/python/tvm/relax/backend/pattern_registry.py b/python/tvm/relax/backend/pattern_registry.py index 5a35eba03d..5ec57164eb 100644 --- a/python/tvm/relax/backend/pattern_registry.py +++ b/python/tvm/relax/backend/pattern_registry.py @@ -20,55 +20,12 @@ import atexit from typing import Callable, List, Mapping, Optional, Set, Tuple, Union -import tvm from tvm.relax.dpl import DFPattern -from tvm.runtime import Object +from tvm.relax.transform import FusionPattern from ..expr import Expr from . import _ffi_api - -@tvm._ffi.register_object("relax.backend.PatternRegistryEntry") -class PatternRegistryEntry(Object): - """ - An entry in the pattern registry. This represents a single pattern that - can be used to identify expressions that can be handled by external - backends, like CUTLASS and TensorRT. - - Parameters - ---------- - name: str - The name of pattern. Usually it starts with the name of backend, like 'cutlass.matmul'. - - pattern: DFPattern - The dataflow pattern that will be used to match expressions that can be handled - by external backends. - - arg_patterns: Mapping[str, DFPattern] - The mapping from arg name to its pattern. It can be used to extract arg expression - from match result. All DFPattern in this map should be part of the `pattern`. - - check: Callable[[Mapping[DFPattern, Expr], Expr], bool] - The function to check whether the match result is accepted. - """ - - name: str - pattern: DFPattern - arg_patterns: Mapping[str, DFPattern] - check: Callable[[Mapping[DFPattern, Expr], Expr], bool] - - def __init__( - self, - name: str, - pattern: DFPattern, - arg_patterns: Mapping[str, DFPattern], - check: Callable[[Mapping[DFPattern, Expr], Expr], bool], - ): - self.__init_handle_by_constructor__( - _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns, check # type: ignore - ) - - _REGISTERED_PATTERN_NAMES: Set[str] = set() @@ -96,7 +53,7 @@ def _ensure_cleanup_function_registered(): CheckFunc = Callable[[Mapping[DFPattern, Expr], Expr], bool] Pattern = Union[ - PatternRegistryEntry, + FusionPattern, Tuple[str, DFPattern], Tuple[str, DFPattern, Mapping[str, DFPattern]], Tuple[str, DFPattern, Mapping[str, DFPattern], CheckFunc], @@ -118,29 +75,17 @@ def register_patterns(patterns: List[Pattern]): entries = [] for item in patterns: - if isinstance(item, PatternRegistryEntry): + if isinstance(item, FusionPattern): entries.append(item) elif isinstance(item, tuple): - name, pattern, *rest = item - - if len(rest) > 0: - arg_patterns = rest[0] - else: - arg_patterns = {} - - if len(rest) > 1: - check = rest[1] - else: - check = lambda *_: True - - entries.append(PatternRegistryEntry(name, pattern, arg_patterns, check)) - _REGISTERED_PATTERN_NAMES.add(name) + entries.append(FusionPattern(*item)) + _REGISTERED_PATTERN_NAMES.add(item[0]) else: - raise TypeError(f"Cannot register type {type(pattern)} as pattern") + raise TypeError(f"Cannot register type {type(item)} as pattern") _ffi_api.RegisterPatterns(entries) -def get_patterns_with_prefix(prefix: str) -> List[PatternRegistryEntry]: +def get_patterns_with_prefix(prefix: str) -> List[FusionPattern]: """ Get a list of patterns whose names startwith `prefix`. @@ -151,13 +96,13 @@ def get_patterns_with_prefix(prefix: str) -> List[PatternRegistryEntry]: Returns ------- - patterns: PatternRegistryEntry + patterns: FusionPattern Matched patterns, ordered by priority from high to low. """ return _ffi_api.GetPatternsWithPrefix(prefix) -def get_pattern(name: str) -> Optional[PatternRegistryEntry]: +def get_pattern(name: str) -> Optional[FusionPattern]: """ Find the pattern with a particular name. @@ -168,7 +113,7 @@ def get_pattern(name: str) -> Optional[PatternRegistryEntry]: Returns ------- - pattern: Optional[PatternRegistryEntry] + pattern: Optional[FusionPattern] The matched pattern. Returns None if such pattern is not found. """ return _ffi_api.GetPattern(name) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index d770cc6faf..e27b91b3ea 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -24,18 +24,18 @@ from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard def _with_bias_activation_pattern( out: DFPattern, - args: Dict[str, DFPattern], + annotations: Dict[str, DFPattern], with_bias: bool = False, activation: str = None, ) -> Tuple[DFPattern, Mapping[str, DFPattern]]: if with_bias: - args["bias"] = bias = wildcard() + annotations["bias"] = bias = wildcard() out = is_op("relax.add")(out, bias) if activation: out = is_op(activation)(out) - return out, args + return out, annotations def make_fused_bias_activation_pattern( @@ -62,16 +62,17 @@ def make_fused_bias_activation_pattern( pattern: DFPattern The resulting pattern describing a fused operation - args: Mapping[str, DFPattern] - The mapping from arg name to its pattern. It can be used to extract - arg expression from match result. + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. """ lhs = wildcard() rhs = wildcard() - args = {"lhs": lhs, "rhs": rhs} out = is_op(op_name)(lhs, rhs) + annotations = {"lhs": lhs, "rhs": rhs, "root": out} - return _with_bias_activation_pattern(out, args, with_bias, activation) + return _with_bias_activation_pattern(out, annotations, with_bias, activation) def make_residual_block_pattern( @@ -99,9 +100,10 @@ def make_residual_block_pattern( pattern: DFPattern The resulting pattern describing a matrix multiplication. - args: Mapping[str, DFPattern] - The mapping from arg name to its pattern. It can be used to extract - arg expression from match result. + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. """ if isinstance(node_output, tuple): @@ -143,21 +145,23 @@ def make_matmul_pattern( pattern: DFPattern The resulting pattern describing a matrix multiplication. - args: Mapping[str, DFPattern] - The mapping from arg name to its pattern. It can be used to extract - arg expression from match result. + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. """ lhs = wildcard() rhs = wildcard() - args = {"lhs": lhs, "rhs": rhs} + annotations = {"lhs": lhs, "rhs": rhs} if transposed_rhs: rhs = is_op("relax.permute_dims")(rhs) out = is_op("relax.matmul")(lhs, rhs) + annotations["root"] = out - return _with_bias_activation_pattern(out, args, with_bias, activation) + return _with_bias_activation_pattern(out, annotations, with_bias, activation) def make_attention_pattern(with_bias: bool = False): @@ -169,19 +173,20 @@ def make_attention_pattern(with_bias: bool = False): pattern: DFPattern The resulting pattern describing a fused multi head attention. - args: Mapping[str, DFPattern] - The mapping from arg name to its pattern. It can be used to extract - arg expression from match result. + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. """ query = wildcard() key = wildcard() value = wildcard() - args = {"query": query, "key": key, "value": value} + annotations = {"query": query, "key": key, "value": value} if with_bias: bias = wildcard() - args["bias"] = bias + annotations["bias"] = bias out = is_op("relax.nn.attention_bias")(query, key, value, bias) else: out = is_op("relax.nn.attention")(query, key, value) - return out, args + return out, annotations diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index c59104ca58..0df29dc093 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -19,11 +19,16 @@ import functools import inspect import types -from typing import Callable, Dict, Union, Optional, List, Tuple -from tvm.tir import PrimFunc, IndexMap +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union + import numpy as np # type: ignore + import tvm.ir -from tvm.runtime import NDArray +from tvm.relax import Expr, Var +from tvm.relax.dpl import DFPattern +from tvm.runtime import NDArray, Object +from tvm.tir import IndexMap, PrimFunc + from . import _ffi_api from .legalize_ops.common import LegalizeFunc @@ -283,8 +288,75 @@ def FuseTIR() -> tvm.ir.transform.Pass: return _ffi_api.FuseTIR() # type: ignore +@tvm._ffi.register_object("relax.transform.PatternCheckContext") +class PatternCheckContext(Object): + """ + The input of check function `FusionPattern.check`. + + Parameters + ---------- + annotated_expr: Mapping[str, Expr] + A map which contains all expressions matched by the sub patterns in + FusionPattern.annotation_patterns. + + var_usages: Mapping[Var, Sequence[Var]] + A map mapping variable definitions to a set of uses. + + value_to_bound_var: Mapping[Expr, Var] + Map from value to its bound variable. + """ + + annotated_expr: Mapping[str, Expr] + var_usages: Mapping[Var, Sequence[Var]] + value_to_bound_var: Mapping[Expr, Var] + + +@tvm._ffi.register_object("relax.transform.FusionPattern") +class FusionPattern(Object): + """ + The pattern used by `FuseOpsByPattern`. It's mainly DFPattern but with other + information to help during the fusion pass. + + Parameters + ---------- + name: str + The name of pattern. Usually it starts with the name of backend, like 'cutlass.matmul'. + + pattern: DFPattern + The dataflow pattern that will be used to match expressions that can be handled + by external backends. + + annotation_patterns: Mapping[str, DFPattern] + The map which is used to extract important expressions from the pattern match + result. All DFPattern in this map should be part of the `pattern`. + + check: Callable[[PatternCheckContext], bool] + The function to check whether the match result is accepted. + """ + + name: str + pattern: DFPattern + annotation_patterns: Mapping[str, DFPattern] + check: Callable[[PatternCheckContext], bool] + + def __init__( + self, + name: str, + pattern: DFPattern, + annotation_patterns: Optional[Mapping[str, DFPattern]] = None, + check: Optional[Callable[[Mapping[str, Expr]], bool]] = None, + ): + if annotation_patterns is None: + annotation_patterns = {} + self.__init_handle_by_constructor__( + _ffi_api.FusionPattern, name, pattern, annotation_patterns, check # type: ignore + ) + + def FuseOpsByPattern( - patterns: List[Tuple], bind_constants: bool = True, annotate_codegen: bool = False + patterns: List[Union[FusionPattern, Tuple]], + bind_constants: bool = True, + annotate_codegen: bool = False, ) -> tvm.ir.transform.Pass: """Apply pattern matching to each function in the given module, and group matched expressions into a new function. @@ -293,15 +365,12 @@ def FuseOpsByPattern( Parameters ---------- - patterns : List[Union[Tuple[str, DFPattern], Tuple[str, DFPattern, Callable]]] - A list of tuple of (name, pattern) or (name, pattern, predicate) to be matched. - The predicate is a function with type (Map<DFPattern, Expr>, Expr) -> bool. It takes a - match result and returns a boolean value to indicate whether the match result is accepted. + patterns : List[Union[FusionPattern, Tuple]] + A list of patterns to be matched. The order of the patterns determines the order of priority + in which they are matched. Higher-priority patterns should come earlier in the list. - The patterns to detect. The order of the patterns determines the order of priority in which - they are matched. Higher-priority patterns should come earlier in the list. - The string is the name of the corresponding pattern. It becomes the value of the kComposite - attribute of a fused function after a successful matching. + In addition to FusionPattern, a tuple can be passed as item of this list. The pattern + will be constructed through FusionPattern(*item) bind_constants : bool Whether or not to keep bound constants in the grouped function. @@ -321,22 +390,19 @@ def FuseOpsByPattern( The registered pass for pattern-based fusion. """ - pattern_names = [] - df_patterns = [] - checks = [] - for tup in patterns: - if len(tup) == 2: - pattern_names.append(tup[0]) - df_patterns.append(tup[1]) - checks.append(lambda *_: True) - elif len(tup) == 3: - pattern_names.append(tup[0]) - df_patterns.append(tup[1]) - checks.append(tup[2]) + converted_patterns = [] + for pattern in patterns: + if isinstance(pattern, tuple): + converted_patterns.append(FusionPattern(*pattern)) + elif isinstance(pattern, FusionPattern): + converted_patterns.append(pattern) else: - raise ValueError("Invalid pattern: {}".format(tup)) + raise ValueError(f"Invalid pattern: {pattern}") + return _ffi_api.FuseOpsByPattern( - pattern_names, df_patterns, checks, bind_constants, annotate_codegen + converted_patterns, + bind_constants, + annotate_codegen, ) # type: ignore diff --git a/src/relax/backend/pattern_registry.cc b/src/relax/backend/pattern_registry.cc index 553018d690..34ebb4d6dd 100644 --- a/src/relax/backend/pattern_registry.cc +++ b/src/relax/backend/pattern_registry.cc @@ -24,25 +24,12 @@ namespace tvm { namespace relax { namespace backend { - -PatternRegistryEntry::PatternRegistryEntry(String name, DFPattern pattern, - Map<String, DFPattern> arg_patterns, PackedFunc check) { - ObjectPtr<PatternRegistryEntryNode> n = make_object<PatternRegistryEntryNode>(); - n->name = std::move(name); - n->pattern = std::move(pattern); - n->arg_patterns = std::move(arg_patterns); - n->check = check; - data_ = std::move(n); -} - -TVM_REGISTER_NODE_TYPE(PatternRegistryEntryNode); - -static std::vector<PatternRegistryEntry>* GetRegistryTable() { - static std::vector<PatternRegistryEntry> table; +static std::vector<FusionPattern>* GetRegistryTable() { + static std::vector<FusionPattern> table; return &table; } -void RegisterPatterns(Array<PatternRegistryEntry> entries) { +void RegisterPatterns(Array<FusionPattern> entries) { auto* table = GetRegistryTable(); for (const auto& entry : entries) { table->push_back(entry); @@ -53,16 +40,15 @@ void RemovePatterns(Array<String> names) { std::unordered_set<String> name_set{names.begin(), names.end()}; auto* table = GetRegistryTable(); - table->erase(std::remove_if(table->begin(), table->end(), - [&](const PatternRegistryEntry& entry) { - return name_set.count(entry->name) > 0; - }), - table->end()); + table->erase( + std::remove_if(table->begin(), table->end(), + [&](const FusionPattern& entry) { return name_set.count(entry->name) > 0; }), + table->end()); } -Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix) { +Array<FusionPattern> GetPatternsWithPrefix(const String& prefix) { auto* table = GetRegistryTable(); - Array<PatternRegistryEntry> result; + Array<FusionPattern> result; for (auto it = table->rbegin(); it != table->rend(); ++it) { if (support::StartsWith((*it)->name, prefix.data())) { result.push_back(*it); @@ -71,7 +57,7 @@ Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix) { return result; } -Optional<PatternRegistryEntry> GetPattern(const String& pattern_name) { +Optional<FusionPattern> GetPattern(const String& pattern_name) { auto* table = GetRegistryTable(); for (auto it = table->rbegin(); it != table->rend(); ++it) { if ((*it)->name == pattern_name) { @@ -81,11 +67,6 @@ Optional<PatternRegistryEntry> GetPattern(const String& pattern_name) { return NullOpt; } -TVM_REGISTER_GLOBAL("relax.backend.PatternRegistryEntry") - .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern> arg_patterns, - PackedFunc check) { - return PatternRegistryEntry(name, pattern, arg_patterns, check); - }); TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns); TVM_REGISTER_GLOBAL("relax.backend.RemovePatterns").set_body_typed(RemovePatterns); TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix); diff --git a/src/relax/backend/pattern_registry.h b/src/relax/backend/pattern_registry.h index e765f56b4e..72eea1238d 100644 --- a/src/relax/backend/pattern_registry.h +++ b/src/relax/backend/pattern_registry.h @@ -28,6 +28,7 @@ #include <tvm/relax/dataflow_pattern.h> #include <tvm/relax/expr.h> +#include <tvm/relax/transform.h> #include <tvm/runtime/container/optional.h> #include <tvm/runtime/object.h> @@ -35,57 +36,7 @@ namespace tvm { namespace relax { namespace backend { -/*! - * \brief An entry in the pattern registry. This represents a single pattern that - * can be used to identify expressions that can be handled by external - * backends, like CUTLASS and TensorRT. - */ -class PatternRegistryEntryNode : public Object { - public: - /*! - * \brief The name of pattern. Usually it starts with the name of backend, like - * 'cutlass.matmul'. - */ - String name; - /*! - * \brief The dataflow pattern that will be used to match expressions that can - * be handled by external backends. - */ - DFPattern pattern; - /*! - * \brief The mapping from arg name to its pattern. It can be used to extract - * arg expression from match result. All DFPattern in this map should be part of - * the `pattern`. - */ - Map<String, DFPattern> arg_patterns; - - /*! - * \brief The function to check whether the match result is accepted. - * - * It should have signature - * bool(const Map<DFPattern, Expr>& match_result, const Expr& matched_expr) - */ - PackedFunc check; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("pattern", &pattern); - v->Visit("arg_patterns", &arg_patterns); - v->Visit("check", &check); - } - - static constexpr const char* _type_key = "relax.backend.PatternRegistryEntry"; - TVM_DECLARE_FINAL_OBJECT_INFO(PatternRegistryEntryNode, Object); -}; - -class PatternRegistryEntry : public ObjectRef { - public: - PatternRegistryEntry(String name, DFPattern pattern, Map<String, DFPattern> arg_patterns, - PackedFunc check); - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternRegistryEntry, ObjectRef, - PatternRegistryEntryNode); -}; +using transform::FusionPattern; /*! * \brief Register patterns which will be used to partition the DataflowBlock @@ -93,7 +44,7 @@ class PatternRegistryEntry : public ObjectRef { * \param patterns Patterns to be registered. Patterns that appear later in the list have * higher priority when partitioning DataflowBlock. */ -void RegisterPatterns(Array<PatternRegistryEntry> entries); +void RegisterPatterns(Array<FusionPattern> patterns); /*! * \brief Remove patterns from the registry by their name. @@ -106,14 +57,14 @@ void RemovePatterns(Array<String> names); * \param prefx The pattern name prefix. * \return Matched patterns, ordered by priority from high to low. */ -Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix); +Array<FusionPattern> GetPatternsWithPrefix(const String& prefix); /*! * \brief Find the pattern with a particular name. * \param name The pattern name. * \return The matched pattern. NullOpt if not found. */ -Optional<PatternRegistryEntry> GetPattern(const String& name); +Optional<FusionPattern> GetPattern(const String& name); } // namespace backend } // namespace relax diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 6d7c278d80..76f53eebc5 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -40,6 +40,7 @@ #include "../../relay/analysis/graph_partitioner.h" #include "../../support/arena.h" +#include "tvm/relax/expr.h" namespace tvm { namespace relax { @@ -905,18 +906,30 @@ class PatternBasedPartitioner : ExprVisitor { using Group = GraphPartitioner::Group; using GroupMap = OperatorFusor::GroupMap; using ExprVisitor::VisitExpr_; - using FCheckMatch = runtime::TypedPackedFunc<bool(const Map<DFPattern, Expr>&, const Expr&)>; + using FCheckMatch = runtime::TypedPackedFunc<bool(const transform::PatternCheckContext&)>; - static GroupMap Run(String pattern_name, DFPattern pattern, FCheckMatch check, Expr expr, + static GroupMap Run(String pattern_name, DFPattern pattern, + Map<String, DFPattern> annotation_patterns, FCheckMatch check, Expr expr, support::Arena* arena) { - PatternBasedPartitioner part(pattern_name, pattern, check, arena); + PatternBasedPartitioner part(pattern_name, pattern, annotation_patterns, check, arena); part.VisitExpr(expr); return part.group_map_; } - PatternBasedPartitioner(String pattern_name, DFPattern pattern, FCheckMatch check, + PatternBasedPartitioner(String pattern_name, DFPattern pattern, + Map<String, DFPattern> annotation_patterns, FCheckMatch check, support::Arena* arena) - : pat_name_(pattern_name), pat_(pattern), check_(check), arena_(arena) {} + : pat_name_(pattern_name), + pat_(pattern), + annotation_pat_(annotation_patterns), + check_(check), + arena_(arena) {} + + void VisitBindingBlock_(const DataflowBlockNode* block) final { + current_block_use_def_ = DataflowBlockUseDef(GetRef<DataflowBlock>(block)); + ExprVisitor::VisitBindingBlock_(block); + current_block_use_def_ = {}; + } void VisitVarDef(const Var& var) final { group_map_[var.get()] = arena_->make<Group>(); } @@ -931,7 +944,9 @@ class PatternBasedPartitioner : ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { VisitVarDef(binding->var); if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef<Call>(call), bindings_)) { - if (!check_(matches_opt.value(), GetRef<Call>(call))) { + if (check_ != nullptr && + !check_(transform::PatternCheckContext(GetAnnotatedExpr(matches_opt.value()), + current_block_use_def_, value_to_bound_var_))) { return; } // If a match is found, put all matching expressions into the same group. @@ -975,12 +990,24 @@ class PatternBasedPartitioner : ExprVisitor { return group_map_[bound_var.get()]->FindRoot(); } + Map<String, Expr> GetAnnotatedExpr(const Map<DFPattern, Expr> matched_result) { + Map<String, Expr> annotated_expr; + for (const auto& it : annotation_pat_) { + if (matched_result.count(it.second)) { + annotated_expr.Set(it.first, matched_result[it.second]); + } + } + return annotated_expr; + } + String pat_name_; DFPattern pat_; + Map<String, DFPattern> annotation_pat_; FCheckMatch check_; support::Arena* arena_; Map<Var, Expr> bindings_; Map<Expr, Var> value_to_bound_var_; + Map<Var, Array<Var>> current_block_use_def_; GroupMap group_map_; }; @@ -1054,19 +1081,18 @@ class CompositeFunctionAnnotator : public ExprMutator { std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_; }; -IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names, - const tvm::Array<DFPattern>& patterns, - const tvm::Array<runtime::PackedFunc>& checks, IRModule mod, +IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns, IRModule mod, bool bind_constants, bool annotate_codegen) { support::Arena arena; - for (size_t i = 0; i < pattern_names.size(); ++i) { + for (const auto& pattern : patterns) { OperatorFusor::GroupMap group_map; for (const auto& entry : mod->functions) { if (entry.second->IsInstance<tir::PrimFuncNode>()) { continue; } - auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i], checks[i], - entry.second, &arena); + auto map = PatternBasedPartitioner::Run( + pattern->name, pattern->pattern, pattern->annotation_patterns, + pattern->check.value_or(nullptr), entry.second, &arena); group_map.insert(map.begin(), map.end()); } mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants); @@ -1079,6 +1105,36 @@ IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names, namespace transform { +FusionPattern::FusionPattern(String name, DFPattern pattern, + Map<String, DFPattern> annotation_patterns, + Optional<PackedFunc> check) { + ObjectPtr<FusionPatternNode> n = make_object<FusionPatternNode>(); + n->name = std::move(name); + n->pattern = std::move(pattern); + n->annotation_patterns = std::move(annotation_patterns); + n->check = check; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(FusionPatternNode); +TVM_REGISTER_GLOBAL("relax.transform.FusionPattern") + .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern> annotation_patterns, + Optional<PackedFunc> check) { + return FusionPattern(name, pattern, annotation_patterns, check); + }); + +PatternCheckContext::PatternCheckContext(Map<String, Expr> annotated_expr, + Map<Var, Array<Var>> var_usages, + Map<Expr, Var> value_to_bound_var) { + ObjectPtr<PatternCheckContextNode> n = make_object<PatternCheckContextNode>(); + n->annotated_expr = std::move(annotated_expr); + n->var_usages = std::move(var_usages); + n->value_to_bound_var = std::move(value_to_bound_var); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PatternCheckContextNode); + Pass FuseOps(int fuse_opt_level) { runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = // [=](IRModule m, PassContext pc) { @@ -1094,14 +1150,11 @@ Pass FuseOps(int fuse_opt_level) { TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); -Pass FuseOpsByPattern(const tvm::Array<String>& pattern_names, - const tvm::Array<DFPattern>& patterns, - const tvm::Array<runtime::PackedFunc>& checks, bool bind_constants, +Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants, bool annotate_codegen) { runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = // [=](IRModule m, PassContext pc) { - return relax::FuseOpsByPattern(pattern_names, patterns, checks, m, bind_constants, - annotate_codegen); + return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen); }; return CreateModulePass(/*pass_function=*/pass_func, // /*opt_level=*/0, // diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index de15f7083a..0bae6801ca 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -26,7 +26,6 @@ from tvm.contrib.pickle_memoize import memoize from tvm.relax.backend import get_patterns_with_prefix from tvm.relax.backend.contrib.cutlass import partition_for_cutlass from tvm.script import relax as R -from tvm.script import tir as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder @@ -296,6 +295,43 @@ def test_conv2d_offload(data_shape, weight_shape, dtype, epilogue, residual_bloc tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) +def test_cutlass_partition_conv2d_residual_blocked(): + @tvm.script.ir_module + class Conv2dReLU: + """ + This conv2d should not be fused as conv2d residual block, because both lhs and rhs of + the last R.add depends on the result of conv2d. + """ + + @R.function + def main( + data: R.Tensor((32, 3, 3, 16), "float32"), + weight: R.Tensor((16, 3, 3, 16), "float32"), + bias: R.Tensor((1, 1, 1, 16), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d( + data, + weight, + padding=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + ) + out = R.nn.relu(conv1 + bias) + # residual depends on conv result, which cannot be handled in cutlass + result = out + out + R.output(result) + + return result + + mod = partition_for_cutlass(Conv2dReLU, annotate_codegen=False) + for f_var in mod.functions: + func = mod[f_var] + if func.attrs and "Composite" in func.attrs: + # verify that the function is not fused as residual block + assert func.attrs["Composite"] == "cutlass.conv2d_bias_relu" + + @pytest.mark.parametrize( "x_shape, y_shape, transpose_y, epilogue, residual_block", [ @@ -451,6 +487,7 @@ def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y, dtype): mod = get_relax_matmul_module( x_shape, y_shape, dtype, with_bias=False, transposed_y=transpose_y ) + mod = partition_for_cutlass(mod) assert len(mod.functions) == 1 diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 3816e11bc5..2f3e2d479f 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -14,14 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest import numpy as np +import pytest import tvm - from tvm import relax -from tvm.script import relax as R, tir as T, ir as I -from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern, is_op, wildcard +from tvm.relax.dpl.pattern import is_op, make_fused_bias_activation_pattern, wildcard +from tvm.relax.transform import PatternCheckContext +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T @tvm.script.ir_module @@ -600,13 +602,23 @@ def test_unused(): def test_check_pattern(): - pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) - - def pred(match, expr): + lhs = wildcard() + rhs = wildcard() + out = is_op("relax.nn.conv2d")(lhs, rhs) + annotation_patterns = {"root": out, "lhs": lhs, "rhs": rhs} + + def pred(context: PatternCheckContext): + lhs = context.annotated_expr["lhs"] + rhs = context.annotated_expr["rhs"] + expr = context.annotated_expr["root"] + assert isinstance(lhs, relax.expr.Var) and lhs.name_hint == "data" + assert isinstance(rhs, relax.expr.Var) and rhs.name_hint == "weight1" assert isinstance(expr, relax.expr.Call) and expr.op.name == "relax.nn.conv2d" - return expr.struct_info.dtype == "float32" + return False - check(Conv2dx2, [("cutlass.conv2d", pat, pred)], Conv2dx2) # expect no partitioning + check( + Conv2dReLU, [("cutlass.conv2d", out, annotation_patterns, pred)], Conv2dReLU + ) # expect no partitioning def test_bind_constants():