This is an automated email from the ASF dual-hosted git repository. ruihangl pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 57316dae14 [Web] Support string[] in setPackedFunc() and exceptionally long arrays (#16910) 57316dae14 is described below commit 57316dae1497c36ed57732a7a610018a990f1927 Author: Charlie Ruan <53290280+charliefr...@users.noreply.github.com> AuthorDate: Sun Apr 21 17:40:26 2024 -0400 [Web] Support string[] in setPackedFunc() and exceptionally long arrays (#16910) There are two changes in this PR. #### Change 1: Support `string[]` in `setPackedFunc()` Prior to this PR, we cannot pass in `string[]` from typescript to a TVM PackedFunc and need to convert it to `TVMArray<TVMString>` (for instance in `getParamsFromCacheByName()`). This may not be the most convenient if the PackedFunc's caller is not internal to tvmjs. Thus, this PR moves the conversion to `setPackedFunc()` instead. #### Change 2: Support exceptionally long TVM arrays The second change is dealing with exceptionally long TVM arrays. In cases like passing in a token table, we need to pass in a long `string[]` (in Llama-3's case, of size 128000), leading to JS error `RangeError: Maximum call stack size exceeded` since we treat each string element as an argument, shown in `this.ctx.arrayMake(...inputs)`. This PR sets an empirical call stack limit of 30000 and chunks the array elements in `makeTVMArray()`, converting each chunk to its own TVMArray. Then we concatenate them with the newly implemented `runtime.ArrayConcat` that concatenates N TVMArrays. Tested end-to-end in WebLLM. --- web/emcc/wasm_runtime.cc | 17 +++++++++++++++++ web/package-lock.json | 4 ++-- web/src/runtime.ts | 32 ++++++++++++++++++++++++++------ 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 00c37dd22a..2f71355958 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -156,5 +156,22 @@ void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, } TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage); + +// Concatenate n TVMArrays +TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat").set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector<ObjectRef> data; + for (int i = 0; i < args.size(); ++i) { + // Get i-th TVMArray + ICHECK_EQ(args[i].type_code(), kTVMObjectHandle); + Object* ptr = static_cast<Object*>(args[i].value().v_handle); + ICHECK(ptr->IsInstance<ArrayNode>()); + auto* arr_i = static_cast<const ArrayNode*>(ptr); + for (size_t j = 0; j < arr_i->size(); ++j) { + // Push back each j-th element of the i-th array + data.push_back(arr_i->at(j)); + } + } + *ret = Array<ObjectRef>(data); +}); } // namespace runtime } // namespace tvm diff --git a/web/package-lock.json b/web/package-lock.json index 74561324c9..75efcbcc7b 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "tvmjs", - "version": "0.16.0-dev0", + "version": "0.17.0-dev0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "tvmjs", - "version": "0.16.0-dev0", + "version": "0.17.0-dev0", "license": "Apache-2.0", "devDependencies": { "@rollup/plugin-commonjs": "^20.0.0", diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 4b40bbc341..ff4dce497d 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -156,6 +156,7 @@ class RuntimeContext implements Disposable { arrayGetItem: PackedFunc; arrayGetSize: PackedFunc; arrayMake: PackedFunc; + arrayConcat: PackedFunc; stringMake: PackedFunc; getFFIString: PackedFunc; getSysLib: PackedFunc; @@ -180,6 +181,7 @@ class RuntimeContext implements Disposable { this.arrayGetItem = getGlobalFunc("runtime.ArrayGetItem"); this.arrayGetSize = getGlobalFunc("runtime.ArraySize"); this.arrayMake = getGlobalFunc("runtime.Array"); + this.arrayConcat = getGlobalFunc("tvmjs.runtime.ArrayConcat"); this.stringMake = getGlobalFunc("runtime.String"); this.getFFIString = getGlobalFunc("runtime.GetFFIString"); this.getSysLib = getGlobalFunc("runtime.SystemLib"); @@ -205,6 +207,7 @@ class RuntimeContext implements Disposable { this.arrayGetItem.dispose(); this.arrayGetSize.dispose(); this.arrayMake.dispose(); + this.arrayConcat.dispose(); this.stringMake.dispose(); this.getFFIString.dispose(); this.arrayCacheGet.dispose(); @@ -1382,11 +1385,7 @@ export class Instance implements Disposable { * @returns Parameters read. */ getParamsFromCacheByName(paramNames: Array<string>): TVMObject { - // Convert Array<string> to Array<TVMString> - const paramNamesTVM: TVMString[] = []; - paramNames.forEach(paramName => { paramNamesTVM.push(this.makeString(paramName)) }); - return (this.ctx.paramModuleFromCacheByName( - this.makeTVMArray(paramNamesTVM)) as Module).getFunction("get_params")(); + return (this.ctx.paramModuleFromCacheByName(paramNames) as Module).getFunction("get_params")(); } /** @@ -1873,7 +1872,20 @@ export class Instance implements Disposable { makeTVMArray( inputs: Array<TVMObjectBase> ): TVMArray { - return this.ctx.arrayMake(...inputs) as TVMArray; + const CALL_STACK_LIMIT = 30000; + const inputsLength = inputs.length; + if (inputsLength <= CALL_STACK_LIMIT) { + return this.ctx.arrayMake(...inputs) as TVMArray; + } + // If too many elements, TypeScript would complain `Maximum call stack size exceeded` + // So we make several arrays and concatenate them + const listOfArrays: Array<TVMArray> = []; + for (let begin = 0; begin < inputsLength; begin += CALL_STACK_LIMIT) { + const end = Math.min(inputsLength, begin + CALL_STACK_LIMIT); + const chunk: Array<TVMObjectBase> = inputs.slice(begin, end); + listOfArrays.push(this.ctx.arrayMake(...chunk) as TVMArray); + } + return this.ctx.arrayConcat(...listOfArrays) as TVMArray; } /** @@ -2230,6 +2242,14 @@ export class Instance implements Disposable { const tp = typeof val; const valueOffset = argsValue + i * SizeOf.TVMValue; const codeOffset = argsCode + i * SizeOf.I32; + + // Convert string[] to a TVMArray of TVMString, hence treated as a TVMObject + if (val instanceof Array && val.every(e => typeof e === "string")) { + const tvmStringArray: TVMString[] = []; + val.forEach(e => { tvmStringArray.push(this.makeString(e)) }); + val = this.makeTVMArray(tvmStringArray); + } + if (val instanceof NDArray) { if (!val.isView) { stack.storePtr(valueOffset, val.getHandle());