This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 9914a77b0f [Unity] Process all Relax functions in CompositeFunctionAnnotator (#14736) 9914a77b0f is described below commit 9914a77b0f9c202b5145d8e292cb6c5b6e5d9545 Author: Lite Ye <yelite...@gmail.com> AuthorDate: Thu Apr 27 20:27:34 2023 -0400 [Unity] Process all Relax functions in CompositeFunctionAnnotator (#14736) Process all Relax functions when annotating codegen --- src/relax/transform/fuse_ops.cc | 19 +++-- .../relax/test_transform_fuse_ops_by_pattern.py | 86 ++++++++++++++++------ 2 files changed, 78 insertions(+), 27 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index f50a578954..ad1dc3eb98 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1079,11 +1079,20 @@ class CompositeFunctionAnnotator : public ExprMutator { IRModule Run() { auto mod = builder_->GetContextIRModule(); - auto gvar = mod->GetGlobalVar("main"); - auto func = Downcast<Function>(mod->Lookup(gvar)); - auto new_func = - Function(func->params, VisitExpr(func->body), func->ret_struct_info, func->attrs); - builder_->UpdateFunction(gvar, new_func); + auto all_functions = mod->functions; + for (const auto& entry : all_functions) { + if (const auto* func = entry.second.as<FunctionNode>()) { + if (func->GetAttr<String>(attr::kComposite).defined()) { + continue; + } + auto new_body = VisitExpr(func->body); + if (!new_body.same_as(func->body)) { + auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, + func->attrs, func->span); + builder_->UpdateFunction(entry.first, new_func); + } + } + } return builder_->GetContextIRModule(); } diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 146c8e1ebc..5fb2b3332c 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -21,9 +21,9 @@ import tvm from tvm import relax from tvm.relax.dpl.pattern import ( is_op, + is_tuple_get_item, make_fused_bias_activation_pattern, wildcard, - is_tuple_get_item, ) from tvm.relax.transform import PatternCheckContext from tvm.script import ir as I @@ -339,7 +339,7 @@ class Branch: @tvm.script.ir_module class Conv2dx2: @R.function - def main( + def main2( data: R.Tensor((16, 32, 32, 16), "float16"), weight1: R.Tensor((16, 3, 3, 16), "float16"), weight2: R.Tensor((16, 3, 3, 16), "float16"), @@ -355,26 +355,28 @@ class Conv2dx2: return conv2 - -@tvm.script.ir_module -class Conv2dx2_partitioned: @R.function def main( - data: R.Tensor((16, 32, 32, 16), dtype="float16"), - weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), - weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), - ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): - cls = Conv2dx2_partitioned + data: R.Tensor((16, 32, 32, 16), "float16"), + weight1: R.Tensor((16, 3, 3, 16), "float16"), + weight2: R.Tensor((16, 3, 3, 16), "float16"), + ): with R.dataflow(): - lv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_cutlass( - data, weight1 + conv1 = relax.op.nn.conv2d( + data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" ) - gv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_cutlass( - lv, weight2 + conv2 = relax.op.nn.conv2d( + conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" ) - R.output(gv) - return gv + conv3 = Conv2dx2.main2(data, weight1, weight2) + result = conv2 + conv3 + R.output(result) + + return result + +@tvm.script.ir_module +class Conv2dx2_partitioned: @R.function def fused_relax_nn_conv2d_cutlass( data: R.Tensor((16, 32, 32, 16), dtype="float16"), @@ -383,26 +385,66 @@ class Conv2dx2_partitioned: R.func_attr({"Codegen": "cutlass", "global_symbol": "fused_relax_nn_conv2d_cutlass"}) @R.function - def gv( + def gv_1( data_1: R.Tensor((16, 32, 32, 16), dtype="float16"), weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"), ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1}) with R.dataflow(): - gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = R.nn.conv2d( + gv_2: R.Tensor((16, 32, 32, 16), dtype="float16") = R.nn.conv2d( data_1, weight1_1, + strides=[1, 1], padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC", + out_dtype="void", ) - R.output(gv_1) - return gv_1 + R.output(gv_2) + return gv_2 - gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1) + gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv_1(data, weight1) return gv1 + @R.function + def main2( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + cls = Conv2dx2_partitioned + with R.dataflow(): + lv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_cutlass( + data, weight1 + ) + gv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_cutlass( + lv, weight2 + ) + R.output(gv) + return gv + + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + cls = Conv2dx2_partitioned + with R.dataflow(): + lv1: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_cutlass( + data, weight1 + ) + lv2: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_cutlass( + lv1, weight2 + ) + conv3: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.main2(data, weight1, weight2) + result: R.Tensor((16, 32, 32, 16), dtype="float16") = R.add(lv2, conv3) + R.output(result) + return result + conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation=None) conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation="relax.nn.relu") @@ -478,7 +520,7 @@ def test_annotate_codegen(): ) -def test_multiple_calls_same_extern(): +def test_multiple_entries_multiple_calls_same_extern(): pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) check(Conv2dx2, [("cutlass.conv2d", pat)], Conv2dx2_partitioned, annotate_codegen=True)