This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2de5b1e08c [Relax][PyTorch] Add torch.isin Op Support for Exported
Program and FX graph (#17878)
2de5b1e08c is described below
commit 2de5b1e08c8e8541412a701f3714a6cc0dd69d10
Author: Deivanayaki S <[email protected]>
AuthorDate: Sat Apr 26 09:25:23 2025 +0530
[Relax][PyTorch] Add torch.isin Op Support for Exported Program and FX
graph (#17878)
* add torch.isin op support into torch frontends
* fix lint issues in test script
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
.../frontend/torch/base_fx_graph_translator.py | 14 ++++++++++
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
.../relax/test_frontend_from_exported_program.py | 31 ++++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 29 ++++++++++++++++++++
5 files changed, 76 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 3e81ff1f0b..c1a1a61398 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -417,6 +417,20 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.subtract(rhs, lhs))
+ def _isin(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ elements = args[0]
+ test_elements = args[1]
+
+ expanded_elements = relax.op.expand_dims(elements, axis=-1)
+ flattened_test_elements = relax.op.reshape(test_elements, (-1,))
+
+ comparison = relax.op.equal(expanded_elements, flattened_test_elements)
+ summed = relax.op.sum(comparison, axis=-1)
+ result = relax.op.greater(summed, relax.const(0,
dtype=elements.struct_info.dtype))
+
+ return self.block_builder.emit(result)
+
########## Neural Network ##########
def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index a3ab575c4b..88f6dd538d 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -299,6 +299,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"hardtanh_.default": self._hardtanh,
"isfinite.default": self._unary_op(relax.op.isfinite),
"isinf.default": self._unary_op(relax.op.isinf),
+ "isin.Tensor_Tensor": self._isin,
"isnan.default": self._unary_op(relax.op.isnan),
"leaky_relu.default": self._leakyrelu,
"leaky_relu_.default": self._leakyrelu,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 18dba2d988..0d3dafc8d5 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -693,6 +693,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"hardtanh": self._hardtanh,
"isfinite": self._unary_op(relax.op.isfinite),
"isinf": self._unary_op(relax.op.isinf),
+ "isin": self._isin,
"isnan": self._unary_op(relax.op.isnan),
"leaky_relu": self._leakyrelu,
"log": self._unary_op(relax.op.log),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index e3b6f4ad9c..8cc3dde397 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1060,6 +1060,37 @@ def test_binary3():
verify_model(RSub2(), example_args2, {}, expected_rsub2)
+# IsIn
+
+
+def test_isin():
+ class IsInModel(torch.nn.Module):
+ def forward(self, x, test_elements):
+ return torch.isin(x, test_elements)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ x: R.Tensor((10, 10), dtype="float32"), test_elements:
R.Tensor((8,), dtype="float32")
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
+ with R.dataflow():
+ lv: R.Tensor((10, 10, 1), dtype="float32") = R.expand_dims(x,
axis=[-1])
+ lv1: R.Tensor((8,), dtype="float32") =
R.reshape(test_elements, R.shape([8]))
+ lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1)
+ lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1],
keepdims=False)
+ lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3,
R.const(0.0, "float32"))
+ gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(10, 10, dtype=torch.float32),
+ torch.randn(8, dtype=torch.float32),
+ )
+ verify_model(IsInModel(), example_args, {}, expected)
+
+
def test_batchnorm2d():
class BatchNorm2d(Module):
def __init__(self):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 4003202d4f..48c2cec8c0 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1868,6 +1868,35 @@ def test_rsub():
verify_model(RSub2(), input_info2, {}, expected_rsub2)
+# IsIn
+
+
+def test_isin():
+ input_info = [([10, 10], "float32"), ([8], "float32")]
+
+ class IsInModel(torch.nn.Module):
+ def forward(self, x, test_elements):
+ return torch.isin(x, test_elements)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((10, 10), dtype="float32"), inp_1: R.Tensor((8,),
dtype="float32")
+ ) -> R.Tensor((10, 10), dtype="bool"):
+ with R.dataflow():
+ lv: R.Tensor((10, 10, 1), dtype="float32") =
R.expand_dims(inp_0, axis=[-1])
+ lv1: R.Tensor((8,), dtype="float32") = R.reshape(inp_1,
R.shape([8]))
+ lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1)
+ lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1],
keepdims=False)
+ lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3,
R.const(0.0, "float32"))
+ gv: R.Tensor((10, 10), dtype="bool") = lv4
+ R.output(gv)
+ return gv
+
+ verify_model(IsInModel(), input_info, {}, expected)
+
+
def test_size():
input_info = [([1, 3, 10, 10], "float32")]