jverma-quic commented on code in PR #12971:
URL: https://github.com/apache/tvm/pull/12971#discussion_r988341680


##########
src/tir/transforms/profile_instrumentation.cc:
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file profile_instrumentation.cc
+ */
+// Insert profile intrinsic at loop and function level. During codegen,
+// these instruction can be replaced with a call to a target specific handler
+// and can be used to capture profiling information such as processor cycles.
+
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+namespace lwp {
+
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_disable_func_prof", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_max_depth", Integer);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_min_height", Integer);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.instr_siblings", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.reset_start_id", Bool);
+
+static int32_t start_id = 0;
+
+struct LoopInfo {
+  LoopInfo() = default;
+  LoopInfo(unsigned i, unsigned d, unsigned h = 0) : id(i), depth(d), 
height(h) {
+    has_siblings = false;
+    has_parallel = false;
+  }
+  unsigned id;
+  unsigned depth;
+  int32_t height;
+  bool has_siblings;
+  // Set to 'true' if ForKind::kParallel is set for the current loop or one of 
its ancestor
+  bool has_parallel;
+};
+
+using LoopInfoMap = std::unordered_map<const ForNode*, LoopInfo>;
+// Traverse loops depth first and assign them a unique number.
+class LoopAnalyzer : public StmtExprVisitor {
+ public:
+  LoopInfoMap Analyze(const Stmt& stmt) {
+    this->VisitStmt(stmt);
+    return loops;
+  }
+  void VisitStmt_(const ForNode* op) final {
+    LoopInfo loop_info(start_id, 0);
+    start_id++;
+    loop_info.height = TraverseLoop(op->body, 0);
+    loops[op] = loop_info;
+  }
+
+  unsigned TraverseLoop(const Stmt& stmt, unsigned parent_depth, bool 
has_parallel = false) {
+    if (stmt->IsInstance<SeqStmtNode>()) {
+      std::vector<const ForNode*> siblings;
+      unsigned height = 0;
+      bool has_loop = false;
+      const SeqStmtNode* n = stmt.as<SeqStmtNode>();
+      for (Stmt s : n->seq) {
+        if (s->IsInstance<ForNode>()) {
+          has_loop = true;
+          const ForNode* f = s.as<ForNode>();
+          LoopInfo loop_info(start_id, parent_depth + 1);
+          start_id++;
+          bool parent_parallel = false;
+          if (has_parallel) {
+            loop_info.has_parallel = true;
+            parent_parallel = true;
+          } else if (f->kind == ForKind::kParallel) {
+            // has_parallel for the current loop is being set to 'false' since 
the
+            // intrinsic is added outside of the loop. The instrumentation 
isn't
+            // allowed for the subsequent nested loops.
+            loop_info.has_parallel = false;
+            parent_parallel = true;
+          }
+          siblings.push_back(f);
+          height = std::max(height, TraverseLoop(f->body, parent_depth + 1, 
parent_parallel));
+          loop_info.height = height;
+          loops[f] = loop_info;
+        }
+      }
+      if (siblings.size() > 1) {
+        for (auto* l : siblings) {
+          loops[l].has_siblings = true;
+        }
+      }
+      height = has_loop ? height + 1 : height;
+      return height;  // Parent's height : max of all children's height
+    } else if (stmt->IsInstance<IfThenElseNode>()) {
+      const IfThenElseNode* n = stmt.as<IfThenElseNode>();
+      unsigned height = TraverseLoop(n->then_case, parent_depth, has_parallel);
+      if (n->else_case.defined())
+        height = std::max(height, TraverseLoop(n->else_case, parent_depth, 
has_parallel));
+      return height;
+    } else if (stmt->IsInstance<ForNode>()) {
+      const ForNode* f = stmt.as<ForNode>();
+      LoopInfo loop_info(start_id, parent_depth + 1);
+      start_id++;
+      bool parent_parallel = false;
+      if (has_parallel) {
+        loop_info.has_parallel = true;
+        parent_parallel = true;
+      } else if (f->kind == ForKind::kParallel) {
+        // has_parallel for the current loop is being set to 'false' since the
+        // intrinsic is added outside of the loop. The instrumentation isn't
+        // allowed for the subsequent nested loops.
+        loop_info.has_parallel = false;
+        parent_parallel = true;
+      }
+      unsigned height = TraverseLoop(f->body, parent_depth + 1, 
parent_parallel);
+      loop_info.height = height;
+      loops[f] = loop_info;
+      return height + 1;
+    } else if (stmt->IsInstance<LetStmtNode>()) {
+      const LetStmtNode* n = stmt.as<LetStmtNode>();
+      return TraverseLoop(n->body, parent_depth, has_parallel);
+    } else if (stmt->IsInstance<AttrStmtNode>()) {
+      const AttrStmtNode* n = stmt.as<AttrStmtNode>();
+      return TraverseLoop(n->body, parent_depth, has_parallel);
+    } else if (stmt->IsInstance<AllocateNode>()) {
+      const AllocateNode* n = stmt.as<AllocateNode>();
+      return TraverseLoop(n->body, parent_depth, has_parallel);
+    } else {
+      return 0;  // inner-most loop
+    }
+  }
+
+ private:
+  LoopInfoMap loops;
+};
+
+class InstrumentIntrin : public StmtMutator {
+ public:
+  InstrumentIntrin(int32_t max_depth, int32_t min_height, bool instr_siblings)
+      : max_instr_depth_(max_depth),
+        min_instr_height_(min_height),
+        instr_siblings_(instr_siblings) {}
+
+  void GetLoopInfo(PrimFuncNode* op) {
+    LoopAnalyzer analzer;
+    loops_ = std::move(analzer.Analyze(op->body));
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) final {
+    Stmt stmt = StmtMutator::VisitStmt_(op);
+    return SeqStmt::Flatten(stmt);
+  }
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    Stmt stmt = StmtMutator::VisitStmt_(op);
+    if (loops_.count(op) < 1) return stmt;
+
+    LoopInfo loop_info = loops_[op];
+
+    if (loop_info.has_parallel) {
+      return stmt;
+    }
+
+    // Exclude inner-most loops from instrumentation. The inner-most loop has
+    // height '0' and it increases as we move outward in the loop nest.
+    if (loop_info.height < min_instr_height_) {
+      return stmt;
+    }
+
+    // Only instrument loops with a sibling
+    if (instr_siblings_ && !loop_info.has_siblings) {
+      return stmt;
+    }
+
+    // If instr_siblings_ is set, ignore max depth for instrumentation
+    if (!instr_siblings_ && loop_info.depth > max_instr_depth_) {
+      return stmt;
+    }
+    PrimExpr id = static_cast<int32_t>(loop_info.id);
+    PrimExpr call = Call(DataType::Handle(), builtin::profile_intrinsic(), 
{id});

Review Comment:
   Sure! It makes sense to have two separate intrinsics to mark the start and 
end of the code segments. This would definitely be more useful for some 
metrics. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to