This is an automated email from the ASF dual-hosted git repository. comaniac pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 5ad2f77 [Relay] Gather op dynamic input support (#9240) 5ad2f77 is described below commit 5ad2f77403bed9a2bf356cc0d3d785ecc13e6c58 Author: masahi <masahi...@gmail.com> AuthorDate: Tue Oct 12 01:22:10 2021 +0900 [Relay] Gather op dynamic input support (#9240) * support gather op dynamic input * fix shape func and add test * remove constness check * fix shape func output rank * restore check Co-authored-by: masa <masa@pop-os.localdomain> --- include/tvm/topi/transform.h | 6 ++++-- python/tvm/relay/op/_transform.py | 20 ++++++++++++++++++++ src/relay/op/tensor/transform.cc | 6 ++++-- tests/python/relay/test_any.py | 22 ++++++++++++++++++++++ 4 files changed, 50 insertions(+), 4 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 8d1a49a..3df9caf 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1233,8 +1233,10 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, } ICHECK_GE(axis, 0); ICHECK_LT(axis, ndim_d); - size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis])); - ICHECK_GE(indices_dim_i, 1); + if (indices->shape[axis].as<IntImmNode>()) { + size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis])); + ICHECK_GE(indices_dim_i, 1); + } ICHECK(indices->dtype.is_int()); Array<PrimExpr> out_shape; diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 0284d24..76c8069 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1174,3 +1174,23 @@ def gather_nd_shape_func(attrs, inputs, _): assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd" return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))] + + +@script +def _gather_shape(data_shape, indices_shape, axis): + out_shape = output_tensor((data_shape.shape[0],), "int64") + for i in range(data_shape.shape[0]): + if i != axis: + assert ( + data_shape[i] == indices_shape[i] + ), "data and indices size at non-gather axes must be the same" + out_shape[i] = indices_shape[i] + return out_shape + + +@_reg.register_shape_func("gather", False) +def gather_shape_func(attrs, inputs, _): + """ + Shape func for gather operator. + """ + return [_gather_shape(inputs[0], inputs[1], attrs.axis)] diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3781107..fa5b31a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3260,8 +3260,10 @@ bool GatherRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, oshape.reserve(ndim_data); for (size_t i = 0; i < ndim_data; ++i) { if (i == static_cast<size_t>(axis)) { - const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]); - ICHECK_GE(*indice_shape_i, 1); + if (indices->shape[i].as<IntImmNode>()) { + const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]); + ICHECK_GE(*indice_shape_i, 1); + } } else { ICHECK(reporter->AssertEQ(indices->shape[i], data->shape[i])); } diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index decddc1..8788faf 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -2064,5 +2064,27 @@ def test_scatter_nd(): verify_scatter_nd(data, indices, updates, out) +@tvm.testing.uses_gpu +def test_gather(): + def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, axis): + x = relay.var("x", relay.TensorType(data_shape, "float32")) + y = relay.var("y", relay.TensorType(indices_shape, "int32")) + z = relay.gather(x, axis, y) + + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + + data_np = np.random.uniform(size=data_shape_np).astype("float32") + indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, dtype="int32") + + ref_res = tvm.topi.testing.gather_python(data_np, axis, indices_np) + check_result([data_np, indices_np], mod, [ref_res]) + + verify_gather((relay.Any(),), (relay.Any(),), (10,), (10,), 0) + verify_gather((2, 2), (2, relay.Any()), (2, 2), (2, 3), 1) + verify_gather((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3), 1) + verify_gather((relay.Any(), relay.Any()), (relay.Any(), relay.Any()), (2, 3), (1, 3), 0) + + if __name__ == "__main__": pytest.main([__file__])