sunggg commented on code in PR #12895:
URL: https://github.com/apache/tvm/pull/12895#discussion_r985956777


##########
python/tvm/meta_schedule/database/database.py:
##########
@@ -361,6 +363,44 @@ def current() -> Optional["Database"]:
         """Get the current database under scope."""
         return _ffi_api.DatabaseCurrent()  # type: ignore # pylint: 
disable=no-member
 
+    @staticmethod
+    def create(  # pylint: disable=keyword-arg-before-vararg

Review Comment:
   Just for a curiosity. 
   Is the motivation for this API to remove user-side imports for various 
database types? 
   For example, with this, users wouldn't need to import every database. 
Rather, they can just import this API. 



##########
python/tvm/meta_schedule/relay_integration.py:
##########
@@ -69,47 +117,218 @@ def extract_task_from_relay(
     """
     # pylint: disable=import-outside-toplevel
     from tvm import autotvm
-    from tvm.relay import Function as RelayFunc
 
     # pylint: enable=import-outside-toplevel
+    mod, target, params = _normalize_params(mod, target, params)
+    if target.kind.name != "cuda" and isinstance(
+        autotvm.DispatchContext.current, autotvm.FallbackContext
+    ):
+        tophub_context = autotvm.tophub.context(target)
+    else:
+        tophub_context = autotvm.utils.EmptyContext()
+    with Profiler.timeit("TaskExtraction"):
+        with target, _autotvm_silencer(), tophub_context:
+            with transform.PassContext(
+                opt_level=3,
+                config={
+                    "relay.backend.use_meta_schedule": True,
+                    "relay.backend.tir_converter": tir_converter,
+                },
+            ):
+                return list(_extract_task(mod, target, params))
 
-    extract_task_func = get_global_func(
-        "relay.backend.MetaScheduleExtractTask",
-        allow_missing=False,
-    )
 
-    if isinstance(mod, RelayFunc):
-        mod = IRModule.from_expr(mod)
-    if not isinstance(target, Target):
-        target = Target(target)
-    if disabled_pass is None:
-        disabled_pass = []
-    if pass_config is None:
-        pass_config = {
-            "relay.backend.use_meta_schedule": True,
-            "relay.backend.tir_converter": tir_converter,
-        }
-    if params is None:
-        params = {}
-    relay_params = {}
-    for name, param in params.items():
-        if isinstance(param, np.ndarray):
-            param = nd.array(param)
-        relay_params[name] = param
+def extracted_tasks_to_tune_contexts(
+    extracted_tasks: List[ExtractedTask],
+    work_dir: str,
+    space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
+    strategy: SearchStrategy.SearchStrategyType = "evolutionary",
+    num_threads: Union[Literal["physical", "logical"], int] = "physical",
+    seed: Optional[int] = None,
+) -> Tuple[List[TuneContext], List[float]]:
+    """Convert ExtractedTask to TuneContext.
 
-    with target, autotvm_silencer(), transform.PassContext(
-        opt_level=opt_level,
-        config=pass_config,
-        disabled_pass=disabled_pass,
+    Parameters
+    ----------
+    tasks : List[ExtractedTask]
+        The tasks to be converted
+    work_dir : str
+        The working directory to store logs and databases
+    space : SpaceGenerator.SpaceGeneratorType
+        The space generator to use.
+    strategy : SearchStrategy.SearchStrategyType
+        The search strategy to use.
+    num_threads : Union[Literal["physical", "logical"], int]
+        The number of threads to use.
+    seed : Optional[int]
+        The random seed to use.
+
+    Returns
+    -------
+    tasks : List[TuneContext]
+        The converted tasks
+    task_weights : List[float]
+        The weights of the tasks
+    """
+    tasks: List[TuneContext] = []
+    task_weights: List[float] = []
+    for task, logger, rand_state in zip(
+        extracted_tasks,
+        get_loggers_from_work_dir(work_dir, [t.task_name for t in 
extracted_tasks]),
+        fork_seed(seed, n=len(extracted_tasks)),
     ):
-        if target.kind.name != "cuda" and isinstance(
-            autotvm.DispatchContext.current, autotvm.FallbackContext
-        ):
-            tophub_context = autotvm.tophub.context(target)
-        else:
-            tophub_context = autotvm.utils.EmptyContext()
-        with tophub_context:
-            return list(extract_task_func(mod, target, relay_params))
+        tasks.append(
+            TuneContext(
+                mod=task.dispatched[0],
+                target=task.target,
+                space_generator=space,
+                search_strategy=strategy,
+                task_name=task.task_name,
+                logger=logger,
+                rand_state=rand_state,
+                num_threads=num_threads,
+            ).clone()
+        )
+        task_weights.append(task.weight)
+    return tasks, task_weights
+
+
+def tune_relay(

Review Comment:
   Given that this function calls several functions internally, I think it 
would be great if we can include their all or important arguments. For example, 
current function cannot allow customization of `num_threads` for 
`extracted_tasks_to_tune_contexts`. 



##########
python/tvm/meta_schedule/relay_integration.py:
##########
@@ -69,47 +117,218 @@ def extract_task_from_relay(
     """
     # pylint: disable=import-outside-toplevel
     from tvm import autotvm
-    from tvm.relay import Function as RelayFunc
 
     # pylint: enable=import-outside-toplevel
+    mod, target, params = _normalize_params(mod, target, params)
+    if target.kind.name != "cuda" and isinstance(
+        autotvm.DispatchContext.current, autotvm.FallbackContext
+    ):
+        tophub_context = autotvm.tophub.context(target)
+    else:
+        tophub_context = autotvm.utils.EmptyContext()
+    with Profiler.timeit("TaskExtraction"):
+        with target, _autotvm_silencer(), tophub_context:
+            with transform.PassContext(
+                opt_level=3,
+                config={
+                    "relay.backend.use_meta_schedule": True,
+                    "relay.backend.tir_converter": tir_converter,
+                },
+            ):
+                return list(_extract_task(mod, target, params))
 
-    extract_task_func = get_global_func(
-        "relay.backend.MetaScheduleExtractTask",
-        allow_missing=False,
-    )
 
-    if isinstance(mod, RelayFunc):
-        mod = IRModule.from_expr(mod)
-    if not isinstance(target, Target):
-        target = Target(target)
-    if disabled_pass is None:
-        disabled_pass = []
-    if pass_config is None:
-        pass_config = {
-            "relay.backend.use_meta_schedule": True,
-            "relay.backend.tir_converter": tir_converter,
-        }
-    if params is None:
-        params = {}
-    relay_params = {}
-    for name, param in params.items():
-        if isinstance(param, np.ndarray):
-            param = nd.array(param)
-        relay_params[name] = param
+def extracted_tasks_to_tune_contexts(
+    extracted_tasks: List[ExtractedTask],
+    work_dir: str,
+    space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
+    strategy: SearchStrategy.SearchStrategyType = "evolutionary",
+    num_threads: Union[Literal["physical", "logical"], int] = "physical",
+    seed: Optional[int] = None,
+) -> Tuple[List[TuneContext], List[float]]:
+    """Convert ExtractedTask to TuneContext.
 
-    with target, autotvm_silencer(), transform.PassContext(
-        opt_level=opt_level,
-        config=pass_config,
-        disabled_pass=disabled_pass,
+    Parameters
+    ----------
+    tasks : List[ExtractedTask]
+        The tasks to be converted
+    work_dir : str
+        The working directory to store logs and databases
+    space : SpaceGenerator.SpaceGeneratorType
+        The space generator to use.
+    strategy : SearchStrategy.SearchStrategyType
+        The search strategy to use.
+    num_threads : Union[Literal["physical", "logical"], int]

Review Comment:
   Is this `num_threads` to use to speed up the compilation? would you clarify 
this in the comment?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to