comaniac commented on pull request #5689: URL: https://github.com/apache/incubator-tvm/pull/5689#issuecomment-635649921
I made an example and tested it with old `MergeComposite` pass (master commit 6100112a150540588ddc9abb36dea0ff961f4301). It behaves as I expected. The partitioned composite function still has 3 arguments even `%w` has been bind. @mbaret could you double check? ```python import tvm from tvm import relay from tvm import tir from tvm.relay.testing import run_opt_pass from tvm.relay.build_module import bind_params_by_name import numpy as np # Make a graph x = relay.var('x', shape=(1, 3, 224, 224)) w = relay.var('w', shape=(3, 3, 3, 3)) b = relay.var('b', shape=(3,)) conv2d = relay.op.nn.conv2d(x, w) out = relay.op.nn.bias_add(conv2d, b) func = relay.Function([x, w, b], out) mod = tvm.IRModule.from_expr(func) mod["main"] = bind_params_by_name(mod["main"], {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))}) print('=== Before ===') print(mod['main'].body) def pat(): x = relay.var('x', shape=(1, 3, 224, 224)) w = relay.var('w', shape=(3, 3, 3, 3)) b = relay.var('b', shape=(3,)) conv2d = relay.op.nn.conv2d(x, w) out = relay.op.nn.bias_add(conv2d, b) return out pattern_table = [('pat', pat())] result = run_opt_pass(mod['main'], relay.transform.MergeComposite(pattern_table)) print('=== After ===') print(result) ``` ``` === Before === free_var %x: Tensor[(1, 3, 224, 224), float32] %0 = nn.conv2d(%x, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 3, 3), float64] */ /* ty=Tensor[(3, 3, 3, 3), float64] */, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 3, 222, 222), float32] */; free_var %b: Tensor[(3), float32] nn.bias_add(%0, %b) /* ty=Tensor[(1, 3, 222, 222), float32] */ // meta data omitted. you can use show_meta_data=True to include meta data === After === fn (%x: Tensor[(1, 3, 224, 224), float32], %b: Tensor[(3), float32]) -> Tensor[(1, 3, 222, 222), float32] { %1 = fn (%x1: Tensor[(1, 3, 224, 224), float32], %w: Tensor[(3, 3, 3, 3), float64], %b1: Tensor[(3), float32], Composite="pat") -> Tensor[(1, 3, 222, 222), float32] { %0 = nn.conv2d(%x1, %w, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 3, 222, 222), float32] */; nn.bias_add(%0, %b1) /* ty=Tensor[(1, 3, 222, 222), float32] */ }; %1(%x, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 3, 3), float64] */ /* ty=Tensor[(3, 3, 3, 3), float64] */, %b) /* ty=Tensor[(1, 3, 222, 222), float32] */ } // meta data omitted. you can use show_meta_data=True to include meta data ``` @masahi I just checked with this PR. The unit test hits the line you pointed out twice. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org