This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ad2f77  [Relay] Gather op dynamic input support (#9240)
5ad2f77 is described below

commit 5ad2f77403bed9a2bf356cc0d3d785ecc13e6c58
Author: masahi <masahi...@gmail.com>
AuthorDate: Tue Oct 12 01:22:10 2021 +0900

    [Relay] Gather op dynamic input support (#9240)
    
    * support gather op dynamic input
    
    * fix shape func and add test
    
    * remove constness check
    
    * fix shape func output rank
    
    * restore check
    
    Co-authored-by: masa <masa@pop-os.localdomain>
---
 include/tvm/topi/transform.h      |  6 ++++--
 python/tvm/relay/op/_transform.py | 20 ++++++++++++++++++++
 src/relay/op/tensor/transform.cc  |  6 ++++--
 tests/python/relay/test_any.py    | 22 ++++++++++++++++++++++
 4 files changed, 50 insertions(+), 4 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 8d1a49a..3df9caf 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1233,8 +1233,10 @@ inline Tensor gather(const Tensor& data, int axis, const 
Tensor& indices,
   }
   ICHECK_GE(axis, 0);
   ICHECK_LT(axis, ndim_d);
-  size_t indices_dim_i = 
static_cast<size_t>(GetConstInt(indices->shape[axis]));
-  ICHECK_GE(indices_dim_i, 1);
+  if (indices->shape[axis].as<IntImmNode>()) {
+    size_t indices_dim_i = 
static_cast<size_t>(GetConstInt(indices->shape[axis]));
+    ICHECK_GE(indices_dim_i, 1);
+  }
   ICHECK(indices->dtype.is_int());
 
   Array<PrimExpr> out_shape;
diff --git a/python/tvm/relay/op/_transform.py 
b/python/tvm/relay/op/_transform.py
index 0284d24..76c8069 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -1174,3 +1174,23 @@ def gather_nd_shape_func(attrs, inputs, _):
     assert index_rank > 0, "index_rank needs to be specified for dynamic 
gather_nd"
 
     return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), 
convert(index_rank))]
+
+
+@script
+def _gather_shape(data_shape, indices_shape, axis):
+    out_shape = output_tensor((data_shape.shape[0],), "int64")
+    for i in range(data_shape.shape[0]):
+        if i != axis:
+            assert (
+                data_shape[i] == indices_shape[i]
+            ), "data and indices size at non-gather axes must be the same"
+        out_shape[i] = indices_shape[i]
+    return out_shape
+
+
+@_reg.register_shape_func("gather", False)
+def gather_shape_func(attrs, inputs, _):
+    """
+    Shape func for gather operator.
+    """
+    return [_gather_shape(inputs[0], inputs[1], attrs.axis)]
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 3781107..fa5b31a 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3260,8 +3260,10 @@ bool GatherRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
   oshape.reserve(ndim_data);
   for (size_t i = 0; i < ndim_data; ++i) {
     if (i == static_cast<size_t>(axis)) {
-      const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]);
-      ICHECK_GE(*indice_shape_i, 1);
+      if (indices->shape[i].as<IntImmNode>()) {
+        const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]);
+        ICHECK_GE(*indice_shape_i, 1);
+      }
     } else {
       ICHECK(reporter->AssertEQ(indices->shape[i], data->shape[i]));
     }
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index decddc1..8788faf 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -2064,5 +2064,27 @@ def test_scatter_nd():
     verify_scatter_nd(data, indices, updates, out)
 
 
+@tvm.testing.uses_gpu
+def test_gather():
+    def verify_gather(data_shape, indices_shape, data_shape_np, 
indices_shape_np, axis):
+        x = relay.var("x", relay.TensorType(data_shape, "float32"))
+        y = relay.var("y", relay.TensorType(indices_shape, "int32"))
+        z = relay.gather(x, axis, y)
+
+        mod = tvm.IRModule()
+        mod["main"] = relay.Function([x, y], z)
+
+        data_np = np.random.uniform(size=data_shape_np).astype("float32")
+        indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, 
dtype="int32")
+
+        ref_res = tvm.topi.testing.gather_python(data_np, axis, indices_np)
+        check_result([data_np, indices_np], mod, [ref_res])
+
+    verify_gather((relay.Any(),), (relay.Any(),), (10,), (10,), 0)
+    verify_gather((2, 2), (2, relay.Any()), (2, 2), (2, 3), 1)
+    verify_gather((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3), 1)
+    verify_gather((relay.Any(), relay.Any()), (relay.Any(), relay.Any()), (2, 
3), (1, 3), 0)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to