apeskov commented on code in PR #11513:
URL: https://github.com/apache/tvm/pull/11513#discussion_r893219794


##########
python/tvm/relay/op/contrib/dnnl.py:
##########
@@ -594,3 +611,104 @@ def rewrite_layer_norm(mod):
     """
     mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
     return mod
+
+
+class DenseReshapeBiasGeluRewrite(DFPatternCallback):
+    """
+    A callback to reorder reshape operators when the patterns are as below:
+
+    Pattern #1:
+    1   %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), 
float32] */,
+                units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), 
float32] */;
+    2   %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), 
float32] */;
+    3   %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63)
+                /* ty=Tensor[(1, 3136, 64), float32] */;
+
+    Pattern #2:
+    1   %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), 
float32] */,
+                units=None, out_dtype="float32") /*  ty=Tensor[(3136, 512), 
float32] */;
+    2   %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 
512), float32] */;
+    3   %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, 
%77)
+                /* ty=Tensor[(1, 3136, 512), float32] */;
+    4   %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136, 
512), float32] */;
+    5   %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */;
+    6   %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), 
float32] */;
+    7   %82 = multiply(%78, %81) /* ty=Tensor[(1, 3136, 512), float32] */;
+    8   %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 
512), float32] */;
+    """
+
+    def __init__(self, has_gelu=True):
+        super(DenseReshapeBiasGeluRewrite, self).__init__()
+        self.data = wildcard()
+        self.weight = wildcard()
+        self.bias = wildcard()
+        self.const1 = wildcard()
+        self.const2 = wildcard()
+        self.const3 = wildcard()
+
+        self.attr_map = {}
+        self.has_gelu = has_gelu
+
+        den = is_op("nn.dense")(self.data, self.weight)
+        re_den = is_op("reshape")(den)
+        added = is_op("add")(self.bias, re_den)
+        if self.has_gelu:
+            divisor = is_op("divide")(added, self.const1)
+            val_erf = is_op("erf")(divisor)
+            added_erf = is_op("add")(val_erf, self.const2)
+            mul1 = is_op("multiply")(added, added_erf)
+            mul2 = is_op("multiply")(mul1, self.const3)
+            self.pattern = mul2
+        else:
+            self.pattern = added
+
+    def get_attr(self, pre):
+        """Recursively retrieve attributes from reshape operator."""
+
+        def visit_func(expr):
+            if isinstance(expr, _expr.Call) and expr.op == 
relay.op.get("reshape"):
+                new_attrs = {}
+                for k in expr.attrs.keys():
+                    new_attrs[k] = expr.attrs[k]
+                self.attr_map["reshape"] = new_attrs
+
+        _analysis.post_order_visit(pre, visit_func)
+
+    def callback(self, pre, post, node_map):
+        self.get_attr(pre)
+
+        data = node_map[self.data][0]
+        weight = node_map[self.weight][0]
+        bias = node_map[self.bias][0]
+
+        den = relay.op.nn.dense(data, weight)
+        added = relay.op.add(bias, den)
+        if not self.has_gelu:
+            return relay.op.reshape(added, 
self.attr_map["reshape"]["newshape"])
+
+        const1 = node_map[self.const1][0]
+        const2 = node_map[self.const2][0]
+        const3 = node_map[self.const3][0]
+
+        divisor = relay.op.divide(added, const1)
+        val_erf = relay.op.erf(divisor)
+        added_erf = relay.op.add(val_erf, const2)
+        mul1 = relay.op.multiply(added, added_erf)
+        mul2 = relay.op.multiply(mul1, const3)
+        return relay.op.reshape(mul2, self.attr_map["reshape"]["newshape"])
+
+
+def rewrite_dense_bias_gelu_reshape_last(mod):

Review Comment:
   Cannot find where it used in test? I see that you added several "gelu" tests 
but where this transformation pass is applied?   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to