This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 486c49895d [TIR] [Hexagon] Add get_vtcm_allocation_sizes with lowering 
(#14720)
486c49895d is described below

commit 486c49895dab63c28e44227b18ae51cca45593d0
Author: Anirudh Sundar Subramaniam <quic_sanir...@quicinc.com>
AuthorDate: Wed Apr 26 09:25:30 2023 +0530

    [TIR] [Hexagon] Add get_vtcm_allocation_sizes with lowering (#14720)
    
    This patch adds an utility function for getting the VTCM sizes allocated
    in an IRModule. In order to do that, we've exposed the list of lowering
    passes to python and we've refactored the PostprocVerifyVTCMLimit to be
    computed for whole module and the same list of lowering passes
---
 include/tvm/tir/analysis.h                      | 16 ++++++++
 python/tvm/tir/analysis/analysis.py             | 13 +++++++
 python/tvm/topi/hexagon/utils.py                | 52 +++++++++++++++++++++++--
 src/meta_schedule/postproc/verify_vtcm_limit.cc | 44 ++++-----------------
 src/tir/analysis/calculate_allocated_memory.cc  | 33 ++++++++++++++++
 5 files changed, 119 insertions(+), 39 deletions(-)

diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 3b5959e781..f4684231f0 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -184,6 +184,22 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
  */
 TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> 
constraints);
 
+/**
+ * @brief Utility function to get the list of lowering passes to be applied to 
calculate the
+ * compacted VTCM allocation size
+ *
+ * @return returns list of passes
+ */
+TVM_DLL Array<tvm::transform::Pass> GetVTCMCompactionPasses();
+
+/*!
+ * \brief Verifies that the VTCM usage for all prim_funcs in the given IRModule
+ * \param mod The module to be checked
+ * \param limit The limit to check.
+ * \return true if the VTCM usage is within the provided limit.
+ */
+TVM_DLL bool VerifyVTCMLimit(const IRModule& mod, Integer limit);
+
 /*!
  * \brief Verifies that the VTCM usage of the given prim_func is within the 
provided limit.
  * \param func The function to be checked.
diff --git a/python/tvm/tir/analysis/analysis.py 
b/python/tvm/tir/analysis/analysis.py
index 387ea04980..1a5f8b9781 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -18,6 +18,7 @@
 # pylint: disable=invalid-name
 from typing import Dict, List, Union
 
+import tvm
 from tvm import Object
 from tvm.ir import IRModule
 from tvm.tir.expr import Var
@@ -384,3 +385,15 @@ def find_anchor_block(mod: IRModule) -> Block:
         The anchor block if found, None otherwise.
     """
     return _ffi_api.find_anchor_block(mod)  # type: ignore # pylint: 
disable=no-member
+
+
+def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]:
+    """Utility function to get the list of lowering passes to be applied to 
calculate thecompacted
+    VTCM allocation size
+
+    Returns
+    -------
+    result : List[tvm.transform.Pass]
+        returns list of passes
+    """
+    return _ffi_api.get_vtcm_compaction_passes()  # type: ignore # pylint: 
disable=no-member
diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py
index 3148360cc2..f017aaebbd 100644
--- a/python/tvm/topi/hexagon/utils.py
+++ b/python/tvm/topi/hexagon/utils.py
@@ -21,9 +21,11 @@
 """Common hexagon specific utilities"""
 import math
 import struct
-from typing import Tuple
-from tvm import te
-from tvm.tir import IndexMap
+from typing import Dict, Tuple, Union
+
+import tvm
+from tvm import IRModule, te
+from tvm.tir import IndexMap, PrimFunc
 
 
 def n11c_1024c_2d(n, h, w, c):
@@ -354,3 +356,47 @@ def get_fixed_point_value(flp: float, dtype: str = 
"int16") -> Tuple[int, int]:
 def saturate(x: te.Tensor, dtype: str):
     """Saturate value for the specified data type"""
     return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype)))
+
+
+def get_vtcm_allocation_sizes(
+    func_or_mod: Union[PrimFunc, IRModule], compacted=True
+) -> Dict[str, int]:
+    """Calculate and return the vtcm allocation sizes for all the functions in
+    the IRModule or just the vtcm size if a single PrimFunc is passed
+
+    Parameters
+    ----------
+    func_or_mod : Union[PrimFunc, IRModule]
+        PrimFunc or IRModule for which VTCM allocation size is to be calculated
+    compacted :
+        Whether to calculate the sizes after applying VTCM lowering passes for
+        buffer compaction. This helps return the VTCM size that would get
+        allocated after lowering
+
+    Returns
+    -------
+    result : Dict[str, int]
+        A dict with function names as keys and vtcm allocated
+        inside that function as values
+
+    """
+    if not isinstance(func_or_mod, (PrimFunc, IRModule)):
+        raise TypeError(
+            f"Expected argument to be PrimFunc or IRModule, but received 
{type(func_or_mod)}"
+        )
+    if isinstance(func_or_mod, tvm.tir.PrimFunc):
+        mod = tvm.IRModule.from_expr(func_or_mod)
+    else:
+        mod = func_or_mod
+    if compacted:
+        passes = tvm.tir.analysis.get_vtcm_compaction_passes()
+        mod = tvm.transform.Sequential(list(passes))(mod)
+
+    result = {}
+    all_sizes = tvm.tir.analysis.calculate_allocated_bytes(mod)
+    for func_name, sizes in all_sizes.items():
+        if "global.vtcm" in sizes:
+            result[func_name] = sizes["global.vtcm"]
+        else:
+            result[func_name] = 0
+    return result
diff --git a/src/meta_schedule/postproc/verify_vtcm_limit.cc 
b/src/meta_schedule/postproc/verify_vtcm_limit.cc
index 46bc7486e1..4de9750896 100644
--- a/src/meta_schedule/postproc/verify_vtcm_limit.cc
+++ b/src/meta_schedule/postproc/verify_vtcm_limit.cc
@@ -36,48 +36,20 @@ class VerifyVTCMLimitNode : public PostprocNode {
   }
 
   bool Verify(const IRModule& mod) const {
-    for (const auto& kv : mod->functions) {
-      if (auto prim_func = kv.second.as<tir::PrimFunc>()) {
-        if (!tir::VerifyVTCMLimit(prim_func.value(), vtcm_capacity)) {
-          return false;
-        }
-      }
+    if (!tir::VerifyVTCMLimit(mod, vtcm_capacity)) {
+      return false;
     }
     return true;
   }
 
   bool Apply(const tir::Schedule& sch) final {
     IRModule mod = sch->mod();
-    for (const auto& kv : mod->functions) {
-      const GlobalVar& g_var = kv.first;
-      const BaseFunc& base_func = kv.second;
-      if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
-        IRModule lowered{nullptr};
-        try {
-          auto pass_list = Array<tvm::transform::Pass>();
-          pass_list.push_back(tir::transform::LowerInitBlock());
-          
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
-          pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
-          pass_list.push_back(tir::transform::CompactBufferAllocation());
-          pass_list.push_back(tir::transform::LowerMatchBuffer());
-          pass_list.push_back(tir::transform::InjectSoftwarePipeline());
-          pass_list.push_back(tir::transform::LowerOpaqueBlock());
-          pass_list.push_back(tir::transform::FlattenBuffer());
-          pass_list.push_back(tir::transform::Simplify());
-          pass_list.push_back(tir::transform::VectorizeLoop(true));
-          pass_list.push_back(tir::transform::StorageRewrite());
-          transform::PassContext pass_ctx = transform::PassContext::Current();
-          tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), 
"global_symbol",
-                                     runtime::String(g_var->name_hint));
-          IRModule mod = IRModule(Map<GlobalVar, 
BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
-          lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
-        } catch (const dmlc::Error& e) {
-          return false;
-        }
-        if (!Verify(lowered)) {
-          return false;
-        }
-      }
+    IRModule lowered{nullptr};
+    auto pass_list = tir::GetVTCMCompactionPasses();
+    transform::PassContext pass_ctx = transform::PassContext::Current();
+    lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
+    if (!Verify(lowered)) {
+      return false;
     }
     return true;
   }
diff --git a/src/tir/analysis/calculate_allocated_memory.cc 
b/src/tir/analysis/calculate_allocated_memory.cc
index 8680f57e4c..3a41c5ac5a 100644
--- a/src/tir/analysis/calculate_allocated_memory.cc
+++ b/src/tir/analysis/calculate_allocated_memory.cc
@@ -27,6 +27,7 @@
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/function.h>
 #include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/usmp/utils.h>
 
 #include <algorithm>
@@ -109,6 +110,18 @@ 
TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes")
       }
     });
 
+bool VerifyVTCMLimit(const IRModule& mod, Integer limit) {
+  auto all_sizes = CalculateAllocatedBytes(mod);
+  for (const auto& kv : all_sizes) {
+    auto sizes = kv.second;
+    const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
+    if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
+      return false;
+    }
+  }
+  return true;
+}
+
 bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
   auto sizes = CalculateAllocatedBytes(func)["main"];
   const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
@@ -127,6 +140,26 @@ int64_t GetVTCMCapacity(Target target, const 
transform::PassContext& pass_ctx) {
   return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity", 
Integer(0)).value()->value;
 }
 
+Array<tvm::transform::Pass> GetVTCMCompactionPasses() {
+  auto pass_list = Array<tvm::transform::Pass>();
+  pass_list.push_back(tir::transform::LowerInitBlock());
+  pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+  pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+  pass_list.push_back(tir::transform::CompactBufferAllocation());
+  pass_list.push_back(tir::transform::LowerMatchBuffer());
+  pass_list.push_back(tir::transform::InjectSoftwarePipeline());
+  pass_list.push_back(tir::transform::LowerOpaqueBlock());
+  pass_list.push_back(tir::transform::FlattenBuffer());
+  pass_list.push_back(tir::transform::Simplify());
+  pass_list.push_back(tir::transform::VectorizeLoop(true));
+  pass_list.push_back(tir::transform::StorageRewrite());
+  return pass_list;
+}
+
+TVM_REGISTER_GLOBAL("tir.analysis.get_vtcm_compaction_passes").set_body_typed([]()
 {
+  return GetVTCMCompactionPasses();
+});
+
 namespace transform {
 
 Pass VerifyVTCMLimit(Optional<Target> default_target) {

Reply via email to