This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new f5ab3f05eb [TIR] [Analysis] Calculate allocated memory at module level (#14711) f5ab3f05eb is described below commit f5ab3f05eb3190b836f41bbeb975258232010def Author: Anirudh Sundar Subramaniam <quic_sanir...@quicinc.com> AuthorDate: Tue Apr 25 09:18:42 2023 +0530 [TIR] [Analysis] Calculate allocated memory at module level (#14711) * [TIR] [Analysis] Calculate allocated memory at module level This patch modifies the existing analysis pass `tir.calculate_allocated_bytes` to accept an IRModule as an argument and return allocated bytes for all prim_funcs in the IRModule. * Fix docstring and modify python API to be consistent with c++ --- include/tvm/tir/analysis.h | 12 +++- python/tvm/tir/analysis/analysis.py | 21 +++++-- src/tir/analysis/calculate_allocated_memory.cc | 36 ++++++++--- ...test_tir_analysis_calculate_allocated_memory.py | 69 +++++++++++++++++----- 4 files changed, 109 insertions(+), 29 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 4ed164e5ad..3b5959e781 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -266,8 +266,18 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, /*! * \brief Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc * \param func The TIR PrimFunc for which the the allocated memory size to be calculated + * \return Allocated memory size per scope in bytes inside the PrimFunc returned as a Map with + * key "main" and a Map of allocated sizes as values. */ -TVM_DLL tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func); +TVM_DLL tvm::Map<String, tvm::Map<String, Integer>> CalculateAllocatedBytes(const PrimFunc& func); + +/*! + * \brief Calculate the allocated memory per scope in bytes for each function inside the module + * \param mod The IRModule for which the the allocated memory size has to be calculated + * \return Allocated memory size per scope in bytes for each function in the IRModule returned as a + Map with function names as keys and a Map of allocated sizes as values. + */ +TVM_DLL tvm::Map<String, tvm::Map<String, Integer>> CalculateAllocatedBytes(const IRModule& mod); /*! * \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 5feb630e48..387ea04980 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -201,20 +201,29 @@ def calculate_constant_bytes(func: PrimFunc, constant_byte_alignment: int) -> in return _ffi_api.calculate_constant_bytes(func, constant_byte_alignment) # type: ignore -def calculate_allocated_bytes(func: PrimFunc) -> Dict[str, int]: +def calculate_allocated_bytes( + func_or_mod: Union[PrimFunc, IRModule] +) -> Union[Dict[str, int], Dict[str, Dict[str, int]]]: """Calculate allocated memory per memory scope required by TIR PrimFuncs. Parameters ---------- - func: tvm.tir.PrimFunc - The function to be detected. + func_or_mod: Union[PrimFunc, IRModule] + The function or module to be detected. If a module is passed, allocated + memory is calcualted for all PrimFuncs inside the module Returns ------- - result : Dict[String, int] - Allocated memory size per scope in bytes. + result : Union[Dict[str, int], Dict[str, Dict[str, int]]] + Allocated memory size per scope in bytes for each function in the IRModule returned as a + dict with function names as keys and a dict of allocated sizes as values. If a single + PrimFunc is passed, the function name is returned as "main" """ - return _ffi_api.calculate_allocated_bytes(func) # type: ignore + if not isinstance(func_or_mod, (PrimFunc, IRModule)): + raise TypeError( + f"Expected argument to be PrimFunc or IRModule, but received {type(func_or_mod)}" + ) + return _ffi_api.calculate_allocated_bytes(func_or_mod) # type: ignore def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index ffdfc1f801..8680f57e4c 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -79,16 +79,38 @@ void AllocationCalculator<T>::VisitStmt_(const T* op) { _current_size[storage_scope] -= size; } -tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func) { - return AllocationCalculator<AllocateNode>()(func); +tvm::Map<String, tvm::Map<String, Integer> > CalculateAllocatedBytes(const PrimFunc& func) { + tvm::Map<String, tvm::Map<String, Integer> > results; + results.Set("main", AllocationCalculator<AllocateNode>()(func)); + return results; } -TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](PrimFunc func) { - return CalculateAllocatedBytes(func); -}); +tvm::Map<String, tvm::Map<String, Integer> > CalculateAllocatedBytes(const IRModule& mod) { + tvm::Map<String, tvm::Map<String, Integer> > results; + for (const auto& kv : mod->functions) { + if (auto prim_func = kv.second.as<tir::PrimFunc>()) { + String func_name = kv.first->name_hint; + results.Set(func_name, AllocationCalculator<AllocateNode>()(prim_func.value())); + } + } + return results; +} + +TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes") + .set_body_typed([](ObjectRef obj) -> tvm::Map<String, tvm::Map<String, Integer> > { + if (auto func = obj.as<PrimFunc>()) { + return CalculateAllocatedBytes(func.value()); + } else if (auto mod = obj.as<IRModule>()) { + return CalculateAllocatedBytes(mod.value()); + } else { + LOG(FATAL) << "TypeError: Expect the input to be either PrimFunc or IRModule, but gets: " + << obj->GetTypeKey(); + throw; + } + }); bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) { - auto sizes = CalculateAllocatedBytes(func); + auto sizes = CalculateAllocatedBytes(func)["main"]; const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) { return false; @@ -121,7 +143,7 @@ Pass VerifyVTCMLimit(Optional<Target> default_target) { } if (limit.has_value() && limit.value() > 0) { - auto sizes = CalculateAllocatedBytes(func); + auto sizes = CalculateAllocatedBytes(func)["main"]; const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); if (vtcm_allocated.IntValue() > limit.value()) { LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been exceeded " diff --git a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py index 2311bfbbef..cb3a663c03 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py +++ b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py @@ -14,32 +14,42 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import pytest import tvm from tvm import tir from tvm.script import tir as T +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks -@T.prim_func -def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")): - for i in T.serial(128): - with T.block("C"): - c[i] = a[i] * T.int8(2) +@tvm.script.ir_module +class Module: + @T.prim_func + def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")): + for i in T.serial(128): + with T.block("C"): + c[i] = a[i] * T.int8(2) -@T.prim_func -def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")): - B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm") - for i in T.serial(128): - with T.block("B"): - B[i] = a[i] * T.int8(2) - for i in T.serial(128): - with T.block("C"): - c[i] = B[i] * T.int8(3) + @T.prim_func + def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")): + B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm") + for i in T.serial(128): + with T.block("B"): + B[i] = a[i] * T.int8(2) + for i in T.serial(128): + with T.block("C"): + c[i] = B[i] * T.int8(3) + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on -@pytest.mark.parametrize("primFunc,size", [(scale_by_two, 128), (scale_by_two_three, 256)]) +@pytest.mark.parametrize( + "primFunc,size", [(Module["scale_by_two"], 128), (Module["scale_by_two_three"], 256)] +) def test_scale_by(primFunc, size): """Test calculate allocated bytes per scope""" mod = tvm.IRModule.from_expr(primFunc.with_attr("global_symbol", "main")) @@ -53,6 +63,8 @@ def test_scale_by(primFunc, size): mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) mod = tvm.tir.transform.LowerOpaqueBlock()(mod) sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"]) + assert "main" in sizes, 'Calls with PrimFunc is expected to return with function key as "main"' + sizes = sizes["main"] assert sizes.get("global.vtcm", 0) == size @@ -94,8 +106,35 @@ def test_matmul_mix_scope(scope, size): mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) mod = tvm.tir.transform.LowerOpaqueBlock()(mod) sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"]) + assert "main" in sizes, 'Calls with PrimFunc is expected to return with function key as "main"' + sizes = sizes["main"] assert sizes.get(scope, 0) == size +def test_full_mod_calculator(): + def apply_schedule(sch, func_name): + sch.work_on(func_name) + block_c = sch.get_block("C") + sch.cache_read(block_c, 0, storage_scope="global.vtcm") + + sch = tvm.tir.Schedule(Module, debug_mask="all") + apply_schedule(sch, "scale_by_two") + apply_schedule(sch, "scale_by_two_three") + mod = tvm.tir.transform.ConvertBlocksToOpaque()(sch.mod) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + sizes = tvm.tir.analysis.calculate_allocated_bytes(mod) + assert "scale_by_two" in sizes, "Values for scale_by_two not found" + scale_by_two_sizes = sizes["scale_by_two"] + assert ( + "global.vtcm" in scale_by_two_sizes + ), "Expected global.vtcm allocation to be calculated scale_by_two" + assert scale_by_two_sizes["global.vtcm"] == 128, "Expected the calculated size to be 128" + scale_by_two_three_sizes = sizes["scale_by_two_three"] + assert ( + "global.vtcm" in scale_by_two_three_sizes + ), "Expected global.vtcm allocation to be calculated scale_by_two_three" + assert scale_by_two_three_sizes["global.vtcm"] == 256, "Expected the calculated size to be 256" + + if __name__ == "__main__": tvm.testing.main()