This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8e23806  [Topi, Relay] Add cumprod (#7722)
8e23806 is described below

commit 8e23806d2d522b71979d0a2730b38cc5c3bf6185
Author: AndrewZhaoLuo <andrew.zhao....@gmail.com>
AuthorDate: Wed Mar 24 21:25:18 2021 -0700

    [Topi, Relay] Add cumprod (#7722)
    
    * make cumbinop, refactor cumsum, add cumprod
    
    * cumsum exclusive test
    
    * Add cumprod + flesh out cumsum tests
    
    add cumprod and tests
    
    reinstate tests
    
    rethink
    
    * add rudimentary scan implementation
    
    * add attributes of cumprod node
    
    * add cumprod strategy
    
    * add cuda strategy
    
    * python relay node construction
    
    * change attrs to be reusuable
    
    * add cumprod nodes
    
    * complete tests
    
    * Fix some typos about sum --> prod
    
    typos fix sum -> prod
    
    more typos
    
    more typo fixes
    
    more typos
    
    add doc strings
    
    * Use Bool instead of int to represent exclusive
    
    make exclusive a bool up and down stack
    
    fix x
    
    fix bool err
    
    it is a bool now
    
    fix
    
    fix thing
    
    formatting to pass linter
    
    lint python
    
    cumprod pylint
    
    fix attribute
    
    fix ordering
    
    add exclusivity tests for end to end
    
    fix things
    
    cuda identity_value
    
    * Overall improve formatting, add doc message corrections
    
    simplify construction
    
    clang-format
    
    more tests
    
    undo simpler construction due to function passing stuff
    
    fix docs
    
    more exclusive doc changes
    
    more fixins"
    
    * merge cumsum and cumprod to scan, merge tests
    
    fix stuff
    
    * remove other mentions of cumbinop -> scanop
    
    * lint formatting
    
    Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@Andrews-MacBook-Pro.local>
---
 include/tvm/relay/attrs/transform.h          |  14 +-
 python/tvm/relay/op/_transform.py            |  19 ++-
 python/tvm/relay/op/strategy/cuda.py         |  19 ++-
 python/tvm/relay/op/strategy/generic.py      |  29 +++-
 python/tvm/relay/op/transform.py             |  65 +++++++-
 python/tvm/topi/__init__.py                  |   2 +-
 python/tvm/topi/cuda/scan.py                 | 196 +++++++++++++++++++---
 python/tvm/topi/cumsum.py                    | 121 --------------
 python/tvm/topi/scan.py                      | 236 +++++++++++++++++++++++++++
 python/tvm/topi/unique.py                    |   2 +-
 src/relay/op/tensor/transform.cc             |  34 +++-
 tests/python/relay/test_op_level3.py         |  77 ++++++---
 tests/python/topi/python/test_topi_cumsum.py |  79 ---------
 tests/python/topi/python/test_topi_scan.py   | 144 ++++++++++++++++
 14 files changed, 758 insertions(+), 279 deletions(-)

diff --git a/include/tvm/relay/attrs/transform.h 
b/include/tvm/relay/attrs/transform.h
index ff344f5..a5544c8 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -438,17 +438,19 @@ struct MatrixSetDiagAttrs : public 
tvm::AttrsNode<MatrixSetDiagAttrs> {
   }
 };  // struct MatrixSetDiagAttrs
 
-/*! \brief Attributes used in cumsum operator */
-struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
+/*! \brief Attributes used in cumsum and cumprod operator */
+struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
   Integer axis;
   DataType dtype;
-  Integer exclusive;
-  TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") {
-    TVM_ATTR_FIELD(axis).describe("The axis to sum 
over").set_default(NullValue<Integer>());
+  Bool exclusive = Bool(false);
+  TVM_DECLARE_ATTRS(ScanopAttrs, "relay.attrs.ScanopAttrs") {
+    TVM_ATTR_FIELD(axis).describe("The axis to operate 
over").set_default(NullValue<Integer>());
     TVM_ATTR_FIELD(dtype).describe("Output data 
type").set_default(NullValue<DataType>());
+
+    // Default is 0 which is "false"
     TVM_ATTR_FIELD(exclusive)
         .describe("The first element is not included")
-        .set_default(NullValue<Integer>());
+        .set_default(Bool(false));
   }
 };
 
diff --git a/python/tvm/relay/op/_transform.py 
b/python/tvm/relay/op/_transform.py
index e90263d..1626283 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -19,16 +19,17 @@
 # pylint: disable=too-many-local-variables, too-many-arguments, no-else-return
 
 from __future__ import absolute_import
+
 import tvm
-from tvm import te
-from tvm.te.hybrid import script
+from tvm import te, topi
 from tvm.runtime import convert
-from tvm import topi
+from tvm.te.hybrid import script
 from tvm.topi.utils import get_const_int, get_const_tuple
+
 from . import op as _reg
 from . import strategy
-from .op import OpPattern
 from ._tensor import elemwise_shape_func
+from .op import OpPattern
 
 _reg.register_broadcast_schedule("broadcast_to")
 _reg.register_broadcast_schedule("broadcast_to_like")
@@ -159,6 +160,16 @@ def compute_cumsum(attrs, inputs, output_type):
 _reg.register_strategy("cumsum", strategy.cumsum_strategy)
 _reg.register_shape_func("cumsum", False, elemwise_shape_func)
 
+# cumprod
+@_reg.register_compute("cumprod")
+def compute_cumprod(attrs, inputs, output_type):
+    """Compute definition of cumprod"""
+    return [topi.cumprod(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]
+
+
+_reg.register_strategy("cumprod", strategy.cumprod_strategy)
+_reg.register_shape_func("cumprod", False, elemwise_shape_func)
+
 
 @_reg.register_compute("unique")
 def compute_unique(attrs, inputs, output_type):
diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index e0d0f16..1a67425 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -18,11 +18,12 @@
 # pylint: 
disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
 from tvm import topi
 from tvm.auto_scheduler import is_auto_scheduler_enabled
-from tvm.te import SpecializedCondition
 from tvm.contrib import nvcc
 from tvm.contrib.thrust import can_use_thrust
-from .generic import *
+from tvm.te import SpecializedCondition
+
 from .. import op as _op
+from .generic import *
 
 
 @schedule_injective.register(["cuda", "gpu"])
@@ -1017,13 +1018,25 @@ def cumsum_strategy_cuda(attrs, inputs, out_type, 
target):
     """cumsum cuda strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_cumsum(topi.cuda.cumsum),
+        wrap_compute_scanop(topi.cuda.cumsum),
         wrap_topi_schedule(topi.cuda.schedule_scan),
         name="cumsum.cuda",
     )
     return strategy
 
 
+@cumprod_strategy.register(["cuda", "gpu"])
+def cumprod_strategy_cuda(attrs, inputs, out_type, target):
+    """cumprod cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_scanop(topi.cuda.cumprod),
+        wrap_topi_schedule(topi.cuda.schedule_scan),
+        name="cumprod.cuda",
+    )
+    return strategy
+
+
 @unique_strategy.register(["cuda", "gpu"])
 def unique_strategy_cuda(attrs, inputs, out_type, target):
     """unique cuda strategy"""
diff --git a/python/tvm/relay/op/strategy/generic.py 
b/python/tvm/relay/op/strategy/generic.py
index 04f2564..322a360 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -17,11 +17,12 @@
 """Definition of generic operator strategy."""
 # pylint: disable=invalid-name,unused-argument
 import logging
-
 import re
-from tvm import topi, _ffi, te, ir
-from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, 
get_float_tuple
+
+from tvm import _ffi, ir, te, topi
 from tvm.target import generic_func, override_native_generic_func
+from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, 
get_float_tuple
+
 from .. import op as _op
 
 logger = logging.getLogger("strategy")
@@ -1463,13 +1464,13 @@ def threefry_split_strategy(attrs, inputs, out_type, 
target):
     return strategy
 
 
-def wrap_compute_cumsum(topi_compute):
-    """Wrap cumsum topi compute"""
+def wrap_compute_scanop(topi_compute):
+    """Wrap scanop style topi compute"""
 
-    def _compute_cumsum(attrs, inputs, _):
+    def _compute_scanop(attrs, inputs, _):
         return [topi_compute(inputs[0], attrs.axis, attrs.dtype, 
attrs.exclusive)]
 
-    return _compute_cumsum
+    return _compute_scanop
 
 
 @override_native_generic_func("cumsum_strategy")
@@ -1477,13 +1478,25 @@ def cumsum_strategy(attrs, inputs, out_type, target):
     """cumsum generic strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_cumsum(topi.cumsum),
+        wrap_compute_scanop(topi.cumsum),
         wrap_topi_schedule(topi.generic.schedule_extern),
         name="cumsum.generic",
     )
     return strategy
 
 
+@override_native_generic_func("cumprod_strategy")
+def cumprod_strategy(attrs, inputs, out_type, target):
+    """cumprod generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_scanop(topi.cumprod),
+        wrap_topi_schedule(topi.generic.schedule_extern),
+        name="cumprod.generic",
+    )
+    return strategy
+
+
 def wrap_compute_unique(topi_compute):
     """Wrap unique topi compute"""
 
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index df0ae76..f94a00d 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -18,11 +18,11 @@
 # pylint: disable=import-outside-toplevel
 """Transform operators."""
 
+from ...tir import expr as _expr
+from ..expr import Constant, Expr, Tuple, TupleWrapper, const
 from . import _make
 from .dyn import _make as _dyn_make
 from .tensor import shape_of
-from ..expr import TupleWrapper, const, Constant, Expr, Tuple
-from ...tir import expr as _expr
 
 
 def cast(data, dtype):
@@ -1539,9 +1539,9 @@ def cumsum(data, axis=None, dtype=None, exclusive=None):
         Type of the returned array and of the accumulator in which the 
elements are summed.
         If dtype is not specified, it defaults to the dtype of data.
 
-    exclusive : int, optional
-        If set to 1 will return exclusive sum in which the first element is not
-        included. In other terms, if set to 1, the j-th output element would be
+    exclusive : bool, optional
+        If true will return exclusive sum in which the first element is not
+        included. In other terms, if true, the j-th output element would be
         the sum of the first (j-1) elements. Otherwise, it would be the sum of
         the first j elements.
 
@@ -1577,6 +1577,61 @@ def cumsum(data, axis=None, dtype=None, exclusive=None):
     return _make.cumsum(data, axis, dtype, exclusive)
 
 
+def cumprod(data, axis=None, dtype=None, exclusive=None):
+    """Numpy style cumprod op. Return the cumulative inclusive product of the 
elements along
+    a given axis.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    axis : int, optional
+        Axis along which the cumulative product is computed. The default 
(None) is to compute
+        the cumprod over the flattened array.
+
+    dtype : string, optional
+        Type of the returned array and of the accumulator in which the 
elements are multiplied.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    exclusive : bool, optional
+        If true will return exclusive product in which the first element is not
+        included. In other terms, if true, the j-th output element would be
+        the product of the first (j-1) elements. Otherwise, it would be the 
product of
+        the first j elements. The product of zero elements will be 1.
+
+    Returns
+    -------
+    result : relay.Expr
+        The result has the same size as data, and the same shape as data if 
axis is not None.
+        If axis is None, the result is a 1-d array.
+
+    Examples
+    --------
+    .. code-block:: python
+        a = [[1,2,3], [4,5,6]]
+
+        cumprod(a)  # if axis is not provided, cumprod is done over the 
flattened input.
+        -> [ 1,  2,  6, 24, 120, 720]
+
+        cumprod(a, dtype="float32")
+        -> [  1.,  2.,  6., 24., 120., 720.]
+
+        cumprod(a, axis=0)  # multiply over rows for each of the 3 columns
+        -> [[1, 2, 3],
+            [4, 10, 18]]
+
+        cumprod(a, axis=1)
+        -> [[ 1,  2,  6],
+            [ 4,  20, 120]]
+
+        a = [1, 1, 1, 0, 1, 1, 0]  # a is a boolean array
+        cumprod(a, dtype=int32)  # dtype should be provided to get the 
expected results
+        -> [1, 1, 1, 0, 0, 0, 0]
+    """
+    return _make.cumprod(data, axis, dtype, exclusive)
+
+
 def unique(data, is_sorted=True, return_counts=False):
     """
     Find the unique elements of a 1-D tensor. Please note `output` and 
`counts` are all padded to
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index c196b33..ef2c5c1 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -42,7 +42,7 @@ from .sparse_fill_empty_rows import *
 from .sparse_reshape import *
 from .scatter_add import *
 from .argwhere import *
-from .cumsum import *
+from .scan import *
 from .einsum import *
 from .unique import *
 from . import generic
diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py
index 84ab5dc..3240ebc 100644
--- a/python/tvm/topi/cuda/scan.py
+++ b/python/tvm/topi/cuda/scan.py
@@ -16,13 +16,16 @@
 # under the License.
 # pylint: disable=invalid-name, too-many-locals, too-many-statements
 "Scan related operators"
+from typing import Callable, Optional, Union
+
 import tvm
 from tvm import te
-from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust
-from ..transform import expand_dims, squeeze, transpose, reshape
-from ..utils import ceil_div, swap, prod, get_const_int
-from ..math import cast
+from tvm.contrib.thrust import can_use_rocthrust, can_use_thrust
+
 from .. import tag
+from ..math import cast
+from ..transform import expand_dims, reshape, squeeze, transpose
+from ..utils import ceil_div, get_const_int, prod, swap
 from .injective import schedule_injective_from_existing
 
 
@@ -32,7 +35,7 @@ def _get_thrust_func_name(tvmop):
     return tvmop_to_thrust_func_name[tvmop]
 
 
-def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add):
+def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, 
identity_value=0):
     """Low level IR to do exclusive sum scan along rows of 2D input.
 
     Parameters
@@ -50,6 +53,11 @@ def exclusive_scan_ir(data, output, reduction=None, 
binop=tvm.tir.generic.add):
         A binary associative op to use for scan. The function takes two TIR 
expressions
         and produce a new TIR expression. By default it uses 
tvm.tir.generic.add to compute
         prefix sum.
+
+    identity_value: int or float
+        A value for the binary operation which provides the identity property. 
E.g. if * is
+        your operator and i is the identity_value then a * i = a for all a in 
the domain of
+        your operation.
     """
 
     batch_size = prod(data.shape[:-1])
@@ -134,7 +142,7 @@ def exclusive_scan_ir(data, output, reduction=None, 
binop=tvm.tir.generic.add):
             with ib.if_scope(bx < batch_size):
                 if reduction is not None:
                     reduction[bx] = output[(bx + 1) * scan_axis_size - 1]
-                output[(bx + 1) * scan_axis_size - 1] = cast(0, out_dtype)
+                output[(bx + 1) * scan_axis_size - 1] = cast(identity_value, 
out_dtype)
 
         with ib.for_range(0, lim, dtype="int64") as l2_width:
             width = 2 << (lim - l2_width - 1)
@@ -309,7 +317,12 @@ def scan_thrust(
 
 
 def exclusive_scan(
-    data, axis=-1, return_reduction=False, output_dtype=None, 
binop=tvm.tir.generic.add
+    data,
+    axis=-1,
+    return_reduction=False,
+    output_dtype=None,
+    binop=tvm.tir.generic.add,
+    identity_value=0,
 ):
     """Do exclusive scan on 1D or multidimensional input.
 
@@ -335,6 +348,11 @@ def exclusive_scan(
         and produce a new TIR expression. By default it uses 
tvm.tir.generic.add to compute
         prefix sum.
 
+    identity_value: int or float
+        A value for the binary operation which provides the identity property. 
E.g. if * is
+        your operator and i is the identity_value then a * i = a for all a in 
the domain of
+        your operation.
+
     Returns
     -------
     output : tvm.te.Tensor
@@ -347,9 +365,15 @@ def exclusive_scan(
 
     def do_scan(data, output_dtype):
         target = tvm.target.Target.current()
-        if target and (
-            can_use_thrust(target, "tvm.contrib.thrust.sum_scan")
-            or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan")
+
+        # TODO: add support for a prod_scan
+        if (
+            target
+            and binop == tvm.tir.generic.add
+            and (
+                can_use_thrust(target, "tvm.contrib.thrust.sum_scan")
+                or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan")
+            )
         ):
             return scan_thrust(
                 data, output_dtype, exclusive=True, 
return_reduction=return_reduction, binop=binop
@@ -366,7 +390,9 @@ def exclusive_scan(
             output, reduction = te.extern(
                 [data.shape, data.shape[:-1]],
                 [data],
-                lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], outs[1], 
binop=binop),
+                lambda ins, outs: exclusive_scan_ir(
+                    ins[0], outs[0], outs[1], binop=binop, 
identity_value=identity_value
+                ),
                 dtype=[data.dtype, output_dtype],
                 in_buffers=[data_buf],
                 name="exclusive_scan",
@@ -376,7 +402,9 @@ def exclusive_scan(
             output = te.extern(
                 [data.shape],
                 [data],
-                lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], 
binop=binop),
+                lambda ins, outs: exclusive_scan_ir(
+                    ins[0], outs[0], binop=binop, identity_value=identity_value
+                ),
                 dtype=[output_dtype],
                 in_buffers=[data_buf],
                 out_buffers=[output_buf],
@@ -423,7 +451,7 @@ def exclusive_scan(
     return output
 
 
-def inclusive_scan(data, axis=-1, output_dtype=None, 
binop=tvm.tir.generic.add):
+def inclusive_scan(data, axis=-1, output_dtype=None, 
binop=tvm.tir.generic.add, identity_value=0):
     """Do inclusive scan on 1D or multidimensional input.
 
     Parameters
@@ -442,12 +470,19 @@ def inclusive_scan(data, axis=-1, output_dtype=None, 
binop=tvm.tir.generic.add):
         and produce a new TIR expression. By default it uses 
tvm.tir.generic.add to compute
         prefix sum.
 
+    identity_value: int or float
+        A value for the binary operation which provides the identity property. 
E.g. if * is
+        your operator and i is the identity_value then a * i = a for all a in 
the domain of
+        your operation.
+
     Returns
     -------
     output : tvm.te.Tensor
         A N-D tensor of the same rank N as the input data.
     """
-    ex_scan = exclusive_scan(data, axis, output_dtype=output_dtype, 
binop=binop)
+    ex_scan = exclusive_scan(
+        data, axis, output_dtype=output_dtype, binop=binop, 
identity_value=identity_value
+    )
 
     if output_dtype is not None and data.dtype != output_dtype and 
output_dtype != "":
         data = cast(data, output_dtype)
@@ -486,7 +521,74 @@ def schedule_scan(outs):
     return s
 
 
-def cumsum(data, axis=None, dtype=None, exclusive=None):
+def scanop(
+    data: tvm.te.Tensor,
+    binop: Callable[["tvm.Expr", "tvm.Expr"], "tvm.Expr"],
+    identity_value: Union[float, int],
+    axis: Optional[int] = None,
+    dtype: Optional[str] = None,
+    exclusive: Optional[bool] = None,
+) -> tvm.te.Tensor:
+    """Cumulative binary operator (scan) with similar axis behavior as 
np.cumsum and np.cumprod.
+
+    See cumprod and cumsum for an example of use.
+
+    E.g. if * is your binary operator and the input tensor is [1, 2, 3, 4] the 
output may be
+    [1, 1 * 2, 1 * 2 * 3, 1 * 2 * 3 * 4]
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The input data to the operator.
+
+    binop: Callable (tvm.Expr, tvm.Expr) -> tvm.Expr
+        A binary operator which should be associative and commutative. E.g. if 
* is your
+        operator then a * (b * c) = (a * b) * c and a * b = b * a
+
+    identity_value: int or float
+        A value for the binary operation which provides the identity property. 
E.g. if * is
+        your operator and i is the identity_value then a * i = a for all a in 
the domain of
+        your operation.
+
+    axis : int, optional
+        Axis along which the operation is computed. The default (None) is to 
compute
+        the cumulative operation over the flattened array.
+
+    dtype : string, optional
+        Type of the returned array and of the accumulator in which the 
elements are computed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    exclusive : bool, optional
+        If true will return exclusive cumulative operation in which the first 
element is not
+        included. In other terms, if true, the j-th output element would be
+        the cumulative operation of the first (j-1) elements. Otherwise, it 
would be the
+        cumulative operation of the first j elements.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        The result has the same size as data, and the same shape as data if 
axis is not None.
+        If axis is None, the result is a 1-d array.
+    """
+    if axis is None:
+        axis = 0
+        data = reshape(data, (prod(data.shape),))
+    axis = get_const_int(axis)
+    if exclusive is not None and exclusive:
+        return exclusive_scan(
+            data, axis, output_dtype=dtype, binop=binop, 
identity_value=identity_value
+        )
+    return inclusive_scan(
+        data, axis, output_dtype=dtype, binop=binop, 
identity_value=identity_value
+    )
+
+
+def cumsum(
+    data: tvm.te.Tensor,
+    axis: Optional[int] = None,
+    dtype: Optional[int] = None,
+    exclusive: Optional[bool] = None,
+) -> tvm.te.Tensor:
     """Numpy style cumsum op. Return the cumulative sum of the elements along 
a given axis.
 
     Parameters
@@ -502,9 +604,9 @@ def cumsum(data, axis=None, dtype=None, exclusive=None):
         Type of the returned array and of the accumulator in which the 
elements are summed.
         If dtype is not specified, it defaults to the dtype of data.
 
-    exclusive : int, optional
-        If set to 1 will return exclusive sum in which the first element is not
-        included. In other terms, if set to 1, the j-th output element would be
+    exclusive : bool, optional
+        If true will return exclusive sum in which the first element is not
+        included. In other terms, if true, the j-th output element would be
         the sum of the first (j-1) elements. Otherwise, it would be the sum of
         the first j elements.
 
@@ -514,10 +616,54 @@ def cumsum(data, axis=None, dtype=None, exclusive=None):
         The result has the same size as data, and the same shape as data if 
axis is not None.
         If axis is None, the result is a 1-d array.
     """
-    if axis is None:
-        axis = 0
-        data = reshape(data, (prod(data.shape),))
-    axis = get_const_int(axis)
-    if exclusive is not None and exclusive != 0:
-        return exclusive_scan(data, axis, output_dtype=dtype, 
binop=tvm.tir.generic.add)
-    return inclusive_scan(data, axis, output_dtype=dtype, 
binop=tvm.tir.generic.add)
+    return scanop(
+        data=data,
+        binop=tvm.tir.generic.add,
+        identity_value=0,
+        axis=axis,
+        dtype=dtype,
+        exclusive=exclusive,
+    )
+
+
+def cumprod(
+    data: tvm.te.Tensor,
+    axis: Optional[int] = None,
+    dtype: Optional[int] = None,
+    exclusive: Optional[bool] = None,
+):
+    """Numpy style cumprod op. Return the cumulative product of the elements 
along a given axis.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The input data to the operator.
+
+    axis : int, optional
+        Axis along which the cumulative product is computed. The default 
(None) is to compute
+        the cumproduct over the flattened array.
+
+    dtype : string, optional
+        Type of the returned array and of the accumulator in which the 
elements are multiplied.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    exclusive : bool, optional
+        If True, will return exclusive product in which the first element is 
not
+        included. In other terms, if True, the j-th output element would be
+        the product of the first (j-1) elements. Otherwise, it would be the 
product of
+        the first j elements.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        The result has the same size as data, and the same shape as data if 
axis is not None.
+        If axis is None, the result is a 1-d array.
+    """
+    return scanop(
+        data=data,
+        binop=tvm.tir.generic.multiply,
+        identity_value=1,
+        axis=axis,
+        dtype=dtype,
+        exclusive=exclusive,
+    )
diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py
deleted file mode 100644
index 2013a35..0000000
--- a/python/tvm/topi/cumsum.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name
-"""Cumsum operator"""
-from ..tir import decl_buffer, ir_builder
-from ..te import extern
-from .utils import prod, get_const_int
-from .math import cast
-
-
-def cumsum(data, axis=None, dtype=None, exclusive=None):
-    """Numpy style cumsum op. Return the cumulative sum of the elements along 
a given axis.
-
-    Parameters
-    ----------
-    data : tvm.te.Tensor
-        The input data to the operator.
-
-    axis : int, optional
-        Axis along which the cumulative sum is computed. The default (None) is 
to compute
-        the cumsum over the flattened array.
-
-    dtype : string, optional
-        Type of the returned array and of the accumulator in which the 
elements are summed.
-        If dtype is not specified, it defaults to the dtype of data.
-
-    exclusive : int, optional
-        If set to 1 will return exclusive sum in which the first element is not
-        included. In other terms, if set to 1, the j-th output element would be
-        the sum of the first (j-1) elements. Otherwise, it would be the sum of
-        the first j elements.
-
-    Returns
-    -------
-    result : tvm.te.Tensor
-        The result has the same size as data, and the same shape as data if 
axis is not None.
-        If axis is None, the result is a 1-d array.
-    """
-    if dtype is None or dtype == "":
-        dtype = data.dtype
-
-    def maybe_cast(x):
-        if dtype != data.dtype:
-            return cast(x, dtype)
-        return x
-
-    axis_mul_before = 1
-    axis_mul_after = 1
-
-    if axis is None:
-        axis = 0
-        cumsum_axis_len = prod(data.shape)
-        shape = (cumsum_axis_len,)
-    else:
-        if not isinstance(axis, int):
-            axis = get_const_int(axis)
-
-        shape = data.shape
-        cumsum_axis_len = shape[axis]
-
-        if axis < 0:
-            axis = len(shape) + axis
-
-        for i, value in enumerate(shape, 0):
-            if i < axis:
-                axis_mul_before *= value
-            elif i > axis:
-                axis_mul_after *= value
-
-    if exclusive is None:
-        exclusive = 0
-
-    def gen_ir(data_buf, out_buf):
-        ib = ir_builder.create()
-        data_buf = ib.buffer_ptr(data_buf)
-        out_buf = ib.buffer_ptr(out_buf)
-
-        with ib.for_range(0, axis_mul_before * axis_mul_after, "fused", 
kind="parallel") as fused:
-            i = fused // axis_mul_after
-            j = fused % axis_mul_after
-            base_idx = i * cumsum_axis_len * axis_mul_after + j
-            if exclusive == 0:
-                out_buf[base_idx] = maybe_cast(data_buf[base_idx])
-            else:
-                out_buf[base_idx] = cast(0, dtype)
-            with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k:
-                k = _k + 1
-                cur_idx = base_idx + k * axis_mul_after
-                prev_idx = base_idx + (k - 1) * axis_mul_after
-                if exclusive == 0:
-                    out_buf[cur_idx] = out_buf[prev_idx] + 
maybe_cast(data_buf[cur_idx])
-                else:
-                    out_buf[cur_idx] = out_buf[prev_idx] + 
maybe_cast(data_buf[prev_idx])
-
-        return ib.get()
-
-    out_buf = decl_buffer(shape, dtype, "out_buf")
-
-    return extern(
-        [shape],
-        [data],
-        lambda ins, outs: gen_ir(ins[0], outs[0]),
-        dtype=dtype,
-        out_buffers=[out_buf],
-        name="cumsum_generic",
-        tag="cumsum_generic",
-    )
diff --git a/python/tvm/topi/scan.py b/python/tvm/topi/scan.py
new file mode 100644
index 0000000..f579673
--- /dev/null
+++ b/python/tvm/topi/scan.py
@@ -0,0 +1,236 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Scan (cumulative binary) operators"""
+from typing import Callable, Optional
+
+import tvm
+
+from ..te import extern
+from ..tir import decl_buffer, generic, ir_builder
+from .math import cast
+from .utils import get_const_int, prod
+
+
+def scanop(
+    data: tvm.te.Tensor,
+    binop: Callable[["tvm.Expr", "tvm.Expr"], "tvm.Expr"],
+    identity_value: "tvm.Expr",
+    op_name: str,
+    axis: Optional[int] = None,
+    dtype: Optional[str] = None,
+    exclusive: Optional[bool] = None,
+) -> tvm.te.Tensor:
+    """Cumulative binary operator (scan) with similar axis behavior as 
np.cumsum and np.cumprod.
+
+    See cumprod and cumsum for an example of use.
+
+    E.g. if * is your binary operator and the input tensor is [1, 2, 3, 4] the 
output may be
+    [1, 1 * 2, 1 * 2 * 3, 1 * 2 * 3 * 4]
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The input data to the operator.
+
+    binop: Callable (tvm.Expr, tvm.Expr) -> tvm.Expr
+        A binary operator which should be associative and commutative. E.g. if 
* is your
+        operator then a * (b * c) = (a * b) * c and a * b = b * a
+
+    identity_value: tvm.Expr
+        A value for the binary operation which provides the identity property. 
E.g. if * is
+        your operator and i is the identity_value then a * i = a for all a in 
the domain of
+        your operation.
+
+    axis : int, optional
+        Axis along which the operation is computed. The default (None) is to 
compute
+        the cumulative operation over the flattened array.
+
+    dtype : string, optional
+        Type of the returned array and of the accumulator in which the 
elements are computed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    exclusive : bool, optional
+        If True will return exclusive cumulative operation in which the first 
element is not
+        included. In other terms, if True, the j-th output element would be
+        the cumulative operation of the first (j-1) elements. Otherwise, it 
would be the
+        cumulative operation of the first j elements. The cumulative operation 
of zero elements
+        is assumed to be the identity_value.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        The result has the same size as data, and the same shape as data if 
axis is not None.
+        If axis is None, the result is a 1-d array.
+    """
+    if dtype is None or dtype == "":
+        dtype = data.dtype
+
+    if exclusive is None:
+        exclusive = False
+
+    def maybe_cast(x):
+        if dtype != data.dtype:
+            return cast(x, dtype)
+        return x
+
+    axis_mul_before = 1
+    axis_mul_after = 1
+
+    if axis is None:
+        axis = 0
+        cumsum_axis_len = prod(data.shape)
+        shape = (cumsum_axis_len,)
+    else:
+        if not isinstance(axis, int):
+            axis = get_const_int(axis)
+
+        shape = data.shape
+        cumsum_axis_len = shape[axis]
+
+        if axis < 0:
+            axis = len(shape) + axis
+
+        for i, value in enumerate(shape, 0):
+            if i < axis:
+                axis_mul_before *= value
+            elif i > axis:
+                axis_mul_after *= value
+
+    def gen_ir(data_buf, out_buf):
+        ib = ir_builder.create()
+        data_buf = ib.buffer_ptr(data_buf)
+        out_buf = ib.buffer_ptr(out_buf)
+
+        with ib.for_range(0, axis_mul_before * axis_mul_after, "fused", 
kind="parallel") as fused:
+            i = fused // axis_mul_after
+            j = fused % axis_mul_after
+            base_idx = i * cumsum_axis_len * axis_mul_after + j
+            if exclusive:
+                out_buf[base_idx] = cast(identity_value, dtype)
+            else:
+                out_buf[base_idx] = maybe_cast(data_buf[base_idx])
+            with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k:
+                k = _k + 1
+                cur_idx = base_idx + k * axis_mul_after
+                prev_idx = base_idx + (k - 1) * axis_mul_after
+                if exclusive:
+                    out_buf[cur_idx] = binop(out_buf[prev_idx], 
maybe_cast(data_buf[prev_idx]))
+                else:
+                    out_buf[cur_idx] = binop(out_buf[prev_idx], 
maybe_cast(data_buf[cur_idx]))
+
+        return ib.get()
+
+    out_buf = decl_buffer(shape, dtype, "out_buf")
+
+    return extern(
+        [shape],
+        [data],
+        lambda ins, outs: gen_ir(ins[0], outs[0]),
+        dtype=dtype,
+        out_buffers=[out_buf],
+        name=op_name,
+        tag=op_name,
+    )
+
+
+def cumsum(
+    data: tvm.te.Tensor,
+    axis: Optional[int] = None,
+    dtype: Optional[int] = None,
+    exclusive: Optional[bool] = None,
+) -> tvm.te.Tensor:
+    """Numpy style cumsum op. Return the cumulative sum of the elements along 
a given axis.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The input data to the operator.
+
+    axis : int, optional
+        Axis along which the cumulative sum is computed. The default (None) is 
to compute
+        the cumsum over the flattened array.
+
+    dtype : string, optional
+        Type of the returned array and of the accumulator in which the 
elements are summed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    exclusive : bool, optional
+        If True, will return exclusive sum in which the first element is not
+        included. In other terms, if True, the j-th output element would be
+        the sum of the first (j-1) elements. Otherwise, it would be the sum of
+        the first j elements.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        The result has the same size as data, and the same shape as data if 
axis is not None.
+        If axis is None, the result is a 1-d array.
+    """
+    return scanop(
+        data=data,
+        binop=generic.add,
+        identity_value=0,
+        op_name="cumsum_generic",
+        axis=axis,
+        dtype=dtype,
+        exclusive=exclusive,
+    )
+
+
+def cumprod(
+    data: tvm.te.Tensor,
+    axis: Optional[int] = None,
+    dtype: Optional[int] = None,
+    exclusive: Optional[bool] = None,
+) -> tvm.te.Tensor:
+    """Numpy style cumprod op. Return the cumulative product of the elements 
along a given axis.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The input data to the operator.
+
+    axis : int, optional
+        Axis along which the cumulative product is computed. The default 
(None) is to compute
+        the cumproduct over the flattened array.
+
+    dtype : string, optional
+        Type of the returned array and of the accumulator in which the 
elements are multiplied.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    exclusive : bool, optional
+        If True, will return exclusive product in which the first element is 
not
+        included. In other terms, if True, the j-th output element would be
+        the product of the first (j-1) elements. Otherwise, it would be the 
product of
+        the first j elements.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        The result has the same size as data, and the same shape as data if 
axis is not None.
+        If axis is None, the result is a 1-d array.
+    """
+    return scanop(
+        data=data,
+        binop=generic.multiply,
+        identity_value=1,
+        op_name="cumprod_generic",
+        axis=axis,
+        dtype=dtype,
+        exclusive=exclusive,
+    )
diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py
index b4f27b3..e725655 100644
--- a/python/tvm/topi/unique.py
+++ b/python/tvm/topi/unique.py
@@ -18,7 +18,7 @@
 """Unique operator"""
 from tvm import te, tir
 from ..te import hybrid
-from .cumsum import cumsum
+from .scan import cumsum
 from .sort import sort, argsort
 
 
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index b65068b..6fb9f77 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3772,20 +3772,20 @@ RELAY_REGISTER_OP("adv_index")
     .set_attr<TOpPattern>("TOpPattern", kInjective)
     .set_attr<FTVMCompute>("FTVMCompute", AdvIndexCompute);
 
-TVM_REGISTER_NODE_TYPE(CumsumAttrs);
+TVM_REGISTER_NODE_TYPE(ScanopAttrs);
 
-bool CumsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+bool ScanopRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                const TypeReporter& reporter) {
   // types: [data, output]
   ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and 
another for the output";
   const auto* data = types[0].as<TensorTypeNode>();
   if (data == nullptr) {
     ICHECK(types[0].as<IncompleteTypeNode>())
-        << "cumsum: expect input type to be TensorType but get " << types[0];
+        << "Scanop: expect input type to be TensorType but get " << types[0];
     return false;
   }
 
-  const auto* param = attrs.as<CumsumAttrs>();
+  const auto* param = attrs.as<ScanopAttrs>();
 
   auto dtype = param->dtype;
   if (dtype.is_void()) {
@@ -3805,8 +3805,8 @@ bool CumsumRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
   return true;
 }
 
-Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Integer exclusive) {
-  auto attrs = make_object<CumsumAttrs>();
+Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) {
+  auto attrs = make_object<ScanopAttrs>();
   attrs->dtype = dtype;
   attrs->axis = axis;
   attrs->exclusive = exclusive;
@@ -3822,7 +3822,27 @@ RELAY_REGISTER_OP("cumsum")
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "The input tensor.")
     .set_support_level(3)
-    .add_type_rel("Cumsum", CumsumRel)
+    .add_type_rel("Cumsum", ScanopRel)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque);
+
+Expr MakeCumprod(Expr data, Integer axis, DataType dtype, Bool exclusive) {
+  auto attrs = make_object<ScanopAttrs>();
+  attrs->dtype = dtype;
+  attrs->axis = axis;
+  attrs->exclusive = exclusive;
+  static const Op& op = Op::Get("cumprod");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.cumprod").set_body_typed(MakeCumprod);
+
+RELAY_REGISTER_OP("cumprod")
+    .describe(
+        R"doc(Return the cumulative product of the elements along a given 
axis.)doc" TVM_ADD_FILELINE)
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_support_level(3)
+    .add_type_rel("Cumprod", ScanopRel)
     .set_attr<TOpPattern>("TOpPattern", kOpaque);
 
 TVM_REGISTER_NODE_TYPE(UniqueAttrs);
diff --git a/tests/python/relay/test_op_level3.py 
b/tests/python/relay/test_op_level3.py
index d2a5090..7e443aa 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -16,16 +16,16 @@
 # under the License.
 """ Support level3 operator test cases.
 """
+from typing import Callable, Optional
+
 import numpy as np
 import pytest
 import tvm
-from tvm import te
-from tvm import relay
+import tvm.testing
+from tvm import relay, te
 from tvm.error import TVMError
 from tvm.relay import create_executor, transform
 from tvm.relay.testing import check_grad, run_infer_type
-from typing import Optional
-import tvm.testing
 
 
 def test_zeros_ones():
@@ -1758,12 +1758,27 @@ def test_adv_index():
     verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])
 
 
-@tvm.testing.parametrize_targets
-def test_cumsum(target, ctx):
-    def verify_cumsum(data_np, np_out, axis=None, out_dtype=None, rtol=1e-5, 
atol=1e-5):
+# Helper for testing binop functions
+scanops_supported = {"cumsum": relay.op.cumsum, "cumprod": relay.op.cumprod}
+
+
+def run_binop_tests(
+    target, ctx, binop_type: str, gt_func: Callable[..., np.array], 
identity_value: int
+):
+    def assert_relay_scanop(
+        data_np: np.array,
+        np_out: np.array,
+        axis: int = None,
+        out_dtype: str = None,
+        rtol: float = 1e-5,
+        atol: float = 1e-5,
+        exclusive: bool = False,
+    ):
         inp = relay.var("data", relay.TensorType(data_np.shape, 
str(data_np.dtype)))
 
-        out = relay.op.cumsum(inp, axis, out_dtype)
+        if binop_type not in scanops_supported.keys():
+            raise ValueError(f"Unknown function {binop_type}. Options: 
{scanops_supported.keys()}")
+        out = scanops_supported[binop_type](inp, axis, out_dtype, 
exclusive=exclusive)
         func = relay.Function([inp], out)
 
         for kind in ["graph", "debug"]:
@@ -1772,24 +1787,48 @@ def test_cumsum(target, ctx):
             tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=rtol, 
atol=atol)
 
     data = np.array([2, 3, 0])
-    verify_cumsum(data, np.cumsum(data))
-    verify_cumsum(data, np.cumsum(data), out_dtype="int64")
+    assert_relay_scanop(data, gt_func(data))
+    assert_relay_scanop(data, gt_func(data), out_dtype="int64")
 
     data = np.random.randn(10, 10)
-    verify_cumsum(data, np.cumsum(data))
-    verify_cumsum(data, np.cumsum(data, axis=0), axis=0)
-    verify_cumsum(data, np.cumsum(data, axis=1), axis=1)
+    assert_relay_scanop(data, gt_func(data))
+    assert_relay_scanop(data, gt_func(data, axis=0), axis=0)
+    assert_relay_scanop(data, gt_func(data, axis=1), axis=1)
 
     data = np.random.randn(10, 5, 10).astype("float32")
-    verify_cumsum(data, np.cumsum(data), rtol=1e-4, atol=1e-4)
-    verify_cumsum(data, np.cumsum(data, axis=0), axis=0, rtol=1e-4, atol=1e-4)
-    verify_cumsum(data, np.cumsum(data, axis=1), axis=1, rtol=1e-4, atol=1e-4)
-    verify_cumsum(data, np.cumsum(data, axis=-1), axis=-1, rtol=1e-4, 
atol=1e-4)
+    assert_relay_scanop(data, gt_func(data), rtol=1e-4, atol=1e-4)
+    assert_relay_scanop(data, gt_func(data, axis=0), axis=0, rtol=1e-4, 
atol=1e-4)
+    assert_relay_scanop(data, gt_func(data, axis=1), axis=1, rtol=1e-4, 
atol=1e-4)
+    assert_relay_scanop(data, gt_func(data, axis=-1), axis=-1, rtol=1e-4, 
atol=1e-4)
 
     data = np.random.rand(10) > 0.5
     data = data.astype(np.int32)
-    verify_cumsum(data, np.cumsum(data, dtype=np.int32))
-    verify_cumsum(data, np.cumsum(data, dtype="int64"), out_dtype="int64")
+    assert_relay_scanop(data, gt_func(data, dtype=np.int32))
+    assert_relay_scanop(data, gt_func(data, dtype="int64"), out_dtype="int64")
+
+    # Test exclusivity operations
+    data = np.random.randint(-100, 100, size=(10, 10)).astype("int64")
+    expected_result = np.roll(gt_func(data), 1)
+    expected_result[0] = identity_value
+    assert_relay_scanop(data, expected_result, exclusive=True)
+
+    expected_result = np.roll(gt_func(data, axis=0), 1, axis=0)
+    expected_result[0, :] = identity_value
+    assert_relay_scanop(data, expected_result, exclusive=True, axis=0)
+
+    expected_result = np.roll(gt_func(data, axis=1), 1, axis=1)
+    expected_result[:, 0] = identity_value
+    assert_relay_scanop(data, expected_result, exclusive=True, axis=1)
+
+
+@tvm.testing.parametrize_targets
+def test_cumsum(target, ctx):
+    run_binop_tests(target, ctx, binop_type="cumsum", gt_func=np.cumsum, 
identity_value=0)
+
+
+@tvm.testing.parametrize_targets
+def test_cumprod(target, ctx):
+    run_binop_tests(target, ctx, binop_type="cumprod", gt_func=np.cumprod, 
identity_value=1)
 
 
 @tvm.testing.parametrize_targets
diff --git a/tests/python/topi/python/test_topi_cumsum.py 
b/tests/python/topi/python/test_topi_cumsum.py
deleted file mode 100644
index cfe5130..0000000
--- a/tests/python/topi/python/test_topi_cumsum.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import numpy as np
-import tvm
-import tvm.testing
-from tvm import topi
-import tvm.topi.testing
-
-
-@tvm.testing.parametrize_targets
-def test_cumsum(ctx, target):
-    def check_cumsum(np_ref, data, axis=None, dtype=None):
-        implementations = {
-            "generic": (lambda x: topi.cumsum(x, axis, dtype), 
topi.generic.schedule_extern),
-            "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), 
topi.cuda.schedule_scan),
-            "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), 
topi.cuda.schedule_scan),
-            "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), 
topi.cuda.schedule_scan),
-            "metal": (lambda x: topi.cuda.cumsum(x, axis, dtype), 
topi.cuda.schedule_scan),
-        }
-        fcompute, fschedule = tvm.topi.testing.dispatch(target, 
implementations)
-        tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, 
fcompute, fschedule)
-
-    data = np.array([2, 3, 0])
-    check_cumsum(np.cumsum(data), data)
-
-    data = np.random.rand(10) > 0.5
-    data = data.astype(np.int32)
-    check_cumsum(np.cumsum(data, dtype=np.int32), data)
-    check_cumsum(np.cumsum(data), data, dtype="int64")
-
-    data = np.random.rand(10) > 0.5
-    check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")
-
-    for in_dtype in ["float32", "float64"]:
-        if target == "metal" and in_dtype == "float64":
-            # float64 is not supported in metal
-            continue
-        data = np.random.randn(10, 10).astype(in_dtype)
-        check_cumsum(np.cumsum(data), data)
-        check_cumsum(np.cumsum(data, axis=0), data, axis=0)
-        check_cumsum(np.cumsum(data, axis=1), data, axis=1)
-
-        data = np.random.randn(10, 5, 10).astype(in_dtype)
-        check_cumsum(np.cumsum(data), data)
-        check_cumsum(np.cumsum(data, axis=0), data, axis=0)
-        check_cumsum(np.cumsum(data, axis=1), data, axis=1)
-        check_cumsum(np.cumsum(data, axis=-1), data, axis=-1)
-
-    for in_dtype in ["int32", "int64"]:
-        data = np.random.randint(-100, 100, size=(100, 100)).astype(in_dtype)
-        check_cumsum(np.cumsum(data, dtype=in_dtype), data)
-        check_cumsum(np.cumsum(data), data, dtype="int64")
-        check_cumsum(np.cumsum(data, axis=0, dtype=in_dtype), data, axis=0)
-        check_cumsum(np.cumsum(data, axis=1, dtype=in_dtype), data, axis=1)
-
-        data = np.random.randint(1 << 30, (1 << 31) - 1, 
size=(100)).astype(in_dtype)
-        check_cumsum(np.cumsum(data), data, dtype="int64")
-
-
-if __name__ == "__main__":
-    test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
-    test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
-    test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))
-    test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))
-    test_cumsum(tvm.context("metal"), tvm.target.Target("metal"))
diff --git a/tests/python/topi/python/test_topi_scan.py 
b/tests/python/topi/python/test_topi_scan.py
new file mode 100644
index 0000000..020fde5
--- /dev/null
+++ b/tests/python/topi/python/test_topi_scan.py
@@ -0,0 +1,144 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Callable
+
+import numpy as np
+import tvm
+import tvm.testing
+import tvm.topi.testing
+from tvm import topi
+
+topi_funcs = {
+    "cumsum": {"generic": topi.cumsum, "cuda": topi.cuda.cumsum},
+    "cumprod": {"generic": topi.cumprod, "cuda": topi.cuda.cumprod},
+}
+
+identity_value = {"cumsum": 0, "cumprod": 1}
+
+
+def get_implementations(name, axis, dtype, exclusive):
+    topi_func_generic = topi_funcs[name]["generic"]
+    topi_func_cuda = topi_funcs[name]["cuda"]
+
+    return {
+        "generic": (
+            lambda x: topi_func_generic(x, axis, dtype, exclusive=exclusive),
+            topi.generic.schedule_extern,
+        ),
+        "cuda": (
+            lambda x: topi_func_cuda(x, axis, dtype, exclusive=exclusive),
+            topi.cuda.schedule_scan,
+        ),
+        "nvptx": (
+            lambda x: topi_func_cuda(x, axis, dtype, exclusive=exclusive),
+            topi.cuda.schedule_scan,
+        ),
+        "vulkan": (
+            lambda x: topi_func_cuda(x, axis, dtype, exclusive=exclusive),
+            topi.cuda.schedule_scan,
+        ),
+        "metal": (
+            lambda x: topi_func_cuda(x, axis, dtype, exclusive=exclusive),
+            topi.cuda.schedule_scan,
+        ),
+    }
+
+
+def _run_tests(
+    ctx,
+    target,
+    op_name: str = "cumsum",
+    gt_func: Callable[..., np.array] = np.cumsum,
+):
+    def check_scan(np_ref, data, axis=None, dtype=None, exclusive=False):
+        implementations = get_implementations(op_name, axis, dtype, exclusive)
+        fcompute, fschedule = tvm.topi.testing.dispatch(target, 
implementations)
+        tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, 
fcompute, fschedule)
+
+    data = np.array([2, 3, 0])
+    check_scan(gt_func(data), data)
+
+    data = np.random.rand(10) > 0.5
+    data = data.astype(np.int32)
+    check_scan(gt_func(data, dtype=np.int32), data)
+    check_scan(gt_func(data), data, dtype="int64")
+
+    data = np.random.rand(10) > 0.5
+    check_scan(gt_func(data, dtype=np.int32), data, dtype="int32")
+
+    for in_dtype in ["float32", "float64"]:
+        if target == "metal" and in_dtype == "float64":
+            # float64 is not supported in metal
+            continue
+        data = np.random.randn(10, 10).astype(in_dtype)
+        check_scan(gt_func(data), data)
+        check_scan(gt_func(data, axis=0), data, axis=0)
+        check_scan(gt_func(data, axis=1), data, axis=1)
+
+        data = np.random.randn(10, 5, 10).astype(in_dtype)
+        check_scan(gt_func(data), data)
+        check_scan(gt_func(data, axis=0), data, axis=0)
+        check_scan(gt_func(data, axis=1), data, axis=1)
+        check_scan(gt_func(data, axis=-1), data, axis=-1)
+
+    for in_dtype in ["int32", "int64"]:
+        data = np.random.randint(-100, 100, size=(100, 100)).astype(in_dtype)
+        check_scan(gt_func(data, dtype=in_dtype), data)
+        check_scan(gt_func(data), data, dtype="int64")
+        check_scan(gt_func(data, axis=0, dtype=in_dtype), data, axis=0)
+        check_scan(gt_func(data, axis=1, dtype=in_dtype), data, axis=1)
+
+        data = np.random.randint(1 << 30, (1 << 31) - 1, 
size=(100)).astype(in_dtype)
+        check_scan(gt_func(data), data, dtype="int64")
+
+    data = np.random.randint(-100, 100, size=(100, 100)).astype("int64")
+
+    expected_result = np.roll(gt_func(data), 1)
+    expected_result[0] = identity_value[op_name]
+    check_scan(expected_result, data, dtype="int64", exclusive=True)
+
+    expected_result = np.roll(gt_func(data, axis=0, dtype=in_dtype), 1, axis=0)
+    expected_result[0, :] = identity_value[op_name]
+    check_scan(expected_result, data, axis=0, exclusive=True)
+
+    expected_result = np.roll(gt_func(data, axis=1, dtype=in_dtype), 1, axis=1)
+    expected_result[:, 0] = identity_value[op_name]
+    check_scan(gt_func(data, axis=1, dtype=in_dtype), data, axis=1)
+
+
+@tvm.testing.parametrize_targets
+def test_cumsum(ctx, target):
+    _run_tests(ctx, target, op_name="cumsum", gt_func=np.cumsum)
+
+
+@tvm.testing.parametrize_targets
+def test_cumprod(ctx, target):
+    _run_tests(ctx, target, op_name="cumprod", gt_func=np.cumprod)
+
+
+if __name__ == "__main__":
+    test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
+    test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
+    test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))
+    test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))
+    test_cumsum(tvm.context("metal"), tvm.target.Target("metal"))
+
+    test_cumprod(tvm.context("cpu"), tvm.target.Target("llvm"))
+    test_cumprod(tvm.context("cuda"), tvm.target.Target("cuda"))
+    test_cumprod(tvm.context("nvptx"), tvm.target.Target("nvptx"))
+    test_cumprod(tvm.context("vulkan"), tvm.target.Target("vulkan"))
+    test_cumprod(tvm.context("metal"), tvm.target.Target("metal"))

Reply via email to