This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 70ea70f31d [Unity][Frontend] FX translator supporting more ops (#14196)
70ea70f31d is described below

commit 70ea70f31d0c43a5349344a2f92bf4936bf4e10a
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 4 21:57:38 2023 -0500

    [Unity][Frontend] FX translator supporting more ops (#14196)
    
    This PR improves the torch FX translator in the following perspectives:
    * support unary op `sigmoid` and `round`,
    * support in-place `fill`, `triu` and `tril`,
    * support `tensor`, `arange`, `empty`,
    * support `bmm` (batch matrix multiplication),
    * support `astype`,
    * support `chunk` and `squeeze`.
    
    This PR also fixes `Embedding`. Previously the translation assumes that
    the input to Embedding will only be 1-dimensional, and will throw
    exception when the input has more than one dimension (i.e., batched).
    This PR brings the support.
---
 python/tvm/relax/frontend/torch/fx_translator.py | 165 ++++++++++-
 tests/python/relax/test_frontend_from_fx.py      | 344 ++++++++++++++++++++++-
 2 files changed, 496 insertions(+), 13 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 1d132c855e..b580e1679b 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -63,13 +63,15 @@ class TorchFXImporter:
         """converts the PyTorch scalar type input_type to a TVM dtype."""
         import torch  # type: ignore
 
-        input_type = input_type.lower()
+        input_type = input_type.lower() if isinstance(input_type, str) else 
input_type
         if input_type in ["float", "float32", "torch.float32", torch.float32]:
             return "float32"
         elif input_type in ["float16", "torch.float16", torch.float16]:
             return "float16"
         elif input_type in ["int64", "torch.int64", torch.int64]:
             return "int64"
+        elif input_type in ["int32", "torch.int32", torch.int32]:
+            return "int32"
         else:
             raise NotImplementedError("input_type {} is not handled 
yet".format(input_type))
 
@@ -134,12 +136,21 @@ class TorchFXImporter:
     def _sin(self, node: fx.node.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))
 
+    def _sigmoid(self, node: fx.node.Node) -> relax.Var:
+        return 
self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]]))
+
     def _sqrt(self, node: fx.node.Node) -> relax.Expr:
         arg = self.env[node.args[0]]
         if isinstance(arg, (int, float)):
             arg = relax.const(arg, "float32")
         return self.block_builder.emit(relax.op.sqrt(arg))
 
+    def _round(self, node: fx.node.Node) -> relax.Expr:
+        if "decimals" in node.kwargs and node.kwargs["decimals"] != 0:
+            raise ValueError("specifying decimals for round is not supported 
yet")
+        arg = self.env[node.args[0]]
+        return self.block_builder.emit(relax.op.round(arg))
+
     def _add(self, node: fx.node.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
@@ -200,11 +211,93 @@ class TorchFXImporter:
 
     ########## Creation ##########
 
-    def _tril(self, node: fx.node.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        k = node.args[1] if len(node.args) > 1 else 0
-        assert isinstance(k, int)
-        return self.block_builder.emit(relax.op.create.tril(x, k))
+    def _arange(self, node: fx.node.Node) -> relax.Var:
+        import torch
+        import numpy as np
+
+        start_end_step = [None, None, None]
+        if "start" in node.kwargs:
+            start_end_step[0] = node.kwargs["start"]
+        if "end" in node.kwargs:
+            start_end_step[1] = node.kwargs["end"]
+        if "step" in node.kwargs:
+            start_end_step[2] = node.kwargs["step"]
+
+        if len(node.args) == 1:
+            assert start_end_step[1] is None
+            start_end_step[1] = node.args[0]
+        elif len(node.args) == 2:
+            assert start_end_step[0] is None
+            assert start_end_step[1] is None
+            start_end_step[0] = node.args[0]
+            start_end_step[1] = node.args[1]
+        elif len(node.args) == 3:
+            assert start_end_step[0] is None
+            assert start_end_step[1] is None
+            assert start_end_step[2] is None
+            start_end_step[0] = node.args[0]
+            start_end_step[1] = node.args[1]
+            start_end_step[2] = node.args[2]
+
+        if start_end_step[0] is None:
+            start_end_step[0] = 0
+        if start_end_step[2] is None:
+            start_end_step[2] = 1
+
+        if "dtype" in node.kwargs:
+            dtype = 
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
+        elif any([isinstance(x, float) for x in start_end_step]):
+            dtype = 
TorchFXImporter._convert_data_type(torch.get_default_dtype())
+        else:
+            dtype = "int64"
+
+        return relax.const(np.arange(*start_end_step, dtype=dtype))
+
+    def _empty(self, node: fx.node.Node) -> relax.Var:
+        dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
+        return self.block_builder.emit(relax.op.zeros(node.args, dtype))
+
+    def _inplace_fill(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dtype = x.struct_info.dtype
+        value = args[1] if isinstance(args[1], relax.Expr) else 
relax.const(args[1], dtype)
+        filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, 
value, dtype))
+        self.env[node.args[0]] = filled
+        return filled
+
+    def _tensor(self, node: fx.node.Node) -> relax.Var:
+        dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None
+        if isinstance(node.args[0], float):
+            return relax.const(node.args[0], dtype if dtype is not None else 
"float64")
+        elif isinstance(node.args[0], int):
+            return relax.const(node.args[0], dtype if dtype is not None else 
"int64")
+        raise ValueError("torch.tensor with value not a float or int is not 
accepted")
+
+    def _tril_triu(self, op: Callable) -> Callable:
+        from torch import fx
+
+        def convert(node: fx.node.Node) -> relax.Var:
+            x = self.env[node.args[0]]
+            k = node.args[1] if len(node.args) > 1 else 0
+            assert isinstance(k, int)
+            return self.block_builder.emit(op(x, k))
+
+        return convert
+
+    def _inplace_tril_triu(self, op: Callable) -> Callable:
+        from torch import fx
+
+        def convert(node: fx.node.Node) -> relax.Var:
+            x = self.env[node.args[0]]
+            k = node.args[1] if len(node.args) > 1 else 0
+            assert isinstance(k, int)
+
+            mutated = self.block_builder.emit(op(x, k))
+            self.env[node.args[0]] = mutated
+            return mutated
+
+        return convert
 
     def _new_ones(self, node: fx.node.Node) -> relax.Var:
         args = self.retrieve_args(node)
@@ -238,8 +331,9 @@ class TorchFXImporter:
         return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"float16"))
 
     def _type(self, node: fx.node.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        return self.block_builder.emit(relax.op.astype(args[0], args[1]))
+        x = self.env[node.args[0]]
+        dtype = self._convert_data_type(node.args[1])
+        return self.block_builder.emit(relax.op.astype(x, dtype))
 
     ########## Linear Algebra ##########
 
@@ -313,12 +407,35 @@ class TorchFXImporter:
         n_section = (self.shape_of(x)[dim].value + split_size - 1) // 
split_size
         return self.block_builder.emit(relax.op.split(x, n_section, dim))
 
+    def _chunk(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        chunks = node.args[1]
+
+        if "dim" in node.kwargs:
+            dim = node.kwargs["dim"]
+        elif len(node.args) > 2:
+            dim = node.args[2]
+        else:
+            dim = 0
+        return self.block_builder.emit(relax.op.split(x, chunks, dim))
+
     def _transpose(self, node: fx.node.Node) -> relax.Var:
         args = self.retrieve_args(node)
         full_idx = list(range(len(self.shape_of(args[0]))))
         full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], 
full_idx[args[1]]
         return self.block_builder.emit(relax.op.permute_dims(args[0], 
full_idx))
 
+    def _squeeze(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+
+        if "dim" in node.kwargs:
+            dim = node.kwargs["dim"]
+        elif len(node.args) > 1:
+            dim = node.args[1]
+        else:
+            dim = None
+        return self.block_builder.emit(relax.op.squeeze(x, dim))
+
     ########## Search ##########
 
     def _argmax_argmin(self, op: Callable) -> Callable:
@@ -521,7 +638,16 @@ class TorchFXImporter:
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
         x = self.block_builder.emit(relax.op.astype(x, "int32"))
-        return self.block_builder.emit(relax.op.take(weight, x, axis=0))
+
+        ndim = x.struct_info.ndim
+        if ndim == 1:
+            return self.block_builder.emit(relax.op.take(weight, x, axis=0))
+        else:
+            x_shape = x.struct_info.shape.values
+            emb_size = weight.struct_info.shape.values[-1]
+            x = self.block_builder.emit(relax.op.reshape(x, shape=[-1]))
+            embedding = self.block_builder.emit(relax.op.take(weight, x, 
axis=0))
+            return self.block_builder.emit(relax.op.reshape(embedding, 
[*x_shape, emb_size]))
 
     def _interpolate(self, node: fx.node.Node) -> relax.Var:
         # torch.nn.functional.interpolate(
@@ -620,6 +746,7 @@ class TorchFXImporter:
             while i < len(shape):
                 begin.append(0)
                 end.append(shape[i])
+                stride.append(1)
                 axes.append(i)
                 i = i + 1
             sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, 
begin, end, stride))
@@ -627,6 +754,9 @@ class TorchFXImporter:
             for i in expand_dim:
                 sliced_shape.insert(i, 1)
             return self.block_builder.emit(relax.op.reshape(sliced, 
sliced_shape))
+        elif isinstance(x, relax.Constant):
+            dtype = x.struct_info.dtype
+            return relax.const(x.data.numpy()[node.args[1]], dtype)
         else:
             assert False
 
@@ -660,24 +790,37 @@ class TorchFXImporter:
             "mul": self._mul,
             "sub": self._sub,
             "pow": self._pow,
+            "sigmoid": self._sigmoid,
             "sqrt": self._sqrt,
+            "round": self._round,
             "lt": self._lt,
             "truediv": self._truediv,
+            "fill_": self._inplace_fill,
             "new_ones": self._new_ones,
-            "tril": self._tril,
+            "arange": self._arange,
+            "empty": self._empty,
+            "tensor": self._tensor,
+            "tril": self._tril_triu(relax.op.tril),
+            "triu": self._tril_triu(relax.op.triu),
+            "tril_": self._inplace_tril_triu(relax.op.tril),
+            "triu_": self._inplace_tril_triu(relax.op.triu),
             "sum": self._sum,
             "float": self._float,
             "half": self._half,
             "type": self._type,
+            "astype": self._type,
             "matmul": self._matmul,
             "addmm": self._addmm,
+            "bmm": self._matmul,
             "cat": self._cat,
             "expand": self._expand,
             "flatten": self._flatten,
             "permute": self._permute,
             "reshape": self._reshape,
             "split": self._split,
+            "chunk": self._chunk,
             "transpose": self._transpose,
+            "squeeze": self._squeeze,
             "unsqueeze": lambda node: self.block_builder.emit(
                 relax.op.expand_dims(self.env[node.args[0]], node.args[1])
             ),
@@ -685,6 +828,7 @@ class TorchFXImporter:
             "argmax": self._argmax_argmin(relax.op.argmax),
             "argmin": self._argmax_argmin(relax.op.argmin),
             "softmax": self._softmax,
+            "dropout": lambda node: self.env[node.args[0]],
             "clamp": self._clamp,
             "relu": lambda node: 
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
             "gelu": lambda node: 
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])),
@@ -693,6 +837,7 @@ class TorchFXImporter:
             "getattr": self._getattr,
             "getitem": self._getitem,
             "contiguous": lambda node: self.env[node.args[0]],
+            "to": lambda node: self.env[node.args[0]],
             "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
         }
 
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 84fc97be27..9ab0b3304c 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -222,6 +222,45 @@ def test_linear():
     )
 
 
[email protected]_gpu
+def test_bmm():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    class BMM(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x, y):
+            return torch.bmm(x, y)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input_1: R.Tensor((4, 128, 256), dtype="float32"),
+            input_2: R.Tensor((4, 256, 512), dtype="float32"),
+        ) -> R.Tensor((4, 128, 512), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
+                    input_1, input_2, out_dtype="float32"
+                )
+                gv: R.Tensor((4, 128, 512), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(
+        BMM(),
+        [((4, 128, 256), "float32"), ((4, 256, 512), "float32")],
+        {},
+        Expected,
+    )
+
+
 @tvm.testing.requires_gpu
 def test_relu():
     import torch
@@ -576,7 +615,7 @@ def test_dropout():
 
     input_info = [([1, 3, 10, 10], "float32")]
 
-    class Dropout(Module):
+    class Dropout1(Module):
         def __init__(self):
             super().__init__()
             self.dropout = torch.nn.Dropout(0.5)
@@ -584,6 +623,10 @@ def test_dropout():
         def forward(self, input):
             return self.dropout(input)
 
+    class Dropout2(Module):
+        def forward(self, input):
+            return torch.dropout(input, 0.5, train=True)
+
     @tvm.script.ir_module
     class expected1:
         @R.function
@@ -596,7 +639,8 @@ def test_dropout():
                 R.output(gv)
             return gv
 
-    verify_model(Dropout(), input_info, {}, expected1)
+    verify_model(Dropout1(), input_info, {}, expected1)
+    verify_model(Dropout2(), input_info, {}, expected1)
 
 
 @tvm.testing.requires_gpu
@@ -1078,6 +1122,52 @@ def test_size():
     verify_model(Size(), input_info, {}, expected1)
 
 
[email protected]_gpu
+def test_squeeze():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([3, 1, 4, 1], "float32")]
+
+    class Squeeze1(Module):
+        def forward(self, input):
+            return input.squeeze(1)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+        ) -> R.Tensor((3, 4, 1), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, 
axis=[1])
+                gv: R.Tensor((3, 4, 1), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class Squeeze2(Module):
+        def forward(self, input):
+            return input.squeeze()
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+        ) -> R.Tensor((3, 4), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, 
axis=None)
+                gv: R.Tensor((3, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Squeeze1(), input_info, {}, Expected1)
+    verify_model(Squeeze2(), input_info, {}, Expected2)
+
+
 @tvm.testing.requires_gpu
 def test_unsqueeze():
     import torch
@@ -1260,6 +1350,46 @@ def test_unary():
 
     verify_model(Sqrt(), input_info, {}, expected3)
 
+    # sigmoid
+    class Sigmoid(Module):
+        def forward(self, input):
+            return torch.sigmoid(input)
+
+    @tvm.script.ir_module
+    class expected4:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.sigmoid(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Sigmoid(), input_info, {}, expected4)
+
+    # round
+    class Round(Module):
+        def forward(self, input):
+            return torch.round(input)
+
+    @tvm.script.ir_module
+    class expected5:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.round(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Round(), input_info, {}, expected5)
+
 
 @tvm.testing.requires_gpu
 def test_gelu():
@@ -1467,6 +1597,159 @@ def test_split():
     verify_model(Split(), input_info, {}, expected1)
 
 
[email protected]_gpu
+def test_chunk():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Chunk(Module):
+        def forward(self, input):
+            return torch.chunk(input, 3, dim=1)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(
+            R.Tensor((1, 1, 10, 10), dtype="float32"),
+            R.Tensor((1, 1, 10, 10), dtype="float32"),
+            R.Tensor((1, 1, 10, 10), dtype="float32"),
+        ):
+            # block 0
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                ) = R.split(input_1, indices_or_sections=3, axis=1)
+                gv: R.Tuple(
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                ) = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Chunk(), input_info, {}, Expected)
+
+
[email protected]_gpu
+def test_inplace_fill():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    class InplaceFill(Module):
+        def forward(self, input):
+            input.fill_(1.5)
+            return input
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 
10), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.full(
+                    R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32"
+                )
+                gv: R.Tensor((10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(InplaceFill(), [([10, 10], "float32")], {}, Expected)
+
+
[email protected]_gpu
+def test_arange():
+    import numpy as np
+    import torch
+    from torch import fx
+    from torch.nn import Module
+    from tvm.relax.frontend.torch import from_fx
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    class Arange(Module):
+        def forward(self, input):
+            return torch.arange(0, 20, dtype=torch.int32)
+
+    graph_model = fx.symbolic_trace(Arange())
+    mod = from_fx(graph_model, [([10, 10], "float32")])
+    assert len(mod["main"].body.blocks) == 1
+    assert len(mod["main"].body.blocks[0].bindings) == 1
+    assert isinstance(mod["main"].body.blocks[0].bindings[0].value, 
relax.Constant)
+    tvm.testing.assert_allclose(
+        mod["main"].body.blocks[0].bindings[0].value.data.numpy(), 
np.arange(0, 20, dtype="int32")
+    )
+
+
[email protected]_gpu
+def test_empty():
+    import torch
+    from torch import fx
+    from torch.nn import Module
+    from tvm.relax.frontend.torch import from_fx
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    class Empty(Module):
+        def forward(self, input):
+            return torch.empty((10, 10), dtype=torch.float32)
+
+    graph_model = fx.symbolic_trace(Empty())
+    mod = from_fx(graph_model, [([10, 10], "float32")])
+    assert len(mod["main"].body.blocks) == 1
+    assert len(mod["main"].body.blocks[0].bindings) == 1
+    assert isinstance(mod["main"].body.blocks[0].bindings[0].value, 
relax.Constant)
+    assert mod["main"].body.blocks[0].bindings[0].value.data.shape == (10, 10)
+    assert mod["main"].body.blocks[0].bindings[0].value.data.dtype == "float32"
+
+
[email protected]_gpu
+def test_tensor():
+    import torch
+    from torch import fx
+    from torch.nn import Module
+    from tvm.relax.frontend.torch import from_fx
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    class Empty1(Module):
+        def forward(self, input):
+            return torch.tensor(3, dtype=torch.float32)
+
+    class Empty2(Module):
+        def forward(self, input):
+            return torch.tensor(3)
+
+    graph_model1 = fx.symbolic_trace(Empty1())
+    mod1 = from_fx(graph_model1, [([10, 10], "float32")])
+    assert len(mod1["main"].body.blocks) == 1
+    assert len(mod1["main"].body.blocks[0].bindings) == 1
+    assert isinstance(mod1["main"].body.blocks[0].bindings[0].value, 
relax.Constant)
+    assert mod1["main"].body.blocks[0].bindings[0].value.data.shape == ()
+    assert mod1["main"].body.blocks[0].bindings[0].value.data.dtype == 
"float32"
+
+    graph_model2 = fx.symbolic_trace(Empty2())
+    mod2 = from_fx(graph_model2, [([10, 10], "float32")])
+    assert len(mod2["main"].body.blocks) == 1
+    assert len(mod2["main"].body.blocks[0].bindings) == 1
+    assert isinstance(mod2["main"].body.blocks[0].bindings[0].value, 
relax.Constant)
+    assert mod2["main"].body.blocks[0].bindings[0].value.data.shape == ()
+    assert mod2["main"].body.blocks[0].bindings[0].value.data.dtype == "int64"
+
+
 @tvm.testing.requires_gpu
 def test_tril():
     import torch
@@ -1481,6 +1764,11 @@ def test_tril():
         def forward(self, input):
             return torch.tril(input, 1)
 
+    class InplaceTril(Module):
+        def forward(self, input):
+            input.tril_(1)
+            return input
+
     @tvm.script.ir_module
     class expected1:
         @R.function
@@ -1495,6 +1783,43 @@ def test_tril():
             return gv
 
     verify_model(Tril(), input_info, {}, expected1)
+    verify_model(InplaceTril(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_triu():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([10, 10], "float32")]
+
+    class Triu(Module):
+        def forward(self, input):
+            return torch.triu(input, 1)
+
+    class InplaceTriu(Module):
+        def forward(self, input):
+            input.triu_(1)
+            return input
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((10, 10), dtype="float32")
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1)
+                gv: R.Tensor((10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Triu(), input_info, {}, expected1)
+    verify_model(InplaceTriu(), input_info, {}, expected1)
 
 
 @tvm.testing.requires_gpu
@@ -1589,7 +1914,7 @@ def test_reduce():
 
 
 @tvm.testing.requires_gpu
-def test_to():
+def test_datatype():
     import torch
     from torch.nn import Module
 
@@ -1638,6 +1963,19 @@ def test_to():
 
     verify_model(ToHalf(), input_info, {}, expected2)
 
+    # type
+    class Type(Module):
+        def forward(self, x):
+            return x.type(torch.float32)
+
+    # astype
+    class AsType(Module):
+        def forward(self, x):
+            return x.astype(torch.float32)
+
+    verify_model(Type(), input_info, {}, expected1)
+    verify_model(AsType(), input_info, {}, expected1)
+
 
 @tvm.testing.requires_gpu
 def test_permute():

Reply via email to