This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 26b107fa12 [Relax][PyTorch] Add support for masked_select (#18535)
26b107fa12 is described below

commit 26b107fa12672c3b958da222fc87755a69d64c42
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Mon Dec 8 03:59:25 2025 +0800

    [Relax][PyTorch] Add support for masked_select (#18535)
    
    ## How
    
    Add support for masked_select
---
 .../frontend/torch/base_fx_graph_translator.py     | 21 ++++++++++++
 .../frontend/torch/exported_program_translator.py  | 11 +++++++
 python/tvm/script/ir_builder/relax/ir.py           |  2 ++
 .../relax/test_frontend_from_exported_program.py   | 37 ++++++++++++++++++++++
 4 files changed, 71 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 7ebb95c136..471d4209d7 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -23,6 +23,7 @@ from functools import reduce
 import math
 from typing import Callable, Dict, Optional, Tuple, Union, List
 
+import tvm
 from tvm import relax, tir
 
 
@@ -2385,6 +2386,26 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return self.block_builder.emit(relax.op.where(mask, values, x))
 
+    def _masked_select(self, node: fx.Node) -> relax.Var:
+        data = self.env[node.args[0]]
+        mask = self.env[node.args[1]]
+
+        data_shape = self.shape_of(data)
+        mask_shape = self.shape_of(mask)
+        shapes_equal = tvm.ir.structural_equal(data_shape, mask_shape)
+
+        if not shapes_equal:
+            mask = self.block_builder.emit(relax.op.broadcast_to(mask, 
data_shape))
+
+        data_flat = self.block_builder.emit(relax.op.reshape(data, [-1]))
+        mask_flat = self.block_builder.emit(relax.op.reshape(mask, [-1]))
+        indices = self.block_builder.emit(relax.op.nonzero(mask_flat))
+        indices_1d = self.block_builder.emit(relax.op.squeeze(indices, 
axis=[0]))
+
+        result = self.block_builder.emit(relax.op.take(data_flat, indices_1d, 
axis=0))
+
+        return result
+
     def _new_ones(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         self_var = args[0]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 641e16f599..3e2274e551 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1153,6 +1153,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
         return self.block_builder.emit(relax.op.reshape(x, size))
 
+    ########## Symbolic Shape Constraints ##########
+
+    def _symbolic_comparison(self, _: fx.Node) -> relax.Expr:
+        return self.block_builder.emit(relax.const(True, dtype="bool"))
+
     ########## Others ##########
 
     def create_convert_map(
@@ -1457,6 +1462,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "linspace.default": self._linspace,
             "masked_fill.Scalar": self._masked_fill,
             "masked_fill_.Scalar": self._inplace_masked_fill,
+            "masked_select.default": self._masked_select,
             "new_ones.default": self._new_ones,
             "new_zeros.default": self._new_zeros,
             "one_hot.default": self._one_hot,
@@ -1477,6 +1483,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "item.default": self._item,
             "sym_size.int": self._sym_size_int,
             "_local_scalar_dense.default": self._item,
+            # symbolic shape constraints (no-ops for compilation)
+            "sym_constrain_range_for_size.default": lambda node: 
self.env[node.args[0]],
+            "_assert_scalar.default": lambda node: self.env[node.args[0]],
+            "ge": self._symbolic_comparison,
+            "le": self._symbolic_comparison,
         }
 
     def _process_derived_symbol(
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index f221a13089..141361a729 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -137,6 +137,7 @@ from tvm.relax.op import (
     multiply,
     negative,
     nn,
+    nonzero,
     not_equal,
     null_value,
     ones,
@@ -882,6 +883,7 @@ __all__ = [
     "multinomial_from_uniform",
     "multiply",
     "negative",
+    "nonzero",
     "not_equal",
     "null_value",
     "ones",
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 68567e1fc8..74ad2329fe 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6231,6 +6231,43 @@ def test_masked_fill_inplace():
     verify_model(Masked_Fill_Inplace(), example_args, {}, Expected)
 
 
+def test_masked_select():
+    class MaskedSelect(Module):
+        def forward(self, data: torch.Tensor, mask: torch.Tensor):
+            return torch.masked_select(data, mask)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            data: R.Tensor((2, 3), dtype="float32"), mask: R.Tensor((2, 3), 
dtype="bool")
+        ) -> R.Tuple(R.Tensor(dtype="float32", ndim=1)):
+            R.func_attr(
+                {
+                    "tir_var_lower_bound": {"u0": 0, "u1": 0},
+                    "tir_var_upper_bound": {"u0": 6, "u1": 6},
+                }
+            )
+            with R.dataflow():
+                lv: R.Tensor((6,), dtype="float32") = R.reshape(data, 
R.shape([6]))
+                lv1: R.Tensor((6,), dtype="bool") = R.reshape(mask, 
R.shape([6]))
+                lv2: R.Tensor(dtype="int64", ndim=2) = R.nonzero(lv1)
+                lv3: R.Tensor(dtype="int64", ndim=1) = R.squeeze(lv2, axis=[0])
+                lv4: R.Tensor(dtype="float32", ndim=1) = R.take(lv, lv3, 
axis=0, mode="fast")
+                lv5: R.Tensor((), dtype="int64") = R.const(0, "int64")
+                lv6: R.Tensor((), dtype="bool") = R.const(True, "bool")
+                lv7: R.Tensor((), dtype="bool") = R.const(True, "bool")
+                gv: R.Tuple(R.Tensor(dtype="float32", ndim=1)) = (lv4,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(2, 3, dtype=torch.float32),
+        torch.tensor([[True, False, True], [False, True, False]]),
+    )
+    verify_model(MaskedSelect(), example_args, {}, Expected)
+
+
 def test_new_ones():
     class NewOnes(Module):
         def forward(self, x):

Reply via email to