This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 70ea70f31d [Unity][Frontend] FX translator supporting more ops (#14196)
70ea70f31d is described below
commit 70ea70f31d0c43a5349344a2f92bf4936bf4e10a
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 4 21:57:38 2023 -0500
[Unity][Frontend] FX translator supporting more ops (#14196)
This PR improves the torch FX translator in the following perspectives:
* support unary op `sigmoid` and `round`,
* support in-place `fill`, `triu` and `tril`,
* support `tensor`, `arange`, `empty`,
* support `bmm` (batch matrix multiplication),
* support `astype`,
* support `chunk` and `squeeze`.
This PR also fixes `Embedding`. Previously the translation assumes that
the input to Embedding will only be 1-dimensional, and will throw
exception when the input has more than one dimension (i.e., batched).
This PR brings the support.
---
python/tvm/relax/frontend/torch/fx_translator.py | 165 ++++++++++-
tests/python/relax/test_frontend_from_fx.py | 344 ++++++++++++++++++++++-
2 files changed, 496 insertions(+), 13 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 1d132c855e..b580e1679b 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -63,13 +63,15 @@ class TorchFXImporter:
"""converts the PyTorch scalar type input_type to a TVM dtype."""
import torch # type: ignore
- input_type = input_type.lower()
+ input_type = input_type.lower() if isinstance(input_type, str) else
input_type
if input_type in ["float", "float32", "torch.float32", torch.float32]:
return "float32"
elif input_type in ["float16", "torch.float16", torch.float16]:
return "float16"
elif input_type in ["int64", "torch.int64", torch.int64]:
return "int64"
+ elif input_type in ["int32", "torch.int32", torch.int32]:
+ return "int32"
else:
raise NotImplementedError("input_type {} is not handled
yet".format(input_type))
@@ -134,12 +136,21 @@ class TorchFXImporter:
def _sin(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))
+ def _sigmoid(self, node: fx.node.Node) -> relax.Var:
+ return
self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]]))
+
def _sqrt(self, node: fx.node.Node) -> relax.Expr:
arg = self.env[node.args[0]]
if isinstance(arg, (int, float)):
arg = relax.const(arg, "float32")
return self.block_builder.emit(relax.op.sqrt(arg))
+ def _round(self, node: fx.node.Node) -> relax.Expr:
+ if "decimals" in node.kwargs and node.kwargs["decimals"] != 0:
+ raise ValueError("specifying decimals for round is not supported
yet")
+ arg = self.env[node.args[0]]
+ return self.block_builder.emit(relax.op.round(arg))
+
def _add(self, node: fx.node.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
@@ -200,11 +211,93 @@ class TorchFXImporter:
########## Creation ##########
- def _tril(self, node: fx.node.Node) -> relax.Var:
- x = self.env[node.args[0]]
- k = node.args[1] if len(node.args) > 1 else 0
- assert isinstance(k, int)
- return self.block_builder.emit(relax.op.create.tril(x, k))
+ def _arange(self, node: fx.node.Node) -> relax.Var:
+ import torch
+ import numpy as np
+
+ start_end_step = [None, None, None]
+ if "start" in node.kwargs:
+ start_end_step[0] = node.kwargs["start"]
+ if "end" in node.kwargs:
+ start_end_step[1] = node.kwargs["end"]
+ if "step" in node.kwargs:
+ start_end_step[2] = node.kwargs["step"]
+
+ if len(node.args) == 1:
+ assert start_end_step[1] is None
+ start_end_step[1] = node.args[0]
+ elif len(node.args) == 2:
+ assert start_end_step[0] is None
+ assert start_end_step[1] is None
+ start_end_step[0] = node.args[0]
+ start_end_step[1] = node.args[1]
+ elif len(node.args) == 3:
+ assert start_end_step[0] is None
+ assert start_end_step[1] is None
+ assert start_end_step[2] is None
+ start_end_step[0] = node.args[0]
+ start_end_step[1] = node.args[1]
+ start_end_step[2] = node.args[2]
+
+ if start_end_step[0] is None:
+ start_end_step[0] = 0
+ if start_end_step[2] is None:
+ start_end_step[2] = 1
+
+ if "dtype" in node.kwargs:
+ dtype =
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
+ elif any([isinstance(x, float) for x in start_end_step]):
+ dtype =
TorchFXImporter._convert_data_type(torch.get_default_dtype())
+ else:
+ dtype = "int64"
+
+ return relax.const(np.arange(*start_end_step, dtype=dtype))
+
+ def _empty(self, node: fx.node.Node) -> relax.Var:
+ dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
+ return self.block_builder.emit(relax.op.zeros(node.args, dtype))
+
+ def _inplace_fill(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dtype = x.struct_info.dtype
+ value = args[1] if isinstance(args[1], relax.Expr) else
relax.const(args[1], dtype)
+ filled = self.block_builder.emit(relax.op.full(x.struct_info.shape,
value, dtype))
+ self.env[node.args[0]] = filled
+ return filled
+
+ def _tensor(self, node: fx.node.Node) -> relax.Var:
+ dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None
+ if isinstance(node.args[0], float):
+ return relax.const(node.args[0], dtype if dtype is not None else
"float64")
+ elif isinstance(node.args[0], int):
+ return relax.const(node.args[0], dtype if dtype is not None else
"int64")
+ raise ValueError("torch.tensor with value not a float or int is not
accepted")
+
+ def _tril_triu(self, op: Callable) -> Callable:
+ from torch import fx
+
+ def convert(node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ k = node.args[1] if len(node.args) > 1 else 0
+ assert isinstance(k, int)
+ return self.block_builder.emit(op(x, k))
+
+ return convert
+
+ def _inplace_tril_triu(self, op: Callable) -> Callable:
+ from torch import fx
+
+ def convert(node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ k = node.args[1] if len(node.args) > 1 else 0
+ assert isinstance(k, int)
+
+ mutated = self.block_builder.emit(op(x, k))
+ self.env[node.args[0]] = mutated
+ return mutated
+
+ return convert
def _new_ones(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
@@ -238,8 +331,9 @@ class TorchFXImporter:
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]],
"float16"))
def _type(self, node: fx.node.Node) -> relax.Var:
- args = self.retrieve_args(node)
- return self.block_builder.emit(relax.op.astype(args[0], args[1]))
+ x = self.env[node.args[0]]
+ dtype = self._convert_data_type(node.args[1])
+ return self.block_builder.emit(relax.op.astype(x, dtype))
########## Linear Algebra ##########
@@ -313,12 +407,35 @@ class TorchFXImporter:
n_section = (self.shape_of(x)[dim].value + split_size - 1) //
split_size
return self.block_builder.emit(relax.op.split(x, n_section, dim))
+ def _chunk(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ chunks = node.args[1]
+
+ if "dim" in node.kwargs:
+ dim = node.kwargs["dim"]
+ elif len(node.args) > 2:
+ dim = node.args[2]
+ else:
+ dim = 0
+ return self.block_builder.emit(relax.op.split(x, chunks, dim))
+
def _transpose(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
full_idx = list(range(len(self.shape_of(args[0]))))
full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]],
full_idx[args[1]]
return self.block_builder.emit(relax.op.permute_dims(args[0],
full_idx))
+ def _squeeze(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+
+ if "dim" in node.kwargs:
+ dim = node.kwargs["dim"]
+ elif len(node.args) > 1:
+ dim = node.args[1]
+ else:
+ dim = None
+ return self.block_builder.emit(relax.op.squeeze(x, dim))
+
########## Search ##########
def _argmax_argmin(self, op: Callable) -> Callable:
@@ -521,7 +638,16 @@ class TorchFXImporter:
module = self.named_modules[node.target]
weight = self.params[module.weight]
x = self.block_builder.emit(relax.op.astype(x, "int32"))
- return self.block_builder.emit(relax.op.take(weight, x, axis=0))
+
+ ndim = x.struct_info.ndim
+ if ndim == 1:
+ return self.block_builder.emit(relax.op.take(weight, x, axis=0))
+ else:
+ x_shape = x.struct_info.shape.values
+ emb_size = weight.struct_info.shape.values[-1]
+ x = self.block_builder.emit(relax.op.reshape(x, shape=[-1]))
+ embedding = self.block_builder.emit(relax.op.take(weight, x,
axis=0))
+ return self.block_builder.emit(relax.op.reshape(embedding,
[*x_shape, emb_size]))
def _interpolate(self, node: fx.node.Node) -> relax.Var:
# torch.nn.functional.interpolate(
@@ -620,6 +746,7 @@ class TorchFXImporter:
while i < len(shape):
begin.append(0)
end.append(shape[i])
+ stride.append(1)
axes.append(i)
i = i + 1
sliced = self.block_builder.emit(relax.op.strided_slice(x, axes,
begin, end, stride))
@@ -627,6 +754,9 @@ class TorchFXImporter:
for i in expand_dim:
sliced_shape.insert(i, 1)
return self.block_builder.emit(relax.op.reshape(sliced,
sliced_shape))
+ elif isinstance(x, relax.Constant):
+ dtype = x.struct_info.dtype
+ return relax.const(x.data.numpy()[node.args[1]], dtype)
else:
assert False
@@ -660,24 +790,37 @@ class TorchFXImporter:
"mul": self._mul,
"sub": self._sub,
"pow": self._pow,
+ "sigmoid": self._sigmoid,
"sqrt": self._sqrt,
+ "round": self._round,
"lt": self._lt,
"truediv": self._truediv,
+ "fill_": self._inplace_fill,
"new_ones": self._new_ones,
- "tril": self._tril,
+ "arange": self._arange,
+ "empty": self._empty,
+ "tensor": self._tensor,
+ "tril": self._tril_triu(relax.op.tril),
+ "triu": self._tril_triu(relax.op.triu),
+ "tril_": self._inplace_tril_triu(relax.op.tril),
+ "triu_": self._inplace_tril_triu(relax.op.triu),
"sum": self._sum,
"float": self._float,
"half": self._half,
"type": self._type,
+ "astype": self._type,
"matmul": self._matmul,
"addmm": self._addmm,
+ "bmm": self._matmul,
"cat": self._cat,
"expand": self._expand,
"flatten": self._flatten,
"permute": self._permute,
"reshape": self._reshape,
"split": self._split,
+ "chunk": self._chunk,
"transpose": self._transpose,
+ "squeeze": self._squeeze,
"unsqueeze": lambda node: self.block_builder.emit(
relax.op.expand_dims(self.env[node.args[0]], node.args[1])
),
@@ -685,6 +828,7 @@ class TorchFXImporter:
"argmax": self._argmax_argmin(relax.op.argmax),
"argmin": self._argmax_argmin(relax.op.argmin),
"softmax": self._softmax,
+ "dropout": lambda node: self.env[node.args[0]],
"clamp": self._clamp,
"relu": lambda node:
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
"gelu": lambda node:
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])),
@@ -693,6 +837,7 @@ class TorchFXImporter:
"getattr": self._getattr,
"getitem": self._getitem,
"contiguous": lambda node: self.env[node.args[0]],
+ "to": lambda node: self.env[node.args[0]],
"adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
}
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 84fc97be27..9ab0b3304c 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -222,6 +222,45 @@ def test_linear():
)
[email protected]_gpu
+def test_bmm():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ class BMM(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return torch.bmm(x, y)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input_1: R.Tensor((4, 128, 256), dtype="float32"),
+ input_2: R.Tensor((4, 256, 512), dtype="float32"),
+ ) -> R.Tensor((4, 128, 512), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
+ input_1, input_2, out_dtype="float32"
+ )
+ gv: R.Tensor((4, 128, 512), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(
+ BMM(),
+ [((4, 128, 256), "float32"), ((4, 256, 512), "float32")],
+ {},
+ Expected,
+ )
+
+
@tvm.testing.requires_gpu
def test_relu():
import torch
@@ -576,7 +615,7 @@ def test_dropout():
input_info = [([1, 3, 10, 10], "float32")]
- class Dropout(Module):
+ class Dropout1(Module):
def __init__(self):
super().__init__()
self.dropout = torch.nn.Dropout(0.5)
@@ -584,6 +623,10 @@ def test_dropout():
def forward(self, input):
return self.dropout(input)
+ class Dropout2(Module):
+ def forward(self, input):
+ return torch.dropout(input, 0.5, train=True)
+
@tvm.script.ir_module
class expected1:
@R.function
@@ -596,7 +639,8 @@ def test_dropout():
R.output(gv)
return gv
- verify_model(Dropout(), input_info, {}, expected1)
+ verify_model(Dropout1(), input_info, {}, expected1)
+ verify_model(Dropout2(), input_info, {}, expected1)
@tvm.testing.requires_gpu
@@ -1078,6 +1122,52 @@ def test_size():
verify_model(Size(), input_info, {}, expected1)
[email protected]_gpu
+def test_squeeze():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([3, 1, 4, 1], "float32")]
+
+ class Squeeze1(Module):
+ def forward(self, input):
+ return input.squeeze(1)
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+ ) -> R.Tensor((3, 4, 1), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0,
axis=[1])
+ gv: R.Tensor((3, 4, 1), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class Squeeze2(Module):
+ def forward(self, input):
+ return input.squeeze()
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+ ) -> R.Tensor((3, 4), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0,
axis=None)
+ gv: R.Tensor((3, 4), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Squeeze1(), input_info, {}, Expected1)
+ verify_model(Squeeze2(), input_info, {}, Expected2)
+
+
@tvm.testing.requires_gpu
def test_unsqueeze():
import torch
@@ -1260,6 +1350,46 @@ def test_unary():
verify_model(Sqrt(), input_info, {}, expected3)
+ # sigmoid
+ class Sigmoid(Module):
+ def forward(self, input):
+ return torch.sigmoid(input)
+
+ @tvm.script.ir_module
+ class expected4:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.sigmoid(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Sigmoid(), input_info, {}, expected4)
+
+ # round
+ class Round(Module):
+ def forward(self, input):
+ return torch.round(input)
+
+ @tvm.script.ir_module
+ class expected5:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.round(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Round(), input_info, {}, expected5)
+
@tvm.testing.requires_gpu
def test_gelu():
@@ -1467,6 +1597,159 @@ def test_split():
verify_model(Split(), input_info, {}, expected1)
[email protected]_gpu
+def test_chunk():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Chunk(Module):
+ def forward(self, input):
+ return torch.chunk(input, 3, dim=1)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ ):
+ # block 0
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ ) = R.split(input_1, indices_or_sections=3, axis=1)
+ gv: R.Tuple(
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ ) = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Chunk(), input_info, {}, Expected)
+
+
[email protected]_gpu
+def test_inplace_fill():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ class InplaceFill(Module):
+ def forward(self, input):
+ input.fill_(1.5)
+ return input
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10,
10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.full(
+ R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32"
+ )
+ gv: R.Tensor((10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(InplaceFill(), [([10, 10], "float32")], {}, Expected)
+
+
[email protected]_gpu
+def test_arange():
+ import numpy as np
+ import torch
+ from torch import fx
+ from torch.nn import Module
+ from tvm.relax.frontend.torch import from_fx
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ class Arange(Module):
+ def forward(self, input):
+ return torch.arange(0, 20, dtype=torch.int32)
+
+ graph_model = fx.symbolic_trace(Arange())
+ mod = from_fx(graph_model, [([10, 10], "float32")])
+ assert len(mod["main"].body.blocks) == 1
+ assert len(mod["main"].body.blocks[0].bindings) == 1
+ assert isinstance(mod["main"].body.blocks[0].bindings[0].value,
relax.Constant)
+ tvm.testing.assert_allclose(
+ mod["main"].body.blocks[0].bindings[0].value.data.numpy(),
np.arange(0, 20, dtype="int32")
+ )
+
+
[email protected]_gpu
+def test_empty():
+ import torch
+ from torch import fx
+ from torch.nn import Module
+ from tvm.relax.frontend.torch import from_fx
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ class Empty(Module):
+ def forward(self, input):
+ return torch.empty((10, 10), dtype=torch.float32)
+
+ graph_model = fx.symbolic_trace(Empty())
+ mod = from_fx(graph_model, [([10, 10], "float32")])
+ assert len(mod["main"].body.blocks) == 1
+ assert len(mod["main"].body.blocks[0].bindings) == 1
+ assert isinstance(mod["main"].body.blocks[0].bindings[0].value,
relax.Constant)
+ assert mod["main"].body.blocks[0].bindings[0].value.data.shape == (10, 10)
+ assert mod["main"].body.blocks[0].bindings[0].value.data.dtype == "float32"
+
+
[email protected]_gpu
+def test_tensor():
+ import torch
+ from torch import fx
+ from torch.nn import Module
+ from tvm.relax.frontend.torch import from_fx
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ class Empty1(Module):
+ def forward(self, input):
+ return torch.tensor(3, dtype=torch.float32)
+
+ class Empty2(Module):
+ def forward(self, input):
+ return torch.tensor(3)
+
+ graph_model1 = fx.symbolic_trace(Empty1())
+ mod1 = from_fx(graph_model1, [([10, 10], "float32")])
+ assert len(mod1["main"].body.blocks) == 1
+ assert len(mod1["main"].body.blocks[0].bindings) == 1
+ assert isinstance(mod1["main"].body.blocks[0].bindings[0].value,
relax.Constant)
+ assert mod1["main"].body.blocks[0].bindings[0].value.data.shape == ()
+ assert mod1["main"].body.blocks[0].bindings[0].value.data.dtype ==
"float32"
+
+ graph_model2 = fx.symbolic_trace(Empty2())
+ mod2 = from_fx(graph_model2, [([10, 10], "float32")])
+ assert len(mod2["main"].body.blocks) == 1
+ assert len(mod2["main"].body.blocks[0].bindings) == 1
+ assert isinstance(mod2["main"].body.blocks[0].bindings[0].value,
relax.Constant)
+ assert mod2["main"].body.blocks[0].bindings[0].value.data.shape == ()
+ assert mod2["main"].body.blocks[0].bindings[0].value.data.dtype == "int64"
+
+
@tvm.testing.requires_gpu
def test_tril():
import torch
@@ -1481,6 +1764,11 @@ def test_tril():
def forward(self, input):
return torch.tril(input, 1)
+ class InplaceTril(Module):
+ def forward(self, input):
+ input.tril_(1)
+ return input
+
@tvm.script.ir_module
class expected1:
@R.function
@@ -1495,6 +1783,43 @@ def test_tril():
return gv
verify_model(Tril(), input_info, {}, expected1)
+ verify_model(InplaceTril(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_triu():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([10, 10], "float32")]
+
+ class Triu(Module):
+ def forward(self, input):
+ return torch.triu(input, 1)
+
+ class InplaceTriu(Module):
+ def forward(self, input):
+ input.triu_(1)
+ return input
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((10, 10), dtype="float32")
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1)
+ gv: R.Tensor((10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Triu(), input_info, {}, expected1)
+ verify_model(InplaceTriu(), input_info, {}, expected1)
@tvm.testing.requires_gpu
@@ -1589,7 +1914,7 @@ def test_reduce():
@tvm.testing.requires_gpu
-def test_to():
+def test_datatype():
import torch
from torch.nn import Module
@@ -1638,6 +1963,19 @@ def test_to():
verify_model(ToHalf(), input_info, {}, expected2)
+ # type
+ class Type(Module):
+ def forward(self, x):
+ return x.type(torch.float32)
+
+ # astype
+ class AsType(Module):
+ def forward(self, x):
+ return x.astype(torch.float32)
+
+ verify_model(Type(), input_info, {}, expected1)
+ verify_model(AsType(), input_info, {}, expected1)
+
@tvm.testing.requires_gpu
def test_permute():