This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new f7835a6f80 [Unity][CUTLASS] Require the residual input to have the 
same shape as input (#14657)
f7835a6f80 is described below

commit f7835a6f808183aaae2ac35ed3e8769884be0173
Author: masahi <masahi...@gmail.com>
AuthorDate: Thu Apr 20 04:02:04 2023 +0900

    [Unity][CUTLASS] Require the residual input to have the same shape as input 
(#14657)
    
    Require residual input to have the same shape as input
---
 python/tvm/relax/backend/contrib/cutlass.py | 35 +++++++++++++++++++++--------
 tests/python/relax/test_codegen_cutlass.py  | 31 +++++++++++++++++++++++++
 2 files changed, 57 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 0c2f38e300..7d6dc6bf89 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -20,7 +20,7 @@
 from typing import Mapping, Sequence
 
 from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import DataflowVar, Var, transform
+from tvm.relax import DataflowVar, Var, transform, Call
 from tvm.relax.transform import PatternCheckContext
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
@@ -82,6 +82,26 @@ def _has_dependency(from_var: Var, to_var: Var, var_usages: 
Mapping[Var, Sequenc
     return False
 
 
+def _check_residual(root_call: Call, context: PatternCheckContext) -> bool:
+    if "residual" in context.annotated_expr:
+        residual = context.annotated_expr["residual"]
+        if not isinstance(residual, Var):
+            residual = context.value_to_bound_var[residual]
+
+        root_var = context.value_to_bound_var[root_call]
+        if _has_dependency(from_var=residual, to_var=root_var, 
var_usages=context.var_usages):
+            # If residual depends on the result of the root call, this cannot 
be handled by cutlass.
+            return False
+
+        shape1 = [int(s) for s in root_var.struct_info.shape]
+        shape2 = [int(s) for s in residual.struct_info.shape]
+
+        if shape1 != shape2:
+            return False
+
+    return True
+
+
 def _check_conv2d(context: PatternCheckContext) -> bool:
     """Check if the given conv2d workload can be offloaded to CUTLASS."""
     if _has_leaking_intermediate_variables(context):
@@ -98,14 +118,8 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
     ):
         return False
 
-    if "residual" in context.annotated_expr:
-        residual = context.annotated_expr["residual"]
-        if not isinstance(residual, Var):
-            residual = context.value_to_bound_var[residual]
-        conv2d_var = context.value_to_bound_var[conv2d_call]
-        if _has_dependency(from_var=residual, to_var=conv2d_var, 
var_usages=context.var_usages):
-            # If residual depends on the result of conv2d, this cannot be 
handled by cutlass.
-            return False
+    if not _check_residual(conv2d_call, context):
+        return False
 
     # pylint: disable=invalid-name
     IC = data.struct_info.shape.values[3]
@@ -127,6 +141,9 @@ def _check_matmul(context: PatternCheckContext) -> bool:
     if not _is_supported_dtype(lhs_dtype, rhs_dtype):
         return False
 
+    if not _check_residual(lhs, context):
+        return False
+
     lhs_shape = lhs.struct_info.shape.values
     rhs_shape = rhs.struct_info.shape.values
     return is_shape_valid_for_cutlass_matmul(lhs_shape, rhs_shape)
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index db8abf34c2..a5fbf0f642 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -744,5 +744,36 @@ def 
test_stacked_attention_strided_slice_offload(stacked_attention_size):
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_invalid_residual():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def main(
+            x: R.Tensor((2, 64, 64, 8), dtype="float16"),
+            w: R.Tensor((8, 3, 3, 8), dtype="float16"),
+            bias: R.Tensor((1, 1, 8), dtype="float16"),
+            residual: R.Tensor((2, 1, 1, 8), dtype="float16"),
+        ) -> R.Tensor((1, 256, 64, 64), dtype="float16"):
+            with R.dataflow():
+                conv = R.nn.conv2d(
+                    x,
+                    w,
+                    padding=[1, 1, 1, 1],
+                    out_dtype="float16",
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                )
+                bias_out = R.add(conv, bias)
+                out = R.add(bias_out, residual)
+                R.output(out)
+            return out
+
+    rewritten = partition_for_cutlass(Module)
+    func_names = [gv.name_hint for gv in rewritten.functions.keys()]
+
+    assert "fused_relax_nn_conv2d_relax_add_relax_add_cutlass" not in 
func_names
+    assert "fused_relax_nn_conv2d_relax_add_cutlass" in func_names
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to