This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new faa8a0ad46 [Unity][nn.Module] Introduce operator `empty` (#16327)
faa8a0ad46 is described below

commit faa8a0ad46d2e3159680df0e09a84e5d6376b1fd
Author: Junru Shao <[email protected]>
AuthorDate: Mon Jan 1 20:30:57 2024 -0800

    [Unity][nn.Module] Introduce operator `empty` (#16327)
    
    This PR introduces an operator `op.empty` in the `nn.Module` frontend.
    It helps us to create an uninitialized memory from the memory pool,
    which could be used as temporary scratchpad memory to handcrafted
    operators.
---
 python/tvm/relax/frontend/nn/op.py        | 59 +++++++++++++++++++++++++++++++
 tests/python/relax/test_frontend_nn_op.py | 27 ++++++++++++--
 2 files changed, 83 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 3197145289..66f023ef9d 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1142,6 +1142,65 @@ def zeros(
     return wrap_nested(_op.zeros(shape, dtype), name)
 
 
+def ones(
+    shape: Sequence[IntExpr],
+    dtype: str = "float32",
+    name: str = "ones",
+) -> Tensor:
+    """Construct a tensor of all zeros, with the input shape and dtype.
+
+    Parameters
+    ----------
+    shape : Sequence[IntExpr]
+        The shape of the created tensor.
+
+    dtype : str
+        The data type of the created tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The result tensor.
+    """
+    return wrap_nested(_op.ones(shape, dtype), name)
+
+
+def empty(
+    shape: Sequence[IntExpr],
+    dtype: str = "float32",
+    name: str = "empty",
+) -> Tensor:
+    """Construct an uninitialized tensor, with the input shape and dtype.
+
+    Parameters
+    ----------
+    shape : Sequence[IntExpr]
+        The shape of the created tensor.
+
+    dtype : str
+        The data type of the created tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The result tensor.
+    """
+    return wrap_nested(  # type: ignore
+        _op.builtin.alloc_tensor(
+            rx.ShapeExpr(shape),  # type: ignore
+            dtype,
+            runtime_device_index=0,
+        ),
+        name,
+    )
+
+
 def split(
     ary: Tensor,
     indices_or_sections: Union[int, Sequence[int]],
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index 55870426e4..43f4a9efc0 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -17,12 +17,14 @@
 # pylint: disable=missing-docstring, invalid-name
 import tvm
 import tvm.testing
-from tvm import tir
+from tvm import relax, tir
 from tvm.relax.frontend.nn import Module, Tensor, op, spec
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
 
+# mypy: disable-error-code="attr-defined,valid-type,name-defined"
+
 
 def test_binary():
     class Model(Module):
@@ -174,7 +176,7 @@ def test_image():
         def test(self, x: Tensor, weight: Tensor, bias: Tensor):
             padded = op.pad(x, [0, 0, 0, 0, 1, 1, 1, 1])
             conv2d = op.conv2d(padded, weight, bias)
-            interpolate = op.interpolate(x, size=[40, 40])
+            interpolate = op.interpolate(x, size=[40, 40])  # type: ignore
             return (conv2d, interpolate)
 
     @R.function
@@ -347,7 +349,7 @@ def test_create():
     class Model(Module):
         def test(self, x: Tensor):
             triu_out = op.triu(x)
-            full_with_scalar_out = op.full([10, 10], fill_value=10)
+            full_with_scalar_out = op.full([10, 10], fill_value=10)  # type: 
ignore
             full_with_FloatImm_out = op.full(
                 [10, 10], fill_value=tir.FloatImm(dtype="float32", value=10)
             )
@@ -638,5 +640,24 @@ def test_extern():
     tvm.ir.assert_structural_equal(irmodule, Expected)
 
 
+def test_empty():
+    @tvm.register_func("test_empty_assert", override=True)
+    def test_empty_assert(_lineo, x):
+        assert x.shape == (10, 10)
+        assert x.dtype == "float32"
+
+    class Model(Module):
+        def test(self):
+            result = op.empty([10, 10], dtype="float32")
+            op.debug_func("test_empty_assert", result)
+            return result
+
+    irmodule, _ = Model().export_tvm(spec={"test": {}}, debug=True)
+    ex = relax.build(irmodule, "llvm")
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    effects = vm["_initialize_effect"]()
+    vm["test"](*effects)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to