This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 3b33caf757 [Unity] Cover all Relax functions in implicit attention rewrite (#14818) 3b33caf757 is described below commit 3b33caf757163c2e360e6f6f978e907ba183bcbf Author: Lite Ye <yelite...@gmail.com> AuthorDate: Thu May 11 01:27:23 2023 -0400 [Unity] Cover all Relax functions in implicit attention rewrite (#14818) * Rewrite all functions in attention op rewriting * Fix lint --- python/tvm/relax/backend/contrib/cutlass.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index d5940ac5e4..19fc2a39ea 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -17,24 +17,24 @@ """Pattern table for CUTLASS backend""" import operator -from typing import Mapping, Sequence from functools import reduce +from typing import Mapping, Sequence import tvm from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul -from tvm.relax import DataflowVar, Var, transform, Call, PyExprMutator, expr_functor, Function -from tvm.relax.transform import PatternCheckContext +from tvm.relax import Call, DataflowVar, Function, PyExprMutator, Var, expr_functor, transform from tvm.relax.dpl import rewrite_call +from tvm.relax.transform import PatternCheckContext from ..pattern_registry import get_patterns_with_prefix, register_patterns from ..patterns import ( make_attention_pattern, + make_attention_rewrite_pattern, make_fused_bias_activation_pattern, + make_layer_norm_pattern, make_matmul_pattern, make_residual_block_pattern, make_stacked_attention_pattern, - make_layer_norm_pattern, - make_attention_rewrite_pattern, ) @@ -435,8 +435,10 @@ def partition_for_cutlass(mod, annotate_codegen=True): The resulting IRModule, containing partitioned subgraphs to be compiled by the CUTLASS backend. """ - for pattern, rewriter in _REWRITE_PATTERNS: - mod["main"] = rewrite_call(pattern, rewriter, mod["main"]) + for func_name, func in mod.functions.items(): + if isinstance(func, Function): + for pattern, rewriter in _REWRITE_PATTERNS: + mod[func_name] = rewrite_call(pattern, rewriter, func) patterns = get_patterns_with_prefix("cutlass") return tvm.transform.Sequential( [