mbrookhart commented on a change in pull request #6388:
URL: https://github.com/apache/incubator-tvm/pull/6388#discussion_r483284293



##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -3146,5 +3146,84 @@ RELAY_REGISTER_OP("matrix_set_diag")
     .set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+// adv_index
+bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                 const TypeReporter& reporter) {
+  CHECK_EQ(num_inputs, 1);
+  auto inputs = types[0].as<TupleTypeNode>();
+  CHECK(inputs != nullptr);

Review comment:
       Don't CHECK here, it will fail in partial type inference (ran into a lot 
of issues like this in the dynamic ONNX PR). Instead, do
   
   ```
   if (inputs == nullptr) {
       return false;
   }
   ```

##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -3146,5 +3146,84 @@ RELAY_REGISTER_OP("matrix_set_diag")
     .set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+// adv_index
+bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                 const TypeReporter& reporter) {
+  CHECK_EQ(num_inputs, 1);
+  auto inputs = types[0].as<TupleTypeNode>();
+  CHECK(inputs != nullptr);
+  auto data = inputs->fields[0].as<TensorTypeNode>();
+  CHECK(data != nullptr);

Review comment:
       return false instead of CHECK

##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -3146,5 +3146,84 @@ RELAY_REGISTER_OP("matrix_set_diag")
     .set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+// adv_index
+bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                 const TypeReporter& reporter) {
+  CHECK_EQ(num_inputs, 1);
+  auto inputs = types[0].as<TupleTypeNode>();
+  CHECK(inputs != nullptr);
+  auto data = inputs->fields[0].as<TensorTypeNode>();
+  CHECK(data != nullptr);
+
+  Array<IndexExpr> oshape;
+  Array<IndexExpr> broadcast_shape;
+  int64_t num_picked_elems = 1;
+
+  if (inputs->fields.size() == 2) {
+    broadcast_shape = inputs->fields[1].as<TensorTypeNode>()->shape;
+  } else {
+    for (size_t i = 1; i < inputs->fields.size(); ++i) {
+      auto index_type = inputs->fields[i].as<TensorTypeNode>();
+      CHECK(index_type != nullptr);

Review comment:
       return false instead of CHECK




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to