This is an automated email from the ASF dual-hosted git repository. zhaowu pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new a867bcb [Auto Scheduler] Add target host to measure record (#7046) a867bcb is described below commit a867bcbf1ecf537cfb061a2ca4790b16a9cc9748 Author: Zhao Wu <zha...@apache.org> AuthorDate: Tue Dec 8 14:46:29 2020 +0800 [Auto Scheduler] Add target host to measure record (#7046) * [Auto Scheduler] Add target host to measure record * Fix PyLint * Fix lint * Solve the serialization logic when we don't have hardware params * update auto scheduler log --- src/auto_scheduler/measure_record.cc | 12 ++++++++-- .../python/unittest/test_auto_scheduler_measure.py | 26 ++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index d57e2f2..aad0abe 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -163,6 +163,9 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->WriteArrayItem(std::string(data.workload_key)); writer->WriteArrayItem(data.target->str()); writer->WriteArrayItem(*data.hardware_params.get()); + if (data.target_host.defined()) { + writer->WriteArrayItem(data.target_host->str()); + } writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) { @@ -183,7 +186,12 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { reader->Read(hardware_params_node.get()); s = reader->NextArrayItem(); data->hardware_params = ::tvm::auto_scheduler::HardwareParams(hardware_params_node); - ICHECK(!s); + if (s) { + reader->Read(&str_value); + data->target_host = ::tvm::Target(str_value); + s = reader->NextArrayItem(); + ICHECK(!s); + } } } }; @@ -271,7 +279,7 @@ namespace auto_scheduler { TVM_REGISTER_OBJECT_TYPE(RecordToFileNode); TVM_REGISTER_OBJECT_TYPE(RecordReaderNode); -const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.3"; // NOLINT(*) +const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*) RecordToFile::RecordToFile(String filename) { auto node = make_object<RecordToFileNode>(); diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index b214d9c..10bb0b4 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -250,6 +250,31 @@ def test_measure_local_builder_rpc_runner_spawn(): p.join() +@tvm.testing.requires_llvm +def test_measure_target_host(): + task = auto_scheduler.SearchTask( + func=matmul_auto_scheduler_test, + args=(512, 512, 512), + target="llvm", + target_host="llvm -mtriple=aarch64-linux-gnu", + ) + + inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) + res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + + with tempfile.NamedTemporaryFile() as fp: + auto_scheduler.save_records(fp.name, [inp], [res]) + + log_reader = auto_scheduler.RecordReader(fp.name) + inputs, results = log_reader.read_lines() + assert len(inputs) == 1 + + raw_inp = inputs[0] + + recovered_inp = auto_scheduler.measure.recover_measure_input(raw_inp) + assert str(recovered_inp.task.target_host) == str(inp.task.target_host) + + if __name__ == "__main__": test_record_split_reorder_fuse_annotation() test_record_compute_at_root_inline_cache_read_write() @@ -258,3 +283,4 @@ if __name__ == "__main__": test_recover_measure_input() test_measure_local_builder_runner() test_measure_local_builder_rpc_runner() + test_measure_target_host()