gemini-code-assist[bot] commented on code in PR #18346:
URL: https://github.com/apache/tvm/pull/18346#discussion_r2380853039


##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
             align_corners=align_corners,
         )
 
+    def _lstm(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        hx = args[1] if len(args) > 1 else None
+        params = args[2] if len(args) > 2 else None
+        has_biases = args[3] if len(args) > 3 else True
+        num_layers = args[4] if len(args) > 4 else 1
+        dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
+        train = args[6] if len(args) > 6 else False  # Not used in inference
+        bidirectional = args[7] if len(args) > 7 else False
+        batch_first = args[8] if len(args) > 8 else False
+        if bidirectional:
+            raise NotImplementedError("Bidirectional LSTM is not yet 
supported")
+        if num_layers > 1:
+            raise NotImplementedError("Multi-layer LSTM is not yet supported")
+        input_shape = self.shape_of(input_tensor)
+        if batch_first:
+            # Input shape: (batch, seq_len, input_size)
+            batch_size, seq_len, input_size = input_shape
+        else:
+            # Input shape: (seq_len, batch, input_size)
+            seq_len, batch_size, input_size = input_shape
+
+        if hasattr(seq_len, "value"):
+            seq_len = seq_len.value
+        if hasattr(batch_size, "value"):
+            batch_size = batch_size.value
+        if hasattr(input_size, "value"):
+            input_size = input_size.value

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The logic to extract shape components like `seq_len` does not correctly 
handle dynamic shapes. Using `hasattr(seq_len, "value")` and then accessing 
`.value` only works for `tir.IntImm`, but not for `tir.Var` which represents a 
dynamic dimension. This will cause an `AttributeError` when tracing models with 
dynamic sequence lengths.
   
   Additionally, the loop `for t in range(seq_len):` requires `seq_len` to be a 
compile-time constant, which prevents support for dynamic sequence lengths. The 
implementation should be updated to handle dynamic shapes, for example by using 
`relax.scan` instead of unrolling the loop.
   
   ```suggestion
           if isinstance(seq_len, tvm.tir.IntImm):
               seq_len = seq_len.value
           if isinstance(batch_size, tvm.tir.IntImm):
               batch_size = batch_size.value
           if isinstance(input_size, tvm.tir.IntImm):
               input_size = input_size.value
   ```



##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1997,6 +1997,12 @@ def _getitem(self, node: fx.Node) -> relax.Var:
                 return self.block_builder.emit(relax.TupleGetItem(x, 
node.args[1]))
 
             assert isinstance(x.struct_info, relax.TensorStructInfo)
+            if isinstance(node.args[1], int):
+                return x
+            if not isinstance(node.args[1], (list, tuple)):
+                indices = [node.args[1]]
+            else:
+                indices = node.args[1]

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   The new logic for handling integer indexing on tensors is incorrect. It 
returns the tensor itself, which breaks the semantics of tensor indexing. For a 
tensor `x`, `x[0]` should return the first slice along axis 0, which has a 
reduced rank. The current implementation returns `x` unmodified. Since 
`_getitem` is a general-purpose function, this change can cause incorrect 
behavior for other operators that rely on it.
   
   This seems to be a workaround for an incomplete `_lstm` implementation. The 
correct fix should be in the `_lstm` operator implementation to return a proper 
tuple output, and this logic should be removed from `_getitem`.



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
             align_corners=align_corners,
         )
 
+    def _lstm(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        hx = args[1] if len(args) > 1 else None
+        params = args[2] if len(args) > 2 else None
+        has_biases = args[3] if len(args) > 3 else True
+        num_layers = args[4] if len(args) > 4 else 1
+        dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
+        train = args[6] if len(args) > 6 else False  # Not used in inference
+        bidirectional = args[7] if len(args) > 7 else False
+        batch_first = args[8] if len(args) > 8 else False
+        if bidirectional:
+            raise NotImplementedError("Bidirectional LSTM is not yet 
supported")
+        if num_layers > 1:
+            raise NotImplementedError("Multi-layer LSTM is not yet supported")
+        input_shape = self.shape_of(input_tensor)
+        if batch_first:
+            # Input shape: (batch, seq_len, input_size)
+            batch_size, seq_len, input_size = input_shape
+        else:
+            # Input shape: (seq_len, batch, input_size)
+            seq_len, batch_size, input_size = input_shape
+
+        if hasattr(seq_len, "value"):
+            seq_len = seq_len.value
+        if hasattr(batch_size, "value"):
+            batch_size = batch_size.value
+        if hasattr(input_size, "value"):
+            input_size = input_size.value
+        # Extract hidden size from the LSTM parameters
+        # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
+        # weight_ih shape: (4 * hidden_size, input_size)
+        # weight_hh shape: (4 * hidden_size, hidden_size)
+        if params and len(params) >= 2:
+            weight_ih = params[0]
+            weight_hh = params[1]
+            # Extract hidden size from weight dimensions
+            # weight_ih has shape (4 * hidden_size, input_size)
+            weight_ih_shape = self.shape_of(weight_ih)
+            hidden_size = weight_ih_shape[0] // 4  # 4 gates: input, forget, 
cell, output
+        else:
+            # Fallback to a default hidden size
+            hidden_size = 16
+        # Implement actual LSTM computation using  Relax operations
+        # LSTM equations:
+        # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi)
+        # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf)
+        # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg)
+        # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho)
+        # c_t = f_t * c_{t-1} + i_t * g_t
+        # h_t = o_t * tanh(c_t)
+        dtype = input_tensor.struct_info.dtype
+        if params and len(params) >= 4:
+            weight_ih = params[0]  # (4 * hidden_size, input_size)
+            weight_hh = params[1]  # (4 * hidden_size, hidden_size)
+            bias_ih = params[2] if has_biases else None  # (4 * hidden_size,)
+            bias_hh = params[3] if has_biases else None  # (4 * hidden_size,)
+        else:
+            # Fallback: create zero weights
+            weight_ih = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), 
dtype)
+            )
+            weight_hh = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((4 * hidden_size, 
hidden_size)), dtype)
+            )
+            bias_ih = None
+            bias_hh = None
+        # Initialize hidden and cell states
+        if hx is not None and len(hx) >= 2:
+            h_0 = hx[0]  # (num_layers, batch_size, hidden_size)
+            c_0 = hx[1]  # (num_layers, batch_size, hidden_size)
+            # Extract the first layer's hidden state
+            h_prev = self.block_builder.emit(
+                relax.op.take(h_0, relax.const(0, "int64"), axis=0, 
mode="clip")
+            )
+            c_prev = self.block_builder.emit(
+                relax.op.take(c_0, relax.const(0, "int64"), axis=0, 
mode="clip")
+            )
+        else:
+            h_prev = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
+            )
+            c_prev = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
+            )
+        # Reshape input for processing
+        if batch_first:
+            # Input: (batch, seq_len, input_size) -> (seq_len, batch, 
input_size)
+            input_reshaped = self.block_builder.emit(
+                relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
+            )
+        else:
+            input_reshaped = input_tensor
+        weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, 
axes=[1, 0]))
+        weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, 
axes=[1, 0]))
+        outputs = []
+        for t in range(seq_len):
+            # Get input at time t: (batch_size, input_size)
+            x_t = self.block_builder.emit(
+                relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, 
mode="clip")
+            )
+            # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias
+            # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size, 
input_size).T
+            ih_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
+
+            # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size, 
hidden_size).T
+            hh_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
+            # Add biases if present
+            if bias_ih is not None and bias_hh is not None:
+                gates = self.block_builder.emit(
+                    relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), 
hh_gates), bias_hh)
+                )
+            elif bias_ih is not None:
+                gates = self.block_builder.emit(
+                    relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates)
+                )
+            elif bias_hh is not None:
+                gates = self.block_builder.emit(
+                    relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh)
+                )
+            else:
+                gates = self.block_builder.emit(relax.op.add(ih_gates, 
hh_gates))
+            # Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size, 
hidden_size)
+            gate_size = hidden_size
+            i_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[0], 
end=[gate_size])
+            )
+            f_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[gate_size], 
end=[2 * gate_size])
+            )
+            g_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size], 
end=[3 * gate_size])
+            )
+            o_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size], 
end=[4 * gate_size])
+            )

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The four gates (input, forget, cell, output) are split from the concatenated 
gates tensor using four separate `strided_slice` operations. This can be done 
more efficiently and concisely using a single `relax.op.split` operation, which 
would also improve readability.
   
   ```python
               gate_tuple = self.block_builder.emit(relax.op.split(gates, 4, 
axis=1))
               i_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 
0))
               f_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 
1))
               g_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 
2))
               o_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 
3))
   ```



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
             align_corners=align_corners,
         )
 
+    def _lstm(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        hx = args[1] if len(args) > 1 else None
+        params = args[2] if len(args) > 2 else None
+        has_biases = args[3] if len(args) > 3 else True
+        num_layers = args[4] if len(args) > 4 else 1
+        dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
+        train = args[6] if len(args) > 6 else False  # Not used in inference
+        bidirectional = args[7] if len(args) > 7 else False
+        batch_first = args[8] if len(args) > 8 else False
+        if bidirectional:
+            raise NotImplementedError("Bidirectional LSTM is not yet 
supported")
+        if num_layers > 1:
+            raise NotImplementedError("Multi-layer LSTM is not yet supported")
+        input_shape = self.shape_of(input_tensor)
+        if batch_first:
+            # Input shape: (batch, seq_len, input_size)
+            batch_size, seq_len, input_size = input_shape
+        else:
+            # Input shape: (seq_len, batch, input_size)
+            seq_len, batch_size, input_size = input_shape
+
+        if hasattr(seq_len, "value"):
+            seq_len = seq_len.value
+        if hasattr(batch_size, "value"):
+            batch_size = batch_size.value
+        if hasattr(input_size, "value"):
+            input_size = input_size.value
+        # Extract hidden size from the LSTM parameters
+        # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
+        # weight_ih shape: (4 * hidden_size, input_size)
+        # weight_hh shape: (4 * hidden_size, hidden_size)
+        if params and len(params) >= 2:
+            weight_ih = params[0]
+            weight_hh = params[1]
+            # Extract hidden size from weight dimensions
+            # weight_ih has shape (4 * hidden_size, input_size)
+            weight_ih_shape = self.shape_of(weight_ih)
+            hidden_size = weight_ih_shape[0] // 4  # 4 gates: input, forget, 
cell, output
+        else:
+            # Fallback to a default hidden size
+            hidden_size = 16

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The fallback logic for when LSTM parameters are not provided is problematic. 
It defaults to `hidden_size = 16`. This can lead to silent correctness issues 
and hard-to-debug errors. It would be better to raise a `ValueError` if the 
parameters are not available to determine `hidden_size`, as a valid LSTM layer 
must have weights.
   
   ```suggestion
           else:
               raise ValueError("Cannot determine hidden_size. LSTM params 
(weights) are required.")
   ```



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
             align_corners=align_corners,
         )
 
+    def _lstm(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        hx = args[1] if len(args) > 1 else None
+        params = args[2] if len(args) > 2 else None
+        has_biases = args[3] if len(args) > 3 else True
+        num_layers = args[4] if len(args) > 4 else 1
+        dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
+        train = args[6] if len(args) > 6 else False  # Not used in inference
+        bidirectional = args[7] if len(args) > 7 else False
+        batch_first = args[8] if len(args) > 8 else False
+        if bidirectional:
+            raise NotImplementedError("Bidirectional LSTM is not yet 
supported")
+        if num_layers > 1:
+            raise NotImplementedError("Multi-layer LSTM is not yet supported")
+        input_shape = self.shape_of(input_tensor)
+        if batch_first:
+            # Input shape: (batch, seq_len, input_size)
+            batch_size, seq_len, input_size = input_shape
+        else:
+            # Input shape: (seq_len, batch, input_size)
+            seq_len, batch_size, input_size = input_shape
+
+        if hasattr(seq_len, "value"):
+            seq_len = seq_len.value
+        if hasattr(batch_size, "value"):
+            batch_size = batch_size.value
+        if hasattr(input_size, "value"):
+            input_size = input_size.value
+        # Extract hidden size from the LSTM parameters
+        # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
+        # weight_ih shape: (4 * hidden_size, input_size)
+        # weight_hh shape: (4 * hidden_size, hidden_size)
+        if params and len(params) >= 2:
+            weight_ih = params[0]
+            weight_hh = params[1]
+            # Extract hidden size from weight dimensions
+            # weight_ih has shape (4 * hidden_size, input_size)
+            weight_ih_shape = self.shape_of(weight_ih)
+            hidden_size = weight_ih_shape[0] // 4  # 4 gates: input, forget, 
cell, output
+        else:
+            # Fallback to a default hidden size
+            hidden_size = 16
+        # Implement actual LSTM computation using  Relax operations
+        # LSTM equations:
+        # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi)
+        # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf)
+        # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg)
+        # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho)
+        # c_t = f_t * c_{t-1} + i_t * g_t
+        # h_t = o_t * tanh(c_t)
+        dtype = input_tensor.struct_info.dtype
+        if params and len(params) >= 4:
+            weight_ih = params[0]  # (4 * hidden_size, input_size)
+            weight_hh = params[1]  # (4 * hidden_size, hidden_size)
+            bias_ih = params[2] if has_biases else None  # (4 * hidden_size,)
+            bias_hh = params[3] if has_biases else None  # (4 * hidden_size,)
+        else:
+            # Fallback: create zero weights
+            weight_ih = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), 
dtype)
+            )
+            weight_hh = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((4 * hidden_size, 
hidden_size)), dtype)
+            )
+            bias_ih = None
+            bias_hh = None

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   Creating zero-tensors for weights as a fallback is problematic. This can 
lead to silent correctness issues where the model compiles but produces 
incorrect (zero) outputs. It's better to raise an error if weights are not 
provided, as they are essential for a functional LSTM layer.
   
   ```suggestion
           else:
               raise ValueError("LSTM params (weights) are required.")
   ```



##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -5914,6 +5915,102 @@ def main(
     verify_model(Model(), example_args, {}, Expected)
 
 
+def test_mm():
+    class MatrixMultiply(Module):
+        def forward(self, a, b):
+            return torch.mm(a, b)
+
+    example_args = (
+        torch.randn(2, 3, dtype=torch.float32),
+        torch.randn(3, 4, dtype=torch.float32),
+    )
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            a: R.Tensor((2, 3), dtype="float32"),
+            b: R.Tensor((3, 4), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 4), dtype="float32") = R.matmul(a, b, 
out_dtype="float32")
+                gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(MatrixMultiply(), example_args, {}, Expected)
+
+
+def test_lstm():
+    class BasicLSTM(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.lstm = nn.LSTM(
+                input_size=4,
+                hidden_size=8,
+                num_layers=1,
+                batch_first=True,
+                bidirectional=False,
+            )
+
+        def forward(self, x):
+            y, _ = self.lstm(x)
+            return y
+
+    torch.manual_seed(42)
+    x = torch.randn(2, 3, 4, dtype=torch.float32)
+    model = BasicLSTM()
+    with torch.no_grad():
+        pytorch_output = model(x)
+    exported_program = export(model, args=(x,))
+    mod = from_exported_program(exported_program)
+    target = tvm.target.Target("llvm")
+    ex = relax.build(mod, target)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x_tvm = tvm.runtime.tensor(x.numpy())
+    tvm_output = vm["main"](x_tvm)
+    if hasattr(tvm_output, "numpy"):
+        tvm_output_np = tvm_output.numpy()
+    else:
+        tvm_output_np = tvm_output[0].numpy()
+    assert (
+        pytorch_output.shape == tvm_output_np.shape
+    ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM 
{tvm_output_np.shape}"
+    np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, 
rtol=1e-4, atol=1e-5)
+
+    class SeqFirstLSTM(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.lstm = nn.LSTM(
+                input_size=3,
+                hidden_size=6,
+                num_layers=1,
+                batch_first=False,
+                bidirectional=False,
+            )
+
+        def forward(self, x):
+            y, _ = self.lstm(x)
+            return y
+
+    torch.manual_seed(43)
+    x2 = torch.randn(4, 2, 3, dtype=torch.float32)
+    model2 = SeqFirstLSTM()
+    with torch.no_grad():
+        pytorch_output2 = model2(x2)
+    exported_program2 = export(model2, args=(x2,))
+    mod2 = from_exported_program(exported_program2)
+    ex2 = relax.build(mod2, target)
+    vm2 = relax.VirtualMachine(ex2, tvm.cpu())
+    x2_tvm = tvm.runtime.tensor(x2.numpy())
+    tvm_output2 = vm2["main"](x2_tvm)
+    if hasattr(tvm_output2, "numpy"):
+        tvm_output2_np = tvm_output2.numpy()
+    else:
+        tvm_output2_np = tvm_output2[0].numpy()
+    assert pytorch_output2.shape == tvm_output2_np.shape
+    np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, 
rtol=1e-4, atol=1e-5)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `test_lstm` function contains a significant amount of duplicated code 
for testing the `batch_first=True` and `batch_first=False` cases. This can be 
refactored into a helper function to improve readability and maintainability. 
The helper function could take the model and input tensor as arguments and 
perform the verification logic.



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
             align_corners=align_corners,
         )
 
+    def _lstm(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        hx = args[1] if len(args) > 1 else None
+        params = args[2] if len(args) > 2 else None
+        has_biases = args[3] if len(args) > 3 else True
+        num_layers = args[4] if len(args) > 4 else 1
+        dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
+        train = args[6] if len(args) > 6 else False  # Not used in inference
+        bidirectional = args[7] if len(args) > 7 else False
+        batch_first = args[8] if len(args) > 8 else False
+        if bidirectional:
+            raise NotImplementedError("Bidirectional LSTM is not yet 
supported")
+        if num_layers > 1:
+            raise NotImplementedError("Multi-layer LSTM is not yet supported")
+        input_shape = self.shape_of(input_tensor)
+        if batch_first:
+            # Input shape: (batch, seq_len, input_size)
+            batch_size, seq_len, input_size = input_shape
+        else:
+            # Input shape: (seq_len, batch, input_size)
+            seq_len, batch_size, input_size = input_shape
+
+        if hasattr(seq_len, "value"):
+            seq_len = seq_len.value
+        if hasattr(batch_size, "value"):
+            batch_size = batch_size.value
+        if hasattr(input_size, "value"):
+            input_size = input_size.value
+        # Extract hidden size from the LSTM parameters
+        # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
+        # weight_ih shape: (4 * hidden_size, input_size)
+        # weight_hh shape: (4 * hidden_size, hidden_size)
+        if params and len(params) >= 2:
+            weight_ih = params[0]
+            weight_hh = params[1]
+            # Extract hidden size from weight dimensions
+            # weight_ih has shape (4 * hidden_size, input_size)
+            weight_ih_shape = self.shape_of(weight_ih)
+            hidden_size = weight_ih_shape[0] // 4  # 4 gates: input, forget, 
cell, output
+        else:
+            # Fallback to a default hidden size
+            hidden_size = 16
+        # Implement actual LSTM computation using  Relax operations
+        # LSTM equations:
+        # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi)
+        # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf)
+        # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg)
+        # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho)
+        # c_t = f_t * c_{t-1} + i_t * g_t
+        # h_t = o_t * tanh(c_t)
+        dtype = input_tensor.struct_info.dtype
+        if params and len(params) >= 4:
+            weight_ih = params[0]  # (4 * hidden_size, input_size)
+            weight_hh = params[1]  # (4 * hidden_size, hidden_size)
+            bias_ih = params[2] if has_biases else None  # (4 * hidden_size,)
+            bias_hh = params[3] if has_biases else None  # (4 * hidden_size,)
+        else:
+            # Fallback: create zero weights
+            weight_ih = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), 
dtype)
+            )
+            weight_hh = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((4 * hidden_size, 
hidden_size)), dtype)
+            )
+            bias_ih = None
+            bias_hh = None
+        # Initialize hidden and cell states
+        if hx is not None and len(hx) >= 2:
+            h_0 = hx[0]  # (num_layers, batch_size, hidden_size)
+            c_0 = hx[1]  # (num_layers, batch_size, hidden_size)
+            # Extract the first layer's hidden state
+            h_prev = self.block_builder.emit(
+                relax.op.take(h_0, relax.const(0, "int64"), axis=0, 
mode="clip")
+            )
+            c_prev = self.block_builder.emit(
+                relax.op.take(c_0, relax.const(0, "int64"), axis=0, 
mode="clip")
+            )
+        else:
+            h_prev = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
+            )
+            c_prev = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
+            )
+        # Reshape input for processing
+        if batch_first:
+            # Input: (batch, seq_len, input_size) -> (seq_len, batch, 
input_size)
+            input_reshaped = self.block_builder.emit(
+                relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
+            )
+        else:
+            input_reshaped = input_tensor
+        weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, 
axes=[1, 0]))
+        weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, 
axes=[1, 0]))
+        outputs = []
+        for t in range(seq_len):
+            # Get input at time t: (batch_size, input_size)
+            x_t = self.block_builder.emit(
+                relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, 
mode="clip")
+            )
+            # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias
+            # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size, 
input_size).T
+            ih_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
+
+            # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size, 
hidden_size).T
+            hh_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
+            # Add biases if present
+            if bias_ih is not None and bias_hh is not None:
+                gates = self.block_builder.emit(
+                    relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), 
hh_gates), bias_hh)
+                )
+            elif bias_ih is not None:
+                gates = self.block_builder.emit(
+                    relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates)
+                )
+            elif bias_hh is not None:
+                gates = self.block_builder.emit(
+                    relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh)
+                )
+            else:
+                gates = self.block_builder.emit(relax.op.add(ih_gates, 
hh_gates))
+            # Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size, 
hidden_size)
+            gate_size = hidden_size
+            i_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[0], 
end=[gate_size])
+            )
+            f_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[gate_size], 
end=[2 * gate_size])
+            )
+            g_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size], 
end=[3 * gate_size])
+            )
+            o_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size], 
end=[4 * gate_size])
+            )
+            # Apply activations
+            i_t = self.block_builder.emit(relax.op.sigmoid(i_gate))
+            f_t = self.block_builder.emit(relax.op.sigmoid(f_gate))
+            g_t = self.block_builder.emit(relax.op.tanh(g_gate))
+            o_t = self.block_builder.emit(relax.op.sigmoid(o_gate))
+            # Update cell state: c_t = f_t * c_{t-1} + i_t * g_t
+            c_t = self.block_builder.emit(
+                relax.op.add(relax.op.multiply(f_t, c_prev), 
relax.op.multiply(i_t, g_t))
+            )
+            # Update hidden state: h_t = o_t * tanh(c_t)
+            h_t = self.block_builder.emit(relax.op.multiply(o_t, 
relax.op.tanh(c_t)))
+            # Store output
+            outputs.append(h_t)
+            # Update for next iteration
+            h_prev = h_t
+            c_prev = c_t
+        # Stack outputs: (seq_len, batch_size, hidden_size)
+        output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+        # Reshape back to batch_first if needed
+        if batch_first:
+            # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, 
hidden_size)
+            output = self.block_builder.emit(relax.op.permute_dims(output, 
axes=[1, 0, 2]))
+        return output

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The `_lstm` implementation is incomplete. It only returns the output 
sequence but not the final hidden and cell states, which are part of the 
standard `torch.nn.LSTM` output `(output, (h_n, c_n))`. This will lead to 
incorrect behavior for models that use these states. The function should be 
updated to return a tuple containing the output sequence and the final 
hidden/cell states to fully match the PyTorch operator's behavior.



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
             align_corners=align_corners,
         )
 
+    def _lstm(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        hx = args[1] if len(args) > 1 else None
+        params = args[2] if len(args) > 2 else None
+        has_biases = args[3] if len(args) > 3 else True
+        num_layers = args[4] if len(args) > 4 else 1
+        dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
+        train = args[6] if len(args) > 6 else False  # Not used in inference
+        bidirectional = args[7] if len(args) > 7 else False
+        batch_first = args[8] if len(args) > 8 else False
+        if bidirectional:
+            raise NotImplementedError("Bidirectional LSTM is not yet 
supported")
+        if num_layers > 1:
+            raise NotImplementedError("Multi-layer LSTM is not yet supported")
+        input_shape = self.shape_of(input_tensor)
+        if batch_first:
+            # Input shape: (batch, seq_len, input_size)
+            batch_size, seq_len, input_size = input_shape
+        else:
+            # Input shape: (seq_len, batch, input_size)
+            seq_len, batch_size, input_size = input_shape
+
+        if hasattr(seq_len, "value"):
+            seq_len = seq_len.value
+        if hasattr(batch_size, "value"):
+            batch_size = batch_size.value
+        if hasattr(input_size, "value"):
+            input_size = input_size.value
+        # Extract hidden size from the LSTM parameters
+        # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
+        # weight_ih shape: (4 * hidden_size, input_size)
+        # weight_hh shape: (4 * hidden_size, hidden_size)
+        if params and len(params) >= 2:
+            weight_ih = params[0]
+            weight_hh = params[1]
+            # Extract hidden size from weight dimensions
+            # weight_ih has shape (4 * hidden_size, input_size)
+            weight_ih_shape = self.shape_of(weight_ih)
+            hidden_size = weight_ih_shape[0] // 4  # 4 gates: input, forget, 
cell, output
+        else:
+            # Fallback to a default hidden size
+            hidden_size = 16
+        # Implement actual LSTM computation using  Relax operations
+        # LSTM equations:
+        # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi)
+        # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf)
+        # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg)
+        # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho)
+        # c_t = f_t * c_{t-1} + i_t * g_t
+        # h_t = o_t * tanh(c_t)
+        dtype = input_tensor.struct_info.dtype
+        if params and len(params) >= 4:
+            weight_ih = params[0]  # (4 * hidden_size, input_size)
+            weight_hh = params[1]  # (4 * hidden_size, hidden_size)
+            bias_ih = params[2] if has_biases else None  # (4 * hidden_size,)
+            bias_hh = params[3] if has_biases else None  # (4 * hidden_size,)
+        else:
+            # Fallback: create zero weights
+            weight_ih = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), 
dtype)
+            )
+            weight_hh = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((4 * hidden_size, 
hidden_size)), dtype)
+            )
+            bias_ih = None
+            bias_hh = None
+        # Initialize hidden and cell states
+        if hx is not None and len(hx) >= 2:
+            h_0 = hx[0]  # (num_layers, batch_size, hidden_size)
+            c_0 = hx[1]  # (num_layers, batch_size, hidden_size)
+            # Extract the first layer's hidden state
+            h_prev = self.block_builder.emit(
+                relax.op.take(h_0, relax.const(0, "int64"), axis=0, 
mode="clip")
+            )
+            c_prev = self.block_builder.emit(
+                relax.op.take(c_0, relax.const(0, "int64"), axis=0, 
mode="clip")
+            )
+        else:
+            h_prev = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
+            )
+            c_prev = self.block_builder.emit(
+                relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), 
dtype)
+            )
+        # Reshape input for processing
+        if batch_first:
+            # Input: (batch, seq_len, input_size) -> (seq_len, batch, 
input_size)
+            input_reshaped = self.block_builder.emit(
+                relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
+            )
+        else:
+            input_reshaped = input_tensor
+        weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, 
axes=[1, 0]))
+        weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, 
axes=[1, 0]))
+        outputs = []
+        for t in range(seq_len):
+            # Get input at time t: (batch_size, input_size)
+            x_t = self.block_builder.emit(
+                relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, 
mode="clip")
+            )
+            # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias
+            # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size, 
input_size).T
+            ih_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
+
+            # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size, 
hidden_size).T
+            hh_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
+            # Add biases if present
+            if bias_ih is not None and bias_hh is not None:
+                gates = self.block_builder.emit(
+                    relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), 
hh_gates), bias_hh)
+                )
+            elif bias_ih is not None:
+                gates = self.block_builder.emit(
+                    relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates)
+                )
+            elif bias_hh is not None:
+                gates = self.block_builder.emit(
+                    relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh)
+                )
+            else:
+                gates = self.block_builder.emit(relax.op.add(ih_gates, 
hh_gates))

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The logic for adding biases is quite verbose with multiple `if/elif/else` 
branches. This can be simplified for better readability and maintainability. 
You can calculate the total gates first, and then conditionally add the biases.
   
   ```suggestion
               gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates))
               if bias_ih is not None:
                   gates = self.block_builder.emit(relax.op.add(gates, bias_ih))
               if bias_hh is not None:
                   gates = self.block_builder.emit(relax.op.add(gates, bias_hh))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to