Hello, I'm trying to partition a Relay graph into some functions and rewrite it
but fail. Here's a minimum working example:
```python
import tvm
import tvm.relay as relay
from tvm.relay.dataflow_pattern import wildcard, is_op, rewrite,
DFPatternCallback, FunctionPattern
class TestCallback(DFPatternCallback):
def __init__(self):
super(TestCallback, self).__init__()
self.x = wildcard()
self.y = wildcard()
pattern = is_op('add')(self.x, self.y)
pattern = FunctionPattern([wildcard(), wildcard()], pattern)
self.pattern = pattern
def callback(self, pre, post, node_map):
print('here')
x = node_map[self.x][0]
y = node_map[self.y][0]
return x - y
x = relay.var('x')
y = relay.var('y')
z = relay.var('z')
expr = (x + y) * z
p = wildcard() + wildcard()
fp = FunctionPattern([wildcard(), wildcard()], p)
print(expr)
expr_p = p.partition(expr)
print(expr_p)
expr_r = rewrite(TestCallback(), expr_p)
print(expr_r)
```
The third print statement print the same output with the second one as
TestCallback fails to match the `add` op. Anyone can help?
Thanks in advance!
---
[Visit
Topic](https://discuss.tvm.apache.org/t/rewrite-a-function-in-a-relay-graph-failed/10173/1)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/0c21c702dfea80aa5e08bc4f36b29943d00f5c33760f16d10eb203a4969c9dec).