This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c12debb462 [Relax][PyTorch] Fix segfault in from_exported_program when
model uses index_put_ with tuple output (#19488)
c12debb462 is described below
commit c12debb462c2e43270987db15ebe5074d33c190d
Author: Neo Chien <[email protected]>
AuthorDate: Fri May 8 17:49:18 2026 +0800
[Relax][PyTorch] Fix segfault in from_exported_program when model uses
index_put_ with tuple output (#19488)
Hi Committers,
This PR is trying to fix issues
https://github.com/apache/tvm/issues/18363. Any suggestions would be
appreciated if you are available.
### Root Cause
- When an ExportedProgram's FX graph output node returns a **nested
Python tuple** (e.g., buffer mutation outputs + user-defined tuple
returns), `_translate_fx_graph()` passes the raw nested structure
directly to the Relax FFI Tuple constructor.
- The C++ Array<Expr> initializer cannot handle heterogeneous/nested
Python containers, causing a segmentation fault at `expr.cc`.
- Additionally, index_put_ (in-place write op) did not update self.env
to alias the source tensor to the mutated output, causing subsequent FX
nodes that read the same tensor to observe **stale pre-mutation
values**.
### Solution
- exported_program_translator.py
- Added static method `_flatten_output_args()` that recursively walks
any Python `tuple/list`, collects only `relax.Expr` leaves, and preserve
explicit None outputs as Relax null objects.
- Replaced the fragile `assert isinstance(output_args, tuple |
relax.Tuple)` guard with a call to `_flatten_output_args()`, producing a
clean flat tuple of `relax.Expr` before FFI construction.
- base_fx_graph_translator.py
- In `_index_put()`, after emitting the `relax.op.index_put(...)` call,
added an env alias update: `self.env[source_node] = output` when the
target op name starts with `index_put_`, preserving correct in-place
mutation semantics for downstream FX nodes.
---------
Co-authored-by: cchung100m <[email protected]>
---
.../frontend/torch/base_fx_graph_translator.py | 20 +++-
.../frontend/torch/exported_program_translator.py | 37 ++++++-
.../relax/test_frontend_from_exported_program.py | 108 +++++++++++++++++++++
3 files changed, 162 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 138176155a..89c91e3773 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1921,7 +1921,25 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
indices = relax.Tuple(processed_indices)
else:
indices = relax.Tuple(indices)
- return self.block_builder.emit(relax.op.index_put(tensor, indices,
values, accumulate))
+
+ output = self.block_builder.emit(relax.op.index_put(tensor, indices,
values, accumulate))
+
+ target_name = (
+ node.target if isinstance(node.target, str) else
getattr(node.target, "__name__", "")
+ )
+ if target_name.startswith("index_put_") and len(node.args) > 0:
+ from torch import fx
+
+ if isinstance(node.args[0], fx.Node):
+ # `index_put_` is in-place. If the mutated input is an alias
of another
+ # FX node, later reads via either the alias node or the
original node
+ # must oberve the updated tensor.
+ aliased_expr = tensor
+ for env_node, env_expr in list(self.env.items()):
+ if env_expr is aliased_expr:
+ self.env[env_node] = output
+
+ return output
def _index_tensor(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index cc37554bf3..5bd2c785f2 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1338,7 +1338,40 @@ class ExportedProgramImporter(BaseFXGraphImporter):
raise ValueError(f"Unsupported op {node.op}")
assert output_args is not None
- return output_args
+ return self._flatten_output_args(output_args)
+
+ @staticmethod
+ def _flatten_output_args(output_args) -> tuple[relax.Expr, ...]:
+ """Flatten output args into a tuple of Relax expressions.
+
+ ExportedProgram output trees contain nested Python tuple/list
containers
+ (e.g. mutation outputs + user tuple outputs). Emitting nested Python
tuples
+ directly through FFI may construct invalid Relax tuples.
+ """
+
+ flattened: list[relax.Expr] = []
+
+ def _visit(value):
+ if isinstance(value, relax.Expr):
+ flattened.append(value)
+ elif isinstance(value, list | tuple):
+ for item in value:
+ _visit(item)
+ elif value is None:
+ # Preserve explicit None outputs as Relax null objects.
+ flattened.append(relax.op.null_value())
+ else:
+ raise ValueError(
+ "Unsupported output type in exported graph output: "
+ f"{type(value)}"
+ )
+
+ _visit(output_args)
+
+ if not flattened:
+ raise ValueError("Exported graph produced no Relax outputs")
+
+ return tuple(flattened)
def _import_branch_subgraph(
self,
@@ -1995,7 +2028,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
output_args = self._translate_fx_graph(
exported_program.graph_module, nodes, inputs_vars,
custom_ops
)
- assert isinstance(output_args, tuple | relax.Tuple)
+ output_args = self._flatten_output_args(output_args)
if unwrap_unit_return_tuple and len(output_args) == 1:
ret = output_args[0]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index e2f9751c15..f3e2e581e1 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7402,6 +7402,114 @@ def test_index_put():
verify_model(IndexPutBatchedWithNone(), example_args_batched_none, {},
ExpectedBatchedWithNone)
+def test_index_put_with_tuple_output():
+ class IndexPutTupleOutput(Module):
+ def forward(self, x, l, idx):
+ values = x
+ l[..., idx, idx] = values
+ return x[..., 1], l
+
+ example_args = (
+ torch.ones(2, 3, 5, dtype=torch.float32),
+ torch.zeros(2, 3, 5, 5, dtype=torch.float32),
+ torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64),
+ )
+
+ exported_program = export(IndexPutTupleOutput(), args=example_args)
+ mod = from_exported_program(exported_program)
+
+ ret_sinfo = mod["main"].ret_struct_info
+ assert isinstance(ret_sinfo, relax.TupleStructInfo)
+
+ tensor_fields = [f for f in ret_sinfo.fields if isinstance(f,
relax.TensorStructInfo)]
+ assert len(tensor_fields) >= 2
+
+ assert any(
+ len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5
+ for f in tensor_fields
+ )
+
+
+def test_m4d_diag_index_put_tuple_output_regression():
+ class M4D(Module):
+ def forward(self, x):
+ b, k, n = 2, 3, 5
+ l = x.new_zeros(b, k, n, n)
+ idx = torch.arange(n, device=x.device)
+
+ diag = l[..., idx, idx]
+ diag = torch.nn.functional.elu(diag) + 1.0 + 1e-8
+ l[..., idx, idx] = diag
+
+ return x[..., :1], l
+
+ ex_in = torch.zeros(2, 3, 5, dtype=torch.float32)
+ exported_program = export(M4D().eval(), args=(ex_in,))
+
+ exported_targets = [str(getattr(n, "target", "")) for n in
exported_program.graph.nodes]
+ assert any("index_put" in target for target in exported_targets)
+
+ # Regression focus: importing this graph should not segfault at Tuple
construction.
+ mod = from_exported_program(exported_program)
+ ret_sinfo = mod["main"].ret_struct_info
+ assert isinstance(ret_sinfo, relax.TupleStructInfo)
+
+ tensor_fields = [f for f in ret_sinfo.fields if isinstance(f,
relax.TensorStructInfo)]
+ assert len(tensor_fields) >= 2
+ # x: (2, 3, 5) → x[..., :1]: (2, 3, 1)
+ assert any(len(f.shape) == 3 and int(f.shape[-1]) == 1 for f in
tensor_fields)
+ # l: (2, 3, 5, 5) → 4-D with spatial dims 5×5
+ assert any(
+ len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5
+ for f in tensor_fields
+ )
+
+
+def test_index_put_mutation_through_alias_regression():
+ class IndexPutAlias(Module):
+ def forward(self, x, idx, values):
+ y = torch.ops.aten.alias.default(x)
+ y[idx] = values
+ return x, y
+
+ example_args = (
+ torch.zeros(5, dtype=torch.float32),
+ torch.tensor([1, 3], dtype=torch.int64),
+ torch.tensor([2.0, 4.0], dtype=torch.float32),
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((5,), dtype="float32"),
+ idx: R.Tensor((2,), dtype="int64"),
+ values: R.Tensor((2,), dtype="float32"),
+ ) -> R.Tuple(
+ R.Tensor((5,), dtype="float32"),
+ R.Tensor((5,), dtype="float32"),
+ R.Tensor((5,), dtype="float32"),
+ ):
+ with R.dataflow():
+ lv: R.Tensor((5,), dtype="float32") = R.index_put(
+ x, (idx,), values, accumulate=False
+ )
+ # ExportedProgram may include an additional mutation output.
+ gv: R.Tuple(
+ R.Tensor((5,), dtype="float32"),
+ R.Tensor((5,), dtype="float32"),
+ R.Tensor((5,), dtype="float32"),
+ ) = (
+ lv,
+ lv,
+ lv,
+ )
+ R.output(gv)
+ return gv
+
+ verify_model(IndexPutAlias(), example_args, {}, Expected)
+
+
def test_flip():
class Flip0(Module):
def forward(self, data):