This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 639a6e4f3c [Contrib] Support NDArray cache taking generator (#16693)
639a6e4f3c is described below
commit 639a6e4f3ccaccfa0545113439b2604de0a1dcb6
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 10 18:38:57 2024 -0400
[Contrib] Support NDArray cache taking generator (#16693)
This PR enhances the `dump_ndarray_cache` function to take
generator as input. Previously it can only take a dictionary.
Sometimes, it is possible that the total ndarray size cannot
fit the main CPU memory, in which case we may turn to using
generators so we can free some NDArray memory on the fly.
And this PR supports the NDArray cache dumping with generators.
---
python/tvm/contrib/tvmjs.py | 43 ++++++++++++++++++++++++++++++-------------
1 file changed, 30 insertions(+), 13 deletions(-)
diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
index 4cef868cfd..8d8bd1b051 100644
--- a/python/tvm/contrib/tvmjs.py
+++ b/python/tvm/contrib/tvmjs.py
@@ -17,12 +17,14 @@
"""Namespace to store utilities for building web runtime."""
import hashlib
import json
+import math
import os
import shutil
# pylint: disable=unused-import
import sys
-from typing import Mapping, Union
+from types import GeneratorType
+from typing import Iterator, Mapping, Tuple, Union
import numpy as np
@@ -149,18 +151,25 @@ class NDArrayCacheShardingManager:
def dump_ndarray_cache(
- params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
+ params: Union[
+ Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
+ Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]],
+ ],
cache_dir: str,
encode_format="f32-to-bf16",
meta_data=None,
shard_cap_mb=32,
+ show_progress: bool = True,
):
"""Dump parameters to NDArray cache.
Parameters
----------
- params: Mapping[str, tvm.runtime.NDArray],
- The parameter dictionary
+ params: Union[
+ Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
+ Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]],
+ ]
+ The parameter dictionary or generator
cache_dir: str
The path to the cache
@@ -168,18 +177,22 @@ def dump_ndarray_cache(
encode_format: {"f32-to-bf16", "raw"}
Encoding format.
- meta_data: json-compatible-struct
- Extra meta_data to be stored in the cache json file.
+ meta_data: json-compatible-struct or Callable[[], Any]
+ Extra meta_data to be stored in the cache json file,
+ or a callable that returns the metadata.
shard_cap_mb: int
Maxinum number of MB to be kept per shard
+
+ show_progress: bool
+ A boolean indicating if to show the dump progress.
"""
if encode_format not in ("raw", "f32-to-bf16"):
raise ValueError(f"Invalie encode_format {encode_format}")
- meta_data = {} if meta_data is None else meta_data
records = []
- total = len(params)
+ from_generator = isinstance(params, GeneratorType)
+ total_bytes = 0
counter = 0
max_out_length = 0
@@ -193,7 +206,8 @@ def dump_ndarray_cache(
shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard",
shard_cap_nbytes)
- for k, origin_v in params.items():
+ param_generator = params.items() if not from_generator else params
+ for k, origin_v in param_generator:
shape = list(origin_v.shape)
v = origin_v
if not isinstance(v, np.ndarray):
@@ -201,6 +215,7 @@ def dump_ndarray_cache(
# prefer to preserve original dtype, especially if the format was
bfloat16
dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray)
else str(v.dtype)
+ total_bytes += math.prod(v.shape) * np.dtype(v.dtype).itemsize
# convert fp32 to bf16
if encode_format == "f32-to-bf16" and dtype == "float32":
@@ -212,12 +227,14 @@ def dump_ndarray_cache(
shard_manager.append(data, name=k, shape=shape, dtype=dtype,
encode_format=encode_format)
counter += 1
- last_cmd = "[%04d/%04d] saving %s" % (counter, total, k)
- flush = "\r" + (" " * max_out_length) + "\r"
- max_out_length = max(len(last_cmd), max_out_length)
- sys.stdout.write(flush + last_cmd)
+ if show_progress:
+ last_cmd = "[%04d] saving %s" % (counter, k)
+ flush = "\r" + (" " * max_out_length) + "\r"
+ max_out_length = max(len(last_cmd), max_out_length)
+ sys.stdout.write(flush + last_cmd)
records = shard_manager.finish()
+ meta_data = {} if meta_data is None else meta_data if not
callable(meta_data) else meta_data()
nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json")