This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a0edf24c60 [TIR] Refactor BF16Legalize (#14405)
a0edf24c60 is described below

commit a0edf24c60bad81a6f4a4333fbf2b63255a37882
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Mar 27 22:35:20 2023 -0400

    [TIR] Refactor BF16Legalize (#14405)
    
    This PR refactors BF16Legalize to enable more f32 computations.
    We also split the BF16Legalize into two steps.
    
    - BF16ComputeLegalize changes all computation to f32 while keeping
      the external BF16 storages.
    - BF16StorageLegalize changes all storage to u16.
    
    Now BF16 kernels accept tvm.nd.array that are created as bfloat16 type.
---
 include/tvm/tir/transform.h                        |  10 +-
 include/tvm/topi/elemwise.h                        |   6 +-
 python/tvm/tir/transform/transform.py              |  45 +-
 src/driver/driver_api.cc                           |   3 +-
 .../postproc/disallow_async_strided_mem_copy.cc    |   2 +-
 src/meta_schedule/postproc/verify_gpu_code.cc      |   2 +-
 src/target/codegen.cc                              |   1 -
 src/target/llvm/codegen_llvm.cc                    |   4 +
 src/target/llvm/llvm_module.cc                     |   1 -
 src/tir/op/op.cc                                   |   2 +
 src/tir/transforms/arg_binder.cc                   |   2 +-
 src/tir/transforms/bf16_legalize.cc                | 696 ++++++++++++++-------
 src/tir/transforms/storage_access.h                |   1 -
 tests/python/frontend/onnx/test_forward.py         |   8 +
 tests/python/unittest/test_target_codegen_llvm.py  |   6 +-
 .../unittest/test_tir_transform_bf16_legalize.py   | 257 +++-----
 16 files changed, 623 insertions(+), 423 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 0aaa8b3e8a..d4f537ff31 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -337,11 +337,17 @@ TVM_DLL Pass CombineContextCall();
 TVM_DLL Pass NarrowDataType(int target_bits);
 
 /*!
- * \brief Legalize bf16 typed Ops. Add a cast to fp32
+ * \brief Legalize bf16 compute Ops. Add a cast to fp32
  *   before Ops, then add a cast back to bf16.
  * \return The pass.
  */
-TVM_DLL Pass BF16Legalize();
+TVM_DLL Pass BF16ComputeLegalize();
+
+/*!
+ * \brief Legalize bf16 storage types to u16.
+ * \return The pass.
+ */
+TVM_DLL Pass BF16StorageLegalize();
 
 /*!
  * \brief Rewrite the pointer content type of arguments,
diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h
index 49b50019f0..132992c57d 100644
--- a/include/tvm/topi/elemwise.h
+++ b/include/tvm/topi/elemwise.h
@@ -310,11 +310,7 @@ inline Tensor cast(const Tensor& x, DataType type, 
std::string name = "T_cast",
 inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = 
"tensor",
                           std::string tag = kElementWise) {
   return compute(
-      x->shape,
-      [&](const Array<Var>& i) {
-        return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)});
-      },
-      name, tag);
+      x->shape, [&](const Array<Var>& i) { return reinterpret(type, x(i)); }, 
name, tag);
 }
 
 /*!
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index a18d698e54..1df2ac76b5 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -286,59 +286,26 @@ def RemoveStoreUndef():
     return _ffi_api.RemoveStoreUndef()  # type: ignore
 
 
-def BF16Legalize():
-    """Legalize bf16 typed Ops.
-    Runs BF16Promote, BF16CastElimination and BF16TypeLowering
+def BF16ComputeLegalize():
+    """Legalize bf16 compute Ops.
 
     Returns
     -------
     fpass : tvm.transform.Pass
         The result pass
     """
-    return _ffi_api.BF16Legalize()  # type: ignore
+    return _ffi_api.BF16ComputeLegalize()  # type: ignore
 
 
-def BF16Promote():
-    """Promote bf16 to fp32. Add a cast to fp32
-    before Ops, then add a cast back to bf16.
+def BF16StorageLegalize():
+    """Legalize bf16 storage types to u16.
 
     Returns
     -------
     fpass : tvm.transform.Pass
         The result pass
     """
-    return _ffi_api.BF16Promote()  # type: ignore
-
-
-def BF16CastElimination():
-    """Eliminate verbose casting between fp32 and bf16
-    Checks if the AST has the pattern:
-    castto32(castto16(some_fp32_op(...)))
-    The verbose casting is generated by BF16Promote for multiple
-    bf16 Ops in a row. e.g.:
-    X[i] + Y[i] + T[i] =>
-    bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
-    After this pass:
-    bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
-
-    Returns
-    -------
-    fpass : tvm.transform.Pass
-        The result pass
-    """
-    return _ffi_api.BF16CastElimination()  # type: ignore
-
-
-def BF16TypeLowering():
-    """Replace all bf16 type with uint16. Also lower the casting
-    between fp32 and bf16
-
-    Returns
-    -------
-    fpass : tvm.transform.Pass
-        The result pass
-    """
-    return _ffi_api.BF16TypeLowering()  # type: ignore
+    return _ffi_api.BF16StorageLegalize()  # type: ignore
 
 
 def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: 
bool = False):
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 3458376848..569864a29e 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -218,7 +218,7 @@ Array<tvm::transform::Pass> CreatePassList(bool 
disable_loop_partition) {
   pass_list.push_back(tir::transform::InjectSoftwarePipeline());
   pass_list.push_back(tir::transform::LowerOpaqueBlock());
   pass_list.push_back(tir::transform::FlattenBuffer());
-  pass_list.push_back(tir::transform::BF16Legalize());
+  pass_list.push_back(tir::transform::BF16ComputeLegalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
 
@@ -605,6 +605,7 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
   } else {
     mixed_pass_list.push_back(tir::transform::MakePackedAPI());
   }
+  mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
   mixed_pass_list.push_back(tir::transform::SplitHostDevice());
 
   return transform::Sequential(mixed_pass_list);
diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc 
b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
index 2d1507c899..d654e467f1 100644
--- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
+++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
@@ -138,7 +138,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode 
{
           pass_list.push_back(tir::transform::InjectSoftwarePipeline());
           pass_list.push_back(tir::transform::LowerOpaqueBlock());
           pass_list.push_back(tir::transform::FlattenBuffer());
-          pass_list.push_back(tir::transform::BF16Legalize());
+          pass_list.push_back(tir::transform::BF16ComputeLegalize());
           pass_list.push_back(tir::transform::NarrowDataType(32));
           pass_list.push_back(tir::transform::Simplify());
           pass_list.push_back(tir::transform::InjectVirtualThread());
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc 
b/src/meta_schedule/postproc/verify_gpu_code.cc
index 0013106b09..3240496afe 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -169,7 +169,7 @@ class VerifyGPUCodeNode : public PostprocNode {
           pass_list.push_back(tir::transform::InjectSoftwarePipeline());
           pass_list.push_back(tir::transform::LowerOpaqueBlock());
           pass_list.push_back(tir::transform::FlattenBuffer());
-          pass_list.push_back(tir::transform::BF16Legalize());
+          pass_list.push_back(tir::transform::BF16ComputeLegalize());
           pass_list.push_back(tir::transform::NarrowDataType(32));
           pass_list.push_back(tir::transform::Simplify());
           // Phase 2
diff --git a/src/target/codegen.cc b/src/target/codegen.cc
index f6b694cb7c..24dbfebe55 100644
--- a/src/target/codegen.cc
+++ b/src/target/codegen.cc
@@ -46,7 +46,6 @@ runtime::Module Build(IRModule mod, Target target) {
           .value()) {
     mod = tir::transform::SkipAssert()(mod);
   }
-
   auto target_attr_map = 
tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime");
   if (target_attr_map.count(target->kind)) {
     return target_attr_map[target->kind](mod, target);
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 365beb5f5d..7c32f3cfa1 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -828,6 +828,10 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, 
llvm::Value* end, llvm::Va
 llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* 
value) {
   llvm::Type* target = DTypeToLLVMType(to);
   if (value->getType() == target) return value;
+  // TODO(tvm-team): consider add native support
+  ICHECK(!from.is_bfloat16()) << "BF16 needs to be storaged lowered first";
+  ICHECK(!to.is_bfloat16()) << "BF16 needs to be storaged lowered first";
+
   if (to.is_handle()) {
     return builder_->CreateBitCast(value, target);
   } else if (to.is_uint() && to.bits() == 1) {
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index 8749925781..50dcd7402a 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -325,7 +325,6 @@ void LLVMModuleNode::Init(const IRModule& mod, const 
Target& target) {
   if (tm->getTargetTriple().isOSDarwin()) {
     module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
   }
-
   std::string verify_errors_storage;
   llvm::raw_string_ostream verify_errors(verify_errors_storage);
   LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors))
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 828ab01083..4439a9c3d7 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -324,6 +324,8 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) 
{
 // reinterpret
 PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) {
   if (value.dtype() == t) return value;
+  ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes())
+      << "Bitcast requires size match " << t << " vs " << value.dtype();
   return tir::Call(t, tir::builtin::reinterpret(), {value}, span);
 }
 
diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc
index 2fc3bd2dca..c785b732ab 100644
--- a/src/tir/transforms/arg_binder.cc
+++ b/src/tir/transforms/arg_binder.cc
@@ -184,7 +184,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
                    TVMArrayGet(DataType::UInt(16), handle, 
builtin::kArrTypeLanes) ==
                        IntImm(DataType::UInt(16), buffer->dtype.lanes()));
   if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) 
||
-        buffer->dtype == DataType::UInt(4) || buffer->dtype == 
DataType::UInt(16))) {
+        buffer->dtype == DataType::UInt(4))) {
     auto type_msg = tvm::tir::StringImm(type_err_msg.str());
     asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
     asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
diff --git a/src/tir/transforms/bf16_legalize.cc 
b/src/tir/transforms/bf16_legalize.cc
index 3b89558622..99fad558cf 100644
--- a/src/tir/transforms/bf16_legalize.cc
+++ b/src/tir/transforms/bf16_legalize.cc
@@ -25,173 +25,198 @@
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
 #include <cmath>
 #include <tuple>
 
-#include "../../arith/ir_mutator_with_analyzer.h"
-#include "../../arith/ir_visitor_with_analyzer.h"
-
 namespace tvm {
 namespace tir {
 
-using arith::Analyzer;
-using arith::IRMutatorWithAnalyzer;
-
-class BF16PromoteRewriter : public StmtExprMutator {
+// NOTE: do not touch buffer on function boundary
+// remap internal bf16 buffer to f32 if they meet the following condition
+// - constant allocation size
+// - do not have raw pointer access to the buffer
+//
+// populate the buffer_remap and var_remap accordingly.
+class BF16ComputeLegalizePlanner : public StmtExprVisitor {
  public:
-  BF16PromoteRewriter() {}
-
-  Stmt operator()(Stmt s) { return VisitStmt(s); }
-
-  PrimExpr VisitExpr_(const AddNode* op) final;
-  PrimExpr VisitExpr_(const SubNode* op) final;
-  PrimExpr VisitExpr_(const MulNode* op) final;
-  PrimExpr VisitExpr_(const DivNode* op) final;
-  PrimExpr VisitExpr_(const MinNode* op) final;
-  PrimExpr VisitExpr_(const MaxNode* op) final;
-  PrimExpr VisitExpr_(const LTNode* op) final;
-  PrimExpr VisitExpr_(const LENode* op) final;
-  PrimExpr VisitExpr_(const GTNode* op) final;
-  PrimExpr VisitExpr_(const GENode* op) final;
-  PrimExpr VisitExpr_(const CallNode* op) final;
-};
+  BF16ComputeLegalizePlanner(
+      std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* 
buffer_remap,
+      std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap)
+      : buffer_remap_(buffer_remap), var_remap_(var_remap) {}
+
+  // run planning to populate buffer remap and var remap.
+  void Plan(PrimFunc func) {
+    this->VisitStmt(func->body);
+    // if there are opaque var access, then we cannot
+    // do remap of var and buffer, post-hoc remove these items.
+    for (Var var : opaque_var_access_) {
+      auto it = var_remap_->find(var);
+      if (it != var_remap_->end()) {
+        var_remap_->erase(it);
+      }
+    }
+    Array<Buffer> drop_buffers;
+    for (auto kv : *buffer_remap_) {
+      if (opaque_var_access_.count(kv.first->data)) {
+        drop_buffers.push_back(kv.first);
+      }
+    }
+    for (Buffer buffer : drop_buffers) {
+      auto it = buffer_remap_->find(buffer);
+      ICHECK(it != buffer_remap_->end());
+      buffer_remap_->erase(it);
+    }
+  }
 
-#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC, NEEDCAST)            
    \
-  PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) {                     
    \
-    PrimExpr origin_a = this->VisitExpr(op->a);                                
    \
-    PrimExpr origin_b = this->VisitExpr(op->b);                                
    \
-    bool a_is_bfloat16 = origin_a->dtype.is_bfloat16();                        
    \
-    bool b_is_bfloat16 = origin_b->dtype.is_bfloat16();                        
    \
-    bool both_bfloat16 = a_is_bfloat16 && b_is_bfloat16;                       
    \
-    bool none_bfloat16 = !(a_is_bfloat16 || b_is_bfloat16);                    
    \
-    if (none_bfloat16) {                                                       
    \
-      return GetRef<PrimExpr>(op);                                             
    \
-    }                                                                          
    \
-    DataType float32_dtype(kDLFloat, 32, 1);                                   
    \
-    PrimExpr float32_a = a_is_bfloat16 ? Cast(float32_dtype, origin_a) : 
origin_a; \
-    PrimExpr float32_b = b_is_bfloat16 ? Cast(float32_dtype, origin_b) : 
origin_b; \
-    PrimExpr result = FUNC(float32_a, float32_b);                              
    \
-    DataType bfloat16_dtype(kDLBfloat, 16, 1);                                 
    \
-    bool do_cast = both_bfloat16 && NEEDCAST;                                  
    \
-    return do_cast ? Cast(bfloat16_dtype, result) : result;                    
    \
-  }
-
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<, false)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=, false)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>, false)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=, false)
-
-PrimExpr BF16PromoteRewriter::VisitExpr_(const CallNode* op) {
-  Array<PrimExpr> args;
-  for (auto& arg : op->args) {
-    PrimExpr x = this->VisitExpr(arg);
-    if (x.dtype().is_bfloat16()) {
-      DataType fp32_dtype(kDLFloat, 32, x.dtype().lanes());
-      args.push_back(Cast(fp32_dtype, {x}, op->span));
-    } else {
-      args.push_back(x);
+  void VisitStmt_(const AllocateNode* op) final {
+    // remap all intermediate constant buffr to fp32
+    if (op->dtype.is_bfloat16() && op->ConstantAllocationSize() != 0) {
+      DataType dtype = DataType::Float(32, op->dtype.lanes());
+      Var buffer_var = Var(op->buffer_var->name_hint, 
PointerType(PrimType(dtype)));
+      (*var_remap_)[op->buffer_var] = buffer_var;
     }
+    return StmtExprVisitor::VisitStmt_(op);
   }
-  if (op->dtype.is_bfloat16()) {
-    DataType fp32_dtype(kDLFloat, 32, op->dtype.lanes());
-    PrimExpr result_fp32 = Call(fp32_dtype, op->op, args, op->span);
-    return Cast(op->dtype, {result_fp32}, op->span);
-  } else {
-    return Call(op->dtype, op->op, args, op->span);
+
+  void VisitStmt_(const BufferStoreNode* op) final {
+    StmtExprVisitor::VisitStmt_(op);
+    this->PopulateBufferRemap(op->buffer);
   }
-}
 
-/*
- * Eliminate verbose casting between fp32 and bf16
- * Checks if the AST has the pattern:
- *     castto32(castto16(some_fp32_op(...)))
- * The verbose casting is generated by BF16Promote for multiple
- * bf16 Ops in a row. e.g.:
- *  X[i] + Y[i] + T[i] =>
- *  bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
- * After this pass:
- *  bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
- */
-class BF16CastEliminationRewriter : public StmtExprMutator {
- public:
-  BF16CastEliminationRewriter() {}
+  void VisitExpr_(const BufferLoadNode* op) final {
+    StmtExprVisitor::VisitExpr_(op);
+    this->PopulateBufferRemap(op->buffer);
+  }
 
-  Stmt operator()(Stmt s) { return VisitStmt(s); }
+  void VisitStmt_(const DeclBufferNode* op) final {
+    StmtExprVisitor::VisitStmt_(op);
+    this->PopulateBufferRemap(op->buffer);
+  }
 
-  PrimExpr VisitExpr_(const CastNode* op) final {
-    auto op_val = StmtExprMutator::VisitExpr(op->value);
-    if (op->dtype.is_float() && op->dtype.bits() == 32) {
-      // if is cast_to_fp32, check if op->value is cast_to_fp16
-      // and op->value->value is a float32
-      if (auto innercast = op_val.as<CastNode>()) {
-        if (innercast->dtype.is_bfloat16() && 
innercast->value->dtype.is_float() &&
-            innercast->value->dtype.bits() == 32) {
-          return innercast->value;
-        }
-      }
+  void VisitExpr_(const VarNode* op) final {
+    StmtExprVisitor::VisitExpr_(op);
+    Var buffer_var = GetRef<Var>(op);
+    if (buffer_var.dtype().is_handle()) {
+      opaque_var_access_.insert(buffer_var);
     }
-    if (op->value.same_as(op_val)) return GetRef<PrimExpr>(op);
-    return Cast(op->dtype, op_val);
   }
-};
 
-union FloatCaster {
-  uint32_t u32;
-  float f32;
+ private:
+  void PopulateBufferRemap(Buffer buf) {
+    auto var_it = var_remap_->find(buf->data);
+    if (var_it == var_remap_->end()) return;
+
+    Buffer new_buffer(var_it->second, DataType::Float(32, buf->dtype.lanes()), 
buf->shape,
+                      buf->strides, buf->elem_offset, buf->name, 
buf->data_alignment,
+                      buf->offset_factor, buf->buffer_type, 
buf->axis_separators, buf->span);
+    (*buffer_remap_)[buf] = new_buffer;
+  }
+
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* 
buffer_remap_;
+  std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap_;
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> opaque_var_access_;
 };
 
-uint16_t RoundToNearestEven(float src) {
-  if (std::isnan(src)) {
-    return UINT16_C(0x7FC0);
-  } else {
-    FloatCaster caster;
-    caster.f32 = src;
-    uint32_t rounding_bias = ((caster.u32 >> 16) & 1) + UINT32_C(0x7FFF);
-    return static_cast<uint16_t>((caster.u32 + rounding_bias) >> 16);
+#define DEFINE_BIOP_EXPR_LEGALIZE(OP, FUNC)                       \
+  PrimExpr VisitExpr_(const OP* op) final {                       \
+    PrimExpr origin_a = PromoteBF16ToF32(this->VisitExpr(op->a)); \
+    PrimExpr origin_b = PromoteBF16ToF32(this->VisitExpr(op->b)); \
+                                                                  \
+    if (origin_a.same_as(op->a) && origin_b.same_as(op->b)) {     \
+      return GetRef<PrimExpr>(op);                                \
+    } else {                                                      \
+      return FUNC(origin_a, origin_b);                            \
+    }                                                             \
   }
-}
 
-/*
- * Lower the bf16 type to int16
- * Lower cast between bf16 and fp32
- * Lower bf16 FloatImm to int16
- */
-class BF16LowerRewriter : public StmtExprMutator {
+// NOTE: Legalize the BF16 computations
+// to floating point computations and only keeps the
+// bf16 storage which can further be legalized by BF16StorageLegalizer
+// BF16StorageLegalizer will be called at a much later time
+// point in the TIR lowering phases.
+class BF16ComputeLegalizer : public StmtExprMutator {
  public:
-  BF16LowerRewriter() {}
-
-  using StmtExprMutator::operator();
+  PrimFunc Legalize(PrimFunc func) {
+    BF16ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_);
+    planner.Plan(func);
+    auto* n = func.CopyOnWrite();
+    n->body = this->VisitStmt(std::move(n->body));
+    return func;
+  }
 
+ protected:
   PrimExpr VisitExpr_(const CastNode* op) final {
-    PrimExpr op_val = StmtExprMutator::VisitExpr(op->value);
-    DataType uint32_dtype(kDLUInt, 32, op_val->dtype.lanes());
-    DataType float32_dtype(kDLFloat, 32, op_val->dtype.lanes());
-    if (op->value->dtype.is_bfloat16()) {  // cast from bf16
-      PrimExpr uint32_v = Cast(uint32_dtype, op_val);
-      PrimExpr float32_v = Call(float32_dtype, builtin::reinterpret(), 
{uint32_v << 16});
-      bool is_to_float32 = op->dtype.is_float() && op->dtype.bits() == 32;
-      return is_to_float32 ? float32_v : Cast(op->dtype, float32_v);
-    } else if (op->dtype.is_bfloat16()) {  // cast to bf16
-      bool is_from_float32 = op->value->dtype.is_float() && 
op->value->dtype.bits() == 32;
-      PrimExpr float32_v = is_from_float32 ? op_val : Cast(float32_dtype, 
op_val);
-      PrimExpr uint32_v = Call(uint32_dtype, builtin::reinterpret(), 
{float32_v});
-      DataType uint16_dtype(kDLUInt, 16, op_val->dtype.lanes());
-      /* the following TIR is equivalent to the C++ code below:
-      uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
-      return static_cast<uint16_t>((U32 + rounding_bias) >> 16);*/
-      PrimExpr rounding_bias = ((uint32_v >> 16) & 1) + 
make_const(uint16_dtype, 0x7FFF);
-      return Cast(uint16_dtype, {(uint32_v + rounding_bias) >> 16});
-    }
-    if (op->value.same_as(op_val)) return GetRef<PrimExpr>(op);
-    return Cast(op->dtype, op_val);
+    auto op_val = PromoteBF16ToF32(this->VisitExpr(op->value));
+
+    // all casts to BF16 becomes f32
+    if (op->dtype.is_bfloat16()) {
+      return cast(DataType::Float(32, op->dtype.lanes()), op_val);
+    }
+
+    if (op_val.same_as(op->value)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return cast(op->dtype, op_val);
+    }
+  }
+
+  PrimExpr VisitExpr_(const SelectNode* op) final {
+    PrimExpr condition = this->VisitExpr(op->condition);
+    PrimExpr true_value = PromoteBF16ToF32(this->VisitExpr(op->true_value));
+    PrimExpr false_value = PromoteBF16ToF32(this->VisitExpr(op->false_value));
+    if (condition.same_as(op->condition) && true_value.same_as(op->true_value) 
&&
+        false_value.same_as(op->false_value)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return Select(condition, true_value, false_value);
+    }
+  }
+
+  PrimExpr VisitExpr_(const BroadcastNode* op) final {
+    PrimExpr value = PromoteBF16ToF32(this->VisitExpr(op->value));
+    if (value.same_as(op->value)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return Broadcast(value, op->lanes);
+    }
+  }
+
+  PrimExpr VisitExpr_(const ShuffleNode* op) final {
+    auto fexpr = [this](const PrimExpr& e) { return 
PromoteBF16ToF32(this->VisitExpr(e)); };
+    auto vectors = op->vectors.Map(fexpr);
+    if (vectors.same_as(op->vectors)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return Shuffle(vectors, op->indices);
+    }
+  }
+
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    // presertve reinterpret<bf16>() behavior.
+    if (op->op.same_as(builtin::reinterpret())) {
+      return StmtExprMutator::VisitExpr_(op);
+    }
+    // update normal computations to return f32 instead.
+    auto fmutate = [this](const PrimExpr& e) { return 
PromoteBF16ToF32(this->VisitExpr(e)); };
+    Array<PrimExpr> args = op->args.Map(fmutate);
+    if (op->dtype.is_bfloat16()) {
+      return Call(DataType::Float(32, op->dtype.lanes()), op->op, args);
+    }
+    if (args.same_as(op->args)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return Call(op->dtype, op->op, args);
+    }
+  }
+
+  PrimExpr VisitExpr_(const FloatImmNode* op) final {
+    if (op->dtype.is_bfloat16()) {
+      return FloatImm(DataType::Float(32), op->value);
+    }
+    return GetRef<PrimExpr>(op);
   }
 
   PrimExpr VisitExpr_(const VarNode* op) final {
@@ -205,26 +230,73 @@ class BF16LowerRewriter : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const AllocateNode* op) final {
-    if (op->dtype.is_bfloat16()) {
-      DataType dtype = DataType::UInt(16, op->dtype.lanes());
-      Var buffer_var = Var(op->buffer_var->name_hint, 
PointerType(PrimType(dtype)));
-      var_remap_[op->buffer_var] = buffer_var;
-      return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, 
op->body));
+  PrimExpr VisitExpr_(const LetNode* op) final {
+    PrimExpr value = PromoteBF16ToF32(op->value);
+    Var var = op->var;
+    if (value.dtype() != op->value.dtype()) {
+      var = op->var.copy_with_dtype(op->value.dtype());
+      var_remap_[op->var] = var;
+    }
+
+    PrimExpr body = VisitExpr(op->body);
+
+    if (value.same_as(op->value) && var.same_as(op->var) && 
body.same_as(op->body)) {
+      return GetRef<PrimExpr>(op);
     } else {
-      return StmtExprMutator::VisitStmt_(op);
+      return Let(var, value, body);
+    }
+  }
+
+  DEFINE_BIOP_EXPR_LEGALIZE(AddNode, operator+);
+  DEFINE_BIOP_EXPR_LEGALIZE(SubNode, operator-);
+  DEFINE_BIOP_EXPR_LEGALIZE(MulNode, operator*);
+  DEFINE_BIOP_EXPR_LEGALIZE(DivNode, div);
+  DEFINE_BIOP_EXPR_LEGALIZE(MinNode, min);
+  DEFINE_BIOP_EXPR_LEGALIZE(MaxNode, max);
+  DEFINE_BIOP_EXPR_LEGALIZE(LTNode, operator<);  // NOLINT(*)
+  DEFINE_BIOP_EXPR_LEGALIZE(LENode, operator<=);
+  DEFINE_BIOP_EXPR_LEGALIZE(GTNode, operator>);  // NOLINT(*)
+  DEFINE_BIOP_EXPR_LEGALIZE(GENode, operator>=);
+  DEFINE_BIOP_EXPR_LEGALIZE(EQNode, operator==);
+  DEFINE_BIOP_EXPR_LEGALIZE(NENode, operator!=);
+
+  Stmt VisitStmt_(const LetStmtNode* op) final {
+    PrimExpr value = PromoteBF16ToF32(op->value);
+    Var var = op->var;
+    if (value.dtype() != op->value.dtype()) {
+      var = op->var.copy_with_dtype(op->value.dtype());
+      var_remap_[op->var] = var;
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (value.same_as(op->value) && var.same_as(op->var) && 
body.same_as(op->body)) {
+      return GetRef<Stmt>(op);
+    } else {
+      return LetStmt(var, value, body);
     }
   }
 
   Stmt VisitStmt_(const BufferStoreNode* op) final {
-    Stmt ret = StmtExprMutator::VisitStmt_(op);
-    op = ret.as<BufferStoreNode>();
+    PrimExpr value = this->VisitExpr(op->value);
+    auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
+
+    Array<PrimExpr> indices = op->indices.Map(fmutate);
 
     Buffer new_buf = GetRemappedBuffer(op->buffer);
-    if (new_buf.same_as(op->buffer)) {
-      return ret;
+
+    if (value.same_as(op->value) && indices.same_as(op->indices) && 
new_buf.same_as(op->buffer)) {
+      return GetRef<Stmt>(op);
     } else {
-      return BufferStore(new_buf, op->value, op->indices);
+      if (new_buf->dtype.is_bfloat16()) {
+        value = CastF32ToBF16(value);
+      }
+      if (value.dtype() != new_buf->dtype) {
+        // this happens when buffer get rewritten to f32
+        // but values remain as bf16
+        ICHECK(value.dtype().is_bfloat16());
+        value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value);
+      }
+      return BufferStore(new_buf, value, indices);
     }
   }
 
@@ -258,6 +330,35 @@ class BF16LowerRewriter : public StmtExprMutator {
     }
   }
 
+  Stmt VisitStmt_(const DeclBufferNode* op) final {
+    Stmt ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<DeclBufferNode>();
+
+    Buffer new_buf = GetRemappedBuffer(op->buffer);
+    if (new_buf.same_as(op->buffer)) {
+      return ret;
+    } else {
+      return DeclBuffer(new_buf, op->body);
+    }
+  }
+
+  Stmt VisitStmt_(const AllocateNode* op) final {
+    Stmt ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<AllocateNode>();
+
+    auto it = var_remap_.find(op->buffer_var);
+    if (it != var_remap_.end()) {
+      Var remapped_var = it->second;
+      auto* ptr = remapped_var->type_annotation.as<PointerTypeNode>();
+      ICHECK(ptr);
+      auto* prim_type = ptr->element_type.as<PrimTypeNode>();
+      ICHECK(prim_type);
+      return Allocate(remapped_var, prim_type->dtype, op->extents, 
op->condition, op->body);
+    } else {
+      return ret;
+    }
+  }
+
   PrimExpr VisitExpr_(const BufferLoadNode* op) final {
     PrimExpr ret = StmtExprMutator::VisitExpr_(op);
     op = ret.as<BufferLoadNode>();
@@ -270,48 +371,244 @@ class BF16LowerRewriter : public StmtExprMutator {
     }
   }
 
-  PrimExpr VisitExpr_(const FloatImmNode* op) final {
+ private:
+  /*!
+   * \brief promote BF16 to F32 and keep other values unchanged.
+   * \param value The input value.
+   * \return The converted value.
+   */
+  PrimExpr PromoteBF16ToF32(PrimExpr value) {
+    if (!value.dtype().is_bfloat16()) return value;
+    if (const CastNode* cast = value.as<CastNode>()) {
+      if (cast->value.dtype() == DataType::Float(32)) return cast->value;
+    }
+    DataType f32 = DataType::Float(32, value.dtype().lanes());
+    DataType u16 = DataType::UInt(16, value.dtype().lanes());
+    DataType u32 = DataType::UInt(32, value.dtype().lanes());
+    // reinterpret<f32>((cast<u32>(reinterpret<u16>(bf16_value)) << 16))
+    return reinterpret(f32, cast(u32, reinterpret(u16, value)) << 16);
+  }
+
+  /*!
+   * \brief Cast value to F32 to BF16 and keep other values unchanged.
+   * \param value The input value
+   * \return The converted value.
+   */
+  PrimExpr CastF32ToBF16(PrimExpr value) {
+    if (!value.dtype().is_float()) return value;
+    ICHECK_EQ(value.dtype().bits(), 32);
+    DataType bf16 = DataType::BFloat(16, value.dtype().lanes());
+    DataType u16 = DataType::UInt(16, value.dtype().lanes());
+    DataType u32 = DataType::UInt(32, value.dtype().lanes());
+    PrimExpr u32_val = reinterpret(u32, value);
+
+    if (round_to_even_) {
+      PrimExpr rounding_bias = ((u32_val >> 16) & 1) + make_const(u32, 0x7FFF);
+      u32_val = u32_val + rounding_bias;
+    }
+    // reinterpret<bf16>((cast<u16>(reinterpret<u32>(f32_value)) >> 16))
+    return reinterpret(bf16, cast(u16, u32_val >> 16));
+  }
+
+  Buffer GetRemappedBuffer(Buffer buf) {
+    auto buf_it = buffer_remap_.find(buf);
+    if (buf_it != buffer_remap_.end()) {
+      return buf_it->second;
+    }
+    return buf;
+  }
+
+  bool round_to_even_{true};
+
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> 
buffer_remap_;
+  std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
+};
+
+/*!
+ * \brief This Pass legalizes remaining BF16 storages to u16
+ *
+ * This pass needs to happens after BF16ComputeLegalizer and serves
+ * as a way to support BF16 on platforms that do not have native support.
+ */
+class BF16StorageLegalizer : public StmtExprMutator {
+ public:
+  PrimFunc Legalize(PrimFunc func) {
+    ICHECK_EQ(func->buffer_map.size(), 0) << "This pass must be called after 
MakePackedAPI";
+    auto* n = func.CopyOnWrite();
+    n->params = n->params.Map([this](Var var) { return this->RemapVarDef(var); 
});
+    n->body = this->VisitStmt(std::move(n->body));
+    return func;
+  }
+
+ private:
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    Var var = GetRef<Var>(op);
+    auto itr = var_remap_.find(var);
+    if (itr != var_remap_.end()) {
+      return itr->second;
+    } else {
+      return std::move(var);
+    }
+  }
+
+  Stmt VisitStmt_(const AllocateNode* op) final {
     if (op->dtype.is_bfloat16()) {
-      return IntImm(DataType::UInt(16, op->dtype.lanes()),
-                    RoundToNearestEven(static_cast<float>(op->value)));
+      DataType dtype = DataType::UInt(16, op->dtype.lanes());
+      Var buffer_var = Var(op->buffer_var->name_hint, 
PointerType(PrimType(dtype)));
+      var_remap_[op->buffer_var] = buffer_var;
+      return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, 
op->body));
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
     }
-    return StmtExprMutator::VisitExpr_(op);
   }
 
-  void AlterBuffers(PrimFuncNode* op) {
-    Map<Var, Buffer> new_buffer_map;
-
-    for (auto& itr : op->buffer_map) {
-      auto param_var = itr.first;
-      auto oldbuf = itr.second;
-      if (oldbuf->dtype.is_bfloat16()) {
-        DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
-        Var buffer_var = Var(oldbuf->data->name_hint, 
PointerType(PrimType(dtype)));
-        auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, 
oldbuf->strides, oldbuf->elem_offset,
-                             oldbuf->name, oldbuf->data_alignment, 
oldbuf->offset_factor,
-                             oldbuf->buffer_type);
-        buffer_remap_[oldbuf] = newbuf;
-        var_remap_[oldbuf->data] = buffer_var;
-        new_buffer_map.Set(param_var, newbuf);
-      } else {
-        new_buffer_map.Set(param_var, oldbuf);
+  Stmt VisitStmt_(const DeclBufferNode* op) final {
+    Buffer buf = GetRemappedBuffer(op->buffer);
+    // in a rare case the buffer didn't get remapped
+    // because the original var is not bfloat*
+    // force remap here
+    if (buf->dtype.is_bfloat16()) {
+      buf = Buffer(buf->data, DataType::UInt(16, buf->dtype.lanes()), 
buf->shape, buf->strides,
+                   buf->elem_offset, buf->name, buf->data_alignment, 
buf->offset_factor,
+                   buf->buffer_type, buf->axis_separators, buf->span);
+      buffer_remap_[op->buffer] = buf;
+    }
+    Stmt body = VisitStmt(op->body);
+    if (buf.same_as(op->buffer) && body.same_as(op->body)) {
+      return GetRef<Stmt>(op);
+    } else {
+      return DeclBuffer(buf, body, op->span);
+    }
+  }
+
+  PrimExpr VisitExpr_(const LetNode* op) final {
+    PrimExpr value = VisitExpr(op->value);
+    Var var = RemapVarDef(op->var);
+    PrimExpr body = VisitExpr(op->body);
+
+    if (value.same_as(op->value) && var.same_as(op->var) && 
body.same_as(op->body)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return Let(var, value, body);
+    }
+  }
+
+  Stmt VisitStmt_(const LetStmtNode* op) final {
+    PrimExpr value = VisitExpr(op->value);
+    Var var = RemapVarDef(op->var);
+    Stmt body = VisitStmt(op->body);
+
+    if (value.same_as(op->value) && var.same_as(op->var) && 
body.same_as(op->body)) {
+      return GetRef<Stmt>(op);
+    } else {
+      return LetStmt(var, value, body);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    PrimExpr value = this->ChangeBF16ToU16(VisitExpr(op->value));
+    Buffer new_buf = GetRemappedBuffer(op->buffer);
+    auto indices = op->indices.Map([this](PrimExpr expr) { return 
this->VisitExpr(expr); });
+    if (new_buf.same_as(op->buffer) && indices.same_as(op->indices) && 
value.same_as(op->value)) {
+      return GetRef<Stmt>(op);
+    } else {
+      if (op->value.dtype().is_bfloat16()) {
+        ICHECK(new_buf->dtype.is_uint());
+      }
+      return BufferStore(new_buf, value, indices);
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    Stmt ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<AttrStmtNode>();
+
+    if (auto* buffer = op->node.as<BufferNode>()) {
+      auto it = buffer_remap_.find(GetRef<Buffer>(buffer));
+      if (it != buffer_remap_.end()) {
+        return AttrStmt(it->second, op->attr_key, op->value, op->body);
+      }
+    } else if (auto* var = op->node.as<VarNode>()) {
+      auto it = var_remap_.find(GetRef<Var>(var));
+      if (it != var_remap_.end()) {
+        return AttrStmt(it->second, op->attr_key, op->value, op->body);
       }
     }
+    return ret;
+  }
 
-    if (buffer_remap_.size() != 0) {
-      op->buffer_map = new_buffer_map;
+  Stmt VisitStmt_(const BufferRealizeNode* op) final {
+    LOG(FATAL) << "Do not expect buffer realize";
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr ret = StmtExprMutator::VisitExpr_(op);
+    op = ret.as<BufferLoadNode>();
+    Buffer new_buf = GetRemappedBuffer(op->buffer);
+    if (new_buf.same_as(op->buffer)) {
+      return ret;
+    } else {
+      return BufferLoad(new_buf, op->indices);
     }
   }
 
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    // remap re-interpret so un-necessary reinterpret can be skipped.
+    if (op->op.same_as(builtin::reinterpret())) {
+      PrimExpr value = VisitExpr(op->args[0]);
+      // sometimes the input dtype can change and we can skip.
+      if (value.dtype() == op->dtype) return value;
+      if (op->dtype.is_bfloat16()) {
+        return reinterpret(DataType::UInt(16, op->dtype.lanes()), value);
+      }
+      if (op->args[0].same_as(value)) {
+        return GetRef<PrimExpr>(op);
+      } else {
+        return reinterpret(op->dtype, value);
+      }
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
  private:
+  /*!
+   * \brief Change BF16 value to U16 value.
+   * \param value The input value.
+   * \return The converted value.
+   */
+  PrimExpr ChangeBF16ToU16(PrimExpr value) {
+    if (!value.dtype().is_bfloat16()) return value;
+    auto* call = value.as<CallNode>();
+    if (call && call->op.same_as(builtin::reinterpret())) {
+      return reinterpret(DataType::UInt(16, value.dtype().lanes()), 
call->args[0]);
+    } else {
+      return value;
+    }
+  }
+
+  Var RemapVarDef(Var var) {
+    // remap the var
+    if (var.dtype().is_handle()) {
+      if (auto* ptr_type = var->type_annotation.as<PointerTypeNode>()) {
+        if (auto* elem_type = ptr_type->element_type.as<PrimTypeNode>()) {
+          if (elem_type->dtype.is_bfloat16()) {
+            Var new_var = Var(var->name_hint,
+                              PointerType(PrimType(DataType::UInt(16, 
elem_type->dtype.lanes()))));
+            var_remap_[var] = new_var;
+            return new_var;
+          }
+        }
+      }
+    }
+    return var;
+  }
+
   Buffer GetRemappedBuffer(Buffer buf) {
     auto buf_it = buffer_remap_.find(buf);
     if (buf_it != buffer_remap_.end()) {
       return buf_it->second;
     }
-
     Buffer new_buf = buf;
-
     auto var_it = var_remap_.find(buf->data);
     if (var_it != var_remap_.end()) {
       DataType dtype =
@@ -319,6 +616,8 @@ class BF16LowerRewriter : public StmtExprMutator {
       new_buf = Buffer(var_it->second, dtype, buf->shape, buf->strides, 
buf->elem_offset, buf->name,
                        buf->data_alignment, buf->offset_factor, 
buf->buffer_type,
                        buf->axis_separators, buf->span);
+    } else {
+      ICHECK(!buf->dtype.is_bfloat16()) << "Cannot find var remap for " << buf;
     }
 
     buffer_remap_[buf] = new_buf;
@@ -332,46 +631,25 @@ class BF16LowerRewriter : public StmtExprMutator {
 
 namespace transform {
 
-Pass BF16Promote() {
+Pass BF16ComputeLegalize() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
-    auto* n = f.CopyOnWrite();
-    n->body = BF16PromoteRewriter()(std::move(n->body));
-    return f;
+    // TODO(tvm-team): skip if the target supports bf16
+    return BF16ComputeLegalizer().Legalize(f);
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.BF16Promote", {});
+  return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {});
 }
 
-TVM_REGISTER_GLOBAL("tir.transform.BF16Promote").set_body_typed(BF16Promote);
+TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16ComputeLegalize);
 
-Pass BF16CastElimination() {
+Pass BF16StorageLegalize() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
-    auto* n = f.CopyOnWrite();
-    n->body = BF16CastEliminationRewriter()(std::move(n->body));
-    return f;
+    // TODO(tvm-team): skip if the target supports bf16
+    return BF16StorageLegalizer().Legalize(f);
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.BF16CastElimination", {});
-}
-
-TVM_REGISTER_GLOBAL("tir.transform.BF16CastElimination").set_body_typed(BF16CastElimination);
-
-Pass BF16TypeLowering() {
-  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
-    auto* n = f.CopyOnWrite();
-    BF16LowerRewriter lowerer;
-    lowerer.AlterBuffers(n);
-    n->body = lowerer(std::move(n->body));
-    return f;
-  };
-  return CreatePrimFuncPass(pass_func, 0, "tir.BF16TypeLowering", {});
-}
-
-TVM_REGISTER_GLOBAL("tir.transform.BF16TypeLowering").set_body_typed(BF16TypeLowering);
-
-Pass BF16Legalize() {
-  return Sequential({BF16Promote(), BF16CastElimination(), 
BF16TypeLowering()}, "tir.BF16Legalize");
+  return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {});
 }
 
-TVM_REGISTER_GLOBAL("tir.transform.BF16Legalize").set_body_typed(BF16Legalize);
+TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize);
 
 }  // namespace transform
 }  // namespace tir
diff --git a/src/tir/transforms/storage_access.h 
b/src/tir/transforms/storage_access.h
index 8fac0c302a..119e595f59 100644
--- a/src/tir/transforms/storage_access.h
+++ b/src/tir/transforms/storage_access.h
@@ -139,7 +139,6 @@ class StorageAccessVisitor : public StmtExprVisitor {
   // The involving threads
   Array<IterVar> env_threads_;
 };
-
 }  // namespace tir
 }  // namespace tvm
 #endif  // TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index 116c023caa..34dda6cf6f 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -99,6 +99,14 @@ def get_tvm_output_with_vm(
             freeze_params=freeze_params,
             convert_config=convert_config,
         )
+        # handle the bfloat16 so we explicitly allocate
+        # bfloat16 arrays as input
+        for i, param in enumerate(mod["main"].params):
+            if param.type_annotation.dtype == "bfloat16":
+                input_data[i] = tvm.nd.empty(input_data[i].shape, 
"bfloat16").copyfrom(
+                    input_data[i]
+                )
+
     if validate_structural_equal:
         with tvm.testing.enable_span_filling():
             mod_with_span, _ = relay.frontend.from_onnx(
diff --git a/tests/python/unittest/test_target_codegen_llvm.py 
b/tests/python/unittest/test_target_codegen_llvm.py
index 44f950c82a..3190115aa6 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -707,7 +707,7 @@ def np_float2tvm_bf16(arr):
     """Convert a numpy array of float to a TVM array
     of bf16"""
     nparr = np_float2np_bf16(arr)
-    return tvm.nd.empty(nparr.shape, "uint16").copyfrom(nparr)
+    return tvm.nd.empty(nparr.shape, "bfloat16").copyfrom(nparr)
 
 
 def np_bf162np_float(arr):
@@ -730,9 +730,9 @@ def test_llvm_bf16():
         B = te.placeholder((32,), dtype="bfloat16")
         d = te.compute((32,), lambda x: A[x] + B[x])
         sch = te.create_schedule(d.op)
-        print(tvm.lower(sch, [A, B, d]))
         if do_vectorize:
             sch[d].vectorize(d.op.axis[0])
+
         module = tvm.build(sch, [A, B, d])
         npa = np.random.rand(32).astype("float32")
         npb = np.random.rand(32).astype("float32")
@@ -741,7 +741,7 @@ def test_llvm_bf16():
         res = np_bf16_cast_and_cast_back(va + vb)
         a_ = np_float2tvm_bf16(npa)
         b_ = np_float2tvm_bf16(npb)
-        c_ = tvm.nd.empty((32,), "uint16")
+        c_ = tvm.nd.empty((32,), "bfloat16")
         module(a_, b_, c_)
         tvm.testing.assert_allclose(np_bf162np_float(c_.numpy()), res)
 
diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py 
b/tests/python/unittest/test_tir_transform_bf16_legalize.py
index 1e3c8061e0..ababfd489a 100644
--- a/tests/python/unittest/test_tir_transform_bf16_legalize.py
+++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py
@@ -15,164 +15,105 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import topi
-from tvm import te
-
-
-def lower_stmt(sche, params, passfunc):
-    func = tvm.driver.build_module.schedule_to_module(sche, params, "main", 
None)["main"]
-    func = passfunc()(tvm.IRModule.from_expr(func))["main"]
-    stmt = func.body
-    return stmt
-
-
-def test_promote():
-    def runpass(op, passfunc):
-        a = te.placeholder((100,), dtype="bfloat16")
-        b = te.placeholder((100,), dtype="bfloat16")
-        c = te.compute((100,), lambda i: op(a[i], b[i]))
-        s = te.create_schedule(c.op)
-        return lower_stmt(s, [a, b, c], passfunc)
-
-    def get_promoted(op):
-        a = te.placeholder((100,), dtype="bfloat16")
-        b = te.placeholder((100,), dtype="bfloat16")
-        c = te.compute(
-            (100,),
-            lambda i: topi.cast(op(topi.cast(a[i], "float"), topi.cast(b[i], 
"float")), "bfloat16"),
-        )
-        s = te.create_schedule(c.op)
-        func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], 
"main", None)["main"]
-        return func.body
-
-    def test_promoted(op):
-        stmt = runpass(op, tvm.tir.transform.BF16Promote)
-        tvm.ir.assert_structural_equal(stmt, get_promoted(op))
-
-    test_promoted(topi.add)
-    test_promoted(topi.subtract)
-    test_promoted(topi.multiply)
-    test_promoted(topi.divide)
-
-
-def test_eliminate():
-    def to32(v):
-        return topi.cast(v, "float")
-
-    def to16(v):
-        return topi.cast(v, "bfloat16")
-
-    def get_eliminated():
-        a = te.placeholder((100,), dtype="bfloat16")
-        b = te.placeholder((100,), dtype="bfloat16")
-        c = te.compute(
-            (100,),
-            lambda i: to16(
-                topi.add(
-                    to32(
-                        to16(
-                            topi.add(
-                                to32(a[i]),
-                                to32(b[i]),
-                            )
-                        )
-                    ),
-                    to32(
-                        to16(
-                            topi.add(
-                                to32(a[i]),
-                                to32(b[i]),
-                            )
-                        )
-                    ),
-                )
-            ),
-        )
-        s = te.create_schedule(c.op)
-        stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16CastElimination)
-        return stmt
-
-    def get_target():
-        a = te.placeholder((100,), dtype="bfloat16")
-        b = te.placeholder((100,), dtype="bfloat16")
-        c = te.compute(
-            (100,),
-            lambda i: to16(
-                topi.add(
-                    topi.add(
-                        to32(a[i]),
-                        to32(b[i]),
-                    ),
-                    topi.add(
-                        to32(a[i]),
-                        to32(b[i]),
-                    ),
-                )
-            ),
-        )
-        s = te.create_schedule(c.op)
-        func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], 
"main", None)["main"]
-        return func.body
-
-    tvm.ir.assert_structural_equal(get_eliminated(), get_target())
-
-
-def test_legalize():
-    def to32(v):
-        uint32_v = topi.cast(v, "uint32")
-        uint32_v = tvm.tir.call_intrin(
-            "uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32")
-        )
-        return tvm.tir.call_intrin("float32", "tir.reinterpret", uint32_v)
-
-    def to16(v):
-        uint32_v = tvm.tir.call_intrin("uint32", "tir.reinterpret", v)
-        rounding_bias = tvm.tir.call_intrin(
-            "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")
-        )
-        rounding_bias = tvm.tir.call_intrin(
-            "uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1, 
"uint32")
-        )
-        rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16")
-        uint32_v = uint32_v + rounding_bias
-        uint32_v = tvm.tir.call_intrin(
-            "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")
-        )
-        return topi.cast(uint32_v, "uint16")
-
-    def check(fcompute_before, fcompute_after):
-        a = te.placeholder((100,), dtype="bfloat16", name="A")
-        b = te.placeholder((100,), dtype="bfloat16", name="B")
-        c = te.compute((100,), fcompute_before(a, b), name="C")
-        s = te.create_schedule(c.op)
-        stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16Legalize)
-
-        a = te.placeholder((100,), dtype="uint16", name="A")
-        b = te.placeholder((100,), dtype="uint16", name="B")
-        c = te.compute((100,), fcompute_after(a, b), name="C")
-        s = te.create_schedule(c.op)
-        func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], 
"main", None)["main"]
-        tvm.ir.assert_structural_equal(stmt, func.body)
-
-    def orig1(a, b):
-        return lambda i: a[i] + b[i] + a[99 - i] + b[99 - i]
-
-    def after1(a, b):
-        return lambda i: to16(to32(a[i]) + to32(b[i]) + to32(a[99 - i]) + 
to32(b[99 - i]))
-
-    def orig2(a, b):
-        return lambda i: a[i] * b[i] + a[99 - i] * b[99 - i] + a[i]
-
-    def after2(a, b):
-        return lambda i: to16(
-            to32(a[i]) * to32(b[i]) + to32(a[99 - i]) * to32(b[99 - i]) + 
to32(a[i])
-        )
-
-    check(orig1, after1)
-    check(orig2, after2)
+import tvm.script
+from tvm.script import tir as T
+
+
+def get_before():
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def main(
+            Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: 
T.handle("bfloat16")
+        ):
+            T.func_attr({"global_symbol": "main"})
+            A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+            B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+            D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+            C = T.decl_buffer((100,), "bfloat16")
+            for i in T.grid(100):
+                C[i] = A[i] + B[i]
+                D[i] = T.exp(C[i])
+
+    return Before
+
+
+def u16tof32(v):
+    uint32_v = v.astype("uint32")
+    uint32_v = uint32_v << tvm.tir.const(16, "uint32")
+    return T.reinterpret("float32", uint32_v)
+
+
+def bf16tof32(v):
+    return u16tof32(T.reinterpret("uint16", v))
+
+
+def f32tou16(v):
+    uint32_v = T.reinterpret("uint32", v)
+    rounding_bias = (uint32_v >> tvm.tir.const(16, "uint32")) & 
tvm.tir.const(1, "uint32")
+    rounding_bias += tvm.tir.const(0x7FFF, "uint32")
+    uint32_v = uint32_v + rounding_bias
+    return uint32_v >> tvm.tir.const(16, "uint32")
+
+
+def f32tobf16(v):
+    uint32_v = f32tou16(v)
+    return T.reinterpret("bfloat16", uint32_v.astype("uint16"))
+
+
+def get_after_compute_legalize():
+    @tvm.script.ir_module
+    class After:
+        @T.prim_func
+        def main(
+            Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: 
T.handle("bfloat16")
+        ):
+            T.func_attr({"global_symbol": "main"})
+            A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+            B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+            D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+            C = T.decl_buffer((100,), "float32")
+            for i in T.grid(100):
+                C[i] = bf16tof32(A[i]) + bf16tof32(B[i])
+                D[i] = f32tobf16(T.exp(C[i]))
+
+    return After
+
+
+def get_after_storage_legalize():
+    @tvm.script.ir_module
+    class After:
+        @T.prim_func
+        def main(Aptr: T.handle("uint16"), Bptr: T.handle("uint16"), Dptr: 
T.handle("uint16")):
+            T.func_attr({"global_symbol": "main"})
+            A = T.decl_buffer((100,), "uint16", data=Aptr)
+            B = T.decl_buffer((100,), "uint16", data=Bptr)
+            D = T.decl_buffer((100,), "uint16", data=Dptr)
+            C = T.decl_buffer((100,), "float32")
+            for i in T.grid(100):
+                C[i] = u16tof32(A[i]) + u16tof32(B[i])
+                D[i] = f32tou16(T.exp(C[i]))
+
+    return After
+
+
+def test_bf16_compute_legalize():
+    before = get_before()
+    expected = get_after_compute_legalize()
+    # run the transform twice to ensure we can afford to deal
+    # with this repeative optimizations
+    after = tvm.tir.transform.BF16ComputeLegalize()(before)
+    after = tvm.tir.transform.BF16ComputeLegalize()(after)
+
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_bf16_storage_legalize():
+    before = get_after_compute_legalize()
+    after = tvm.tir.transform.BF16StorageLegalize()(before)
+    expected = get_after_storage_legalize()
+    tvm.ir.assert_structural_equal(after, expected)
 
 
 if __name__ == "__main__":
-    test_promote()
-    test_eliminate()
-    test_legalize()
+    test_bf16_storage_legalize()

Reply via email to