comaniac commented on pull request #5689:
URL: https://github.com/apache/incubator-tvm/pull/5689#issuecomment-635649921


   I made an example and tested it with old `MergeComposite` pass (master 
commit 6100112a150540588ddc9abb36dea0ff961f4301). It behaves as I expected. The 
partitioned composite function still has 3 arguments even `%w` has been bind. 
@mbaret could you double check?
   
   ```python
   import tvm
   from tvm import relay
   from tvm import tir
   from tvm.relay.testing import run_opt_pass
   
   from tvm.relay.build_module import bind_params_by_name
   import numpy as np
   
   
   # Make a graph
   x = relay.var('x', shape=(1, 3, 224, 224))
   w = relay.var('w', shape=(3, 3, 3, 3))
   b = relay.var('b', shape=(3,))
   
   conv2d = relay.op.nn.conv2d(x, w)
   out = relay.op.nn.bias_add(conv2d, b)
   func = relay.Function([x, w, b], out)
   mod = tvm.IRModule.from_expr(func)
   
   mod["main"] = bind_params_by_name(mod["main"],
                                     {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 
3)))})
   print('=== Before ===')
   print(mod['main'].body)
   
   
   def pat():
       x = relay.var('x', shape=(1, 3, 224, 224))
       w = relay.var('w', shape=(3, 3, 3, 3))
       b = relay.var('b', shape=(3,))
   
       conv2d = relay.op.nn.conv2d(x, w)
       out = relay.op.nn.bias_add(conv2d, b)
       return out
   
   pattern_table = [('pat', pat())]
   result = run_opt_pass(mod['main'], 
relay.transform.MergeComposite(pattern_table))
   print('=== After ===')
   print(result)
   ```
   
   ```
   === Before ===
   free_var %x: Tensor[(1, 3, 224, 224), float32]
   %0 = nn.conv2d(%x, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 3, 3), 
float64] */ /* ty=Tensor[(3, 3, 3, 3), float64] */, padding=[0, 0, 0, 0]) /* 
ty=Tensor[(1, 3, 222, 222), float32] */;
   free_var %b: Tensor[(3), float32]
   nn.bias_add(%0, %b) /* ty=Tensor[(1, 3, 222, 222), float32] */
   // meta data omitted. you can use show_meta_data=True to include meta data
   === After ===
   fn (%x: Tensor[(1, 3, 224, 224), float32], %b: Tensor[(3), float32]) -> 
Tensor[(1, 3, 222, 222), float32] {
     %1 = fn (%x1: Tensor[(1, 3, 224, 224), float32], %w: Tensor[(3, 3, 3, 3), 
float64], %b1: Tensor[(3), float32], Composite="pat") -> Tensor[(1, 3, 222, 
222), float32] {
       %0 = nn.conv2d(%x1, %w, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 3, 222, 
222), float32] */;
       nn.bias_add(%0, %b1) /* ty=Tensor[(1, 3, 222, 222), float32] */
     };
     %1(%x, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 3, 3), float64] */ /* 
ty=Tensor[(3, 3, 3, 3), float64] */, %b) /* ty=Tensor[(1, 3, 222, 222), 
float32] */
   }
   // meta data omitted. you can use show_meta_data=True to include meta data
   
   ```
   
   @masahi I just checked with this PR. The unit test hits the line you pointed 
out twice.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to