This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b1df4b0856 [Unity][Web][Fix] Fix fetchNDArray for f32-to-bf16 (#16294)
b1df4b0856 is described below

commit b1df4b085608158b451d4777743d63f9fab3e0e3
Author: Charlie Ruan <[email protected]>
AuthorDate: Tue Jan 2 06:10:59 2024 +0800

    [Unity][Web][Fix] Fix fetchNDArray for f32-to-bf16 (#16294)
    
    Currently when loading the params, we try to decode from bf16 to f32 
regardless of the dtype of the param, since all the params are stored with 
"format=f32-to-bf16" in the record regardless of the dtype.
    
    We solve it by checking the dtype as well, just like the C++ counter part:
    
https://github.com/apache/tvm/blob/4e66690a4d033af912f5051c0e5a16c9c10691d9/src/runtime/relax_vm/ndarray_cache_support.cc#L168-L172
---
 web/emcc/wasm_runtime.cc | 4 ++--
 web/src/runtime.ts       | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc
index 60f40adbf4..311bbd9971 100644
--- a/web/emcc/wasm_runtime.cc
+++ b/web/emcc/wasm_runtime.cc
@@ -126,8 +126,8 @@ 
TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRet
   *ret = (obj.use_count() - 1);
 });
 
-void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string 
format) {
-  if (format == "f32-to-bf16") {
+void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string 
format, std::string dtype) {
+  if (format == "f32-to-bf16" && dtype == "float32") {
     std::vector<uint16_t> buffer(bytes.length() / 2);
     std::memcpy(buffer.data(), bytes.data(), buffer.size() * 2);
     // decode bf16 to f32
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index f842b2723f..5aa38dee39 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -1556,7 +1556,7 @@ export class Instance implements Disposable {
         });
         const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + 
rec.nbytes);
         // first sync copy to cpu.
-        this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), 
rec.format);
+        this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), 
rec.format, rec.dtype);
         // then async stream into GPU if needed
         if (device.deviceType === DeviceStrToEnum.cpu) {
           this.ndarrayCacheUpdate(rec.name, cpu_arr, false);

Reply via email to