This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 19e82ecb3a [Unity] MetaScheduleApplyDatabase using workload from 
records (#14702)
19e82ecb3a is described below

commit 19e82ecb3a539a42a4726394720e46a4cd3532e5
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Apr 23 00:43:55 2023 -0400

    [Unity] MetaScheduleApplyDatabase using workload from records (#14702)
    
    This PR fixes an issue in MetaScheduleApplyDatabase, which uses the
    IRModule being queried as the base module of TIR Schedule. This will
    cause the inconsistency between the IRModule being scheduled and the
    trace of the record. For example, the trace of the record may have
    `get-block` instructions which tries to get block from a given name,
    which exists in the IRModule of the record while does not exist in the
    IRModule being queried (which is right the IRModule being scheduled).
    
    Therefore, this PR adds a case discussion. When anchor-op equality is
    not applied, we create the TIR schedule from the IRModule of the record,
    so that we can prevent the issue above from happening.
    
    This PR adds a unit test that can demonstrate the issue.
---
 src/relax/transform/meta_schedule.cc               | 10 ++-
 .../test_transform_meta_schedule_apply_database.py | 84 ++++++++++++++++++++++
 2 files changed, 91 insertions(+), 3 deletions(-)

diff --git a/src/relax/transform/meta_schedule.cc 
b/src/relax/transform/meta_schedule.cc
index 03456d0ef8..d84eecc89d 100644
--- a/src/relax/transform/meta_schedule.cc
+++ b/src/relax/transform/meta_schedule.cc
@@ -126,15 +126,19 @@ Pass MetaScheduleApplyDatabase(Optional<String> work_dir, 
bool enable_warning =
         if (Optional<meta_schedule::TuningRecord> opt_record =
                 database->QueryTuningRecord(tir_mod, target, gv->name_hint)) {
           meta_schedule::TuningRecord record = opt_record.value();
-          tir::Schedule sch =
-              tir::Schedule::Traced(tir_mod, /*seed=*/-1, /*debug_mask=*/0,
-                                    
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
+          tir::Schedule sch{nullptr};
           if (!mod_eq_structural->Equal(tir_mod, record->workload->mod)) {
             // When the database lookup succeeds while structural equality 
check fails,
             // it implies that the anchor block based equality has been used 
during tuning.
             // The trace in the record cannot directly be applied to this 
query module.
+            sch = tir::Schedule::Traced(
+                tir_mod, /*seed=*/-1, /*debug_mask=*/0,
+                /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
             meta_schedule::ScheduleUsingAnchorTrace(sch, record->trace, 
target);
           } else {
+            sch = tir::Schedule::Traced(
+                record->workload->mod, /*seed=*/-1, /*debug_mask=*/0,
+                /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
             record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false);
           }
           IRModule new_mod = sch->mod();
diff --git a/tests/python/relax/test_transform_meta_schedule_apply_database.py 
b/tests/python/relax/test_transform_meta_schedule_apply_database.py
new file mode 100644
index 0000000000..d388ccab43
--- /dev/null
+++ b/tests/python/relax/test_transform_meta_schedule_apply_database.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm import meta_schedule as ms
+from tvm import relax
+from tvm.script import ir as I, tir as T
+
+target = tvm.target.Target("llvm --num-cores=16")
+
+
+def test_apply_to_func_with_different_block_name():
+    @I.ir_module
+    class RecordModule:
+        @T.prim_func
+        def main(A: T.Buffer((2,), "float32"), B: T.Buffer((2,), "float32")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+            for i in T.serial(2):
+                with T.block("block"):
+                    vi = T.axis.spatial(2, i)
+                    B[vi] = A[vi]
+
+    @I.ir_module
+    class BlockRenamedModule:
+        @T.prim_func
+        def main(A: T.Buffer((2,), "float32"), B: T.Buffer((2,), "float32")):
+            T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+            for i in T.serial(2):
+                with T.block("renamed_block"):
+                    vi = T.axis.spatial(2, i)
+                    B[vi] = A[vi]
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def main(A: T.Buffer((2,), "float32"), B: T.Buffer((2,), "float32")):
+            T.func_attr(
+                {
+                    "tir.is_scheduled": T.bool(True),
+                    "global_symbol": "main",
+                    "tir.noalias": T.bool(True),
+                }
+            )
+            for i in T.serial(2):
+                with T.block("renamed_block"):
+                    vi = T.axis.spatial(2, i)
+                    B[vi] = A[vi]
+
+    def create_trace(mod: tvm.IRModule):
+        sch = tir.Schedule(mod)
+        _ = sch.get_block("block")
+        return sch.trace
+
+    db = ms.database.create(kind="memory")
+    db.commit_workload(RecordModule)
+    db.commit_tuning_record(
+        ms.database.TuningRecord(
+            create_trace(RecordModule), ms.database.Workload(RecordModule), 
[0.0], target
+        )
+    )
+
+    with db, target:
+        mod = relax.transform.MetaScheduleApplyDatabase()(BlockRenamedModule)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to