This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new be37afdf30 [Relax][PyTorch] Add support for decomposed operators and
fix IR of ops tests(2) (#18403)
be37afdf30 is described below
commit be37afdf300569ccb2fff5b71170985065597335
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Oct 30 12:00:02 2025 -0400
[Relax][PyTorch] Add support for decomposed operators and fix IR of ops
tests(2) (#18403)
---
.../frontend/torch/exported_program_translator.py | 10 ++
.../relax/test_frontend_from_exported_program.py | 112 +++++++++++++--------
2 files changed, 79 insertions(+), 43 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 011e23f1df..5bb7a9ea8b 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -760,6 +760,14 @@ class ExportedProgramImporter(BaseFXGraphImporter):
)
return self.block_builder.emit(relax.op.zeros(size, dtype))
+ def _scalar_tensor(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ scalar_value = args[0]
+ dtype = self._convert_data_type(
+ node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+ )
+ return self.block_builder.emit(relax.const(scalar_value, dtype))
+
def _instance_norm(self, node: fx.Node):
import numpy as np
@@ -851,6 +859,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"relu6_.default": self._unary_op(relax.op.nn.relu6),
"round.default": self._round,
"rsqrt.default": self._unary_op(relax.op.rsqrt),
+ "scalar_tensor.default": self._scalar_tensor,
"rsub.Tensor": self._rsub,
"rsub.Scalar": self._rsub,
"selu.default": self._unary_op(relax.op.nn.selu),
@@ -861,6 +870,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"sin.default": self._unary_op(relax.op.sin),
"sinh.default": self._unary_op(relax.op.sinh),
"softmax.int": self._softmax,
+ "_softmax.default": self._softmax,
"softplus.default": self._softplus,
"softshrink.default": self._softshrink,
"softsign.default": self._softsign,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 9851804e2a..ac36c3fe8f 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1103,8 +1103,8 @@ def test_softmax():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Softmax(), example_args, {}, expected1)
- verify_model(Softmax2(), example_args, {}, expected1)
+ verify_model(Softmax(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Softmax2(), example_args, {}, expected1,
run_ep_decomposition=True)
def test_softsign():
@@ -1135,8 +1135,8 @@ def test_softsign():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Softsign(), example_args, {}, expected_softsign)
- verify_model(Softsign2(), example_args, {}, expected_softsign)
+ verify_model(Softsign(), example_args, {}, expected_softsign,
run_ep_decomposition=True)
+ verify_model(Softsign2(), example_args, {}, expected_softsign,
run_ep_decomposition=True)
def test_softshrink():
@@ -1159,32 +1159,24 @@ def test_softshrink():
input: R.Tensor((1, 3, 10, 10), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
- input, R.const(0.5, "float32")
- )
- lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
- input, R.const(0.5, "float32")
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.abs(input)
+ lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lv,
R.const(0.5, "float32"))
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sign(input)
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
+ lv2, R.const(0.5, "float32")
)
- lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1,
"float32")
- lv3: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(lv, lv2)
-
- lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
- input, R.const(0.5, "float32")
+ lv4: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.subtract(input, lv3)
+ lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
+ input, R.const(0.0, "float32")
)
- lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5,
"float32"))
- lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input,
lv5)
- lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6,
"float32")
- lv8: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(lv4, lv7)
-
- lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3,
lv8)
-
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,)
+ lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv1,
lv4, lv5)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,)
R.output(gv)
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Softshrink(), example_args, {}, expected_softshrink)
- verify_model(Softshrink2(), example_args, {}, expected_softshrink)
+ verify_model(Softshrink(), example_args, {}, expected_softshrink,
run_ep_decomposition=True)
+ verify_model(Softshrink2(), example_args, {}, expected_softshrink,
run_ep_decomposition=True)
def test_tril_triu():
@@ -1198,16 +1190,27 @@ def test_tril_triu():
class expected_tril:
@R.function
def main(
- input_1: R.Tensor((10, 10), dtype="float32")
+ input: R.Tensor((10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
# block 0
with R.dataflow():
- lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1)
- gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ lv: R.Tensor((10,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(10), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((1, 10), dtype="int64") = R.expand_dims(lv,
axis=[-2])
+ lv2: R.Tensor((10,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(10), R.prim_value(1),
dtype="int64"
+ )
+ lv3: R.Tensor((10, 1), dtype="int64") = R.expand_dims(lv2,
axis=[-1])
+ lv4: R.Tensor((10, 10), dtype="int64") = R.subtract(lv1, lv3)
+ lv5: R.Tensor((10, 10), dtype="bool") = R.less_equal(lv4,
R.const(1, "int64"))
+ lv6: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+ lv7: R.Tensor((10, 10), dtype="float32") = R.where(lv5, input,
lv6)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv7,)
R.output(gv)
return gv
- verify_model(Tril(), example_args, {}, expected_tril)
+ verify_model(Tril(), example_args, {}, expected_tril,
run_ep_decomposition=True)
class Triu(Module):
def forward(self, input):
@@ -1217,16 +1220,27 @@ def test_tril_triu():
class expected_triu:
@R.function
def main(
- input_1: R.Tensor((10, 10), dtype="float32")
+ input: R.Tensor((10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
# block 0
with R.dataflow():
- lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1)
- gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ lv: R.Tensor((10,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(10), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((1, 10), dtype="int64") = R.expand_dims(lv,
axis=[-2])
+ lv2: R.Tensor((10,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(10), R.prim_value(1),
dtype="int64"
+ )
+ lv3: R.Tensor((10, 1), dtype="int64") = R.expand_dims(lv2,
axis=[-1])
+ lv4: R.Tensor((10, 10), dtype="int64") = R.subtract(lv1, lv3)
+ lv5: R.Tensor((10, 10), dtype="bool") = R.greater_equal(lv4,
R.const(1, "int64"))
+ lv6: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+ lv7: R.Tensor((10, 10), dtype="float32") = R.where(lv5, input,
lv6)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv7,)
R.output(gv)
return gv
- verify_model(Triu(), example_args, {}, expected_triu)
+ verify_model(Triu(), example_args, {}, expected_triu,
run_ep_decomposition=True)
operator_binary_1 = [
@@ -1501,7 +1515,7 @@ def test_div_mode():
torch.randn(64, 64, dtype=torch.float32),
torch.randn(64, dtype=torch.float32),
)
- verify_model(DivModel(), example_args, {}, expected_div)
+ verify_model(DivModel(), example_args, {}, expected_div,
run_ep_decomposition=True)
# Case 2: Division with trunc rounding
class DivTruncModel(torch.nn.Module):
@@ -1521,7 +1535,7 @@ def test_div_mode():
R.output(gv)
return gv
- verify_model(DivTruncModel(), example_args, {}, expected_div_trunc)
+ verify_model(DivTruncModel(), example_args, {}, expected_div_trunc,
run_ep_decomposition=True)
# Case 3: Division with floor rounding
class DivFloorModel(torch.nn.Module):
@@ -1540,7 +1554,7 @@ def test_div_mode():
R.output(gv)
return gv
- verify_model(DivFloorModel(), example_args, {}, expected_div_floor)
+ verify_model(DivFloorModel(), example_args, {}, expected_div_floor,
run_ep_decomposition=True)
def test_batchnorm2d():
@@ -1578,6 +1592,8 @@ def test_batchnorm2d():
epsilon=1e-05,
center=True,
scale=True,
+ momentum=1e-05,
+ training=False,
)
lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
@@ -1593,7 +1609,7 @@ def test_batchnorm2d():
"w3": model.bn.running_mean.detach().numpy(),
"w4": model.bn.running_var.detach().numpy(),
}
- verify_model(model, example_args, binding, expected1)
+ verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
def test_adaptive_avgpool1d():
@@ -1748,8 +1764,8 @@ def test_addmm():
torch.randn(10, 10, dtype=torch.float32),
)
- verify_model(Addmm1(), example_args, {}, expected1)
- verify_model(Addmm2(), example_args, {}, expected2)
+ verify_model(Addmm1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Addmm2(), example_args, {}, expected2,
run_ep_decomposition=True)
def test_avg_pool1d():
@@ -2054,8 +2070,10 @@ def test_baddbmm():
inp_2: R.Tensor((4, 256, 512), dtype="float32"),
) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1,
inp_2)
- lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv,
inp_0)
+ lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
+ inp_1, inp_2, out_dtype="float32"
+ )
+ lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(inp_0,
lv)
gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,)
R.output(gv)
return gv
@@ -2076,7 +2094,9 @@ def test_baddbmm():
inp_2: R.Tensor((4, 256, 512), dtype="float32"),
) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1,
inp_2)
+ lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
+ inp_1, inp_2, out_dtype="float32"
+ )
lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
lv, R.const(2, "float32")
)
@@ -2100,14 +2120,16 @@ def test_baddbmm():
inp_2: R.Tensor((4, 256, 512), dtype="float32"),
) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1,
inp_2)
+ lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
+ inp_1, inp_2, out_dtype="float32"
+ )
lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
lv, R.const(2, "float32")
)
lv2: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
inp_0, R.const(3, "float32")
)
- lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv1, lv2)
+ lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv2, lv1)
gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv3,)
R.output(gv)
return gv
@@ -2122,6 +2144,7 @@ def test_baddbmm():
example_args,
{},
Expected1,
+ run_ep_decomposition=True,
)
verify_model(
@@ -2129,6 +2152,7 @@ def test_baddbmm():
example_args,
{},
Expected2,
+ run_ep_decomposition=True,
)
verify_model(
@@ -2136,6 +2160,7 @@ def test_baddbmm():
example_args,
{},
Expected3,
+ run_ep_decomposition=True,
)
@@ -2172,6 +2197,7 @@ def test_bmm():
example_args,
{},
Expected,
+ run_ep_decomposition=True,
)