This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 8e23806 [Topi, Relay] Add cumprod (#7722) 8e23806 is described below commit 8e23806d2d522b71979d0a2730b38cc5c3bf6185 Author: AndrewZhaoLuo <andrew.zhao....@gmail.com> AuthorDate: Wed Mar 24 21:25:18 2021 -0700 [Topi, Relay] Add cumprod (#7722) * make cumbinop, refactor cumsum, add cumprod * cumsum exclusive test * Add cumprod + flesh out cumsum tests add cumprod and tests reinstate tests rethink * add rudimentary scan implementation * add attributes of cumprod node * add cumprod strategy * add cuda strategy * python relay node construction * change attrs to be reusuable * add cumprod nodes * complete tests * Fix some typos about sum --> prod typos fix sum -> prod more typos more typo fixes more typos add doc strings * Use Bool instead of int to represent exclusive make exclusive a bool up and down stack fix x fix bool err it is a bool now fix fix thing formatting to pass linter lint python cumprod pylint fix attribute fix ordering add exclusivity tests for end to end fix things cuda identity_value * Overall improve formatting, add doc message corrections simplify construction clang-format more tests undo simpler construction due to function passing stuff fix docs more exclusive doc changes more fixins" * merge cumsum and cumprod to scan, merge tests fix stuff * remove other mentions of cumbinop -> scanop * lint formatting Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@Andrews-MacBook-Pro.local> --- include/tvm/relay/attrs/transform.h | 14 +- python/tvm/relay/op/_transform.py | 19 ++- python/tvm/relay/op/strategy/cuda.py | 19 ++- python/tvm/relay/op/strategy/generic.py | 29 +++- python/tvm/relay/op/transform.py | 65 +++++++- python/tvm/topi/__init__.py | 2 +- python/tvm/topi/cuda/scan.py | 196 +++++++++++++++++++--- python/tvm/topi/cumsum.py | 121 -------------- python/tvm/topi/scan.py | 236 +++++++++++++++++++++++++++ python/tvm/topi/unique.py | 2 +- src/relay/op/tensor/transform.cc | 34 +++- tests/python/relay/test_op_level3.py | 77 ++++++--- tests/python/topi/python/test_topi_cumsum.py | 79 --------- tests/python/topi/python/test_topi_scan.py | 144 ++++++++++++++++ 14 files changed, 758 insertions(+), 279 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index ff344f5..a5544c8 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -438,17 +438,19 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> { } }; // struct MatrixSetDiagAttrs -/*! \brief Attributes used in cumsum operator */ -struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> { +/*! \brief Attributes used in cumsum and cumprod operator */ +struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> { Integer axis; DataType dtype; - Integer exclusive; - TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") { - TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue<Integer>()); + Bool exclusive = Bool(false); + TVM_DECLARE_ATTRS(ScanopAttrs, "relay.attrs.ScanopAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis to operate over").set_default(NullValue<Integer>()); TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>()); + + // Default is 0 which is "false" TVM_ATTR_FIELD(exclusive) .describe("The first element is not included") - .set_default(NullValue<Integer>()); + .set_default(Bool(false)); } }; diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index e90263d..1626283 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -19,16 +19,17 @@ # pylint: disable=too-many-local-variables, too-many-arguments, no-else-return from __future__ import absolute_import + import tvm -from tvm import te -from tvm.te.hybrid import script +from tvm import te, topi from tvm.runtime import convert -from tvm import topi +from tvm.te.hybrid import script from tvm.topi.utils import get_const_int, get_const_tuple + from . import op as _reg from . import strategy -from .op import OpPattern from ._tensor import elemwise_shape_func +from .op import OpPattern _reg.register_broadcast_schedule("broadcast_to") _reg.register_broadcast_schedule("broadcast_to_like") @@ -159,6 +160,16 @@ def compute_cumsum(attrs, inputs, output_type): _reg.register_strategy("cumsum", strategy.cumsum_strategy) _reg.register_shape_func("cumsum", False, elemwise_shape_func) +# cumprod +@_reg.register_compute("cumprod") +def compute_cumprod(attrs, inputs, output_type): + """Compute definition of cumprod""" + return [topi.cumprod(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)] + + +_reg.register_strategy("cumprod", strategy.cumprod_strategy) +_reg.register_shape_func("cumprod", False, elemwise_shape_func) + @_reg.register_compute("unique") def compute_unique(attrs, inputs, output_type): diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index e0d0f16..1a67425 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -18,11 +18,12 @@ # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import from tvm import topi from tvm.auto_scheduler import is_auto_scheduler_enabled -from tvm.te import SpecializedCondition from tvm.contrib import nvcc from tvm.contrib.thrust import can_use_thrust -from .generic import * +from tvm.te import SpecializedCondition + from .. import op as _op +from .generic import * @schedule_injective.register(["cuda", "gpu"]) @@ -1017,13 +1018,25 @@ def cumsum_strategy_cuda(attrs, inputs, out_type, target): """cumsum cuda strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_cumsum(topi.cuda.cumsum), + wrap_compute_scanop(topi.cuda.cumsum), wrap_topi_schedule(topi.cuda.schedule_scan), name="cumsum.cuda", ) return strategy +@cumprod_strategy.register(["cuda", "gpu"]) +def cumprod_strategy_cuda(attrs, inputs, out_type, target): + """cumprod cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scanop(topi.cuda.cumprod), + wrap_topi_schedule(topi.cuda.schedule_scan), + name="cumprod.cuda", + ) + return strategy + + @unique_strategy.register(["cuda", "gpu"]) def unique_strategy_cuda(attrs, inputs, out_type, target): """unique cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 04f2564..322a360 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -17,11 +17,12 @@ """Definition of generic operator strategy.""" # pylint: disable=invalid-name,unused-argument import logging - import re -from tvm import topi, _ffi, te, ir -from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple + +from tvm import _ffi, ir, te, topi from tvm.target import generic_func, override_native_generic_func +from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, get_float_tuple + from .. import op as _op logger = logging.getLogger("strategy") @@ -1463,13 +1464,13 @@ def threefry_split_strategy(attrs, inputs, out_type, target): return strategy -def wrap_compute_cumsum(topi_compute): - """Wrap cumsum topi compute""" +def wrap_compute_scanop(topi_compute): + """Wrap scanop style topi compute""" - def _compute_cumsum(attrs, inputs, _): + def _compute_scanop(attrs, inputs, _): return [topi_compute(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)] - return _compute_cumsum + return _compute_scanop @override_native_generic_func("cumsum_strategy") @@ -1477,13 +1478,25 @@ def cumsum_strategy(attrs, inputs, out_type, target): """cumsum generic strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_cumsum(topi.cumsum), + wrap_compute_scanop(topi.cumsum), wrap_topi_schedule(topi.generic.schedule_extern), name="cumsum.generic", ) return strategy +@override_native_generic_func("cumprod_strategy") +def cumprod_strategy(attrs, inputs, out_type, target): + """cumprod generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scanop(topi.cumprod), + wrap_topi_schedule(topi.generic.schedule_extern), + name="cumprod.generic", + ) + return strategy + + def wrap_compute_unique(topi_compute): """Wrap unique topi compute""" diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index df0ae76..f94a00d 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -18,11 +18,11 @@ # pylint: disable=import-outside-toplevel """Transform operators.""" +from ...tir import expr as _expr +from ..expr import Constant, Expr, Tuple, TupleWrapper, const from . import _make from .dyn import _make as _dyn_make from .tensor import shape_of -from ..expr import TupleWrapper, const, Constant, Expr, Tuple -from ...tir import expr as _expr def cast(data, dtype): @@ -1539,9 +1539,9 @@ def cumsum(data, axis=None, dtype=None, exclusive=None): Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of data. - exclusive : int, optional - If set to 1 will return exclusive sum in which the first element is not - included. In other terms, if set to 1, the j-th output element would be + exclusive : bool, optional + If true will return exclusive sum in which the first element is not + included. In other terms, if true, the j-th output element would be the sum of the first (j-1) elements. Otherwise, it would be the sum of the first j elements. @@ -1577,6 +1577,61 @@ def cumsum(data, axis=None, dtype=None, exclusive=None): return _make.cumsum(data, axis, dtype, exclusive) +def cumprod(data, axis=None, dtype=None, exclusive=None): + """Numpy style cumprod op. Return the cumulative inclusive product of the elements along + a given axis. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + axis : int, optional + Axis along which the cumulative product is computed. The default (None) is to compute + the cumprod over the flattened array. + + dtype : string, optional + Type of the returned array and of the accumulator in which the elements are multiplied. + If dtype is not specified, it defaults to the dtype of data. + + exclusive : bool, optional + If true will return exclusive product in which the first element is not + included. In other terms, if true, the j-th output element would be + the product of the first (j-1) elements. Otherwise, it would be the product of + the first j elements. The product of zero elements will be 1. + + Returns + ------- + result : relay.Expr + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. + + Examples + -------- + .. code-block:: python + a = [[1,2,3], [4,5,6]] + + cumprod(a) # if axis is not provided, cumprod is done over the flattened input. + -> [ 1, 2, 6, 24, 120, 720] + + cumprod(a, dtype="float32") + -> [ 1., 2., 6., 24., 120., 720.] + + cumprod(a, axis=0) # multiply over rows for each of the 3 columns + -> [[1, 2, 3], + [4, 10, 18]] + + cumprod(a, axis=1) + -> [[ 1, 2, 6], + [ 4, 20, 120]] + + a = [1, 1, 1, 0, 1, 1, 0] # a is a boolean array + cumprod(a, dtype=int32) # dtype should be provided to get the expected results + -> [1, 1, 1, 0, 0, 0, 0] + """ + return _make.cumprod(data, axis, dtype, exclusive) + + def unique(data, is_sorted=True, return_counts=False): """ Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index c196b33..ef2c5c1 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -42,7 +42,7 @@ from .sparse_fill_empty_rows import * from .sparse_reshape import * from .scatter_add import * from .argwhere import * -from .cumsum import * +from .scan import * from .einsum import * from .unique import * from . import generic diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 84ab5dc..3240ebc 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -16,13 +16,16 @@ # under the License. # pylint: disable=invalid-name, too-many-locals, too-many-statements "Scan related operators" +from typing import Callable, Optional, Union + import tvm from tvm import te -from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust -from ..transform import expand_dims, squeeze, transpose, reshape -from ..utils import ceil_div, swap, prod, get_const_int -from ..math import cast +from tvm.contrib.thrust import can_use_rocthrust, can_use_thrust + from .. import tag +from ..math import cast +from ..transform import expand_dims, reshape, squeeze, transpose +from ..utils import ceil_div, get_const_int, prod, swap from .injective import schedule_injective_from_existing @@ -32,7 +35,7 @@ def _get_thrust_func_name(tvmop): return tvmop_to_thrust_func_name[tvmop] -def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add): +def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, identity_value=0): """Low level IR to do exclusive sum scan along rows of 2D input. Parameters @@ -50,6 +53,11 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add): A binary associative op to use for scan. The function takes two TIR expressions and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute prefix sum. + + identity_value: int or float + A value for the binary operation which provides the identity property. E.g. if * is + your operator and i is the identity_value then a * i = a for all a in the domain of + your operation. """ batch_size = prod(data.shape[:-1]) @@ -134,7 +142,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add): with ib.if_scope(bx < batch_size): if reduction is not None: reduction[bx] = output[(bx + 1) * scan_axis_size - 1] - output[(bx + 1) * scan_axis_size - 1] = cast(0, out_dtype) + output[(bx + 1) * scan_axis_size - 1] = cast(identity_value, out_dtype) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << (lim - l2_width - 1) @@ -309,7 +317,12 @@ def scan_thrust( def exclusive_scan( - data, axis=-1, return_reduction=False, output_dtype=None, binop=tvm.tir.generic.add + data, + axis=-1, + return_reduction=False, + output_dtype=None, + binop=tvm.tir.generic.add, + identity_value=0, ): """Do exclusive scan on 1D or multidimensional input. @@ -335,6 +348,11 @@ def exclusive_scan( and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute prefix sum. + identity_value: int or float + A value for the binary operation which provides the identity property. E.g. if * is + your operator and i is the identity_value then a * i = a for all a in the domain of + your operation. + Returns ------- output : tvm.te.Tensor @@ -347,9 +365,15 @@ def exclusive_scan( def do_scan(data, output_dtype): target = tvm.target.Target.current() - if target and ( - can_use_thrust(target, "tvm.contrib.thrust.sum_scan") - or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan") + + # TODO: add support for a prod_scan + if ( + target + and binop == tvm.tir.generic.add + and ( + can_use_thrust(target, "tvm.contrib.thrust.sum_scan") + or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan") + ) ): return scan_thrust( data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop @@ -366,7 +390,9 @@ def exclusive_scan( output, reduction = te.extern( [data.shape, data.shape[:-1]], [data], - lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], outs[1], binop=binop), + lambda ins, outs: exclusive_scan_ir( + ins[0], outs[0], outs[1], binop=binop, identity_value=identity_value + ), dtype=[data.dtype, output_dtype], in_buffers=[data_buf], name="exclusive_scan", @@ -376,7 +402,9 @@ def exclusive_scan( output = te.extern( [data.shape], [data], - lambda ins, outs: exclusive_scan_ir(ins[0], outs[0], binop=binop), + lambda ins, outs: exclusive_scan_ir( + ins[0], outs[0], binop=binop, identity_value=identity_value + ), dtype=[output_dtype], in_buffers=[data_buf], out_buffers=[output_buf], @@ -423,7 +451,7 @@ def exclusive_scan( return output -def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add): +def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add, identity_value=0): """Do inclusive scan on 1D or multidimensional input. Parameters @@ -442,12 +470,19 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add): and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute prefix sum. + identity_value: int or float + A value for the binary operation which provides the identity property. E.g. if * is + your operator and i is the identity_value then a * i = a for all a in the domain of + your operation. + Returns ------- output : tvm.te.Tensor A N-D tensor of the same rank N as the input data. """ - ex_scan = exclusive_scan(data, axis, output_dtype=output_dtype, binop=binop) + ex_scan = exclusive_scan( + data, axis, output_dtype=output_dtype, binop=binop, identity_value=identity_value + ) if output_dtype is not None and data.dtype != output_dtype and output_dtype != "": data = cast(data, output_dtype) @@ -486,7 +521,74 @@ def schedule_scan(outs): return s -def cumsum(data, axis=None, dtype=None, exclusive=None): +def scanop( + data: tvm.te.Tensor, + binop: Callable[["tvm.Expr", "tvm.Expr"], "tvm.Expr"], + identity_value: Union[float, int], + axis: Optional[int] = None, + dtype: Optional[str] = None, + exclusive: Optional[bool] = None, +) -> tvm.te.Tensor: + """Cumulative binary operator (scan) with similar axis behavior as np.cumsum and np.cumprod. + + See cumprod and cumsum for an example of use. + + E.g. if * is your binary operator and the input tensor is [1, 2, 3, 4] the output may be + [1, 1 * 2, 1 * 2 * 3, 1 * 2 * 3 * 4] + + Parameters + ---------- + data : tvm.te.Tensor + The input data to the operator. + + binop: Callable (tvm.Expr, tvm.Expr) -> tvm.Expr + A binary operator which should be associative and commutative. E.g. if * is your + operator then a * (b * c) = (a * b) * c and a * b = b * a + + identity_value: int or float + A value for the binary operation which provides the identity property. E.g. if * is + your operator and i is the identity_value then a * i = a for all a in the domain of + your operation. + + axis : int, optional + Axis along which the operation is computed. The default (None) is to compute + the cumulative operation over the flattened array. + + dtype : string, optional + Type of the returned array and of the accumulator in which the elements are computed. + If dtype is not specified, it defaults to the dtype of data. + + exclusive : bool, optional + If true will return exclusive cumulative operation in which the first element is not + included. In other terms, if true, the j-th output element would be + the cumulative operation of the first (j-1) elements. Otherwise, it would be the + cumulative operation of the first j elements. + + Returns + ------- + result : tvm.te.Tensor + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. + """ + if axis is None: + axis = 0 + data = reshape(data, (prod(data.shape),)) + axis = get_const_int(axis) + if exclusive is not None and exclusive: + return exclusive_scan( + data, axis, output_dtype=dtype, binop=binop, identity_value=identity_value + ) + return inclusive_scan( + data, axis, output_dtype=dtype, binop=binop, identity_value=identity_value + ) + + +def cumsum( + data: tvm.te.Tensor, + axis: Optional[int] = None, + dtype: Optional[int] = None, + exclusive: Optional[bool] = None, +) -> tvm.te.Tensor: """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis. Parameters @@ -502,9 +604,9 @@ def cumsum(data, axis=None, dtype=None, exclusive=None): Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of data. - exclusive : int, optional - If set to 1 will return exclusive sum in which the first element is not - included. In other terms, if set to 1, the j-th output element would be + exclusive : bool, optional + If true will return exclusive sum in which the first element is not + included. In other terms, if true, the j-th output element would be the sum of the first (j-1) elements. Otherwise, it would be the sum of the first j elements. @@ -514,10 +616,54 @@ def cumsum(data, axis=None, dtype=None, exclusive=None): The result has the same size as data, and the same shape as data if axis is not None. If axis is None, the result is a 1-d array. """ - if axis is None: - axis = 0 - data = reshape(data, (prod(data.shape),)) - axis = get_const_int(axis) - if exclusive is not None and exclusive != 0: - return exclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add) - return inclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add) + return scanop( + data=data, + binop=tvm.tir.generic.add, + identity_value=0, + axis=axis, + dtype=dtype, + exclusive=exclusive, + ) + + +def cumprod( + data: tvm.te.Tensor, + axis: Optional[int] = None, + dtype: Optional[int] = None, + exclusive: Optional[bool] = None, +): + """Numpy style cumprod op. Return the cumulative product of the elements along a given axis. + + Parameters + ---------- + data : tvm.te.Tensor + The input data to the operator. + + axis : int, optional + Axis along which the cumulative product is computed. The default (None) is to compute + the cumproduct over the flattened array. + + dtype : string, optional + Type of the returned array and of the accumulator in which the elements are multiplied. + If dtype is not specified, it defaults to the dtype of data. + + exclusive : bool, optional + If True, will return exclusive product in which the first element is not + included. In other terms, if True, the j-th output element would be + the product of the first (j-1) elements. Otherwise, it would be the product of + the first j elements. + + Returns + ------- + result : tvm.te.Tensor + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. + """ + return scanop( + data=data, + binop=tvm.tir.generic.multiply, + identity_value=1, + axis=axis, + dtype=dtype, + exclusive=exclusive, + ) diff --git a/python/tvm/topi/cumsum.py b/python/tvm/topi/cumsum.py deleted file mode 100644 index 2013a35..0000000 --- a/python/tvm/topi/cumsum.py +++ /dev/null @@ -1,121 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Cumsum operator""" -from ..tir import decl_buffer, ir_builder -from ..te import extern -from .utils import prod, get_const_int -from .math import cast - - -def cumsum(data, axis=None, dtype=None, exclusive=None): - """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis. - - Parameters - ---------- - data : tvm.te.Tensor - The input data to the operator. - - axis : int, optional - Axis along which the cumulative sum is computed. The default (None) is to compute - the cumsum over the flattened array. - - dtype : string, optional - Type of the returned array and of the accumulator in which the elements are summed. - If dtype is not specified, it defaults to the dtype of data. - - exclusive : int, optional - If set to 1 will return exclusive sum in which the first element is not - included. In other terms, if set to 1, the j-th output element would be - the sum of the first (j-1) elements. Otherwise, it would be the sum of - the first j elements. - - Returns - ------- - result : tvm.te.Tensor - The result has the same size as data, and the same shape as data if axis is not None. - If axis is None, the result is a 1-d array. - """ - if dtype is None or dtype == "": - dtype = data.dtype - - def maybe_cast(x): - if dtype != data.dtype: - return cast(x, dtype) - return x - - axis_mul_before = 1 - axis_mul_after = 1 - - if axis is None: - axis = 0 - cumsum_axis_len = prod(data.shape) - shape = (cumsum_axis_len,) - else: - if not isinstance(axis, int): - axis = get_const_int(axis) - - shape = data.shape - cumsum_axis_len = shape[axis] - - if axis < 0: - axis = len(shape) + axis - - for i, value in enumerate(shape, 0): - if i < axis: - axis_mul_before *= value - elif i > axis: - axis_mul_after *= value - - if exclusive is None: - exclusive = 0 - - def gen_ir(data_buf, out_buf): - ib = ir_builder.create() - data_buf = ib.buffer_ptr(data_buf) - out_buf = ib.buffer_ptr(out_buf) - - with ib.for_range(0, axis_mul_before * axis_mul_after, "fused", kind="parallel") as fused: - i = fused // axis_mul_after - j = fused % axis_mul_after - base_idx = i * cumsum_axis_len * axis_mul_after + j - if exclusive == 0: - out_buf[base_idx] = maybe_cast(data_buf[base_idx]) - else: - out_buf[base_idx] = cast(0, dtype) - with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k: - k = _k + 1 - cur_idx = base_idx + k * axis_mul_after - prev_idx = base_idx + (k - 1) * axis_mul_after - if exclusive == 0: - out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx]) - else: - out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[prev_idx]) - - return ib.get() - - out_buf = decl_buffer(shape, dtype, "out_buf") - - return extern( - [shape], - [data], - lambda ins, outs: gen_ir(ins[0], outs[0]), - dtype=dtype, - out_buffers=[out_buf], - name="cumsum_generic", - tag="cumsum_generic", - ) diff --git a/python/tvm/topi/scan.py b/python/tvm/topi/scan.py new file mode 100644 index 0000000..f579673 --- /dev/null +++ b/python/tvm/topi/scan.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Scan (cumulative binary) operators""" +from typing import Callable, Optional + +import tvm + +from ..te import extern +from ..tir import decl_buffer, generic, ir_builder +from .math import cast +from .utils import get_const_int, prod + + +def scanop( + data: tvm.te.Tensor, + binop: Callable[["tvm.Expr", "tvm.Expr"], "tvm.Expr"], + identity_value: "tvm.Expr", + op_name: str, + axis: Optional[int] = None, + dtype: Optional[str] = None, + exclusive: Optional[bool] = None, +) -> tvm.te.Tensor: + """Cumulative binary operator (scan) with similar axis behavior as np.cumsum and np.cumprod. + + See cumprod and cumsum for an example of use. + + E.g. if * is your binary operator and the input tensor is [1, 2, 3, 4] the output may be + [1, 1 * 2, 1 * 2 * 3, 1 * 2 * 3 * 4] + + Parameters + ---------- + data : tvm.te.Tensor + The input data to the operator. + + binop: Callable (tvm.Expr, tvm.Expr) -> tvm.Expr + A binary operator which should be associative and commutative. E.g. if * is your + operator then a * (b * c) = (a * b) * c and a * b = b * a + + identity_value: tvm.Expr + A value for the binary operation which provides the identity property. E.g. if * is + your operator and i is the identity_value then a * i = a for all a in the domain of + your operation. + + axis : int, optional + Axis along which the operation is computed. The default (None) is to compute + the cumulative operation over the flattened array. + + dtype : string, optional + Type of the returned array and of the accumulator in which the elements are computed. + If dtype is not specified, it defaults to the dtype of data. + + exclusive : bool, optional + If True will return exclusive cumulative operation in which the first element is not + included. In other terms, if True, the j-th output element would be + the cumulative operation of the first (j-1) elements. Otherwise, it would be the + cumulative operation of the first j elements. The cumulative operation of zero elements + is assumed to be the identity_value. + + Returns + ------- + result : tvm.te.Tensor + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. + """ + if dtype is None or dtype == "": + dtype = data.dtype + + if exclusive is None: + exclusive = False + + def maybe_cast(x): + if dtype != data.dtype: + return cast(x, dtype) + return x + + axis_mul_before = 1 + axis_mul_after = 1 + + if axis is None: + axis = 0 + cumsum_axis_len = prod(data.shape) + shape = (cumsum_axis_len,) + else: + if not isinstance(axis, int): + axis = get_const_int(axis) + + shape = data.shape + cumsum_axis_len = shape[axis] + + if axis < 0: + axis = len(shape) + axis + + for i, value in enumerate(shape, 0): + if i < axis: + axis_mul_before *= value + elif i > axis: + axis_mul_after *= value + + def gen_ir(data_buf, out_buf): + ib = ir_builder.create() + data_buf = ib.buffer_ptr(data_buf) + out_buf = ib.buffer_ptr(out_buf) + + with ib.for_range(0, axis_mul_before * axis_mul_after, "fused", kind="parallel") as fused: + i = fused // axis_mul_after + j = fused % axis_mul_after + base_idx = i * cumsum_axis_len * axis_mul_after + j + if exclusive: + out_buf[base_idx] = cast(identity_value, dtype) + else: + out_buf[base_idx] = maybe_cast(data_buf[base_idx]) + with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k: + k = _k + 1 + cur_idx = base_idx + k * axis_mul_after + prev_idx = base_idx + (k - 1) * axis_mul_after + if exclusive: + out_buf[cur_idx] = binop(out_buf[prev_idx], maybe_cast(data_buf[prev_idx])) + else: + out_buf[cur_idx] = binop(out_buf[prev_idx], maybe_cast(data_buf[cur_idx])) + + return ib.get() + + out_buf = decl_buffer(shape, dtype, "out_buf") + + return extern( + [shape], + [data], + lambda ins, outs: gen_ir(ins[0], outs[0]), + dtype=dtype, + out_buffers=[out_buf], + name=op_name, + tag=op_name, + ) + + +def cumsum( + data: tvm.te.Tensor, + axis: Optional[int] = None, + dtype: Optional[int] = None, + exclusive: Optional[bool] = None, +) -> tvm.te.Tensor: + """Numpy style cumsum op. Return the cumulative sum of the elements along a given axis. + + Parameters + ---------- + data : tvm.te.Tensor + The input data to the operator. + + axis : int, optional + Axis along which the cumulative sum is computed. The default (None) is to compute + the cumsum over the flattened array. + + dtype : string, optional + Type of the returned array and of the accumulator in which the elements are summed. + If dtype is not specified, it defaults to the dtype of data. + + exclusive : bool, optional + If True, will return exclusive sum in which the first element is not + included. In other terms, if True, the j-th output element would be + the sum of the first (j-1) elements. Otherwise, it would be the sum of + the first j elements. + + Returns + ------- + result : tvm.te.Tensor + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. + """ + return scanop( + data=data, + binop=generic.add, + identity_value=0, + op_name="cumsum_generic", + axis=axis, + dtype=dtype, + exclusive=exclusive, + ) + + +def cumprod( + data: tvm.te.Tensor, + axis: Optional[int] = None, + dtype: Optional[int] = None, + exclusive: Optional[bool] = None, +) -> tvm.te.Tensor: + """Numpy style cumprod op. Return the cumulative product of the elements along a given axis. + + Parameters + ---------- + data : tvm.te.Tensor + The input data to the operator. + + axis : int, optional + Axis along which the cumulative product is computed. The default (None) is to compute + the cumproduct over the flattened array. + + dtype : string, optional + Type of the returned array and of the accumulator in which the elements are multiplied. + If dtype is not specified, it defaults to the dtype of data. + + exclusive : bool, optional + If True, will return exclusive product in which the first element is not + included. In other terms, if True, the j-th output element would be + the product of the first (j-1) elements. Otherwise, it would be the product of + the first j elements. + + Returns + ------- + result : tvm.te.Tensor + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. + """ + return scanop( + data=data, + binop=generic.multiply, + identity_value=1, + op_name="cumprod_generic", + axis=axis, + dtype=dtype, + exclusive=exclusive, + ) diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py index b4f27b3..e725655 100644 --- a/python/tvm/topi/unique.py +++ b/python/tvm/topi/unique.py @@ -18,7 +18,7 @@ """Unique operator""" from tvm import te, tir from ..te import hybrid -from .cumsum import cumsum +from .scan import cumsum from .sort import sort, argsort diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b65068b..6fb9f77 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3772,20 +3772,20 @@ RELAY_REGISTER_OP("adv_index") .set_attr<TOpPattern>("TOpPattern", kInjective) .set_attr<FTVMCompute>("FTVMCompute", AdvIndexCompute); -TVM_REGISTER_NODE_TYPE(CumsumAttrs); +TVM_REGISTER_NODE_TYPE(ScanopAttrs); -bool CumsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, +bool ScanopRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, output] ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output"; const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) { ICHECK(types[0].as<IncompleteTypeNode>()) - << "cumsum: expect input type to be TensorType but get " << types[0]; + << "Scanop: expect input type to be TensorType but get " << types[0]; return false; } - const auto* param = attrs.as<CumsumAttrs>(); + const auto* param = attrs.as<ScanopAttrs>(); auto dtype = param->dtype; if (dtype.is_void()) { @@ -3805,8 +3805,8 @@ bool CumsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, return true; } -Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Integer exclusive) { - auto attrs = make_object<CumsumAttrs>(); +Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) { + auto attrs = make_object<ScanopAttrs>(); attrs->dtype = dtype; attrs->axis = axis; attrs->exclusive = exclusive; @@ -3822,7 +3822,27 @@ RELAY_REGISTER_OP("cumsum") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(3) - .add_type_rel("Cumsum", CumsumRel) + .add_type_rel("Cumsum", ScanopRel) + .set_attr<TOpPattern>("TOpPattern", kOpaque); + +Expr MakeCumprod(Expr data, Integer axis, DataType dtype, Bool exclusive) { + auto attrs = make_object<ScanopAttrs>(); + attrs->dtype = dtype; + attrs->axis = axis; + attrs->exclusive = exclusive; + static const Op& op = Op::Get("cumprod"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.cumprod").set_body_typed(MakeCumprod); + +RELAY_REGISTER_OP("cumprod") + .describe( + R"doc(Return the cumulative product of the elements along a given axis.)doc" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Cumprod", ScanopRel) .set_attr<TOpPattern>("TOpPattern", kOpaque); TVM_REGISTER_NODE_TYPE(UniqueAttrs); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index d2a5090..7e443aa 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -16,16 +16,16 @@ # under the License. """ Support level3 operator test cases. """ +from typing import Callable, Optional + import numpy as np import pytest import tvm -from tvm import te -from tvm import relay +import tvm.testing +from tvm import relay, te from tvm.error import TVMError from tvm.relay import create_executor, transform from tvm.relay.testing import check_grad, run_infer_type -from typing import Optional -import tvm.testing def test_zeros_ones(): @@ -1758,12 +1758,27 @@ def test_adv_index(): verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)]) -@tvm.testing.parametrize_targets -def test_cumsum(target, ctx): - def verify_cumsum(data_np, np_out, axis=None, out_dtype=None, rtol=1e-5, atol=1e-5): +# Helper for testing binop functions +scanops_supported = {"cumsum": relay.op.cumsum, "cumprod": relay.op.cumprod} + + +def run_binop_tests( + target, ctx, binop_type: str, gt_func: Callable[..., np.array], identity_value: int +): + def assert_relay_scanop( + data_np: np.array, + np_out: np.array, + axis: int = None, + out_dtype: str = None, + rtol: float = 1e-5, + atol: float = 1e-5, + exclusive: bool = False, + ): inp = relay.var("data", relay.TensorType(data_np.shape, str(data_np.dtype))) - out = relay.op.cumsum(inp, axis, out_dtype) + if binop_type not in scanops_supported.keys(): + raise ValueError(f"Unknown function {binop_type}. Options: {scanops_supported.keys()}") + out = scanops_supported[binop_type](inp, axis, out_dtype, exclusive=exclusive) func = relay.Function([inp], out) for kind in ["graph", "debug"]: @@ -1772,24 +1787,48 @@ def test_cumsum(target, ctx): tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=rtol, atol=atol) data = np.array([2, 3, 0]) - verify_cumsum(data, np.cumsum(data)) - verify_cumsum(data, np.cumsum(data), out_dtype="int64") + assert_relay_scanop(data, gt_func(data)) + assert_relay_scanop(data, gt_func(data), out_dtype="int64") data = np.random.randn(10, 10) - verify_cumsum(data, np.cumsum(data)) - verify_cumsum(data, np.cumsum(data, axis=0), axis=0) - verify_cumsum(data, np.cumsum(data, axis=1), axis=1) + assert_relay_scanop(data, gt_func(data)) + assert_relay_scanop(data, gt_func(data, axis=0), axis=0) + assert_relay_scanop(data, gt_func(data, axis=1), axis=1) data = np.random.randn(10, 5, 10).astype("float32") - verify_cumsum(data, np.cumsum(data), rtol=1e-4, atol=1e-4) - verify_cumsum(data, np.cumsum(data, axis=0), axis=0, rtol=1e-4, atol=1e-4) - verify_cumsum(data, np.cumsum(data, axis=1), axis=1, rtol=1e-4, atol=1e-4) - verify_cumsum(data, np.cumsum(data, axis=-1), axis=-1, rtol=1e-4, atol=1e-4) + assert_relay_scanop(data, gt_func(data), rtol=1e-4, atol=1e-4) + assert_relay_scanop(data, gt_func(data, axis=0), axis=0, rtol=1e-4, atol=1e-4) + assert_relay_scanop(data, gt_func(data, axis=1), axis=1, rtol=1e-4, atol=1e-4) + assert_relay_scanop(data, gt_func(data, axis=-1), axis=-1, rtol=1e-4, atol=1e-4) data = np.random.rand(10) > 0.5 data = data.astype(np.int32) - verify_cumsum(data, np.cumsum(data, dtype=np.int32)) - verify_cumsum(data, np.cumsum(data, dtype="int64"), out_dtype="int64") + assert_relay_scanop(data, gt_func(data, dtype=np.int32)) + assert_relay_scanop(data, gt_func(data, dtype="int64"), out_dtype="int64") + + # Test exclusivity operations + data = np.random.randint(-100, 100, size=(10, 10)).astype("int64") + expected_result = np.roll(gt_func(data), 1) + expected_result[0] = identity_value + assert_relay_scanop(data, expected_result, exclusive=True) + + expected_result = np.roll(gt_func(data, axis=0), 1, axis=0) + expected_result[0, :] = identity_value + assert_relay_scanop(data, expected_result, exclusive=True, axis=0) + + expected_result = np.roll(gt_func(data, axis=1), 1, axis=1) + expected_result[:, 0] = identity_value + assert_relay_scanop(data, expected_result, exclusive=True, axis=1) + + +@tvm.testing.parametrize_targets +def test_cumsum(target, ctx): + run_binop_tests(target, ctx, binop_type="cumsum", gt_func=np.cumsum, identity_value=0) + + +@tvm.testing.parametrize_targets +def test_cumprod(target, ctx): + run_binop_tests(target, ctx, binop_type="cumprod", gt_func=np.cumprod, identity_value=1) @tvm.testing.parametrize_targets diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py deleted file mode 100644 index cfe5130..0000000 --- a/tests/python/topi/python/test_topi_cumsum.py +++ /dev/null @@ -1,79 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import numpy as np -import tvm -import tvm.testing -from tvm import topi -import tvm.topi.testing - - -@tvm.testing.parametrize_targets -def test_cumsum(ctx, target): - def check_cumsum(np_ref, data, axis=None, dtype=None): - implementations = { - "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern), - "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), - "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), - "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), - "metal": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), - } - fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) - tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule) - - data = np.array([2, 3, 0]) - check_cumsum(np.cumsum(data), data) - - data = np.random.rand(10) > 0.5 - data = data.astype(np.int32) - check_cumsum(np.cumsum(data, dtype=np.int32), data) - check_cumsum(np.cumsum(data), data, dtype="int64") - - data = np.random.rand(10) > 0.5 - check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") - - for in_dtype in ["float32", "float64"]: - if target == "metal" and in_dtype == "float64": - # float64 is not supported in metal - continue - data = np.random.randn(10, 10).astype(in_dtype) - check_cumsum(np.cumsum(data), data) - check_cumsum(np.cumsum(data, axis=0), data, axis=0) - check_cumsum(np.cumsum(data, axis=1), data, axis=1) - - data = np.random.randn(10, 5, 10).astype(in_dtype) - check_cumsum(np.cumsum(data), data) - check_cumsum(np.cumsum(data, axis=0), data, axis=0) - check_cumsum(np.cumsum(data, axis=1), data, axis=1) - check_cumsum(np.cumsum(data, axis=-1), data, axis=-1) - - for in_dtype in ["int32", "int64"]: - data = np.random.randint(-100, 100, size=(100, 100)).astype(in_dtype) - check_cumsum(np.cumsum(data, dtype=in_dtype), data) - check_cumsum(np.cumsum(data), data, dtype="int64") - check_cumsum(np.cumsum(data, axis=0, dtype=in_dtype), data, axis=0) - check_cumsum(np.cumsum(data, axis=1, dtype=in_dtype), data, axis=1) - - data = np.random.randint(1 << 30, (1 << 31) - 1, size=(100)).astype(in_dtype) - check_cumsum(np.cumsum(data), data, dtype="int64") - - -if __name__ == "__main__": - test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm")) - test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda")) - test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx")) - test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan")) - test_cumsum(tvm.context("metal"), tvm.target.Target("metal")) diff --git a/tests/python/topi/python/test_topi_scan.py b/tests/python/topi/python/test_topi_scan.py new file mode 100644 index 0000000..020fde5 --- /dev/null +++ b/tests/python/topi/python/test_topi_scan.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Callable + +import numpy as np +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import topi + +topi_funcs = { + "cumsum": {"generic": topi.cumsum, "cuda": topi.cuda.cumsum}, + "cumprod": {"generic": topi.cumprod, "cuda": topi.cuda.cumprod}, +} + +identity_value = {"cumsum": 0, "cumprod": 1} + + +def get_implementations(name, axis, dtype, exclusive): + topi_func_generic = topi_funcs[name]["generic"] + topi_func_cuda = topi_funcs[name]["cuda"] + + return { + "generic": ( + lambda x: topi_func_generic(x, axis, dtype, exclusive=exclusive), + topi.generic.schedule_extern, + ), + "cuda": ( + lambda x: topi_func_cuda(x, axis, dtype, exclusive=exclusive), + topi.cuda.schedule_scan, + ), + "nvptx": ( + lambda x: topi_func_cuda(x, axis, dtype, exclusive=exclusive), + topi.cuda.schedule_scan, + ), + "vulkan": ( + lambda x: topi_func_cuda(x, axis, dtype, exclusive=exclusive), + topi.cuda.schedule_scan, + ), + "metal": ( + lambda x: topi_func_cuda(x, axis, dtype, exclusive=exclusive), + topi.cuda.schedule_scan, + ), + } + + +def _run_tests( + ctx, + target, + op_name: str = "cumsum", + gt_func: Callable[..., np.array] = np.cumsum, +): + def check_scan(np_ref, data, axis=None, dtype=None, exclusive=False): + implementations = get_implementations(op_name, axis, dtype, exclusive) + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule) + + data = np.array([2, 3, 0]) + check_scan(gt_func(data), data) + + data = np.random.rand(10) > 0.5 + data = data.astype(np.int32) + check_scan(gt_func(data, dtype=np.int32), data) + check_scan(gt_func(data), data, dtype="int64") + + data = np.random.rand(10) > 0.5 + check_scan(gt_func(data, dtype=np.int32), data, dtype="int32") + + for in_dtype in ["float32", "float64"]: + if target == "metal" and in_dtype == "float64": + # float64 is not supported in metal + continue + data = np.random.randn(10, 10).astype(in_dtype) + check_scan(gt_func(data), data) + check_scan(gt_func(data, axis=0), data, axis=0) + check_scan(gt_func(data, axis=1), data, axis=1) + + data = np.random.randn(10, 5, 10).astype(in_dtype) + check_scan(gt_func(data), data) + check_scan(gt_func(data, axis=0), data, axis=0) + check_scan(gt_func(data, axis=1), data, axis=1) + check_scan(gt_func(data, axis=-1), data, axis=-1) + + for in_dtype in ["int32", "int64"]: + data = np.random.randint(-100, 100, size=(100, 100)).astype(in_dtype) + check_scan(gt_func(data, dtype=in_dtype), data) + check_scan(gt_func(data), data, dtype="int64") + check_scan(gt_func(data, axis=0, dtype=in_dtype), data, axis=0) + check_scan(gt_func(data, axis=1, dtype=in_dtype), data, axis=1) + + data = np.random.randint(1 << 30, (1 << 31) - 1, size=(100)).astype(in_dtype) + check_scan(gt_func(data), data, dtype="int64") + + data = np.random.randint(-100, 100, size=(100, 100)).astype("int64") + + expected_result = np.roll(gt_func(data), 1) + expected_result[0] = identity_value[op_name] + check_scan(expected_result, data, dtype="int64", exclusive=True) + + expected_result = np.roll(gt_func(data, axis=0, dtype=in_dtype), 1, axis=0) + expected_result[0, :] = identity_value[op_name] + check_scan(expected_result, data, axis=0, exclusive=True) + + expected_result = np.roll(gt_func(data, axis=1, dtype=in_dtype), 1, axis=1) + expected_result[:, 0] = identity_value[op_name] + check_scan(gt_func(data, axis=1, dtype=in_dtype), data, axis=1) + + +@tvm.testing.parametrize_targets +def test_cumsum(ctx, target): + _run_tests(ctx, target, op_name="cumsum", gt_func=np.cumsum) + + +@tvm.testing.parametrize_targets +def test_cumprod(ctx, target): + _run_tests(ctx, target, op_name="cumprod", gt_func=np.cumprod) + + +if __name__ == "__main__": + test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm")) + test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda")) + test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx")) + test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan")) + test_cumsum(tvm.context("metal"), tvm.target.Target("metal")) + + test_cumprod(tvm.context("cpu"), tvm.target.Target("llvm")) + test_cumprod(tvm.context("cuda"), tvm.target.Target("cuda")) + test_cumprod(tvm.context("nvptx"), tvm.target.Target("nvptx")) + test_cumprod(tvm.context("vulkan"), tvm.target.Target("vulkan")) + test_cumprod(tvm.context("metal"), tvm.target.Target("metal"))