This is an automated email from the ASF dual-hosted git repository. kevinthesun pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push: new 1228111 [Relay][Topi][Op]Advanced indexing (#6388) 1228111 is described below commit 122811137a6fe3e787d5174a36d99eebefb479d0 Author: Yao Wang <kevinthesu...@gmail.com> AuthorDate: Thu Sep 10 23:09:45 2020 -0700 [Relay][Topi][Op]Advanced indexing (#6388) * Add Relay adv_index op * Support single index tensor dynamic shape * Support more dynamic index * Fix lint * Minor fix for comment * Fix lint * Fix lint * Fix test * Fix --- include/tvm/topi/transform.h | 80 ++++++++++++++++++++++++ python/tvm/relay/frontend/pytorch.py | 40 +----------- python/tvm/relay/op/_transform.py | 33 ++++++++++ python/tvm/relay/op/transform.py | 18 ++++++ python/tvm/topi/transform.py | 18 ++++++ src/relay/op/tensor/transform.cc | 83 +++++++++++++++++++++++++ src/topi/transform.cc | 4 ++ tests/python/relay/test_any.py | 15 +++++ tests/python/relay/test_op_level3.py | 26 ++++++++ tests/python/topi/python/test_topi_transform.py | 42 +++++++++++++ 10 files changed, 321 insertions(+), 38 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index af59928..2c0d102 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -26,6 +26,7 @@ #include <tvm/te/operation.h> #include <tvm/tir/data_layout.h> +#include <tvm/topi/broadcast.h> #include <tvm/topi/detail/constant_utils.h> #include <tvm/topi/detail/ravel_unravel.h> #include <tvm/topi/detail/tensor_utils.h> @@ -1551,6 +1552,85 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, name, tag); } +/*! + * \brief Numpy style advanced indexing with tensor. + * \param data is input data. + * \param indices is list of indexing tensors. + * \param name output tensor name. + * \param tag output tensor tag. + * \return Output tensor. + */ +inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices, + const std::string name = "advanced_index", + const std::string tag = kInjective) { + Array<PrimExpr> oshape; + Array<PrimExpr> broadcast_shape; + Array<Tensor> bindices; + std::vector<int64_t> flatten_shape_lens; + int64_t num_picked_elems = 1; + bool has_dyn_shape = false; + + if (indices.size() == 1) { + broadcast_shape = indices[0]->shape; + bindices = indices; + } else { + for (const auto& index : indices) { + int64_t flatten_len = 1; + for (const auto& dim : index->shape) { + const IntImmNode* axis_len = dim.as<IntImmNode>(); + if (!axis_len) { + broadcast_shape = index->shape; + has_dyn_shape = true; + break; + } + flatten_len *= axis_len->value; + } + if (has_dyn_shape) break; + flatten_shape_lens.push_back(flatten_len); + if (flatten_len > num_picked_elems) { + num_picked_elems = flatten_len; + broadcast_shape = index->shape; + } + } + + // Do broadcast for indices + for (size_t i = 0; i < indices.size(); ++i) { + if (!has_dyn_shape && flatten_shape_lens[i] < num_picked_elems) { + bindices.push_back(broadcast_to(indices[i], broadcast_shape)); + } else { + bindices.push_back(indices[i]); + } + } + } + + for (const auto& dim : broadcast_shape) { + oshape.push_back(dim); + } + for (size_t i = indices.size(); i < data->shape.size(); ++i) { + oshape.push_back(data->shape[i]); + } + + return compute( + oshape, + [&](const Array<Var>& iter_var) { + Array<PrimExpr> tensor_indices; + for (size_t i = 0; i < broadcast_shape.size(); ++i) { + tensor_indices.push_back(iter_var[i]); + } + + Array<PrimExpr> real_indices; + for (size_t i = 0; i < bindices.size(); ++i) { + real_indices.push_back(bindices[i](tensor_indices)); + } + for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) { + real_indices.push_back(iter_var[i]); + } + + return data(real_indices); + }, + name, tag); +} + } // namespace topi } // namespace tvm #endif // TVM_TOPI_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7203150..19cbf75 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1816,44 +1816,8 @@ def _one_hot(): def _index(): def _impl(inputs, input_types): data = inputs[0] - indices = [] - raw_indices = [] - max_indices_len = -1 - for index in inputs[1]: - if not isinstance(index, _expr.Constant): - try: - index = _expr.const(_infer_value(index, {})) - except Exception: - raise RuntimeError("Only supports constant indices for " - "pytorch advanced indexing ") - raw_indices.append(index) - cindex_len = index.data.shape[0] - if cindex_len > max_indices_len: - max_indices_len = cindex_len - - for index in raw_indices: - cnp = index.data.asnumpy() - cindex_len = cnp.shape[0] - if cindex_len < max_indices_len: - cnp = np.tile(cnp, max_indices_len // cindex_len) - indices.append(cnp) - - ret = [] - slice_map = {} - for i in range(indices[0].shape[0]): - tmp = data - current_indices = [] - for index in indices: - current_indices.append(index[i]) - index_key = tuple(current_indices) - if index_key in slice_map: - tmp = slice_map[index_key] - else: - tmp = _op.take(tmp, _expr.const(index[i]), axis=0) - slice_map[index_key] = tmp - ret.append(_op.expand_dims(tmp, axis=0)) - - return _op.concatenate(ret, axis=0) + indices = inputs[1] + return _op.adv_index([data] + indices) return _impl diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 9d7c389..98ff0b3 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -62,6 +62,7 @@ _reg.register_reduce_schedule("collapse_sum_to") _reg.register_injective_schedule("unravel_index") _reg.register_injective_schedule("sparse_to_dense") _reg.register_injective_schedule("matrix_set_diag") +_reg.register_injective_schedule("adv_index") # concatenate _reg.register_schedule("concatenate", strategy.schedule_concatenate) @@ -661,3 +662,35 @@ def split_shape_func(attrs, inputs, _): convert(i), convert(indices_or_sections), convert(axis)) for i in range(num_out)] + +@script +def _adv_index_shape_func(inputs): + index_rank = inputs[1].shape[0] + data_rank = inputs[0].shape[0] + out = output_tensor((data_rank + index_rank - len(inputs) + 1,), "int64") + + max_flatten_len = int64(1) + for i in const_range(index_rank): + max_flatten_len *= inputs[1][i] + out[i] = inputs[1][i] + for i in const_range(len(inputs) - 2): + flatten_len = int64(1) + for j in const_range(index_rank): + flatten_len *= inputs[i + 2][j] + if flatten_len > max_flatten_len: + max_flatten_len = flatten_len + for k in const_range(index_rank): + out[k] = inputs[i + 2][k] + + for i in const_range(data_rank - len(inputs) + 1): + out[i + index_rank] = inputs[0][i + len(inputs) - 1] + + return out + +@_reg.register_shape_func("adv_index", False) +def adv_index_shape_func(attrs, inputs, _): + """ + Shape func for adv_index. + Only allow single index tensor. + """ + return [_adv_index_shape_func(inputs)] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 01466f7..0ce59ad 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1213,3 +1213,21 @@ def matrix_set_diag(data, diagonal): [7, 7, 6, 7]]] """ return _make.matrix_set_diag(data, diagonal) + + +def adv_index(inputs): + """ + Numpy style advanced indexing. Index with a list of tensors. + + Parameters + ---------- + inputs : Union(List[relay.Expr], Tuple[relay.Expr]) + Input tensor and indices. + The first tensor is input data and rests are indices. + + Returns + ------- + result: relay.Expr + Output tensor. + """ + return _make.adv_index(Tuple(inputs)) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index f3e5a6a..1681d87 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -838,3 +838,21 @@ def matrix_set_diag(data, diagonal): [7, 7, 6, 7]]] """ return cpp.matrix_set_diag(data, diagonal) + +def adv_index(data, indices): + """Numpy style indexing with tensors. + + Parameters + ---------- + data : tvm.te.Tensor + Input data. + + indices : A list of tvm.te.Tensor + Tensor index. + + Returns + ------- + result : tvm.te.Tensor + Output tensor + """ + return cpp.adv_index(data, indices) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 88179b7..e3d0950 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3163,5 +3163,88 @@ RELAY_REGISTER_OP("matrix_set_diag") .set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute) .set_attr<TOpPattern>("TOpPattern", kInjective); +// adv_index +bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 1); + auto inputs = types[0].as<TupleTypeNode>(); + auto data = inputs->fields[0].as<TensorTypeNode>(); + + if (inputs == nullptr || data == nullptr) { + return false; + } + + Array<IndexExpr> oshape; + Array<IndexExpr> broadcast_shape; + int64_t num_picked_elems = 1; + + if (inputs->fields.size() == 2) { + broadcast_shape = inputs->fields[1].as<TensorTypeNode>()->shape; + } else { + for (size_t i = 1; i < inputs->fields.size(); ++i) { + auto index_type = inputs->fields[i].as<TensorTypeNode>(); + if (index_type == nullptr) { + return false; + } + CHECK(index_type->dtype.is_int()) << "indices must be tensor of integers"; + + int64_t flatten_len = 1; + bool has_dyn_shape = false; + for (const auto& dim : index_type->shape) { + const IntImmNode* axis_len = dim.as<IntImmNode>(); + if (!axis_len) { + // If dynamic shape appears, just use the first shape + broadcast_shape = index_type->shape; + has_dyn_shape = true; + break; + } + flatten_len *= axis_len->value; + } + if (has_dyn_shape) break; + if (flatten_len > num_picked_elems) { + num_picked_elems = flatten_len; + broadcast_shape = index_type->shape; + } + } + } + + for (const auto& dim : broadcast_shape) { + oshape.push_back(dim); + } + for (size_t i = inputs->fields.size() - 1; i < data->shape.size(); ++i) { + oshape.push_back(data->shape[i]); + } + reporter->Assign(types[1], TensorType(oshape, data->dtype)); + return true; +} + +Array<te::Tensor> AdvIndexCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, + const Type& out_type) { + Array<te::Tensor> indices; + for (size_t i = 1; i < inputs.size(); ++i) { + indices.push_back(inputs[i]); + } + return {topi::adv_index(inputs[0], indices)}; +} + +Expr MakeAdvIndex(Expr inputs) { + static const Op& op = Op::Get("adv_index"); + return Call(op, {inputs}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.adv_index").set_body_typed(MakeAdvIndex); + +RELAY_REGISTER_OP("adv_index") + .describe(R"code(Numpy style advanced indexing. Index with a list of tensors. + )code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .set_support_level(3) + .add_argument("inputs", "Tuple of Tensors", "Input tensor and indices.") + .add_type_rel("AdvIndex", AdvIndexRel) + .set_attr<TOpIsStateful>("TOpIsStateful", false) + .set_attr<TOpPattern>("TOpPattern", kInjective) + .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr<FTVMCompute>("FTVMCompute", AdvIndexCompute); + } // namespace relay } // namespace tvm diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 154933f..bf7e1e6 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -180,5 +180,9 @@ TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValu *rv = matrix_set_diag(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = adv_index(args[0], args[1]); +}); + } // namespace topi } // namespace tvm diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 6bb34d3..3a46fdd 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -892,5 +892,20 @@ def test_reshape_concat(): np.reshape(np_data1, np_shape_like1.shape)], axis=0) check_result([np_data0, np_data1, np_shape_like0, np_shape_like1], mod, ref_res) +def test_any_adv_index(): + data = relay.var("data", shape=(5, relay.Any(), relay.Any()), dtype='float32') + index0 = relay.var("index0", shape=(1, relay.Any()), dtype='int64') + index1 = relay.var("index1", shape=(1, relay.Any()), dtype='int64') + out = relay.adv_index([data, index0, index1]) + mod = tvm.IRModule() + mod['main'] = relay.Function([data, index0, index1], out) + np_data_shape = (5, 5, 10) + np_index_shape = (1, 4) + np_data = np.random.uniform(size=np_data_shape).astype('float32') + np_index = np.random.uniform(0, np_data_shape[0], size=np_index_shape).astype('int64') + ref_res = np_data[tuple([np_index, np_index])] + check_result([np_data, np_index, np_index], mod, ref_res) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index f709aa2..98ef38d 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1091,6 +1091,31 @@ def test_sparse_to_dense(): #sparse_indices should not be > 2d tensor #verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) +def test_adv_index(): + def verify_adv_index(data_shape, index_shapes): + dtype = "float32" + inputs = [relay.var("data", relay.TensorType(data_shape, dtype))] + np_data = np.random.uniform(size=data_shape).astype(dtype) + np_indices = [] + for i, index_shape in enumerate(index_shapes): + limit = data_shape[i] + np_indices.append(np.random.uniform(0, limit - 1, size=index_shape).astype("int64")) + inputs.append(relay.var("index_{}".format(i), relay.TensorType(index_shape, "int64"))) + np_out = np_data[tuple(np_indices)] + np_args = [np_data] + np_indices + out = relay.op.adv_index(inputs) + + func = relay.Function(inputs, out) + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(*np_args) + tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=1e-5) + + verify_adv_index((10, 5), [(3, 4), (3, 1)]) + verify_adv_index((10, 5), [(2,),]) + verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)]) + if __name__ == "__main__": test_cast() test_zeros_ones() @@ -1127,3 +1152,4 @@ if __name__ == "__main__": test_unravel_index() test_sparse_to_dense() test_fixed_point_multiply() + test_adv_index() diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index a061ba9..fc6f19f 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -678,6 +678,7 @@ def verify_matrix_set_diag(input_shape, dtype): input = te.placeholder(shape=input_shape, name="input", dtype=dtype) diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype) matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal) + def check_device(device, ctx): ctx = tvm.context(device, 0) print("Running on target: %s" % device) @@ -697,6 +698,40 @@ def verify_matrix_set_diag(input_shape, dtype): for target, ctx in tvm.testing.enabled_targets(): check_device(target, ctx) +def verify_adv_index(data_shape, index_shapes): + dtype = "float32" + data = te.placeholder(shape=data_shape, name="data", dtype=dtype) + indices = [] + np_data = np.random.uniform(size=data_shape).astype(dtype) + np_indices = [] + for i, index_shape in enumerate(index_shapes): + limit = data_shape[i] + np_indices.append(np.random.uniform(0, limit - 1, size=index_shape).astype("int64")) + indices.append(te.placeholder(shape=index_shape, name="index_{}".format(i), dtype="int64")) + np_out = np_data[tuple(np_indices)] + out = topi.adv_index(data, indices) + + def check_device(device, ctx): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = tvm.topi.testing.get_injective_schedule(device)(out) + + func = tvm.build(s, [data] + indices + [out], device, name="adv_index") + + nd_list = [tvm.nd.array(np_data, ctx)] + for np_index in np_indices: + nd_list.append(tvm.nd.array(np_index, ctx)) + nd_list.append(tvm.nd.empty(out.shape, ctx=ctx, dtype=data.dtype)) + + func(*nd_list) + tvm.testing.assert_allclose(nd_list[-1].asnumpy(), np.array(np_out)) + + for target, ctx in tvm.testing.enabled_targets(): + check_device(target, ctx) @tvm.testing.uses_gpu def test_strided_slice(): @@ -1071,6 +1106,12 @@ def test_matrix_set_diag(): verify_matrix_set_diag((4, 3, 3), dtype) verify_matrix_set_diag((2, 3, 4), dtype) +@tvm.testing.uses_gpu +def test_adv_index(): + verify_adv_index((3, 4, 5), [(2,), (2, ), (1,)]) + verify_adv_index((10, 15, 5), [(1, 1), (2, 7)]) + verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)]) + if __name__ == "__main__": test_strided_slice() test_concatenate() @@ -1097,3 +1138,4 @@ if __name__ == "__main__": test_unravel_index() test_sparse_to_dense() test_matrix_set_diag() + test_adv_index()