This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dcad7373c7 [Relax][Torch] Avoid decomposition crash with sparse CSR
buffers (#18670)
dcad7373c7 is described below
commit dcad7373c7fe25b571d8bd7c8b57b5ab993510b5
Author: YinHanke <[email protected]>
AuthorDate: Fri Feb 6 04:36:06 2026 +0800
[Relax][Torch] Avoid decomposition crash with sparse CSR buffers (#18670)
### Motivation
The Relax Torch frontend crashes when importing an exported program that
includes
a torch.sparse_csr_tensor registered as a buffer. The crash happens
during
from_exported_program because run_decompositions() triggers a PyTorch
layout_impl error for sparse tensors.
This PR avoids the crash while keeping the import pipeline functional
for such
models, even though Relax does not yet support sparse tensors.
### Changes
- Skip run_decompositions() when the exported program contains sparse
tensors
- Treat aten.to_sparse.default as a no-op in the Relax Torch frontend
- Add a regression test that imports a model with a sparse CSR buffer
### Testing
- test_frontend_from_exported_program.py
Fixes: [[Bug] Relax Torch frontend crash with sparse CSR buffer in
ExportedProgramhttps://github.com/apache/tvm/issues/18648](https://github.com/apache/tvm/issues/18648)
---
.../frontend/torch/exported_program_translator.py | 28 +++++++++++++++++++++-
.../relax/test_frontend_from_exported_program.py | 25 +++++++++++++++++++
2 files changed, 52 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0a97614eb5..959f7fd4a2 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -113,6 +113,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
return self.block_builder.emit(relax.op.rsqrt(x))
+ def _to_sparse(self, node: fx.Node) -> relax.Var:
+ """Fallback for sparse conversion: Relax does not support sparse
tensors yet."""
+ args = self.retrieve_args(node)
+ return args[0]
+
########## Neural Network ##########
def _batch_norm(self, node: fx.Node, training: bool, return_tuple: bool =
False) -> relax.Var:
@@ -1254,6 +1259,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"square.default": self._unary_op(relax.op.square),
"tan.default": self._unary_op(relax.op.tan),
"tanh.default": self._unary_op(relax.op.tanh),
+ "to_sparse.default": self._to_sparse,
"tril.default": self._tril_triu(relax.op.tril),
"triu.default": self._tril_triu(relax.op.triu),
"trunc.default": self._unary_op(relax.op.trunc),
@@ -1812,8 +1818,28 @@ def from_exported_program(
# Use the importer to import the ExportedProgram to Relax.
mod: tvm.IRModule = from_exported_program(exported_program)
"""
+
+ def _is_sparse_tensor(value: object) -> bool:
+ if not isinstance(value, torch.Tensor):
+ return False
+ try:
+ return value.layout != torch.strided
+ except RuntimeError:
+ return False
+
+ def _has_sparse_tensors(ep: torch.export.ExportedProgram) -> bool:
+ from itertools import chain
+
+ all_potential_tensors = chain(
+ (t for _, t in ep.named_buffers()),
+ (t for _, t in ep.named_parameters()),
+ ep.constants.values(),
+ ep.tensor_constants.values(),
+ )
+ return any(_is_sparse_tensor(t) for t in all_potential_tensors)
+
# Conditionally decompose into Core ATen operators
- if run_ep_decomposition:
+ if run_ep_decomposition and not _has_sparse_tensors(exported_program):
exported_program = exported_program.run_decompositions()
return ExportedProgramImporter().from_exported_program(
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 01a24ada1f..374c21d560 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -8787,5 +8787,30 @@ def test_upsample_nearest2d():
verify_model(UpsampleNearest2dSize(), example_args, {}, expected_size)
+def test_from_exported_program_sparse_csr_buffer():
+ class SparseCsrBufferModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ crow_indices = torch.tensor([0, 1, 2], dtype=torch.int64)
+ col_indices = torch.tensor([0, 1], dtype=torch.int64)
+ values = torch.tensor([1.0, 1.0], dtype=torch.float32,
requires_grad=True)
+ csr_tensor = torch.sparse_csr_tensor(
+ crow_indices, col_indices, values, dtype=torch.float32
+ )
+ self.register_buffer("csr_tensor", csr_tensor)
+ self.csr_tensor.requires_grad_(True)
+
+ def forward(self, x):
+ csr2 = self.csr_tensor.to_sparse(layout=torch.sparse_csr)
+ y = torch.matmul(csr2, x)
+ return y.sum()
+
+ model = SparseCsrBufferModule().eval()
+ x = torch.ones((2, 1), dtype=torch.float32)
+ exported_program = export(model, (x,))
+ mod = from_exported_program(exported_program)
+ assert isinstance(mod, tvm.IRModule)
+
+
if __name__ == "__main__":
tvm.testing.main()