This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 7dd27f075832f9ea9d89d1116f637de8480bd5c9 Author: Ruihang Lai <ruiha...@cs.cmu.edu> AuthorDate: Tue Feb 14 15:03:22 2023 -0500 [Unity] Relax op: neural networks (#13993) This PR is about the high-level tensor computation operators in Relax. This PR includes the neural network operators. --- include/tvm/relax/attrs/nn.h | 190 +++++ python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/{ => nn}/__init__.py | 31 +- .../tvm/relax/op/{__init__.py => nn/_ffi_api.py} | 30 +- python/tvm/relax/op/nn/nn.py | 524 ++++++++++++ python/tvm/relax/op/op_attrs.py | 35 + python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/op/nn/convolution.cc | 146 ++++ src/relax/op/nn/convolution.h | 63 ++ src/relax/op/nn/nn.cc | 245 ++++++ src/relax/op/nn/nn.h | 81 ++ src/relax/op/nn/pooling.cc | 184 ++++ src/relax/op/nn/pooling.h | 46 + tests/python/relax/test_op_nn.py | 929 +++++++++++++++++++++ tests/python/relax/test_op_nn_convolution.py | 429 ++++++++++ tests/python/relax/test_op_nn_pooling.py | 429 ++++++++++ tests/python/relax/test_tvmscript_parser_op_nn.py | 193 +++++ 17 files changed, 3503 insertions(+), 55 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h new file mode 100644 index 0000000000..694a510706 --- /dev/null +++ b/include/tvm/relax/attrs/nn.h @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/nn.h + * \brief Attributes for neural network operators. + */ +#ifndef TVM_RELAX_ATTRS_NN_H_ +#define TVM_RELAX_ATTRS_NN_H_ + +#include <tvm/relax/expr.h> + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in Conv2d operator */ +struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> { + Array<IntImm> strides; + Array<IntImm> padding; + Array<IntImm> dilation; + int groups; + String data_layout; + String kernel_layout; + String out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv2DAttrs, "relax.attrs.Conv2DAttrs") { + TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation).describe( + "Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).describe( + "Number of groups to split the input into for grouped convolution. The number of input and " + "output channels should be divisible by the number of groups."); + TVM_ATTR_FIELD(data_layout) + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(out_dtype).describe( + "Output data type, set to explicit type under mixed precision setting"); + } +}; // struct Conv2dAttrs + +/*! \brief Attributes used in max_pool2d operator */ +struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> { + Array<IntImm> pool_size; + Array<IntImm> strides; + Array<IntImm> padding; + Array<IntImm> dilation; + bool ceil_mode; + String layout; + String out_layout; + + TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relax.attrs.MaxPool2DAttrs") { + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the convolution."); + TVM_ATTR_FIELD(padding).describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(ceil_mode).describe( + "A boolean indicating if use ceil or floor to compute the output shape. By using ceil, " + "every element in the input tensor will be covered by a sliding window."); + TVM_ATTR_FIELD(layout).describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + } +}; // struct MaxPool2dAttrs + +/*! \brief Attributes for 2d adaptive pool operator */ +struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> { + Optional<Array<IntImm>> output_size; + String layout; + String out_layout; + + TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relax.attrs.AdaptivePool2DAttrs") { + TVM_ATTR_FIELD(output_size).describe("Output height and width."); + TVM_ATTR_FIELD(layout).describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + } +}; // struct AdaptivePool2DAttrs + +/*! \brief Attributes used in softmax operators */ +struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> { + int axis; + + TVM_DECLARE_ATTRS(SoftmaxAttrs, "relax.attrs.SoftmaxAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis to sum over when computing softmax."); + } +}; + +/*! \brief Attributes used in batch_norm operator */ +struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> { + int axis; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(BatchNormAttrs, "relax.attrs.BatchNormAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis along which the normalization is applied."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct BatchNormAttrs + +/*! \brief Attributes used in layer_norm operator */ +struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> { + Array<Integer> axes; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(LayerNormAttrs, "relax.attrs.LayerNormAttrs") { + TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization is applied."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct LayerNormAttrs + +/*! \brief Attributes used in dropout operator */ +struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> { + double rate; + + TVM_DECLARE_ATTRS(DropoutAttrs, "relax.attrs.DropoutAttrs") { + TVM_ATTR_FIELD(rate).describe( + "Fraction of the input that gets dropped out during training time"); + } +}; // struct DropoutAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_NN_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 68152c2056..6c6fffc7c6 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -31,6 +31,7 @@ from .unary import * from . import builtin from . import image from . import memory +from . import nn def _register_op_make(): diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/nn/__init__.py similarity index 57% copy from python/tvm/relax/op/__init__.py copy to python/tvm/relax/op/nn/__init__.py index 68152c2056..af2aa106bc 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -14,31 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=wildcard-import, redefined-builtin -"""Relax core operators.""" - -# Operators -from .base import * -from .binary import * -from .datatype import * -from .index import * -from .manipulate import * -from .op_attrs import * -from .statistical import * -from .set import * -from .ternary import * -from .unary import * -from . import builtin -from . import image -from . import memory - - -def _register_op_make(): - # pylint: disable=import-outside-toplevel - from . import _ffi_api - from .. import expr - - expr._op_ffi_api = _ffi_api # type: ignore - - -_register_op_make() +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from .nn import * diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/nn/_ffi_api.py similarity index 57% copy from python/tvm/relax/op/__init__.py copy to python/tvm/relax/op/nn/_ffi_api.py index 68152c2056..1785345ac1 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/nn/_ffi_api.py @@ -14,31 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=wildcard-import, redefined-builtin -"""Relax core operators.""" +"""Constructor APIs""" +import tvm._ffi -# Operators -from .base import * -from .binary import * -from .datatype import * -from .index import * -from .manipulate import * -from .op_attrs import * -from .statistical import * -from .set import * -from .ternary import * -from .unary import * -from . import builtin -from . import image -from . import memory - - -def _register_op_make(): - # pylint: disable=import-outside-toplevel - from . import _ffi_api - from .. import expr - - expr._op_ffi_api = _ffi_api # type: ignore - - -_register_op_make() +tvm._ffi._init_api("relax.op.nn", __name__) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py new file mode 100644 index 0000000000..cdf0e96464 --- /dev/null +++ b/python/tvm/relax/op/nn/nn.py @@ -0,0 +1,524 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Neural Network (NN) operators""" +from typing import List, Optional, Tuple, Union + +from tvm import DataType + +from . import _ffi_api +from ...expr import Expr + + +def conv2d( + data: Expr, + weight: Expr, + strides: Union[int, Tuple[int, int]] = (1, 1), + padding: Union[int, Tuple[int, ...]] = (0, 0), + dilation: Union[int, Tuple[int, int]] = (1, 1), + groups: int = 1, + data_layout: str = "NCHW", + kernel_layout: str = "OIHW", + out_layout: Optional[str] = None, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + r"""2D convolution. + + This operator takes the weight as the convolution kernel + and convolves it with data to produce an output. + + + In the default case, where the data_layout is `NCHW` + and kernel_layout is `OIHW`, conv2d takes in + a data Tensor with shape `(batch_size, in_channels, height, width)`, + and a weight Tensor with shape `(channels, in_channels, kernel_h, kernel_w)`, + where `kernel_h` and `kernel_w` is the lengths of the `H` and `W` kernel dimensions, + to produce an output Tensor with the following rule: + + .. math:: + + \mbox{out}[b, c, y, x] = \sum_{dy, dx, k} + \mbox{data}[b, k, \mbox{strides}[0] * y + dy, \mbox{strides}[1] * x + dx] * + \mbox{weight}[c, k, dy, dx] + + Padding and dilation are applied to data and weight respectively before the computation. + This operator accepts data layout specification. + Semantically, the operator will convert the layout to the canonical layout + (`NCHW` for data and `OIHW` for weight), perform the computation, + then convert to the out_layout. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + weight : relax.Expr + The weight expressions. + + strides : Union[int, Tuple[int, int]] + The strides of convolution. It is required to have length either 1 or 2. + + padding : Union[int, Tuple[int, ...]] + The padding of convolution on both sides of inputs before convolution. + It is required to have length either 1, 2 or 4. + + dilation : Union[int, Tuple[int, int]] + Specifies the dilation rate to be used for dilated convolution. + It is required to have length either 1 or 2. + + groups : int + Number of groups to split the input into for grouped convolution. + The number of input and output channels should be divisible by the number of groups. + + data_layout : str + Layout of the input. + + kernel_layout : str + Layout of the weight. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + out_dtype : Optional[Union[str, DataType]] + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(strides, int): + strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding, padding, padding) + + return _ffi_api.conv2d( # type: ignore + data, + weight, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + out_layout, + out_dtype, + ) + + +def max_pool2d( + data: Expr, + pool_size: Union[int, Tuple[int, int]] = (1, 1), + strides: Union[int, Tuple[int, int]] = (1, 1), + padding: Union[int, Tuple[int, ...]] = (0, 0), + dilation: Union[int, Tuple[int, int]] = (1, 1), + ceil_mode: bool = False, + layout: str = "NCHW", + out_layout: Optional[str] = None, +) -> Expr: + r"""2D maximum pooling operator. + + This operator takes data as input and does 2D max value calculation + with in pool_size sized window by striding defined by stride + + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, h, w) and pool_size (kh, kw) + + .. math:: + + \mbox{out}(b, c, y, x) = \max_{m=0, \ldots, kh-1} \max_{n=0, \ldots, kw-1} + \mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x + n) + + Padding is applied to data before the computation. + ceil_mode is used to take ceil or floor while computing out shape. + This operator accepts data layout specification. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + pool_size : Union[int, Tuple[int, int]] + The size of window for pooling. It is required to have length either 1 or 2. + + strides : Union[int, Tuple[int, int]] + The strides of pooling. It is required to have length either 1 or 2. + + padding : Union[int, Tuple[int, ...]] + The padding for pooling. It is required to have length either 1, 2 or 4. + + dilation : Union[int, Tuple[int, int]] + The dilation of pooling. It is required to have length either 1 or 2. + + ceil_mode : bool + A boolean indicating if use ceil or floor to compute the output shape. + By using ceil, every element in the input tensor will be covered by a sliding window. + + layout : str + Layout of the input. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + Returns + ------- + result : Expr + The computed result. + """ + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding, padding, padding) + + return _ffi_api.max_pool2d( # type: ignore + data, pool_size, strides, padding, dilation, ceil_mode, layout, out_layout + ) + + +def adaptive_avg_pool2d( + data: Expr, + output_size: Optional[Union[int, Tuple[int, int]]] = None, + layout: str = "NCHW", + out_layout: Optional[str] = None, +) -> Expr: + r"""2D adaptive average pooling operator. This operator is experimental. + + This operator takes data as input and does 2D average value calculation + across each window represented by WxH. + + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with shape + (batch_size, in_channels, output_height, output_width). + + The pooling kernel and stride sizes are automatically chosen for + desired output sizes. + + For output_size: + If this argument is not provided, input height and width will be used + as output height and width. + + If a single integer is provided for output_size, the output size is + (N x C x output_size x output_size) for any input (NCHW). + + If a tuple of integers (height, width) are provided for output_size, + the output size is (N x C x height x width) for any input (NCHW). + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + output_size : Optional[Union[int, Tuple[int, int]]] + Output height and width. + If not specified, it will be the same as the input height and width. + If specified, it is required to have length either 1 or 2. + + layout : str + Layout of the input. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(output_size, int): + output_size = (output_size, output_size) + return _ffi_api.adaptive_avg_pool2d(data, output_size, layout, out_layout) # type: ignore + + +def relu(data: Expr) -> Expr: + """Rectified linear unit. + + .. math:: + text{ReLU}(x) = max(x, 0) + + Parameters + ---------- + data : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.relu(data) # type: ignore + + +def gelu(data: Expr) -> Expr: + """Gaussian Error Linear Units function + + .. math:: + text{GeLU}(x) = 0.5 * x * (1 + erf(x * 0.5**0.5)) + + where :math:`erf` is the Gauss Error function. + + Parameters + ---------- + data : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.gelu(data) # type: ignore + + +def silu(data: Expr) -> Expr: + """Sigmoid Linear Unit function + + .. math:: + text{SiLU}(x) = x * sigmoid(x) + + Parameters + ---------- + data : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.silu(data) # type: ignore + + +def softmax(data: Expr, axis: int = -1) -> Expr: + r"""Computes softmax. + + .. math:: text{softmax}(x)_i = frac{exp(x_i)}{\sum_j exp(x_j)} + + Parameters + ---------- + data: relax.Expr + The input data to the operator. + + axis: int + The axis to sum over when computing softmax. + If not specified, it is by default the last axis of the input tensor. + Supports negative indexing. + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.softmax(data, axis) # type: ignore + + +def batch_norm( + data: Expr, + gamma: Expr, + beta: Expr, + moving_mean: Expr, + moving_var: Expr, + axis: int, + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Batch normalization layer (Ioffe and Szegedy, 2014). + Normalizes the input at each batch, i.e. applies a transformation + that maintains the mean activation close to 0 and the activation + standard deviation close to 1. + + .. math:: + + data\_mean[i] = mean(data[:,i,:,...]) \\ + data\_var[i] = var(data[:,i,:,...]) + + Then compute the normalized output, which has the same shape as input, as following: + + .. math:: + + out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} + * gamma[i] + beta[i] + + Both *mean* and *var* returns a scalar by treating the input as a vector. + + Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` + have shape *(k,)*. + + Besides the inputs and the outputs, this operator accepts two auxiliary + states, ``moving_mean`` and ``moving_var``, which are *k*-length + vectors. They are global statistics for the whole dataset, which are updated by + + .. code:: python + + moving_mean = moving_mean * momentum + data_mean * (1 - momentum) + moving_var = moving_var * momentum + data_var * (1 - momentum) + + The parameter ``axis`` specifies which axis of the input shape denotes + the 'channel' (separately normalized groups). The default is 1. + Specifying -1 sets the channel axis to be the last item in the input shape. + + .. note:: + + This operator can be optimized away for inference. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + moving_mean : relax.Expr + Running mean of input. + + moving_var : relax.Expr + Running variance of input. + + axis : int + The axis along which the normalization is applied. + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.batch_norm( # type: ignore + data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale + ) + + +def layer_norm( + data: Expr, + gamma: Expr, + beta: Expr, + axes: Union[int, List[int]], + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Layer normalization (Lei Ba and et al., 2016). + Applies layer normalization to the n-dimensional input array. + This operator takes an n-dimensional input array and normalizes + the input using the given axis: + + .. math:: + + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}} + * gamma + beta + + Unlike batch normalization, the mean and var are computed along the channel dimension. + + Assume the input has size k on axis 1, then both gamma and beta have shape (k,). + + .. note:: + + This operator can be optimized away for inference. + + Parameters + ---------- + data : relax.Expr + Input to which layer_norm will be applied. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + axes : Union[int, List[int]] + The axes that along which the normalization is applied. + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axes, int): + axes = [axes] + return _ffi_api.layer_norm(data, gamma, beta, axes, epsilon, center, scale) # type: ignore + + +def dropout(data: Expr, rate: float = 0.5) -> Expr: + """Applies the dropout operation to the input tensor. + + During training, each element of the input is set to zero with + probability ``p``. The whole array is scaled by ``1/(1-p)`` + to keep the expected sum of the input unchanged. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + rate : float + The probability for an element to be reset to 0. + + Returns + ------- + result : relax.Expr + The result of dropout, which is a tuple of two tensors. + The first one is the original tensor and the second one is a + mask tensor (1.0 where element not dropped, 0.0 where dropped) + """ + return _ffi_api.dropout(data, rate) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 1fb8853040..68f84b3514 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -34,6 +34,41 @@ class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" +@tvm._ffi.register_object("relax.attrs.Conv2DAttrs") +class Conv2DAttrs(Attrs): + """Attributes for nn.conv2d""" + + +@tvm._ffi.register_object("relax.attrs.MaxPool2DAttrs") +class MaxPool2DAttrs(Attrs): + """Attributes for nn.max_pool2d""" + + +@tvm._ffi.register_object("relax.attrs.AdaptivePool2DAttrs") +class AdaptivePool2DAttrs(Attrs): + """Attributes for 2d adaptive pool operator""" + + +@tvm._ffi.register_object("relax.attrs.SoftmaxAttrs") +class SoftmaxAttrs(Attrs): + """Attributes for nn.softmax""" + + +@tvm._ffi.register_object("relax.attrs.BatchNormAttrs") +class BatchNormAttrs(Attrs): + """Attributes used in batch_norm operator""" + + +@tvm._ffi.register_object("relax.attrs.LayerNormAttrs") +class LayerNormAttrs(Attrs): + """Attributes used in layer_norm operator""" + + +@tvm._ffi.register_object("relax.attrs.DropoutAttrs") +class DropoutAttrs(Attrs): + """Attributes for dropout operator""" + + @tvm._ffi.register_object("relax.attrs.StatisticalAttrs") class StatisticalAttrs(Attrs): """Attributes used in statistical operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 47779a6024..1f0e31428c 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -93,6 +93,7 @@ from tvm.relax.op import ( tan, tanh, unique, + nn, ) from tvm.relax.struct_info import StructInfo from tvm.relax.utils import args_converter @@ -530,4 +531,5 @@ __all__ = [ "tuple", "variance", "unique", + "nn", ] diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc new file mode 100644 index 0000000000..a3ddd3e350 --- /dev/null +++ b/src/relax/op/nn/convolution.cc @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/op/nn/convolution.cc + * \brief Convolution operators + */ + +#include "convolution.h" + +#include <vector> + +namespace tvm { +namespace relax { + +/* relax.nn.conv2d */ +TVM_REGISTER_NODE_TYPE(Conv2DAttrs); + +Expr conv2d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding, + Array<IntImm> dilation, int groups, String data_layout, String kernel_layout, + Optional<String> out_layout, DataType out_dtype) { + padding = GetCompletePadding2D(std::move(padding)); + if (strides.size() == 1) { + strides.push_back(strides[0]); + } + if (dilation.size() == 1) { + dilation.push_back(dilation[0]); + } + + CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " + "the given number of groups is " + << groups; + CHECK_EQ(strides.size(), 2) + << "The input strides length is expected to be 2. However, the given strides is " << strides; + CHECK_EQ(dilation.size(), 2) + << "The input dilation length is expected to be 2. However, the given dilation is " + << dilation; + return MakeConv<Conv2DAttrs>(std::move(data), std::move(weight), std::move(strides), + std::move(padding), std::move(dilation), groups, data_layout, + std::move(kernel_layout), out_layout.value_or(data_layout), + out_dtype, /*op_name=*/"relax.nn.conv2d"); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.conv2d").set_body_typed(conv2d); + +StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { + Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo weight_sinfo = input_sinfo[1]; + + const auto* attrs = call->attrs.as<Conv2DAttrs>(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->data_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + auto [weight_layout, weight2OIHW] = CheckTensorLayout(call, ctx, attrs->kernel_layout, // + /*tgt_layout=*/"OIHW", // + /*tensor_name=*/"kernel"); + auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"output"); + + Optional<ShapeExpr> data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + Optional<ShapeExpr> weight_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + + DataType out_dtype = attrs->out_dtype.is_void() + ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + : attrs->out_dtype; + if (!data_shape.defined() || !weight_shape.defined()) { + return TensorStructInfo(out_dtype, out_layout.ndim()); + } + + Array<PrimExpr> data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + Array<PrimExpr> weight_OIHW_shape = weight2OIHW.ForwardShape(weight_shape.value()->values); + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr input_channel_data = data_NCHW_shape[1]; + PrimExpr input_channel_kernel = weight_OIHW_shape[1]; + if (analyzer->CanProve(input_channel_data != input_channel_kernel * attrs->groups)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "The channel size of the data should equal to the product of input channel size of the " + "weight and the number of groups. However, the data channel size is " + << input_channel_data << " while the weight input channel size and number of groups are " + << input_channel_kernel << " and " << attrs->groups); + } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel * attrs->groups)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + if (analyzer->CanProve(floormod(weight_OIHW_shape[0], attrs->groups) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Conv2d expects the number of output channels to be divisible by the " + "number of groups. However, the number of output channels is " + << weight_OIHW_shape[0] << " while the number of groups is " << attrs->groups); + } else if (!analyzer->CanProveEqual(floormod(weight_OIHW_shape[0], attrs->groups), 0)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + + PrimExpr input_h = data_NCHW_shape[2]; + PrimExpr input_w = data_NCHW_shape[3]; + PrimExpr kernel_h = weight_OIHW_shape[2]; + PrimExpr kernel_w = weight_OIHW_shape[3]; + PrimExpr padding_h = attrs->padding[0] + attrs->padding[2]; + PrimExpr padding_w = attrs->padding[1] + attrs->padding[3]; + + std::vector<PrimExpr> out_NCHW_shape; + out_NCHW_shape.resize(4); + out_NCHW_shape[0] = data_NCHW_shape[0]; + out_NCHW_shape[1] = weight_OIHW_shape[0]; + + PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1; + out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); + out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); + + Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), out_dtype); +} + +TVM_REGISTER_OP("relax.nn.conv2d") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_attrs_type<Conv2DAttrs>() + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2d); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h new file mode 100644 index 0000000000..a65617b48d --- /dev/null +++ b/src/relax/op/nn/convolution.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file convolution.h + * \brief The functions to make Relax neural network convolution operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_CONVOLUTION_H_ +#define TVM_RELAX_OP_NN_CONVOLUTION_H_ + +#include <tvm/relax/attrs/nn.h> + +#include <string> +#include <utility> + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +template <typename T> +inline Expr MakeConv(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding, + Array<IntImm> dilation, int groups, String data_layout, String kernel_layout, + String out_layout, DataType out_dtype, std::string op_name) { + auto attrs = make_object<T>(); + attrs->strides = ConvertIntImmToInt64(strides); + attrs->padding = ConvertIntImmToInt64(padding); + attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->groups = groups; + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get(op_name); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +/*! \brief 2D convolution */ +Expr conv2d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding, + Array<IntImm> dilation, int groups, String data_layout, String kernel_layout, + Optional<String> out_layout, DataType out_dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_CONVOLUTION_H_ diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc new file mode 100644 index 0000000000..66ae10fe6c --- /dev/null +++ b/src/relax/op/nn/nn.cc @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "nn.h" + +#include <utility> +#include <vector> + +namespace tvm { +namespace relax { + +/* relax.nn.relu */ +RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(relu, "nn.relu", /*require_float_dtype=*/false); + +/* relax.nn.gelu */ +RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu, "nn.gelu", /*require_float_dtype=*/true); + +/* relax.nn.silu */ +RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu", /*require_float_dtype=*/true); + +/* relax.nn.softmax */ +TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); + +Expr softmax(Expr data, int axis) { + auto attrs = make_object<SoftmaxAttrs>(); + attrs->axis = axis; + static const Op& op = Op::Get("relax.nn.softmax"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.softmax").set_body_typed(softmax); + +StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (data_sinfo->IsUnknownNdim()) { + return data_sinfo; + } + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input tensor to have float " + "dtype. However, the given input dtype is " + << data_sinfo->dtype); + } + const auto* attrs = call->attrs.as<SoftmaxAttrs>(); + NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); + + return data_sinfo; +} + +TVM_REGISTER_OP("relax.nn.softmax") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type<SoftmaxAttrs>() + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSoftmax); + +bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, + const Array<TensorStructInfo>& input_sinfo, Array<Integer> axes) { + Op op = Downcast<Op>(call->op); + int n_input = op->arguments.size(); + + TensorStructInfo data_sinfo = input_sinfo[0]; + + std::vector<int> axes_non_neg; + if (!data_sinfo->IsUnknownNdim()) { + axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes); + } + int n_axis = axes.size(); + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + ctx->ReportFatal( + Diagnostic::Error(call) + << op << " requires the input data to have float dtype. However, the given data dtype is " + << data_sinfo->dtype); + } + for (int i = 1; i < n_input; ++i) { + if (input_sinfo[i]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " requires all the input tensors to have the same dtype. However, the " + << op->arguments[i]->name << " has dtype " << input_sinfo[i]->dtype + << " which is other than the input data's dtype " << data_sinfo->dtype); + } else if (input_sinfo[i]->ndim != n_axis) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " requires the input " << op->arguments[i]->name + << " to have as many dimensions as the length of input axes. However, the " + "given one has ndim " + << input_sinfo[i]->ndim << ", which is other than the length of axes " + << n_axis); + } + } + + std::vector<Array<PrimExpr>> axis_lengths; + axis_lengths.reserve(n_input); + if (const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>()) { + std::vector<PrimExpr> lengths; + lengths.reserve(n_axis); + for (int d = 0; d < n_axis; ++d) { + lengths.push_back(data_shape->values[axes_non_neg[d]]); + } + axis_lengths.push_back(lengths); + } + for (int i = 1; i < n_input; ++i) { + if (const auto* shape = input_sinfo[i]->shape.as<ShapeExprNode>()) { + axis_lengths.push_back(shape->values); + } + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + for (int i = 1; i < static_cast<int>(axis_lengths.size()); ++i) { + for (int d = 0; d < n_axis; ++d) { + if (analyzer->CanProve(axis_lengths[0][d] != axis_lengths[i][d])) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " requires the input gamma, beta, etc., to have size same as the " + "lengths of the data on the given axes. However, there exists " + << axis_lengths[0] << " and " << axis_lengths[i] << " that are unequal."); + } else if (!analyzer->CanProveEqual(axis_lengths[0][d], axis_lengths[i][d])) { + return true; + } + } + } + return false; +} + +/* relax.nn.batch_norm */ +TVM_REGISTER_NODE_TYPE(BatchNormAttrs); + +Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // + int axis, double epsilon, bool center, bool scale) { + ObjectPtr<BatchNormAttrs> attrs = make_object<BatchNormAttrs>(); + attrs->axis = axis; + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.batch_norm"); + return Call(op, + {std::move(data), std::move(gamma), std::move(beta), std::move(moving_mean), + std::move(moving_var)}, + Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm); + +StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { + Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as<BatchNormAttrs>(); + bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, {attrs->axis}); + + DataType dtype = input_sinfo[0]->dtype; + if (unknown_shape) { + return TupleStructInfo({TensorStructInfo(dtype, input_sinfo[0]->ndim), + TensorStructInfo(dtype, /*ndim=*/1), + TensorStructInfo(dtype, /*ndim=*/1)}); + } else { + return TupleStructInfo({input_sinfo[0], input_sinfo[3], input_sinfo[4]}); + } +} + +TVM_REGISTER_OP("relax.nn.batch_norm") + .set_attrs_type<BatchNormAttrs>() + .set_num_inputs(5) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .add_argument("moving_mean", "Tensor", "Running mean of input.") + .add_argument("moving_var", "Tensor", "Running variance of input.") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoBatchNorm); + +/* relax.nn.layer_norm */ +TVM_REGISTER_NODE_TYPE(LayerNormAttrs); + +Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double epsilon, bool center, + bool scale) { + ObjectPtr<LayerNormAttrs> attrs = make_object<LayerNormAttrs>(); + attrs->axes = std::move(axes); + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.layer_norm"); + return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.layer_norm").set_body_typed(layer_norm); + +StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { + Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as<LayerNormAttrs>(); + bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes); + + return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim) + : input_sinfo[0]; +} + +TVM_REGISTER_OP("relax.nn.layer_norm") + .set_attrs_type<LayerNormAttrs>() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm); + +/* relax.nn.dropout */ +TVM_REGISTER_NODE_TYPE(DropoutAttrs); + +Expr dropout(Expr data, double rate) { + ObjectPtr<DropoutAttrs> attrs = make_object<DropoutAttrs>(); + attrs->rate = rate; + + static const Op& op = Op::Get("relax.nn.dropout"); + return Call(op, {std::move(data)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.dropout").set_body_typed(dropout); + +StructInfo InferStructInfoDropout(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + return TupleStructInfo({data_sinfo, data_sinfo}); +} + +TVM_REGISTER_OP("relax.nn.dropout") + .set_attrs_type<DropoutAttrs>() + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input to which dropout will be applied.") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDropout); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h new file mode 100644 index 0000000000..df2b978fc2 --- /dev/null +++ b/src/relax/op/nn/nn.h @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file nn.h + * \brief The functions to make Relax neural network operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_NN_H_ +#define TVM_RELAX_OP_NN_NN_H_ + +#include <tvm/relax/attrs/nn.h> + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro to + * - expose a make-function interface which construct the call node. + * - register op to the registry. + * \param OpName The name of operator to register. + * \param OpRegName The identifier of the operator in the registry. + * \param RequireFloatDtype A boolean indicating if the input is required to have float dtype. + */ +#define RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(OpName, OpRegName, RequireFloatDtype) \ + RELAX_REGISTER_UNARY_OP(OpRegName).set_attr<FInferStructInfo>( \ + "FInferStructInfo", InferStructInfoUnaryArith<RequireFloatDtype>); \ + RELAX_UNARY_OP_INTERFACE(OpName, OpRegName); + +/*! \brief Rectified linear unit. */ +Expr relu(Expr data); + +/*! \brief Gaussian Error Linear Units function. */ +Expr gelu(Expr data); + +/*! \brief Sigmoid Linear Unit function. */ +Expr silu(Expr data); + +/*! \brief Softmax function. */ +Expr softmax(Expr data, int axis); + +/*! \brief Compute batch normalization. */ +Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // + int axis, double epsilon, bool center, bool scale); + +/*! \brief Compute layer normalization. */ +Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double epsilon, bool center, + bool scale); + +/*! + * \brief Applies the dropout operation to the input tensor. + * \param data The input data to the operator. + * \param rate The probability for an element to be reset to 0. + * \return A Tuple of two tensors. + * The first one is the original tensor and the second one is a + * mask tensor (1.0 where element not dropped, 0.0 where dropped) + */ +Expr dropout(Expr data, double rate); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_NN_H_ diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc new file mode 100644 index 0000000000..a4c1e6b17d --- /dev/null +++ b/src/relax/op/nn/pooling.cc @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "pooling.h" + +#include <utility> +#include <vector> + +namespace tvm { +namespace relax { + +/* relax.nn.max_pool2d */ +TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); + +Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, Array<IntImm> padding, + Array<IntImm> dilation, bool ceil_mode, String layout, + Optional<String> out_layout) { + padding = GetCompletePadding2D(std::move(padding)); + if (pool_size.size() == 1) { + pool_size.push_back(pool_size[0]); + } + if (strides.size() == 1) { + strides.push_back(strides[0]); + } + if (dilation.size() == 1) { + dilation.push_back(dilation[0]); + } + + CHECK_EQ(pool_size.size(), 2) + << "The input pool_size length is expected to be 2. However, the given pool_size is " + << pool_size; + CHECK_EQ(strides.size(), 2) + << "The input strides length is expected to be 2. However, the given strides is " << strides; + CHECK_EQ(dilation.size(), 2) + << "The input dilation length is expected to be 2. However, the given dilation is " + << dilation; + + auto attrs = make_object<MaxPool2DAttrs>(); + attrs->pool_size = std::move(pool_size); + attrs->strides = ConvertIntImmToInt64(strides); + attrs->padding = ConvertIntImmToInt64(padding); + attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->ceil_mode = ceil_mode; + attrs->layout = layout; + attrs->out_layout = out_layout.value_or(layout); + static const Op& op = Op::Get("relax.nn.max_pool2d"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d); + +StructInfo InferStructInfoMaxPool2D(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as<MaxPool2DAttrs>(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"output"); + + Optional<ShapeExpr> data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + if (!data_shape.defined()) { + return TensorStructInfo(data_sinfo->dtype, out_layout.ndim()); + } + + Array<PrimExpr> data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + + PrimExpr input_h = data_NCHW_shape[2]; + PrimExpr input_w = data_NCHW_shape[3]; + PrimExpr kernel_h = attrs->pool_size[0]; + PrimExpr kernel_w = attrs->pool_size[1]; + PrimExpr padding_h = attrs->padding[0] + attrs->padding[2]; + PrimExpr padding_w = attrs->padding[1] + attrs->padding[3]; + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + std::vector<PrimExpr> out_NCHW_shape; + out_NCHW_shape.resize(4); + out_NCHW_shape[0] = data_NCHW_shape[0]; + out_NCHW_shape[1] = data_NCHW_shape[1]; + + PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1; + if (attrs->ceil_mode) { + numerator_h += attrs->strides[0] - 1; + numerator_w += attrs->strides[1] - 1; + } + out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); + out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); + + Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.nn.max_pool2d") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attrs_type<MaxPool2DAttrs>() + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMaxPool2D); + +/* relax.nn.adaptive_avg_pool2d */ +TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs); + +Expr adaptive_avg_pool2d(Expr data, Optional<Array<IntImm>> output_size, String layout, + Optional<String> out_layout) { + ObjectPtr<AdaptivePool2DAttrs> attrs = make_object<AdaptivePool2DAttrs>(); + attrs->layout = layout; + attrs->out_layout = out_layout.value_or(layout); + if (output_size.defined()) { + Array<IntImm> _output_size = output_size.value(); + if (_output_size.size() == 1) { + _output_size.push_back(_output_size[0]); + } + CHECK_EQ(_output_size.size(), 2) + << "The output_size length is expected to be 2. However, the given output_size is " + << _output_size; + attrs->output_size = std::move(_output_size); + } + + static const Op& op = Op::Get("relax.nn.adaptive_avg_pool2d"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool2d").set_body_typed(adaptive_avg_pool2d); + +StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as<AdaptivePool2DAttrs>(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"output"); + + Optional<ShapeExpr> data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + if (!data_shape.defined()) { + if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && + !attrs->output_size.defined()) { + return data_sinfo; + } else { + return TensorStructInfo(data_sinfo->dtype, out_layout.ndim()); + } + } + + Array<PrimExpr> data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + Array<PrimExpr> out_NCHW_shape(data_NCHW_shape); + if (attrs->output_size.defined()) { + out_NCHW_shape.Set(2, attrs->output_size.value()[0]); + out_NCHW_shape.Set(3, attrs->output_size.value()[1]); + } + + Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") + .set_attrs_type<AdaptivePool2DAttrs>() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAdaptiveAvgPool2D); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h new file mode 100644 index 0000000000..3c1792d21f --- /dev/null +++ b/src/relax/op/nn/pooling.h @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file pooling.h + * \brief The functions to make Relax neural network pooling operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_POOLING_H_ +#define TVM_RELAX_OP_NN_POOLING_H_ + +#include <tvm/relax/attrs/nn.h> + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief 2D maximum pooling operator. */ +Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, Array<IntImm> padding, + Array<IntImm> dilation, bool ceil_mode, String layout, Optional<String> out_layout); + +/*! \brief 2D adaptive average pooling operator. */ +Expr adaptive_avg_pool2d(Expr data, Optional<Array<IntImm>> output_size, String layout, + Optional<String> out_layout); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_POOLING_H_ diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py new file mode 100644 index 0000000000..d047448309 --- /dev/null +++ b/tests/python/relax/test_op_nn.py @@ -0,0 +1,929 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.nn.relu(x).op == Op.get("relax.nn.relu") + assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu") + assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu") + assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax") + assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout") + + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + gamma = relax.Var("gamma", R.Tensor((3,), "float32")) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float32")) + assert relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1).op == Op.get( + "relax.nn.batch_norm" + ) + assert relax.op.nn.layer_norm(x, gamma, beta, axes=1).op == Op.get("relax.nn.layer_norm") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_linear_unit_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.nn.silu(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.gelu(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.nn.relu(x3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.nn.gelu(x4), relax.TensorStructInfo(dtype="")) + + +def test_linear_unit_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n), "float32")) + + +def test_linear_unit_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1, "float32")) + + +def test_linear_unit_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + + _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.nn.relu(x2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_linear_unit_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.gelu(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.silu(x1)) + + +def test_linear_unit_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.gelu(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.silu(x1)) + + +def test_softmax_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.nn.softmax(x2, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.nn.softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.nn.softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + + +def test_softmax_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32")) + + +def test_softmax_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo(s1, "float32")) + + +def test_softmax_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "float64")) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + + +def test_softmax_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x1)) + + +def test_softmax_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x, axis=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x, axis=-4)) + + +def test_softmax_wrong_with_multiple_axes(): + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + with pytest.raises(TVMError): + relax.op.nn.softmax(x, axis=[1, 2]) + with pytest.raises(TVMError): + relax.op.nn.softmax(x, axis=[-1, -2, -3]) + + +def test_softmax_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x1)) + + +def test_batch_norm_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor(ndim=4)) + x4 = relax.Var("x", R.Tensor()) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=1)) + gamma2 = relax.Var("gamma", R.Tensor(ndim=1)) + beta0 = relax.Var("beta", R.Tensor((3,), "float32")) + beta1 = relax.Var("beta", R.Tensor((3,))) + moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_mean1 = relax.Var("moving_mean", R.Tensor((3,))) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor("float32", ndim=1)) + moving_var2 = relax.Var("moving_var", R.Tensor(ndim=1)) + + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, axis=-3), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma0, beta0, moving_mean0, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma1, beta0, moving_mean0, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma1, beta0, moving_mean0, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x2, gamma1, beta0, moving_mean0, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x3, gamma2, beta1, moving_mean1, moving_var2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(ndim=4, dtype=""), + relax.TensorStructInfo((3,), dtype=""), + relax.TensorStructInfo(dtype="", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x4, gamma2, beta1, moving_mean1, moving_var2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo((3,), dtype=""), + relax.TensorStructInfo(dtype="", ndim=1), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c0 = tir.Var("c", "int64") + c1 = tir.Var("c", "int64") + h = tir.Var("h", "int64") + w = tir.Var("w", "int64") + x0 = relax.Var("x", R.Tensor((n, c0, h, w), "float32")) + x1 = relax.Var("x", R.Tensor((n, c1, h, w), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + gamma0 = relax.Var("gamma", R.Tensor((c0,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((c1,), "float32")) + gamma2 = relax.Var("gamma", R.Tensor("float32", ndim=1)) + beta = relax.Var("beta", R.Tensor((c0,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((c0,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((c0,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((c1,), "float32")) + moving_var2 = relax.Var("moving_var", R.Tensor("float32", ndim=1)) + + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((n, c0, h, w), "float32"), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo((c0,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x2, gamma0, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo((c0,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma2, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((n, c0, h, w), "float32"), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo((c0,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((n, c0, h, w), "float32"), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo()) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) + beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + moving_mean = relax.Var("moving_mean", relax.TensorStructInfo(s2, "float32")) + moving_var = relax.Var("moving_var", relax.TensorStructInfo(s3, "float32")) + + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma, beta, moving_mean, moving_var, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(s0, "float32"), + relax.TensorStructInfo(s2, "float32"), + relax.TensorStructInfo(s3, "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma, beta, moving_mean, moving_var, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(s1, "float32"), + relax.TensorStructInfo(s2, "float32"), + relax.TensorStructInfo(s3, "float32"), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) + gamma = relax.Var("gamma", R.Tensor((3,), "float16")) + beta = relax.Var("beta", R.Tensor((3,), "float16")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float16")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float16")) + + _check_inference( + bb, + relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float16"), + relax.TensorStructInfo((3,), "float16"), + relax.TensorStructInfo((3,), "float16"), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "int8")) + beta0 = relax.Var("beta", R.Tensor((3,), "int8")) + moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "int8")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32")) + gamma1 = relax.Var("gamma", R.Tensor((3,), "int32")) + beta1 = relax.Var("beta", R.Tensor((3,), "int32")) + moving_mean1 = relax.Var("moving_mean", R.Tensor((3,), "int32")) + moving_var1 = relax.Var("moving_var", R.Tensor((3,), "int32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma1, beta1, moving_mean1, moving_var1, axis=1)) + + +def test_batch_norm_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + gamma = relax.Var("gamma", R.Tensor((3,), "float32")) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=-5)) + + +def test_batch_norm_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((3,))) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((3,), "float16")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var1, axis=1)) + + +def test_batch_norm_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((3, 1), "float32")) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((1, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma1, beta, moving_mean, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma0, beta, moving_mean, moving_var1, axis=1)) + + +def test_batch_norm_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, c, 28, 28), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "float32")) + gamma2 = relax.Var("gamma", R.Tensor((c + 2,), "float32")) + beta0 = relax.Var("beta", R.Tensor((3,), "float32")) + beta1 = relax.Var("beta", R.Tensor((c,), "float32")) + moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_mean1 = relax.Var("moving_mean", R.Tensor((c,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((4,), "float32")) + moving_var2 = relax.Var("moving_var", R.Tensor((c,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta0, moving_mean0, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma2, beta1, moving_mean1, moving_var2, axis=1)) + + +def test_batch_norm_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((3,), "float32"))) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, moving_var, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var, axis=1)) + + +def test_layer_norm_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=2)) + gamma2 = relax.Var("gamma", R.Tensor((4, 5))) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, 5))) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, 3]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x2, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma1, beta0, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x3, gamma2, beta1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype=""), + ) + + +def test_layer_norm_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c0 = tir.Var("c", "int64") + c1 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((n, a, b, c0), "float32")) + x1 = relax.Var("x", R.Tensor((n, a, b, c1), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + gamma0 = relax.Var("gamma", R.Tensor((b, c0), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((b, c1), "float32")) + beta = relax.Var("beta", R.Tensor((b, c0), "float32")) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c0), "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma0, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x2, gamma0, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x2, gamma1, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_layer_norm_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo()) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=2)) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) + beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma, beta, axes=[2, 3]), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma, beta, axes=[2, 3]), + relax.TensorStructInfo(s1, "float32"), + ) + + +def test_layer_norm_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float16")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float64")) + gamma1 = relax.Var("gamma", R.Tensor((4, 5), "float64")) + beta1 = relax.Var("beta", R.Tensor((4, 5), "float64")) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float16"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float64"), + ) + + +def test_layer_norm_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "int8")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int32")) + gamma1 = relax.Var("gamma", R.Tensor((4, 5), "int32")) + beta1 = relax.Var("beta", R.Tensor((4, 5), "int32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_axis_out_of_range_and_repetitive(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4, 5), "float32")) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma, beta, axes=[3, 4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma, beta, axes=[3, -1])) + + +def test_layer_norm_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4, 5), "int8")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma1, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma0, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma1, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma0, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + c0 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, c0), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 6), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4, c0), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, c0 - 2), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((4, 5), "float32"))) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x1, gamma0, beta, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1])) + + +def test_dropout_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.nn.dropout(x0), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x2), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x3), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), dtype=""), relax.TensorStructInfo((2, 3), dtype="")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x4), + relax.TupleStructInfo([relax.TensorStructInfo(dtype=""), relax.TensorStructInfo(dtype="")]), + ) + + +def test_dropout_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor((m, n), "float32")) + + _check_inference( + bb, + relax.op.nn.dropout(x), + relax.TupleStructInfo( + [relax.TensorStructInfo((m, n), "float32"), relax.TensorStructInfo((m, n), "float32")] + ), + ) + + +def test_dropout_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, + relax.op.nn.dropout(x0), + relax.TupleStructInfo( + [relax.TensorStructInfo(s0, "float32"), relax.TensorStructInfo(s0, "float32")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x1), + relax.TupleStructInfo( + [relax.TensorStructInfo(s1, "float32"), relax.TensorStructInfo(s1, "float32")] + ), + ) + + +def test_dropout_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + + _check_inference( + bb, + relax.op.nn.dropout(x0), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "float64"), relax.TensorStructInfo((2, 3), "float64")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x1), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "int8"), relax.TensorStructInfo((2, 3), "int8")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x2), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "int64"), relax.TensorStructInfo((2, 3), "int64")] + ), + ) + + +def test_dropout_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.dropout(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.dropout(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py new file mode 100644 index 0000000000..6533d43420 --- /dev/null +++ b/tests/python/relax/test_op_nn_convolution.py @@ -0,0 +1,429 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + assert relax.op.nn.conv2d(x, w).op == Op.get("relax.nn.conv2d") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_conv2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 28, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor()) + x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=4)) + w3 = relax.Var("w", R.Tensor("float32")) + w4 = relax.Var("w", R.Tensor((48, 4, 3, 3, 16), "float32")) + + _check_inference( + bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, out_dtype="float16"), + relax.TensorStructInfo((2, 4, 26, 26), "float16"), + ) + _check_inference( + bb, relax.op.nn.conv2d(x0, w0, padding=1), relax.TensorStructInfo((2, 4, 28, 28), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, padding=[1, 2]), + relax.TensorStructInfo((2, 4, 28, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, padding=[1, 2, 3, 4]), + relax.TensorStructInfo((2, 4, 30, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, strides=2), + relax.TensorStructInfo((2, 4, 13, 13), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, strides=(2, 3)), + relax.TensorStructInfo((2, 4, 13, 9), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, dilation=2), + relax.TensorStructInfo((2, 4, 24, 24), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, dilation=(2, 1)), + relax.TensorStructInfo((2, 4, 24, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w0, data_layout="NHWC"), + relax.TensorStructInfo((2, 26, 26, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, out_layout="NHWC"), + relax.TensorStructInfo((2, 26, 26, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w1, kernel_layout="IOHW"), + relax.TensorStructInfo((2, 4, 26, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d( + x5, w4, data_layout="NCHW16c", kernel_layout="OIHW16i", out_layout="NHWC16c" + ), + relax.TensorStructInfo((2, 26, 26, 3, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv2d(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.nn.conv2d(x4, w0), relax.TensorStructInfo(dtype="", ndim=4)) + + +def test_conv2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + ki = tir.Var("ki", "int64") + ko = tir.Var("ko", "int64") + kh = tir.Var("kh", "int64") + kw = tir.Var("kw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + w0 = relax.Var("w", R.Tensor((ko, ki, kh, kw), "float32")) + w1 = relax.Var("w", R.Tensor((ko, c, kh, kw), "float32")) + w2 = relax.Var("w", R.Tensor((ko, c, kh, kw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0), + relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w1), + relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d( + x1, w2, data_layout="NCHW16c", kernel_layout="OIHW16i", out_layout="NCHW" + ), + relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, strides=(2, 2), padding=(1, 1), dilation=(2, 2)), + relax.TensorStructInfo( + (n, ko, tvm.tir.floordiv(ih + 3, 2) + 1 - kh, tvm.tir.floordiv(iw + 3, 2) + 1 - kw), + "float32", + ), + ) + + +def test_conv2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + w = relax.Var("w", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.nn.conv2d(x0, w), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w, data_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w, out_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x2, w), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_conv2d_infer_struct_info_groups(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 8, 28, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((48, 2, 3, 3, 8), "float32")) + + _check_inference( + bb, relax.op.nn.conv2d(x0, w0, groups=8), relax.TensorStructInfo((2, 48, 26, 26), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w1, kernel_layout="OIHW8i", groups=8), + relax.TensorStructInfo((2, 48, 26, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w0, data_layout="NCHW16c", groups=8), + relax.TensorStructInfo((2, 3, 26, 26, 16), "float32"), + ) + + +def test_conv2d_infer_struct_info_symbolic_groups(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x = relax.Var("x", R.Tensor((n, ic * 4, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((oc * 4, ic, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic, 3, 3), "float32")) + + _check_inference( + bb, + relax.op.nn.conv2d(x, w0, groups=4), + relax.TensorStructInfo((n, oc * 4, 26, 26), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv2d(x, w1, groups=4), relax.TensorStructInfo((n, oc, 26, 26), "float32") + ) + + +def test_conv2d_infer_struct_info_input_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((48, 20, 3, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic - 1, 3, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6)) + + +def test_conv2d_infer_struct_info_output_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 120, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((128, 20, 3, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc * 6 + 4, ic * 6, 3, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6)) + + +def test_conv2d_non_positive_group(): + x = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32")) + + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, groups=0) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, groups=-2) + + +def test_conv2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float64")) + w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float64")) + x2 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + w2 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8")) + x3 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32")) + w3 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int32")) + + _check_inference( + bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), "float16") + ) + _check_inference( + bb, relax.op.nn.conv2d(x1, w1), relax.TensorStructInfo((2, 4, 26, 26), "float64") + ) + _check_inference(bb, relax.op.nn.conv2d(x2, w2), relax.TensorStructInfo((2, 4, 26, 26), "int8")) + _check_inference( + bb, relax.op.nn.conv2d(x3, w3), relax.TensorStructInfo((2, 4, 26, 26), "int32") + ) + + +def test_conv2d_infer_struct_info_mixed_precision(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 28, 28))) + w2 = relax.Var("w", R.Tensor((4, 3, 3, 3))) + + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 26, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w1, out_dtype="int32"), + relax.TensorStructInfo((2, 4, 26, 26), "int32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x2, w2, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 26, 26), "float32"), + ) + + +def test_conv2d_unequal_input_channel(): + bb = relax.BlockBuilder() + ic = tir.Var("ic", "int64") + x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32")) + w0 = relax.Var("w", R.Tensor([3, 4, 3, 3], "float32")) + x1 = relax.Var("x", R.Tensor([2, ic, 28, 28], "float32")) + w1 = relax.Var("w", R.Tensor([4, ic + 2, 3, 3], "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w1)) + + +def test_conv2d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + conv2d = relax.op.nn.conv2d(x, w, strides=(1, 1), padding=(1, 1), dilation=(1, 1)) + + assert conv2d.attrs.strides[0].dtype == "int64" + assert conv2d.attrs.strides[1].dtype == "int64" + assert conv2d.attrs.padding[0].dtype == "int64" + assert conv2d.attrs.padding[1].dtype == "int64" + assert conv2d.attrs.padding[2].dtype == "int64" + assert conv2d.attrs.padding[3].dtype == "int64" + assert conv2d.attrs.dilation[0].dtype == "int64" + assert conv2d.attrs.dilation[1].dtype == "int64" + + +def test_conv2d_wrong_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, strides=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, dilation=(1, 2, 3)) + + +def test_conv2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w, data_layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w, kernel_layout="NHWC")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w, out_layout="OHWI")) + + +def test_conv2d_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w)) + + +def test_conv2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((4, 3, 6, 3, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=6)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w1, data_layout="NCHW16c")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x2, w0)) + + +def test_conv2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((4, 3, 3, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w0)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py new file mode 100644 index 0000000000..0eec5de21c --- /dev/null +++ b/tests/python/relax/test_op_nn_pooling.py @@ -0,0 +1,429 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d") + assert relax.op.nn.adaptive_avg_pool2d(x).op == Op.get("relax.nn.adaptive_avg_pool2d") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_max_pool2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + + _check_inference( + bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, pool_size=3), + relax.TensorStructInfo((2, 3, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, pool_size=(5, 3)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, relax.op.nn.max_pool2d(x0, padding=1), relax.TensorStructInfo((2, 3, 34, 34), "float32") + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, padding=[1, 2]), + relax.TensorStructInfo((2, 3, 34, 36), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, strides=2), + relax.TensorStructInfo((2, 3, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, dilation=2), + relax.TensorStructInfo((2, 3, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x1, layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, out_layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), + relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.max_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.max_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.nn.max_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.max_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4)) + + +def test_max_pool2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool2d( + x0, pool_size=(3, 3), strides=(3, 3), padding=(2, 2), dilation=(2, 2) + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(ih - 1, 3) + 1, + tvm.tir.floordiv(iw - 1, 3) + 1, + ), + "float32", + ), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), + relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + ) + + +def test_max_pool2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x1, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x2), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_max_pool2d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool2d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15, 16), "float32"), + ) + + +def test_max_pool2d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool2d( + x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 2), ceil_mode=True + ), + relax.TensorStructInfo((n, c, tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), "float32"), + ) + + +def test_max_pool2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + _check_inference( + bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") + ) + _check_inference(bb, relax.op.nn.max_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8")) + _check_inference( + bb, relax.op.nn.max_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") + ) + + +def test_conv2d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + max_pool2d = relax.op.nn.max_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 1), dilation=(1, 1)) + + assert max_pool2d.attrs.strides[0].dtype == "int64" + assert max_pool2d.attrs.strides[1].dtype == "int64" + assert max_pool2d.attrs.padding[0].dtype == "int64" + assert max_pool2d.attrs.padding[1].dtype == "int64" + assert max_pool2d.attrs.padding[2].dtype == "int64" + assert max_pool2d.attrs.padding[3].dtype == "int64" + assert max_pool2d.attrs.dilation[0].dtype == "int64" + assert max_pool2d.attrs.dilation[1].dtype == "int64" + + +def test_max_pool2d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, pool_size=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, strides=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, dilation=(1, 2, 3)) + + +def test_max_pool2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x, layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x, out_layout="OHWI")) + + +def test_max_pool2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x1)) + + +def test_max_pool2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x1)) + + +def test_adaptive_avg_pool2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=30), + relax.TensorStructInfo((2, 3, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=(28, 30)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x1, layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), + relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4) + ) + + +def test_adaptive_avg_pool2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((n, c, ih, iw), "float32") + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=256), + relax.TensorStructInfo((n, c, 256, 256), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=(256, 128)), + relax.TensorStructInfo((n, c, 256, 128), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), + relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + ) + + +def test_adaptive_avg_pool2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=32), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x1, layout="NCHW16c"), + relax.TensorStructInfo(s1, "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x2, out_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + + +def test_adaptive_avg_pool2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") + ) + + +def test_adaptive_avg_pool2d_wrong_output_size_ndim(): + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + with pytest.raises(TVMError): + relax.op.nn.adaptive_avg_pool2d(x, (32, 32, 32)) + + +def test_adaptive_avg_pool2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x, layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x, out_layout="OHWI")) + + +def test_adaptive_avg_pool2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1)) + + +def test_adaptive_avg_pool2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py new file mode 100644 index 0000000000..4e52bccb86 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_conv2d(): + @R.function + def foo( + x: R.Tensor((2, 3, 228, 228), "float32"), w: R.Tensor((16, 3, 5, 5), "float32") + ) -> R.Tensor((2, 16, 224, 224), "float16"): + gv: R.Tensor((2, 16, 224, 224), "float16") = R.nn.conv2d(x, w, out_dtype="float16") + return gv + + x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float32")) + w = relax.Var("w", R.Tensor([16, 3, 5, 5], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, w]): + gv = bb.emit(relax.op.nn.conv2d(x, w, out_dtype="float16")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_max_pool2d(): + @R.function + def foo( + x: R.Tensor((1, 1, 32, 32), dtype="float32") + ) -> R.Tensor((1, 1, 30, 30), dtype="float32"): + gv: R.Tensor((1, 1, 30, 30), dtype="float32") = R.nn.max_pool2d(x, pool_size=(3,)) + return gv + + x = relax.Var("x", R.Tensor([1, 1, 32, 32], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.max_pool2d(x, pool_size=(3,))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_adaptive_avg_pool2d(): + @R.function + def foo(x: R.Tensor((2, 64, 8, 9), "float32")) -> R.Tensor((2, 64, 7, 7), "float32"): + gv: R.Tensor((2, 64, 7, 7), "float32") = R.nn.adaptive_avg_pool2d(x, output_size=(7, 7)) + return gv + + x = relax.Var("x", R.Tensor((2, 64, 8, 9), dtype="float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.adaptive_avg_pool2d(x, output_size=(7, 7))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_gelu(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.gelu(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.gelu(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_softmax(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.softmax(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.softmax(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_batch_norm(): + @R.function + def foo( + x: R.Tensor((2, 4, 3, 3), dtype="float32"), + gamma: R.Tensor((4,), dtype="float32"), + beta: R.Tensor((4,), dtype="float32"), + moving_mean: R.Tensor((4,), dtype="float32"), + moving_var: R.Tensor((4,), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 4, 3, 3), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ): + gv: R.Tuple( + R.Tensor((2, 4, 3, 3), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) + return gv + + x = relax.Var("x", R.Tensor((2, 4, 3, 3), "float32")) + gamma = relax.Var("gamma", R.Tensor((4,), "float32")) + beta = relax.Var("beta", R.Tensor((4,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((4,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((4,), "float32")) + + bb = relax.BlockBuilder() + with bb.function("foo", [x, gamma, beta, moving_mean, moving_var]): + gv = bb.emit(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_layer_norm(): + @R.function + def foo( + x: R.Tensor((2, 3, 4, 5), "float32"), + gamma: R.Tensor((4, 5), "float32"), + beta: R.Tensor((4, 5), "float32"), + ) -> R.Tensor((2, 3, 4, 5), "float32"): + gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.layer_norm(x, gamma, beta, axes=[-2, -1]) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4, 5), "float32")) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, gamma, beta]): + gv = bb.emit(relax.op.nn.layer_norm(x, gamma, beta, axes=[-2, -1])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_dropout(): + @R.function + def foo( + x: R.Tensor((2, 3), "float32") + ) -> R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")): + gv: R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")) = R.nn.dropout( + x, rate=0.5 + ) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.dropout(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main()