This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new b1df4b0856 [Unity][Web][Fix] Fix fetchNDArray for f32-to-bf16 (#16294)
b1df4b0856 is described below
commit b1df4b085608158b451d4777743d63f9fab3e0e3
Author: Charlie Ruan <[email protected]>
AuthorDate: Tue Jan 2 06:10:59 2024 +0800
[Unity][Web][Fix] Fix fetchNDArray for f32-to-bf16 (#16294)
Currently when loading the params, we try to decode from bf16 to f32
regardless of the dtype of the param, since all the params are stored with
"format=f32-to-bf16" in the record regardless of the dtype.
We solve it by checking the dtype as well, just like the C++ counter part:
https://github.com/apache/tvm/blob/4e66690a4d033af912f5051c0e5a16c9c10691d9/src/runtime/relax_vm/ndarray_cache_support.cc#L168-L172
---
web/emcc/wasm_runtime.cc | 4 ++--
web/src/runtime.ts | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc
index 60f40adbf4..311bbd9971 100644
--- a/web/emcc/wasm_runtime.cc
+++ b/web/emcc/wasm_runtime.cc
@@ -126,8 +126,8 @@
TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRet
*ret = (obj.use_count() - 1);
});
-void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string
format) {
- if (format == "f32-to-bf16") {
+void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string
format, std::string dtype) {
+ if (format == "f32-to-bf16" && dtype == "float32") {
std::vector<uint16_t> buffer(bytes.length() / 2);
std::memcpy(buffer.data(), bytes.data(), buffer.size() * 2);
// decode bf16 to f32
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index f842b2723f..5aa38dee39 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -1556,7 +1556,7 @@ export class Instance implements Disposable {
});
const recSource = buffer.slice(rec.byteOffset, rec.byteOffset +
rec.nbytes);
// first sync copy to cpu.
- this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource),
rec.format);
+ this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource),
rec.format, rec.dtype);
// then async stream into GPU if needed
if (device.deviceType === DeviceStrToEnum.cpu) {
this.ndarrayCacheUpdate(rec.name, cpu_arr, false);