This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 3b33caf757 [Unity] Cover all Relax functions in implicit attention 
rewrite (#14818)
3b33caf757 is described below

commit 3b33caf757163c2e360e6f6f978e907ba183bcbf
Author: Lite Ye <yelite...@gmail.com>
AuthorDate: Thu May 11 01:27:23 2023 -0400

    [Unity] Cover all Relax functions in implicit attention rewrite (#14818)
    
    * Rewrite all functions in attention op rewriting
    
    * Fix lint
---
 python/tvm/relax/backend/contrib/cutlass.py | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index d5940ac5e4..19fc2a39ea 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -17,24 +17,24 @@
 
 """Pattern table for CUTLASS backend"""
 import operator
-from typing import Mapping, Sequence
 from functools import reduce
+from typing import Mapping, Sequence
 
 import tvm
 from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import DataflowVar, Var, transform, Call, PyExprMutator, 
expr_functor, Function
-from tvm.relax.transform import PatternCheckContext
+from tvm.relax import Call, DataflowVar, Function, PyExprMutator, Var, 
expr_functor, transform
 from tvm.relax.dpl import rewrite_call
+from tvm.relax.transform import PatternCheckContext
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
 from ..patterns import (
     make_attention_pattern,
+    make_attention_rewrite_pattern,
     make_fused_bias_activation_pattern,
+    make_layer_norm_pattern,
     make_matmul_pattern,
     make_residual_block_pattern,
     make_stacked_attention_pattern,
-    make_layer_norm_pattern,
-    make_attention_rewrite_pattern,
 )
 
 
@@ -435,8 +435,10 @@ def partition_for_cutlass(mod, annotate_codegen=True):
         The resulting IRModule, containing partitioned subgraphs to be
         compiled by the CUTLASS backend.
     """
-    for pattern, rewriter in _REWRITE_PATTERNS:
-        mod["main"] = rewrite_call(pattern, rewriter, mod["main"])
+    for func_name, func in mod.functions.items():
+        if isinstance(func, Function):
+            for pattern, rewriter in _REWRITE_PATTERNS:
+                mod[func_name] = rewrite_call(pattern, rewriter, func)
     patterns = get_patterns_with_prefix("cutlass")
     return tvm.transform.Sequential(
         [

Reply via email to