masahi commented on code in PR #12059:
URL: https://github.com/apache/tvm/pull/12059#discussion_r918501599


##########
python/tvm/meta_schedule/testing/schedule_rule.py:
##########
@@ -110,6 +112,38 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
     raise NotImplementedError(f"{target.kind.name} is not supported")
 
 
+def multi_level_tiling_tensor_core(target: Target, scope="shared") -> 
ScheduleRule:
+    """Default schedule rules for with multi-level tiling reuse for tensor 
core"""
+    assert scope in ["shared", "global"]
+    if target.kind.name == "cuda":
+        return MultiLevelTilingTensorCore(
+            intrin_group={
+                "init": tensor_intrin.WMMA_FILL_16x16x16_F32_INTRIN,
+                "load_a": tensor_intrin.WMMA_LOAD_16x16x16_F16_A_INTRIN,
+                "load_b": tensor_intrin.WMMA_LOAD_16x16x16_F16_B_INTRIN,
+                "compute": tensor_intrin.WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
+                "store": tensor_intrin.WMMA_STORE_16x16x16_F32_SHARED_INTRIN
+                if scope == "shared"
+                else tensor_intrin.WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,

Review Comment:
   Need to check dtype



##########
include/tvm/meta_schedule/schedule_rule.h:
##########
@@ -173,6 +173,32 @@ class ScheduleRule : public runtime::ObjectRef {
       Optional<Integer> max_innermost_factor, Optional<Array<Integer>> 
vector_load_lens,
       Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, 
ObjectRef>> reuse_write);
 
+  /*!
+   * \brief Extension of MultiLevelTiling for auto-tensorizing with a single 
group of tensor core
+   * intrinsics
+   * \param intrin_group A group of tensor core intrinsics. The map should 
contains key "init",
+   * "load_a", "load_b", "compute", "store", which represent the tensor intrin 
for initialization,
+   * loading operand A, loading operand B, tensor core computation, storing 
the result. The value of
+   * the map should be names of tensor intrinsics, must be registerd via 
TensorIntrin.register(...)
+   * beforehand
+   * \param structure The tiling structure. Recommended:
+   * - 'SSRSRS' on CPU
+   * - 'SSSRRSRS' on GPU
+   * \param tile_binds For each level of tiles, which thread axis it is bound 
to. Recommended:
+   * - NullOpt on CPU

Review Comment:
   CPU not relevant at L185 and L188



##########
include/tvm/meta_schedule/schedule_rule.h:
##########
@@ -173,6 +173,32 @@ class ScheduleRule : public runtime::ObjectRef {
       Optional<Integer> max_innermost_factor, Optional<Array<Integer>> 
vector_load_lens,
       Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, 
ObjectRef>> reuse_write);
 
+  /*!
+   * \brief Extension of MultiLevelTiling for auto-tensorizing with a single 
group of tensor core
+   * intrinsics
+   * \param intrin_group A group of tensor core intrinsics. The map should 
contains key "init",
+   * "load_a", "load_b", "compute", "store", which represent the tensor intrin 
for initialization,
+   * loading operand A, loading operand B, tensor core computation, storing 
the result. The value of
+   * the map should be names of tensor intrinsics, must be registerd via 
TensorIntrin.register(...)
+   * beforehand
+   * \param structure The tiling structure. Recommended:
+   * - 'SSRSRS' on CPU
+   * - 'SSSRRSRS' on GPU
+   * \param tile_binds For each level of tiles, which thread axis it is bound 
to. Recommended:
+   * - NullOpt on CPU

Review Comment:
   Or do we expect this class to be useful for Intel AMX? (CPU with matrix 
intrinsic)



##########
src/meta_schedule/postproc/rewrite_tensorize.cc:
##########
@@ -35,26 +35,24 @@ void CollectTensorizationJobs(
   tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) {
     if (const auto* block = obj.as<tir::BlockNode>()) {
       tir::StmtSRef block_sref = sch->GetSRef(block);
+      std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
       if (Optional<String> intrin_name =
               tir::GetAnn<String>(block_sref, 
tir::attr::meta_schedule_auto_tensorize)) {
-        std::string block_name = 
block_sref->StmtAs<tir::BlockNode>()->name_hint;
-        if (block_name.find("init") == std::string::npos) {
-          jobs->emplace_back(block_name, func_name, [sch, 
intrin_name](tir::BlockRV block) {
-            try {
-              sch->Tensorize(block, intrin_name.value());
-            } catch (const std::exception& e) {
-              LOG(WARNING) << "Tensorize failed with error " << e.what();
-            }
-          });
-        } else if (vectorize_init_loop) {
-          jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
-            Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
-            ICHECK(child_blocks.size() == 1);
-            Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
-            ICHECK(init_loops.size() == 1);
-            sch->Vectorize(init_loops[0]);
-          });
-        }
+        jobs->emplace_back(block_name, func_name, [sch, 
intrin_name](tir::BlockRV block) {
+          try {
+            sch->Tensorize(block, intrin_name.value());
+          } catch (const std::exception& e) {
+            LOG(WARNING) << "Tensorize failed with error " << e.what();
+          }
+        });
+      } else if (block_name.find("init") && vectorize_init_loop) {

Review Comment:
   Do we ever hit this condition after your change in 
[rewrite_reduction_block.cc](https://github.com/apache/tvm/pull/12059/files#diff-470a1ee8bb8d9ce151669661a93e24fe3b9df3094d431026d0d54fda5b2e2adf)?
   
   To vectorize init loop, should we switch to using 
`tir::attr::meta_schedule_auto_tensorize_init`?



##########
src/meta_schedule/schedule_rule/multi_level_tiling.h:
##########
@@ -112,6 +117,31 @@ class State : public ObjectRef {
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
 };
 
+class TensorCoreStateNode : public StateNode {
+ public:
+  /*! \brief The Tensor Core reindex block A for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_A;
+  /*! \brief The Tensor Core reindex block B for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_B;
+  /*! \brief The Tensor Core reindex store block for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_store;
+
+  State Copy() const final;
+
+  static constexpr const char* _type_key = "meta_schedule.TensorCoreState";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode);
+};
+
+class TensorCoreState : public State {
+ public:
+  explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
+                           Array<Array<tir::LoopRV>> tiles = {});
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, 
TensorCoreStateNode);
+};

Review Comment:
   This class can be moved to `multi_level_tiling_tensor_core.cc` I think



##########
src/meta_schedule/schedule_rule/multi_level_tiling.h:
##########
@@ -112,6 +117,31 @@ class State : public ObjectRef {
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
 };
 
+class TensorCoreStateNode : public StateNode {
+ public:
+  /*! \brief The Tensor Core reindex block A for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_A;
+  /*! \brief The Tensor Core reindex block B for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_B;
+  /*! \brief The Tensor Core reindex store block for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_store;
+
+  State Copy() const final;
+
+  static constexpr const char* _type_key = "meta_schedule.TensorCoreState";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode);
+};
+
+class TensorCoreState : public State {
+ public:
+  explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
+                           Array<Array<tir::LoopRV>> tiles = {});
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, 
TensorCoreStateNode);
+};
+
+struct AutoTensorizationState : public State {};

Review Comment:
   Unused?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to