This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0c6495902d [Unity][Op] Legalize `round`, `floor`, `ceil`, `sign` 
(#14198)
0c6495902d is described below

commit 0c6495902dc8a07888f1062953e607ca5c0a9d6d
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 4 21:09:18 2023 -0500

    [Unity][Op] Legalize `round`, `floor`, `ceil`, `sign` (#14198)
    
    This PR implements the legalization for four unary operators:
    * round,
    * floor,
    * ceil,
    * sign.
    
    Unit tests are provided accordingly.
---
 python/tvm/relax/transform/legalize_ops/unary.py   |   4 +
 .../relax/test_transform_legalize_ops_unary.py     | 405 +++++++++++++++++++++
 2 files changed, 409 insertions(+)

diff --git a/python/tvm/relax/transform/legalize_ops/unary.py 
b/python/tvm/relax/transform/legalize_ops/unary.py
index cd29182c4d..ca84cbf0ad 100644
--- a/python/tvm/relax/transform/legalize_ops/unary.py
+++ b/python/tvm/relax/transform/legalize_ops/unary.py
@@ -21,11 +21,15 @@ from .common import _call_topi_without_attr, 
register_legalize
 # To avoid conflict of IRModule function name and libc function name, we add
 # "tir_" as the prefix of the generated PrimFunc name.
 register_legalize("relax.abs", _call_topi_without_attr(topi.abs, "tir_abs"))
+register_legalize("relax.ceil", _call_topi_without_attr(topi.ceil, "tir_ceil"))
 register_legalize("relax.cos", _call_topi_without_attr(topi.cos, "tir_cos"))
 register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log"))
 register_legalize("relax.exp", _call_topi_without_attr(topi.exp, "tir_exp"))
+register_legalize("relax.floor", _call_topi_without_attr(topi.floor, 
"tir_floor"))
 register_legalize("relax.negative", _call_topi_without_attr(topi.negative, 
"tir_negative"))
+register_legalize("relax.round", _call_topi_without_attr(topi.round, 
"tir_round"))
 register_legalize("relax.sigmoid", _call_topi_without_attr(topi.sigmoid, 
"tir_sigmoid"))
+register_legalize("relax.sign", _call_topi_without_attr(topi.sign, "tir_sign"))
 register_legalize("relax.sin", _call_topi_without_attr(topi.sin, "tir_sin"))
 register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt"))
 register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh"))
diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py 
b/tests/python/relax/test_transform_legalize_ops_unary.py
index 7250e711be..97904e2a5f 100644
--- a/tests/python/relax/test_transform_legalize_ops_unary.py
+++ b/tests/python/relax/test_transform_legalize_ops_unary.py
@@ -18,6 +18,7 @@
 import tvm
 import tvm.testing
 from tvm.relax.transform import LegalizeOps
+from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
 
@@ -92,6 +93,107 @@ def test_abs_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_ceil():
+    # fmt: off
+    @tvm.script.ir_module
+    class Ceil:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv: R.Tensor((2, 3), "float32") = R.ceil(x)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1])
+
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
+            gv = R.call_tir(tir_ceil, (x,), out_sinfo=R.Tensor((2, 3), 
dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Ceil)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_ceil_int():
+    # fmt: off
+    @tvm.script.ir_module
+    class Ceil:
+        @R.function
+        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+            gv: R.Tensor((2, 3), "int32") = R.ceil(x)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
+
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), 
dtype="int32"):
+            gv = R.call_tir(tir_ceil, (x,), out_sinfo=R.Tensor((2, 3), 
dtype="int32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Ceil)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_ceil_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class Ceil:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
+            m = T.int64()
+            n = T.int64()
+            gv: R.Tensor((m, n), "float32") = R.ceil(x)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_ceil(var_rxplaceholder: T.handle, var_compute: T.handle):
+            T.func_attr({"tir.noalias": True})
+            m = T.int64()
+            n = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
+            compute = T.match_buffer(var_compute, (m, n))
+            for i0, i1 in T.grid(m, n):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1])
+
+        @R.function
+        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
+            m = T.int64()
+            n = T.int64()
+            gv = R.call_tir(tir_ceil, (x,), out_sinfo=R.Tensor((m, n), 
dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Ceil)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_cos():
     # fmt: off
     @tvm.script.ir_module
@@ -232,6 +334,107 @@ def test_exp_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_floor():
+    # fmt: off
+    @tvm.script.ir_module
+    class Floor:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv: R.Tensor((2, 3), "float32") = R.floor(x)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1])
+
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
+            gv = R.call_tir(tir_floor, (x,), out_sinfo=R.Tensor((2, 3), 
dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Floor)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_floor_int():
+    # fmt: off
+    @tvm.script.ir_module
+    class Floor:
+        @R.function
+        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+            gv: R.Tensor((2, 3), "int32") = R.floor(x)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
+
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), 
dtype="int32"):
+            gv = R.call_tir(tir_floor, (x,), out_sinfo=R.Tensor((2, 3), 
dtype="int32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Floor)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_floor_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class Floor:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
+            m = T.int64()
+            n = T.int64()
+            gv: R.Tensor((m, n), "float32") = R.floor(x)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_floor(var_rxplaceholder: T.handle, var_compute: T.handle):
+            T.func_attr({"tir.noalias": True})
+            m = T.int64()
+            n = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
+            compute = T.match_buffer(var_compute, (m, n))
+            for i0, i1 in T.grid(m, n):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1])
+
+        @R.function
+        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
+            m = T.int64()
+            n = T.int64()
+            gv = R.call_tir(tir_floor, (x,), out_sinfo=R.Tensor((m, n), 
dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Floor)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_log():
     # fmt: off
     @tvm.script.ir_module
@@ -372,6 +575,107 @@ def test_negative_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_round():
+    # fmt: off
+    @tvm.script.ir_module
+    class Round:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv: R.Tensor((2, 3), "float32") = R.round(x)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1])
+
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
+            gv = R.call_tir(tir_round, (x,), out_sinfo=R.Tensor((2, 3), 
dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Round)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_round_int():
+    # fmt: off
+    @tvm.script.ir_module
+    class Round:
+        @R.function
+        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+            gv: R.Tensor((2, 3), "int32") = R.round(x)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
+
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), 
dtype="int32"):
+            gv = R.call_tir(tir_round, (x,), out_sinfo=R.Tensor((2, 3), 
dtype="int32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Round)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_round_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class Round:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
+            m = T.int64()
+            n = T.int64()
+            gv: R.Tensor((m, n), "float32") = R.round(x)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_round(var_rxplaceholder: T.handle, var_compute: T.handle):
+            T.func_attr({"tir.noalias": True})
+            m = T.int64()
+            n = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
+            compute = T.match_buffer(var_compute, (m, n))
+            for i0, i1 in T.grid(m, n):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1])
+
+        @R.function
+        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
+            m = T.int64()
+            n = T.int64()
+            gv = R.call_tir(tir_round, (x,), out_sinfo=R.Tensor((m, n), 
dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Round)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_sigmoid():
     # fmt: off
     @tvm.script.ir_module
@@ -442,6 +746,107 @@ def test_sigmoid_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_sign():
+    # fmt: off
+    @tvm.script.ir_module
+    class Sign:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv: R.Tensor((2, 3), "float32") = R.sign(x)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_sign"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax0, v_ax1])
+                    T.writes(T_sign[v_ax0, v_ax1])
+                    T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) < 
rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1] 
< T.float32(0), T.float32(-1), T.float32(0)))
+
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
+            gv = R.call_tir(tir_sign, (x,), out_sinfo=R.Tensor((2, 3), 
dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Sign)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_sign_int():
+    # fmt: off
+    @tvm.script.ir_module
+    class Sign:
+        @R.function
+        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+            gv: R.Tensor((2, 3), "int32") = R.sign(x)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"int32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "int32")):
+            T.func_attr({"tir.noalias": True})
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_sign"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax0, v_ax1])
+                    T.writes(T_sign[v_ax0, v_ax1])
+                    T_sign[v_ax0, v_ax1] = T.Select(0 < rxplaceholder[v_ax0, 
v_ax1], 1, T.Select(rxplaceholder[v_ax0, v_ax1] < 0, -1, 0))
+
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), 
dtype="int32"):
+            gv = R.call_tir(tir_sign, (x,), out_sinfo=R.Tensor((2, 3), 
dtype="int32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Sign)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_sign_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class Sign:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
+            m = T.int64()
+            n = T.int64()
+            gv: R.Tensor((m, n), "float32") = R.sign(x)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tir_sign(var_rxplaceholder: T.handle, var_T_sign: T.handle):
+            T.func_attr({"tir.noalias": True})
+            m = T.int64()
+            n = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
+            T_sign = T.match_buffer(var_T_sign, (m, n))
+            for ax0, ax1 in T.grid(m, n):
+                with T.block("T_sign"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax0, v_ax1])
+                    T.writes(T_sign[v_ax0, v_ax1])
+                    T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) < 
rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1] 
< T.float32(0), T.float32(-1), T.float32(0)))
+
+        @R.function
+        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
+            m = T.int64()
+            n = T.int64()
+            gv = R.call_tir(tir_sign, (x,), out_sinfo=R.Tensor((m, n), 
dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Sign)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_sin():
     # fmt: off
     @tvm.script.ir_module

Reply via email to