This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9eac0e1635 [Relax][PyTorch] Fix scalar parameter inputs in Dynamo 
(#18725)
9eac0e1635 is described below

commit 9eac0e1635c83904aedd1100c9471c608d89a7ad
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Feb 8 05:32:57 2026 +0900

    [Relax][PyTorch] Fix scalar parameter inputs in Dynamo (#18725)
    
    Ensure scalar parameter placeholders are forwarded to the Relax VM.
    Fix a model reported in https://github.com/pytorch/pytorch/issues/169188
---
 .../relax/frontend/torch/base_fx_graph_translator.py  |  4 ++--
 python/tvm/relax/frontend/torch/dynamo.py             |  7 +++----
 tests/python/relax/test_frontend_dynamo.py            | 19 +++++++++++++++++++
 3 files changed, 24 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d04dfbb6c3..447f4a4dc6 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -505,9 +505,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             lhs, rhs = self.retrieve_args(node)
             if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
                 return call_binary_op(relax_op, lhs, rhs)
-            elif isinstance(lhs, relax.expr.Constant):
+            elif isinstance(lhs, relax.expr.Constant) and not isinstance(rhs, 
relax.expr.Constant):
                 return call_binary_op(relax_op, lhs, relax.const(rhs, 
dtype=lhs.struct_info.dtype))
-            elif isinstance(rhs, relax.expr.Constant):
+            elif isinstance(rhs, relax.expr.Constant) and not isinstance(lhs, 
relax.expr.Constant):
                 return call_binary_op(relax_op, relax.const(lhs, 
dtype=rhs.struct_info.dtype), rhs)
             return intrinsic_op(lhs, rhs)
 
diff --git a/python/tvm/relax/frontend/torch/dynamo.py 
b/python/tvm/relax/frontend/torch/dynamo.py
index 8dc9e2a55a..dea08256de 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -137,10 +137,9 @@ def relax_dynamo(pipeline: Optional[tvm.transform.Pass] = 
None):
             args = [a.contiguous() for a in i_args if isinstance(a, 
torch.Tensor)]
             vm_args = list()
             for arg in args:
-                if arg.dim() != 0:
-                    if arg.requires_grad:
-                        arg = arg.detach()
-                    vm_args.append(to_tvm_tensor(arg))
+                if arg.requires_grad:
+                    arg = arg.detach()
+                vm_args.append(to_tvm_tensor(arg))
             outputs = vm["main"](*vm_args)
             return to_torch_tensor(outputs)
 
diff --git a/tests/python/relax/test_frontend_dynamo.py 
b/tests/python/relax/test_frontend_dynamo.py
index a48907eae5..b6d2345571 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -123,6 +123,25 @@ def test_relax_dynamo():
     tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5, 
atol=1e-5)
 
 
+def test_relax_dynamo_scalar_params():
+    class ScalarParams(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.x = torch.nn.Parameter(torch.tensor(1.0))
+            self.y = torch.nn.Parameter(torch.tensor(2.0))
+
+        def forward(self):
+            return self.x + self.y
+
+    model = ScalarParams()
+
+    opt_model = torch.compile(model, backend=relax_dynamo())
+
+    default_output = model().detach().numpy()
+    optimized_output = opt_model().detach().numpy()
+    tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5, 
atol=1e-5)
+
+
 def test_relax_dynamo_dynamic():
     class Input1(torch.nn.Module):
         def __init__(self):

Reply via email to