masahi commented on a change in pull request #7346:
URL: https://github.com/apache/tvm/pull/7346#discussion_r565698955



##########
File path: python/tvm/relay/frontend/pytorch_utils.py
##########
@@ -169,10 +218,193 @@ def callback(self, pre, post, node_map):
         return self.convert_batched_nms(boxes, scores, idxs, iou_thres, 
num_boxes, indices)
 
 
+class PostNMSTopKRewrite(DFPatternCallback):
+    """A callback to rewrite nms to exploit max_out_size parameter."""
+
+    def __init__(self):
+        super().__init__()
+        self.cond = wildcard()
+        self.true_branch = wildcard()
+        self.data = wildcard()
+        self.valid_count = wildcard()
+        self.indices = wildcard()
+        self.iou_threshold = wildcard()
+
+        self.pattern = topk_after_batch_nms_pattern(
+            self.cond,
+            self.true_branch,
+            self.data,
+            self.valid_count,
+            self.indices,
+            self.iou_threshold,
+        )
+
+    def rewrite_batch_nms_with_max_out_size(
+        self, cond, true_branch, data, valid_count, indices, iou_threshold, 
post_nms_topk
+    ):
+        """Use the detected post NMS topk parameter in NMS op."""
+        nms_ret = op.vision.non_max_suppression(
+            data=data,
+            valid_count=valid_count,
+            indices=indices,
+            max_output_size=post_nms_topk,
+            iou_threshold=iou_threshold,
+            force_suppress=False,
+            top_k=-1,
+            coord_start=2,
+            score_index=1,
+            id_index=0,
+            return_indices=True,
+            invalid_to_bottom=False,
+        )
+
+        size = op.squeeze(nms_ret[1], axis=[1])
+        data_slice = op.squeeze(nms_ret[0], axis=[0])
+
+        ret = op.strided_slice(data_slice, begin=expr.const([0]), end=size, 
slice_mode="size")
+
+        nms_result = op.cast(ret, "int64")
+
+        return expr.If(cond, true_branch, nms_result)
+
+    def callback(self, pre, post, node_map):
+        post_nms_topk = post.attrs.end[0].value
+        return self.rewrite_batch_nms_with_max_out_size(
+            node_map[self.cond][0],
+            node_map[self.true_branch][0],
+            node_map[self.data][0],
+            node_map[self.valid_count][0],
+            node_map[self.indices][0],
+            node_map[self.iou_threshold][0],
+            post_nms_topk,
+        )
+
+
+def scatter_roi_align_result_pattern(levels, roi_align_results, num_scales):
+    """Detect the Relay subgraph corresponding to the following PyTorch code
+
+    first_result = roi_align_results[0]
+    dtype, device = first_result.dtype, first_result.device
+    res = torch.zeros((levels.size(0), first_result.size(1),
+                       first_result.size(2), first_result.size(3)),
+                      dtype=dtype, device=device)
+    for level in range(len(roi_align_results)):
+        index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
+        index = index.expand(index.size(0),
+                             roi_align_results[level].size(1),
+                             roi_align_results[level].size(2),
+                             roi_align_results[level].size(3))
+        res = res.scatter(0, index, roi_align_results[level])
+    return res
+    """
+
+    def do_where(levels, _):
+        idx_in_level = is_op("argwhere")(is_op("equal")(levels, is_constant()))
+        idx_in_level = is_op("split")(idx_in_level)
+        idx_in_level = is_tuple_get_item(idx_in_level, 0)
+        idx_in_level = is_op("squeeze")(idx_in_level)
+        idx_in_level = is_tuple_get_item(is_tuple([idx_in_level]), 0)
+        return idx_in_level
+
+    scatter_res = wildcard()
+
+    for i in range(num_scales):
+        # index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
+        scatter_indices = do_where(levels, i)
+        scatter_indices = is_op("reshape")(scatter_indices)
+
+        # index = index.expand(index.size(0),

Review comment:
       Actually this is the equivalent PyTorch code to explain what pattern 
does :) I intentionally keep it




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to