This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c52e633c7 [TIR][Schedule] Method returning the function being worked 
on (#14593)
1c52e633c7 is described below

commit 1c52e633c79afa4a6e5cf90fc97445448e486bf5
Author: Ruihang Lai <ruiha...@cs.cmu.edu>
AuthorDate: Tue Apr 11 21:43:42 2023 -0400

    [TIR][Schedule] Method returning the function being worked on (#14593)
    
    PR #11999 introduces the sugar method `work_on` to TIR Schedule, with
    a field `func_working_on_` newly added to the ScheduleNode. In some
    cases we may want to know which function a ScheduleNode is working on,
    which is not supported previously.
    
    Therefore, this PR introduces a method to ScheduleNode that returns
    the function (more accurately, GlobalVar) currently being worked on.
    With this we are able to know the function being worked on.
---
 include/tvm/tir/schedule/schedule.h                  | 2 ++
 python/tvm/tir/schedule/schedule.py                  | 7 ++++++-
 src/tir/schedule/concrete_schedule.h                 | 1 +
 src/tir/schedule/schedule.cc                         | 2 ++
 tests/python/unittest/test_tir_schedule_utilities.py | 1 +
 5 files changed, 12 insertions(+), 1 deletion(-)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index c294d0ae87..69f0520117 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -115,6 +115,8 @@ class ScheduleNode : public runtime::Object {
   virtual ScheduleState state() const = 0;
   /*! \return The internally maintained trace of scheduling program execution 
*/
   virtual Optional<Trace> trace() const = 0;
+  /*! \return The GlobalVar of the func that the schedule is currently working 
on */
+  virtual Optional<GlobalVar> func_working_on() const = 0;
   /*!
    * \brief Instruct the schedule to work on a function in the IRModule.
    *
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index b19e30848f..34fd649a5d 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -19,7 +19,7 @@ from typing import Callable, Dict, List, Optional, Tuple, 
Union
 
 from tvm._ffi import register_object as _register_object
 from tvm.error import TVMError, register_error
-from tvm.ir import IRModule, PrimExpr
+from tvm.ir import GlobalVar, IRModule, PrimExpr
 from tvm.runtime import Object, String
 from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
 
@@ -207,6 +207,11 @@ class Schedule(Object):
         """Returns the internally maintained trace of scheduling program 
execution"""
         return _ffi_api.ScheduleGetTrace(self)  # type: ignore # pylint: 
disable=no-member
 
+    @property
+    def func_working_on(self) -> Optional[GlobalVar]:
+        """Returns the GlobalVar of the func that the schedule is currently 
working on"""
+        return _ffi_api.ScheduleGetFuncWorkingOn(self)  # type: ignore # 
pylint: disable=no-member
+
     def work_on(self, func_name: str) -> None:
         """Instruct the schedule to work on a function in the IRModule.
 
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index eb7c38753c..16065df3cd 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -64,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode {
  public:
   ScheduleState state() const final { return state_; }
   Optional<Trace> trace() const override { return NullOpt; }
+  Optional<GlobalVar> func_working_on() const final { return func_working_on_; 
}
   void WorkOn(const String& func_name) final;
   Schedule Copy() override;
   void Seed(support::LinearCongruentialEngine::TRandState seed) final;
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 20a044439b..8663ac2b97 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -50,6 +50,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState")  //
     .set_body_method<Schedule>(&ScheduleNode::state);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace")  //
     .set_body_method<Schedule>(&ScheduleNode::trace);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn")  //
+    .set_body_method<Schedule>(&ScheduleNode::func_working_on);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy")  //
     .set_body_method<Schedule>(&ScheduleNode::Copy);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed")  //
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py 
b/tests/python/unittest/test_tir_schedule_utilities.py
index ba2c134def..a8be97488b 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -193,6 +193,7 @@ def test_tir_schedule_work_on():
         sch.get_block(name="init")
     sch.work_on(func_name="vector_add")
     sch.get_block(name="init")
+    assert sch.func_working_on == sch.mod.get_global_var("vector_add")
 
 
 def test_tir_schedule_get_loops(use_block_name):

Reply via email to