gemini-code-assist[bot] commented on code in PR #18360:
URL: https://github.com/apache/tvm/pull/18360#discussion_r2404954112


##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -391,6 +391,300 @@ def _lstm(self, node: fx.Node) -> relax.Var:
             output = self.block_builder.emit(relax.op.permute_dims(output, 
axes=[1, 0, 2]))
         return output
 
+    def _gru(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        hx = args[1] if len(args) > 1 else None
+        params = args[2] if len(args) > 2 else None
+        has_biases = args[3] if len(args) > 3 else True
+        num_layers = args[4] if len(args) > 4 else 1
+        _dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
+        _train = args[6] if len(args) > 6 else False  # Not used in inference
+        bidirectional = args[7] if len(args) > 7 else False
+        batch_first = args[8] if len(args) > 8 else False
+
+        if bidirectional:
+            raise NotImplementedError("Bidirectional GRU is not yet supported")
+
+        input_shape = self.shape_of(input_tensor)
+        if batch_first:
+            batch_size, seq_len, input_size = input_shape
+        else:
+            seq_len, batch_size, input_size = input_shape
+
+        if isinstance(seq_len, tvm.tir.IntImm):
+            seq_len = seq_len.value
+        if isinstance(batch_size, tvm.tir.IntImm):
+            batch_size = batch_size.value
+        if isinstance(input_size, tvm.tir.IntImm):
+            input_size = input_size.value
+
+        if params and len(params) >= 2:
+            # For multi-layer, we need to extract the first layer's weights
+            # to determine hidden size
+            if num_layers > 1:
+                # Multi-layer: params[0] is first layer's weight_ih
+                weight_ih = params[0]
+            else:
+                # Single layer: params[0] is weight_ih
+                weight_ih = params[0]
+            # Extract hidden size from weight dimensions
+            # weight_ih has shape (3 * hidden_size, input_size)
+            weight_ih_shape = self.shape_of(weight_ih)
+            hidden_size = weight_ih_shape[0] // 3  # 3 gates: reset, update, 
new
+        else:
+            # Fallback to a default hidden size
+            hidden_size = 16
+
+        # Implement actual GRU computation using Relax operations
+        # GRU equations:
+        # r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
+        # z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
+        # n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
+        # h_t = (1 - z_t) * n_t + z_t * h_{t-1}
+        dtype = input_tensor.struct_info.dtype
+
+        # Reshape input for processing
+        if batch_first:
+            # Input: (batch, seq_len, input_size) -> (seq_len, batch, 
input_size)
+            input_reshaped = self.block_builder.emit(
+                relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
+            )
+        else:
+            input_reshaped = input_tensor
+
+        # Initialize hidden states for all layers
+        if hx is not None:
+            # hx shape: (num_layers, batch_size, hidden_size)
+            h_states = []
+            for layer in range(num_layers):
+                h_layer = self.block_builder.emit(
+                    relax.op.take(hx, relax.const(layer, "int64"), axis=0, 
mode="clip")
+                )
+                h_states.append(h_layer)
+        else:
+            h_states = []
+            for layer in range(num_layers):
+                h_layer = self.block_builder.emit(
+                    relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
+                )
+                h_states.append(h_layer)
+
+        outputs = []
+
+        for t in range(seq_len):
+            # Get input at time t: (batch_size, input_size)
+            x_t = self.block_builder.emit(
+                relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, 
mode="clip")
+            )
+
+            # Process through each layer
+            current_input = x_t
+            new_h_states = []
+
+            for layer in range(num_layers):
+                # Get layer parameters
+                if params and len(params) >= 4 * num_layers:
+                    # Multi-layer case: params are organized as
+                    # [layer0_ih, layer0_hh, layer0_bias_ih, layer0_bias_hh, 
layer1_ih, ...]
+                    param_offset = layer * 4
+                    weight_ih = params[param_offset]
+                    weight_hh = params[param_offset + 1]
+                    bias_ih = params[param_offset + 2] if has_biases else None
+                    bias_hh = params[param_offset + 3] if has_biases else None
+                elif params and len(params) >= 4:
+                    # Single layer case
+                    weight_ih = params[0]
+                    weight_hh = params[1]
+                    bias_ih = params[2] if has_biases else None
+                    bias_hh = params[3] if has_biases else None
+                else:
+                    # Fallback: create zero weights
+                    weight_ih = self.block_builder.emit(
+                        relax.op.zeros(
+                            relax.ShapeExpr(
+                                (3 * hidden_size, input_size if layer == 0 
else hidden_size)
+                            ),
+                            dtype,
+                        )
+                    )
+                    weight_hh = self.block_builder.emit(
+                        relax.op.zeros(relax.ShapeExpr((3 * hidden_size, 
hidden_size)), dtype)
+                    )
+                    bias_ih = None
+                    bias_hh = None
+
+                # Get previous hidden state for this layer
+                h_prev = h_states[layer]
+
+                # Split weights by gates: PyTorch GRU gate order: reset, 
update, new (r, z, n)
+                gate_size = hidden_size
+
+                # Reset gate weights
+                weight_ih_r = self.block_builder.emit(
+                    relax.op.strided_slice(weight_ih, axes=[0], begin=[0], 
end=[gate_size])
+                )
+                weight_hh_r = self.block_builder.emit(
+                    relax.op.strided_slice(weight_hh, axes=[0], begin=[0], 
end=[gate_size])
+                )
+
+                # Update gate weights
+                weight_ih_z = self.block_builder.emit(
+                    relax.op.strided_slice(
+                        weight_ih, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
+                    )
+                )
+                weight_hh_z = self.block_builder.emit(
+                    relax.op.strided_slice(
+                        weight_hh, axes=[0], begin=[gate_size], end=[2 * 
gate_size]
+                    )
+                )
+
+                # New gate weights
+                weight_ih_n = self.block_builder.emit(
+                    relax.op.strided_slice(
+                        weight_ih, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
+                    )
+                )
+                weight_hh_n = self.block_builder.emit(
+                    relax.op.strided_slice(
+                        weight_hh, axes=[0], begin=[2 * gate_size], end=[3 * 
gate_size]
+                    )
+                )
+
+                # Transpose weights for matmul
+                weight_ih_r_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_ih_r, axes=[1, 0])
+                )
+                weight_hh_r_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_hh_r, axes=[1, 0])
+                )
+                weight_ih_z_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_ih_z, axes=[1, 0])
+                )
+                weight_hh_z_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_hh_z, axes=[1, 0])
+                )
+                weight_ih_n_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_ih_n, axes=[1, 0])
+                )
+                weight_hh_n_t = self.block_builder.emit(
+                    relax.op.permute_dims(weight_hh_n, axes=[1, 0])
+                )

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The weight slicing and transposition operations are performed inside the 
time-step loop (`for t in range(seq_len)`). Since these weights do not depend 
on the time step `t`, these computations are redundant and highly inefficient, 
especially for long sequences. They should be hoisted out of the time-step loop 
and computed only once per layer. The same applies to bias slicing (e.g., lines 
583-588, 607-616, 635-644). This will result in a much smaller and more 
efficient computation graph.



##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -6050,5 +6050,76 @@ def main(
     verify_model(TensorNoneModel(), example_args, {}, Expected)
 
 
+def test_gru():
+    class BasicGRU(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.gru = nn.GRU(
+                input_size=4,
+                hidden_size=8,
+                num_layers=1,
+                batch_first=True,
+                bidirectional=False,
+            )
+
+        def forward(self, x):
+            y, _ = self.gru(x)
+            return y
+
+    torch.manual_seed(42)
+    x = torch.randn(2, 3, 4, dtype=torch.float32)
+    model = BasicGRU()
+    with torch.no_grad():
+        pytorch_output = model(x)
+    exported_program = export(model, args=(x,))
+    mod = from_exported_program(exported_program)
+    target = tvm.target.Target("llvm")
+    ex = relax.build(mod, target)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x_tvm = tvm.runtime.tensor(x.numpy())
+    tvm_output = vm["main"](x_tvm)
+    if hasattr(tvm_output, "numpy"):
+        tvm_output_np = tvm_output.numpy()
+    else:
+        tvm_output_np = tvm_output[0].numpy()
+    assert (
+        pytorch_output.shape == tvm_output_np.shape
+    ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM 
{tvm_output_np.shape}"
+    np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, 
rtol=1e-4, atol=1e-5)
+
+    class SeqFirstGRU(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.gru = nn.GRU(
+                input_size=3,
+                hidden_size=6,
+                num_layers=1,
+                batch_first=False,
+                bidirectional=False,
+            )
+
+        def forward(self, x):
+            y, _ = self.gru(x)
+            return y
+
+    torch.manual_seed(43)
+    x2 = torch.randn(4, 2, 3, dtype=torch.float32)
+    model2 = SeqFirstGRU()
+    with torch.no_grad():
+        pytorch_output2 = model2(x2)
+    exported_program2 = export(model2, args=(x2,))
+    mod2 = from_exported_program(exported_program2)
+    ex2 = relax.build(mod2, target)
+    vm2 = relax.VirtualMachine(ex2, tvm.cpu())
+    x2_tvm = tvm.runtime.tensor(x2.numpy())
+    tvm_output2 = vm2["main"](x2_tvm)
+    if hasattr(tvm_output2, "numpy"):
+        tvm_output2_np = tvm_output2.numpy()
+    else:
+        tvm_output2_np = tvm_output2[0].numpy()
+    assert pytorch_output2.shape == tvm_output2_np.shape
+    np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, 
rtol=1e-4, atol=1e-5)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The test function `test_gru` contains two very similar blocks of code for 
testing `BasicGRU` (`batch_first=True`) and `SeqFirstGRU` 
(`batch_first=False`). This code duplication makes the test harder to read and 
maintain. Consider refactoring the common testing logic into a helper function 
that can be called for both GRU configurations. This helper could take the 
model class, input data, and other relevant parameters as arguments.



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -391,6 +391,300 @@ def _lstm(self, node: fx.Node) -> relax.Var:
             output = self.block_builder.emit(relax.op.permute_dims(output, 
axes=[1, 0, 2]))
         return output
 
+    def _gru(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        hx = args[1] if len(args) > 1 else None
+        params = args[2] if len(args) > 2 else None
+        has_biases = args[3] if len(args) > 3 else True
+        num_layers = args[4] if len(args) > 4 else 1
+        _dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
+        _train = args[6] if len(args) > 6 else False  # Not used in inference
+        bidirectional = args[7] if len(args) > 7 else False
+        batch_first = args[8] if len(args) > 8 else False
+
+        if bidirectional:
+            raise NotImplementedError("Bidirectional GRU is not yet supported")
+
+        input_shape = self.shape_of(input_tensor)
+        if batch_first:
+            batch_size, seq_len, input_size = input_shape
+        else:
+            seq_len, batch_size, input_size = input_shape
+
+        if isinstance(seq_len, tvm.tir.IntImm):
+            seq_len = seq_len.value
+        if isinstance(batch_size, tvm.tir.IntImm):
+            batch_size = batch_size.value
+        if isinstance(input_size, tvm.tir.IntImm):
+            input_size = input_size.value
+
+        if params and len(params) >= 2:
+            # For multi-layer, we need to extract the first layer's weights
+            # to determine hidden size
+            if num_layers > 1:
+                # Multi-layer: params[0] is first layer's weight_ih
+                weight_ih = params[0]
+            else:
+                # Single layer: params[0] is weight_ih
+                weight_ih = params[0]

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This `if/else` block is redundant as both branches execute the same code 
(`weight_ih = params[0]`). This can be simplified to improve code clarity and 
maintainability.
   
   ```python
               # For multi-layer, we need to extract the first layer's weights
               # to determine hidden size. params[0] is the first layer's 
weight_ih
               # for both single and multi-layer cases.
               weight_ih = params[0]
   ```



##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -6050,5 +6050,76 @@ def main(
     verify_model(TensorNoneModel(), example_args, {}, Expected)
 
 
+def test_gru():
+    class BasicGRU(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.gru = nn.GRU(
+                input_size=4,
+                hidden_size=8,
+                num_layers=1,
+                batch_first=True,
+                bidirectional=False,
+            )
+
+        def forward(self, x):
+            y, _ = self.gru(x)
+            return y
+
+    torch.manual_seed(42)
+    x = torch.randn(2, 3, 4, dtype=torch.float32)
+    model = BasicGRU()
+    with torch.no_grad():
+        pytorch_output = model(x)
+    exported_program = export(model, args=(x,))
+    mod = from_exported_program(exported_program)
+    target = tvm.target.Target("llvm")
+    ex = relax.build(mod, target)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x_tvm = tvm.runtime.tensor(x.numpy())
+    tvm_output = vm["main"](x_tvm)
+    if hasattr(tvm_output, "numpy"):
+        tvm_output_np = tvm_output.numpy()
+    else:
+        tvm_output_np = tvm_output[0].numpy()
+    assert (
+        pytorch_output.shape == tvm_output_np.shape
+    ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM 
{tvm_output_np.shape}"
+    np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, 
rtol=1e-4, atol=1e-5)
+
+    class SeqFirstGRU(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.gru = nn.GRU(
+                input_size=3,
+                hidden_size=6,
+                num_layers=1,
+                batch_first=False,
+                bidirectional=False,
+            )
+
+        def forward(self, x):
+            y, _ = self.gru(x)
+            return y
+
+    torch.manual_seed(43)
+    x2 = torch.randn(4, 2, 3, dtype=torch.float32)
+    model2 = SeqFirstGRU()
+    with torch.no_grad():
+        pytorch_output2 = model2(x2)
+    exported_program2 = export(model2, args=(x2,))
+    mod2 = from_exported_program(exported_program2)
+    ex2 = relax.build(mod2, target)
+    vm2 = relax.VirtualMachine(ex2, tvm.cpu())
+    x2_tvm = tvm.runtime.tensor(x2.numpy())
+    tvm_output2 = vm2["main"](x2_tvm)
+    if hasattr(tvm_output2, "numpy"):
+        tvm_output2_np = tvm_output2.numpy()
+    else:
+        tvm_output2_np = tvm_output2[0].numpy()
+    assert pytorch_output2.shape == tvm_output2_np.shape
+    np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, 
rtol=1e-4, atol=1e-5)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `_gru` implementation supports multi-layer GRUs, GRUs with an initial 
hidden state (`hx`), and GRUs with/without biases. However, the tests only 
cover single-layer GRUs without an initial hidden state and with biases. To 
ensure the implementation is robust and prevent future regressions, please add 
test cases for:
   
   - Multi-layer GRU (`num_layers > 1`).
   - GRU with a provided initial hidden state (`hx`).
   - GRU without biases (`bias=False` in `nn.GRU`).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to