gemini-code-assist[bot] commented on code in PR #18346:
URL: https://github.com/apache/tvm/pull/18346#discussion_r2380853039
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
align_corners=align_corners,
)
+ def _lstm(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ input_tensor = args[0]
+ hx = args[1] if len(args) > 1 else None
+ params = args[2] if len(args) > 2 else None
+ has_biases = args[3] if len(args) > 3 else True
+ num_layers = args[4] if len(args) > 4 else 1
+ dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
+ train = args[6] if len(args) > 6 else False # Not used in inference
+ bidirectional = args[7] if len(args) > 7 else False
+ batch_first = args[8] if len(args) > 8 else False
+ if bidirectional:
+ raise NotImplementedError("Bidirectional LSTM is not yet
supported")
+ if num_layers > 1:
+ raise NotImplementedError("Multi-layer LSTM is not yet supported")
+ input_shape = self.shape_of(input_tensor)
+ if batch_first:
+ # Input shape: (batch, seq_len, input_size)
+ batch_size, seq_len, input_size = input_shape
+ else:
+ # Input shape: (seq_len, batch, input_size)
+ seq_len, batch_size, input_size = input_shape
+
+ if hasattr(seq_len, "value"):
+ seq_len = seq_len.value
+ if hasattr(batch_size, "value"):
+ batch_size = batch_size.value
+ if hasattr(input_size, "value"):
+ input_size = input_size.value
Review Comment:

The logic to extract shape components like `seq_len` does not correctly
handle dynamic shapes. Using `hasattr(seq_len, "value")` and then accessing
`.value` only works for `tir.IntImm`, but not for `tir.Var` which represents a
dynamic dimension. This will cause an `AttributeError` when tracing models with
dynamic sequence lengths.
Additionally, the loop `for t in range(seq_len):` requires `seq_len` to be a
compile-time constant, which prevents support for dynamic sequence lengths. The
implementation should be updated to handle dynamic shapes, for example by using
`relax.scan` instead of unrolling the loop.
```suggestion
if isinstance(seq_len, tvm.tir.IntImm):
seq_len = seq_len.value
if isinstance(batch_size, tvm.tir.IntImm):
batch_size = batch_size.value
if isinstance(input_size, tvm.tir.IntImm):
input_size = input_size.value
```
##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1997,6 +1997,12 @@ def _getitem(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.TupleGetItem(x,
node.args[1]))
assert isinstance(x.struct_info, relax.TensorStructInfo)
+ if isinstance(node.args[1], int):
+ return x
+ if not isinstance(node.args[1], (list, tuple)):
+ indices = [node.args[1]]
+ else:
+ indices = node.args[1]
Review Comment:

The new logic for handling integer indexing on tensors is incorrect. It
returns the tensor itself, which breaks the semantics of tensor indexing. For a
tensor `x`, `x[0]` should return the first slice along axis 0, which has a
reduced rank. The current implementation returns `x` unmodified. Since
`_getitem` is a general-purpose function, this change can cause incorrect
behavior for other operators that rely on it.
This seems to be a workaround for an incomplete `_lstm` implementation. The
correct fix should be in the `_lstm` operator implementation to return a proper
tuple output, and this logic should be removed from `_getitem`.
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
align_corners=align_corners,
)
+ def _lstm(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ input_tensor = args[0]
+ hx = args[1] if len(args) > 1 else None
+ params = args[2] if len(args) > 2 else None
+ has_biases = args[3] if len(args) > 3 else True
+ num_layers = args[4] if len(args) > 4 else 1
+ dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
+ train = args[6] if len(args) > 6 else False # Not used in inference
+ bidirectional = args[7] if len(args) > 7 else False
+ batch_first = args[8] if len(args) > 8 else False
+ if bidirectional:
+ raise NotImplementedError("Bidirectional LSTM is not yet
supported")
+ if num_layers > 1:
+ raise NotImplementedError("Multi-layer LSTM is not yet supported")
+ input_shape = self.shape_of(input_tensor)
+ if batch_first:
+ # Input shape: (batch, seq_len, input_size)
+ batch_size, seq_len, input_size = input_shape
+ else:
+ # Input shape: (seq_len, batch, input_size)
+ seq_len, batch_size, input_size = input_shape
+
+ if hasattr(seq_len, "value"):
+ seq_len = seq_len.value
+ if hasattr(batch_size, "value"):
+ batch_size = batch_size.value
+ if hasattr(input_size, "value"):
+ input_size = input_size.value
+ # Extract hidden size from the LSTM parameters
+ # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
+ # weight_ih shape: (4 * hidden_size, input_size)
+ # weight_hh shape: (4 * hidden_size, hidden_size)
+ if params and len(params) >= 2:
+ weight_ih = params[0]
+ weight_hh = params[1]
+ # Extract hidden size from weight dimensions
+ # weight_ih has shape (4 * hidden_size, input_size)
+ weight_ih_shape = self.shape_of(weight_ih)
+ hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget,
cell, output
+ else:
+ # Fallback to a default hidden size
+ hidden_size = 16
+ # Implement actual LSTM computation using Relax operations
+ # LSTM equations:
+ # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi)
+ # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf)
+ # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg)
+ # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho)
+ # c_t = f_t * c_{t-1} + i_t * g_t
+ # h_t = o_t * tanh(c_t)
+ dtype = input_tensor.struct_info.dtype
+ if params and len(params) >= 4:
+ weight_ih = params[0] # (4 * hidden_size, input_size)
+ weight_hh = params[1] # (4 * hidden_size, hidden_size)
+ bias_ih = params[2] if has_biases else None # (4 * hidden_size,)
+ bias_hh = params[3] if has_biases else None # (4 * hidden_size,)
+ else:
+ # Fallback: create zero weights
+ weight_ih = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)),
dtype)
+ )
+ weight_hh = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((4 * hidden_size,
hidden_size)), dtype)
+ )
+ bias_ih = None
+ bias_hh = None
+ # Initialize hidden and cell states
+ if hx is not None and len(hx) >= 2:
+ h_0 = hx[0] # (num_layers, batch_size, hidden_size)
+ c_0 = hx[1] # (num_layers, batch_size, hidden_size)
+ # Extract the first layer's hidden state
+ h_prev = self.block_builder.emit(
+ relax.op.take(h_0, relax.const(0, "int64"), axis=0,
mode="clip")
+ )
+ c_prev = self.block_builder.emit(
+ relax.op.take(c_0, relax.const(0, "int64"), axis=0,
mode="clip")
+ )
+ else:
+ h_prev = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
+ )
+ c_prev = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
+ )
+ # Reshape input for processing
+ if batch_first:
+ # Input: (batch, seq_len, input_size) -> (seq_len, batch,
input_size)
+ input_reshaped = self.block_builder.emit(
+ relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
+ )
+ else:
+ input_reshaped = input_tensor
+ weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih,
axes=[1, 0]))
+ weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh,
axes=[1, 0]))
+ outputs = []
+ for t in range(seq_len):
+ # Get input at time t: (batch_size, input_size)
+ x_t = self.block_builder.emit(
+ relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0,
mode="clip")
+ )
+ # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias
+ # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size,
input_size).T
+ ih_gates =
self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
+
+ # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size,
hidden_size).T
+ hh_gates =
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
+ # Add biases if present
+ if bias_ih is not None and bias_hh is not None:
+ gates = self.block_builder.emit(
+ relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih),
hh_gates), bias_hh)
+ )
+ elif bias_ih is not None:
+ gates = self.block_builder.emit(
+ relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates)
+ )
+ elif bias_hh is not None:
+ gates = self.block_builder.emit(
+ relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh)
+ )
+ else:
+ gates = self.block_builder.emit(relax.op.add(ih_gates,
hh_gates))
+ # Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size,
hidden_size)
+ gate_size = hidden_size
+ i_gate = self.block_builder.emit(
+ relax.op.strided_slice(gates, axes=[1], begin=[0],
end=[gate_size])
+ )
+ f_gate = self.block_builder.emit(
+ relax.op.strided_slice(gates, axes=[1], begin=[gate_size],
end=[2 * gate_size])
+ )
+ g_gate = self.block_builder.emit(
+ relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size],
end=[3 * gate_size])
+ )
+ o_gate = self.block_builder.emit(
+ relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size],
end=[4 * gate_size])
+ )
Review Comment:

The four gates (input, forget, cell, output) are split from the concatenated
gates tensor using four separate `strided_slice` operations. This can be done
more efficiently and concisely using a single `relax.op.split` operation, which
would also improve readability.
```python
gate_tuple = self.block_builder.emit(relax.op.split(gates, 4,
axis=1))
i_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple,
0))
f_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple,
1))
g_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple,
2))
o_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple,
3))
```
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
align_corners=align_corners,
)
+ def _lstm(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ input_tensor = args[0]
+ hx = args[1] if len(args) > 1 else None
+ params = args[2] if len(args) > 2 else None
+ has_biases = args[3] if len(args) > 3 else True
+ num_layers = args[4] if len(args) > 4 else 1
+ dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
+ train = args[6] if len(args) > 6 else False # Not used in inference
+ bidirectional = args[7] if len(args) > 7 else False
+ batch_first = args[8] if len(args) > 8 else False
+ if bidirectional:
+ raise NotImplementedError("Bidirectional LSTM is not yet
supported")
+ if num_layers > 1:
+ raise NotImplementedError("Multi-layer LSTM is not yet supported")
+ input_shape = self.shape_of(input_tensor)
+ if batch_first:
+ # Input shape: (batch, seq_len, input_size)
+ batch_size, seq_len, input_size = input_shape
+ else:
+ # Input shape: (seq_len, batch, input_size)
+ seq_len, batch_size, input_size = input_shape
+
+ if hasattr(seq_len, "value"):
+ seq_len = seq_len.value
+ if hasattr(batch_size, "value"):
+ batch_size = batch_size.value
+ if hasattr(input_size, "value"):
+ input_size = input_size.value
+ # Extract hidden size from the LSTM parameters
+ # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
+ # weight_ih shape: (4 * hidden_size, input_size)
+ # weight_hh shape: (4 * hidden_size, hidden_size)
+ if params and len(params) >= 2:
+ weight_ih = params[0]
+ weight_hh = params[1]
+ # Extract hidden size from weight dimensions
+ # weight_ih has shape (4 * hidden_size, input_size)
+ weight_ih_shape = self.shape_of(weight_ih)
+ hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget,
cell, output
+ else:
+ # Fallback to a default hidden size
+ hidden_size = 16
Review Comment:

The fallback logic for when LSTM parameters are not provided is problematic.
It defaults to `hidden_size = 16`. This can lead to silent correctness issues
and hard-to-debug errors. It would be better to raise a `ValueError` if the
parameters are not available to determine `hidden_size`, as a valid LSTM layer
must have weights.
```suggestion
else:
raise ValueError("Cannot determine hidden_size. LSTM params
(weights) are required.")
```
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
align_corners=align_corners,
)
+ def _lstm(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ input_tensor = args[0]
+ hx = args[1] if len(args) > 1 else None
+ params = args[2] if len(args) > 2 else None
+ has_biases = args[3] if len(args) > 3 else True
+ num_layers = args[4] if len(args) > 4 else 1
+ dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
+ train = args[6] if len(args) > 6 else False # Not used in inference
+ bidirectional = args[7] if len(args) > 7 else False
+ batch_first = args[8] if len(args) > 8 else False
+ if bidirectional:
+ raise NotImplementedError("Bidirectional LSTM is not yet
supported")
+ if num_layers > 1:
+ raise NotImplementedError("Multi-layer LSTM is not yet supported")
+ input_shape = self.shape_of(input_tensor)
+ if batch_first:
+ # Input shape: (batch, seq_len, input_size)
+ batch_size, seq_len, input_size = input_shape
+ else:
+ # Input shape: (seq_len, batch, input_size)
+ seq_len, batch_size, input_size = input_shape
+
+ if hasattr(seq_len, "value"):
+ seq_len = seq_len.value
+ if hasattr(batch_size, "value"):
+ batch_size = batch_size.value
+ if hasattr(input_size, "value"):
+ input_size = input_size.value
+ # Extract hidden size from the LSTM parameters
+ # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
+ # weight_ih shape: (4 * hidden_size, input_size)
+ # weight_hh shape: (4 * hidden_size, hidden_size)
+ if params and len(params) >= 2:
+ weight_ih = params[0]
+ weight_hh = params[1]
+ # Extract hidden size from weight dimensions
+ # weight_ih has shape (4 * hidden_size, input_size)
+ weight_ih_shape = self.shape_of(weight_ih)
+ hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget,
cell, output
+ else:
+ # Fallback to a default hidden size
+ hidden_size = 16
+ # Implement actual LSTM computation using Relax operations
+ # LSTM equations:
+ # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi)
+ # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf)
+ # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg)
+ # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho)
+ # c_t = f_t * c_{t-1} + i_t * g_t
+ # h_t = o_t * tanh(c_t)
+ dtype = input_tensor.struct_info.dtype
+ if params and len(params) >= 4:
+ weight_ih = params[0] # (4 * hidden_size, input_size)
+ weight_hh = params[1] # (4 * hidden_size, hidden_size)
+ bias_ih = params[2] if has_biases else None # (4 * hidden_size,)
+ bias_hh = params[3] if has_biases else None # (4 * hidden_size,)
+ else:
+ # Fallback: create zero weights
+ weight_ih = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)),
dtype)
+ )
+ weight_hh = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((4 * hidden_size,
hidden_size)), dtype)
+ )
+ bias_ih = None
+ bias_hh = None
Review Comment:

Creating zero-tensors for weights as a fallback is problematic. This can
lead to silent correctness issues where the model compiles but produces
incorrect (zero) outputs. It's better to raise an error if weights are not
provided, as they are essential for a functional LSTM layer.
```suggestion
else:
raise ValueError("LSTM params (weights) are required.")
```
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -5914,6 +5915,102 @@ def main(
verify_model(Model(), example_args, {}, Expected)
+def test_mm():
+ class MatrixMultiply(Module):
+ def forward(self, a, b):
+ return torch.mm(a, b)
+
+ example_args = (
+ torch.randn(2, 3, dtype=torch.float32),
+ torch.randn(3, 4, dtype=torch.float32),
+ )
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ a: R.Tensor((2, 3), dtype="float32"),
+ b: R.Tensor((3, 4), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 4), dtype="float32") = R.matmul(a, b,
out_dtype="float32")
+ gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(MatrixMultiply(), example_args, {}, Expected)
+
+
+def test_lstm():
+ class BasicLSTM(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.lstm = nn.LSTM(
+ input_size=4,
+ hidden_size=8,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=False,
+ )
+
+ def forward(self, x):
+ y, _ = self.lstm(x)
+ return y
+
+ torch.manual_seed(42)
+ x = torch.randn(2, 3, 4, dtype=torch.float32)
+ model = BasicLSTM()
+ with torch.no_grad():
+ pytorch_output = model(x)
+ exported_program = export(model, args=(x,))
+ mod = from_exported_program(exported_program)
+ target = tvm.target.Target("llvm")
+ ex = relax.build(mod, target)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ x_tvm = tvm.runtime.tensor(x.numpy())
+ tvm_output = vm["main"](x_tvm)
+ if hasattr(tvm_output, "numpy"):
+ tvm_output_np = tvm_output.numpy()
+ else:
+ tvm_output_np = tvm_output[0].numpy()
+ assert (
+ pytorch_output.shape == tvm_output_np.shape
+ ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM
{tvm_output_np.shape}"
+ np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np,
rtol=1e-4, atol=1e-5)
+
+ class SeqFirstLSTM(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.lstm = nn.LSTM(
+ input_size=3,
+ hidden_size=6,
+ num_layers=1,
+ batch_first=False,
+ bidirectional=False,
+ )
+
+ def forward(self, x):
+ y, _ = self.lstm(x)
+ return y
+
+ torch.manual_seed(43)
+ x2 = torch.randn(4, 2, 3, dtype=torch.float32)
+ model2 = SeqFirstLSTM()
+ with torch.no_grad():
+ pytorch_output2 = model2(x2)
+ exported_program2 = export(model2, args=(x2,))
+ mod2 = from_exported_program(exported_program2)
+ ex2 = relax.build(mod2, target)
+ vm2 = relax.VirtualMachine(ex2, tvm.cpu())
+ x2_tvm = tvm.runtime.tensor(x2.numpy())
+ tvm_output2 = vm2["main"](x2_tvm)
+ if hasattr(tvm_output2, "numpy"):
+ tvm_output2_np = tvm_output2.numpy()
+ else:
+ tvm_output2_np = tvm_output2[0].numpy()
+ assert pytorch_output2.shape == tvm_output2_np.shape
+ np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np,
rtol=1e-4, atol=1e-5)
Review Comment:

The `test_lstm` function contains a significant amount of duplicated code
for testing the `batch_first=True` and `batch_first=False` cases. This can be
refactored into a helper function to improve readability and maintainability.
The helper function could take the model and input tensor as arguments and
perform the verification logic.
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
align_corners=align_corners,
)
+ def _lstm(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ input_tensor = args[0]
+ hx = args[1] if len(args) > 1 else None
+ params = args[2] if len(args) > 2 else None
+ has_biases = args[3] if len(args) > 3 else True
+ num_layers = args[4] if len(args) > 4 else 1
+ dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
+ train = args[6] if len(args) > 6 else False # Not used in inference
+ bidirectional = args[7] if len(args) > 7 else False
+ batch_first = args[8] if len(args) > 8 else False
+ if bidirectional:
+ raise NotImplementedError("Bidirectional LSTM is not yet
supported")
+ if num_layers > 1:
+ raise NotImplementedError("Multi-layer LSTM is not yet supported")
+ input_shape = self.shape_of(input_tensor)
+ if batch_first:
+ # Input shape: (batch, seq_len, input_size)
+ batch_size, seq_len, input_size = input_shape
+ else:
+ # Input shape: (seq_len, batch, input_size)
+ seq_len, batch_size, input_size = input_shape
+
+ if hasattr(seq_len, "value"):
+ seq_len = seq_len.value
+ if hasattr(batch_size, "value"):
+ batch_size = batch_size.value
+ if hasattr(input_size, "value"):
+ input_size = input_size.value
+ # Extract hidden size from the LSTM parameters
+ # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
+ # weight_ih shape: (4 * hidden_size, input_size)
+ # weight_hh shape: (4 * hidden_size, hidden_size)
+ if params and len(params) >= 2:
+ weight_ih = params[0]
+ weight_hh = params[1]
+ # Extract hidden size from weight dimensions
+ # weight_ih has shape (4 * hidden_size, input_size)
+ weight_ih_shape = self.shape_of(weight_ih)
+ hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget,
cell, output
+ else:
+ # Fallback to a default hidden size
+ hidden_size = 16
+ # Implement actual LSTM computation using Relax operations
+ # LSTM equations:
+ # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi)
+ # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf)
+ # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg)
+ # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho)
+ # c_t = f_t * c_{t-1} + i_t * g_t
+ # h_t = o_t * tanh(c_t)
+ dtype = input_tensor.struct_info.dtype
+ if params and len(params) >= 4:
+ weight_ih = params[0] # (4 * hidden_size, input_size)
+ weight_hh = params[1] # (4 * hidden_size, hidden_size)
+ bias_ih = params[2] if has_biases else None # (4 * hidden_size,)
+ bias_hh = params[3] if has_biases else None # (4 * hidden_size,)
+ else:
+ # Fallback: create zero weights
+ weight_ih = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)),
dtype)
+ )
+ weight_hh = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((4 * hidden_size,
hidden_size)), dtype)
+ )
+ bias_ih = None
+ bias_hh = None
+ # Initialize hidden and cell states
+ if hx is not None and len(hx) >= 2:
+ h_0 = hx[0] # (num_layers, batch_size, hidden_size)
+ c_0 = hx[1] # (num_layers, batch_size, hidden_size)
+ # Extract the first layer's hidden state
+ h_prev = self.block_builder.emit(
+ relax.op.take(h_0, relax.const(0, "int64"), axis=0,
mode="clip")
+ )
+ c_prev = self.block_builder.emit(
+ relax.op.take(c_0, relax.const(0, "int64"), axis=0,
mode="clip")
+ )
+ else:
+ h_prev = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
+ )
+ c_prev = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
+ )
+ # Reshape input for processing
+ if batch_first:
+ # Input: (batch, seq_len, input_size) -> (seq_len, batch,
input_size)
+ input_reshaped = self.block_builder.emit(
+ relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
+ )
+ else:
+ input_reshaped = input_tensor
+ weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih,
axes=[1, 0]))
+ weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh,
axes=[1, 0]))
+ outputs = []
+ for t in range(seq_len):
+ # Get input at time t: (batch_size, input_size)
+ x_t = self.block_builder.emit(
+ relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0,
mode="clip")
+ )
+ # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias
+ # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size,
input_size).T
+ ih_gates =
self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
+
+ # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size,
hidden_size).T
+ hh_gates =
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
+ # Add biases if present
+ if bias_ih is not None and bias_hh is not None:
+ gates = self.block_builder.emit(
+ relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih),
hh_gates), bias_hh)
+ )
+ elif bias_ih is not None:
+ gates = self.block_builder.emit(
+ relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates)
+ )
+ elif bias_hh is not None:
+ gates = self.block_builder.emit(
+ relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh)
+ )
+ else:
+ gates = self.block_builder.emit(relax.op.add(ih_gates,
hh_gates))
+ # Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size,
hidden_size)
+ gate_size = hidden_size
+ i_gate = self.block_builder.emit(
+ relax.op.strided_slice(gates, axes=[1], begin=[0],
end=[gate_size])
+ )
+ f_gate = self.block_builder.emit(
+ relax.op.strided_slice(gates, axes=[1], begin=[gate_size],
end=[2 * gate_size])
+ )
+ g_gate = self.block_builder.emit(
+ relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size],
end=[3 * gate_size])
+ )
+ o_gate = self.block_builder.emit(
+ relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size],
end=[4 * gate_size])
+ )
+ # Apply activations
+ i_t = self.block_builder.emit(relax.op.sigmoid(i_gate))
+ f_t = self.block_builder.emit(relax.op.sigmoid(f_gate))
+ g_t = self.block_builder.emit(relax.op.tanh(g_gate))
+ o_t = self.block_builder.emit(relax.op.sigmoid(o_gate))
+ # Update cell state: c_t = f_t * c_{t-1} + i_t * g_t
+ c_t = self.block_builder.emit(
+ relax.op.add(relax.op.multiply(f_t, c_prev),
relax.op.multiply(i_t, g_t))
+ )
+ # Update hidden state: h_t = o_t * tanh(c_t)
+ h_t = self.block_builder.emit(relax.op.multiply(o_t,
relax.op.tanh(c_t)))
+ # Store output
+ outputs.append(h_t)
+ # Update for next iteration
+ h_prev = h_t
+ c_prev = c_t
+ # Stack outputs: (seq_len, batch_size, hidden_size)
+ output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+ # Reshape back to batch_first if needed
+ if batch_first:
+ # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len,
hidden_size)
+ output = self.block_builder.emit(relax.op.permute_dims(output,
axes=[1, 0, 2]))
+ return output
Review Comment:

The `_lstm` implementation is incomplete. It only returns the output
sequence but not the final hidden and cell states, which are part of the
standard `torch.nn.LSTM` output `(output, (h_n, c_n))`. This will lead to
incorrect behavior for models that use these states. The function should be
updated to return a tuple containing the output sequence and the final
hidden/cell states to fully match the PyTorch operator's behavior.
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
align_corners=align_corners,
)
+ def _lstm(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ input_tensor = args[0]
+ hx = args[1] if len(args) > 1 else None
+ params = args[2] if len(args) > 2 else None
+ has_biases = args[3] if len(args) > 3 else True
+ num_layers = args[4] if len(args) > 4 else 1
+ dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
+ train = args[6] if len(args) > 6 else False # Not used in inference
+ bidirectional = args[7] if len(args) > 7 else False
+ batch_first = args[8] if len(args) > 8 else False
+ if bidirectional:
+ raise NotImplementedError("Bidirectional LSTM is not yet
supported")
+ if num_layers > 1:
+ raise NotImplementedError("Multi-layer LSTM is not yet supported")
+ input_shape = self.shape_of(input_tensor)
+ if batch_first:
+ # Input shape: (batch, seq_len, input_size)
+ batch_size, seq_len, input_size = input_shape
+ else:
+ # Input shape: (seq_len, batch, input_size)
+ seq_len, batch_size, input_size = input_shape
+
+ if hasattr(seq_len, "value"):
+ seq_len = seq_len.value
+ if hasattr(batch_size, "value"):
+ batch_size = batch_size.value
+ if hasattr(input_size, "value"):
+ input_size = input_size.value
+ # Extract hidden size from the LSTM parameters
+ # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
+ # weight_ih shape: (4 * hidden_size, input_size)
+ # weight_hh shape: (4 * hidden_size, hidden_size)
+ if params and len(params) >= 2:
+ weight_ih = params[0]
+ weight_hh = params[1]
+ # Extract hidden size from weight dimensions
+ # weight_ih has shape (4 * hidden_size, input_size)
+ weight_ih_shape = self.shape_of(weight_ih)
+ hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget,
cell, output
+ else:
+ # Fallback to a default hidden size
+ hidden_size = 16
+ # Implement actual LSTM computation using Relax operations
+ # LSTM equations:
+ # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi)
+ # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf)
+ # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg)
+ # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho)
+ # c_t = f_t * c_{t-1} + i_t * g_t
+ # h_t = o_t * tanh(c_t)
+ dtype = input_tensor.struct_info.dtype
+ if params and len(params) >= 4:
+ weight_ih = params[0] # (4 * hidden_size, input_size)
+ weight_hh = params[1] # (4 * hidden_size, hidden_size)
+ bias_ih = params[2] if has_biases else None # (4 * hidden_size,)
+ bias_hh = params[3] if has_biases else None # (4 * hidden_size,)
+ else:
+ # Fallback: create zero weights
+ weight_ih = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)),
dtype)
+ )
+ weight_hh = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((4 * hidden_size,
hidden_size)), dtype)
+ )
+ bias_ih = None
+ bias_hh = None
+ # Initialize hidden and cell states
+ if hx is not None and len(hx) >= 2:
+ h_0 = hx[0] # (num_layers, batch_size, hidden_size)
+ c_0 = hx[1] # (num_layers, batch_size, hidden_size)
+ # Extract the first layer's hidden state
+ h_prev = self.block_builder.emit(
+ relax.op.take(h_0, relax.const(0, "int64"), axis=0,
mode="clip")
+ )
+ c_prev = self.block_builder.emit(
+ relax.op.take(c_0, relax.const(0, "int64"), axis=0,
mode="clip")
+ )
+ else:
+ h_prev = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
+ )
+ c_prev = self.block_builder.emit(
+ relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)),
dtype)
+ )
+ # Reshape input for processing
+ if batch_first:
+ # Input: (batch, seq_len, input_size) -> (seq_len, batch,
input_size)
+ input_reshaped = self.block_builder.emit(
+ relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
+ )
+ else:
+ input_reshaped = input_tensor
+ weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih,
axes=[1, 0]))
+ weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh,
axes=[1, 0]))
+ outputs = []
+ for t in range(seq_len):
+ # Get input at time t: (batch_size, input_size)
+ x_t = self.block_builder.emit(
+ relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0,
mode="clip")
+ )
+ # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias
+ # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size,
input_size).T
+ ih_gates =
self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
+
+ # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size,
hidden_size).T
+ hh_gates =
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
+ # Add biases if present
+ if bias_ih is not None and bias_hh is not None:
+ gates = self.block_builder.emit(
+ relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih),
hh_gates), bias_hh)
+ )
+ elif bias_ih is not None:
+ gates = self.block_builder.emit(
+ relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates)
+ )
+ elif bias_hh is not None:
+ gates = self.block_builder.emit(
+ relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh)
+ )
+ else:
+ gates = self.block_builder.emit(relax.op.add(ih_gates,
hh_gates))
Review Comment:

The logic for adding biases is quite verbose with multiple `if/elif/else`
branches. This can be simplified for better readability and maintainability.
You can calculate the total gates first, and then conditionally add the biases.
```suggestion
gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates))
if bias_ih is not None:
gates = self.block_builder.emit(relax.op.add(gates, bias_ih))
if bias_hh is not None:
gates = self.block_builder.emit(relax.op.add(gates, bias_hh))
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]