This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 0421efb [RELAY,TOPI] Add scatter_nd op (#6854) 0421efb is described below commit 0421efba4c3a42c6cf8d692734c24fe8e08e3884 Author: Tristan Konolige <tristan.konol...@gmail.com> AuthorDate: Tue Dec 1 11:20:09 2020 -0800 [RELAY,TOPI] Add scatter_nd op (#6854) * [RELAY,TOPI] Add scatter_nd op Scatter_nd is the inverse of gather_nd and also happens to be its gradient. The implementation here is not optimized. There are no cpu or gpu specific implementations. * formatting * Fix tests * formatting * specify types on test * Fix grad test * scatter_nd cuda impl * cuda impl * x86 impl * formatting * fix shape rel * fix tests * formatting --- include/tvm/relay/attrs/transform.h | 8 ++ python/tvm/relay/backend/compile_engine.py | 5 +- python/tvm/relay/op/_tensor_grad.py | 7 ++ python/tvm/relay/op/_transform.py | 9 ++ python/tvm/relay/op/strategy/cuda.py | 13 +++ python/tvm/relay/op/strategy/generic.py | 22 +++++ python/tvm/relay/op/strategy/x86.py | 13 +++ python/tvm/relay/op/transform.py | 24 ++++++ python/tvm/relay/testing/__init__.py | 2 + python/tvm/te/operation.py | 6 +- python/tvm/topi/cuda/scatter.py | 106 +++++++++++++++++++++++ python/tvm/topi/scatter.py | 120 +++++++++++++++++++++++++- python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/common.py | 31 +++++++ python/tvm/topi/x86/__init__.py | 1 + python/tvm/topi/x86/scatter.py | 109 +++++++++++++++++++++++ src/relay/analysis/type_solver.cc | 9 +- src/relay/op/tensor/transform.cc | 68 +++++++++++++++ tests/python/relay/test_any.py | 5 +- tests/python/relay/test_op_grad_level3.py | 9 ++ tests/python/topi/python/test_topi_scatter.py | 67 ++++++++++++++ 21 files changed, 627 insertions(+), 8 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index a7830cf..3ed6b83 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -129,6 +129,14 @@ struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> { } }; +struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> { + Array<Integer> out_shape; + + TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") { + TVM_ATTR_FIELD(out_shape).describe("Output shape of the scatter."); + } +}; + struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> { Integer axis; diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 32affe7..a39f72e 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -122,7 +122,10 @@ def get_valid_implementations(op, attrs, inputs, out_type, target): The list of all valid op implementations. """ fstrategy = op.get_attr("FTVMStrategy") - assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name + assert fstrategy is not None, ( + "%s doesn't have an FTVMStrategy registered. You can register " + "one in python with `tvm.relay.op.register_strategy`." % op.name + ) with target: strategy = fstrategy(attrs, inputs, out_type, target) analyzer = tvm.arith.Analyzer() diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index b070d9f..9c84411 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -62,6 +62,7 @@ from .transform import ( squeeze, strided_set, arange, + scatter_nd, ) @@ -803,3 +804,9 @@ def arange_grad(orig, grad): grad_step = cast_like(_sum(grad_step), step) return [grad_start, grad_stop, grad_step] + + +@register_gradient("gather_nd") +def gather_nd_grad(orig, grad): + data, indices = orig.args + return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)] diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 439d44b..e1cb9e9 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -115,6 +115,15 @@ def compute_scatter_add(attrs, inputs, output_type): _reg.register_strategy("scatter_add", strategy.scatter_add_strategy) +# scatter +@_reg.register_compute("scatter_nd") +def compute_scatter_nd(attrs, inputs, output_type): + """Compute definition of scatter_nd""" + return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)] + + +_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy) + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index f37fc2a..bd96cad 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -776,6 +776,19 @@ def scatter_add_cuda(attrs, inputs, out_type, target): return strategy +@scatter_nd_strategy.register(["cuda", "gpu"]) +def scatter_nd_cuda(attrs, inputs, out_type, target): + """scatter_nd cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter_nd(topi.cuda.scatter_nd), + wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter_nd.cuda", + plevel=10, + ) + return strategy + + @argsort_strategy.register(["cuda", "gpu"]) def argsort_strategy_cuda(attrs, inputs, out_type, target): """argsort cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index e49135c..ac9d3b1 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1063,6 +1063,28 @@ def scatter_add_strategy(attrs, outs, out_type, target): return strategy +# scatter_nd +@override_native_generic_func("scatter_nd_strategy") +def scatter_nd_strategy(attrs, inputs, out_type, target): + """scatter_nd generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter_nd(topi.scatter_nd), + wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter_nd.generic", + ) + return strategy + + +def wrap_compute_scatter_nd(topi_compute): + """Wrap scatter_nd topi compute""" + + def _compute_scatter_nd(attrs, inputs, _): + return [topi_compute(inputs[0], inputs[1], attrs.out_shape)] + + return _compute_scatter_nd + + # bitserial_conv2d def wrap_compute_bitserial_conv2d(topi_compute): """wrap bitserial_conv2d topi compute""" diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 3c5735b..3f129c4 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -446,3 +446,16 @@ def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target): name="bitserial_dense.x86", ) return strategy + + +@scatter_nd_strategy.register("cpu") +def scatter_nd_strategy_cpu(attrs, inputs, out_type, target): + """scatter_nd x86 strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter_nd(topi.x86.scatter_nd), + wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter_nd.x86", + plevel=10, + ) + return strategy diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 19488a0..7e7f9b2 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -308,6 +308,30 @@ def scatter_add(data, indices, updates, axis): return _make.scatter_add(data, indices, updates, axis) +def scatter_nd(data, indices, out_shape): + """Scatter values from an array. + + See :py:func:`tvm.topi.scatter` for how data is scattered. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + out_shape : relay.Expr + Output shape of the scatter. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.scatter_nd(data, indices, out_shape) + + def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None): """Reshapes the input tensor by the size of another tensor. For an input tensor with shape ``(d0, d1, ..., d(k-1))``, `reshape_like` operation reshapes diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 9c87f27..93110e3 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -143,6 +143,8 @@ def check_grad( break grads = tmp + assert len(grads) > 0, "You must test at least one gradient." + # Get numeric gradients for each dimension of each param, using two-sided approximation. approx_grads = [] for x in test_inputs: diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 30d0df3..0f3457a 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -317,7 +317,11 @@ def extern( if isinstance(body, tvm.tir.PrimExpr): body = tvm.tir.Evaluate(body) if not isinstance(body, tvm.tir.Stmt): - raise ValueError("Function '{}' should return PrimExpr or Stmt".format(fcompute.__name__)) + raise ValueError( + "Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format( + fcompute.__name__, type(body) + ) + ) op = _ffi_api.ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body) res = [op.output(i) for i in range(len(output_placeholders))] diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 0a3e96f..5e03faf 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -18,6 +18,7 @@ """Scatter operator """ import tvm from tvm import te +from ..scatter import _verify_scatter_nd_inputs def ceil_div(a, b): @@ -522,3 +523,108 @@ def scatter_add(data, indices, updates, axis=0): ) return out + + +def scatter_nd(data, indices, shape): + """Scatter elements from a n-dimension array. + + Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + + .. code-block:: + + output[indices[0, y_0, ..., y_{K-1}], + ..., + indices[M-1, y_0, ..., y_{K-1}], + x_M, + ..., + x_{N-1} + ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + + all other entries in the output are 0. Repeated indices are summed. + + Parameters + ---------- + data : tvm.te.Tensor + The source array. + + indices : tvm.te.Tensor + The indices of the values to extract. + + shape : Sequence[int] + The output shape. This must be specified because it cannot be inferred. + + Returns + ------- + ret : tvm.te.Tensor + """ + _verify_scatter_nd_inputs(data, indices, shape) + + def gen_ir(data_ptr, indices_ptr, out_ptr): + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data_ptr) + indices = ib.buffer_ptr(indices_ptr) + out = ib.buffer_ptr(out_ptr) + + # We combine all the indices dimensions but the first one into a single + # dimension so we can iterate it in single loop instead of an arbitrary + # number of loops. We do the same thing for all the data dimensions. + fused_indices_dimension = 1 + for i in indices_ptr.shape[1:]: + fused_indices_dimension *= i + + fused_data_dimension = 1 + for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: + fused_data_dimension *= i + + fused_shape = 1 + for i in shape: + fused_shape *= i + + # For now we avoid parallizing over dimensions indexed by `indices` as + # there may be repeated indices and hadling parallel accumulation can + # be hard. So we parallelize over X_M .. X_{N-1} instead. This will + # work well when these dimensions are large enough to saturate memory + # bandwidth, but performance will be bad when these dimensions are + # small. + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + tdim = min(max_threads, fused_data_dimension) + ib.scope_attr(tx, "thread_extent", tdim) + bdim = ceil_div(fused_data_dimension, tdim) + ib.scope_attr(bx, "thread_extent", bdim) + + # zero data + # TODO(tkonolige): could we use topi.full to zero it instead? + with ib.for_range(0, ceil_div(fused_shape, bdim)) as i: + index = i * fused_data_dimension + bx * tdim + tx + with ib.if_scope(index < fused_shape): + out[index] = tvm.tir.Cast(data_ptr.dtype, 0) + + with ib.for_range(0, fused_indices_dimension) as i: + j = bx * tdim + tx + with ib.if_scope(j < fused_data_dimension): + offset = fused_data_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part + # of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + offset *= shape[l] + out[index] += data[i * fused_data_dimension + j] + + return ib.get() + + out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf") + return te.extern( + [shape], + [data, indices], + lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_nd_cuda", + tag="scatter_nd_cuda", + ) diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index f1c307a..a376963 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -16,7 +16,8 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Scatter operator""" -from tvm.te import hybrid +from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate +from ..te import extern, hybrid @hybrid.script @@ -196,3 +197,120 @@ def scatter(data, indices, updates, axis=0): if len(data.shape) == 4: return _scatter_4d(data, indices, updates, axis) raise ValueError("scatter only support for 1-4 dimensions") + + +def _verify_scatter_nd_inputs(data, indices, shape): + mdim = int(indices.shape[0]) + assert mdim <= len(shape), ( + f"The first dimension of the indices ({mdim}) must be less than or equal to " + f"the length of the shape of the output ({len(shape)})." + ) + for i in range(len(indices.shape) - 1): + assert indices.shape[i + 1] == data.shape[i], ( + f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " + f"data[{i}] ({data.shape[i]})." + ) + for i in range(mdim, len(shape)): + data_ind = i - mdim + len(indices.shape) - 1 + assert data.shape[data_ind] == shape[i], ( + f"Dimension of data[{data_ind}] ({data.shape[data_ind]}) must equal dimension " + f"of out_shape[{i}] ({shape[i]})." + ) + + assert ( + "int" in indices.dtype + ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}." + + +def scatter_nd(data, indices, shape): + """Scatter elements from a n-dimension array. + + Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + + .. code-block:: + + output[indices[0, y_0, ..., y_{K-1}], + ..., + indices[M-1, y_0, ..., y_{K-1}], + x_M, + ..., + x_{N-1} + ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + + all other entries in the output are 0. Repeated indices are summed. + + Parameters + ---------- + data : tvm.te.Tensor + The source array. + + indices : tvm.te.Tensor + The indices of the values to extract. + + shape : Sequence[int] + The output shape. This must be specified because it cannot be inferred. + + Returns + ------- + ret : tvm.te.Tensor + """ + _verify_scatter_nd_inputs(data, indices, shape) + + def gen_ir(data_ptr, indices_ptr, out_ptr): + ib = ir_builder.create() + + data = ib.buffer_ptr(data_ptr) + indices = ib.buffer_ptr(indices_ptr) + out = ib.buffer_ptr(out_ptr) + + # zero data + # TODO(tkonolige): could we use topi.full to zero it instead? + fused_shape = 1 + for i in shape: + fused_shape *= i + with ib.for_range(0, fused_shape) as i: + out[i] = Cast(data_ptr.dtype, 0) + + # We combine all the indices dimensions but the first one into a single + # dimension so we can iterate it in single loop instead of an arbitrary + # number of loops. We do the same thing for all the data dimensions. + fused_indices_dimension = 1 + for i in indices_ptr.shape[1:]: + fused_indices_dimension *= i + + fused_data_dimension = 1 + for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: + fused_data_dimension *= i + + with ib.for_range(0, fused_indices_dimension, name="i") as i: + with ib.for_range(0, fused_data_dimension, name="j") as j: + offset = fused_data_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part + # of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + ib.emit( + AssertStmt( + indices[i + l * fused_indices_dimension] < shape[l], + StringImm("index out of bounds"), + Evaluate(0), + ) + ) + offset *= shape[l] + out[index] += data[i * fused_data_dimension + j] + + return ib.get() + + out_buf = decl_buffer(shape, data.dtype, "out_buf") + return extern( + [shape], + [data, indices], + lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_nd_generic", + tag="scatter_nd_generic", + ) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 4f90550..0654344 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -57,6 +57,7 @@ from .depth_to_space import depth_to_space_python from .space_to_depth import space_to_depth_python from .crop_and_resize_python import crop_and_resize_python from .common import ( + compare_numpy_tvm, get_injective_schedule, get_reduce_schedule, get_broadcast_schedule, diff --git a/python/tvm/topi/testing/common.py b/python/tvm/topi/testing/common.py index 51ea19a..e4e5e81 100644 --- a/python/tvm/topi/testing/common.py +++ b/python/tvm/topi/testing/common.py @@ -17,8 +17,10 @@ # pylint: disable=invalid-name """Common utility for topi test""" +import numpy as np import tvm from tvm import topi +from tvm.testing import assert_allclose _injective_schedule = { "generic": topi.generic.schedule_injective, @@ -77,3 +79,32 @@ _conv2d_nchw_implement = { def get_conv2d_nchw_implement(target): return dispatch(target, _conv2d_nchw_implement) + + +def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule): + """Compare a numpy inputs and output of a function to the results of the TVM version. + + Parameters + ---------- + inputs : Sequence[numpy.nd.array] + List of input numpy arrays to pass to the function. + output : numpy.nd.array + Verified correct function output. + target : tvm.target.Target + Target to run on. + ctx : tvm.TVMContext + Context to run on. + compute : callable + Topi compute function to test against. + schedule : callable + Topi scheduling function to test against. + """ + te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs] + te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx) + with tvm.target.Target(target): + out = compute(*te_inputs) + s = schedule([out]) + func = tvm.build(s, te_inputs + [out]) + arys = [tvm.nd.array(x, ctx=ctx) for x in inputs] + func(*(arys + [te_out])) + assert_allclose(te_out.asnumpy(), output, atol=1e-4, rtol=1e-4) diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py index 659668c..1545110 100644 --- a/python/tvm/topi/x86/__init__.py +++ b/python/tvm/topi/x86/__init__.py @@ -39,3 +39,4 @@ from .conv2d_transpose import * from .conv3d_transpose import * from .sparse import * from .conv2d_alter_op import * +from .scatter import * diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py new file mode 100644 index 0000000..8147d3a --- /dev/null +++ b/python/tvm/topi/x86/scatter.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Scatter operators for x86""" +import tvm +from tvm import te +from ..scatter import _verify_scatter_nd_inputs + + +def scatter_nd(data, indices, shape): + """Scatter elements from a n-dimension array. + + Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + + .. code-block:: + + output[indices[0, y_0, ..., y_{K-1}], + ..., + indices[M-1, y_0, ..., y_{K-1}], + x_M, + ..., + x_{N-1} + ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + + all other entries in the output are 0. Repeated indices are summed. + + Parameters + ---------- + data : tvm.te.Tensor + The source array. + + indices : tvm.te.Tensor + The indices of the values to extract. + + shape : Sequence[int] + The output shape. This must be specified because it cannot be inferred. + + Returns + ------- + ret : tvm.te.Tensor + """ + _verify_scatter_nd_inputs(data, indices, shape) + + def gen_ir(data_ptr, indices_ptr, out_ptr): + # pylint: disable=invalid-name + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data_ptr) + indices = ib.buffer_ptr(indices_ptr) + out = ib.buffer_ptr(out_ptr) + + # We combine all the indices dimensions but the first one into a single + # dimension so we can iterate it in single loop instead of an arbitrary + # number of loops. We do the same thing for all the data dimensions. + fused_indices_dimension = 1 + for i in indices_ptr.shape[1:]: + fused_indices_dimension *= i + + fused_data_dimension = 1 + for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: + fused_data_dimension *= i + + fused_shape = 1 + for i in shape: + fused_shape *= i + + # zero data + # TODO(tkonolige): could we use topi.full to zero it instead? + with ib.for_range(0, fused_shape) as i: + out[i] = tvm.tir.Cast(data_ptr.dtype, 0) + + with ib.for_range(0, fused_indices_dimension) as i: + with ib.for_range(0, fused_data_dimension, for_type="parallel") as j: + offset = fused_data_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part + # of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + offset *= shape[l] + out[index] += data[i * fused_data_dimension + j] + + return ib.get() + + out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf") + return te.extern( + [shape], + [data, indices], + lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_nd_x86", + tag="scatter_nd_x86", + ) diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 8f14b55..64db13a 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -246,7 +246,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { for (size_t i = 0; i < tt1->shape.size(); i++) { auto dim = UnifyDim(tt1->shape[i], tt2->shape[i]); if (!dim.defined()) { - // NB: We push an arbitrary dimension here so we can continue error propogation. + // NB: We push an arbitrary dimension here so we can continue error propagation. shape.push_back(tt1->shape[i]); tvm::PrimExpr shape1 = tt1->shape[i]; tvm::PrimExpr shape2 = tt2->shape[i]; @@ -259,10 +259,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { if (mismatches.size() != 0) { auto err = Diagnostic::Error(this->span); - err << "in particular "; + err << "The Relay type checker is unable to show the following types match.\n"; + err << "In particular "; for (auto mismatch : mismatches) { - err << "dimension " << std::get<0>(mismatch) << " conflicts " << std::get<1>(mismatch) - << " does not match " << std::get<2>(mismatch); + err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch) + << " does not match " << std::get<2>(mismatch) << "."; } this->solver_->diag_ctx_.Emit(err); return Type(nullptr); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d1f2f26..5a13e9a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -977,6 +977,74 @@ RELAY_REGISTER_OP("scatter_add") .set_attr<TOpPattern>("TOpPattern", kOpaque) .set_support_level(10); +// scatter_nd operator +TVM_REGISTER_NODE_TYPE(ScatterNDAttrs); + +bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, indices, result] + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as<TensorTypeNode>(); + const auto* indices = types[1].as<TensorTypeNode>(); + if (data == nullptr) { + ICHECK(types[0].as<IncompleteTypeNode>()) + << "ScatterND: expect input data type to be TensorType but got " << types[0]; + return false; + } + if (indices == nullptr) { + ICHECK(types[1].as<IncompleteTypeNode>()) + << "ScatterND: expect indices type to be TensorType but got " << types[1]; + return false; + } + ICHECK(indices->dtype.is_int()) << "ScatterND: indices must be a tensor of integers."; + const auto out_shape = attrs.as<ScatterNDAttrs>()->out_shape; + const IntImmNode* mdim = indices->shape[0].as<IntImmNode>(); + const size_t kdim = indices->shape.size() - 1; + const size_t ndim = out_shape.size(); + ICHECK_LE(size_t(mdim->value), ndim) + << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices " + "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N."; + // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's. + for (size_t i = 0; i < kdim; i++) { + reporter->AssertEQ(indices->shape[i + 1], data->shape[i]); + } + + std::vector<IndexExpr> oshape; + for (auto& x : out_shape) { + oshape.push_back(x); + } + + // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1} + for (size_t i = mdim->value; i < ndim; i++) { + reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]); + } + + reporter->Assign(types[2], TensorType(oshape, data->dtype)); + return true; +} + +Expr MakeScatterND(Expr data, Expr indices, const Array<Integer> out_shape) { + auto attrs = make_object<ScatterNDAttrs>(); + attrs->out_shape = out_shape; + static const Op& op = Op::Get("scatter_nd"); + return Call(op, {data, indices}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND); + +RELAY_REGISTER_OP("scatter_nd") + .describe(R"code(Scatter elements or slices from data and store to a tensor +whose shape is defined by indices. + +Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with shape +(M, Y_0, ..., Y_{K-1}), the output will have shape (X_0, X_1, ..., X_{N-1}). +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("ScatterND", ScatterNDRel) + .set_attr<TOpPattern>("TOpPattern", kInjective); + // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 5469737..eec6aa2 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -989,7 +989,10 @@ def test_recursive_concat_with_wrong_annotation(): body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) func = relay.Function([start], relay.TupleGetItem(body, 1)) with DiagnosticTesting() as diagnostics: - diagnostics.assert_message("in particular dimension 0 conflicts 2 does not match 1") + diagnostics.assert_message( + "The Relay type checker is unable to show the following types " + "match.\nIn particular dimension 0 conflicts: 2 does not match 1." + ) func = infer_type(func) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 9c27afd..98ff62e 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -117,5 +117,14 @@ def test_arange_grad(): check_grad(fwd_func, inputs=values) +def test_gather_nd_grad(): + data = relay.var("data", relay.TensorType((2, 3), "float64")) + indices = relay.var("indices", relay.TensorType((2, 4), "int64")) + fwd = relay.Function([data, indices], relay.gather_nd(data, indices)) + data_np = np.random.rand(2, 3).astype("float64") + indices_np = np.array([[0, 1, 1, 0], [0, 1, 0, 0]], dtype="int64") + check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[data_np]) + + if __name__ == "__main__": pytest.main() diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py new file mode 100644 index 0000000..2e701e2 --- /dev/null +++ b/tests/python/topi/python/test_topi_scatter.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +from tvm import topi +import tvm.topi.testing + + +@tvm.testing.parametrize_targets +def test_scatter_nd(ctx, target): + def check_scatter_nd(data, indices, shape, out): + implementations = { + "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern), + "gpu": (lambda x, y: topi.cuda.scatter_nd(x, y, shape), topi.generic.schedule_extern), + "cpu": (lambda x, y: topi.x86.scatter_nd(x, y, shape), topi.generic.schedule_extern), + } + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, ctx, fcompute, fschedule) + + data = np.array([2, 3, 0]) + indices = np.array([[1, 1, 0], [0, 1, 0]]) + shape = (2, 2) + out = np.array([[0, 0], [2, 3]]) + check_scatter_nd(data, indices, shape, out) + + data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + indices = np.array([[0, 1], [1, 1]]) + shape = (2, 2, 2, 2) + out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]]) + check_scatter_nd(data, indices, shape, out) + + data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") + indices = np.array([[1, 0, 0]]) + shape = (2, 1560) + out = np.zeros(shape).astype("float32") + out[1, :] += data[0, :] + out[0, :] += data[1, :] + out[0, :] += data[2, :] + check_scatter_nd(data, indices, shape, out) + + data = np.ones((5, 3)).astype("float64") + indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64") + shape = (2, 7, 3) + out = np.zeros(shape).astype("float64") + for i in range(indices.shape[1]): + for j in range(data.shape[1]): + out[indices[0, i], indices[1, i], j] += data[i, j] + check_scatter_nd(data, indices, shape, out) + + +if __name__ == "__main__": + test_scatter_nd(tvm.context("cpu"), tvm.target.Target("llvm"))