This is an automated email from the ASF dual-hosted git repository. jwfromm pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 75c31cae75 [Relay] Bug fix when applying history using an iterator or records. (#11306) 75c31cae75 is described below commit 75c31cae75fe31af9e0901210ba7fa597e6f153a Author: Josh Fromm <jwfr...@octoml.ai> AuthorDate: Tue May 17 16:17:48 2022 -0700 [Relay] Bug fix when applying history using an iterator or records. (#11306) * Bug fix when applying history using an iterator or records. * I forgot strings are iterables. --- python/tvm/auto_scheduler/dispatcher.py | 3 ++- python/tvm/autotvm/task/dispatcher.py | 5 +++-- tests/python/relay/test_auto_scheduler_tuning.py | 7 +++++++ tests/python/unittest/test_autotvm_record.py | 5 +++++ 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index eceeba38e0..98566f8636 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -25,6 +25,7 @@ as a schedule configuration here. import logging import pathlib +from collections.abc import Iterable import numpy as np @@ -199,7 +200,7 @@ class ApplyHistoryBest(DispatchContext): if it is not None, only load the first `n_lines` lines of log """ joint_records = [] - if not isinstance(records, (list, tuple)): + if not isinstance(records, Iterable) or isinstance(records, str): records = [records] for rec in records: diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index ffff50b9dc..6c072dc1fa 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -31,6 +31,7 @@ of the DispatchContext base class. from __future__ import absolute_import as _abs import logging +from collections.abc import Iterable import numpy as np @@ -212,7 +213,7 @@ class ApplyHistoryBest(DispatchContext): Collection of tuning records. If is str, then it should be the filename of a records log file. Each row of this file is an encoded record pair. If it is a list - it can either be a list of paths to logs that will loaded jointly or + it can either be a list of paths to logs that will be loaded jointly or an iterator of measurement results. """ # pylint: disable=import-outside-toplevel @@ -220,7 +221,7 @@ class ApplyHistoryBest(DispatchContext): from ..record import load_from_file joint_records = [] - if not isinstance(records, (list, tuple)): + if not isinstance(records, Iterable) or isinstance(records, str): records = [records] for rec in records: diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index c9ce5b59ff..735486ef27 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -62,6 +62,13 @@ def tune_network(network, target): best, auto_scheduler.dispatcher.ApplyHistoryBest ), "Unable to load multiple log files jointly." + # Confirm iterables can be directly loaded. + loaded_recs = auto_scheduler.dispatcher.load_records(log_file) + with auto_scheduler.ApplyHistoryBest(iter(loaded_recs)) as best: + assert isinstance( + best, auto_scheduler.dispatcher.ApplyHistoryBest + ), "Unable to ingest logs from an interator." + # Sample a schedule when missing with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2): with tvm.transform.PassContext( diff --git a/tests/python/unittest/test_autotvm_record.py b/tests/python/unittest/test_autotvm_record.py index 2ee75cf18c..147122ff10 100644 --- a/tests/python/unittest/test_autotvm_record.py +++ b/tests/python/unittest/test_autotvm_record.py @@ -91,6 +91,11 @@ def test_apply_history_best(): x = hist_best.query(target, tsk.workload) assert str(x) == str(tsk.config_space.get(2)) + # Confirm same functionality for iterators. + hist_best = ApplyHistoryBest(iter(records)) + x = hist_best.query(target, tsk.workload) + assert str(x) == str(tsk.config_space.get(2)) + if __name__ == "__main__": test_load_dump()