This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 49564f7a2c Add masked_fill_.scalar, logical_not.default in Exported
Program frontend (#17909)
49564f7a2c is described below
commit 49564f7a2cc8b79c43323003bde94eb8df1fbef3
Author: Pratheesh-04-MCW <[email protected]>
AuthorDate: Wed Apr 30 10:28:00 2025 +0530
Add masked_fill_.scalar, logical_not.default in Exported Program frontend
(#17909)
add op support for masked_fill_
---
.../frontend/torch/exported_program_translator.py | 2 ++
.../relax/test_frontend_from_exported_program.py | 24 +++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 25 ++++++++++++++++++++++
3 files changed, 51 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 86f5de5f36..7919e288c2 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -307,6 +307,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"log2.default": self._log2,
"log10.default": self._log10,
"log1p.default": self._log1p,
+ "logical_not.default": self._unary_op(relax.op.logical_not),
"log_softmax.int": self._log_softmax,
"neg.default": self._unary_op(relax.op.negative),
"pad.default": self._pad,
@@ -481,6 +482,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"lift_fresh_copy.default": self._to_copy,
"linspace.default": self._linspace,
"masked_fill.Scalar": self._masked_fill,
+ "masked_fill_.Scalar": self._inplace_masked_fill,
"new_ones.default": self._new_ones,
"one_hot.default": self._one_hot,
"ones.default": self._ones,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index dd1869a23c..b9385d1cc2 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3727,6 +3727,30 @@ def test_masked_fill():
verify_model(Masked_Fill(), example_args, {}, Expected)
+def test_masked_fill_inplace():
+ class Masked_Fill_Inplace(Module):
+ def forward(self, input: torch.Tensor, mask: torch.Tensor):
+ return input.masked_fill_(mask, 1.5)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128,
128), dtype="bool")
+ ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
+ input, R.const(1.5, "float32"), dtype="void"
+ )
+ lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv,
input)
+ gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(128, 128, dtype=torch.float32),
torch.rand(128, 128) < 0.5)
+ verify_model(Masked_Fill_Inplace(), example_args, {}, Expected)
+
+
def test_new_ones():
class NewOnes(Module):
def forward(self, x):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index f60f158cbf..fdec5ed19c 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3243,6 +3243,31 @@ def test_inplace_fill():
verify_model(InplaceFill(), [([10, 10], "float32")], {}, Expected)
+def test_masked_fill_inplace():
+ class Masked_Fill_Inplace(Module):
+ def forward(self, input: torch.Tensor, mask: torch.Tensor):
+ input.masked_fill_(mask, 1.5)
+ return input
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((10, 10), dtype="float32"), mask: R.Tensor((10,
10), dtype="bool")
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.full_like(
+ input, R.const(1.5, "float32"), dtype="void"
+ )
+ lv1: R.Tensor((10, 10), dtype="float32") = R.where(mask, lv,
input)
+ gv: R.Tensor((10, 10), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ input_info = [((10, 10), "float32"), ((10, 10), "bool")]
+ verify_model(Masked_Fill_Inplace(), input_info, {}, Expected)
+
+
def test_arange():
import numpy as np