This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9249061ad8 [Relax][PyTorch] Add support for decomposed operators and
fix IR of ops tests(3) (#18410)
9249061ad8 is described below
commit 9249061ad80f5ab7d06ff9d259dbdc4b190b7e6c
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Nov 1 10:10:57 2025 -0400
[Relax][PyTorch] Add support for decomposed operators and fix IR of ops
tests(3) (#18410)
* finish1
* finish2
* finish3
---
.../frontend/torch/base_fx_graph_translator.py | 3 +
.../frontend/torch/exported_program_translator.py | 2 +
.../relax/test_frontend_from_exported_program.py | 122 ++++++++++++++++-----
3 files changed, 97 insertions(+), 30 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index b17f62738f..aedef8acf8 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1722,6 +1722,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
def _squeeze(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
None)
+ # Support both "dim" and "dims" parameters
+ if dim is None:
+ dim = node.kwargs.get("dims", None)
return self.block_builder.emit(relax.op.squeeze(x, dim))
def _stack(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 5bb7a9ea8b..48ae002c05 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1018,6 +1018,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"split_with_sizes.default": self._split,
"squeeze.default": self._squeeze,
"squeeze.dim": self._squeeze,
+ "squeeze.dims": self._squeeze,
"stack.default": self._stack,
"take.default": self._take,
"tile.default": self._tile,
@@ -1075,6 +1076,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# other
"getitem": self._getitem,
"item.default": self._item,
+ "_local_scalar_dense.default": self._item,
}
def create_input_vars(
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index ac36c3fe8f..019d649558 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5823,7 +5823,7 @@ def test_cumprod():
return gv
example_input = torch.randn(5, 3, dtype=torch.float32)
- verify_model(Cumprod(), (example_input,), {}, Expected)
+ verify_model(Cumprod(), (example_input,), {}, Expected,
run_ep_decomposition=True)
def test_where():
@@ -5849,7 +5849,7 @@ def test_where():
x = torch.randn(5, 3, dtype=torch.float32)
y = torch.randn(5, 3, dtype=torch.float32)
- verify_model(Where(), (condition, x, y), {}, Expected)
+ verify_model(Where(), (condition, x, y), {}, Expected,
run_ep_decomposition=True)
def test_bucketize():
@@ -5874,7 +5874,7 @@ def test_bucketize():
input_tensor = torch.arange(0, 20)
boundaries = torch.arange(0, 20, 2)
- verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected)
+ verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected,
run_ep_decomposition=True)
def test_argsort():
@@ -5890,12 +5890,18 @@ def test_argsort():
lv: R.Tensor((5, 3), dtype="int32") = R.argsort(
x, axis=1, descending=True, dtype="int32"
)
- gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,)
+ lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(x,
lv, axis=1)
+ lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5,
3), dtype="int32")) = (
+ lv1,
+ lv,
+ )
+ lv3: R.Tensor((5, 3), dtype="int32") = lv2[1]
+ gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv3,)
R.output(gv)
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Argsort(), example_args, {}, Expected)
+ verify_model(Argsort(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_topk():
@@ -5923,7 +5929,7 @@ def test_topk():
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Topk(), example_args, {}, Expected)
+ verify_model(Topk(), example_args, {}, Expected, run_ep_decomposition=True)
def test_dynamic_shape():
@@ -5972,7 +5978,7 @@ def test_broadcast_to():
return gv
example_args = (torch.randn(5, 1, dtype=torch.float32),)
- verify_model(BroadcastTo(), example_args, {}, Expected)
+ verify_model(BroadcastTo(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_narrow():
@@ -5992,6 +5998,7 @@ def test_narrow():
(R.prim_value(1),),
(R.prim_value(0),),
(R.prim_value(2),),
+ (R.prim_value(1),),
assume_inbound=False,
)
gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
@@ -6000,7 +6007,7 @@ def test_narrow():
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Narrow(), example_args, {}, Expected)
+ verify_model(Narrow(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_item():
@@ -6019,7 +6026,7 @@ def test_item():
return gv
example_args = (torch.randn(1, dtype=torch.float32),)
- verify_model(Item(), example_args, {}, Expected)
+ verify_model(Item(), example_args, {}, Expected, run_ep_decomposition=True)
def test_norm():
@@ -6131,7 +6138,9 @@ def test_norm():
example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),)
for (p, dim, keepdim), expected in norms:
- verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {},
expected)
+ verify_model(
+ Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected,
run_ep_decomposition=True
+ )
def test_eye():
@@ -6146,8 +6155,20 @@ def test_eye():
input: R.Tensor((3, 5), dtype="float32")
) -> R.Tuple(R.Tensor((3, 5), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5,
dtype="float32")
- gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,)
+ lv: R.Tensor((3,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(3), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((5,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(5), R.prim_value(1),
dtype="int64"
+ )
+ lv2: R.Tensor((3, 1), dtype="int64") = R.expand_dims(lv,
axis=[-1])
+ lv3: R.Tensor((3, 5), dtype="bool") = R.equal(lv2, lv1)
+ lv4: R.Tensor((1,), dtype="float32") = R.full(
+ R.shape([1]), R.const(1.0, "float32"), dtype="float32"
+ )
+ lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+ lv6: R.Tensor((3, 5), dtype="float32") = R.where(lv3, lv4, lv5)
+ gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv6,)
R.output(gv)
return gv
@@ -6162,16 +6183,28 @@ def test_eye():
input: R.Tensor((5,), dtype="float32")
) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((5, 5), dtype="float32") = R.eye(5,
dtype="float32")
- gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,)
+ lv: R.Tensor((5,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(5), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((5,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(5), R.prim_value(1),
dtype="int64"
+ )
+ lv2: R.Tensor((5, 1), dtype="int64") = R.expand_dims(lv,
axis=[-1])
+ lv3: R.Tensor((5, 5), dtype="bool") = R.equal(lv2, lv1)
+ lv4: R.Tensor((1,), dtype="float32") = R.full(
+ R.shape([1]), R.const(1.0, "float32"), dtype="float32"
+ )
+ lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+ lv6: R.Tensor((5, 5), dtype="float32") = R.where(lv3, lv4, lv5)
+ gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv6,)
R.output(gv)
return gv
example_args1 = (torch.randn(3, 5, dtype=torch.float32),)
- verify_model(Eye1(), example_args1, {}, Expected1)
+ verify_model(Eye1(), example_args1, {}, Expected1,
run_ep_decomposition=True)
example_args2 = (torch.randn(5, dtype=torch.float32),)
- verify_model(Eye2(), example_args2, {}, Expected2)
+ verify_model(Eye2(), example_args2, {}, Expected2,
run_ep_decomposition=True)
def test_cross_entropy():
@@ -6187,21 +6220,39 @@ def test_cross_entropy():
@tvm.script.ir_module
class Expected1:
@R.function
- def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((),
dtype="float32")):
+ def main(x: R.Tensor((4, 3), dtype="float32")) ->
R.Tuple(R.Tensor((4,), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x,
axis=-1)
- lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss(
- lv,
- targets=R.const([0, 1, 2, 1], dtype="int64"),
- reduction="mean",
- ignore_index=-100,
+ lv: R.Tensor((4, 3), dtype="float32") = R.astype(x,
dtype="float32")
+ lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv,
axis=1)
+ lv2: R.Tensor((4,), dtype="bool") = R.not_equal(
+ R.const([0, 1, 2, 1], dtype="int64"), R.const(-100,
"int64")
+ )
+ lv3: R.Tensor((), dtype="int64") = R.const(0, "int64")
+ lv4: R.Tensor((4,), dtype="int64") = R.where(
+ lv2, R.const([0, 1, 2, 1], dtype="int64"), lv3
)
- gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
+ lv5: R.Tensor((4, 1), dtype="int64") = R.expand_dims(lv4,
axis=[1])
+ lv6: R.Tensor((4, 1), dtype="float32") =
R.gather_elements(lv1, lv5, axis=1)
+ lv7: R.Tensor((4,), dtype="float32") = R.squeeze(lv6, axis=[1])
+ lv8: R.Tensor((4,), dtype="float32") = R.negative(lv7)
+ lv9: R.Tensor((4,), dtype="bool") = R.not_equal(
+ R.const([0, 1, 2, 1], dtype="int64"), R.const(-100,
"int64")
+ )
+ lv10: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+ lv11: R.Tensor((4,), dtype="float32") = R.where(lv9, lv8, lv10)
+ lv12: R.Tensor((4,), dtype="bool") = R.not_equal(
+ R.const([0, 1, 2, 1], dtype="int64"), R.const(-100,
"int64")
+ )
+ lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[],
keepdims=False)
+ lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13,
dtype="float32")
+ lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[],
keepdims=False)
+ lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14)
+ gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,)
R.output(gv)
return gv
example_args1 = (torch.randn(4, 3, dtype=torch.float32),)
- verify_model(CrossEntropyModule(), example_args1, {}, Expected1)
+ verify_model(CrossEntropyModule(), example_args1, {}, Expected1,
run_ep_decomposition=True)
def test_linspace():
@@ -6216,13 +6267,24 @@ def test_linspace():
input: R.Tensor((9, 9), dtype="float32")
) -> R.Tuple(R.Tensor((9,), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625,
0.125, dtype="float32")
- gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,)
+ lv: R.Tensor((9,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(9), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((9,), dtype="bool") = R.less(lv, R.const(4,
"int64"))
+ lv2: R.Tensor((9,), dtype="float32") = R.astype(lv,
dtype="float32")
+ lv3: R.Tensor((9,), dtype="float32") = R.multiply(lv2,
R.const(0.125, "float32"))
+ lv4: R.Tensor((9,), dtype="float32") = R.add(lv3, R.const(0.0,
"float32"))
+ lv5: R.Tensor((9,), dtype="int64") = R.subtract(R.const(8,
"int64"), lv)
+ lv6: R.Tensor((9,), dtype="float32") = R.astype(lv5,
dtype="float32")
+ lv7: R.Tensor((9,), dtype="float32") = R.multiply(lv6,
R.const(0.125, "float32"))
+ lv8: R.Tensor((9,), dtype="float32") = R.subtract(R.const(1.0,
"float32"), lv7)
+ lv9: R.Tensor((9,), dtype="float32") = R.where(lv1, lv4, lv8)
+ gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv9,)
R.output(gv)
return gv
example_args = (torch.randn(9, 9, dtype=torch.float32),)
- verify_model(Linspace(), example_args, {}, Expected)
+ verify_model(Linspace(), example_args, {}, Expected,
run_ep_decomposition=True)
@pytest.mark.parametrize(
@@ -6259,7 +6321,7 @@ def test_dtypes(torch_dtype, relax_dtype):
R.output(gv)
return gv
- verify_model(Model(), example_args, {}, Expected)
+ verify_model(Model(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_mm():
@@ -6285,7 +6347,7 @@ def test_mm():
R.output(gv)
return gv
- verify_model(MatrixMultiply(), example_args, {}, Expected)
+ verify_model(MatrixMultiply(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_lstm():