This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 648a29a53a [MetaSchedule] Introduce `ScheduleFnDatabase` (#12626)
648a29a53a is described below

commit 648a29a53a641f1e923220600dce9c9215104879
Author: Junru Shao <junrushao1...@gmail.com>
AuthorDate: Mon Aug 29 00:34:11 2022 -0700

    [MetaSchedule] Introduce `ScheduleFnDatabase` (#12626)
    
    Following #12520, this PR introduces `ScheduleFnDatabase`, a mocked
    database to allow injecting handcrafted schedules provided by a schedule
    function.
    
    The schedule function comes with the following signature:
    
    ```python
    def schedule_fn(
      sch: tir.Schedule,
    ) -> bool:
      task_name = sch.mod.attrs["task_name"]
      # ^^^ provides an optional name of the task queried
      ...
    ```
    
    This mocked database helps incorporate the existing testing utility
    `apply_fixed_schedule` more formally into the MetaSchedule-Relay build
    pipeline, and allows further extension to Relax with the same interface.
    
    Next as another follow-up, we will introduce ConcatDatabase that allows
    mixing multiple databases, including the mocked and ones from JSON
    files.
---
 include/tvm/meta_schedule/database.h               |  19 +++-
 python/tvm/meta_schedule/database/__init__.py      |   1 +
 python/tvm/meta_schedule/database/database.py      |  41 ++++++--
 .../{__init__.py => schedule_fn_database.py}       |  29 ++++--
 python/tvm/meta_schedule/testing/utils.py          |  83 -----------------
 src/meta_schedule/database/database.cc             |  13 ++-
 src/meta_schedule/database/memory_database.cc      |  10 +-
 src/meta_schedule/database/schedule_fn_database.cc | 103 +++++++++++++++++++++
 src/relay/backend/te_compiler_cache.cc             |   5 +-
 tests/python/unittest/test_link_params.py          |  15 ++-
 .../unittest/test_meta_schedule_multi_anchor.py    |   8 +-
 .../test_meta_schedule_relay_tir_compute.py        |  18 ++--
 .../unittest/test_meta_schedule_tune_relay.py      |   7 +-
 13 files changed, 210 insertions(+), 142 deletions(-)

diff --git a/include/tvm/meta_schedule/database.h 
b/include/tvm/meta_schedule/database.h
index 0e7f45d393..88db2e2277 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -207,23 +207,29 @@ class DatabaseNode : public runtime::Object {
    * \brief Query the best record of the given workload from the database.
    * \param mod The IRModule to be searched for.
    * \param target The target to be searched for.
+   * \param workload_name The name of the workload to be searched for.
    * \return The best record of the given workload; NullOpt if not found.
    */
-  virtual Optional<TuningRecord> QueryTuningRecord(IRModule mod, Target 
target);
+  virtual Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const 
Target& target,
+                                                   const String& 
workload_name);
   /*!
    * \brief Query the best schedule of the given workload from the database.
    * \param mod The IRModule to be searched for.
    * \param target The target to be searched for.
+   * \param workload_name The name of the workload to be searched for.
    * \return The schedule in the best schedule of the given workload; NullOpt 
if not found.
    */
-  virtual Optional<tir::Schedule> QuerySchedule(IRModule mod, Target target);
+  virtual Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const 
Target& target,
+                                                const String& workload_name);
   /*!
    * \brief Query the best IRModule of the given workload from the database.
    * \param mod The IRModule to be searched for.
    * \param target The target to be searched for.
+   * \param workload_name The name of the workload to be searched for.
    * \return The IRModule in the best IRModule of the given workload; NullOpt 
if not found.
    */
-  virtual Optional<IRModule> QueryIRModule(IRModule mod, Target target);
+  virtual Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& 
target,
+                                           const String& workload_name);
 
   static constexpr const char* _type_key = "meta_schedule.Database";
   TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object);
@@ -336,6 +342,13 @@ class Database : public runtime::ObjectRef {
  public:
   /*! An in-memory database. */
   TVM_DLL static Database MemoryDatabase();
+  /*!
+   * \brief A database for injecting handcrafted schedule functions.
+   * \param schedule_fn The function to do scheduling, which takes a TIR 
schedule,
+   * and returns a boolean indicating if the schedule is successful.
+   */
+  TVM_DLL static Database ScheduleFnDatabase(
+      runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn);
   /*!
    * \brief Create a default database that uses JSON file for tuning records.
    * \param path_workload The path to the workload table.
diff --git a/python/tvm/meta_schedule/database/__init__.py 
b/python/tvm/meta_schedule/database/__init__.py
index 2a87eea147..7726daf6eb 100644
--- a/python/tvm/meta_schedule/database/__init__.py
+++ b/python/tvm/meta_schedule/database/__init__.py
@@ -21,3 +21,4 @@ The database that stores serialized tuning records and 
workloads
 from .database import Database, PyDatabase, TuningRecord, Workload
 from .json_database import JSONDatabase
 from .memory_database import MemoryDatabase
+from .schedule_fn_database import ScheduleFnDatabase
diff --git a/python/tvm/meta_schedule/database/database.py 
b/python/tvm/meta_schedule/database/database.py
index 68283b4554..aa509b7151 100644
--- a/python/tvm/meta_schedule/database/database.py
+++ b/python/tvm/meta_schedule/database/database.py
@@ -235,7 +235,12 @@ class Database(Object):
         """
         return _ffi_api.DatabaseSize(self)  # type: ignore # pylint: 
disable=no-member
 
-    def query_tuning_record(self, mod: IRModule, target: Target) -> 
Optional[TuningRecord]:
+    def query_tuning_record(
+        self,
+        mod: IRModule,
+        target: Target,
+        workload_name: str,
+    ) -> Optional[TuningRecord]:
         """Query the best record of the given workload from the database.
 
         Parameters
@@ -244,15 +249,22 @@ class Database(Object):
             The IRModule to be searched for.
         target : Target
             The target to be searched for.
+        workload_name : str
+            The name of the workload to be searched for.
 
         Returns
         -------
         tuning_record : Optional[TuningRecord]
             The best record of the given workload; None if not found.
         """
-        return _ffi_api.DatabaseQueryTuningRecord(self, mod, target)  # type: 
ignore # pylint: disable=no-member
+        return _ffi_api.DatabaseQueryTuningRecord(self, mod, target, 
workload_name)  # type: ignore # pylint: disable=no-member
 
-    def query_schedule(self, mod: IRModule, target: Target) -> 
Optional[Schedule]:
+    def query_schedule(
+        self,
+        mod: IRModule,
+        target: Target,
+        workload_name: str,
+    ) -> Optional[Schedule]:
         """Query the best schedule of the given workload from the database.
 
         Parameters
@@ -261,15 +273,22 @@ class Database(Object):
             The IRModule to be searched for.
         target : Target
             The target to be searched for.
+        workload_name : str
+            The name of the workload to be searched for.
 
         Returns
         -------
         schedule : Optional[Schedule]
             The best schedule of the given workload; None if not found.
         """
-        return _ffi_api.DatabaseQuerySchedule(self, mod, target)  # type: 
ignore # pylint: disable=no-member
+        return _ffi_api.DatabaseQuerySchedule(self, mod, target, 
workload_name)  # type: ignore # pylint: disable=no-member
 
-    def query_ir_module(self, mod: IRModule, target: Target) -> 
Optional[IRModule]:
+    def query_ir_module(
+        self,
+        mod: IRModule,
+        target: Target,
+        workload_name: str,
+    ) -> Optional[IRModule]:
         """Query the best IRModule of the given workload from the database.
 
         Parameters
@@ -278,18 +297,22 @@ class Database(Object):
             The IRModule to be searched for.
         target : Target
             The target to be searched for.
+        workload_name : str
+            The name of the workload to be searched for.
 
         Returns
         -------
         ir_module : Optional[IRModule]
             The best IRModule of the given workload; None if not found.
         """
-        return _ffi_api.DatabaseQueryIRModule(self, mod, target)  # type: 
ignore # pylint: disable=no-member
+        return _ffi_api.DatabaseQueryIRModule(self, mod, target, 
workload_name)  # type: ignore # pylint: disable=no-member
 
     def query(
         self,
         mod: IRModule,
         target: Target,
+        *,
+        workload_name: str = "main",
         kind: Union[
             Literal["schedule"],
             Literal["record"],
@@ -313,11 +336,11 @@ class Database(Object):
             The best optimization outcome of the given workload.
         """
         if kind == "schedule":
-            return self.query_schedule(mod, target)
+            return self.query_schedule(mod, target, workload_name)
         if kind == "record":
-            return self.query_tuning_record(mod, target)
+            return self.query_tuning_record(mod, target, workload_name)
         if kind == "ir_module":
-            return self.query_ir_module(mod, target)
+            return self.query_ir_module(mod, target, workload_name)
         raise ValueError(f'Unknown kind: {kind}. Candidates are: "schedule", 
"record", "ir_module"')
 
     def __enter__(self) -> "Database":
diff --git a/python/tvm/meta_schedule/database/__init__.py 
b/python/tvm/meta_schedule/database/schedule_fn_database.py
similarity index 55%
copy from python/tvm/meta_schedule/database/__init__.py
copy to python/tvm/meta_schedule/database/schedule_fn_database.py
index 2a87eea147..2918f05799 100644
--- a/python/tvm/meta_schedule/database/__init__.py
+++ b/python/tvm/meta_schedule/database/schedule_fn_database.py
@@ -14,10 +14,25 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""
-The tvm.meta_schedule.database package.
-The database that stores serialized tuning records and workloads
-"""
-from .database import Database, PyDatabase, TuningRecord, Workload
-from .json_database import JSONDatabase
-from .memory_database import MemoryDatabase
+"""A database for injecting handcrafted schedule functions."""
+from typing import Callable
+
+from tvm._ffi import register_object
+from tvm.tir import Schedule
+
+from .. import _ffi_api
+from .database import Database
+
+
+@register_object("meta_schedule.ScheduleFnDatabase")
+class ScheduleFnDatabase(Database):
+    """A database for injecting handcrafted schedule functions."""
+
+    def __init__(
+        self,
+        schedule_fn: Callable[[Schedule], bool],
+    ) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.DatabaseScheduleFnDatabase,  # type: ignore # pylint: 
disable=no-member
+            schedule_fn,
+        )
diff --git a/python/tvm/meta_schedule/testing/utils.py 
b/python/tvm/meta_schedule/testing/utils.py
deleted file mode 100644
index 5919fb47c8..0000000000
--- a/python/tvm/meta_schedule/testing/utils.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Testing utility functions in meta schedule"""
-from typing import Callable, Dict, Optional, Union
-
-from tvm import meta_schedule as ms
-from tvm.ir import IRModule, transform
-from tvm.relay import Function as RelayFunc
-from tvm.runtime import NDArray
-from tvm.target import Target
-from tvm.tir import Schedule
-
-
-def apply_fixed_schedules(
-    relay_mod: Union[RelayFunc, IRModule],
-    target: Union[str, Target],
-    params: Optional[Dict[str, NDArray]],
-    schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool],
-    tir_converter: str = "default",
-):
-    """Apply fixed schedules (manually written, without any tunable knobs) as 
specified by
-    schedule_fn to extracted tasks, and return a database that can be passed 
to compilation.
-
-    Parameters
-    ----------
-    mod : Union[RelayFunc, IRModule]
-        The Relay module to apply fixed schedules.
-    target : Union[str, Target]
-        The target used to extract tasks.
-    params : Optional[Dict[str, tvm.runtime.NDArray]]
-        The associated parameters of the module.
-    schedule_fn : Callable[[ExtractedTask, Schedule], bool]
-        A callable that is applied for each extracted task and the 
corresponding default schedule.
-        Returns True if the given schedule should be committed to the 
database, False otherwise.
-    tir_converter : str
-        The filter function to filter out the extracted tasks. Builtin filters:
-          - "default"
-          - "allow_extern"
-        The converter is a PackedFunc registered as 
f"relay.backend.tir_converter.{tir_converter}",
-        with the signature below:
-            (args: List[te.Tensor], constants: List[NDArray]) -> 
Optional[tir.PrimFunc]
-
-    Returns
-    -------
-    database : Database
-        The database containing dummy tuning records for manually scheduled 
traces.
-    """
-    target = Target(target) if isinstance(target, str) else target
-    config = {"relay.backend.use_meta_schedule": True}
-    for k, v in transform.PassContext.current().config.items():
-        config[k] = v
-
-    extracted_tasks = ms.extract_task_from_relay(
-        relay_mod,
-        target,
-        params,
-        tir_converter=tir_converter,
-    )
-    database = ms.database.MemoryDatabase()
-    for task in extracted_tasks:
-        mod = ms.default_config.mod(task.dispatched[0])
-        sch = Schedule(mod)
-
-        if schedule_fn(task, sch):
-            workload = database.commit_workload(mod)
-            tune_rec = ms.database.TuningRecord(sch.trace, workload, [0.0], 
target, [])
-            database.commit_tuning_record(tune_rec)
-
-    return database
diff --git a/src/meta_schedule/database/database.cc 
b/src/meta_schedule/database/database.cc
index fedd2aa352..d082ff7a39 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -156,7 +156,8 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& 
json_obj, const Workload& w
 
 /******** Database ********/
 
-Optional<TuningRecord> DatabaseNode::QueryTuningRecord(IRModule mod, Target 
target) {
+Optional<TuningRecord> DatabaseNode::QueryTuningRecord(const IRModule& mod, 
const Target& target,
+                                                       const String& 
workload_name) {
   if (!this->HasWorkload(mod)) {
     return NullOpt;
   }
@@ -168,8 +169,9 @@ Optional<TuningRecord> 
DatabaseNode::QueryTuningRecord(IRModule mod, Target targ
   return records[0];
 }
 
-Optional<tir::Schedule> DatabaseNode::QuerySchedule(IRModule mod, Target 
target) {
-  if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod, 
target)) {
+Optional<tir::Schedule> DatabaseNode::QuerySchedule(const IRModule& mod, const 
Target& target,
+                                                    const String& 
workload_name) {
+  if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod, target, 
workload_name)) {
     TuningRecord record = opt_record.value();
     tir::Schedule sch =
         tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, 
/*debug_mask=*/0,
@@ -181,8 +183,9 @@ Optional<tir::Schedule> 
DatabaseNode::QuerySchedule(IRModule mod, Target target)
   }
 }
 
-Optional<IRModule> DatabaseNode::QueryIRModule(IRModule mod, Target target) {
-  if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target)) {
+Optional<IRModule> DatabaseNode::QueryIRModule(const IRModule& mod, const 
Target& target,
+                                               const String& workload_name) {
+  if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target, 
workload_name)) {
     return opt_sch.value()->mod();
   } else {
     return NullOpt;
diff --git a/src/meta_schedule/database/memory_database.cc 
b/src/meta_schedule/database/memory_database.cc
index a00d5501ad..b6c6355551 100644
--- a/src/meta_schedule/database/memory_database.cc
+++ b/src/meta_schedule/database/memory_database.cc
@@ -44,7 +44,7 @@ class MemoryDatabaseNode : public DatabaseNode {
     return false;
   }
 
-  Workload CommitWorkload(const IRModule& mod) {
+  Workload CommitWorkload(const IRModule& mod) final {
     for (const auto& workload : workloads) {
       if (StructuralEqual()(workload->mod, mod)) {
         return workload;
@@ -55,9 +55,9 @@ class MemoryDatabaseNode : public DatabaseNode {
     return workload;
   }
 
-  void CommitTuningRecord(const TuningRecord& record) { 
records.push_back(record); }
+  void CommitTuningRecord(const TuningRecord& record) final { 
records.push_back(record); }
 
-  Array<TuningRecord> GetTopK(const Workload& workload, int top_k) {
+  Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
     std::vector<std::pair<double, TuningRecord>> results;
     results.reserve(this->records.size());
     for (const TuningRecord& record : records) {
@@ -91,9 +91,9 @@ class MemoryDatabaseNode : public DatabaseNode {
     return ret;
   }
 
-  Array<TuningRecord> GetAllTuningRecords() { return records; }
+  Array<TuningRecord> GetAllTuningRecords() final { return records; }
 
-  int64_t Size() { return records.size(); }
+  int64_t Size() final { return records.size(); }
 };
 
 Database Database::MemoryDatabase() {
diff --git a/src/meta_schedule/database/schedule_fn_database.cc 
b/src/meta_schedule/database/schedule_fn_database.cc
new file mode 100644
index 0000000000..751721fe52
--- /dev/null
+++ b/src/meta_schedule/database/schedule_fn_database.cc
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+class ScheduleFnDatabaseNode : public DatabaseNode {
+ public:
+  runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `schedule_fn` is not visited.
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.ScheduleFnDatabase";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnDatabaseNode, DatabaseNode);
+
+ public:
+  Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& 
target,
+                                           const String& workload_name) final {
+    if (Optional<tir::Schedule> sch = this->QuerySchedule(mod, target, 
workload_name)) {
+      return TuningRecord(sch.value()->trace().value(),
+                          /*workload=*/Workload(mod, 0),  //
+                          /*run_secs=*/NullOpt,           //
+                          /*target=*/target,              //
+                          /*arg_info=*/NullOpt);
+    }
+    return NullOpt;
+  }
+
+  Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& 
target,
+                                        const String& workload_name) final {
+    tir::Schedule sch =
+        tir::Schedule::Traced(WithAttr<IRModule>(mod, "task_name", 
workload_name),
+                              /*rand_state=*/-1,
+                              /*debug_mode=*/0,
+                              
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
+    if (!schedule_fn(sch)) {
+      return NullOpt;
+    }
+    return sch;
+  }
+
+  bool HasWorkload(const IRModule& mod) final {
+    LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.HasWorkload";
+    throw;
+  }
+
+  Workload CommitWorkload(const IRModule& mod) final {
+    LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitWorkload";
+    throw;
+  }
+
+  void CommitTuningRecord(const TuningRecord& record) final {
+    LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitTuningRecord";
+    throw;
+  }
+
+  Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
+    LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetTopK";
+    throw;
+  }
+
+  Array<TuningRecord> GetAllTuningRecords() final {
+    LOG(FATAL) << "NotImplementedError: 
ScheduleFnDatabase.GetAllTuningRecords";
+    throw;
+  }
+
+  int64_t Size() final {
+    LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.size";
+    throw;
+  }
+};
+
+Database 
Database::ScheduleFnDatabase(runtime::TypedPackedFunc<bool(tir::Schedule)> 
schedule_fn) {
+  ObjectPtr<ScheduleFnDatabaseNode> n = make_object<ScheduleFnDatabaseNode>();
+  n->schedule_fn = std::move(schedule_fn);
+  return Database(n);
+}
+
+TVM_REGISTER_NODE_TYPE(ScheduleFnDatabaseNode);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase")
+    .set_body_typed(Database::ScheduleFnDatabase);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/src/relay/backend/te_compiler_cache.cc 
b/src/relay/backend/te_compiler_cache.cc
index 0e2a3e2702..1d7566ebe2 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -367,7 +367,8 @@ class ScheduleBuilder : public ExprVisitor {
         if (Optional<PrimFunc> f = tir_converter(te_args, constants)) {
           if (Optional<TuningRecord> opt_record = 
database_.value()->QueryTuningRecord(
                   /*mod=*/backend::PrimFuncToIRModule(f.value()),
-                  /*target=*/target_)) {
+                  /*target=*/target_,
+                  /*workload_name=*/prim_fn_var->name_hint)) {
             static InstructionKind kind_transform_layout = 
InstructionKind::Get("TransformLayout");
             TuningRecord record = opt_record.value();
             for (const Instruction& inst : record->trace->insts) {
@@ -383,6 +384,8 @@ class ScheduleBuilder : public ExprVisitor {
             ICHECK_EQ(mod->functions.size(), 1);
             mod = 
tir::transform::RemoveWeightLayoutRewriteBlock()(std::move(mod));
             prim_func = Downcast<PrimFunc>(mod->Lookup("main"));
+          } else {
+            LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint;
           }
         }
       }
diff --git a/tests/python/unittest/test_link_params.py 
b/tests/python/unittest/test_link_params.py
index c741ecb59a..b14c18e55f 100644
--- a/tests/python/unittest/test_link_params.py
+++ b/tests/python/unittest/test_link_params.py
@@ -29,7 +29,6 @@ import tvm.testing
 from tvm import meta_schedule as ms
 from tvm import relay
 from tvm.contrib import utils
-from tvm.meta_schedule.testing.utils import apply_fixed_schedules
 from tvm.relay.backend import Executor, Runtime
 
 INPUT_SHAPE = (1, 3, 16, 16)
@@ -407,21 +406,21 @@ def test_tir_link_params():
     target = "llvm"
     params = {"weight": weight_np}
 
-    def schedule_fn(task, sch):
-        if "nn_dense" in task.task_name:
+    def schedule_fn(sch):
+        if "nn_dense" in sch.mod.attrs["task_name"]:
             schedule_dense(sch)
             return True
         return False
 
     link_params = True
 
-    with tvm.transform.PassContext(config={"relay.FuseOps.link_params": 
link_params}):
-        database = apply_fixed_schedules(relay_mod, target, params, 
schedule_fn)
-
     with StringIO() as stderr_buf, redirect_stderr(stderr_buf):
-        with database, tvm.transform.PassContext(
+        with ms.database.ScheduleFnDatabase(schedule_fn), 
tvm.transform.PassContext(
             opt_level=3,
-            config={"relay.backend.use_meta_schedule": True},
+            config={
+                "relay.backend.use_meta_schedule": True,
+                "relay.FuseOps.link_params": link_params,
+            },
         ):
             executor = Executor("graph", {"link-params": link_params})
             lib = relay.build(relay_mod, target=target, executor=executor)
diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py 
b/tests/python/unittest/test_meta_schedule_multi_anchor.py
index 1770017811..cb6f59c6e5 100644
--- a/tests/python/unittest/test_meta_schedule_multi_anchor.py
+++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py
@@ -19,7 +19,6 @@ import tvm
 import tvm.testing
 from tvm import meta_schedule as ms
 from tvm import relay
-from tvm.meta_schedule.testing.utils import apply_fixed_schedules
 
 
 def get_dense_dense(data_shape, weight_shape):
@@ -63,14 +62,13 @@ def test_dense_dense():
     target = "llvm"
     params = {"weight1": weight1_np, "weight2": weight2_np}
 
-    def schedule_fn(task, sch):
-        if "nn_dense_nn_dense" in task.task_name:
+    def schedule_fn(sch):
+        if "nn_dense_nn_dense" in sch.mod.attrs["task_name"]:
             schedule_dense_dense(sch)
             return True
         return False
 
-    database = apply_fixed_schedules(relay_mod, target, params, schedule_fn)
-    with database:
+    with ms.database.ScheduleFnDatabase(schedule_fn):
         with tvm.transform.PassContext(
             opt_level=3,
             config={"relay.backend.use_meta_schedule": True},
diff --git a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py 
b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py
index 939851a657..b373338036 100644
--- a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py
+++ b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py
@@ -18,8 +18,9 @@ import numpy as np
 import tvm
 import tvm.testing
 import tvm.topi.testing
-from tvm import autotvm, relay, te
-from tvm.meta_schedule.testing.utils import apply_fixed_schedules
+from tvm import autotvm
+from tvm import meta_schedule as ms
+from tvm import relay, te
 from tvm.relay.testing.temp_op_attr import TempOpAttr
 from tvm.script import tir as T
 
@@ -139,21 +140,14 @@ def test_conv2d():
     target = "llvm"
     params = {"weight": weight_np}
 
-    def schedule_fn(task, sch):
-        if "nn_conv2d" in task.task_name:
+    def schedule_fn(sch):
+        if "nn_conv2d" in sch.mod.attrs["task_name"]:
             schedule_tir_conv2d_nchw_oihw(sch)
             return True
         return False
 
     with TempOpAttr("nn.conv2d", "FTVMStrategy", _tmp_strategy):
-        database = apply_fixed_schedules(
-            relay_mod,
-            target,
-            params,
-            schedule_fn,
-            tir_converter="allow_extern",
-        )
-        with database, tvm.transform.PassContext(
+        with ms.database.ScheduleFnDatabase(schedule_fn), 
tvm.transform.PassContext(
             opt_level=3,
             config={
                 "relay.backend.use_meta_schedule": True,
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py 
b/tests/python/unittest/test_meta_schedule_tune_relay.py
index bc37fed7d6..b05b57feaf 100644
--- a/tests/python/unittest/test_meta_schedule_tune_relay.py
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -29,7 +29,6 @@ from tvm._ffi import register_func
 from tvm.contrib import graph_executor
 from tvm.ir import IRModule
 from tvm.meta_schedule.testing.relay_workload import get_network
-from tvm.meta_schedule.testing.utils import apply_fixed_schedules
 from tvm.script import tir as T
 from tvm.target.target import Target
 from tvm.tir.schedule import BlockRV, Schedule
@@ -452,8 +451,8 @@ def manual_tir_common(do_tune=False):
             )
     else:
 
-        def schedule_fn(task, sch):
-            if "dense" not in task.task_name:
+        def schedule_fn(sch) -> bool:
+            if "dense" not in sch.mod.attrs["task_name"]:
                 return False
 
             block = sch.get_block("compute")
@@ -468,7 +467,7 @@ def manual_tir_common(do_tune=False):
 
             return True
 
-        database = apply_fixed_schedules(relay_mod, target, params, 
schedule_fn)
+        database = ms.database.ScheduleFnDatabase(schedule_fn)
 
     with database, tvm.transform.PassContext(
         opt_level=3,

Reply via email to