kevinthesun commented on a change in pull request #4312: [TOPI][Relay][OP] 
Dynamic NMS and strided_slice
URL: https://github.com/apache/incubator-tvm/pull/4312#discussion_r345407087
 
 

 ##########
 File path: python/tvm/relay/frontend/tensorflow.py
 ##########
 @@ -324,6 +325,53 @@ def _impl(inputs, attr, params):
         return out
     return _impl
 
+def _nms():
+    def _impl(inputs, attr, params):
+        # Get parameter values
+        max_output_size = 
int(np.atleast_1d(inputs[2].data.asnumpy().astype("int64"))[0])
+        iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0]
+        # score_threshold was introduced from V3
+        score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if 
len(inputs) > 4 else None
+
+        scores = AttrCvt(op_name="expand_dims",
+                         extras={'axis': -1, 'num_newaxis': 1})([inputs[1]], 
attr)
+
+        data = get_relay_op('concatenate')([scores, inputs[0]], -1)
+        # expand to [class_id, prob, box]
+        # data = _get_relay_op('concatenate')([scores, data], -1)
+        # expand to [batch_size, num_anchors, 6] or [batch_size, num_anchors, 
5]
+
+        data = get_relay_op('expand_dims')(data, 0, 1)
+
+        # Don't need to call get_valid_counts for TensorFlow and ONNX
+        # ct, data = _get_relay_op('get_valid_counts')(data, 
score_threshold=score_threshold,
+        #                                              id_index=-1, 
score_index=0)
+        # get the number of anchors
+        data_shape = attr['_input_shapes'][inputs[1]]
+        valid_cnt = _expr.const(data_shape)
+        # TensorFlow NMS doesn't have parameter top_k
+        top_k = -1
+        # score_index is 0 since TF doesn't have class id for nms input
+        score_index = 0
+        nms_ret = get_relay_op('non_max_suppression')(data=data,
+                                                      valid_count=valid_cnt,
+                                                      
max_output_size=max_output_size,
+                                                      
score_threshold=score_threshold,
 
 Review comment:
   Same here. We still need get_valid_count.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to