This is an automated email from the ASF dual-hosted git repository.

kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b48fcaba22 [TIR][Hexagon] Use the "target" value in T.func_attr for 
VTCM limit (#14567)
b48fcaba22 is described below

commit b48fcaba227c6d455c30bec2216183fed9853677
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Thu Apr 13 14:16:30 2023 -0500

    [TIR][Hexagon] Use the "target" value in T.func_attr for VTCM limit (#14567)
    
    * [TIR][Hexagon] Use the "target" value in T.func_attr for VTCM limit
    
    For the VerifyVTCMLimit, read directly from the function attribute, if
    the function has already been annotated with the target.
    
    * Retain passing of target to VerifyVTCMLimit
---
 include/tvm/tir/analysis.h                     |  6 ++--
 src/auto_scheduler/feature.cc                  |  4 +--
 src/driver/driver_api.cc                       | 11 +-------
 src/tir/analysis/calculate_allocated_memory.cc | 39 +++++++++++++++++++-------
 4 files changed, 35 insertions(+), 25 deletions(-)

diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 5bac25faa5..4ed164e5ad 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -26,6 +26,7 @@
 
 #include <tvm/ir/module.h>
 #include <tvm/ir/transform.h>
+#include <tvm/target/target.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/function.h>
 #include <tvm/tir/op_attr_types.h>
@@ -348,12 +349,13 @@ TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> 
constraints);
 /*!
  * \brief Pass to checks if the size of the allocated vtcm memory satisfies 
the limit
  *
- * \param limit The limit to check.
+ * \param target The target whose VTCM limit should be used for any
+ * functions not already annotated with `tvm::attr::kTarget`.
  *
  * \returns The pass.
  * \sa tvm::tir::CalculateAllocatedBytes
  */
-TVM_DLL Pass VerifyVTCMLimit(const Integer& limit);
+TVM_DLL Pass VerifyVTCMLimit(Optional<Target> target = NullOpt);
 
 /*!
  * \brief Statically check TIR code for out of bounds array access.
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 884215c24a..65cc13eb61 100644
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -1408,9 +1408,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& 
task, const State& state, i
     }
     if (IsHexagonTask(task)) {
       Target target = task->target;
-      const auto vtcm_capacity = 
target->GetAttr<Integer>("vtcm-capacity").value().IntValue();
-      const auto& optimize =
-          
tir::transform::Sequential({tir::transform::VerifyVTCMLimit(vtcm_capacity)});
+      const auto& optimize = 
tir::transform::Sequential({tir::transform::VerifyVTCMLimit(target)});
       optimize(mod);
     }
     const auto& optimize =
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 1962b9ab3b..486b40c994 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -544,22 +544,13 @@ runtime::Module build(const IRModule& funcs, const 
Target& target_arg,
   return TIRToRuntime(inputs, target_host);
 }
 
-int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) 
{
-  if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true);
-  if (target.defined() && target->kind->name == "hexagon") {
-    auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity"))->value;
-    if (value > 0) return value;
-  }
-  return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity", 
Integer(0)).value()->value;
-}
-
 transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target 
target) {
   transform::PassContext pass_ctx = transform::PassContext::Current();
 
   Array<Pass> mixed_pass_list;
 
   // VerifyVTCMLimit must occur before LowerVtcmAlloc
-  
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(GetVTCMCapacity(target,
 pass_ctx)));
+  mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
   // LowerVtcmAlloc must occur after any transformations that modify memory 
allocation locations
   mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());
 
diff --git a/src/tir/analysis/calculate_allocated_memory.cc 
b/src/tir/analysis/calculate_allocated_memory.cc
index 95fd7f134e..ffdfc1f801 100644
--- a/src/tir/analysis/calculate_allocated_memory.cc
+++ b/src/tir/analysis/calculate_allocated_memory.cc
@@ -96,20 +96,39 @@ bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
   return true;
 }
 
+int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) 
{
+  if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true);
+  if (target.defined() && target->kind->name == "hexagon") {
+    auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity"))->value;
+    if (value > 0) return value;
+  }
+  return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity", 
Integer(0)).value()->value;
+}
+
 namespace transform {
 
-Pass VerifyVTCMLimit(const Integer& limit) {
+Pass VerifyVTCMLimit(Optional<Target> default_target) {
   auto pass_func = [=](IRModule mod, PassContext ctx) {
     for (auto kv : mod->functions) {
-      if (auto func = kv.second.as<PrimFunc>()) {
-        auto sizes = CalculateAllocatedBytes(func.value());
-        const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
-        if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > 
limit.IntValue()) {
-          LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit 
has been "
-                        "exceeded(allocated: "
-                     << vtcm_allocated << ", limit: " << limit << ").\n"
-                     << "In function\n"
-                     << func;
+      if (auto opt = kv.second.as<PrimFunc>()) {
+        auto func = opt.value();
+
+        std::optional<int64_t> limit = std::nullopt;
+        if (auto func_target = func->GetAttr<Target>(tvm::attr::kTarget)) {
+          limit = GetVTCMCapacity(func_target.value(), ctx);
+        } else if (default_target) {
+          limit = GetVTCMCapacity(default_target.value(), ctx);
+        }
+
+        if (limit.has_value() && limit.value() > 0) {
+          auto sizes = CalculateAllocatedBytes(func);
+          const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
+          if (vtcm_allocated.IntValue() > limit.value()) {
+            LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation 
limit has been exceeded "
+                       << "(allocated: " << vtcm_allocated << ", limit: " << 
limit.value() << ").\n"
+                       << "In function\n"
+                       << func;
+          }
         }
       }
     }

Reply via email to