This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9265328563 [Relax][PyTorch] Re-enable test_subgraph_capture in dynamo
test (#17925)
9265328563 is described below
commit 9265328563d3a59471ffa762b09e02f7eed01622
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Tue May 20 10:59:41 2025 +0900
[Relax][PyTorch] Re-enable test_subgraph_capture in dynamo test (#17925)
* fix test_subgraph_capture
* no need to skip
---
tests/python/relax/test_frontend_dynamo.py | 36 +++++++++---------------------
1 file changed, 11 insertions(+), 25 deletions(-)
diff --git a/tests/python/relax/test_frontend_dynamo.py
b/tests/python/relax/test_frontend_dynamo.py
index 3deed8c2bf..fb1544be68 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -157,10 +157,6 @@ def test_relax_dynamo_dynamic():
tvm.testing.assert_allclose(opt_func(x, y), opt_func(x, y))
[email protected](
- version.parse(torch_version) >= version.parse("2.6.0"),
- reason="Tests not compatible with PyTorch >= 2.6",
-)
def test_subgraph_capture():
import torch
from tvm.relax.frontend.torch.dynamo import dynamo_capture_subgraphs
@@ -178,13 +174,13 @@ def test_subgraph_capture():
@R.function
def subgraph_0(
inp_0: R.Tensor((10, 100), dtype="float32"),
- w0: R.Tensor((10, 100), dtype="float32"),
w1: R.Tensor((10,), dtype="float32"),
+ w0: R.Tensor((10, 100), dtype="float32"),
) -> R.Tensor((10, 10), dtype="float32"):
# block 0
with R.dataflow():
- lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0,
axes=None)
- lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(inp_0, lv,
out_dtype="float32")
+ lv: R.Tensor((100, 10), dtype="float32") =
R.permute_dims(inp_0, axes=None)
+ lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(w0, lv,
out_dtype="float32")
lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
gv: R.Tensor((10, 10), dtype="float32") = lv3
@@ -193,10 +189,7 @@ def test_subgraph_capture():
model = Input1()
mod = dynamo_capture_subgraphs(model, torch.randn(10, 100))
- binding = {"w0": model.lin.weight.detach().numpy(), "w1":
model.lin.bias.detach().numpy()}
- binding = {k: tvm.nd.array(v) for k, v in binding.items()}
- expected = relax.transform.BindParams("subgraph_0", binding)(Expected1)
- tvm.ir.assert_structural_equal(mod, expected)
+ tvm.ir.assert_structural_equal(mod, Expected1)
def Input2(a, b):
x = a / (torch.sin(a) + 1)
@@ -258,27 +251,20 @@ def test_subgraph_capture():
) -> R.Tensor((10, 10), dtype="float32"):
# block 0
with R.dataflow():
- lv0 = R.add(inp_0, R.const(1, "float32"))
- lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0,
axes=None)
- lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(lv0, lv,
out_dtype="float32")
- lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
- lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
- gv: R.Tensor((10, 10), dtype="float32") = lv3
+ lv: R.Tensor((10, 100), dtype="float32") = R.add(inp_0,
R.const(1.0, "float32"))
+ lv1: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0,
axes=None)
+ lv2: R.Tensor((10, 10), dtype="float32") = R.matmul(lv, lv1,
out_dtype="float32")
+ lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, w1)
+ lv4: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv3)
+ gv: R.Tensor((10, 10), dtype="float32") = lv4
R.output(gv)
return gv
model = Input3()
mod = dynamo_capture_subgraphs(model, torch.randn(10, 100), add_one=True)
- binding = {"w0": model.lin.weight.detach().numpy(), "w1":
model.lin.bias.detach().numpy()}
- binding = {k: tvm.nd.array(v) for k, v in binding.items()}
- expected = relax.transform.BindParams("subgraph_0", binding)(Expected3)
- tvm.ir.assert_structural_equal(mod, expected)
+ tvm.ir.assert_structural_equal(mod, Expected3)
[email protected](
- version.parse(torch_version) >= version.parse("2.6.0"),
- reason="Tests not compatible with PyTorch >= 2.6",
-)
def verify_dynamo_model(torch_model, input_info, binding, expected):
import torch
import torch._dynamo as dynamo