This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 647be2b425 [MetaSchedule] Add `from-target` Defaults for LLVM VNNI 
Targets (#13383)
647be2b425 is described below

commit 647be2b42510bffb3ed78267c19e76263adcac36
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Mon Nov 14 20:58:12 2022 -0800

    [MetaSchedule] Add `from-target` Defaults for LLVM VNNI Targets (#13383)
---
 include/tvm/meta_schedule/mutator.h                |  2 +
 include/tvm/meta_schedule/postproc.h               |  2 +
 include/tvm/meta_schedule/schedule_rule.h          |  2 +
 src/meta_schedule/mutator/mutator.cc               |  2 +
 src/meta_schedule/postproc/postproc.cc             |  8 ++++
 src/meta_schedule/schedule_rule/schedule_rule.cc   | 45 ++++++++++++++++++++++
 .../space_generator/space_generator.cc             | 11 ++++++
 7 files changed, 72 insertions(+)

diff --git a/include/tvm/meta_schedule/mutator.h 
b/include/tvm/meta_schedule/mutator.h
index 08a8248dfd..4095d6ca03 100644
--- a/include/tvm/meta_schedule/mutator.h
+++ b/include/tvm/meta_schedule/mutator.h
@@ -131,6 +131,8 @@ class Mutator : public runtime::ObjectRef {
                                    FApply f_apply, FClone f_clone, FAsString 
f_as_string);
   /*! \brief Create default mutators for LLVM */
   TVM_DLL static Map<Mutator, FloatImm, void> DefaultLLVM();
+  /*! \brief Create default mutators for x86 VNNI */
+  TVM_DLL static Map<Mutator, FloatImm, void> DefaultVNNI();
   /*! \brief Create default mutators for CUDA */
   TVM_DLL static Map<Mutator, FloatImm, void> DefaultCUDA();
   /*! \brief Create default mutators for CUDA with TensorCore */
diff --git a/include/tvm/meta_schedule/postproc.h 
b/include/tvm/meta_schedule/postproc.h
index a680a64795..13fe470587 100644
--- a/include/tvm/meta_schedule/postproc.h
+++ b/include/tvm/meta_schedule/postproc.h
@@ -152,6 +152,8 @@ class Postproc : public runtime::ObjectRef {
   TVM_DLL static Postproc RewriteLayout();
   /*! \brief Create default postprocessors for LLVM */
   TVM_DLL static Array<Postproc, void> DefaultLLVM();
+  /*! \brief Create default postprocessors for x86 VNNI */
+  TVM_DLL static Array<Postproc, void> DefaultVNNI();
   /*! \brief Create default postprocessors for CUDA */
   TVM_DLL static Array<Postproc, void> DefaultCUDA();
   /*! \brief Create default postprocessors for CUDA with TensorCore */
diff --git a/include/tvm/meta_schedule/schedule_rule.h 
b/include/tvm/meta_schedule/schedule_rule.h
index 70dec47e60..a3d6c7ef68 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -285,6 +285,8 @@ class ScheduleRule : public runtime::ObjectRef {
 
   /*! \brief Create default schedule rules for LLVM */
   TVM_DLL static Array<ScheduleRule, void> DefaultLLVM();
+  /*! \brief Create default schedule rules for x86 VNNI */
+  TVM_DLL static Array<ScheduleRule, void> DefaultVNNI();
   /*! \brief Create default schedule rules for CUDA */
   TVM_DLL static Array<ScheduleRule, void> DefaultCUDA();
   /*! \brief Create default postprocessors for CUDA with TensorCore */
diff --git a/src/meta_schedule/mutator/mutator.cc 
b/src/meta_schedule/mutator/mutator.cc
index 8e9bfc8bde..8f3d14b6c4 100644
--- a/src/meta_schedule/mutator/mutator.cc
+++ b/src/meta_schedule/mutator/mutator.cc
@@ -59,6 +59,8 @@ Map<Mutator, FloatImm> Mutator::DefaultLLVM() {
       {Mutator::MutateParallel(/*max_jobs_per_core=*/16), 
FloatImm(DataType::Float(64), 0.02)}};
 }
 
+Map<Mutator, FloatImm> Mutator::DefaultVNNI() { return Mutator::DefaultLLVM(); 
}
+
 Map<Mutator, FloatImm> Mutator::DefaultCUDA() {
   return Map<Mutator, FloatImm>{
       {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)},
diff --git a/src/meta_schedule/postproc/postproc.cc 
b/src/meta_schedule/postproc/postproc.cc
index 0738c87112..c614f3230d 100644
--- a/src/meta_schedule/postproc/postproc.cc
+++ b/src/meta_schedule/postproc/postproc.cc
@@ -59,6 +59,14 @@ Array<Postproc> Postproc::DefaultLLVM() {
   };
 }
 
+Array<Postproc> Postproc::DefaultVNNI() {
+  return Array<Postproc>{
+      Postproc::DisallowDynamicLoop(),   
Postproc::RewriteParallelVectorizeUnroll(),
+      Postproc::RewriteReductionBlock(), 
Postproc::RewriteTensorize(/*vectorize_init_loop=*/true),
+      Postproc::RewriteLayout(),
+  };
+}
+
 Array<Postproc> Postproc::DefaultCUDA() {
   return Array<Postproc>{
       Postproc::DisallowDynamicLoop(),
diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc 
b/src/meta_schedule/schedule_rule/schedule_rule.cc
index b1e8c3695d..e4f97c1fa6 100644
--- a/src/meta_schedule/schedule_rule/schedule_rule.cc
+++ b/src/meta_schedule/schedule_rule/schedule_rule.cc
@@ -85,6 +85,51 @@ Array<ScheduleRule> ScheduleRule::DefaultLLVM() {
   };
 }
 
+Array<ScheduleRule> ScheduleRule::DefaultVNNI() {
+  return {
+      ScheduleRule::ApplyCustomRule(),
+      ScheduleRule::InlineConstantScalars(),
+      ScheduleRule::AutoInline(
+          /*into_producer=*/false,
+          /*into_consumer=*/true,
+          /*inline_const_tensor=*/true,
+          /*disallow_if_then_else=*/true,
+          /*require_injective=*/true,
+          /*require_ordered=*/true,
+          /*disallow_op=*/Array<String>{"tir.exp"}),
+      ScheduleRule::AddRFactor(
+          /*max_jobs_per_core=*/16,
+          /*max_innermost_factor=*/Integer(64)),
+      ScheduleRule::MultiLevelTilingWithIntrin(
+          /*intrin_name=*/"dot_16x4_vnni",
+          /*structure=*/"SSRSRS",
+          /*tile_binds=*/NullOpt,
+          /*max_innermost_factor=*/Integer(64),
+          /*vector_load_lens=*/NullOpt,
+          /*reuse_read=*/NullOpt,
+          /*reuse_write=*/
+          Map<String, ObjectRef>{{"req", String("may")},
+                                 {"levels", Array<Integer>{1, 2}},
+                                 {"scope", String("global")}}),
+      ScheduleRule::MultiLevelTiling(
+          /*structure=*/"SSRSRS",
+          /*tile_binds=*/NullOpt,
+          /*max_innermost_factor=*/Integer(64),
+          /*vector_load_lens=*/NullOpt,
+          /*reuse_read=*/NullOpt,
+          /*reuse_write=*/
+          Map<String, ObjectRef>{{"req", String("may")},
+                                 {"levels", Array<Integer>{1, 2}},
+                                 {"scope", String("global")}}),
+      ScheduleRule::ParallelizeVectorizeUnroll(
+          /*max_jobs_per_core=*/16,
+          /*max_vectorize_extent=*/64,
+          /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512},
+          /*unroll_explicit=*/true),
+      ScheduleRule::RandomComputeLocation(),
+  };
+}
+
 Array<ScheduleRule> ScheduleRule::DefaultCUDA() {
   return {
       ScheduleRule::ApplyCustomRule(),
diff --git a/src/meta_schedule/space_generator/space_generator.cc 
b/src/meta_schedule/space_generator/space_generator.cc
index bcc0673e59..bd124511b8 100644
--- a/src/meta_schedule/space_generator/space_generator.cc
+++ b/src/meta_schedule/space_generator/space_generator.cc
@@ -23,6 +23,13 @@ namespace meta_schedule {
 
 String GetRuleKindFromTarget(const Target& target) {
   if (target->kind->name == "llvm") {
+    static const PackedFunc* f_check_vnni =
+        runtime::Registry::Get("tvm.topi.x86.utils.target_has_vnni");
+    ICHECK(*f_check_vnni != nullptr) << "The `target_has_vnni` func is not in 
tvm registry.";
+    if (target->GetAttr<String>("mcpu") &&
+        (*f_check_vnni)(target->GetAttr<String>("mcpu").value())) {
+      return "vnni";
+    }
     return "llvm";
   }
   if (target->kind->name == "hexagon") {
@@ -79,6 +86,10 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const 
TuneContext& context) {
       default_sch_rules = ScheduleRule::DefaultHexagon();
       default_postprocs = Postproc::DefaultHexagon();
       default_mutator_probs = Mutator::DefaultHexagon();
+    } else if (kind == "vnni") {
+      default_sch_rules = ScheduleRule::DefaultVNNI();
+      default_postprocs = Postproc::DefaultVNNI();
+      default_mutator_probs = Mutator::DefaultVNNI();
     } else {
       LOG(FATAL) << "Unsupported kind: " << kind;
       throw;

Reply via email to