This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f4c4e78eb9 [Relax] support masked_scatter (#17525)
f4c4e78eb9 is described below

commit f4c4e78eb9a6c58d6c3388c79211bfd24501198b
Author: Archermmt <[email protected]>
AuthorDate: Sun Nov 17 12:58:33 2024 +0800

    [Relax] support masked_scatter (#17525)
    
    * support masked_scatter
    
    * remove logging
---
 python/tvm/relax/frontend/torch/fx_translator.py   | 26 +++++++
 src/contrib/msc/core/ir/graph_builder.cc           |  4 +-
 src/contrib/msc/core/transform/set_expr_layout.cc  | 13 +++-
 src/contrib/msc/framework/torch/torch_opcode.cc    | 20 +++++
 tests/python/contrib/test_msc/test_graph_build.py  | 85 ++++++++++++++++++++++
 .../contrib/test_msc/test_translate_relax.py       | 23 ++++++
 .../contrib/test_msc/test_translate_torch.py       | 23 ++++++
 tests/python/relax/test_frontend_from_fx.py        | 61 ++++++++++++++++
 8 files changed, 251 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 746010a4dc..52122ce333 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -472,6 +472,31 @@ class TorchFXImporter(BaseFXGraphImporter):
         values = self.block_builder.emit(relax.op.full_like(x, rx_value))
         return self.block_builder.emit(relax.op.where(mask, values, x))
 
+    def _masked_scatter(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        mask = self.env[node.args[1]]
+        source = self.env[node.args[2]]
+        ndim = len(mask.struct_info.shape)
+        if ndim == 1:
+            index = self.block_builder.emit(relax.op.cumsum(mask, 0, 
dtype="int32"))
+            index = self.block_builder.emit(relax.op.subtract(index, 
relax.const(1, "int32")))
+            gathered_source = self.block_builder.emit(relax.op.take(source, 
index, axis=0))
+        else:
+            f_mask = self.block_builder.emit(relax.op.reshape(mask, [-1]))
+            index = self.block_builder.emit(relax.op.cumsum(f_mask, 0, 
dtype="int32"))
+            index = self.block_builder.emit(relax.op.subtract(index, 
relax.const(1, "int32")))
+            source_shape = [-1] + [
+                s for idx, s in enumerate(source.struct_info.shape) if idx >= 
ndim
+            ]
+            f_source = self.block_builder.emit(relax.op.reshape(source, 
source_shape))
+            gathered_source = self.block_builder.emit(relax.op.take(f_source, 
index, axis=0))
+            gathered_source = self.block_builder.emit(
+                relax.op.reshape(gathered_source, x.struct_info.shape)
+            )
+        if ndim != len(x.struct_info.shape):
+            mask = self.block_builder.emit(relax.op.broadcast_to(mask, 
x.struct_info.shape))
+        return self.block_builder.emit(relax.op.where(mask, gathered_source, 
x))
+
     def _ones(self, node: fx.Node) -> relax.Var:
         import torch
 
@@ -695,6 +720,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "index_select": self._index_select,
             "masked_fill_": self._inplace_masked_fill,
             "masked_fill": self._masked_fill,
+            "masked_scatter": self._masked_scatter,
             "new_ones": self._new_ones,
             "ones": self._ones,
             "tensor": self._tensor,
diff --git a/src/contrib/msc/core/ir/graph_builder.cc 
b/src/contrib/msc/core/ir/graph_builder.cc
index abb7dfbd5e..27115cb130 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -704,7 +704,9 @@ const MSCPrim RelaxGraphBuilder::MatchOrCreatePrim(const 
PrimExpr& prim, const S
 }
 
 void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) {
-  AddNode(GetRef<relax::Constant>(op));
+  if (!expr_tensor_map_.count(GetRef<relax::Constant>(op))) {
+    AddNode(GetRef<relax::Constant>(op));
+  }
 }
 
 void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc 
b/src/contrib/msc/core/transform/set_expr_layout.cc
index a3902a44bf..f3504d7723 100644
--- a/src/contrib/msc/core/transform/set_expr_layout.cc
+++ b/src/contrib/msc/core/transform/set_expr_layout.cc
@@ -492,9 +492,16 @@ InferLayoutOutput ForwardInferLayoutTake(const Call& call,
     return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, 
Attrs());
   }
   if (indices_layout->layout.defined()) {
-    size_t indices_size = indices_layout->layout.ndim();
-    LayoutDecision output_layout =
-        LayoutUtils::ExpandLayout(indices_layout, 
std::vector<size_t>{indices_size});
+    std::vector<size_t> expand_axes;
+    for (size_t i = indices_layout->layout.ndim(); i < output_shape.size(); 
i++) {
+      expand_axes.push_back(i);
+    }
+    LayoutDecision output_layout;
+    if (expand_axes.size() == 0) {
+      output_layout = indices_layout;
+    } else {
+      output_layout = LayoutUtils::ExpandLayout(indices_layout, expand_axes);
+    }
     return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, 
Attrs());
   }
   return InferLayoutOutput();
diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc 
b/src/contrib/msc/framework/torch/torch_opcode.cc
index abac3682fb..f5784efe3d 100644
--- a/src/contrib/msc/framework/torch/torch_opcode.cc
+++ b/src/contrib/msc/framework/torch/torch_opcode.cc
@@ -224,6 +224,12 @@ class TorchConstantCodeGen : public TorchOpCode {
       } else if (dtype == "float32") {
         stack_.assign(module_ref(), node()->GetTypeAttr<float>("scalar"));
       }
+    } else if (dtype == "bool") {
+      stack_.func_call("register_buffer", "", "self")
+          .call_arg(DocUtils::ToStr(ref_name))
+          .inplace_start("torch.BoolTensor")
+          .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
+          .inplace_end();
     } else if (dtype == "int32") {
       stack_.func_call("register_buffer", "", "self")
           .call_arg(DocUtils::ToStr(ref_name))
@@ -658,6 +664,18 @@ class TorchStridedSliceCodeGen : public TorchOpCode {
   }
 };
 
+class TorchTakeCodeGen : public TorchOpCode {
+  TORCH_OP_CODEGEN_METHODS(TorchTakeCodeGen)
+
+ protected:
+  void CodeGenForward() final {
+    if (node()->InputAt(1)->DTypeName() == "int32") {
+      stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
+    }
+    stack_.assign(IdxNode(), DocUtils::ToIndex(IdxInput(0), IdxInput(1)));
+  }
+};
+
 class TorchTriCodeGen : public TorchOpCode {
   TORCH_OP_CODEGEN_METHODS(TorchTriCodeGen)
 
@@ -738,6 +756,7 @@ const std::shared_ptr<std::unordered_map<String, 
std::shared_ptr<TorchOpCode>>>
   map->emplace("subtract", std::make_shared<TorchSimpleCodeGen>("", 
"torch.subtract"));
   map->emplace("tan", std::make_shared<TorchSimpleCodeGen>("", "torch.tan"));
   map->emplace("tanh", std::make_shared<TorchSimpleCodeGen>("", "torch.tanh"));
+  map->emplace("where", std::make_shared<TorchSimpleCodeGen>("", 
"torch.where"));
 
   // reduce ops
   map->emplace("max", std::make_shared<TorchReduceAxesCodeGen>("", 
"torch.max"));
@@ -771,6 +790,7 @@ const std::shared_ptr<std::unordered_map<String, 
std::shared_ptr<TorchOpCode>>>
   map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
   map->emplace("split", std::make_shared<TorchSplitCodeGen>("", 
"torch.split"));
   map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("", 
""));
+  map->emplace("take", std::make_shared<TorchTakeCodeGen>("", ""));
 
   // create ops
   map->emplace("constant", 
std::make_shared<TorchConstantCodeGen>("nn.Parameter", ""));
diff --git a/tests/python/contrib/test_msc/test_graph_build.py 
b/tests/python/contrib/test_msc/test_graph_build.py
index 647879378e..3b514ad6d8 100644
--- a/tests/python/contrib/test_msc/test_graph_build.py
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -2472,6 +2472,91 @@ def test_scatter(dynamic):
     )
 
 
[email protected]("dynamic", [True, False])
+def test_masked_scatter(dynamic):
+    """test graph builder for masked_scatter"""
+
+    dim = "dim" if dynamic else 5
+
+    class MaskedScatter1(Module):
+        def forward(self, data, mask, src):
+            return data.masked_scatter(mask, src)
+
+    class MaskedScatter2(Module):
+        def forward(self, data, mask, src):
+            return data.masked_scatter(mask, src)
+
+    expected1 = {
+        "inputs": [
+            {"name": "inp_0", "shape": [dim], "dtype": "float32", "layout": 
"A"},
+            {"name": "inp_1", "shape": [dim], "dtype": "bool", "layout": "A"},
+            {"name": "inp_2", "shape": [10], "dtype": "float32", "layout": 
"A"},
+        ],
+        "outputs": [{"name": "where", "shape": [dim], "dtype": "float32", 
"layout": "A"}],
+        "nodes": {
+            "total": 8,
+            "input": 3,
+            "cumsum": 1,
+            "constant": 1,
+            "subtract": 1,
+            "take": 1,
+            "where": 1,
+        },
+    }
+    expected2 = {
+        "inputs": [
+            {
+                "name": "inp_0",
+                "shape": [2, dim],
+                "dtype": "float32",
+                "layout": "" if dynamic else "BA",
+            },
+            {
+                "name": "inp_1",
+                "shape": [2, dim],
+                "dtype": "bool",
+                "layout": "" if dynamic else "BA",
+            },
+            {
+                "name": "inp_2",
+                "shape": [3, dim],
+                "dtype": "float32",
+                "layout": "" if dynamic else "BA",
+            },
+        ],
+        "outputs": [
+            {
+                "name": "where",
+                "shape": [2, dim],
+                "dtype": "float32",
+                "layout": "" if dynamic else "BA",
+            }
+        ],
+        "nodes": {
+            "total": 11,
+            "input": 3,
+            "reshape": 3,
+            "cumsum": 1,
+            "constant": 1,
+            "subtract": 1,
+            "take": 1,
+            "where": 1,
+        },
+    }
+    if dynamic:
+        expected1["prims"] = {"total": 1, "shape": 1}
+        expected2["prims"] = {"total": 5, "shape": 1, "Int": 2, "Mul": 2}
+
+    verify_model(
+        MaskedScatter1(), [([dim], "float32"), ([dim], "bool"), ([10], 
"float32")], expected1
+    )
+    verify_model(
+        MaskedScatter2(),
+        [([2, dim], "float32"), ([2, dim], "bool"), ([3, dim], "float32")],
+        expected2,
+    )
+
+
 def test_put():
     """test graph builder for index_put"""
 
diff --git a/tests/python/contrib/test_msc/test_translate_relax.py 
b/tests/python/contrib/test_msc/test_translate_relax.py
index 27a02844e1..d8f746d688 100644
--- a/tests/python/contrib/test_msc/test_translate_relax.py
+++ b/tests/python/contrib/test_msc/test_translate_relax.py
@@ -1193,6 +1193,29 @@ def test_scatter():
     verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 
5], "float32")])
 
 
+def test_masked_scatter():
+    """test relax translator for masked_scatter"""
+
+    class MaskedScatter1(Module):
+        def __init__(self):
+            super().__init__()
+            self.mask = msc_utils.random_data([(5,), "bool"], 
MSCFramework.TORCH)
+
+        def forward(self, data, src):
+            return data.masked_scatter(self.mask, src)
+
+    class MaskedScatter2(Module):
+        def __init__(self):
+            super().__init__()
+            self.mask = msc_utils.random_data([(2, 5), "bool"], 
MSCFramework.TORCH)
+
+        def forward(self, data, src):
+            return data.masked_scatter(self.mask, src)
+
+    verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")])
+    verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")])
+
+
 def test_put():
     """test relax translator for index_put"""
 
diff --git a/tests/python/contrib/test_msc/test_translate_torch.py 
b/tests/python/contrib/test_msc/test_translate_torch.py
index 6ed28c0ac0..6535ef66c8 100644
--- a/tests/python/contrib/test_msc/test_translate_torch.py
+++ b/tests/python/contrib/test_msc/test_translate_torch.py
@@ -1173,6 +1173,29 @@ def test_scatter():
         )
 
 
+def test_masked_scatter():
+    """test torch translator for masked_scatter"""
+
+    class MaskedScatter1(Module):
+        def __init__(self):
+            super().__init__()
+            self.mask = msc_utils.random_data([(5,), "bool"], 
MSCFramework.TORCH)
+
+        def forward(self, data, src):
+            return data.masked_scatter(self.mask, src)
+
+    class MaskedScatter2(Module):
+        def __init__(self):
+            super().__init__()
+            self.mask = msc_utils.random_data([(2, 5), "bool"], 
MSCFramework.TORCH)
+
+        def forward(self, data, src):
+            return data.masked_scatter(self.mask, src)
+
+    verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")], True)
+    verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")], 
True)
+
+
 def test_put():
     """test torch translator for index_put"""
 
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 08331f0861..d9857723b1 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4023,5 +4023,66 @@ def test_scatter():
     verify_model(Scatter(), input_info, {}, expected)
 
 
+def test_masked_scatter():
+    class MaskedScatter1(Module):
+        def forward(self, data, mask, src):
+            return data.masked_scatter(mask, src)
+
+    class MaskedScatter2(Module):
+        def forward(self, data, mask, src):
+            return data.masked_scatter(mask, src)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5,), dtype="float32"),
+            inp_1: R.Tensor((5,), dtype="bool"),
+            inp_2: R.Tensor((10,), dtype="float32"),
+        ) -> R.Tensor((5,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((5,), dtype="int32") = R.cumsum(
+                    inp_1, axis=0, dtype="int32", exclusive=False
+                )
+                lv1: R.Tensor((5,), dtype="int32") = R.subtract(lv, R.const(1, 
"int32"))
+                lv2: R.Tensor((5,), dtype="float32") = R.take(inp_2, lv1, 
axis=0)
+                lv3: R.Tensor((5,), dtype="float32") = R.where(inp_1, lv2, 
inp_0)
+                gv: R.Tensor((5,), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 5), dtype="float32"),
+            inp_1: R.Tensor((2, 5), dtype="bool"),
+            inp_2: R.Tensor((3, 5), dtype="float32"),
+        ) -> R.Tensor((2, 5), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((10,), dtype="bool") = R.reshape(inp_1, 
R.shape([10]))
+                lv1: R.Tensor((10,), dtype="int32") = R.cumsum(
+                    lv, axis=0, dtype="int32", exclusive=False
+                )
+                lv2: R.Tensor((10,), dtype="int32") = R.subtract(lv1, 
R.const(1, "int32"))
+                lv3: R.Tensor((15,), dtype="float32") = R.reshape(inp_2, 
R.shape([15]))
+                lv4: R.Tensor((10,), dtype="float32") = R.take(lv3, lv2, 
axis=0)
+                lv5: R.Tensor((2, 5), dtype="float32") = R.reshape(lv4, 
R.shape([2, 5]))
+                lv6: R.Tensor((2, 5), dtype="float32") = R.where(inp_1, lv5, 
inp_0)
+                gv: R.Tensor((2, 5), dtype="float32") = lv6
+                R.output(gv)
+            return gv
+
+    verify_model(
+        MaskedScatter1(), [([5], "float32"), ([5], "bool"), ([10], 
"float32")], {}, expected1
+    )
+    verify_model(
+        MaskedScatter2(),
+        [([2, 5], "float32"), ([2, 5], "bool"), ([3, 5], "float32")],
+        {},
+        expected2,
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to