This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f5ab3f05eb [TIR] [Analysis] Calculate allocated memory at module level 
(#14711)
f5ab3f05eb is described below

commit f5ab3f05eb3190b836f41bbeb975258232010def
Author: Anirudh Sundar Subramaniam <quic_sanir...@quicinc.com>
AuthorDate: Tue Apr 25 09:18:42 2023 +0530

    [TIR] [Analysis] Calculate allocated memory at module level (#14711)
    
    * [TIR] [Analysis] Calculate allocated memory at module level
    
    This patch modifies the existing analysis pass
    `tir.calculate_allocated_bytes` to accept an IRModule as an argument and
    return allocated bytes for all prim_funcs in the IRModule.
    
    * Fix docstring and modify python API to be consistent with c++
---
 include/tvm/tir/analysis.h                         | 12 +++-
 python/tvm/tir/analysis/analysis.py                | 21 +++++--
 src/tir/analysis/calculate_allocated_memory.cc     | 36 ++++++++---
 ...test_tir_analysis_calculate_allocated_memory.py | 69 +++++++++++++++++-----
 4 files changed, 109 insertions(+), 29 deletions(-)

diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 4ed164e5ad..3b5959e781 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -266,8 +266,18 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& 
func,
 /*!
  * \brief Calculate the allocated memory per scope in bytes needed inside the 
TIR PrimFunc
  * \param func The TIR PrimFunc for which the the allocated memory size to be 
calculated
+ * \return Allocated memory size per scope in bytes inside the PrimFunc 
returned as a Map with
+ * key "main" and a Map of allocated sizes as values.
  */
-TVM_DLL tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& 
func);
+TVM_DLL tvm::Map<String, tvm::Map<String, Integer>> 
CalculateAllocatedBytes(const PrimFunc& func);
+
+/*!
+ * \brief Calculate the allocated memory per scope in bytes for each function 
inside the module
+ * \param mod The IRModule for which the the allocated memory size has to be 
calculated
+ * \return Allocated memory size per scope in bytes for each function in the 
IRModule returned as a
+           Map with function names as keys and a Map of allocated sizes as 
values.
+ */
+TVM_DLL tvm::Map<String, tvm::Map<String, Integer>> 
CalculateAllocatedBytes(const IRModule& mod);
 
 /*!
  * \brief Detect the lowest common ancestor(LCA) of buffer access, including 
both high-level
diff --git a/python/tvm/tir/analysis/analysis.py 
b/python/tvm/tir/analysis/analysis.py
index 5feb630e48..387ea04980 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -201,20 +201,29 @@ def calculate_constant_bytes(func: PrimFunc, 
constant_byte_alignment: int) -> in
     return _ffi_api.calculate_constant_bytes(func, constant_byte_alignment)  # 
type: ignore
 
 
-def calculate_allocated_bytes(func: PrimFunc) -> Dict[str, int]:
+def calculate_allocated_bytes(
+    func_or_mod: Union[PrimFunc, IRModule]
+) -> Union[Dict[str, int], Dict[str, Dict[str, int]]]:
     """Calculate allocated memory per memory scope required by TIR PrimFuncs.
 
     Parameters
     ----------
-    func: tvm.tir.PrimFunc
-        The function to be detected.
+    func_or_mod: Union[PrimFunc, IRModule]
+        The function or module to be detected. If a module is passed, allocated
+        memory is calcualted for all PrimFuncs inside the module
 
     Returns
     -------
-    result : Dict[String, int]
-        Allocated memory size per scope in bytes.
+    result : Union[Dict[str, int], Dict[str, Dict[str, int]]]
+        Allocated memory size per scope in bytes for each function in the 
IRModule returned as a
+        dict with function names as keys and a dict of allocated sizes as 
values. If a single
+        PrimFunc is passed, the function name is returned as "main"
     """
-    return _ffi_api.calculate_allocated_bytes(func)  # type: ignore
+    if not isinstance(func_or_mod, (PrimFunc, IRModule)):
+        raise TypeError(
+            f"Expected argument to be PrimFunc or IRModule, but received 
{type(func_or_mod)}"
+        )
+    return _ffi_api.calculate_allocated_bytes(func_or_mod)  # type: ignore
 
 
 def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
diff --git a/src/tir/analysis/calculate_allocated_memory.cc 
b/src/tir/analysis/calculate_allocated_memory.cc
index ffdfc1f801..8680f57e4c 100644
--- a/src/tir/analysis/calculate_allocated_memory.cc
+++ b/src/tir/analysis/calculate_allocated_memory.cc
@@ -79,16 +79,38 @@ void AllocationCalculator<T>::VisitStmt_(const T* op) {
   _current_size[storage_scope] -= size;
 }
 
-tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func) {
-  return AllocationCalculator<AllocateNode>()(func);
+tvm::Map<String, tvm::Map<String, Integer> > CalculateAllocatedBytes(const 
PrimFunc& func) {
+  tvm::Map<String, tvm::Map<String, Integer> > results;
+  results.Set("main", AllocationCalculator<AllocateNode>()(func));
+  return results;
 }
 
-TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](PrimFunc
 func) {
-  return CalculateAllocatedBytes(func);
-});
+tvm::Map<String, tvm::Map<String, Integer> > CalculateAllocatedBytes(const 
IRModule& mod) {
+  tvm::Map<String, tvm::Map<String, Integer> > results;
+  for (const auto& kv : mod->functions) {
+    if (auto prim_func = kv.second.as<tir::PrimFunc>()) {
+      String func_name = kv.first->name_hint;
+      results.Set(func_name, 
AllocationCalculator<AllocateNode>()(prim_func.value()));
+    }
+  }
+  return results;
+}
+
+TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes")
+    .set_body_typed([](ObjectRef obj) -> tvm::Map<String, tvm::Map<String, 
Integer> > {
+      if (auto func = obj.as<PrimFunc>()) {
+        return CalculateAllocatedBytes(func.value());
+      } else if (auto mod = obj.as<IRModule>()) {
+        return CalculateAllocatedBytes(mod.value());
+      } else {
+        LOG(FATAL) << "TypeError: Expect the input to be either PrimFunc or 
IRModule, but gets: "
+                   << obj->GetTypeKey();
+        throw;
+      }
+    });
 
 bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
-  auto sizes = CalculateAllocatedBytes(func);
+  auto sizes = CalculateAllocatedBytes(func)["main"];
   const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
   if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
     return false;
@@ -121,7 +143,7 @@ Pass VerifyVTCMLimit(Optional<Target> default_target) {
         }
 
         if (limit.has_value() && limit.value() > 0) {
-          auto sizes = CalculateAllocatedBytes(func);
+          auto sizes = CalculateAllocatedBytes(func)["main"];
           const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
           if (vtcm_allocated.IntValue() > limit.value()) {
             LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation 
limit has been exceeded "
diff --git 
a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py 
b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py
index 2311bfbbef..cb3a663c03 100644
--- a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py
+++ b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py
@@ -14,32 +14,42 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: 
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 import pytest
 
 import tvm
 from tvm import tir
 from tvm.script import tir as T
 
+# fmt: off
+# pylint: 
disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
 
-@T.prim_func
-def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
-    for i in T.serial(128):
-        with T.block("C"):
-            c[i] = a[i] * T.int8(2)
+@tvm.script.ir_module
+class Module:
+    @T.prim_func
+    def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
+        for i in T.serial(128):
+            with T.block("C"):
+                c[i] = a[i] * T.int8(2)
 
 
-@T.prim_func
-def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), 
"int8")):
-    B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm")
-    for i in T.serial(128):
-        with T.block("B"):
-            B[i] = a[i] * T.int8(2)
-    for i in T.serial(128):
-        with T.block("C"):
-            c[i] = B[i] * T.int8(3)
+    @T.prim_func
+    def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), 
"int8")):
+        B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm")
+        for i in T.serial(128):
+            with T.block("B"):
+                B[i] = a[i] * T.int8(2)
+        for i in T.serial(128):
+            with T.block("C"):
+                c[i] = B[i] * T.int8(3)
+
+# pylint: 
enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
+# fmt: on
 
 
-@pytest.mark.parametrize("primFunc,size", [(scale_by_two, 128), 
(scale_by_two_three, 256)])
+@pytest.mark.parametrize(
+    "primFunc,size", [(Module["scale_by_two"], 128), 
(Module["scale_by_two_three"], 256)]
+)
 def test_scale_by(primFunc, size):
     """Test calculate allocated bytes per scope"""
     mod = tvm.IRModule.from_expr(primFunc.with_attr("global_symbol", "main"))
@@ -53,6 +63,8 @@ def test_scale_by(primFunc, size):
     mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
     mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
     sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"])
+    assert "main" in sizes, 'Calls with PrimFunc is expected to return with 
function key as "main"'
+    sizes = sizes["main"]
     assert sizes.get("global.vtcm", 0) == size
 
 
@@ -94,8 +106,35 @@ def test_matmul_mix_scope(scope, size):
     mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
     mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
     sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"])
+    assert "main" in sizes, 'Calls with PrimFunc is expected to return with 
function key as "main"'
+    sizes = sizes["main"]
     assert sizes.get(scope, 0) == size
 
 
+def test_full_mod_calculator():
+    def apply_schedule(sch, func_name):
+        sch.work_on(func_name)
+        block_c = sch.get_block("C")
+        sch.cache_read(block_c, 0, storage_scope="global.vtcm")
+
+    sch = tvm.tir.Schedule(Module, debug_mask="all")
+    apply_schedule(sch, "scale_by_two")
+    apply_schedule(sch, "scale_by_two_three")
+    mod = tvm.tir.transform.ConvertBlocksToOpaque()(sch.mod)
+    mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
+    sizes = tvm.tir.analysis.calculate_allocated_bytes(mod)
+    assert "scale_by_two" in sizes, "Values for scale_by_two not found"
+    scale_by_two_sizes = sizes["scale_by_two"]
+    assert (
+        "global.vtcm" in scale_by_two_sizes
+    ), "Expected global.vtcm allocation to be calculated scale_by_two"
+    assert scale_by_two_sizes["global.vtcm"] == 128, "Expected the calculated 
size to be 128"
+    scale_by_two_three_sizes = sizes["scale_by_two_three"]
+    assert (
+        "global.vtcm" in scale_by_two_three_sizes
+    ), "Expected global.vtcm allocation to be calculated scale_by_two_three"
+    assert scale_by_two_three_sizes["global.vtcm"] == 256, "Expected the 
calculated size to be 256"
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to