This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9eac0e1635 [Relax][PyTorch] Fix scalar parameter inputs in Dynamo
(#18725)
9eac0e1635 is described below
commit 9eac0e1635c83904aedd1100c9471c608d89a7ad
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Feb 8 05:32:57 2026 +0900
[Relax][PyTorch] Fix scalar parameter inputs in Dynamo (#18725)
Ensure scalar parameter placeholders are forwarded to the Relax VM.
Fix a model reported in https://github.com/pytorch/pytorch/issues/169188
---
.../relax/frontend/torch/base_fx_graph_translator.py | 4 ++--
python/tvm/relax/frontend/torch/dynamo.py | 7 +++----
tests/python/relax/test_frontend_dynamo.py | 19 +++++++++++++++++++
3 files changed, 24 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d04dfbb6c3..447f4a4dc6 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -505,9 +505,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return call_binary_op(relax_op, lhs, rhs)
- elif isinstance(lhs, relax.expr.Constant):
+ elif isinstance(lhs, relax.expr.Constant) and not isinstance(rhs,
relax.expr.Constant):
return call_binary_op(relax_op, lhs, relax.const(rhs,
dtype=lhs.struct_info.dtype))
- elif isinstance(rhs, relax.expr.Constant):
+ elif isinstance(rhs, relax.expr.Constant) and not isinstance(lhs,
relax.expr.Constant):
return call_binary_op(relax_op, relax.const(lhs,
dtype=rhs.struct_info.dtype), rhs)
return intrinsic_op(lhs, rhs)
diff --git a/python/tvm/relax/frontend/torch/dynamo.py
b/python/tvm/relax/frontend/torch/dynamo.py
index 8dc9e2a55a..dea08256de 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -137,10 +137,9 @@ def relax_dynamo(pipeline: Optional[tvm.transform.Pass] =
None):
args = [a.contiguous() for a in i_args if isinstance(a,
torch.Tensor)]
vm_args = list()
for arg in args:
- if arg.dim() != 0:
- if arg.requires_grad:
- arg = arg.detach()
- vm_args.append(to_tvm_tensor(arg))
+ if arg.requires_grad:
+ arg = arg.detach()
+ vm_args.append(to_tvm_tensor(arg))
outputs = vm["main"](*vm_args)
return to_torch_tensor(outputs)
diff --git a/tests/python/relax/test_frontend_dynamo.py
b/tests/python/relax/test_frontend_dynamo.py
index a48907eae5..b6d2345571 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -123,6 +123,25 @@ def test_relax_dynamo():
tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5,
atol=1e-5)
+def test_relax_dynamo_scalar_params():
+ class ScalarParams(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.x = torch.nn.Parameter(torch.tensor(1.0))
+ self.y = torch.nn.Parameter(torch.tensor(2.0))
+
+ def forward(self):
+ return self.x + self.y
+
+ model = ScalarParams()
+
+ opt_model = torch.compile(model, backend=relax_dynamo())
+
+ default_output = model().detach().numpy()
+ optimized_output = opt_model().detach().numpy()
+ tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5,
atol=1e-5)
+
+
def test_relax_dynamo_dynamic():
class Input1(torch.nn.Module):
def __init__(self):