This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 23371cadb4 [Unity][MSC][M1.4] Add Runner and test with relax (#15997)
23371cadb4 is described below
commit 23371cadb40241af80db4350cd62254c0f45715a
Author: Archermmt <[email protected]>
AuthorDate: Thu Nov 2 16:21:03 2023 +0800
[Unity][MSC][M1.4] Add Runner and test with relax (#15997)
* add relax runner
* minor fix
* update runner
---
python/tvm/contrib/msc/core/codegen/codegen.py | 2 +-
.../contrib/msc/core/{ir => frontend}/__init__.py | 3 +-
python/tvm/contrib/msc/core/frontend/translate.py | 341 +++++++++
python/tvm/contrib/msc/core/ir/__init__.py | 1 -
.../msc/{framework => core/runtime}/__init__.py | 4 +-
python/tvm/contrib/msc/core/runtime/runner.py | 818 +++++++++++++++++++++
python/tvm/contrib/msc/core/utils/__init__.py | 2 +
python/tvm/contrib/msc/core/utils/dataset.py | 78 +-
python/tvm/contrib/msc/core/utils/file.py | 112 ++-
python/tvm/contrib/msc/core/utils/info.py | 166 ++++-
python/tvm/contrib/msc/core/utils/log.py | 132 ++++
python/tvm/contrib/msc/core/utils/message.py | 133 ++++
python/tvm/contrib/msc/core/utils/namespace.py | 4 +
python/tvm/contrib/msc/framework/__init__.py | 2 +-
.../msc/framework/tensorflow/codegen/codegen.py | 9 +-
.../msc/framework/tensorflow/frontend/__init__.py | 2 +
.../msc/framework/tensorflow/frontend/translate.py | 15 +-
.../msc/framework/tensorrt/codegen/codegen.py | 4 +-
.../msc/framework/tensorrt/frontend/__init__.py | 2 +
.../msc/framework/tensorrt/frontend/translate.py | 14 +-
.../contrib/msc/framework/torch/codegen/codegen.py | 10 +-
.../msc/framework/torch/frontend/__init__.py | 2 +
.../msc/framework/torch/frontend/translate.py | 67 +-
.../contrib/msc/framework/tvm/codegen/codegen.py | 10 +-
.../msc/framework/{ => tvm/runtime}/__init__.py | 4 +-
.../contrib/msc/framework/tvm/runtime/runner.py | 129 ++++
tests/python/contrib/test_msc/test_runner.py | 89 +++
.../contrib/test_msc/test_translate_tensorrt.py | 5 +-
28 files changed, 2057 insertions(+), 103 deletions(-)
diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py
b/python/tvm/contrib/msc/core/codegen/codegen.py
index ceb322b03b..f0884bf2d6 100644
--- a/python/tvm/contrib/msc/core/codegen/codegen.py
+++ b/python/tvm/contrib/msc/core/codegen/codegen.py
@@ -23,7 +23,7 @@ from typing import Dict, List, Optional, Any, Callable
import tvm
from tvm.relax.transform import BindParams
from tvm.contrib.msc.core.ir import MSCGraph
-from tvm.contrib.msc.core.ir.translate import from_relay
+from tvm.contrib.msc.core.frontend import from_relay
from tvm.contrib.msc.core import utils as msc_utils
diff --git a/python/tvm/contrib/msc/core/ir/__init__.py
b/python/tvm/contrib/msc/core/frontend/__init__.py
similarity index 94%
copy from python/tvm/contrib/msc/core/ir/__init__.py
copy to python/tvm/contrib/msc/core/frontend/__init__.py
index 81a34bedb6..a5fd7a01a8 100644
--- a/python/tvm/contrib/msc/core/ir/__init__.py
+++ b/python/tvm/contrib/msc/core/frontend/__init__.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.core.ir"""
+"""tvm.contrib.msc.core.frontend"""
-from .graph import *
from .translate import *
diff --git a/python/tvm/contrib/msc/core/frontend/translate.py
b/python/tvm/contrib/msc/core/frontend/translate.py
new file mode 100644
index 0000000000..e1dce2ae28
--- /dev/null
+++ b/python/tvm/contrib/msc/core/frontend/translate.py
@@ -0,0 +1,341 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.core.frontend.translate"""
+
+from typing import Dict, Optional, Tuple, List
+
+import tvm
+from tvm.relax.transform import BindParams
+from tvm.relax import PyExprVisitor
+from tvm.relax.backend.pattern_registry import get_patterns_with_prefix
+from tvm.relay.expr_functor import ExprVisitor
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay import dataflow_pattern as relay_pattern
+from tvm.contrib.msc.core import transform as msc_transform
+from tvm.contrib.msc.core import _ffi_api
+from tvm.contrib.msc.core import utils as msc_utils
+from tvm.contrib.msc.core.ir import MSCGraph, MSCTensor
+
+
+def normalize_weights(
+ t_weights: Dict[MSCTensor, tvm.nd.array], graph: MSCGraph
+) -> Dict[str, tvm.nd.array]:
+ """Normalize the weghts.
+
+ Parameters
+ ----------
+ t_weights: dict of <MSCTensor, tvm.nd.array>
+ The weights extracted from IRModule.
+ graph: tvm.contrib.msc.core.ir.MSCGraph
+ The translated graph.
+
+ Returns
+ -------
+ weights: dict of <string:tvm.ndarray>
+ The normalized weights.
+ """
+
+ def _to_data(ref_t, data):
+ weight_t = graph.find_tensor(ref_t.name)
+ if weight_t.ndim == 1:
+ if ref_t.ndim != weight_t.ndim:
+ return
tvm.nd.array(data.asnumpy().reshape(weight_t.get_shape()))
+ return data
+ if ref_t.layout and weight_t.layout:
+ ref_layout, weight_layout = ref_t.layout.name, weight_t.layout.name
+ if ref_layout != weight_layout:
+ assert all(
+ l in ref_layout for l in weight_layout
+ ), "layout mismatch {} compare to {}".format(ref_t, weight_t)
+ permute = [ref_layout.index(l) for l in weight_layout]
+ return tvm.nd.array(data.asnumpy().transpose(*permute))
+ return data
+
+ weights = {t.name: _to_data(t, d) for t, d in t_weights.items() if
graph.has_tensor(t.name)}
+ return weights
+
+
+def from_relax(
+ mod: tvm.IRModule,
+ params: Optional[Dict[str, tvm.nd.array]] = None,
+ trans_config: Optional[Dict[str, str]] = None,
+ build_config: Optional[Dict[str, str]] = None,
+ opt_config: Optional[Dict[str, str]] = None,
+) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]:
+ """Change IRModule to MSCGraph.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The IRModule of relax.
+ params: dict of <string:tvm.ndarray>
+ The parameters of the IRModule.
+ trans_config: dict
+ The config for transform IRModule.
+ build_config: dict
+ The config for build MSCGraph.
+ opt_config: dict
+ The config for optimize the relax before translate.
+
+ Returns
+ -------
+ graph: tvm.contrib.msc.core.ir.MSCGraph
+ The translated graph.
+ weights: dict of <string:tvm.ndarray>
+ The weights from the IRModule.
+ """
+
+ trans_config = msc_utils.copy_dict(trans_config)
+ build_config = msc_utils.copy_dict(build_config)
+ opt_config = msc_utils.copy_dict(opt_config)
+ entry = trans_config.get("entry", "main")
+ if params:
+ mod = BindParams("main", params)(mod)
+ opt_level = opt_config.get("opt_level", 1)
+ if opt_level > 0:
+ mod = tvm.transform.Sequential(
+ [
+ tvm.relax.transform.FoldConstant(),
+ ]
+ )(mod)
+ patterns = get_patterns_with_prefix("msc.")
+ passes = [
+ tvm.relax.transform.FuseOpsByPattern(
+ patterns, bind_constants=False, annotate_codegen=False
+ ),
+ msc_transform.SetExprName(entry_name=entry,
target=trans_config.get("target", "")),
+ msc_transform.SetExprLayout(
+ trans_config.get("allow_layout_missing", True), entry_name=entry
+ ),
+ ]
+ mod = tvm.transform.Sequential(passes)(mod)
+ graph = _ffi_api.BuildFromRelax(mod, entry,
msc_utils.dump_dict(build_config))
+ t_weights = _ffi_api.GetRelaxWeights(mod, entry)
+ return graph, normalize_weights(t_weights, graph)
+
+
+def get_relay_patterns(
+ mod: tvm.IRModule,
+ entry_name: str = "main",
+) -> List[Tuple[str, relay_pattern.DFPattern, callable]]:
+ """Filter relay patterns based on mod.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The IRModule of relay.
+ entry_name: str
+ The entry name.
+
+ Returns
+ -------
+ patterns: list
+ The useful patterns for relay
+ """
+
+ class OpExtractor(ExprVisitor):
+ """Extract ops from expr."""
+
+ def extract(self, expr):
+ self._optypes = set()
+ super().visit(expr)
+ return self._optypes
+
+ def visit_call(self, expr):
+ super().visit_call(expr)
+ if isinstance(expr.op, tvm.ir.Op):
+ self._optypes.add(expr.op.name)
+
+ op_names = OpExtractor().extract(mod[entry_name])
+ skip_tags, patterns = set(),
list(tvm.relay.op.contrib.get_pattern_table("msc"))
+ if "nn.conv1d" not in op_names or "add" not in op_names:
+ skip_tags.add("msc.conv1d_bias")
+ if "nn.conv2d" not in op_names or "add" not in op_names:
+ skip_tags.add("msc.conv2d_bias")
+ if "nn.batch_matmul" not in op_names or "add" not in op_names:
+ skip_tags.add("msc.linear_bias")
+ if "nn.batch_matmul" not in op_names:
+ skip_tags |= set(p[0] for p in patterns if
p[0].startswith("msc.linear"))
+ if "nn.dense" not in op_names:
+ skip_tags |= set(p[0] for p in patterns if
p[0].startswith("msc.matmul"))
+ if "take" not in op_names:
+ skip_tags |= set(p[0] for p in patterns if
p[0].startswith("msc.embedding"))
+ if "erf" not in op_names:
+ skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.gelu"))
+ valid_patterns = [p for p in patterns if p[0] not in skip_tags]
+ return valid_patterns
+
+
+def from_relay(
+ mod: tvm.IRModule,
+ params: Optional[Dict[str, tvm.nd.array]] = None,
+ trans_config: Optional[Dict[str, str]] = None,
+ build_config: Optional[Dict[str, str]] = None,
+ opt_config: Optional[Dict[str, str]] = None,
+) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]:
+ """Change IRModule to MSCGraph.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The IRModule of relay.
+ params: dict of <string:tvm.ndarray>
+ The parameters of the IRModule.
+ trans_config: dict
+ The config for transform IRModule.
+ build_config: dict
+ The config for build MSCGraph.
+ opt_config: dict
+ The config for optimize the relay before translate.
+
+ Returns
+ -------
+ graph: tvm.contrib.msc.core.ir.MSCGraph
+ The translated graph.
+ weights: dict of <string:tvm.ndarray>
+ The weights from the IRModule.
+ """
+
+ trans_config = msc_utils.copy_dict(trans_config)
+ build_config = msc_utils.copy_dict(build_config)
+ opt_config = msc_utils.copy_dict(opt_config)
+ # TODO(tong.meng): optimize before translate?
+ opt_level = opt_config.get("opt_level", 0)
+ if params:
+ mod["main"] = bind_params_by_name(mod["main"], params)
+ if opt_level > 0:
+ target = opt_config.get("target", "llvm")
+ disabled_pass = opt_config.get("disabled_pass", []) + [
+ "SimplifyInference",
+ "CanonicalizeOps",
+ "FuseOps",
+ "AlterOpLayout",
+ ]
+ with tvm.transform.PassContext(opt_level=opt_level,
disabled_pass=disabled_pass):
+ mod, params = tvm.relay.optimize(mod, target=target, params=params)
+ patterns = get_relay_patterns(mod)
+ passes = [
+ tvm.relay.transform.InferType(),
+ tvm.relay.transform.MergeComposite(patterns),
+ msc_transform.SetExprName(as_relax=False),
+ ]
+ mod = tvm.transform.Sequential(passes)(mod)
+ graph = _ffi_api.BuildFromRelay(mod, "main",
msc_utils.dump_dict(build_config))
+ t_weights = _ffi_api.GetRelayWeights(mod, "main")
+ return graph, normalize_weights(t_weights, graph)
+
+
[email protected]_functor.visitor
+class BYOCChecker(PyExprVisitor):
+ """Checker to check if any non-target ops exist"""
+
+ def check(self, func_names, expr):
+ self._func_names = func_names
+ self._non_target_exprs = []
+ if isinstance(expr, tvm.relax.Expr):
+ self.visit_expr(expr)
+ elif isinstance(expr, tvm.relax.BindingBlock):
+ self.visit_binding_block(expr)
+ assert len(self._non_target_exprs) == 0, "Some exprs not on target
{}".format(expr)
+
+ def visit_var_binding_(self, binding) -> None:
+ super().visit_var_binding_(binding)
+ if isinstance(binding.value, tvm.relax.Call):
+ if isinstance(binding.value.op, tvm.relax.GlobalVar):
+ if binding.value.op.name_hint not in self._func_names:
+ self._non_target_exprs.append(binding.value)
+ else:
+ self._non_target_exprs.append(binding.value)
+ elif not isinstance(binding.value, tvm.relax.DataflowVar):
+ self._non_target_exprs.append(binding.value)
+
+
+def byoc_partition(
+ target: str,
+ mod: tvm.IRModule,
+ params: Optional[Dict[str, tvm.nd.array]] = None,
+ trans_config: Optional[Dict[str, str]] = None,
+ build_config: Optional[Dict[str, str]] = None,
+) -> Tuple[tvm.IRModule, List[Tuple[MSCGraph, Dict[str, tvm.nd.array]]]]:
+ """Partition module to target sub functions.
+
+ Parameters
+ ----------
+ target: str
+ The target for the BYOC.
+ mod: IRModule
+ The IRModule of relax.
+ trans_config: dict
+ The config for transform IRModule.
+ params: dict of <string:tvm.ndarray>
+ The parameters of the IRModule.
+ build_config: dict
+ The config for build MSCGraph.
+
+ Returns
+ -------
+ mod: IRModule
+ The IRModule of partitioned relax.
+ graphs_info: list<<MSCGraph, weights>>
+ The func <MSCGraph and weights> list, each element for a sub graph.
+ """
+
+ trans_config = msc_utils.copy_dict(trans_config)
+ build_config = msc_utils.copy_dict(build_config)
+ build_config["target"] = target
+ for key in ["input_aliases", "output_aliases"]:
+ if key in build_config:
+ build_config.pop(key)
+ entry = trans_config.get("entry", "main")
+ if params:
+ mod = BindParams("main", params)(mod)
+
+ def _partition_mod(mod, as_msc=True):
+ patterns = get_patterns_with_prefix(target)
+ if as_msc:
+ passes = [tvm.relax.transform.FuseOpsByPattern(patterns,
bind_constants=False)]
+ else:
+ passes = [tvm.relax.transform.FuseOpsByPattern(patterns,
bind_constants=True)]
+ passes.extend(
+ [
+ msc_transform.BindShape(),
+ msc_transform.FuseTuple(target),
+ tvm.relax.transform.MergeCompositeFunctions(),
+ msc_transform.SetExprName(target=target),
+
msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)),
+ ]
+ )
+ return tvm.transform.Sequential(passes)(mod)
+
+ def _is_target_func(func):
+ if "Codegen" not in func.attrs:
+ return False
+ return func.attrs["Codegen"] == target
+
+ msc_mod = _partition_mod(mod)
+ func_names = [var.name_hint for var, func in msc_mod.functions.items() if
_is_target_func(func)]
+
+ if not trans_config.get("allow_incomplete", False):
+ assert len(func_names) == 1, "More than 1 target func is found: " +
str(msc_mod)
+ BYOCChecker().check(func_names, msc_mod[entry])
+
+ graphs_info, all_weights = [], _ffi_api.GetRelaxWeights(msc_mod, entry)
+ for name in func_names:
+ build_config.update({"graph_name": name, "byoc_entry": name})
+ graph = _ffi_api.BuildFromRelax(msc_mod, entry,
msc_utils.dump_dict(build_config))
+ graphs_info.append((graph, normalize_weights(all_weights, graph)))
+ return _partition_mod(mod, False), graphs_info
diff --git a/python/tvm/contrib/msc/core/ir/__init__.py
b/python/tvm/contrib/msc/core/ir/__init__.py
index 81a34bedb6..ce23a2dd8b 100644
--- a/python/tvm/contrib/msc/core/ir/__init__.py
+++ b/python/tvm/contrib/msc/core/ir/__init__.py
@@ -17,4 +17,3 @@
"""tvm.contrib.msc.core.ir"""
from .graph import *
-from .translate import *
diff --git a/python/tvm/contrib/msc/framework/__init__.py
b/python/tvm/contrib/msc/core/runtime/__init__.py
similarity index 93%
copy from python/tvm/contrib/msc/framework/__init__.py
copy to python/tvm/contrib/msc/core/runtime/__init__.py
index 17c974aab9..a0ccca5b2b 100644
--- a/python/tvm/contrib/msc/framework/__init__.py
+++ b/python/tvm/contrib/msc/core/runtime/__init__.py
@@ -14,4 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework"""
+"""tvm.contrib.msc.core.runtime"""
+
+from .runner import *
diff --git a/python/tvm/contrib/msc/core/runtime/runner.py
b/python/tvm/contrib/msc/core/runtime/runner.py
new file mode 100644
index 0000000000..65e86e4896
--- /dev/null
+++ b/python/tvm/contrib/msc/core/runtime/runner.py
@@ -0,0 +1,818 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""tvm.contrib.msc.core.runtime.runner"""
+
+import os
+import json
+import logging
+from typing import Dict, Optional, Any, List, Tuple, Union
+import numpy as np
+
+import tvm
+from tvm.contrib.msc.core.ir import MSCGraph
+from tvm.contrib.msc.core.frontend import from_relax
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.core import utils as msc_utils
+from tvm.contrib.msc.core import _ffi_api
+
+
+class BaseRunner(object):
+ """Basic runner of MSC
+
+ Parameters
+ ----------
+ mod: IRModule
+ The IRModule of relax.
+ params: dict of <string:tvm.ndarray>
+ The parameters of the IRModule.
+ tools_config: dict
+ The config of MSC Tools.
+ translate_config: dict
+ The config for translate IRModule to MSCGraph.
+ codegen_config: dict
+ The config for build MSCGraph to runnable model.
+ name: str
+ The name of the runner
+ device: str
+ The device of the model, cpu| cuda| cuda:0|...
+ is_training: bool
+ Whether use model in training
+ logger: logging.Logger
+ The logger
+ """
+
+ def __init__(
+ self,
+ mod: tvm.IRModule,
+ tools_config: Optional[Dict[str, Any]] = None,
+ translate_config: Optional[Dict[str, str]] = None,
+ load_config: Optional[Dict[str, str]] = None,
+ name: str = "main",
+ device: str = "cpu",
+ is_training: bool = False,
+ logger: logging.Logger = None,
+ ):
+ self._mod = mod
+ self._tools_config = tools_config or {}
+ self._translate_config = translate_config or {}
+ self._load_config = load_config or {}
+ self._name = name
+ self._device = device if self._device_enabled(device) else "cpu"
+ self._is_training = is_training
+ self._logger = logger or msc_utils.get_global_logger()
+ self.setup()
+ config = {
+ "class": self.__class__.__name__,
+ "tools_config": self._tools_config,
+ "translate_config": self._translate_config,
+ "load_config": self._load_config,
+ "name": self._name,
+ "device": self._device,
+ "is_training": self._is_training,
+ }
+ self._logger.debug(msc_utils.msg_block("RUNNER_CONFIG", config))
+
+ def setup(self):
+ """Setup the runner"""
+
+ self._graphs, self._weights = [], []
+ self._model, self._model_info = None, {}
+ self._runnable = None
+
+ def build(self, cache_dir: msc_utils.MSCDirectory = None, build_graph:
bool = False) -> object:
+ """Build the runnable object
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+ build_graph: bool
+ Whether to build the MSCGraphs.
+
+ Returns
+ -------
+ runnable: object
+ The runnable object.
+ """
+
+ if cache_dir and os.path.isfile(cache_dir.relpath("cache_info.json")):
+ cache_info =
msc_utils.load_dict(cache_dir.relpath("cache_info.json"))
+ else:
+ cache_info = {}
+
+ # Load graphs from cache
+ if cache_info.get("graphs"):
+ self._graphs, self._weights = self._load_graphs(cache_dir,
cache_info["graphs"])
+ self._logger.debug(
+ "Load {} graphs from cache @ {}".format(len(self._graphs),
cache_dir)
+ )
+
+ # Get or rebuild graphs
+ if build_graph or not self._graphs:
+ self._graphs, self._weights = self._translate()
+ self._logger.debug("Translate {} graphs from
module".format(len(self._graphs)))
+
+ # Save graphs for debug
+ for graph in self._graphs:
+ graph.visualize(msc_utils.get_debug_dir().relpath(graph.name +
".prototxt"))
+
+ # Create tools
+ if self._tools_config:
+ raise NotImplementedError("Build runner with tools is not
supported")
+
+ if cache_info.get("model") and not build_graph:
+ # Load model from cache
+ self._model = self._load_model(cache_dir, cache_info["model"])
+ else:
+ # Generate and save model
+ self._model = self._generate_model()
+ if "loader" in self._load_config:
+ loader, load_config = self._load_config["loader"]
+ self._model = loader(self._model, **load_config)
+ self._logger.info(
+ "Model({}) processed by customize loader {}({})".format(
+ self.framework, loader, load_config
+ )
+ )
+ self._model_info = self._inspect_model()
+ self._logger.debug(msc_utils.msg_block("MODEL_INFO", self._model_info))
+
+ if cache_info.get("runnable") and not build_graph:
+ # Load runnable from cache
+ self._runnable = self._load_runnable(cache_dir,
cache_info["runnable"])
+ else:
+ # Build runnable on device
+ self._runnable = self._to_runnable(self._model, self._device,
self._is_training)
+ self._logger.info(
+ "Runnable({}, {}) loaded on device {}".format(
+ self.framework, "train" if self._is_training else "eval",
self._device
+ )
+ )
+ return self._runnable
+
+ def save_cache(self, cache_dir: msc_utils.MSCDirectory):
+ """Save runner to cache
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+ """
+
+ cache_info = {
+ "graphs": self._save_graphs(cache_dir),
+ "model": self._save_model(cache_dir),
+ "runnable": self._save_runnable(cache_dir),
+ }
+ with open(cache_dir.relpath("cache_info.json"), "w") as f:
+ f.write(json.dumps(cache_info, indent=2))
+ self._logger.debug("Runner save cache -> " + str(cache_dir.path))
+ self._logger.debug(msc_utils.msg_block("CACHE_INFO", cache_info))
+
+ def run(
+ self, inputs: Union[List[np.ndarray], Dict[str, np.ndarray]],
ret_type="dict"
+ ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]:
+ """Run the model to get outputs
+
+ Parameters
+ -------
+ inputs: list<data> or dict<str, data>
+ The inputs in list or dict.
+ ret_type: str
+ The return type list| dict
+
+ Returns
+ -------
+ outputs: dict<str, data>
+ The outputs in dict.
+ """
+
+ model_inputs = self.get_inputs()
+ model_outputs = self.get_outputs()
+ if isinstance(inputs, (list, tuple)):
+ assert len(inputs) == len(
+ model_inputs
+ ), "inputs({}) mismatch with model inputs {}".format(len(inputs),
model_inputs)
+ inputs = {info["name"]: data for info, data in zip(model_inputs,
inputs)}
+ assert isinstance(inputs, dict), "Expect inputs as list or dict, get
{}({})".format(
+ inputs, type(inputs)
+ )
+ assert all(
+ isinstance(data, np.ndarray) for data in inputs.values()
+ ), "Expected all inputs as np.ndarray"
+ inputs = {i["name"]: inputs[i["name"]] for i in model_inputs}
+ outputs = self._call_runnable(self._runnable, inputs, self._device)
+ if ret_type == "dict":
+ if isinstance(outputs, (list, tuple)):
+ assert len(outputs) == len(
+ model_outputs
+ ), "outputs({}) mismatch with model outputs
{}".format(len(outputs), model_outputs)
+ outputs = {info["name"]: data for info, data in
zip(model_outputs, outputs)}
+ if not isinstance(outputs, dict):
+ assert len(model_outputs) == 1, "Expect model_outputs with len
1, get " + str(
+ model_outputs
+ )
+ outputs = {model_outputs[0]["name"]: outputs}
+ outputs = {name: msc_utils.cast_array(data) for name, data in
outputs.items()}
+ elif ret_type == "list":
+ if isinstance(outputs, dict):
+ assert len(outputs) == len(
+ model_outputs
+ ), "outputs({}) mismatch with model outputs
{}".format(len(outputs), model_outputs)
+ outputs = [outputs[o["name"]] for o in model_outputs]
+ if not isinstance(outputs, (list, tuple)):
+ outputs = [outputs]
+ outputs = [msc_utils.cast_array(data) for data in outputs]
+ return outputs
+
+ def get_inputs(self) -> List[Dict[str, str]]:
+ """Get the inputs of the model
+
+ Returns
+ -------
+ inputs: list<tensor_des>
+ The inputs info.
+ """
+
+ return self._model_info["inputs"]
+
+ def get_outputs(self) -> List[Dict[str, str]]:
+ """Get the outputs of the model
+
+ Returns
+ -------
+ outputs: list<tensor_des>
+ The outputs info.
+ """
+
+ return self._model_info["outputs"]
+
+ def destory(self):
+ """Destory runner"""
+
+ if self._model:
+ del self._model
+
+ def _translate(self) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]:
+ """Translate IRModule to MSCgraphs
+
+ Returns
+ -------
+ graph_list: list<MSCGraph>
+ The translated graphs
+ weights_list: list<dict<str, tvm.nd.array>>
+ The translated weights
+ """
+
+ raise NotImplementedError("_translate is not implemented for " +
str(self.__class__))
+
+ def _load_graphs(
+ self, cache_dir: msc_utils.MSCDirectory, cache_info: dict
+ ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]:
+ """Load MSCgraphs from cache
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+ cache_info: dict
+ The cache info.
+
+ Returns
+ -------
+ graph_list: list<MSCGraph>
+ The translated graphs
+ weights_list: list<dict<str, tvm.nd.array>>
+ The translated weights
+ """
+
+ raise NotImplementedError("_load_graphs is not implemented for " +
str(self.__class__))
+
+ def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict:
+ """Save MSCgraphs to cache
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+
+ Returns
+ -------
+ cache_info: dict
+ The cache info.
+ """
+
+ raise NotImplementedError("_save_graphs is not implemented for " +
str(self.__class__))
+
+ def _generate_model(self) -> object:
+ """Codegen the model according to framework
+
+ Returns
+ -------
+ model: object
+ The meta model
+ """
+
+ raise NotImplementedError("_load is not implemented for " +
str(self.__class__))
+
+ def _load_model(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict)
-> object:
+ """Load the model from cache
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+ cache_info: dict
+ The cache info.
+
+ Returns
+ -------
+ model: object
+ The meta model
+ """
+
+ raise NotImplementedError("_load_model is not implemented for " +
str(self.__class__))
+
+ def _save_model(self, cache_dir: msc_utils.MSCDirectory) -> dict:
+ """Save model to cache
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+
+ Returns
+ -------
+ cache_info: dict
+ The cache info.
+ """
+
+ # disable save model by default
+ return {}
+
+ def _to_runnable(self, model: object, device: str, is_training: bool) ->
object:
+ """Build runnable object
+
+ Parameters
+ -------
+ model: object
+ The meta model.
+ device: str
+ The device for place model
+ is_training: bool
+ Whether to load model for training
+
+ Returns
+ -------
+ runnable: object
+ The runnable
+ """
+
+ raise NotImplementedError("_to_runnable is not implemented for " +
str(self.__class__))
+
+ def _load_runnable(self, cache_dir: msc_utils.MSCDirectory, cache_info:
dict) -> object:
+ """Load the runnable from cache
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+ cache_info: dict
+ The cache info.
+
+ Returns
+ -------
+ runnable: object
+ The runnable
+ """
+
+ raise NotImplementedError("_load_runnable is not implemented for " +
str(self.__class__))
+
+ def _save_runnable(self, cache_dir: msc_utils.MSCDirectory) -> dict:
+ """Save runnable to cache
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+
+ Returns
+ -------
+ cache_info: dict
+ The cache info.
+ """
+
+ # disable save runnable by default
+ return {}
+
+ def _inspect_model(self) -> dict:
+ """Inspect the model
+
+ Returns
+ -------
+ model_info: dict
+ The inspected model info
+ """
+
+ raise NotImplementedError("_inspect_model is not implemented for " +
str(self.__class__))
+
+ def _call_runnable(
+ self, runnable: object, inputs: Dict[str, np.ndarray], device: str
+ ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]:
+ """Call the runnable to get outputs
+
+ Parameters
+ -------
+ model:
+ The runnable model.
+ inputs: dict<str, data>
+ The inputs in dict.
+ device: str
+ The device.
+
+ Returns
+ -------
+ outputs: list<data> or dict<str, data>
+ The outputs in list or dict.
+ """
+
+ raise NotImplementedError("_call_runnable is not implemented for " +
str(self.__class__))
+
+ def _device_enabled(self, device: str) -> bool:
+ """Check if the device is enabled
+
+ Returns
+ -------
+ enabled: bool
+ Whether the device is enabled.
+ """
+
+ return True
+
+ @property
+ def model(self):
+ return self._model
+
+ @property
+ def runnable(self):
+ return self._runnable
+
+ @property
+ def device(self):
+ return self._device
+
+ @property
+ def codegen_func(self):
+ raise NotImplementedError("codegen_func is not implemented for " +
str(self.__class__))
+
+ @property
+ def framework(self):
+ return MSCFramework.MSC
+
+
+class ModelRunner(BaseRunner):
+ """Model runner of MSC"""
+
+ def _translate(self) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]:
+ """Translate IRModule to MSCgraphs
+
+ Returns
+ -------
+ graph_list: list<MSCGraph>
+ The translated graphs
+ weights_list: list<dict<str, tvm.nd.array>>
+ The translated weights
+ """
+
+ graph, weights = from_relax(
+ self._mod,
+ trans_config=self._translate_config.get("transform"),
+ build_config=self._translate_config.get("build"),
+ opt_config=self._translate_config.get("optimize"),
+ )
+ return [graph], [weights]
+
+ def _load_graphs(
+ self, cache_dir: msc_utils.MSCDirectory, cache_info: dict
+ ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]:
+ """Load MSCgraphs from cache
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+ cache_info: dict
+ The cache info.
+
+ Returns
+ -------
+ graph_list: list<MSCGraph>
+ The translated graphs
+ weights_list: list<dict<str, tvm.nd.array>>
+ The translated weights
+ """
+
+ assert "main" in cache_info, "main should be given in cache_info, get
" + str(cache_info)
+ graph =
MSCGraph.from_json(cache_dir.relpath(cache_info["main"]["graph"]))
+ with open(cache_dir.relpath(cache_info["main"]["weights"]), "rb") as f:
+ weights = tvm.runtime.load_param_dict(f.read())
+ return [graph], [weights]
+
+ def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict:
+ """Save MSCgraphs to cache
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+
+ Returns
+ -------
+ cache_info: dict
+ The cache info.
+ """
+
+ main_info = {
+ "graph": self._graphs[0].name + "_graph.json",
+ "weights": self._graphs[0].name + "_params.bin",
+ }
+ with cache_dir:
+ with open(main_info["graph"], "w") as f_graph:
+ f_graph.write(self._graphs[0].to_json())
+ with open(main_info["weights"], "wb") as f_params:
+ f_params.write(tvm.runtime.save_param_dict(self._weights[0]))
+ return {"main": main_info}
+
+ def _generate_model(self) -> object:
+ """Codegen the model according to framework
+
+ Returns
+ -------
+ model: object
+ The runnable model
+ """
+
+ return self.codegen_func(
+ self._graphs[0],
+ self._weights[0],
+ codegen_config=self._load_config.get("codegen"),
+ print_config=self._load_config.get("build"),
+ build_folder=self._load_config.get("build_folder",
msc_utils.get_build_dir()),
+ )
+
+ def _inspect_model(self) -> dict:
+ """Inspect the model
+
+ Returns
+ -------
+ model_info: dict
+ The inspected model info
+ """
+
+ return self._graphs[0].inspect()
+
+
+class BYOCRunner(BaseRunner):
+ """BYOC runner of MSC"""
+
+ def setup(self):
+ """Setup the runner"""
+
+ super().setup()
+ self._byoc_mod, self._byoc_graph = None, None
+ self._graph_infos = {}
+
+ def _translate(self) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]:
+ """Translate IRModule to MSCgraphs
+
+ Returns
+ -------
+ graph_list: list<MSCGraph>
+ The translated graphs
+ weights_list: list<dict<str, tvm.nd.array>>
+ The translated weights
+ """
+
+ self._byoc_mod, self._graph_infos = self.partition_func(
+ self._mod,
+ trans_config=self._translate_config.get("transform"),
+ build_config=self._translate_config.get("build"),
+ )
+ graphs, weights = [], []
+ for graph, sub_weights in self._graph_infos:
+ graphs.append(graph)
+ weights.append(sub_weights)
+ self._byoc_graph = _ffi_api.BuildFromRelax(
+ self._byoc_mod, "main",
msc_utils.dump_dict(self._translate_config.get("build"))
+ )
+ self._byoc_graph.visualize(
+ msc_utils.get_debug_dir().relpath(self._byoc_graph.name +
".prototxt")
+ )
+ return graphs, weights
+
+ def _load_graphs(
+ self, cache_dir: msc_utils.MSCDirectory, cache_info: dict
+ ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]:
+ """Load MSCgraphs from cache
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+ cache_info: dict
+ The cache info.
+
+ Returns
+ -------
+ graph_list: list<MSCGraph>
+ The translated graphs
+ weights_list: list<dict<str, tvm.nd.array>>
+ The translated weights
+ """
+
+ assert "byoc_mod" in cache_info, "byoc_mod should be given in
cache_info, get " + str(
+ cache_info
+ )
+ assert "byoc_graph" in cache_info, "byoc_graph should be given in
cache_info, get " + str(
+ cache_info
+ )
+ assert "sub_graphs" in cache_info, "sub_graphs should be given in
cache_info, get " + str(
+ cache_info
+ )
+
+ self._byoc_mod =
tvm.ir.load_json(cache_dir.relpath(cache_info["byoc_mod"]))
+ graphs, weights = [], []
+ for f_graph, f_weights in cache_info["sub_graphs"]:
+ graphs.append(MSCGraph.from_json(cache_dir.relpath(f_graph)))
+ with open(cache_dir.relpath(f_weights), "rb") as f:
+ weights = tvm.runtime.load_param_dict(f.read())
+ self._graph_infos = list(zip(graphs, weights))
+ self._byoc_graph =
MSCGraph.from_json(cache_dir.relpath(cache_info["byoc_graph"]))
+ self._byoc_graph.visualize(
+ msc_utils.get_debug_dir().relpath(self._byoc_graph.name +
".prototxt")
+ )
+ return graphs, weights
+
+ def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict:
+ """Save MSCgraphs to cache
+
+ Parameters
+ -------
+ cache_dir: MSCDirectory
+ cache path for save/load info
+
+ Returns
+ -------
+ cache_info: dict
+ The cache info.
+ """
+
+ sub_graphs = [
+ (graph.name + "_graph.info", graph.name + "_params.bin") for graph
in self._graphs
+ ]
+ with cache_dir:
+ for graph, weights, info in zip(self._graphs, self._weights,
sub_graphs):
+ with open(info[0], "w") as f_graph:
+ f_graph.write(graph.to_json())
+ with open(info[1], "wb") as f_params:
+ f_params.write(tvm.runtime.save_param_dict(weights))
+ with open("byoc_graph.json", "w") as f:
+ f.write(self._byoc_graph.to_json())
+ with open("byoc_module.json", "w") as f:
+ f.write(tvm.ir.save_json(self._byoc_mod))
+ return {
+ "sub_graphs": sub_graphs,
+ "byoc_graph": "byoc_graph.json",
+ "byoc_mod": "byoc_module.json",
+ }
+
+ def _generate_model(self) -> tvm.IRModule:
+ """Codegen the model according to framework
+
+ Returns
+ -------
+ model: tvm.IRModule
+ The relax module
+ """
+
+ return self.codegen_func(
+ self._byoc_mod,
+ self._graph_infos,
+ codegen_config=self._load_config.get("codegen"),
+ print_config=self._load_config.get("build"),
+ build_folder=self._load_config.get("build_folder",
msc_utils.get_build_dir()),
+ output_folder=self._load_config.get("output_folder",
msc_utils.get_output_dir()),
+ )
+
+ def _to_runnable(self, model: object, device: str, is_training: bool) ->
object:
+ """Build runnable object
+
+ Parameters
+ -------
+ model: object
+ The runnable model on cpu.
+ device: str
+ The device for place model
+ is_training: bool
+ Whether to load model for training
+
+ Returns
+ -------
+ runnable: object
+ The runnable
+ """
+
+ model = tvm.relax.transform.LegalizeOps()(model)
+ if device == "cpu":
+ target = tvm.target.Target("llvm")
+ with tvm.transform.PassContext(opt_level=3):
+ relax_exec = tvm.relax.build(model, target)
+ runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu())
+ elif device.startswith("cuda"):
+ target = tvm.target.Target("cuda")
+ with target:
+ model = tvm.tir.transform.DefaultGPUSchedule()(model)
+ with tvm.transform.PassContext(opt_level=3):
+ relax_exec = tvm.relax.build(model, target)
+ runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cuda())
+ else:
+ raise NotImplementedError("Unsupported device " + str(device))
+ return runnable
+
+ def _call_runnable(
+ self, runnable: tvm.relax.VirtualMachine, inputs: Dict[str,
np.ndarray], device: str
+ ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]:
+ """Call the runnable to get outputs
+
+ Parameters
+ -------
+ runnable: tvm.relax.VirtualMachine
+ The virtual machine.
+ inputs: dict<str, data>
+ The inputs in dict.
+ device: str
+ The device.
+
+ Returns
+ -------
+ outputs: list<data>
+ The outputs in list.
+ """
+
+ model_inputs = self.get_inputs()
+ if device == "cpu":
+ tvm_inputs = [tvm.nd.array(inputs[i["name"]]) for i in
model_inputs]
+ elif device.startswith("cuda"):
+ dev_id = int(device.split(":")[1]) if ":" in device else 0
+ tvm_inputs = [
+ tvm.nd.array(inputs[i["name"]], device=tvm.cuda(dev_id)) for i
in model_inputs
+ ]
+ else:
+ raise NotImplementedError("Unsupported device " + str(device))
+ return runnable["main"](*tvm_inputs)
+
+ def _inspect_model(self) -> dict:
+ """Inspect the model
+
+ Returns
+ -------
+ model_info: dict
+ The inspected model info
+ """
+
+ return self._byoc_graph.inspect()
+
+ def _device_enabled(self, device: str) -> bool:
+ """Check if the device is enabled
+
+ Returns
+ -------
+ enabled: bool
+ Whether the device is enabled.
+ """
+
+ if device == "cpu":
+ return True
+ if device.startswith("cuda"):
+ dev_id = int(device.split(":")[1]) if ":" in device else 0
+ return tvm.cuda(dev_id).exist
+ return False
+
+ @property
+ def partition_func(self):
+ raise NotImplementedError("partition_func is not implemented for " +
str(self.__class__))
diff --git a/python/tvm/contrib/msc/core/utils/__init__.py
b/python/tvm/contrib/msc/core/utils/__init__.py
index 4df8b087cb..a76659609d 100644
--- a/python/tvm/contrib/msc/core/utils/__init__.py
+++ b/python/tvm/contrib/msc/core/utils/__init__.py
@@ -22,3 +22,5 @@ from .file import *
from .namespace import *
from .register import *
from .dataset import *
+from .log import *
+from .message import *
diff --git a/python/tvm/contrib/msc/core/utils/dataset.py
b/python/tvm/contrib/msc/core/utils/dataset.py
index 2a9c5f4a4a..68760f07ae 100644
--- a/python/tvm/contrib/msc/core/utils/dataset.py
+++ b/python/tvm/contrib/msc/core/utils/dataset.py
@@ -17,8 +17,9 @@
"""tvm.contrib.msc.core.utils.dataset"""
import os
+import shutil
import json
-from typing import List
+from typing import List, Union, Dict
import numpy as np
from .info import load_dict
@@ -38,7 +39,6 @@ class MSCDataLoader(object):
"""
def __init__(self, folder: str, start: int = 0, end: int = -1):
- super(MSCDataLoader, self).__init__()
self._folder = folder
self._start = start
self._current = 0
@@ -93,10 +93,15 @@ class MSCDataLoader(object):
The loaded data.
"""
- f_path = os.path.join(self._folder, name,
"batch_{}.bin".format(self._start + index))
+ save_name = info.get("save_name", name)
+ f_path = os.path.join(self._folder, save_name,
"batch_{}.bin".format(self._start + index))
assert os.path.isfile(f_path), "Can not find data file " + str(f_path)
return np.fromfile(f_path, dtype=info["dtype"]).reshape(info["shape"])
+ @property
+ def info(self):
+ return self._info
+
class MSCDataSaver(object):
"""Dataset Saver for MSC
@@ -123,9 +128,9 @@ class MSCDataSaver(object):
start: int = 0,
max_size: int = -1,
):
- super(MSCDataSaver, self).__init__()
- if not os.path.isdir(folder):
- os.mkdir(folder)
+ if os.path.isdir(folder):
+ shutil.rmtree(folder)
+ os.mkdir(folder)
self._folder = folder
self._input_names = input_names
self._output_names = output_names
@@ -146,31 +151,48 @@ class MSCDataSaver(object):
def reset(self):
self._current = 0
- def save(self, inputs: List[np.ndarray], outputs: List[np.ndarray] = None):
+ def save(
+ self,
+ inputs: Union[Dict[str, np.ndarray], List[np.ndarray]],
+ outputs: Union[Dict[str, np.ndarray], List[np.ndarray]] = None,
+ ):
"""Save 1 batch inputs and outputs.
Parameters
-------
- inputs: list<np.ndarray>
+ inputs: list<np.ndarray>/dict<str, np.ndarray>
The inputs datas.
- outputs: list<np.ndarray>
+ outputs: list<np.ndarray>/dict<str, np.ndarray>
The outputs datas.
"""
- assert len(inputs) == len(
- self._input_names
- ), "inputs size {} mismatch with input_names {}".format(len(inputs),
self._input_names)
- for idx, i_data in enumerate(inputs):
- self._save_data(self._input_names[idx], i_data, True)
+ if isinstance(inputs, dict):
+ assert set(inputs.keys()) == set(
+ self._input_names
+ ), "Input names mismatch {} with {}".format(inputs.keys(),
self._input_names)
+ elif isinstance(inputs, (tuple, list)):
+ assert len(inputs) == len(
+ self._input_names
+ ), "Inputs size {} mismatch with input_names
{}".format(len(inputs), self._input_names)
+ inputs = dict(zip(self._input_names, inputs))
+ for name, data in inputs.items():
+ self._save_data(name, data, True)
if outputs:
- assert len(outputs) == len(
- self._output_names
- ), "outputs size {} mismatch with output_names {}".format(
- len(outputs), self._output_names
- )
- for idx, o_data in enumerate(outputs):
- self._save_data(self._output_names[idx], o_data, False)
+ if isinstance(outputs, dict):
+ assert set(outputs.keys()) == set(
+ self._output_names
+ ), "Output names mismatch {} with {}".format(outputs.keys(),
self._output_names)
+ elif isinstance(outputs, (tuple, list)):
+ assert len(outputs) == len(
+ self._output_names
+ ), "Outputs size {} mismatch with input_names {}".format(
+ len(outputs), self._output_names
+ )
+ outputs = dict(zip(self._output_names, outputs))
+ for name, data in outputs.items():
+ self._save_data(name, data, False)
self._current += 1
+ return self._current
def _save_data(self, name: str, data: np.ndarray, is_input: bool):
"""Save data to file.
@@ -185,7 +207,8 @@ class MSCDataSaver(object):
Whether the data is input.
"""
- sub_folder = f_path = os.path.join(self._folder, name)
+ save_name = name.replace("/", "_")
+ sub_folder = f_path = os.path.join(self._folder, save_name)
if not os.path.isdir(sub_folder):
os.mkdir(sub_folder)
f_path = os.path.join(sub_folder, "batch_{}.bin".format(self._start +
self._current))
@@ -203,5 +226,16 @@ class MSCDataSaver(object):
"shape": list(data.shape),
"dtype": data.dtype.name,
"bytes": data.size * data.itemsize,
+ "save_name": save_name,
}
data.tofile(f_path)
+
+ @property
+ def info(self):
+ return self._info
+
+
+def is_dataset(folder: str) -> bool:
+ """Check if a folder is MSC dataset"""
+
+ return os.path.isfile(os.path.join(folder, "msc_info.json"))
diff --git a/python/tvm/contrib/msc/core/utils/file.py
b/python/tvm/contrib/msc/core/utils/file.py
index 2e049b9ee6..278d9d56b9 100644
--- a/python/tvm/contrib/msc/core/utils/file.py
+++ b/python/tvm/contrib/msc/core/utils/file.py
@@ -20,13 +20,42 @@ import os
import shutil
import tempfile
import types
+from functools import partial
from typing import List
from importlib.machinery import SourceFileLoader
-from .namespace import MSCFramework
+from .namespace import MSCMap, MSCKey, MSCFramework
from .register import get_registered_func
+def load_callable(name: str, framework: str = MSCFramework.MSC) -> callable:
+ """Load a callable object.
+
+ Parameters
+ ----------
+ name: string
+ The name of the registered func or path:f_name str.
+ framework: string
+ Should be from MSCFramework.
+
+ Returns
+ -------
+ func: callable
+ The function.
+ """
+
+ func = get_registered_func(name, framework)
+ if func:
+ return func
+ if ".py:" in name:
+ path, func_name = name.split(":")
+ loader = SourceFileLoader(path.replace(".py", ""), path)
+ mod = types.ModuleType(loader.name)
+ loader.exec_module(mod)
+ return getattr(mod, func_name)
+ raise Exception("Func {} is neighter registered nor path.py:name string")
+
+
class MSCDirectory(object):
"""Create a directory manager for MSC"""
@@ -152,7 +181,7 @@ class MSCDirectory(object):
os.remove(dir_path)
return self.__class__(dir_path, keep_history=keep_history,
cleanup=cleanup)
- def relpath(self, name: str) -> str:
+ def relpath(self, name: str, keep_history: bool = True) -> str:
"""Relative path in dir
Parameters
@@ -166,7 +195,12 @@ class MSCDirectory(object):
The concatenated path.
"""
- return os.path.join(self._path, name)
+ f_path = os.path.join(self._path, name)
+ if os.path.isfile(f_path) and not keep_history:
+ os.remove(f_path)
+ if os.path.isdir(f_path) and not keep_history:
+ shutil.rmtree(f_path)
+ return f_path
def listdir(self) -> List[str]:
"""List contents in the dir.
@@ -211,29 +245,65 @@ def msc_dir(path: str = None, keep_history: bool = True,
cleanup: bool = False)
return MSCDirectory(path, keep_history, cleanup)
-def load_callable(name: str, framework: str = MSCFramework.MSC) -> callable:
- """Load a callable object.
+def set_workspace(
+ path: str = None, keep_history: bool = True, cleanup: bool = False
+) -> MSCDirectory:
+ """Create MSCDirectory as worksapce and set to map
Parameters
----------
- name: string
- The name of the registered func or path:f_name str.
- framework: string
- Should be from MSCFramework.
+ path: str
+ The path of the dir.
+ keep_history: bool
+ Whether to remove files before start.
+ cleanup: bool
+ Whether to clean up before exit.
Returns
-------
- func: callable
- The function.
+ dir: MSCDirectory
+ The created dir.
"""
- func = get_registered_func(name, framework)
- if func:
- return func
- if ".py:" in name:
- path, func_name = name.split(":")
- loader = SourceFileLoader(path.replace(".py", ""), path)
- mod = types.ModuleType(loader.name)
- loader.exec_module(mod)
- return getattr(mod, func_name)
- raise Exception("Func {} is neighter registered nor path.py:name string")
+ path = path or "msc_workspace"
+ workspace = MSCDirectory(path, keep_history, cleanup)
+ MSCMap.set(MSCKey.WORKSPACE, workspace)
+ return workspace
+
+
+def get_workspace() -> MSCDirectory:
+ """Get workspace from MSCMap
+
+ Returns
+ -------
+ dir: MSCDirectory
+ The worksapce dir.
+ """
+
+ workspace = MSCMap.get(MSCKey.WORKSPACE)
+ assert workspace, "Can not find workspace, please call set_workspace"
+ return workspace
+
+
+def get_workspace_subdir(name: str = None) -> MSCDirectory:
+ """Create sub dir for workspace
+
+ Parameters
+ ----------
+ name: str
+ The sub dir name under workspace.
+
+ Returns
+ -------
+ dir: MSCDirectory
+ The created dir.
+ """
+
+ return get_workspace().create_dir(name)
+
+
+get_build_dir = partial(get_workspace_subdir, name="Build")
+get_output_dir = partial(get_workspace_subdir, name="Output")
+get_dataset_dir = partial(get_workspace_subdir, name="Dataset")
+get_debug_dir = partial(get_workspace_subdir, name="Debug")
+get_cache_dir = partial(get_workspace_subdir, name="Cache")
diff --git a/python/tvm/contrib/msc/core/utils/info.py
b/python/tvm/contrib/msc/core/utils/info.py
index b2ea97d2e5..894447b169 100644
--- a/python/tvm/contrib/msc/core/utils/info.py
+++ b/python/tvm/contrib/msc/core/utils/info.py
@@ -18,13 +18,133 @@
import os
import json
-from typing import List
+import copy
+from typing import List, Tuple, Dict
from distutils.version import LooseVersion
+import numpy as np
import tvm
from .namespace import MSCFramework
+class MSCArray(object):
+ """MSC wrapper for array like object
+
+ Parameters
+ ----------
+ data: array_like: np.ndarray| torch.Tensor| tvm.ndarray| ...
+ The data object.
+ """
+
+ def __init__(self, data: object):
+ self._type, self._data = self._analysis(data)
+
+ def __str__(self):
+ return "<{}>{}".format(self._type, self.abstract())
+
+ def _analysis(self, data: object) -> Tuple[str, np.ndarray]:
+ if isinstance(data, np.ndarray):
+ return "np", data
+ if isinstance(data, tvm.runtime.NDArray):
+ return "tvm", data.asnumpy()
+ try:
+ import torch # pylint: disable=import-outside-toplevel
+
+ if isinstance(data, torch.Tensor):
+ return "torch", data.detach().cpu().numpy()
+ except: # pylint: disable=bare-except
+ pass
+
+ raise Exception("Unkonwn data {}({})".format(data, type(data)))
+
+ def abstract(self) -> str:
+ """Get abstract describe of the data"""
+ return "[S:{},D:{}] Max {:g}, Min {:g}, Avg {:g}".format(
+ ";".join([str(s) for s in self._data.shape]),
+ self._data.dtype.name,
+ self._data.max(),
+ self._data.min(),
+ self._data.sum() / self._data.size,
+ )
+
+ @property
+ def type(self):
+ return self._type
+
+ @property
+ def data(self):
+ return self._data
+
+
+def cast_array(data: object):
+ """Cast array like object to np.ndarray
+
+ Parameters
+ ----------
+ data: array_like: np.ndarray| torch.Tensor| tvm.ndarray| ...
+ The data object.
+
+ Returns
+ -------
+ output: np.ndarray
+ The output as numpy array.
+ """
+
+ return MSCArray(data).data
+
+
+def compare_arrays(
+ golden: Dict[str, np.ndarray],
+ datas: Dict[str, np.ndarray],
+ atol: float = 1e-2,
+ rtol: float = 1e-2,
+) -> dict:
+ """Compare elements in array
+
+ Parameters
+ ----------
+ golden: dict<str, np.ndarray>
+ The golden datas.
+ datas: dict<str, np.ndarray>
+ The datas to be compared.
+ atol: float
+ The atol for compare.
+ rtol: float
+ The rtol for compare.
+
+ Returns
+ -------
+ report: dict
+ The compare results.
+ """
+
+ assert golden.keys() == datas.keys(), "golden {} and datas {}
mismatch".format(
+ golden.keys(), datas.keys()
+ )
+ report = {"total": 0, "passed": 0, "info": {}}
+ for name, gol in golden.items():
+ report["total"] += 1
+ data = datas[name]
+ if list(gol.shape) != list(data.shape):
+ report["info"][name] = "<Fail> shape mismatch [G]{} vs
[D]{}".format(
+ gol.shape, data.shape
+ )
+ continue
+ if gol.dtype != data.dtype:
+ report["info"][name] = "<Fail> dtype mismatch [G]{} vs
[D]{}".format(
+ gol.dtype, data.dtype
+ )
+ continue
+ diff = MSCArray(gol - data)
+ try:
+ np.testing.assert_allclose(gol, data, rtol=rtol, atol=atol,
verbose=False)
+ report["info"][name] = "<Pass> diff {}".format(diff.abstract())
+ report["passed"] += 1
+ except: # pylint: disable=bare-except
+ report["info"][name] = "<Fail> diff {}".format(diff.abstract())
+ return report
+
+
def load_dict(str_dict: str, flavor: str = "json") -> dict:
"""Load the string/file to dict.
@@ -70,6 +190,31 @@ def dump_dict(dict_obj: dict, flavor: str = "dmlc") -> str:
return ""
if flavor == "dmlc":
return json.dumps({k: int(v) if isinstance(v, bool) else v for k, v in
dict_obj.items()})
+ if flavor == "table":
+
+ def _get_lines(value, indent=0):
+ lines = []
+ for k, v in value.items():
+ if isinstance(v, dict):
+ lines.append("{}{}:".format(indent * " ", k))
+ lines.extend(_get_lines(v, indent + 2))
+ elif isinstance(v, (tuple, list)) and len(str(v)) > 100:
+ lines.append("{}{}:".format(indent * " ", k))
+ lines.extend(
+ [
+ "{}<{}>{}".format((indent + 2) * " ", idx, ele)
+ for idx, ele in enumerate(v)
+ ]
+ )
+ elif isinstance(v, bool):
+ lines.append("{}{}: {}".format(indent * " ", k, "true" if
v else "false"))
+ elif isinstance(v, np.ndarray):
+ lines.append("{}{}: {}".format(indent * " ", k,
MSCArray(v).abstract()))
+ else:
+ lines.append("{}{}: {}".format(indent * " ", k, v))
+ return lines
+
+ return "\n".join(_get_lines(dict_obj))
return json.dumps(dict_obj)
@@ -103,6 +248,25 @@ def dict_equal(dict_a: dict, dict_b: dict) -> bool:
return True
+def copy_dict(dict_obj: dict) -> dict:
+ """Deepcopy dict object
+
+ Parameters
+ ----------
+ dict_obj: dict
+ The source dict.
+
+ Returns
+ -------
+ dict_obj: dict
+ The copied dict.
+ """
+
+ if not dict_obj:
+ return {}
+ return copy.deepcopy(dict_obj)
+
+
def get_version(framework: str) -> List[int]:
"""Get the version list of framework.
diff --git a/python/tvm/contrib/msc/core/utils/log.py
b/python/tvm/contrib/msc/core/utils/log.py
new file mode 100644
index 0000000000..525f1706f9
--- /dev/null
+++ b/python/tvm/contrib/msc/core/utils/log.py
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.core.utils.log"""
+
+import os
+import logging
+from typing import Union
+
+from .file import get_workspace
+from .namespace import MSCMap, MSCKey
+
+
+class IOLogger(object):
+ """IO Logger for MSC"""
+
+ def __init__(self):
+ self._printers = {
+ "red": (lambda m: print("\033[91m {}\033[00m".format(m))),
+ "green": (lambda m: print("\033[92m {}\033[00m".format(m))),
+ "yellow": (lambda m: print("\033[93m {}\033[00m".format(m))),
+ "purple": (lambda m: print("\033[95m {}\033[00m".format(m))),
+ "cyan": (lambda m: print("\033[96m {}\033[00m".format(m))),
+ "gray": (lambda m: print("\033[97m {}\033[00m".format(m))),
+ "black": (lambda m: print("\033[98m {}\033[00m".format(m))),
+ }
+
+ def info(self, msg):
+ self._printers["green"]("[MSC_INFO] " + str(msg))
+
+ def debug(self, msg):
+ self._printers["green"]("[MSC_DEBUG] " + str(msg))
+
+ def warning(self, msg):
+ self._printers["yellow"]("[MSC_WARNING] " + str(msg))
+
+ def error(self, msg):
+ self._printers["red"]("[MSC_ERROR] " + str(msg))
+ raise Exception(msg)
+
+
+def create_file_logger(level=logging.INFO, path: str = None) -> logging.Logger:
+ """Create file logger
+
+ Parameters
+ ----------
+ level: logging level
+ The logging level.
+ path: str
+ The file path.
+
+ Returns
+ -------
+ logger: logging.Logger
+ The logger.
+ """
+
+ path = path or os.path.join(get_workspace(), "MSC_LOG")
+ log_name = os.path.basename(path)
+ logger = logging.getLogger(log_name)
+ logger.setLevel(level)
+ if any(isinstance(h, logging.FileHandler) and h.baseFilename == path for h
in logger.handlers):
+ return logger
+ formatter = logging.Formatter(
+ "%(asctime)s %(filename)s[ln:%(lineno)d]<%(levelname)s> %(message)s"
+ )
+ handlers = [
+ logging.FileHandler(path, mode="a", encoding=None, delay=False),
+ logging.StreamHandler(),
+ ]
+ for handler in handlers:
+ handler.setLevel(level)
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+ return logger
+
+
+def set_global_logger(level: Union[str, int] = logging.INFO, path: str = None)
-> logging.Logger:
+ """Create file logger and set to global
+
+ Parameters
+ ----------
+ level: logging level
+ The logging level.
+ path: str
+ The file path.
+
+ Returns
+ -------
+ logger: logging.Logger
+ The logger.
+ """
+
+ if isinstance(level, str):
+ if level == "debug":
+ level = logging.DEBUG
+ elif level == "info":
+ level = logging.INFO
+ elif level == "warn":
+ level = logging.WARN
+ else:
+ raise Exception("Unexcept verbose {}, should be debug| info| warn")
+ logger = create_file_logger(level, path)
+ MSCMap.set(MSCKey.GLOBALE_LOGGER, logger)
+ return logger
+
+
+def get_global_logger() -> logging.Logger:
+ """Get the global logger
+
+ Returns
+ -------
+ logger: logging.Logger
+ The logger.
+ """
+
+ if not MSCMap.get(MSCKey.GLOBALE_LOGGER):
+ MSCMap.set(MSCKey.GLOBALE_LOGGER, IOLogger())
+ return MSCMap.get(MSCKey.GLOBALE_LOGGER)
diff --git a/python/tvm/contrib/msc/core/utils/message.py
b/python/tvm/contrib/msc/core/utils/message.py
new file mode 100644
index 0000000000..b3508f0e7c
--- /dev/null
+++ b/python/tvm/contrib/msc/core/utils/message.py
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.core.utils.message"""
+
+import datetime
+import logging
+
+from .info import dump_dict
+from .log import get_global_logger
+from .namespace import MSCMap, MSCKey
+
+
+def time_stamp(
+ stage: str, mark_stage: bool = False, log_stage: bool = True, logger:
logging.Logger = None
+):
+ """Mark the stamp and record time.
+
+ Parameters
+ ----------
+ stage: str
+ The stage name.
+ mark_stage: bool
+ Whether to mark the stage.
+ log_stage: bool
+ Whether to log the stage
+ logger: logging.Logger
+ The logger.
+ """
+
+ logger = logger or get_global_logger()
+ time_stamps = MSCMap.get(MSCKey.TIME_STAMPS, [])
+ time_stamps.append((stage, datetime.datetime.now()))
+ MSCMap.set(MSCKey.TIME_STAMPS, time_stamps)
+ if log_stage:
+ if mark_stage:
+ last_stage = MSCMap.get(MSCKey.MSC_STAGE)
+ if last_stage:
+ end_msg = "[MSC] End {}".format(last_stage)
+ logger.info("\n{0} {1} {0}\n".format("#" * 20,
end_msg.center(40)))
+ start_msg = "[MSC] Start {}".format(stage)
+ logger.info("\n{0} {1} {0}".format("#" * 20, start_msg.center(40)))
+ MSCMap.set(MSCKey.MSC_STAGE, stage)
+ else:
+ logger.debug("Start {}".format(stage))
+
+
+def get_duration() -> dict:
+ """Get duration of the whole process.
+
+ Returns
+ -------
+ duration: dict
+ The duration of the process.
+ """
+
+ time_stamps = MSCMap.get(MSCKey.TIME_STAMPS, [])
+ if not time_stamps:
+ return {}
+
+ def _get_duration(start_idx, end_idx):
+ return (time_stamps[end_idx][1] -
time_stamps[start_idx][1]).total_seconds()
+
+ total = _get_duration(0, -1)
+ duration = {"total": total}
+ for idx in range(len(time_stamps) - 1):
+ duration[time_stamps[idx][0]] = _get_duration(idx, idx + 1)
+ sub_durations = {}
+ for stage, _ in time_stamps:
+ if stage not in duration:
+ continue
+ if "." in stage:
+ main_stage = stage.split(".")[0]
+ if main_stage not in sub_durations:
+ sub_durations[main_stage] = {"total": 0}
+ if main_stage in duration and "init" not in
sub_durations[main_stage]:
+ sub_durations[main_stage]["init"] = duration[main_stage]
+ sub_durations[main_stage]["total"] += duration[main_stage]
+ sub_duration = duration.pop(stage)
+ sub_durations[main_stage][stage.replace(main_stage + ".", "")] =
sub_duration
+ sub_durations[main_stage]["total"] += sub_duration
+
+ # change to report format
+ def _to_str(dur):
+ return "{:.2f} s({:.2f}%)".format(dur, dur * 100 / total)
+
+ for sub_dur in sub_durations.values():
+ for stage in sub_dur:
+ sub_dur[stage] = _to_str(sub_dur[stage])
+ for stage in duration:
+ duration[stage] = _to_str(duration[stage])
+ duration.update(sub_durations)
+ return duration
+
+
+def msg_block(title: str, msg: str):
+ """Log message in block format
+
+ Parameters
+ ----------
+ title: str
+ The title of the block
+ msg: str
+ The message to log.
+
+ Returns
+ -------
+ msg: str
+ The block message.
+ """
+
+ if isinstance(msg, dict):
+ msg = dump_dict(msg, "table")
+ return "\n{0} {1} {0}\n{2}\n{3} {1} {3}".format(">" * 20,
title.center(40), msg, "<" * 20)
+
+
+def current_stage():
+ """Get the current stage"""
+
+ return MSCMap.get(MSCKey.MSC_STAGE, "Unknown")
diff --git a/python/tvm/contrib/msc/core/utils/namespace.py
b/python/tvm/contrib/msc/core/utils/namespace.py
index e9d72f1a70..dcf20bef1b 100644
--- a/python/tvm/contrib/msc/core/utils/namespace.py
+++ b/python/tvm/contrib/msc/core/utils/namespace.py
@@ -53,9 +53,13 @@ class MSCKey:
WORKSPACE = "workspace"
VERBOSE = "verbose"
+ GLOBALE_LOGGER = "global_logger"
REGISTERED_FUNCS = "registered_funcs"
REGISTERED_TOOLS = "registered_tools"
+ MSC_STAGE = "msc_stage"
+ TIME_STAMPS = "time_stamps"
+
class MSCFramework:
"""Framework type for the MSC"""
diff --git a/python/tvm/contrib/msc/framework/__init__.py
b/python/tvm/contrib/msc/framework/__init__.py
index 17c974aab9..fcdf0c886c 100644
--- a/python/tvm/contrib/msc/framework/__init__.py
+++ b/python/tvm/contrib/msc/framework/__init__.py
@@ -14,4 +14,4 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework"""
+"""tvm.contrib.msc.framework.tvm"""
diff --git a/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py
b/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py
index 5fd67735f5..4555d23528 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py
@@ -16,7 +16,7 @@
# under the License.
"""tvm.contrib.msc.framework.tensorflow.codegen.codegen"""
-from typing import Dict, Optional, Union, List
+from typing import Dict, Optional
import tvm
from tvm.contrib.msc.core.ir import MSCGraph
@@ -54,16 +54,13 @@ def to_tensorflow(
The tensorflow Graph.
"""
- def _bind_weights(
- outs: Union[tf_v1.Tensor, List[tf_v1.Tensor]], folder:
msc_utils.MSCDirectory
- ) -> Union[tf_v1.Tensor, List[tf_v1.Tensor]]:
+ def _save_weights(folder: msc_utils.MSCDirectory):
if weights:
with open(folder.relpath(graph.name + "_params.bin"), "wb") as
f_params:
f_params.write(tvm.runtime.save_param_dict(weights))
- return outs
inputs = [tf_v1.placeholder(i.dtype_name, i.get_shape(), i.alias) for i in
graph.get_inputs()]
codegen = CodeGen(
graph, _ffi_api.GetTensorflowSources, codegen_config, print_config,
build_folder
)
- return codegen.load(inputs + [weights], post_load=_bind_weights)
+ return codegen.load(inputs + [weights], pre_load=_save_weights)
diff --git a/python/tvm/contrib/msc/framework/tensorflow/frontend/__init__.py
b/python/tvm/contrib/msc/framework/tensorflow/frontend/__init__.py
index 557020a80e..419b4f8fef 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/frontend/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/frontend/__init__.py
@@ -15,3 +15,5 @@
# specific language governing permissions and limitations
# under the License.
"""tvm.contrib.msc.framework.tensorflow.frontend"""
+
+from .translate import *
diff --git a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py
b/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py
index 8ccf8820aa..dc97a315d0 100644
--- a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py
+++ b/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py
@@ -16,13 +16,13 @@
# under the License.
"""tvm.contrib.msc.framework.torch.frontend.translate"""
-from typing import Dict, Optional, Tuple, List
+from typing import Dict, Optional, Tuple, List, Union
import tvm
from tvm.contrib.msc.core.ir.graph import MSCGraph
from tvm.contrib.msc.core import transform as msc_transform
-from tvm.contrib.msc.core.ir.translate import from_relax
+from tvm.contrib.msc.core.frontend import from_relax
from tvm.contrib.msc.core.codegen import relay_to_relax
from tvm.contrib.msc.framework.tensorflow import tf_v1
@@ -35,7 +35,8 @@ def from_tensorflow(
trans_config: Optional[Dict[str, str]] = None,
build_config: Optional[Dict[str, str]] = None,
opt_config: Optional[Dict[str, str]] = None,
-) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]:
+ as_msc: bool = True,
+) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]:
"""Change tensorflow GraphDef to MSCGraph.
Parameters
@@ -54,11 +55,13 @@ def from_tensorflow(
The config for build MSCGraph.
opt_config: dict
The config for optimize the relay before translate.
+ as_msc: bool
+ Set to to return msc graph, otherwise relax mod
Returns
-------
- graph: tvm.contrib.msc.core.ir.MSCGraph
- The translated graph.
+ graph/mod: tvm.contrib.msc.core.ir.MSCGraph/tvm.IRModule
+ The translated graph/IRModule.
weights: dict of <string:tvm.ndarray>
The weights from the IRModule.
"""
@@ -70,6 +73,8 @@ def from_tensorflow(
passes = [msc_transform.BindExprName()]
relay_mod = tvm.transform.Sequential(passes)(relay_mod)
relax_mod = relay_to_relax(relay_mod, params, trans_config, build_config,
opt_config)
+ if not as_msc:
+ return relax_mod, params
build_config = build_config or {}
build_config["use_var_name"] = True
graph, weights = from_relax(relax_mod, trans_config=trans_config,
build_config=build_config)
diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
index 574c2cc31b..be233d9465 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
@@ -149,11 +149,11 @@ def to_tensorrt(
"""
target_options = {}
- for name, graph, weights in graph_infos:
+ for graph, weights in graph_infos:
options = to_sub_tensorrt(
graph, weights, codegen_config, print_config, build_folder,
output_folder
)
- target_options[name] = msc_utils.dump_dict(options)
+ target_options[graph.name] = msc_utils.dump_dict(options)
mod = tvm.transform.Sequential(
[
tvm.relax.transform.RunCodegen({"msc_tensorrt": target_options}),
diff --git a/python/tvm/contrib/msc/framework/tensorrt/frontend/__init__.py
b/python/tvm/contrib/msc/framework/tensorrt/frontend/__init__.py
index 85b163d7c6..f917195115 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/frontend/__init__.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/frontend/__init__.py
@@ -15,3 +15,5 @@
# specific language governing permissions and limitations
# under the License.
"""tvm.contrib.msc.framework.tensorrt.frontend"""
+
+from .translate import *
diff --git a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py
b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py
index 845a661396..a165a106ec 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py
@@ -21,7 +21,8 @@ from typing import Dict, Optional, Tuple, List
import tvm
from tvm import relax
from tvm.contrib.msc.core import transform as msc_transform
-from tvm.contrib.msc.core.ir import MSCGraph, byoc_partition
+from tvm.contrib.msc.core.ir import MSCGraph
+from tvm.contrib.msc.core.frontend import byoc_partition
from tvm.contrib.msc.framework.tensorrt import transform as trt_transform
@@ -30,8 +31,7 @@ def partition_for_tensorrt(
params: Optional[Dict[str, tvm.nd.array]] = None,
trans_config: Optional[Dict[str, str]] = None,
build_config: Optional[Dict[str, str]] = None,
- allow_incomplete: bool = True,
-) -> Tuple[tvm.IRModule, List[Tuple[str, MSCGraph, Dict[str, tvm.nd.array]]]]:
+) -> Tuple[tvm.IRModule, List[Tuple[MSCGraph, Dict[str, tvm.nd.array]]]]:
"""Partition module to tensorrt sub functions.
Parameters
@@ -44,15 +44,13 @@ def partition_for_tensorrt(
The parameters of the IRModule.
build_config: dict
The config for build MSCGraph.
- allow_incomplete: bool
- Whether allow some ops not on tensorrt
Returns
-------
mod: IRModule
The IRModule of partitioned relax.
- graphs_info: list<<str, MSCGraph, weights>>
- The func <name, MSCGraph and weights> list, each element for a sub
graph.
+ graphs_info: list<<MSCGraph, weights>>
+ The func <MSCGraph and weights> list, each element for a sub graph.
"""
trans_config = trans_config or {}
@@ -63,4 +61,4 @@ def partition_for_tensorrt(
relax.transform.FoldConstant(),
]
)(mod)
- return byoc_partition("msc_tensorrt", mod, params, trans_config,
build_config, allow_incomplete)
+ return byoc_partition("msc_tensorrt", mod, params, trans_config,
build_config)
diff --git a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py
b/python/tvm/contrib/msc/framework/torch/codegen/codegen.py
index 6bfe86056e..f885c81aa6 100644
--- a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py
+++ b/python/tvm/contrib/msc/framework/torch/codegen/codegen.py
@@ -54,7 +54,7 @@ def to_torch(
The torch.nn.Module.
"""
- def _bind_weights(model: torch.nn.Module, folder: msc_utils.MSCDirectory)
-> torch.nn.Module:
+ def _save_weights(folder: msc_utils.MSCDirectory):
if weights:
state_dict = {}
for name, data in weights.items():
@@ -64,9 +64,13 @@ def to_torch(
w_tensor = graph.find_tensor(name)
w_name = w_tensor.alias or name
state_dict[w_name] = torch.from_numpy(data.asnumpy())
- model.load_state_dict(state_dict)
torch.save(state_dict, folder.relpath(graph.name + ".pth"))
+
+ def _bind_weights(model: torch.nn.Module, folder: msc_utils.MSCDirectory)
-> torch.nn.Module:
+ if weights:
+ state_dict = torch.load(folder.relpath(graph.name + ".pth"))
+ model.load_state_dict(state_dict)
return model
codegen = CodeGen(graph, _ffi_api.GetTorchSources, codegen_config,
print_config, build_folder)
- return codegen.load([], post_load=_bind_weights)
+ return codegen.load([], pre_load=_save_weights, post_load=_bind_weights)
diff --git a/python/tvm/contrib/msc/framework/torch/frontend/__init__.py
b/python/tvm/contrib/msc/framework/torch/frontend/__init__.py
index 84f4bae2f4..5572720a69 100644
--- a/python/tvm/contrib/msc/framework/torch/frontend/__init__.py
+++ b/python/tvm/contrib/msc/framework/torch/frontend/__init__.py
@@ -15,3 +15,5 @@
# specific language governing permissions and limitations
# under the License.
"""tvm.contrib.msc.framework.torch.frontend"""
+
+from .translate import *
diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py
b/python/tvm/contrib/msc/framework/torch/frontend/translate.py
index 67f4b1e7e7..2dce394708 100644
--- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py
+++ b/python/tvm/contrib/msc/framework/torch/frontend/translate.py
@@ -16,7 +16,7 @@
# under the License.
"""tvm.contrib.msc.framework.torch.frontend.translate"""
-from typing import Dict, Optional, Tuple, List
+from typing import Dict, Optional, Tuple, List, Union
import numpy as np
import torch
@@ -24,10 +24,43 @@ import tvm
from tvm.relax.frontend.torch import from_fx
from tvm.contrib.msc.core.ir.graph import MSCGraph
-from tvm.contrib.msc.core.ir.translate import from_relax
+from tvm.contrib.msc.core.frontend import from_relax
from tvm.contrib.msc.core.codegen import relay_to_relax
+def set_weight_alias(graph: MSCGraph) -> MSCGraph:
+ """Set weight with alias in MSCGraph.
+
+ Parameters
+ ----------
+ graph: MSCGraph
+ The graph.
+
+ Returns
+ -------
+ graph: MSCGraph
+ The graph with weight alias.
+ """
+
+ for node in graph.get_nodes():
+ for ref, weight in node.get_weights().items():
+ if node.optype == "constant":
+ alias = node.name.replace(".", "_")
+ elif node.optype in ("nn.batch_norm", "nn.layer_norm",
"nn.group_norm"):
+ if ref == "gamma":
+ alias = node.name.replace(".", "_") + ".weight"
+ elif ref == "beta":
+ alias = node.name.replace(".", "_") + ".bias"
+ elif ref == "mean":
+ alias = node.name.replace(".", "_") + ".running_mean"
+ elif ref == "var":
+ alias = node.name.replace(".", "_") + ".running_var"
+ else:
+ alias = node.name.replace(".", "_") + "." + ref
+ weight.set_alias(alias)
+ return graph
+
+
def from_torch(
model: torch.nn.Module,
input_info: List[Tuple[Tuple[int], str]],
@@ -36,7 +69,8 @@ def from_torch(
trans_config: Optional[Dict[str, str]] = None,
build_config: Optional[Dict[str, str]] = None,
opt_config: Optional[Dict[str, str]] = None,
-) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]:
+ as_msc: bool = True,
+) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]:
"""Change torch nn.Module to MSCGraph.
Parameters
@@ -55,11 +89,13 @@ def from_torch(
The config for build MSCGraph.
opt_config: dict
The config for optimize the relay before translate.
+ as_msc: bool
+ Set to to return msc graph, otherwise relax mod
Returns
-------
- graph: tvm.contrib.msc.core.ir.MSCGraph
- The translated graph.
+ graph/mod: tvm.contrib.msc.core.ir.MSCGraph/tvm.IRModule
+ The translated graph/IRModule.
weights: dict of <string:tvm.ndarray>
The weights from the IRModule.
"""
@@ -82,22 +118,7 @@ def from_torch(
shape_list = [("input" + str(idx), i_info) for idx, i_info in
enumerate(input_info)]
relay_mod, params = tvm.relay.frontend.from_pytorch(scripted_model,
shape_list)
relax_mod = relay_to_relax(relay_mod, params, trans_config,
build_config, opt_config)
+ if not as_msc:
+ return relax_mod, params
graph, weights = from_relax(relax_mod, trans_config=trans_config,
build_config=build_config)
- # set alias for weights
- for node in graph.get_nodes():
- for ref, weight in node.get_weights().items():
- if node.optype == "constant":
- alias = node.name.replace(".", "_")
- elif node.optype in ("nn.batch_norm", "nn.layer_norm",
"nn.group_norm"):
- if ref == "gamma":
- alias = node.name.replace(".", "_") + ".weight"
- elif ref == "beta":
- alias = node.name.replace(".", "_") + ".bias"
- elif ref == "mean":
- alias = node.name.replace(".", "_") + ".running_mean"
- elif ref == "var":
- alias = node.name.replace(".", "_") + ".running_var"
- else:
- alias = node.name.replace(".", "_") + "." + ref
- weight.set_alias(alias)
- return graph, weights
+ return set_weight_alias(graph), weights
diff --git a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py
b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py
index 0edce88365..c30e05ed98 100644
--- a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py
+++ b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py
@@ -59,12 +59,16 @@ def to_relax(
for i in graph.get_inputs()
]
- def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) ->
tvm.IRModule:
+ def _save_weights(folder: msc_utils.MSCDirectory):
if weights:
- mod = BindParams("main", weights)(mod)
with open(folder.relpath(graph.name + "_params.bin"), "wb") as
f_params:
f_params.write(tvm.runtime.save_param_dict(weights))
+
+ # pylint: disable=unused-argument
+ def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) ->
tvm.IRModule:
+ if weights:
+ mod = BindParams("main", weights)(mod)
return mod
codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config,
print_config, build_folder)
- return codegen.load(inputs, post_load=_bind_weights)
+ return codegen.load(inputs, pre_load=_save_weights,
post_load=_bind_weights)
diff --git a/python/tvm/contrib/msc/framework/__init__.py
b/python/tvm/contrib/msc/framework/tvm/runtime/__init__.py
similarity index 92%
copy from python/tvm/contrib/msc/framework/__init__.py
copy to python/tvm/contrib/msc/framework/tvm/runtime/__init__.py
index 17c974aab9..73ea8bb06c 100644
--- a/python/tvm/contrib/msc/framework/__init__.py
+++ b/python/tvm/contrib/msc/framework/tvm/runtime/__init__.py
@@ -14,4 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""tvm.contrib.msc.framework"""
+"""tvm.contrib.msc.framework.tvm.runtime"""
+
+from .runner import *
diff --git a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py
b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py
new file mode 100644
index 0000000000..90ba8e4cce
--- /dev/null
+++ b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.contrib.msc.framework.runtime.tvm.runner"""
+
+from typing import Dict, List, Union
+import numpy as np
+
+import tvm
+from tvm.contrib.msc.core.runtime import ModelRunner
+from tvm.contrib.msc.core.utils.namespace import MSCFramework
+from tvm.contrib.msc.framework.tvm.codegen import to_relax
+
+
+class TVMRunner(ModelRunner):
+ """Runner of Relax"""
+
+ def _to_runnable(self, model: object, device: str, is_training: bool) ->
object:
+ """Build runnable object
+
+ Parameters
+ -------
+ model: object
+ The meta model.
+ device: str
+ The device for place model
+ is_training: bool
+ Whether to load model for training
+
+ Returns
+ -------
+ runnable: object
+ The runnable
+ """
+
+ if "builder" in self._load_config:
+ builder, build_config = self._load_config["builder"]
+ runnable = builder(model, **build_config)
+ self._logger.info(
+ "Model({}) processed by customize builder {}({})".format(
+ self.framework, builder, build_config
+ )
+ )
+ else:
+ model = tvm.relax.transform.LegalizeOps()(model)
+ if device == "cpu":
+ target = tvm.target.Target("llvm")
+ with tvm.transform.PassContext(opt_level=3):
+ relax_exec = tvm.relax.build(model, target)
+ runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu())
+ elif device.startswith("cuda"):
+ target = tvm.target.Target("cuda")
+ with target:
+ model = tvm.tir.transform.DefaultGPUSchedule()(model)
+ with tvm.transform.PassContext(opt_level=3):
+ relax_exec = tvm.relax.build(model, target)
+ runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cuda())
+ else:
+ raise NotImplementedError("Unsupported device " + str(device))
+ return runnable
+
+ def _call_runnable(
+ self, runnable: tvm.relax.VirtualMachine, inputs: Dict[str,
np.ndarray], device: str
+ ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]:
+ """Call the runnable to get outputs
+
+ Parameters
+ -------
+ runnable: tvm.relax.VirtualMachine
+ The virtual machine.
+ inputs: dict<str, data>
+ The inputs in dict.
+ device: str
+ The device.
+
+ Returns
+ -------
+ outputs: list<data>
+ The outputs in list.
+ """
+
+ model_inputs = self.get_inputs()
+ if device == "cpu":
+ tvm_inputs = [tvm.nd.array(inputs[i["name"]]) for i in
model_inputs]
+ elif device.startswith("cuda"):
+ dev_id = int(device.split(":")[1]) if ":" in device else 0
+ tvm_inputs = [
+ tvm.nd.array(inputs[i["name"]], device=tvm.cuda(dev_id)) for i
in model_inputs
+ ]
+ else:
+ raise NotImplementedError("Unsupported device " + str(device))
+ return runnable["main"](*tvm_inputs)
+
+ def _device_enabled(self, device: str) -> bool:
+ """Check if the device is enabled
+
+ Returns
+ -------
+ enabled: bool
+ Whether the device is enabled.
+ """
+
+ if device == "cpu":
+ return True
+ if device.startswith("cuda"):
+ dev_id = int(device.split(":")[1]) if ":" in device else 0
+ return tvm.cuda(dev_id).exist
+ return False
+
+ @property
+ def codegen_func(self):
+ return to_relax
+
+ @property
+ def framework(self):
+ return MSCFramework.TVM
diff --git a/tests/python/contrib/test_msc/test_runner.py
b/tests/python/contrib/test_msc/test_runner.py
new file mode 100644
index 0000000000..4653a2225b
--- /dev/null
+++ b/tests/python/contrib/test_msc/test_runner.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Test Runners in MSC. """
+
+import pytest
+import numpy as np
+
+import torch
+from torch import fx
+
+import tvm.testing
+from tvm.relax.frontend.torch import from_fx
+from tvm.contrib.msc.framework.tvm.runtime import TVMRunner
+from tvm.contrib.msc.core import utils as msc_utils
+
+requires_tensorrt = pytest.mark.skipif(
+ tvm.get_global_func("relax.ext.tensorrt", True) is None,
+ reason="TENSORRT is not enabled",
+)
+
+
+def _get_torch_model(name, is_training=False):
+ """Get model from torch vision"""
+ # pylint: disable=import-outside-toplevel
+ try:
+ import torchvision
+
+ model = getattr(torchvision.models, name)(pretrained=True)
+ if is_training:
+ model = model.train()
+ else:
+ model = model.eval()
+ return model
+ except: # pylint: disable=bare-except
+ print("please install torchvision package")
+ return None
+
+
+def _test_from_torch(runner_cls, device, is_training=False, atol=1e-3,
rtol=1e-3):
+ """Test runner from torch model"""
+ torch_model = _get_torch_model("resnet50", is_training)
+ if torch_model:
+ workspace = msc_utils.set_workspace()
+ input_info = [([1, 3, 224, 224], "float32")]
+ datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
+ torch_datas = [torch.from_numpy(d) for d in datas]
+ graph_model = fx.symbolic_trace(torch_model)
+ with torch.no_grad():
+ golden = torch_model(*torch_datas)
+ mod = from_fx(graph_model, input_info)
+ runner = runner_cls(mod, device=device, is_training=is_training)
+ runner.build()
+ outputs = runner.run(datas, ret_type="list")
+ golden = [msc_utils.cast_array(golden)]
+ for gol_r, out_r in zip(golden, outputs):
+ tvm.testing.assert_allclose(gol_r, out_r, atol=atol, rtol=rtol)
+ workspace.destory()
+
+
+def test_tvm_runner_cpu():
+ """Test runner for tvm on cpu"""
+
+ _test_from_torch(TVMRunner, "cpu", is_training=True)
+
+
[email protected]_gpu
+def test_tvm_runner_gpu():
+ """Test runner for tvm on gpu"""
+
+ _test_from_torch(TVMRunner, "cuda", is_training=True)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py
b/tests/python/contrib/test_msc/test_translate_tensorrt.py
index e9981237be..0c142c1d46 100644
--- a/tests/python/contrib/test_msc/test_translate_tensorrt.py
+++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py
@@ -30,7 +30,6 @@ from tvm.contrib.msc.framework.tensorrt.frontend import
translate
from tvm.contrib.msc.framework.tensorrt import codegen
from tvm.contrib.msc.core import utils as msc_utils
-
requires_tensorrt = pytest.mark.skipif(
tvm.get_global_func("relax.ext.tensorrt", True) is None,
reason="TENSORRT is not enabled",
@@ -66,7 +65,9 @@ def verify_model(torch_model, input_info,
allow_incomplete=False):
golden = [golden]
golden = [g.detach().cpu().numpy() for g in golden]
# partition module for tensorrt
- mod, graph_infos = translate.partition_for_tensorrt(mod,
allow_incomplete=allow_incomplete)
+ mod, graph_infos = translate.partition_for_tensorrt(
+ mod, trans_config={"allow_incomplete": allow_incomplete}
+ )
output_folder = msc_utils.msc_dir()
# tranalte to tensorrt
mod = codegen.to_tensorrt(mod, graph_infos, output_folder=output_folder)