This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 639a6e4f3c [Contrib] Support NDArray cache taking generator (#16693)
639a6e4f3c is described below

commit 639a6e4f3ccaccfa0545113439b2604de0a1dcb6
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 10 18:38:57 2024 -0400

    [Contrib] Support NDArray cache taking generator (#16693)
    
    This PR enhances the `dump_ndarray_cache` function to take
    generator as input. Previously it can only take a dictionary.
    
    Sometimes, it is possible that the total ndarray size cannot
    fit the main CPU memory, in which case we may turn to using
    generators so we can free some NDArray memory on the fly.
    And this PR supports the NDArray cache dumping with generators.
---
 python/tvm/contrib/tvmjs.py | 43 ++++++++++++++++++++++++++++++-------------
 1 file changed, 30 insertions(+), 13 deletions(-)

diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
index 4cef868cfd..8d8bd1b051 100644
--- a/python/tvm/contrib/tvmjs.py
+++ b/python/tvm/contrib/tvmjs.py
@@ -17,12 +17,14 @@
 """Namespace to store utilities for building web runtime."""
 import hashlib
 import json
+import math
 import os
 import shutil
 
 # pylint: disable=unused-import
 import sys
-from typing import Mapping, Union
+from types import GeneratorType
+from typing import Iterator, Mapping, Tuple, Union
 
 import numpy as np
 
@@ -149,18 +151,25 @@ class NDArrayCacheShardingManager:
 
 
 def dump_ndarray_cache(
-    params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
+    params: Union[
+        Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
+        Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]],
+    ],
     cache_dir: str,
     encode_format="f32-to-bf16",
     meta_data=None,
     shard_cap_mb=32,
+    show_progress: bool = True,
 ):
     """Dump parameters to NDArray cache.
 
     Parameters
     ----------
-    params: Mapping[str, tvm.runtime.NDArray],
-        The parameter dictionary
+    params: Union[
+        Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
+        Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]],
+    ]
+        The parameter dictionary or generator
 
     cache_dir: str
         The path to the cache
@@ -168,18 +177,22 @@ def dump_ndarray_cache(
     encode_format: {"f32-to-bf16", "raw"}
         Encoding format.
 
-    meta_data: json-compatible-struct
-        Extra meta_data to be stored in the cache json file.
+    meta_data: json-compatible-struct or Callable[[], Any]
+        Extra meta_data to be stored in the cache json file,
+        or a callable that returns the metadata.
 
     shard_cap_mb: int
         Maxinum number of MB to be kept per shard
+
+    show_progress: bool
+        A boolean indicating if to show the dump progress.
     """
     if encode_format not in ("raw", "f32-to-bf16"):
         raise ValueError(f"Invalie encode_format {encode_format}")
 
-    meta_data = {} if meta_data is None else meta_data
     records = []
-    total = len(params)
+    from_generator = isinstance(params, GeneratorType)
+    total_bytes = 0
     counter = 0
     max_out_length = 0
 
@@ -193,7 +206,8 @@ def dump_ndarray_cache(
 
     shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard", 
shard_cap_nbytes)
 
-    for k, origin_v in params.items():
+    param_generator = params.items() if not from_generator else params
+    for k, origin_v in param_generator:
         shape = list(origin_v.shape)
         v = origin_v
         if not isinstance(v, np.ndarray):
@@ -201,6 +215,7 @@ def dump_ndarray_cache(
 
         # prefer to preserve original dtype, especially if the format was 
bfloat16
         dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) 
else str(v.dtype)
+        total_bytes += math.prod(v.shape) * np.dtype(v.dtype).itemsize
 
         # convert fp32 to bf16
         if encode_format == "f32-to-bf16" and dtype == "float32":
@@ -212,12 +227,14 @@ def dump_ndarray_cache(
         shard_manager.append(data, name=k, shape=shape, dtype=dtype, 
encode_format=encode_format)
 
         counter += 1
-        last_cmd = "[%04d/%04d] saving %s" % (counter, total, k)
-        flush = "\r" + (" " * max_out_length) + "\r"
-        max_out_length = max(len(last_cmd), max_out_length)
-        sys.stdout.write(flush + last_cmd)
+        if show_progress:
+            last_cmd = "[%04d] saving %s" % (counter, k)
+            flush = "\r" + (" " * max_out_length) + "\r"
+            max_out_length = max(len(last_cmd), max_out_length)
+            sys.stdout.write(flush + last_cmd)
 
     records = shard_manager.finish()
+    meta_data = {} if meta_data is None else meta_data if not 
callable(meta_data) else meta_data()
 
     nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json")
 

Reply via email to