This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new abb901f08c [Relax] Support left_shift and right_shift op (#17448)
abb901f08c is described below
commit abb901f08cdc646d69758eb32503dcab59a904e0
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Oct 7 22:56:54 2024 +0800
[Relax] Support left_shift and right_shift op (#17448)
Introduced left_shift and right_shift op in Relax with ONNX frontend
support.
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 104 ++++++++++++++++++++--
python/tvm/relax/op/__init__.py | 2 +
python/tvm/relax/op/binary.py | 32 +++++++
python/tvm/relax/transform/legalize_ops/binary.py | 2 +
python/tvm/script/ir_builder/relax/ir.py | 4 +
src/relax/op/distributed/binary.cc | 2 +
src/relax/op/tensor/binary.cc | 2 +
src/relax/op/tensor/binary.h | 6 ++
tests/python/relax/test_frontend_onnx.py | 36 ++++++++
tests/python/relax/test_op_binary.py | 2 +
10 files changed, 184 insertions(+), 8 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 36a7823f86..aa156a025f 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -244,7 +244,8 @@ class BinaryBase(OnnxOpConverter):
relax_op: Callable = None
@classmethod
- def _impl_v1(cls, bb, inputs, attr, params):
+ def base_impl(cls, bb, inputs, attr, params):
+ """Base implementation for binary operations."""
if cls.numpy_op is None or cls.relax_op is None:
raise ValueError("Numpy and Relax operators must be defined for
BinaryBase.")
if all([isinstance(inp, relax.Constant) for inp in inputs]):
@@ -274,6 +275,10 @@ class Add(BinaryBase):
numpy_op = _np.add
relax_op = relax.op.add
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class Sub(BinaryBase):
"""Converts an onnx Sub node into an equivalent Relax expression."""
@@ -281,6 +286,10 @@ class Sub(BinaryBase):
numpy_op = _np.subtract
relax_op = relax.op.subtract
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class Mul(BinaryBase):
"""Converts an onnx Mul node into an equivalent Relax expression."""
@@ -288,6 +297,10 @@ class Mul(BinaryBase):
numpy_op = _np.multiply
relax_op = relax.op.multiply
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class Div(BinaryBase):
"""Converts an onnx Div node into an equivalent Relax expression."""
@@ -295,6 +308,10 @@ class Div(BinaryBase):
numpy_op = _np.divide
relax_op = relax.op.divide
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class Pow(BinaryBase):
"""Converts an onnx Pow node into an equivalent Relax expression."""
@@ -302,6 +319,10 @@ class Pow(BinaryBase):
numpy_op = _np.power
relax_op = relax.op.power
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class And(BinaryBase):
"""Converts an onnx And node into an equivalent Relax expression."""
@@ -309,6 +330,10 @@ class And(BinaryBase):
numpy_op = _np.logical_and
relax_op = relax.op.logical_and
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class Or(BinaryBase):
"""Converts an onnx Or node into an equivalent Relax expression."""
@@ -316,6 +341,10 @@ class Or(BinaryBase):
numpy_op = _np.logical_or
relax_op = relax.op.logical_or
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class Xor(BinaryBase):
"""Converts an onnx Xor node into an equivalent Relax expression."""
@@ -323,6 +352,10 @@ class Xor(BinaryBase):
numpy_op = _np.logical_xor
relax_op = relax.op.logical_xor
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class Less(BinaryBase):
"""Converts an onnx Less node into an equivalent Relax expression."""
@@ -330,6 +363,10 @@ class Less(BinaryBase):
numpy_op = _np.less
relax_op = relax.op.less
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class LessOrEqual(BinaryBase):
"""Converts an onnx LessEqual node into an equivalent Relax expression."""
@@ -337,6 +374,10 @@ class LessOrEqual(BinaryBase):
numpy_op = _np.less_equal
relax_op = relax.op.less_equal
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class Greater(BinaryBase):
"""Converts an onnx Greater node into an equivalent Relax expression."""
@@ -344,6 +385,10 @@ class Greater(BinaryBase):
numpy_op = _np.greater
relax_op = relax.op.greater
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class GreaterOrEqual(BinaryBase):
"""Converts an onnx GreaterEqual node into an equivalent Relax
expression."""
@@ -351,6 +396,10 @@ class GreaterOrEqual(BinaryBase):
numpy_op = _np.greater_equal
relax_op = relax.op.greater_equal
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
class Equal(OnnxOpConverter):
"""Converts an onnx Equal node into an equivalent Relax expression."""
@@ -374,7 +423,8 @@ class BitwiseBase(BinaryBase):
"""Converts an onnx BitwiseBase node into an equivalent Relax
expression."""
@classmethod
- def base_impl(cls, bb, inputs, attr, params, py_func, relax_op):
+ def base_impl(cls, bb, inputs, attr, params):
+ """Base implementation for bitwise operations."""
valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64"]
for num, inp in enumerate(inputs):
if inp.struct_info.dtype not in valid_types:
@@ -382,31 +432,69 @@ class BitwiseBase(BinaryBase):
f"Bitwise operations expect all inputs to have integer
types, "
f"got {inp.struct_info.dtype} for input {num}"
)
- return BinaryBase.base_impl(bb, inputs, attr, params, py_func,
relax_op)
+ return super().base_impl(bb, inputs, attr, params)
class BitwiseAnd(BitwiseBase):
"""Converts an onnx BitwiseAnd node into an equivalent Relax expression."""
+ numpy_op = _np.bitwise_and
+ relax_op = relax.op.bitwise_and
+
@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
- return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y,
relax.op.bitwise_and)
+ return cls.base_impl(bb, inputs, attr, params)
class BitwiseOr(BitwiseBase):
"""Converts an onnx BitwiseOr node into an equivalent Relax expression."""
+ numpy_op = _np.bitwise_or
+ relax_op = relax.op.bitwise_or
+
@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
- return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y,
relax.op.bitwise_or)
+ return cls.base_impl(bb, inputs, attr, params)
class BitwiseXor(BitwiseBase):
"""Converts an onnx BitwiseXor node into an equivalent Relax expression."""
+ numpy_op = _np.bitwise_xor
+ relax_op = relax.op.bitwise_xor
+
@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
- return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y,
relax.op.bitwise_xor)
+ return cls.base_impl(bb, inputs, attr, params)
+
+
+class BitwiseNot(BitwiseBase):
+ """Converts an onnx BitwiseNot node into an equivalent Relax expression."""
+
+ numpy_op = _np.bitwise_not
+ relax_op = relax.op.bitwise_not
+
+ @classmethod
+ def _impl_v18(cls, bb, inputs, attr, params):
+ return cls.base_impl(bb, inputs, attr, params)
+
+
+class BitShift(BitwiseBase):
+ """Converts an onnx BitShift node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ direction = attr.get("direction", "LEFT").decode("ascii")
+ if direction == "LEFT":
+ cls.numpy_op = _np.left_shift
+ cls.relax_op = relax.op.left_shift
+ elif direction == "RIGHT":
+ cls.numpy_op = _np.right_shift
+ cls.relax_op = relax.op.right_shift
+ else:
+ raise ValueError("Unsupported Shift Direction: " + direction)
+
+ return cls.base_impl(bb, inputs, attr, params)
class Sigmoid(OnnxOpConverter):
@@ -2654,8 +2742,8 @@ def _get_convert_map():
"BitwiseAnd": BitwiseAnd,
"BitwiseOr": BitwiseOr,
"BitwiseXor": BitwiseXor,
- # "BitwiseNot": BitwiseNot,
- # "BitwiseShift": BitwiseShift,
+ "BitwiseNot": BitwiseNot,
+ "BitShift": BitShift,
"And": And,
"Or": Or,
"Xor": Xor,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 4581defa1a..c99201e969 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -52,6 +52,7 @@ from .binary import (
floor_divide,
greater,
greater_equal,
+ left_shift,
less,
less_equal,
logical_and,
@@ -62,6 +63,7 @@ from .binary import (
multiply,
not_equal,
power,
+ right_shift,
subtract,
)
from .create import (
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 982b3a24f2..7632235cb3 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -386,3 +386,35 @@ def bitwise_xor(x1: Expr, x2: Expr) -> Expr:
The computed result.
"""
return _ffi_api.bitwise_xor(x1, x2)
+
+
+def left_shift(x1: Expr, x2: Expr) -> Expr:
+ """Bitwise Shift Left
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The input tensor to be shifted.
+ x2 : relax.Expr
+ The number of positions to shift.
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.left_shift(x1, x2)
+
+
+def right_shift(x1: Expr, x2: Expr) -> Expr:
+ """Bitwise Shift Right
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The input tensor to be shifted.
+ x2 : relax.Expr
+ The number of positions to shift.
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.right_shift(x1, x2)
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py
b/python/tvm/relax/transform/legalize_ops/binary.py
index 16d6c02696..d28e100edb 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -62,6 +62,8 @@ register_legalize("relax.minimum", _binary(topi.minimum))
register_legalize("relax.bitwise_and", _binary(topi.bitwise_and))
register_legalize("relax.bitwise_or", _binary(topi.bitwise_or))
register_legalize("relax.bitwise_xor", _binary(topi.bitwise_xor))
+register_legalize("relax.left_shift", _binary(topi.left_shift))
+register_legalize("relax.right_shift", _binary(topi.right_shift))
# logical
register_legalize("relax.logical_and", _binary(topi.logical_and))
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index c4be8afac4..e6ff35ebe5 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -102,6 +102,7 @@ from tvm.relax.op import (
isinf,
isnan,
layout_transform,
+ left_shift,
less,
less_equal,
linear,
@@ -133,6 +134,7 @@ from tvm.relax.op import (
quantize,
repeat,
reshape,
+ right_shift,
round,
rsqrt,
scatter_elements,
@@ -773,6 +775,7 @@ __all__ = [
"isinf",
"isnan",
"layout_transform",
+ "left_shift",
"less",
"less_equal",
"linear",
@@ -809,6 +812,7 @@ __all__ = [
"repeat",
"reshape",
"rewriter",
+ "right_shift",
"tensor_to_shape",
"shape_to_tensor",
"rocm",
diff --git a/src/relax/op/distributed/binary.cc
b/src/relax/op/distributed/binary.cc
index 63f4f356c0..6ad71e0f85 100644
--- a/src/relax/op/distributed/binary.cc
+++ b/src/relax/op/distributed/binary.cc
@@ -68,6 +68,8 @@
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(logical_xor);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_and);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_or);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_xor);
+RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(left_shift);
+RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(right_shift);
} // namespace distributed
} // namespace relax
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index afc0fb7303..f1dc3d4904 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -207,6 +207,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_xor);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_and);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_or);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_xor);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(left_shift);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(right_shift);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index b28a6c3369..003bcb7e27 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -129,6 +129,12 @@ Expr bitwise_or(Expr x1, Expr x2);
/*! \brief Broadcasted element-wise bitwise xor */
Expr bitwise_xor(Expr x1, Expr x2);
+/*! \brief Broadcasted element-wise bitwise shift left */
+Expr left_shift(Expr x1, Expr x2);
+
+/*! \brief Broadcasted element-wise bitwise shift right */
+Expr right_shift(Expr x1, Expr x2);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index f2bbd3f3f5..e3ed3a3a9d 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -358,6 +358,42 @@ def test_binary_bool(op_name: str):
verify_binary(op_name, [32, 32], [32, 32], [32, 32],
dtype=TensorProto.BOOL)
[email protected](reason="opset 18 is not supported in CI")
[email protected]("op_name", ["BitwiseAnd", "BitwiseOr", "BitwiseXor"])
+def test_bitwise(op_name: str):
+ verify_binary(op_name, [32, 32], [32, 32], [32, 32],
dtype=TensorProto.UINT64, opset=18)
+
+
[email protected](reason="opset 18 is not supported in CI")
+def test_bitwise_not():
+ verify_unary(
+ "BitwiseNot",
+ [32, 32],
+ input_dtype=TensorProto.UINT64,
+ output_dtype=TensorProto.UINT64,
+ opset=18,
+ )
+
+
[email protected]("direction", ["LEFT", "RIGHT"])
+def test_bitwise_shift(direction: str):
+ shape = [32, 32]
+ dtype = TensorProto.UINT64
+ test_node = helper.make_node("BitShift", ["a", "b"], ["c"],
direction=direction)
+ graph = helper.make_graph(
+ [test_node],
+ "binary_test",
+ inputs=[
+ helper.make_tensor_value_info("a", dtype, shape),
+ helper.make_tensor_value_info("b", dtype, shape),
+ ],
+ outputs=[helper.make_tensor_value_info("c", dtype, shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="binary_test")
+ check_correctness(model, inputs={"b": np.random.randint(0, 8,
shape).astype("uint64")})
+
+
@pytest.mark.parametrize(
"op_name",
[
diff --git a/tests/python/relax/test_op_binary.py
b/tests/python/relax/test_op_binary.py
index 85842f1578..20c111495d 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -46,6 +46,8 @@ def test_op_correctness():
assert relax.op.bitwise_and(x, y).op == Op.get("relax.bitwise_and")
assert relax.op.bitwise_or(x, y).op == Op.get("relax.bitwise_or")
assert relax.op.bitwise_xor(x, y).op == Op.get("relax.bitwise_xor")
+ assert relax.op.left_shift(x, y).op == Op.get("relax.left_shift")
+ assert relax.op.right_shift(x, y).op == Op.get("relax.right_shift")
x = relax.Var("x", R.Tensor((2, 3), "bool"))
y = relax.Var("y", R.Tensor((2, 3), "bool"))