This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f4c4e78eb9 [Relax] support masked_scatter (#17525)
f4c4e78eb9 is described below
commit f4c4e78eb9a6c58d6c3388c79211bfd24501198b
Author: Archermmt <[email protected]>
AuthorDate: Sun Nov 17 12:58:33 2024 +0800
[Relax] support masked_scatter (#17525)
* support masked_scatter
* remove logging
---
python/tvm/relax/frontend/torch/fx_translator.py | 26 +++++++
src/contrib/msc/core/ir/graph_builder.cc | 4 +-
src/contrib/msc/core/transform/set_expr_layout.cc | 13 +++-
src/contrib/msc/framework/torch/torch_opcode.cc | 20 +++++
tests/python/contrib/test_msc/test_graph_build.py | 85 ++++++++++++++++++++++
.../contrib/test_msc/test_translate_relax.py | 23 ++++++
.../contrib/test_msc/test_translate_torch.py | 23 ++++++
tests/python/relax/test_frontend_from_fx.py | 61 ++++++++++++++++
8 files changed, 251 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 746010a4dc..52122ce333 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -472,6 +472,31 @@ class TorchFXImporter(BaseFXGraphImporter):
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
return self.block_builder.emit(relax.op.where(mask, values, x))
+ def _masked_scatter(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ mask = self.env[node.args[1]]
+ source = self.env[node.args[2]]
+ ndim = len(mask.struct_info.shape)
+ if ndim == 1:
+ index = self.block_builder.emit(relax.op.cumsum(mask, 0,
dtype="int32"))
+ index = self.block_builder.emit(relax.op.subtract(index,
relax.const(1, "int32")))
+ gathered_source = self.block_builder.emit(relax.op.take(source,
index, axis=0))
+ else:
+ f_mask = self.block_builder.emit(relax.op.reshape(mask, [-1]))
+ index = self.block_builder.emit(relax.op.cumsum(f_mask, 0,
dtype="int32"))
+ index = self.block_builder.emit(relax.op.subtract(index,
relax.const(1, "int32")))
+ source_shape = [-1] + [
+ s for idx, s in enumerate(source.struct_info.shape) if idx >=
ndim
+ ]
+ f_source = self.block_builder.emit(relax.op.reshape(source,
source_shape))
+ gathered_source = self.block_builder.emit(relax.op.take(f_source,
index, axis=0))
+ gathered_source = self.block_builder.emit(
+ relax.op.reshape(gathered_source, x.struct_info.shape)
+ )
+ if ndim != len(x.struct_info.shape):
+ mask = self.block_builder.emit(relax.op.broadcast_to(mask,
x.struct_info.shape))
+ return self.block_builder.emit(relax.op.where(mask, gathered_source,
x))
+
def _ones(self, node: fx.Node) -> relax.Var:
import torch
@@ -695,6 +720,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"index_select": self._index_select,
"masked_fill_": self._inplace_masked_fill,
"masked_fill": self._masked_fill,
+ "masked_scatter": self._masked_scatter,
"new_ones": self._new_ones,
"ones": self._ones,
"tensor": self._tensor,
diff --git a/src/contrib/msc/core/ir/graph_builder.cc
b/src/contrib/msc/core/ir/graph_builder.cc
index abb7dfbd5e..27115cb130 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -704,7 +704,9 @@ const MSCPrim RelaxGraphBuilder::MatchOrCreatePrim(const
PrimExpr& prim, const S
}
void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) {
- AddNode(GetRef<relax::Constant>(op));
+ if (!expr_tensor_map_.count(GetRef<relax::Constant>(op))) {
+ AddNode(GetRef<relax::Constant>(op));
+ }
}
void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc
b/src/contrib/msc/core/transform/set_expr_layout.cc
index a3902a44bf..f3504d7723 100644
--- a/src/contrib/msc/core/transform/set_expr_layout.cc
+++ b/src/contrib/msc/core/transform/set_expr_layout.cc
@@ -492,9 +492,16 @@ InferLayoutOutput ForwardInferLayoutTake(const Call& call,
return InferLayoutOutput({input_layout, indices_layout}, {output_layout},
Attrs());
}
if (indices_layout->layout.defined()) {
- size_t indices_size = indices_layout->layout.ndim();
- LayoutDecision output_layout =
- LayoutUtils::ExpandLayout(indices_layout,
std::vector<size_t>{indices_size});
+ std::vector<size_t> expand_axes;
+ for (size_t i = indices_layout->layout.ndim(); i < output_shape.size();
i++) {
+ expand_axes.push_back(i);
+ }
+ LayoutDecision output_layout;
+ if (expand_axes.size() == 0) {
+ output_layout = indices_layout;
+ } else {
+ output_layout = LayoutUtils::ExpandLayout(indices_layout, expand_axes);
+ }
return InferLayoutOutput({input_layout, indices_layout}, {output_layout},
Attrs());
}
return InferLayoutOutput();
diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc
b/src/contrib/msc/framework/torch/torch_opcode.cc
index abac3682fb..f5784efe3d 100644
--- a/src/contrib/msc/framework/torch/torch_opcode.cc
+++ b/src/contrib/msc/framework/torch/torch_opcode.cc
@@ -224,6 +224,12 @@ class TorchConstantCodeGen : public TorchOpCode {
} else if (dtype == "float32") {
stack_.assign(module_ref(), node()->GetTypeAttr<float>("scalar"));
}
+ } else if (dtype == "bool") {
+ stack_.func_call("register_buffer", "", "self")
+ .call_arg(DocUtils::ToStr(ref_name))
+ .inplace_start("torch.BoolTensor")
+ .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
+ .inplace_end();
} else if (dtype == "int32") {
stack_.func_call("register_buffer", "", "self")
.call_arg(DocUtils::ToStr(ref_name))
@@ -658,6 +664,18 @@ class TorchStridedSliceCodeGen : public TorchOpCode {
}
};
+class TorchTakeCodeGen : public TorchOpCode {
+ TORCH_OP_CODEGEN_METHODS(TorchTakeCodeGen)
+
+ protected:
+ void CodeGenForward() final {
+ if (node()->InputAt(1)->DTypeName() == "int32") {
+ stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
+ }
+ stack_.assign(IdxNode(), DocUtils::ToIndex(IdxInput(0), IdxInput(1)));
+ }
+};
+
class TorchTriCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchTriCodeGen)
@@ -738,6 +756,7 @@ const std::shared_ptr<std::unordered_map<String,
std::shared_ptr<TorchOpCode>>>
map->emplace("subtract", std::make_shared<TorchSimpleCodeGen>("",
"torch.subtract"));
map->emplace("tan", std::make_shared<TorchSimpleCodeGen>("", "torch.tan"));
map->emplace("tanh", std::make_shared<TorchSimpleCodeGen>("", "torch.tanh"));
+ map->emplace("where", std::make_shared<TorchSimpleCodeGen>("",
"torch.where"));
// reduce ops
map->emplace("max", std::make_shared<TorchReduceAxesCodeGen>("",
"torch.max"));
@@ -771,6 +790,7 @@ const std::shared_ptr<std::unordered_map<String,
std::shared_ptr<TorchOpCode>>>
map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
map->emplace("split", std::make_shared<TorchSplitCodeGen>("",
"torch.split"));
map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("",
""));
+ map->emplace("take", std::make_shared<TorchTakeCodeGen>("", ""));
// create ops
map->emplace("constant",
std::make_shared<TorchConstantCodeGen>("nn.Parameter", ""));
diff --git a/tests/python/contrib/test_msc/test_graph_build.py
b/tests/python/contrib/test_msc/test_graph_build.py
index 647879378e..3b514ad6d8 100644
--- a/tests/python/contrib/test_msc/test_graph_build.py
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -2472,6 +2472,91 @@ def test_scatter(dynamic):
)
[email protected]("dynamic", [True, False])
+def test_masked_scatter(dynamic):
+ """test graph builder for masked_scatter"""
+
+ dim = "dim" if dynamic else 5
+
+ class MaskedScatter1(Module):
+ def forward(self, data, mask, src):
+ return data.masked_scatter(mask, src)
+
+ class MaskedScatter2(Module):
+ def forward(self, data, mask, src):
+ return data.masked_scatter(mask, src)
+
+ expected1 = {
+ "inputs": [
+ {"name": "inp_0", "shape": [dim], "dtype": "float32", "layout":
"A"},
+ {"name": "inp_1", "shape": [dim], "dtype": "bool", "layout": "A"},
+ {"name": "inp_2", "shape": [10], "dtype": "float32", "layout":
"A"},
+ ],
+ "outputs": [{"name": "where", "shape": [dim], "dtype": "float32",
"layout": "A"}],
+ "nodes": {
+ "total": 8,
+ "input": 3,
+ "cumsum": 1,
+ "constant": 1,
+ "subtract": 1,
+ "take": 1,
+ "where": 1,
+ },
+ }
+ expected2 = {
+ "inputs": [
+ {
+ "name": "inp_0",
+ "shape": [2, dim],
+ "dtype": "float32",
+ "layout": "" if dynamic else "BA",
+ },
+ {
+ "name": "inp_1",
+ "shape": [2, dim],
+ "dtype": "bool",
+ "layout": "" if dynamic else "BA",
+ },
+ {
+ "name": "inp_2",
+ "shape": [3, dim],
+ "dtype": "float32",
+ "layout": "" if dynamic else "BA",
+ },
+ ],
+ "outputs": [
+ {
+ "name": "where",
+ "shape": [2, dim],
+ "dtype": "float32",
+ "layout": "" if dynamic else "BA",
+ }
+ ],
+ "nodes": {
+ "total": 11,
+ "input": 3,
+ "reshape": 3,
+ "cumsum": 1,
+ "constant": 1,
+ "subtract": 1,
+ "take": 1,
+ "where": 1,
+ },
+ }
+ if dynamic:
+ expected1["prims"] = {"total": 1, "shape": 1}
+ expected2["prims"] = {"total": 5, "shape": 1, "Int": 2, "Mul": 2}
+
+ verify_model(
+ MaskedScatter1(), [([dim], "float32"), ([dim], "bool"), ([10],
"float32")], expected1
+ )
+ verify_model(
+ MaskedScatter2(),
+ [([2, dim], "float32"), ([2, dim], "bool"), ([3, dim], "float32")],
+ expected2,
+ )
+
+
def test_put():
"""test graph builder for index_put"""
diff --git a/tests/python/contrib/test_msc/test_translate_relax.py
b/tests/python/contrib/test_msc/test_translate_relax.py
index 27a02844e1..d8f746d688 100644
--- a/tests/python/contrib/test_msc/test_translate_relax.py
+++ b/tests/python/contrib/test_msc/test_translate_relax.py
@@ -1193,6 +1193,29 @@ def test_scatter():
verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2,
5], "float32")])
+def test_masked_scatter():
+ """test relax translator for masked_scatter"""
+
+ class MaskedScatter1(Module):
+ def __init__(self):
+ super().__init__()
+ self.mask = msc_utils.random_data([(5,), "bool"],
MSCFramework.TORCH)
+
+ def forward(self, data, src):
+ return data.masked_scatter(self.mask, src)
+
+ class MaskedScatter2(Module):
+ def __init__(self):
+ super().__init__()
+ self.mask = msc_utils.random_data([(2, 5), "bool"],
MSCFramework.TORCH)
+
+ def forward(self, data, src):
+ return data.masked_scatter(self.mask, src)
+
+ verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")])
+ verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")])
+
+
def test_put():
"""test relax translator for index_put"""
diff --git a/tests/python/contrib/test_msc/test_translate_torch.py
b/tests/python/contrib/test_msc/test_translate_torch.py
index 6ed28c0ac0..6535ef66c8 100644
--- a/tests/python/contrib/test_msc/test_translate_torch.py
+++ b/tests/python/contrib/test_msc/test_translate_torch.py
@@ -1173,6 +1173,29 @@ def test_scatter():
)
+def test_masked_scatter():
+ """test torch translator for masked_scatter"""
+
+ class MaskedScatter1(Module):
+ def __init__(self):
+ super().__init__()
+ self.mask = msc_utils.random_data([(5,), "bool"],
MSCFramework.TORCH)
+
+ def forward(self, data, src):
+ return data.masked_scatter(self.mask, src)
+
+ class MaskedScatter2(Module):
+ def __init__(self):
+ super().__init__()
+ self.mask = msc_utils.random_data([(2, 5), "bool"],
MSCFramework.TORCH)
+
+ def forward(self, data, src):
+ return data.masked_scatter(self.mask, src)
+
+ verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")], True)
+ verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")],
True)
+
+
def test_put():
"""test torch translator for index_put"""
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 08331f0861..d9857723b1 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4023,5 +4023,66 @@ def test_scatter():
verify_model(Scatter(), input_info, {}, expected)
+def test_masked_scatter():
+ class MaskedScatter1(Module):
+ def forward(self, data, mask, src):
+ return data.masked_scatter(mask, src)
+
+ class MaskedScatter2(Module):
+ def forward(self, data, mask, src):
+ return data.masked_scatter(mask, src)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5,), dtype="float32"),
+ inp_1: R.Tensor((5,), dtype="bool"),
+ inp_2: R.Tensor((10,), dtype="float32"),
+ ) -> R.Tensor((5,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((5,), dtype="int32") = R.cumsum(
+ inp_1, axis=0, dtype="int32", exclusive=False
+ )
+ lv1: R.Tensor((5,), dtype="int32") = R.subtract(lv, R.const(1,
"int32"))
+ lv2: R.Tensor((5,), dtype="float32") = R.take(inp_2, lv1,
axis=0)
+ lv3: R.Tensor((5,), dtype="float32") = R.where(inp_1, lv2,
inp_0)
+ gv: R.Tensor((5,), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 5), dtype="float32"),
+ inp_1: R.Tensor((2, 5), dtype="bool"),
+ inp_2: R.Tensor((3, 5), dtype="float32"),
+ ) -> R.Tensor((2, 5), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((10,), dtype="bool") = R.reshape(inp_1,
R.shape([10]))
+ lv1: R.Tensor((10,), dtype="int32") = R.cumsum(
+ lv, axis=0, dtype="int32", exclusive=False
+ )
+ lv2: R.Tensor((10,), dtype="int32") = R.subtract(lv1,
R.const(1, "int32"))
+ lv3: R.Tensor((15,), dtype="float32") = R.reshape(inp_2,
R.shape([15]))
+ lv4: R.Tensor((10,), dtype="float32") = R.take(lv3, lv2,
axis=0)
+ lv5: R.Tensor((2, 5), dtype="float32") = R.reshape(lv4,
R.shape([2, 5]))
+ lv6: R.Tensor((2, 5), dtype="float32") = R.where(inp_1, lv5,
inp_0)
+ gv: R.Tensor((2, 5), dtype="float32") = lv6
+ R.output(gv)
+ return gv
+
+ verify_model(
+ MaskedScatter1(), [([5], "float32"), ([5], "bool"), ([10],
"float32")], {}, expected1
+ )
+ verify_model(
+ MaskedScatter2(),
+ [([2, 5], "float32"), ([2, 5], "bool"), ([3, 5], "float32")],
+ {},
+ expected2,
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()