This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5ca61bbf6d [Relax][PyTorch] Add support for decomposed operators and
fix IR of ops tests(4) (#18414)
5ca61bbf6d is described below
commit 5ca61bbf6dee9f938e629802f4c395b078b441ce
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Nov 2 20:27:24 2025 -0500
[Relax][PyTorch] Add support for decomposed operators and fix IR of ops
tests(4) (#18414)
---
.../frontend/torch/exported_program_translator.py | 4 +
.../relax/test_frontend_from_exported_program.py | 154 +++++++++++++--------
2 files changed, 100 insertions(+), 58 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 48ae002c05..3be255a29a 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1003,6 +1003,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"flip.default": self._flip,
"gather.default": self._gather,
"index.Tensor": self._index_tensor,
+ "index_put.default": self._index_put,
"index_put_.default": self._index_put,
"meshgrid.indexing": self._meshgrid,
"meshgrid.default": self._meshgrid,
@@ -1041,6 +1042,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"contiguous.default": lambda node: self.env[node.args[0]], # no-op
"clone.default": lambda node: self.env[node.args[0]],
"bernoulli.p": lambda node: self.env[node.args[0]], # Dropout:
just return input
+ "_assert_tensor_metadata.default": lambda node: self.env[
+ node.args[0]
+ ], # metadata assertion: no-op
"empty.memory_format": self._empty,
"empty_permuted.default": self._empty, # Similar to empty with
permuted layout
"empty_like.default": self._empty_like,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 019d649558..9f63743faa 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5222,17 +5222,17 @@ def test_empty_like():
class Expected:
@R.function
def main(
- inp_0: R.Tensor((5,), dtype="float32"),
+ data: R.Tensor((5,), dtype="float32"),
) -> R.Tuple(R.Tensor((5,), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0,
dtype="void")
+ lv: R.Tensor((5,), dtype="float32") = R.zeros(R.shape([5]),
dtype="float32")
gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,)
R.output(gv)
return gv
example_args = (torch.randn(5, dtype=torch.float32),)
- verify_model(EmptyLike(), example_args, {}, Expected)
+ verify_model(EmptyLike(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_one_hot():
@@ -5244,19 +5244,22 @@ def test_one_hot():
class Expected:
@R.function
def main(
- inp_0: R.Tensor((5,), dtype="int64"),
+ indices: R.Tensor((5,), dtype="int64"),
) -> R.Tuple(R.Tensor((5, 10), dtype="int64")):
with R.dataflow():
- lv: R.Tensor((5, 10), dtype="int64") = R.one_hot(
- inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1
+ lv: R.Tensor((10,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(10), R.prim_value(1),
dtype="int64"
)
- gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv,)
+ lv1: R.Tensor((5, 1), dtype="int64") = R.expand_dims(indices,
axis=[-1])
+ lv2: R.Tensor((5, 10), dtype="bool") = R.equal(lv1, lv)
+ lv3: R.Tensor((5, 10), dtype="int64") = R.astype(lv2,
dtype="int64")
+ gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv3,)
R.output(gv)
return gv
example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),)
- verify_model(OneHot(), example_args, {}, Expected)
+ verify_model(OneHot(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_ones_like():
@@ -5271,14 +5274,16 @@ def test_ones_like():
input: R.Tensor((128, 128), dtype="float32")
) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(input,
dtype="void")
+ lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
+ input, R.const(1, "int32"), dtype="void"
+ )
gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
R.output(gv)
return gv
example_args = (torch.rand(128, 128, dtype=torch.float32),)
- verify_model(OnesLike(), example_args, {}, Expected)
+ verify_model(OnesLike(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_zero_inplace():
@@ -5291,16 +5296,23 @@ def test_zero_inplace():
@R.function
def main(
input: R.Tensor((128, 128), dtype="float32")
- ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+ ) -> R.Tuple(R.Tensor((128, 128), dtype="float32"), R.Tensor((128,
128), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((128, 128), dtype="float32") =
R.zeros_like(input, dtype="void")
- gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
+ lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
+ input, R.const(0, "int32"), dtype="void"
+ )
+ gv: R.Tuple(
+ R.Tensor((128, 128), dtype="float32"), R.Tensor((128,
128), dtype="float32")
+ ) = (
+ lv,
+ lv,
+ )
R.output(gv)
return gv
example_args = (torch.rand(128, 128, dtype=torch.float32),)
- verify_model(ZeroInplace(), example_args, {}, Expected)
+ verify_model(ZeroInplace(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_zeros():
@@ -5315,14 +5327,16 @@ def test_zeros():
input: R.Tensor((128, 128), dtype="float32")
) -> R.Tuple(R.Tensor((5, 2), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((5, 2), dtype="float32") = R.zeros(R.shape([5,
2]), dtype="float32")
+ lv: R.Tensor((5, 2), dtype="float32") = R.full(
+ R.shape([5, 2]), R.const(0.0, "float32"), dtype="float32"
+ )
gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
R.output(gv)
return gv
example_args = (torch.rand(128, 128, dtype=torch.float32),)
- verify_model(Zeros(), example_args, {}, Expected)
+ verify_model(Zeros(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_zeros_like():
@@ -5337,13 +5351,15 @@ def test_zeros_like():
input: R.Tensor((128, 128), dtype="float32")
) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((128, 128), dtype="float32") =
R.zeros_like(input, dtype="void")
+ lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
+ input, R.const(0, "int32"), dtype="void"
+ )
gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
R.output(gv)
return gv
example_args = (torch.rand(128, 128, dtype=torch.float32),)
- verify_model(ZerosLike(), example_args, {}, Expected)
+ verify_model(ZerosLike(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_type_as():
@@ -5369,7 +5385,7 @@ def test_type_as():
torch.rand(128, 128, dtype=torch.float16),
)
- verify_model(TypeAs(), example_args, {}, Expected)
+ verify_model(TypeAs(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_select():
@@ -5391,7 +5407,7 @@ def test_select():
example_args = (torch.randn(2, 3, dtype=torch.float32),)
- verify_model(Select(), example_args, {}, Expected)
+ verify_model(Select(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_unflatten():
@@ -5417,8 +5433,8 @@ def test_unflatten():
example_args = (torch.randn(2, 15, 7, dtype=torch.float32),)
- verify_model(Unflatten(), example_args, {}, Expected)
- verify_model(Unflatten1(), example_args, {}, Expected)
+ verify_model(Unflatten(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Unflatten1(), example_args, {}, Expected,
run_ep_decomposition=True)
def test_gather():
@@ -5495,10 +5511,10 @@ def test_gather():
torch.randint(0, 3, (2, 3), dtype=torch.int64),
)
- verify_model(Gather0(), example_args, {}, Expected0)
- verify_model(Gather1(), example_args, {}, Expected1)
- verify_model(Gather2(), example_args, {}, Expected2)
- verify_model(Gather3(), example_args, {}, Expected3)
+ verify_model(Gather0(), example_args, {}, Expected0,
run_ep_decomposition=True)
+ verify_model(Gather1(), example_args, {}, Expected1,
run_ep_decomposition=True)
+ verify_model(Gather2(), example_args, {}, Expected2,
run_ep_decomposition=True)
+ verify_model(Gather3(), example_args, {}, Expected3,
run_ep_decomposition=True)
def test_index_put():
@@ -5521,12 +5537,15 @@ def test_index_put():
data: R.Tensor((64,), dtype="float32"),
indices_0: R.Tensor((128,), dtype="int64"),
values: R.Tensor((128,), dtype="float32"),
- ) -> R.Tuple(R.Tensor((64,), dtype="float32")):
+ ) -> R.Tuple(R.Tensor((64,), dtype="float32"), R.Tensor((64,),
dtype="float32")):
with R.dataflow():
lv: R.Tensor((64,), dtype="float32") = R.index_put(
data, R.tuple(indices_0), values, accumulate=False
)
- gv: R.Tuple(R.Tensor((64,), dtype="float32")) = (lv,)
+ gv: R.Tuple(R.Tensor((64,), dtype="float32"), R.Tensor((64,),
dtype="float32")) = (
+ lv,
+ lv,
+ )
R.output(gv)
return gv
@@ -5551,12 +5570,14 @@ def test_index_put():
indices_0: R.Tensor((128,), dtype="int64"),
indices_1: R.Tensor((128,), dtype="int64"),
values: R.Tensor((128,), dtype="float32"),
- ) -> R.Tuple(R.Tensor((32, 64), dtype="float32")):
+ ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64),
dtype="float32")):
with R.dataflow():
lv: R.Tensor((32, 64), dtype="float32") = R.index_put(
data, R.tuple(indices_0, indices_1), values,
accumulate=False
)
- gv: R.Tuple(R.Tensor((32, 64), dtype="float32")) = (lv,)
+ gv: R.Tuple(
+ R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64),
dtype="float32")
+ ) = (lv, lv)
R.output(gv)
return gv
@@ -5583,12 +5604,16 @@ def test_index_put():
indices_1: R.Tensor((128,), dtype="int64"),
indices_2: R.Tensor((128,), dtype="int64"),
values: R.Tensor((128,), dtype="float32"),
- ) -> R.Tuple(R.Tensor((16, 32, 64), dtype="float32")):
+ ) -> R.Tuple(
+ R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64),
dtype="float32")
+ ):
with R.dataflow():
lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put(
data, R.tuple(indices_0, indices_1, indices_2), values,
accumulate=False
)
- gv: R.Tuple(R.Tensor((16, 32, 64), dtype="float32")) = (lv,)
+ gv: R.Tuple(
+ R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32,
64), dtype="float32")
+ ) = (lv, lv)
R.output(gv)
return gv
@@ -5617,7 +5642,10 @@ def test_index_put():
indices_2: R.Tensor((128,), dtype="int64"),
indices_3: R.Tensor((128,), dtype="int64"),
values: R.Tensor((128,), dtype="float32"),
- ) -> R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")):
+ ) -> R.Tuple(
+ R.Tensor((8, 16, 32, 64), dtype="float32"),
+ R.Tensor((8, 16, 32, 64), dtype="float32"),
+ ):
with R.dataflow():
lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put(
data,
@@ -5625,7 +5653,10 @@ def test_index_put():
values,
accumulate=False,
)
- gv: R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")) = (lv,)
+ gv: R.Tuple(
+ R.Tensor((8, 16, 32, 64), dtype="float32"),
+ R.Tensor((8, 16, 32, 64), dtype="float32"),
+ ) = (lv, lv)
R.output(gv)
return gv
@@ -5656,7 +5687,10 @@ def test_index_put():
indices_3: R.Tensor((128,), dtype="int64"),
indices_4: R.Tensor((128,), dtype="int64"),
values: R.Tensor((128,), dtype="float32"),
- ) -> R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")):
+ ) -> R.Tuple(
+ R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+ R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+ ):
with R.dataflow():
lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") =
R.index_put(
data,
@@ -5664,16 +5698,19 @@ def test_index_put():
values,
accumulate=False,
)
- gv: R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")) =
(lv,)
+ gv: R.Tuple(
+ R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+ R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+ ) = (lv, lv)
R.output(gv)
return gv
# Run verification for each case
- verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
- verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
- verify_model(IndexPut3D(), example_args_3d, {}, Expected3D)
- verify_model(IndexPut4D(), example_args_4d, {}, Expected4D)
- verify_model(IndexPut5D(), example_args_5d, {}, Expected5D)
+ verify_model(IndexPut1D(), example_args_1d, {}, Expected1D,
run_ep_decomposition=True)
+ verify_model(IndexPut2D(), example_args_2d, {}, Expected2D,
run_ep_decomposition=True)
+ verify_model(IndexPut3D(), example_args_3d, {}, Expected3D,
run_ep_decomposition=True)
+ verify_model(IndexPut4D(), example_args_4d, {}, Expected4D,
run_ep_decomposition=True)
+ verify_model(IndexPut5D(), example_args_5d, {}, Expected5D,
run_ep_decomposition=True)
def test_flip():
@@ -5711,8 +5748,8 @@ def test_flip():
example_args = (torch.randn(2, 2, dtype=torch.float32),)
- verify_model(Flip0(), example_args, {}, Expected0)
- verify_model(Flip1(), example_args, {}, Expected1)
+ verify_model(Flip0(), example_args, {}, Expected0,
run_ep_decomposition=True)
+ verify_model(Flip1(), example_args, {}, Expected1,
run_ep_decomposition=True)
def test_take():
@@ -5724,12 +5761,12 @@ def test_take():
class Expected:
@R.function
def main(
- inp_0: R.Tensor((5,), dtype="float32"),
- inp_1: R.Tensor((3,), dtype="int64"),
+ data: R.Tensor((5,), dtype="float32"),
+ indices: R.Tensor((3,), dtype="int64"),
) -> R.Tuple(R.Tensor((3,), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((3,), dtype="int32") = R.astype(inp_1,
dtype="int32")
- lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv,
axis=None)
+ lv: R.Tensor((5,), dtype="float32") = R.reshape(data,
R.shape([5]))
+ lv1: R.Tensor((3,), dtype="float32") = R.index_tensor(lv,
(indices,))
gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv1,)
R.output(gv)
return gv
@@ -5739,7 +5776,7 @@ def test_take():
torch.randint(0, 5, (3,), dtype=torch.int64),
)
- verify_model(Take(), example_args, {}, Expected)
+ verify_model(Take(), example_args, {}, Expected, run_ep_decomposition=True)
def test_std():
@@ -5751,16 +5788,17 @@ def test_std():
class Expected:
@R.function
def main(
- inp_0: R.Tensor((5, 3), dtype="float32"),
+ x: R.Tensor((5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((), dtype="float32") = R.std(inp_0, axis=None,
keepdims=False)
- gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None,
keepdims=False)
+ lv1: R.Tensor((), dtype="float32") = R.sqrt(lv)
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
R.output(gv)
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Std(), example_args, {}, Expected)
+ verify_model(Std(), example_args, {}, Expected, run_ep_decomposition=True)
def test_var():
@@ -5772,16 +5810,16 @@ def test_var():
class Expected:
@R.function
def main(
- inp_0: R.Tensor((5, 3), dtype="float32"),
+ x: R.Tensor((5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((), dtype="float32") = R.variance(inp_0,
axis=None, keepdims=False)
+ lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None,
keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
R.output(gv)
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Var(), example_args, {}, Expected)
+ verify_model(Var(), example_args, {}, Expected, run_ep_decomposition=True)
def test_prod():
@@ -5793,16 +5831,16 @@ def test_prod():
class Expected:
@R.function
def main(
- inp_0: R.Tensor((5, 3), dtype="float32"),
+ x: R.Tensor((5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None,
keepdims=False)
+ lv: R.Tensor((), dtype="float32") = R.prod(x, axis=None,
keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
R.output(gv)
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Prod(), example_args, {}, Expected)
+ verify_model(Prod(), example_args, {}, Expected, run_ep_decomposition=True)
def test_cumprod():