This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new f7835a6f80 [Unity][CUTLASS] Require the residual input to have the same shape as input (#14657) f7835a6f80 is described below commit f7835a6f808183aaae2ac35ed3e8769884be0173 Author: masahi <masahi...@gmail.com> AuthorDate: Thu Apr 20 04:02:04 2023 +0900 [Unity][CUTLASS] Require the residual input to have the same shape as input (#14657) Require residual input to have the same shape as input --- python/tvm/relax/backend/contrib/cutlass.py | 35 +++++++++++++++++++++-------- tests/python/relax/test_codegen_cutlass.py | 31 +++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 0c2f38e300..7d6dc6bf89 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -20,7 +20,7 @@ from typing import Mapping, Sequence from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul -from tvm.relax import DataflowVar, Var, transform +from tvm.relax import DataflowVar, Var, transform, Call from tvm.relax.transform import PatternCheckContext from ..pattern_registry import get_patterns_with_prefix, register_patterns @@ -82,6 +82,26 @@ def _has_dependency(from_var: Var, to_var: Var, var_usages: Mapping[Var, Sequenc return False +def _check_residual(root_call: Call, context: PatternCheckContext) -> bool: + if "residual" in context.annotated_expr: + residual = context.annotated_expr["residual"] + if not isinstance(residual, Var): + residual = context.value_to_bound_var[residual] + + root_var = context.value_to_bound_var[root_call] + if _has_dependency(from_var=residual, to_var=root_var, var_usages=context.var_usages): + # If residual depends on the result of the root call, this cannot be handled by cutlass. + return False + + shape1 = [int(s) for s in root_var.struct_info.shape] + shape2 = [int(s) for s in residual.struct_info.shape] + + if shape1 != shape2: + return False + + return True + + def _check_conv2d(context: PatternCheckContext) -> bool: """Check if the given conv2d workload can be offloaded to CUTLASS.""" if _has_leaking_intermediate_variables(context): @@ -98,14 +118,8 @@ def _check_conv2d(context: PatternCheckContext) -> bool: ): return False - if "residual" in context.annotated_expr: - residual = context.annotated_expr["residual"] - if not isinstance(residual, Var): - residual = context.value_to_bound_var[residual] - conv2d_var = context.value_to_bound_var[conv2d_call] - if _has_dependency(from_var=residual, to_var=conv2d_var, var_usages=context.var_usages): - # If residual depends on the result of conv2d, this cannot be handled by cutlass. - return False + if not _check_residual(conv2d_call, context): + return False # pylint: disable=invalid-name IC = data.struct_info.shape.values[3] @@ -127,6 +141,9 @@ def _check_matmul(context: PatternCheckContext) -> bool: if not _is_supported_dtype(lhs_dtype, rhs_dtype): return False + if not _check_residual(lhs, context): + return False + lhs_shape = lhs.struct_info.shape.values rhs_shape = rhs.struct_info.shape.values return is_shape_valid_for_cutlass_matmul(lhs_shape, rhs_shape) diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index db8abf34c2..a5fbf0f642 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -744,5 +744,36 @@ def test_stacked_attention_strided_slice_offload(stacked_attention_size): tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) +def test_invalid_residual(): + @tvm.script.ir_module + class Module: + @R.function + def main( + x: R.Tensor((2, 64, 64, 8), dtype="float16"), + w: R.Tensor((8, 3, 3, 8), dtype="float16"), + bias: R.Tensor((1, 1, 8), dtype="float16"), + residual: R.Tensor((2, 1, 1, 8), dtype="float16"), + ) -> R.Tensor((1, 256, 64, 64), dtype="float16"): + with R.dataflow(): + conv = R.nn.conv2d( + x, + w, + padding=[1, 1, 1, 1], + out_dtype="float16", + data_layout="NHWC", + kernel_layout="OHWI", + ) + bias_out = R.add(conv, bias) + out = R.add(bias_out, residual) + R.output(out) + return out + + rewritten = partition_for_cutlass(Module) + func_names = [gv.name_hint for gv in rewritten.functions.keys()] + + assert "fused_relax_nn_conv2d_relax_add_relax_add_cutlass" not in func_names + assert "fused_relax_nn_conv2d_relax_add_cutlass" in func_names + + if __name__ == "__main__": tvm.testing.main()