areusch commented on code in PR #12525:
URL: https://github.com/apache/tvm/pull/12525#discussion_r955439427


##########
python/tvm/driver/tvmc/autotuner.py:
##########
@@ -453,7 +518,37 @@ def tune_model(
             hardware_params=hardware_params,
             include_simple_tasks=include_simple_tasks,
         )
+    else:
+        tasks = autotvm_get_tuning_tasks(
+            mod=mod,
+            params=params,
+            target=target,
+            alter_layout=desired_layout,
+        )
+
+    # Filter extracted tasks by provided user expression
+    if tasks_filter:
+        tasks, do_list = filter_tasks(tasks, tasks_filter)
+        if do_list:
+            print("Available Tasks for tuning:")
+            print(
+                "\n".join(
+                    [
+                        "  {}. {}".format(
+                            i, task if len(str(task)) < 100 else 
str(task)[:97] + "..."

Review Comment:
   in this case does this make the task useless?



##########
tests/python/driver/tvmc/test_autotuner.py:
##########
@@ -182,3 +183,23 @@ def test_tune_rpc_tracker_parsing(mock_load_model, 
mock_tune_model, mock_auto_sc
     assert "10.0.0.1" == kwargs["hostname"]
     assert "port" in kwargs
     assert 9999 == kwargs["port"]
+
+
+def test_filter_tasks_valid():
+    filter_tasks(list(range(10)), "list") == ([], True)
+    filter_tasks(list(range(10)), "help") == ([], True)
+    filter_tasks(list(range(10)), "all") == ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 
False)
+    filter_tasks(list(range(10)), "5") == ([5], False)
+    filter_tasks(list(range(10)), "1-5") == ([1, 2, 3, 4, 5], False)
+    filter_tasks(list(range(10)), "-5") == ([0, 1, 2, 3, 4, 5], False)
+    filter_tasks(list(range(10)), "6-") == ([6, 7, 8, 9], False)
+    filter_tasks(list(range(10)), "0,1-3,all") == ([0, 1, 2, 3, 4, 5, 6, 7, 8, 
9], False)
+    filter_tasks(list(range(10)), "0,4-5,9,list") == ([0, 4, 5, 9], True)
+
+
+@pytest.mark.xfail
+def test_filter_tasks_invalid():
+    filter_tasks(list(range(10)), "10")
+    filter_tasks(list(range(10)), "5,10")
+    filter_tasks(list(range(10)), "1-10")
+    filter_tasks(list(range(10)), "-10")

Review Comment:
   do you mind adding the customary
   
   ```
   if __name__ == "__main__":
     tvm.testing.main()
   ```
   here?



##########
python/tvm/driver/tvmc/autotuner.py:
##########
@@ -293,9 +298,66 @@ def drive_tune(args):
         include_simple_tasks=args.include_simple_tasks,
         log_estimated_latency=args.log_estimated_latency,
         additional_target_options=reconstruct_target_args(args),
+        tasks_filter=args.tasks,
     )
 
 
+def filter_tasks(
+    tasks: Optional[Union[auto_scheduler.SearchTask, autotvm.task.Task]],
+    expr: str,
+):
+    assert isinstance(expr, str), "Expected filter expression of string type"
+    assert len(expr) > 0, "Got empty filter expression"
+
+    # groups of keywords are comma-separated
+    splitted = expr.split(",")
+
+    do_list = False
+    do_filter = False
+    selected = []
+    for item in splitted:
+        if item in ["list", "help"]:
+            do_list = True
+        elif item in ["all"]:
+            selected = list(range(len(tasks)))
+        else:
+            do_filter = True
+            if "-" in item:
+                lhs, rhs = item.split("-")[:2]
+                if len(lhs) == 0:
+                    assert len(rhs) > 0
+                    assert isinstance(rhs, str)

Review Comment:
   not sure we need these asserts, since item should already be a str right?



##########
python/tvm/driver/tvmc/autotuner.py:
##########
@@ -293,9 +298,66 @@ def drive_tune(args):
         include_simple_tasks=args.include_simple_tasks,
         log_estimated_latency=args.log_estimated_latency,
         additional_target_options=reconstruct_target_args(args),
+        tasks_filter=args.tasks,
     )
 
 
+def filter_tasks(
+    tasks: Optional[Union[auto_scheduler.SearchTask, autotvm.task.Task]],
+    expr: str,
+):
+    assert isinstance(expr, str), "Expected filter expression of string type"
+    assert len(expr) > 0, "Got empty filter expression"
+
+    # groups of keywords are comma-separated
+    splitted = expr.split(",")
+
+    do_list = False
+    do_filter = False
+    selected = []
+    for item in splitted:
+        if item in ["list", "help"]:
+            do_list = True
+        elif item in ["all"]:
+            selected = list(range(len(tasks)))
+        else:
+            do_filter = True
+            if "-" in item:
+                lhs, rhs = item.split("-")[:2]
+                if len(lhs) == 0:

Review Comment:
   i think you could maybe make this a bit more concise:
   lhs = 
   ```suggestion
                   lhs = int(lhs) if lhs else 0
                   rhs = int(rhs) if rhs else len(tasks) - 1
                   assert 0 <= lhs < len(tasks), "explanation"
                   assert 0 <= rhs < len(tasks), "explanation"
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to