This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0c6495902d [Unity][Op] Legalize `round`, `floor`, `ceil`, `sign`
(#14198)
0c6495902d is described below
commit 0c6495902dc8a07888f1062953e607ca5c0a9d6d
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 4 21:09:18 2023 -0500
[Unity][Op] Legalize `round`, `floor`, `ceil`, `sign` (#14198)
This PR implements the legalization for four unary operators:
* round,
* floor,
* ceil,
* sign.
Unit tests are provided accordingly.
---
python/tvm/relax/transform/legalize_ops/unary.py | 4 +
.../relax/test_transform_legalize_ops_unary.py | 405 +++++++++++++++++++++
2 files changed, 409 insertions(+)
diff --git a/python/tvm/relax/transform/legalize_ops/unary.py
b/python/tvm/relax/transform/legalize_ops/unary.py
index cd29182c4d..ca84cbf0ad 100644
--- a/python/tvm/relax/transform/legalize_ops/unary.py
+++ b/python/tvm/relax/transform/legalize_ops/unary.py
@@ -21,11 +21,15 @@ from .common import _call_topi_without_attr,
register_legalize
# To avoid conflict of IRModule function name and libc function name, we add
# "tir_" as the prefix of the generated PrimFunc name.
register_legalize("relax.abs", _call_topi_without_attr(topi.abs, "tir_abs"))
+register_legalize("relax.ceil", _call_topi_without_attr(topi.ceil, "tir_ceil"))
register_legalize("relax.cos", _call_topi_without_attr(topi.cos, "tir_cos"))
register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log"))
register_legalize("relax.exp", _call_topi_without_attr(topi.exp, "tir_exp"))
+register_legalize("relax.floor", _call_topi_without_attr(topi.floor,
"tir_floor"))
register_legalize("relax.negative", _call_topi_without_attr(topi.negative,
"tir_negative"))
+register_legalize("relax.round", _call_topi_without_attr(topi.round,
"tir_round"))
register_legalize("relax.sigmoid", _call_topi_without_attr(topi.sigmoid,
"tir_sigmoid"))
+register_legalize("relax.sign", _call_topi_without_attr(topi.sign, "tir_sign"))
register_legalize("relax.sin", _call_topi_without_attr(topi.sin, "tir_sin"))
register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt"))
register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh"))
diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py
b/tests/python/relax/test_transform_legalize_ops_unary.py
index 7250e711be..97904e2a5f 100644
--- a/tests/python/relax/test_transform_legalize_ops_unary.py
+++ b/tests/python/relax/test_transform_legalize_ops_unary.py
@@ -18,6 +18,7 @@
import tvm
import tvm.testing
from tvm.relax.transform import LegalizeOps
+from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
@@ -92,6 +93,107 @@ def test_abs_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_ceil():
+ # fmt: off
+ @tvm.script.ir_module
+ class Ceil:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv: R.Tensor((2, 3), "float32") = R.ceil(x)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1])
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
+ gv = R.call_tir(tir_ceil, (x,), out_sinfo=R.Tensor((2, 3),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Ceil)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_ceil_int():
+ # fmt: off
+ @tvm.script.ir_module
+ class Ceil:
+ @R.function
+ def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+ gv: R.Tensor((2, 3), "int32") = R.ceil(x)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_ceil(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3),
dtype="int32"):
+ gv = R.call_tir(tir_ceil, (x,), out_sinfo=R.Tensor((2, 3),
dtype="int32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Ceil)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_ceil_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class Ceil:
+ @R.function
+ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
+ m = T.int64()
+ n = T.int64()
+ gv: R.Tensor((m, n), "float32") = R.ceil(x)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_ceil(var_rxplaceholder: T.handle, var_compute: T.handle):
+ T.func_attr({"tir.noalias": True})
+ m = T.int64()
+ n = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
+ compute = T.match_buffer(var_compute, (m, n))
+ for i0, i1 in T.grid(m, n):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.ceil(rxplaceholder[v_i0, v_i1])
+
+ @R.function
+ def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m",
"n"), dtype="float32"):
+ m = T.int64()
+ n = T.int64()
+ gv = R.call_tir(tir_ceil, (x,), out_sinfo=R.Tensor((m, n),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Ceil)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_cos():
# fmt: off
@tvm.script.ir_module
@@ -232,6 +334,107 @@ def test_exp_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_floor():
+ # fmt: off
+ @tvm.script.ir_module
+ class Floor:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv: R.Tensor((2, 3), "float32") = R.floor(x)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1])
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
+ gv = R.call_tir(tir_floor, (x,), out_sinfo=R.Tensor((2, 3),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Floor)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_floor_int():
+ # fmt: off
+ @tvm.script.ir_module
+ class Floor:
+ @R.function
+ def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+ gv: R.Tensor((2, 3), "int32") = R.floor(x)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_floor(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3),
dtype="int32"):
+ gv = R.call_tir(tir_floor, (x,), out_sinfo=R.Tensor((2, 3),
dtype="int32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Floor)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_floor_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class Floor:
+ @R.function
+ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
+ m = T.int64()
+ n = T.int64()
+ gv: R.Tensor((m, n), "float32") = R.floor(x)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_floor(var_rxplaceholder: T.handle, var_compute: T.handle):
+ T.func_attr({"tir.noalias": True})
+ m = T.int64()
+ n = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
+ compute = T.match_buffer(var_compute, (m, n))
+ for i0, i1 in T.grid(m, n):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.floor(rxplaceholder[v_i0, v_i1])
+
+ @R.function
+ def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m",
"n"), dtype="float32"):
+ m = T.int64()
+ n = T.int64()
+ gv = R.call_tir(tir_floor, (x,), out_sinfo=R.Tensor((m, n),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Floor)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_log():
# fmt: off
@tvm.script.ir_module
@@ -372,6 +575,107 @@ def test_negative_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_round():
+ # fmt: off
+ @tvm.script.ir_module
+ class Round:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv: R.Tensor((2, 3), "float32") = R.round(x)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1])
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
+ gv = R.call_tir(tir_round, (x,), out_sinfo=R.Tensor((2, 3),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Round)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_round_int():
+ # fmt: off
+ @tvm.script.ir_module
+ class Round:
+ @R.function
+ def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+ gv: R.Tensor((2, 3), "int32") = R.round(x)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_round(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"int32"), compute: T.Buffer((T.int64(2), T.int64(3)), "int32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = rxplaceholder[v_i0, v_i1]
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3),
dtype="int32"):
+ gv = R.call_tir(tir_round, (x,), out_sinfo=R.Tensor((2, 3),
dtype="int32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Round)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_round_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class Round:
+ @R.function
+ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
+ m = T.int64()
+ n = T.int64()
+ gv: R.Tensor((m, n), "float32") = R.round(x)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_round(var_rxplaceholder: T.handle, var_compute: T.handle):
+ T.func_attr({"tir.noalias": True})
+ m = T.int64()
+ n = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
+ compute = T.match_buffer(var_compute, (m, n))
+ for i0, i1 in T.grid(m, n):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.round(rxplaceholder[v_i0, v_i1])
+
+ @R.function
+ def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m",
"n"), dtype="float32"):
+ m = T.int64()
+ n = T.int64()
+ gv = R.call_tir(tir_round, (x,), out_sinfo=R.Tensor((m, n),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Round)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_sigmoid():
# fmt: off
@tvm.script.ir_module
@@ -442,6 +746,107 @@ def test_sigmoid_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_sign():
+ # fmt: off
+ @tvm.script.ir_module
+ class Sign:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv: R.Tensor((2, 3), "float32") = R.sign(x)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"float32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_sign"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax0, v_ax1])
+ T.writes(T_sign[v_ax0, v_ax1])
+ T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) <
rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1]
< T.float32(0), T.float32(-1), T.float32(0)))
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3),
dtype="float32"):
+ gv = R.call_tir(tir_sign, (x,), out_sinfo=R.Tensor((2, 3),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Sign)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_sign_int():
+ # fmt: off
+ @tvm.script.ir_module
+ class Sign:
+ @R.function
+ def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+ gv: R.Tensor((2, 3), "int32") = R.sign(x)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_sign(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)),
"int32"), T_sign: T.Buffer((T.int64(2), T.int64(3)), "int32")):
+ T.func_attr({"tir.noalias": True})
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_sign"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax0, v_ax1])
+ T.writes(T_sign[v_ax0, v_ax1])
+ T_sign[v_ax0, v_ax1] = T.Select(0 < rxplaceholder[v_ax0,
v_ax1], 1, T.Select(rxplaceholder[v_ax0, v_ax1] < 0, -1, 0))
+
+ @R.function
+ def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3),
dtype="int32"):
+ gv = R.call_tir(tir_sign, (x,), out_sinfo=R.Tensor((2, 3),
dtype="int32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Sign)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_sign_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class Sign:
+ @R.function
+ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
+ m = T.int64()
+ n = T.int64()
+ gv: R.Tensor((m, n), "float32") = R.sign(x)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tir_sign(var_rxplaceholder: T.handle, var_T_sign: T.handle):
+ T.func_attr({"tir.noalias": True})
+ m = T.int64()
+ n = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (m, n))
+ T_sign = T.match_buffer(var_T_sign, (m, n))
+ for ax0, ax1 in T.grid(m, n):
+ with T.block("T_sign"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax0, v_ax1])
+ T.writes(T_sign[v_ax0, v_ax1])
+ T_sign[v_ax0, v_ax1] = T.Select(T.float32(0) <
rxplaceholder[v_ax0, v_ax1], T.float32(1), T.Select(rxplaceholder[v_ax0, v_ax1]
< T.float32(0), T.float32(-1), T.float32(0)))
+
+ @R.function
+ def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m",
"n"), dtype="float32"):
+ m = T.int64()
+ n = T.int64()
+ gv = R.call_tir(tir_sign, (x,), out_sinfo=R.Tensor((m, n),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Sign)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_sin():
# fmt: off
@tvm.script.ir_module