masahi commented on a change in pull request #8781:
URL: https://github.com/apache/tvm/pull/8781#discussion_r693628005



##########
File path: python/tvm/relay/frontend/common.py
##########
@@ -658,6 +658,92 @@ def unbind(data, axis=0):
     return _expr.TupleWrapper(_expr.Tuple(ret), selections)
 
 
+def gru_cell(
+    input_seqs,
+    hidden_state,
+    w_inp,
+    w_hid,
+    b_inp=None,
+    b_hid=None,
+    rz_act=_op.sigmoid,
+    n_act=_op.tanh,
+    backwards=False,
+    linear_before_reset=True,
+):
+    """
+    Common implementation of GRU cell for all frontends of TVM
+    TODO(vvchernov): currently it is used by pytorch. Extend for other 
frontends
+
+    Parameters
+    ----------
+    input_seqs : List[relay.Expr]
+        The sequence of input tensors
+        Input tensor should be 2d while issue #8412 is not resolved
+        Shape = (batch, feature_size)
+    hidden_state : relay.Expr
+        Hidden state. shape = (batch_size, hidden_size)
+    w_inp, w_hid : relay.Expr
+        weight matrices. wi shape = (3 * hidden_size, feature_size)
+        wh shape = (3 * hidden_size, hidden_size)
+        NOTE: wi = (w_ir|w_iz|w_in) for reset, update and new gates.
+        The order is important for correct GRU calculation!
+    b_inp, b_hid : relay.Expr
+        bias matrices. The same order of internal parts as for weights. shape 
= (3 * hidden_size)
+    r_act : relay.op
+        activation funtion for reset gate. it is sigmoid by default
+    z_act : relay.op
+        activation funtion for update gate. it is sigmoid by default
+    n_act : relay.op
+        activation funtion for new gate. it is tanh by default
+    backwards : bool
+        Flag for reverse pass of GRU
+
+    Returns
+    -------
+    result : List[relay.Expr], relay.Expr, relay.Expr
+        The sequence of computed result, final hidden and cell state
+    """
+
+    outputs_list = []
+    for x_t in input_seqs if not backwards else reversed(input_seqs):
+        xwt = _op.nn.dense(x_t, w_inp)
+        if linear_before_reset:
+            hwt = _op.nn.dense(hidden_state, w_hid)
+            # TODO(vvchernov): It is assumed that both bias are or not

Review comment:
       What do you mean by `both bias are or not`? 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to