This is an automated email from the ASF dual-hosted git repository. comaniac pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 73a1a9a [TOPI] deformable_conv2d in NHWC (#6999) 73a1a9a is described below commit 73a1a9a69f62281d61148280a023e58e6dcd08f0 Author: Wuwei Lin <vincentl...@gmail.com> AuthorDate: Tue Dec 1 14:59:09 2020 -0500 [TOPI] deformable_conv2d in NHWC (#6999) * [TOPI] deformable_conv2d in NHWC * Update python/tvm/topi/generic/nn.py Co-authored-by: Cody Yu <comaniac0...@gmail.com> * Update python/tvm/topi/testing/deformable_conv2d_python.py Co-authored-by: Cody Yu <comaniac0...@gmail.com> * style * fix * style Co-authored-by: Cody Yu <comaniac0...@gmail.com> --- include/tvm/topi/detail/tensor_utils.h | 37 +++++++ python/tvm/topi/generic/nn.py | 18 ++++ python/tvm/topi/nn/deformable_conv2d.py | 110 ++++++++++++++++++++- python/tvm/topi/testing/__init__.py | 2 +- ..._nchw_python.py => deformable_conv2d_python.py} | 49 +++++++++ src/topi/schedule.cc | 4 + .../topi/python/test_topi_deformable_conv2d.py | 95 +++++++++++++++++- 7 files changed, 311 insertions(+), 4 deletions(-) diff --git a/include/tvm/topi/detail/tensor_utils.h b/include/tvm/topi/detail/tensor_utils.h index 7004c35..65a760b 100644 --- a/include/tvm/topi/detail/tensor_utils.h +++ b/include/tvm/topi/detail/tensor_utils.h @@ -89,6 +89,43 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>& D * x_lerp * y_lerp; } +/*! + * \brief Sample a point in a tensor using bilinear interpolation. + * + * \param input The input tensor. + * \param indices The index of the target point, which can be fractional + * \param max_y The maximum of y dimension + * \param max_x The maximum of x dimension + * + * \return The interpolated value in the given index. + */ +inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array<PrimExpr>& indices, + const PrimExpr max_y, const PrimExpr max_x) { + auto in_y = indices[1]; + auto yf = tvm::floor(in_y); + auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y)); + + auto y0 = tvm::cast(DataType::Int(32), tvm::floor(in_y)); + auto y1 = tvm::if_then_else((yc > max_y), max_y, yc); + auto y_lerp = in_y - yf; + + auto in_x = indices[2]; + auto xf = tvm::floor(in_x); + auto xc = tvm::cast(DataType::Int(32), tvm::ceil(in_x)); + + auto x0 = tvm::cast(DataType::Int(32), tvm::floor(in_x)); + auto x1 = tvm::if_then_else((xc > max_x), max_x, xc); + auto x_lerp = in_x - xf; + + auto A = input(indices[0], y0, x0, indices[3]); + auto B = input(indices[0], y0, x1, indices[3]); + auto C = input(indices[0], y1, x0, indices[3]); + auto D = input(indices[0], y1, x1, indices[3]); + + return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp + + D * x_lerp * y_lerp; +} + } // namespace detail } // namespace topi } // namespace tvm diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 4bc3f97..60ccd0d 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -462,6 +462,24 @@ def schedule_deformable_conv2d_nchw(outs): return _default_schedule(outs, False) +def schedule_deformable_conv2d_nhwc(outs): + """Schedule for deformable_conv2d_nhwc. + We only use the default schedule here and rely on auto_scheduler. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of deformable_conv2d_nhwc + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_bitserial_conv2d_nchw(outs): """Schedule for bitserial_conv2d_nchw diff --git a/python/tvm/topi/nn/deformable_conv2d.py b/python/tvm/topi/nn/deformable_conv2d.py index a8c2745..780530c 100644 --- a/python/tvm/topi/nn/deformable_conv2d.py +++ b/python/tvm/topi/nn/deformable_conv2d.py @@ -21,7 +21,7 @@ from tvm import te from .utils import get_pad_tuple from ..utils import get_const_tuple -from ..cpp.utils import bilinear_sample_nchw +from ..cpp.utils import bilinear_sample_nchw, bilinear_sample_nhwc def deformable_conv2d_nchw( @@ -130,3 +130,111 @@ def deformable_conv2d_nchw( ), tag="deformable_conv2d_nchw", ) + + +def deformable_conv2d_nhwc( + data, offset, kernel, strides, padding, dilation, deformable_groups, groups, out_dtype +): + """Deformable conv2D operator in NHWC layout. + + The deformable convolution operation is described in https://arxiv.org/abs/1703.06211 + + Parameters + ---------- + data : tvm.te.Tensor + 4-D with shape [batch, in_height, in_width, in_channel] + + offset : tvm.te.Tensor + 4-D with shape [batch, out_height, out_width, + deformable_groups * filter_height * filter_width * 2]. + + kernel : tvm.te.Tensor + 4-D with shape [filter_height, filter_width, in_channel, num_filter] + + strides : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of two ints + padding size, or [pad_height, pad_width] + + dilation : int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + deformable_groups : int + number of deformable groups + + groups : int + number of groups + + Returns + ------- + output : tvm.te.Tensor + 4-D with shape [batch, out_height, out_width, out_channel] + """ + if out_dtype is None: + out_dtype = data.dtype + + if isinstance(strides, int): + stride_h = stride_w = strides + else: + stride_h, stride_w = strides + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = get_const_tuple(data.shape) + kernel_h, kernel_w, channel, out_channel = get_const_tuple(kernel.shape) + _, out_height, out_width, _ = get_const_tuple(offset.shape) + assert in_channel % deformable_groups == 0, "Input cahnnels must divide deformable group size" + assert groups == 1, "deformable_conv2d_nchw does not support groups > 1" + + ic_per_dgroup = channel // deformable_groups + + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, _, _ = get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w)) + rc = te.reduce_axis((0, in_channel), name="rc") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + + zero = tvm.tir.const(0.0, data.dtype) + + def _bilinear(n, h, w, c): + outside = tvm.tir.any(h < 0, w < 0, h >= in_height, w >= in_width) + val = bilinear_sample_nhwc(data, (n, h, w, c), in_height - 1, in_width - 1) + return tvm.tir.if_then_else(outside, zero, val) + + data_deform = te.compute( + (batch, kernel_h, kernel_w, in_channel, out_height, out_width), + lambda n, kh, kw, c, y, x: _bilinear( + n, + y * stride_h + - pad_top + + kh * dilation_h + + offset[ + n, y, x, c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2 + ], + x * stride_w + - pad_left + + kw * dilation_w + + offset[ + n, + y, + x, + c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2 + 1, + ], + c, + ), + tag="data_deform", + ) + return te.compute( + (batch, out_height, out_width, out_channel), + lambda n, y, x, f: te.sum( + data_deform[n, ry, rx, rc, y, x].astype(out_dtype) + * kernel[ry, rx, rc, f].astype(out_dtype), + axis=[ry, rx, rc], + ), + tag="deformable_conv2d_nhwc", + ) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 0654344..85f13a7 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -31,7 +31,7 @@ from .conv3d_transpose_ncdhw_python import conv3d_transpose_ncdhw_python from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python from .correlation_nchw_python import correlation_nchw_python -from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python +from .deformable_conv2d_python import deformable_conv2d_nchw_python, deformable_conv2d_nhwc_python from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python from .softmax_python import softmax_python, log_softmax_python diff --git a/python/tvm/topi/testing/deformable_conv2d_nchw_python.py b/python/tvm/topi/testing/deformable_conv2d_python.py similarity index 74% rename from python/tvm/topi/testing/deformable_conv2d_nchw_python.py rename to python/tvm/topi/testing/deformable_conv2d_python.py index 6a7afb4..0930843 100644 --- a/python/tvm/topi/testing/deformable_conv2d_nchw_python.py +++ b/python/tvm/topi/testing/deformable_conv2d_python.py @@ -119,3 +119,52 @@ def deformable_conv2d_nchw_python( b_np[n, f, h, w] += np.tensordot(a_deform[n, c, h, w], w_np[f, c]) return b_np + + +def deformable_conv2d_nhwc_python( + a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups +): + """Deformable convolution operator in NHWC layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_height, in_width, in_channel] + + offset_np : numpy.ndarray + 4-D with shape [batch, out_height, out_width, + deformable_groups * filter_height * filter_width * 2] + + w_np : numpy.ndarray + 4-D with shape [filter_height, filter_width, in_channel, num_filter] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str or a list/tuple of 2 or 4 ints + Padding size, or ['VALID', 'SAME'], or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 2 ints + + dilation : int or a list/tuple of two ints + Dilation size, or [dilate_height, dilate_width] + + deformable_groups : int + Number of deformable groups + + groups : int + Number of groups + + Returns + ------- + b_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + a_np = np.transpose(a_np, [0, 3, 1, 2]) # NHWC -> NCHW + offset_np = np.transpose(offset_np, [0, 3, 1, 2]) # NHWC -> NCHW + w_np = np.transpose(w_np, [3, 2, 0, 1]) # HWIO -> OIHW + b_np = deformable_conv2d_nchw_python( + a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups + ) + b_np = np.transpose(b_np, [0, 2, 3, 1]) # NCHW -> NHWC + return b_np diff --git a/src/topi/schedule.cc b/src/topi/schedule.cc index c315d40..f9400bf 100644 --- a/src/topi/schedule.cc +++ b/src/topi/schedule.cc @@ -190,6 +190,10 @@ TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw").set_body([](TVMArgs args, *rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]); }); +TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = detail::bilinear_sample_nhwc(args[0], args[1], args[2], args[3]); +}); + /*! \brief Builder function for instantiating schedules. */ using FTVMScheduleBuilder = std::function<tvm::te::Schedule( const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>; diff --git a/tests/python/topi/python/test_topi_deformable_conv2d.py b/tests/python/topi/python/test_topi_deformable_conv2d.py index 34bfae7..cd6f33f 100644 --- a/tests/python/topi/python/test_topi_deformable_conv2d.py +++ b/tests/python/topi/python/test_topi_deformable_conv2d.py @@ -26,11 +26,15 @@ from tvm.topi.utils import get_const_tuple import tvm.testing -_deformable_conv2d_implement = { +_deformable_conv2d_nchw_implement = { "generic": (topi.nn.deformable_conv2d_nchw, topi.generic.schedule_deformable_conv2d_nchw), "cuda": (topi.cuda.deformable_conv2d_nchw, topi.cuda.schedule_deformable_conv2d_nchw), } +_deformable_conv2d_nhwc_implement = { + "generic": (topi.nn.deformable_conv2d_nhwc, topi.generic.schedule_deformable_conv2d_nhwc), +} + def verify_deformable_conv2d_nchw( batch, @@ -94,7 +98,7 @@ def verify_deformable_conv2d_nchw( print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) - fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_implement) + fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_nchw_implement) with tvm.target.Target(device): C = fcompute(A, Offset, W, stride, padding, dilation, deformable_groups, groups, dtype) s = fschedule([C]) @@ -112,6 +116,86 @@ def verify_deformable_conv2d_nchw( check_device(device) +def verify_deformable_conv2d_nhwc( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dilation=1, + deformable_groups=1, + groups=1, +): + print( + "Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)" + % ( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dilation, + deformable_groups, + groups, + ) + ) + + A = te.placeholder((batch, in_size, in_size, in_channel), name="A") + out_size = (in_size - (kernel - 1) * dilation - 1 + 2 * padding) // stride + 1 + Offset = te.placeholder( + (batch, out_size, out_size, deformable_groups * kernel * kernel * 2), name="offset" + ) + W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W") + bias = te.placeholder((num_filter,), name="bias") + + a_shape = get_const_tuple(A.shape) + offset_shape = get_const_tuple(Offset.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_deformable_conv2d_nchw.verify_deformable_conv2d_nhwc") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + offset_np = np.random.randn(*offset_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np = tvm.topi.testing.deformable_conv2d_nhwc_python( + a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups + ) + + return a_np, offset_np, w_np, c_np + + a_np, offset_np, w_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not tvm.testing.device_enabled(device): + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_nhwc_implement) + with tvm.target.Target(device): + C = fcompute(A, Offset, W, stride, padding, dilation, deformable_groups, groups, dtype) + s = fschedule([C]) + + a = tvm.nd.array(a_np, ctx) + offset = tvm.nd.array(offset_np, ctx) + w = tvm.nd.array(w_np, ctx) + c = tvm.nd.empty(c_np.shape, dtype=c_np.dtype, ctx=ctx) + + func = tvm.build(s, [A, Offset, W, C], device) + func(a, offset, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in ["llvm"]: + check_device(device) + + @tvm.testing.uses_gpu def test_deformable_conv2d_nchw(): verify_deformable_conv2d_nchw(1, 16, 7, 16, 1, 1, 0, deformable_groups=4) @@ -119,5 +203,12 @@ def test_deformable_conv2d_nchw(): verify_deformable_conv2d_nchw(1, 16, 7, 16, 3, 1, 2, dilation=2) +def test_deformable_conv2d_nhwc(): + verify_deformable_conv2d_nhwc(1, 16, 7, 16, 1, 1, 0, deformable_groups=4) + verify_deformable_conv2d_nhwc(1, 16, 7, 16, 3, 1, 1, dilation=2, deformable_groups=4) + verify_deformable_conv2d_nhwc(1, 16, 7, 16, 3, 1, 2, dilation=2) + + if __name__ == "__main__": test_deformable_conv2d_nchw() + test_deformable_conv2d_nhwc()