This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new abb901f08c [Relax] Support left_shift and right_shift op (#17448)
abb901f08c is described below

commit abb901f08cdc646d69758eb32503dcab59a904e0
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Oct 7 22:56:54 2024 +0800

    [Relax] Support left_shift and right_shift op (#17448)
    
    Introduced left_shift and right_shift op in Relax with ONNX frontend
    support.
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py   | 104 ++++++++++++++++++++--
 python/tvm/relax/op/__init__.py                   |   2 +
 python/tvm/relax/op/binary.py                     |  32 +++++++
 python/tvm/relax/transform/legalize_ops/binary.py |   2 +
 python/tvm/script/ir_builder/relax/ir.py          |   4 +
 src/relax/op/distributed/binary.cc                |   2 +
 src/relax/op/tensor/binary.cc                     |   2 +
 src/relax/op/tensor/binary.h                      |   6 ++
 tests/python/relax/test_frontend_onnx.py          |  36 ++++++++
 tests/python/relax/test_op_binary.py              |   2 +
 10 files changed, 184 insertions(+), 8 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 36a7823f86..aa156a025f 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -244,7 +244,8 @@ class BinaryBase(OnnxOpConverter):
     relax_op: Callable = None
 
     @classmethod
-    def _impl_v1(cls, bb, inputs, attr, params):
+    def base_impl(cls, bb, inputs, attr, params):
+        """Base implementation for binary operations."""
         if cls.numpy_op is None or cls.relax_op is None:
             raise ValueError("Numpy and Relax operators must be defined for 
BinaryBase.")
         if all([isinstance(inp, relax.Constant) for inp in inputs]):
@@ -274,6 +275,10 @@ class Add(BinaryBase):
     numpy_op = _np.add
     relax_op = relax.op.add
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class Sub(BinaryBase):
     """Converts an onnx Sub node into an equivalent Relax expression."""
@@ -281,6 +286,10 @@ class Sub(BinaryBase):
     numpy_op = _np.subtract
     relax_op = relax.op.subtract
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class Mul(BinaryBase):
     """Converts an onnx Mul node into an equivalent Relax expression."""
@@ -288,6 +297,10 @@ class Mul(BinaryBase):
     numpy_op = _np.multiply
     relax_op = relax.op.multiply
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class Div(BinaryBase):
     """Converts an onnx Div node into an equivalent Relax expression."""
@@ -295,6 +308,10 @@ class Div(BinaryBase):
     numpy_op = _np.divide
     relax_op = relax.op.divide
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class Pow(BinaryBase):
     """Converts an onnx Pow node into an equivalent Relax expression."""
@@ -302,6 +319,10 @@ class Pow(BinaryBase):
     numpy_op = _np.power
     relax_op = relax.op.power
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class And(BinaryBase):
     """Converts an onnx And node into an equivalent Relax expression."""
@@ -309,6 +330,10 @@ class And(BinaryBase):
     numpy_op = _np.logical_and
     relax_op = relax.op.logical_and
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class Or(BinaryBase):
     """Converts an onnx Or node into an equivalent Relax expression."""
@@ -316,6 +341,10 @@ class Or(BinaryBase):
     numpy_op = _np.logical_or
     relax_op = relax.op.logical_or
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class Xor(BinaryBase):
     """Converts an onnx Xor node into an equivalent Relax expression."""
@@ -323,6 +352,10 @@ class Xor(BinaryBase):
     numpy_op = _np.logical_xor
     relax_op = relax.op.logical_xor
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class Less(BinaryBase):
     """Converts an onnx Less node into an equivalent Relax expression."""
@@ -330,6 +363,10 @@ class Less(BinaryBase):
     numpy_op = _np.less
     relax_op = relax.op.less
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class LessOrEqual(BinaryBase):
     """Converts an onnx LessEqual node into an equivalent Relax expression."""
@@ -337,6 +374,10 @@ class LessOrEqual(BinaryBase):
     numpy_op = _np.less_equal
     relax_op = relax.op.less_equal
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class Greater(BinaryBase):
     """Converts an onnx Greater node into an equivalent Relax expression."""
@@ -344,6 +385,10 @@ class Greater(BinaryBase):
     numpy_op = _np.greater
     relax_op = relax.op.greater
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class GreaterOrEqual(BinaryBase):
     """Converts an onnx GreaterEqual node into an equivalent Relax 
expression."""
@@ -351,6 +396,10 @@ class GreaterOrEqual(BinaryBase):
     numpy_op = _np.greater_equal
     relax_op = relax.op.greater_equal
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
 
 class Equal(OnnxOpConverter):
     """Converts an onnx Equal node into an equivalent Relax expression."""
@@ -374,7 +423,8 @@ class BitwiseBase(BinaryBase):
     """Converts an onnx BitwiseBase node into an equivalent Relax 
expression."""
 
     @classmethod
-    def base_impl(cls, bb, inputs, attr, params, py_func, relax_op):
+    def base_impl(cls, bb, inputs, attr, params):
+        """Base implementation for bitwise operations."""
         valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16", 
"uint32", "uint64"]
         for num, inp in enumerate(inputs):
             if inp.struct_info.dtype not in valid_types:
@@ -382,31 +432,69 @@ class BitwiseBase(BinaryBase):
                     f"Bitwise operations expect all inputs to have integer 
types, "
                     f"got {inp.struct_info.dtype} for input {num}"
                 )
-        return BinaryBase.base_impl(bb, inputs, attr, params, py_func, 
relax_op)
+        return super().base_impl(bb, inputs, attr, params)
 
 
 class BitwiseAnd(BitwiseBase):
     """Converts an onnx BitwiseAnd node into an equivalent Relax expression."""
 
+    numpy_op = _np.bitwise_and
+    relax_op = relax.op.bitwise_and
+
     @classmethod
     def _impl_v18(cls, bb, inputs, attr, params):
-        return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y, 
relax.op.bitwise_and)
+        return cls.base_impl(bb, inputs, attr, params)
 
 
 class BitwiseOr(BitwiseBase):
     """Converts an onnx BitwiseOr node into an equivalent Relax expression."""
 
+    numpy_op = _np.bitwise_or
+    relax_op = relax.op.bitwise_or
+
     @classmethod
     def _impl_v18(cls, bb, inputs, attr, params):
-        return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y, 
relax.op.bitwise_or)
+        return cls.base_impl(bb, inputs, attr, params)
 
 
 class BitwiseXor(BitwiseBase):
     """Converts an onnx BitwiseXor node into an equivalent Relax expression."""
 
+    numpy_op = _np.bitwise_xor
+    relax_op = relax.op.bitwise_xor
+
     @classmethod
     def _impl_v18(cls, bb, inputs, attr, params):
-        return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y, 
relax.op.bitwise_xor)
+        return cls.base_impl(bb, inputs, attr, params)
+
+
+class BitwiseNot(BitwiseBase):
+    """Converts an onnx BitwiseNot node into an equivalent Relax expression."""
+
+    numpy_op = _np.bitwise_not
+    relax_op = relax.op.bitwise_not
+
+    @classmethod
+    def _impl_v18(cls, bb, inputs, attr, params):
+        return cls.base_impl(bb, inputs, attr, params)
+
+
+class BitShift(BitwiseBase):
+    """Converts an onnx BitShift node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        direction = attr.get("direction", "LEFT").decode("ascii")
+        if direction == "LEFT":
+            cls.numpy_op = _np.left_shift
+            cls.relax_op = relax.op.left_shift
+        elif direction == "RIGHT":
+            cls.numpy_op = _np.right_shift
+            cls.relax_op = relax.op.right_shift
+        else:
+            raise ValueError("Unsupported Shift Direction: " + direction)
+
+        return cls.base_impl(bb, inputs, attr, params)
 
 
 class Sigmoid(OnnxOpConverter):
@@ -2654,8 +2742,8 @@ def _get_convert_map():
         "BitwiseAnd": BitwiseAnd,
         "BitwiseOr": BitwiseOr,
         "BitwiseXor": BitwiseXor,
-        # "BitwiseNot": BitwiseNot,
-        # "BitwiseShift": BitwiseShift,
+        "BitwiseNot": BitwiseNot,
+        "BitShift": BitShift,
         "And": And,
         "Or": Or,
         "Xor": Xor,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 4581defa1a..c99201e969 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -52,6 +52,7 @@ from .binary import (
     floor_divide,
     greater,
     greater_equal,
+    left_shift,
     less,
     less_equal,
     logical_and,
@@ -62,6 +63,7 @@ from .binary import (
     multiply,
     not_equal,
     power,
+    right_shift,
     subtract,
 )
 from .create import (
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 982b3a24f2..7632235cb3 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -386,3 +386,35 @@ def bitwise_xor(x1: Expr, x2: Expr) -> Expr:
         The computed result.
     """
     return _ffi_api.bitwise_xor(x1, x2)
+
+
+def left_shift(x1: Expr, x2: Expr) -> Expr:
+    """Bitwise Shift Left
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The input tensor to be shifted.
+    x2 : relax.Expr
+        The number of positions to shift.
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.left_shift(x1, x2)
+
+
+def right_shift(x1: Expr, x2: Expr) -> Expr:
+    """Bitwise Shift Right
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The input tensor to be shifted.
+    x2 : relax.Expr
+        The number of positions to shift.
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.right_shift(x1, x2)
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py 
b/python/tvm/relax/transform/legalize_ops/binary.py
index 16d6c02696..d28e100edb 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -62,6 +62,8 @@ register_legalize("relax.minimum", _binary(topi.minimum))
 register_legalize("relax.bitwise_and", _binary(topi.bitwise_and))
 register_legalize("relax.bitwise_or", _binary(topi.bitwise_or))
 register_legalize("relax.bitwise_xor", _binary(topi.bitwise_xor))
+register_legalize("relax.left_shift", _binary(topi.left_shift))
+register_legalize("relax.right_shift", _binary(topi.right_shift))
 
 # logical
 register_legalize("relax.logical_and", _binary(topi.logical_and))
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index c4be8afac4..e6ff35ebe5 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -102,6 +102,7 @@ from tvm.relax.op import (
     isinf,
     isnan,
     layout_transform,
+    left_shift,
     less,
     less_equal,
     linear,
@@ -133,6 +134,7 @@ from tvm.relax.op import (
     quantize,
     repeat,
     reshape,
+    right_shift,
     round,
     rsqrt,
     scatter_elements,
@@ -773,6 +775,7 @@ __all__ = [
     "isinf",
     "isnan",
     "layout_transform",
+    "left_shift",
     "less",
     "less_equal",
     "linear",
@@ -809,6 +812,7 @@ __all__ = [
     "repeat",
     "reshape",
     "rewriter",
+    "right_shift",
     "tensor_to_shape",
     "shape_to_tensor",
     "rocm",
diff --git a/src/relax/op/distributed/binary.cc 
b/src/relax/op/distributed/binary.cc
index 63f4f356c0..6ad71e0f85 100644
--- a/src/relax/op/distributed/binary.cc
+++ b/src/relax/op/distributed/binary.cc
@@ -68,6 +68,8 @@ 
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(logical_xor);
 RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_and);
 RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_or);
 RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_xor);
+RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(left_shift);
+RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(right_shift);
 
 }  // namespace distributed
 }  // namespace relax
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index afc0fb7303..f1dc3d4904 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -207,6 +207,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_xor);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_and);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_or);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_xor);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(left_shift);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(right_shift);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index b28a6c3369..003bcb7e27 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -129,6 +129,12 @@ Expr bitwise_or(Expr x1, Expr x2);
 /*! \brief Broadcasted element-wise bitwise xor */
 Expr bitwise_xor(Expr x1, Expr x2);
 
+/*! \brief Broadcasted element-wise bitwise shift left */
+Expr left_shift(Expr x1, Expr x2);
+
+/*! \brief Broadcasted element-wise bitwise shift right */
+Expr right_shift(Expr x1, Expr x2);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index f2bbd3f3f5..e3ed3a3a9d 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -358,6 +358,42 @@ def test_binary_bool(op_name: str):
     verify_binary(op_name, [32, 32], [32, 32], [32, 32], 
dtype=TensorProto.BOOL)
 
 
[email protected](reason="opset 18 is not supported in CI")
[email protected]("op_name", ["BitwiseAnd", "BitwiseOr", "BitwiseXor"])
+def test_bitwise(op_name: str):
+    verify_binary(op_name, [32, 32], [32, 32], [32, 32], 
dtype=TensorProto.UINT64, opset=18)
+
+
[email protected](reason="opset 18 is not supported in CI")
+def test_bitwise_not():
+    verify_unary(
+        "BitwiseNot",
+        [32, 32],
+        input_dtype=TensorProto.UINT64,
+        output_dtype=TensorProto.UINT64,
+        opset=18,
+    )
+
+
[email protected]("direction", ["LEFT", "RIGHT"])
+def test_bitwise_shift(direction: str):
+    shape = [32, 32]
+    dtype = TensorProto.UINT64
+    test_node = helper.make_node("BitShift", ["a", "b"], ["c"], 
direction=direction)
+    graph = helper.make_graph(
+        [test_node],
+        "binary_test",
+        inputs=[
+            helper.make_tensor_value_info("a", dtype, shape),
+            helper.make_tensor_value_info("b", dtype, shape),
+        ],
+        outputs=[helper.make_tensor_value_info("c", dtype, shape)],
+    )
+
+    model = helper.make_model(graph, producer_name="binary_test")
+    check_correctness(model, inputs={"b": np.random.randint(0, 8, 
shape).astype("uint64")})
+
+
 @pytest.mark.parametrize(
     "op_name",
     [
diff --git a/tests/python/relax/test_op_binary.py 
b/tests/python/relax/test_op_binary.py
index 85842f1578..20c111495d 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -46,6 +46,8 @@ def test_op_correctness():
     assert relax.op.bitwise_and(x, y).op == Op.get("relax.bitwise_and")
     assert relax.op.bitwise_or(x, y).op == Op.get("relax.bitwise_or")
     assert relax.op.bitwise_xor(x, y).op == Op.get("relax.bitwise_xor")
+    assert relax.op.left_shift(x, y).op == Op.get("relax.left_shift")
+    assert relax.op.right_shift(x, y).op == Op.get("relax.right_shift")
 
     x = relax.Var("x", R.Tensor((2, 3), "bool"))
     y = relax.Var("y", R.Tensor((2, 3), "bool"))

Reply via email to