This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9265328563 [Relax][PyTorch] Re-enable test_subgraph_capture in dynamo 
test (#17925)
9265328563 is described below

commit 9265328563d3a59471ffa762b09e02f7eed01622
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Tue May 20 10:59:41 2025 +0900

    [Relax][PyTorch] Re-enable test_subgraph_capture in dynamo test (#17925)
    
    * fix test_subgraph_capture
    
    * no need to skip
---
 tests/python/relax/test_frontend_dynamo.py | 36 +++++++++---------------------
 1 file changed, 11 insertions(+), 25 deletions(-)

diff --git a/tests/python/relax/test_frontend_dynamo.py 
b/tests/python/relax/test_frontend_dynamo.py
index 3deed8c2bf..fb1544be68 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -157,10 +157,6 @@ def test_relax_dynamo_dynamic():
             tvm.testing.assert_allclose(opt_func(x, y), opt_func(x, y))
 
 
[email protected](
-    version.parse(torch_version) >= version.parse("2.6.0"),
-    reason="Tests not compatible with PyTorch >= 2.6",
-)
 def test_subgraph_capture():
     import torch
     from tvm.relax.frontend.torch.dynamo import dynamo_capture_subgraphs
@@ -178,13 +174,13 @@ def test_subgraph_capture():
         @R.function
         def subgraph_0(
             inp_0: R.Tensor((10, 100), dtype="float32"),
-            w0: R.Tensor((10, 100), dtype="float32"),
             w1: R.Tensor((10,), dtype="float32"),
+            w0: R.Tensor((10, 100), dtype="float32"),
         ) -> R.Tensor((10, 10), dtype="float32"):
             # block 0
             with R.dataflow():
-                lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, 
axes=None)
-                lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(inp_0, lv, 
out_dtype="float32")
+                lv: R.Tensor((100, 10), dtype="float32") = 
R.permute_dims(inp_0, axes=None)
+                lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(w0, lv, 
out_dtype="float32")
                 lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
                 lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
                 gv: R.Tensor((10, 10), dtype="float32") = lv3
@@ -193,10 +189,7 @@ def test_subgraph_capture():
 
     model = Input1()
     mod = dynamo_capture_subgraphs(model, torch.randn(10, 100))
-    binding = {"w0": model.lin.weight.detach().numpy(), "w1": 
model.lin.bias.detach().numpy()}
-    binding = {k: tvm.nd.array(v) for k, v in binding.items()}
-    expected = relax.transform.BindParams("subgraph_0", binding)(Expected1)
-    tvm.ir.assert_structural_equal(mod, expected)
+    tvm.ir.assert_structural_equal(mod, Expected1)
 
     def Input2(a, b):
         x = a / (torch.sin(a) + 1)
@@ -258,27 +251,20 @@ def test_subgraph_capture():
         ) -> R.Tensor((10, 10), dtype="float32"):
             # block 0
             with R.dataflow():
-                lv0 = R.add(inp_0, R.const(1, "float32"))
-                lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, 
axes=None)
-                lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(lv0, lv, 
out_dtype="float32")
-                lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
-                lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
-                gv: R.Tensor((10, 10), dtype="float32") = lv3
+                lv: R.Tensor((10, 100), dtype="float32") = R.add(inp_0, 
R.const(1.0, "float32"))
+                lv1: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, 
axes=None)
+                lv2: R.Tensor((10, 10), dtype="float32") = R.matmul(lv, lv1, 
out_dtype="float32")
+                lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, w1)
+                lv4: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv3)
+                gv: R.Tensor((10, 10), dtype="float32") = lv4
                 R.output(gv)
             return gv
 
     model = Input3()
     mod = dynamo_capture_subgraphs(model, torch.randn(10, 100), add_one=True)
-    binding = {"w0": model.lin.weight.detach().numpy(), "w1": 
model.lin.bias.detach().numpy()}
-    binding = {k: tvm.nd.array(v) for k, v in binding.items()}
-    expected = relax.transform.BindParams("subgraph_0", binding)(Expected3)
-    tvm.ir.assert_structural_equal(mod, expected)
+    tvm.ir.assert_structural_equal(mod, Expected3)
 
 
[email protected](
-    version.parse(torch_version) >= version.parse("2.6.0"),
-    reason="Tests not compatible with PyTorch >= 2.6",
-)
 def verify_dynamo_model(torch_model, input_info, binding, expected):
     import torch
     import torch._dynamo as dynamo

Reply via email to