This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a0edf24c60 [TIR] Refactor BF16Legalize (#14405)
a0edf24c60 is described below
commit a0edf24c60bad81a6f4a4333fbf2b63255a37882
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Mar 27 22:35:20 2023 -0400
[TIR] Refactor BF16Legalize (#14405)
This PR refactors BF16Legalize to enable more f32 computations.
We also split the BF16Legalize into two steps.
- BF16ComputeLegalize changes all computation to f32 while keeping
the external BF16 storages.
- BF16StorageLegalize changes all storage to u16.
Now BF16 kernels accept tvm.nd.array that are created as bfloat16 type.
---
include/tvm/tir/transform.h | 10 +-
include/tvm/topi/elemwise.h | 6 +-
python/tvm/tir/transform/transform.py | 45 +-
src/driver/driver_api.cc | 3 +-
.../postproc/disallow_async_strided_mem_copy.cc | 2 +-
src/meta_schedule/postproc/verify_gpu_code.cc | 2 +-
src/target/codegen.cc | 1 -
src/target/llvm/codegen_llvm.cc | 4 +
src/target/llvm/llvm_module.cc | 1 -
src/tir/op/op.cc | 2 +
src/tir/transforms/arg_binder.cc | 2 +-
src/tir/transforms/bf16_legalize.cc | 696 ++++++++++++++-------
src/tir/transforms/storage_access.h | 1 -
tests/python/frontend/onnx/test_forward.py | 8 +
tests/python/unittest/test_target_codegen_llvm.py | 6 +-
.../unittest/test_tir_transform_bf16_legalize.py | 257 +++-----
16 files changed, 623 insertions(+), 423 deletions(-)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 0aaa8b3e8a..d4f537ff31 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -337,11 +337,17 @@ TVM_DLL Pass CombineContextCall();
TVM_DLL Pass NarrowDataType(int target_bits);
/*!
- * \brief Legalize bf16 typed Ops. Add a cast to fp32
+ * \brief Legalize bf16 compute Ops. Add a cast to fp32
* before Ops, then add a cast back to bf16.
* \return The pass.
*/
-TVM_DLL Pass BF16Legalize();
+TVM_DLL Pass BF16ComputeLegalize();
+
+/*!
+ * \brief Legalize bf16 storage types to u16.
+ * \return The pass.
+ */
+TVM_DLL Pass BF16StorageLegalize();
/*!
* \brief Rewrite the pointer content type of arguments,
diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h
index 49b50019f0..132992c57d 100644
--- a/include/tvm/topi/elemwise.h
+++ b/include/tvm/topi/elemwise.h
@@ -310,11 +310,7 @@ inline Tensor cast(const Tensor& x, DataType type,
std::string name = "T_cast",
inline Tensor reinterpret(const Tensor& x, DataType type, std::string name =
"tensor",
std::string tag = kElementWise) {
return compute(
- x->shape,
- [&](const Array<Var>& i) {
- return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)});
- },
- name, tag);
+ x->shape, [&](const Array<Var>& i) { return reinterpret(type, x(i)); },
name, tag);
}
/*!
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index a18d698e54..1df2ac76b5 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -286,59 +286,26 @@ def RemoveStoreUndef():
return _ffi_api.RemoveStoreUndef() # type: ignore
-def BF16Legalize():
- """Legalize bf16 typed Ops.
- Runs BF16Promote, BF16CastElimination and BF16TypeLowering
+def BF16ComputeLegalize():
+ """Legalize bf16 compute Ops.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
- return _ffi_api.BF16Legalize() # type: ignore
+ return _ffi_api.BF16ComputeLegalize() # type: ignore
-def BF16Promote():
- """Promote bf16 to fp32. Add a cast to fp32
- before Ops, then add a cast back to bf16.
+def BF16StorageLegalize():
+ """Legalize bf16 storage types to u16.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
- return _ffi_api.BF16Promote() # type: ignore
-
-
-def BF16CastElimination():
- """Eliminate verbose casting between fp32 and bf16
- Checks if the AST has the pattern:
- castto32(castto16(some_fp32_op(...)))
- The verbose casting is generated by BF16Promote for multiple
- bf16 Ops in a row. e.g.:
- X[i] + Y[i] + T[i] =>
- bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
- After this pass:
- bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
-
- Returns
- -------
- fpass : tvm.transform.Pass
- The result pass
- """
- return _ffi_api.BF16CastElimination() # type: ignore
-
-
-def BF16TypeLowering():
- """Replace all bf16 type with uint16. Also lower the casting
- between fp32 and bf16
-
- Returns
- -------
- fpass : tvm.transform.Pass
- The result pass
- """
- return _ffi_api.BF16TypeLowering() # type: ignore
+ return _ffi_api.BF16StorageLegalize() # type: ignore
def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms:
bool = False):
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 3458376848..569864a29e 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -218,7 +218,7 @@ Array<tvm::transform::Pass> CreatePassList(bool
disable_loop_partition) {
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
- pass_list.push_back(tir::transform::BF16Legalize());
+ pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
@@ -605,6 +605,7 @@ transform::Sequential MixedModulePassManager(IRModule
mixed_mod, Target target)
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
+ mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
return transform::Sequential(mixed_pass_list);
diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
index 2d1507c899..d654e467f1 100644
--- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
+++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
@@ -138,7 +138,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode
{
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
- pass_list.push_back(tir::transform::BF16Legalize());
+ pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::InjectVirtualThread());
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc
b/src/meta_schedule/postproc/verify_gpu_code.cc
index 0013106b09..3240496afe 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -169,7 +169,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
- pass_list.push_back(tir::transform::BF16Legalize());
+ pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
// Phase 2
diff --git a/src/target/codegen.cc b/src/target/codegen.cc
index f6b694cb7c..24dbfebe55 100644
--- a/src/target/codegen.cc
+++ b/src/target/codegen.cc
@@ -46,7 +46,6 @@ runtime::Module Build(IRModule mod, Target target) {
.value()) {
mod = tir::transform::SkipAssert()(mod);
}
-
auto target_attr_map =
tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime");
if (target_attr_map.count(target->kind)) {
return target_attr_map[target->kind](mod, target);
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 365beb5f5d..7c32f3cfa1 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -828,6 +828,10 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
llvm::Value* end, llvm::Va
llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value*
value) {
llvm::Type* target = DTypeToLLVMType(to);
if (value->getType() == target) return value;
+ // TODO(tvm-team): consider add native support
+ ICHECK(!from.is_bfloat16()) << "BF16 needs to be storaged lowered first";
+ ICHECK(!to.is_bfloat16()) << "BF16 needs to be storaged lowered first";
+
if (to.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (to.is_uint() && to.bits() == 1) {
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index 8749925781..50dcd7402a 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -325,7 +325,6 @@ void LLVMModuleNode::Init(const IRModule& mod, const
Target& target) {
if (tm->getTargetTriple().isOSDarwin()) {
module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
}
-
std::string verify_errors_storage;
llvm::raw_string_ostream verify_errors(verify_errors_storage);
LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors))
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 828ab01083..4439a9c3d7 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -324,6 +324,8 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span)
{
// reinterpret
PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) {
if (value.dtype() == t) return value;
+ ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes())
+ << "Bitcast requires size match " << t << " vs " << value.dtype();
return tir::Call(t, tir::builtin::reinterpret(), {value}, span);
}
diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc
index 2fc3bd2dca..c785b732ab 100644
--- a/src/tir/transforms/arg_binder.cc
+++ b/src/tir/transforms/arg_binder.cc
@@ -184,7 +184,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const
PrimExpr& device_type,
TVMArrayGet(DataType::UInt(16), handle,
builtin::kArrTypeLanes) ==
IntImm(DataType::UInt(16), buffer->dtype.lanes()));
if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4)
||
- buffer->dtype == DataType::UInt(4) || buffer->dtype ==
DataType::UInt(16))) {
+ buffer->dtype == DataType::UInt(4))) {
auto type_msg = tvm::tir::StringImm(type_err_msg.str());
asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
diff --git a/src/tir/transforms/bf16_legalize.cc
b/src/tir/transforms/bf16_legalize.cc
index 3b89558622..99fad558cf 100644
--- a/src/tir/transforms/bf16_legalize.cc
+++ b/src/tir/transforms/bf16_legalize.cc
@@ -25,173 +25,198 @@
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <cmath>
#include <tuple>
-#include "../../arith/ir_mutator_with_analyzer.h"
-#include "../../arith/ir_visitor_with_analyzer.h"
-
namespace tvm {
namespace tir {
-using arith::Analyzer;
-using arith::IRMutatorWithAnalyzer;
-
-class BF16PromoteRewriter : public StmtExprMutator {
+// NOTE: do not touch buffer on function boundary
+// remap internal bf16 buffer to f32 if they meet the following condition
+// - constant allocation size
+// - do not have raw pointer access to the buffer
+//
+// populate the buffer_remap and var_remap accordingly.
+class BF16ComputeLegalizePlanner : public StmtExprVisitor {
public:
- BF16PromoteRewriter() {}
-
- Stmt operator()(Stmt s) { return VisitStmt(s); }
-
- PrimExpr VisitExpr_(const AddNode* op) final;
- PrimExpr VisitExpr_(const SubNode* op) final;
- PrimExpr VisitExpr_(const MulNode* op) final;
- PrimExpr VisitExpr_(const DivNode* op) final;
- PrimExpr VisitExpr_(const MinNode* op) final;
- PrimExpr VisitExpr_(const MaxNode* op) final;
- PrimExpr VisitExpr_(const LTNode* op) final;
- PrimExpr VisitExpr_(const LENode* op) final;
- PrimExpr VisitExpr_(const GTNode* op) final;
- PrimExpr VisitExpr_(const GENode* op) final;
- PrimExpr VisitExpr_(const CallNode* op) final;
-};
+ BF16ComputeLegalizePlanner(
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>*
buffer_remap,
+ std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap)
+ : buffer_remap_(buffer_remap), var_remap_(var_remap) {}
+
+ // run planning to populate buffer remap and var remap.
+ void Plan(PrimFunc func) {
+ this->VisitStmt(func->body);
+ // if there are opaque var access, then we cannot
+ // do remap of var and buffer, post-hoc remove these items.
+ for (Var var : opaque_var_access_) {
+ auto it = var_remap_->find(var);
+ if (it != var_remap_->end()) {
+ var_remap_->erase(it);
+ }
+ }
+ Array<Buffer> drop_buffers;
+ for (auto kv : *buffer_remap_) {
+ if (opaque_var_access_.count(kv.first->data)) {
+ drop_buffers.push_back(kv.first);
+ }
+ }
+ for (Buffer buffer : drop_buffers) {
+ auto it = buffer_remap_->find(buffer);
+ ICHECK(it != buffer_remap_->end());
+ buffer_remap_->erase(it);
+ }
+ }
-#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC, NEEDCAST)
\
- PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) {
\
- PrimExpr origin_a = this->VisitExpr(op->a);
\
- PrimExpr origin_b = this->VisitExpr(op->b);
\
- bool a_is_bfloat16 = origin_a->dtype.is_bfloat16();
\
- bool b_is_bfloat16 = origin_b->dtype.is_bfloat16();
\
- bool both_bfloat16 = a_is_bfloat16 && b_is_bfloat16;
\
- bool none_bfloat16 = !(a_is_bfloat16 || b_is_bfloat16);
\
- if (none_bfloat16) {
\
- return GetRef<PrimExpr>(op);
\
- }
\
- DataType float32_dtype(kDLFloat, 32, 1);
\
- PrimExpr float32_a = a_is_bfloat16 ? Cast(float32_dtype, origin_a) :
origin_a; \
- PrimExpr float32_b = b_is_bfloat16 ? Cast(float32_dtype, origin_b) :
origin_b; \
- PrimExpr result = FUNC(float32_a, float32_b);
\
- DataType bfloat16_dtype(kDLBfloat, 16, 1);
\
- bool do_cast = both_bfloat16 && NEEDCAST;
\
- return do_cast ? Cast(bfloat16_dtype, result) : result;
\
- }
-
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max, true)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<, false)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=, false)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>, false)
-DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=, false)
-
-PrimExpr BF16PromoteRewriter::VisitExpr_(const CallNode* op) {
- Array<PrimExpr> args;
- for (auto& arg : op->args) {
- PrimExpr x = this->VisitExpr(arg);
- if (x.dtype().is_bfloat16()) {
- DataType fp32_dtype(kDLFloat, 32, x.dtype().lanes());
- args.push_back(Cast(fp32_dtype, {x}, op->span));
- } else {
- args.push_back(x);
+ void VisitStmt_(const AllocateNode* op) final {
+ // remap all intermediate constant buffr to fp32
+ if (op->dtype.is_bfloat16() && op->ConstantAllocationSize() != 0) {
+ DataType dtype = DataType::Float(32, op->dtype.lanes());
+ Var buffer_var = Var(op->buffer_var->name_hint,
PointerType(PrimType(dtype)));
+ (*var_remap_)[op->buffer_var] = buffer_var;
}
+ return StmtExprVisitor::VisitStmt_(op);
}
- if (op->dtype.is_bfloat16()) {
- DataType fp32_dtype(kDLFloat, 32, op->dtype.lanes());
- PrimExpr result_fp32 = Call(fp32_dtype, op->op, args, op->span);
- return Cast(op->dtype, {result_fp32}, op->span);
- } else {
- return Call(op->dtype, op->op, args, op->span);
+
+ void VisitStmt_(const BufferStoreNode* op) final {
+ StmtExprVisitor::VisitStmt_(op);
+ this->PopulateBufferRemap(op->buffer);
}
-}
-/*
- * Eliminate verbose casting between fp32 and bf16
- * Checks if the AST has the pattern:
- * castto32(castto16(some_fp32_op(...)))
- * The verbose casting is generated by BF16Promote for multiple
- * bf16 Ops in a row. e.g.:
- * X[i] + Y[i] + T[i] =>
- * bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
- * After this pass:
- * bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
- */
-class BF16CastEliminationRewriter : public StmtExprMutator {
- public:
- BF16CastEliminationRewriter() {}
+ void VisitExpr_(const BufferLoadNode* op) final {
+ StmtExprVisitor::VisitExpr_(op);
+ this->PopulateBufferRemap(op->buffer);
+ }
- Stmt operator()(Stmt s) { return VisitStmt(s); }
+ void VisitStmt_(const DeclBufferNode* op) final {
+ StmtExprVisitor::VisitStmt_(op);
+ this->PopulateBufferRemap(op->buffer);
+ }
- PrimExpr VisitExpr_(const CastNode* op) final {
- auto op_val = StmtExprMutator::VisitExpr(op->value);
- if (op->dtype.is_float() && op->dtype.bits() == 32) {
- // if is cast_to_fp32, check if op->value is cast_to_fp16
- // and op->value->value is a float32
- if (auto innercast = op_val.as<CastNode>()) {
- if (innercast->dtype.is_bfloat16() &&
innercast->value->dtype.is_float() &&
- innercast->value->dtype.bits() == 32) {
- return innercast->value;
- }
- }
+ void VisitExpr_(const VarNode* op) final {
+ StmtExprVisitor::VisitExpr_(op);
+ Var buffer_var = GetRef<Var>(op);
+ if (buffer_var.dtype().is_handle()) {
+ opaque_var_access_.insert(buffer_var);
}
- if (op->value.same_as(op_val)) return GetRef<PrimExpr>(op);
- return Cast(op->dtype, op_val);
}
-};
-union FloatCaster {
- uint32_t u32;
- float f32;
+ private:
+ void PopulateBufferRemap(Buffer buf) {
+ auto var_it = var_remap_->find(buf->data);
+ if (var_it == var_remap_->end()) return;
+
+ Buffer new_buffer(var_it->second, DataType::Float(32, buf->dtype.lanes()),
buf->shape,
+ buf->strides, buf->elem_offset, buf->name,
buf->data_alignment,
+ buf->offset_factor, buf->buffer_type,
buf->axis_separators, buf->span);
+ (*buffer_remap_)[buf] = new_buffer;
+ }
+
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>*
buffer_remap_;
+ std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap_;
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> opaque_var_access_;
};
-uint16_t RoundToNearestEven(float src) {
- if (std::isnan(src)) {
- return UINT16_C(0x7FC0);
- } else {
- FloatCaster caster;
- caster.f32 = src;
- uint32_t rounding_bias = ((caster.u32 >> 16) & 1) + UINT32_C(0x7FFF);
- return static_cast<uint16_t>((caster.u32 + rounding_bias) >> 16);
+#define DEFINE_BIOP_EXPR_LEGALIZE(OP, FUNC) \
+ PrimExpr VisitExpr_(const OP* op) final { \
+ PrimExpr origin_a = PromoteBF16ToF32(this->VisitExpr(op->a)); \
+ PrimExpr origin_b = PromoteBF16ToF32(this->VisitExpr(op->b)); \
+ \
+ if (origin_a.same_as(op->a) && origin_b.same_as(op->b)) { \
+ return GetRef<PrimExpr>(op); \
+ } else { \
+ return FUNC(origin_a, origin_b); \
+ } \
}
-}
-/*
- * Lower the bf16 type to int16
- * Lower cast between bf16 and fp32
- * Lower bf16 FloatImm to int16
- */
-class BF16LowerRewriter : public StmtExprMutator {
+// NOTE: Legalize the BF16 computations
+// to floating point computations and only keeps the
+// bf16 storage which can further be legalized by BF16StorageLegalizer
+// BF16StorageLegalizer will be called at a much later time
+// point in the TIR lowering phases.
+class BF16ComputeLegalizer : public StmtExprMutator {
public:
- BF16LowerRewriter() {}
-
- using StmtExprMutator::operator();
+ PrimFunc Legalize(PrimFunc func) {
+ BF16ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_);
+ planner.Plan(func);
+ auto* n = func.CopyOnWrite();
+ n->body = this->VisitStmt(std::move(n->body));
+ return func;
+ }
+ protected:
PrimExpr VisitExpr_(const CastNode* op) final {
- PrimExpr op_val = StmtExprMutator::VisitExpr(op->value);
- DataType uint32_dtype(kDLUInt, 32, op_val->dtype.lanes());
- DataType float32_dtype(kDLFloat, 32, op_val->dtype.lanes());
- if (op->value->dtype.is_bfloat16()) { // cast from bf16
- PrimExpr uint32_v = Cast(uint32_dtype, op_val);
- PrimExpr float32_v = Call(float32_dtype, builtin::reinterpret(),
{uint32_v << 16});
- bool is_to_float32 = op->dtype.is_float() && op->dtype.bits() == 32;
- return is_to_float32 ? float32_v : Cast(op->dtype, float32_v);
- } else if (op->dtype.is_bfloat16()) { // cast to bf16
- bool is_from_float32 = op->value->dtype.is_float() &&
op->value->dtype.bits() == 32;
- PrimExpr float32_v = is_from_float32 ? op_val : Cast(float32_dtype,
op_val);
- PrimExpr uint32_v = Call(uint32_dtype, builtin::reinterpret(),
{float32_v});
- DataType uint16_dtype(kDLUInt, 16, op_val->dtype.lanes());
- /* the following TIR is equivalent to the C++ code below:
- uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
- return static_cast<uint16_t>((U32 + rounding_bias) >> 16);*/
- PrimExpr rounding_bias = ((uint32_v >> 16) & 1) +
make_const(uint16_dtype, 0x7FFF);
- return Cast(uint16_dtype, {(uint32_v + rounding_bias) >> 16});
- }
- if (op->value.same_as(op_val)) return GetRef<PrimExpr>(op);
- return Cast(op->dtype, op_val);
+ auto op_val = PromoteBF16ToF32(this->VisitExpr(op->value));
+
+ // all casts to BF16 becomes f32
+ if (op->dtype.is_bfloat16()) {
+ return cast(DataType::Float(32, op->dtype.lanes()), op_val);
+ }
+
+ if (op_val.same_as(op->value)) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ return cast(op->dtype, op_val);
+ }
+ }
+
+ PrimExpr VisitExpr_(const SelectNode* op) final {
+ PrimExpr condition = this->VisitExpr(op->condition);
+ PrimExpr true_value = PromoteBF16ToF32(this->VisitExpr(op->true_value));
+ PrimExpr false_value = PromoteBF16ToF32(this->VisitExpr(op->false_value));
+ if (condition.same_as(op->condition) && true_value.same_as(op->true_value)
&&
+ false_value.same_as(op->false_value)) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ return Select(condition, true_value, false_value);
+ }
+ }
+
+ PrimExpr VisitExpr_(const BroadcastNode* op) final {
+ PrimExpr value = PromoteBF16ToF32(this->VisitExpr(op->value));
+ if (value.same_as(op->value)) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ return Broadcast(value, op->lanes);
+ }
+ }
+
+ PrimExpr VisitExpr_(const ShuffleNode* op) final {
+ auto fexpr = [this](const PrimExpr& e) { return
PromoteBF16ToF32(this->VisitExpr(e)); };
+ auto vectors = op->vectors.Map(fexpr);
+ if (vectors.same_as(op->vectors)) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ return Shuffle(vectors, op->indices);
+ }
+ }
+
+ PrimExpr VisitExpr_(const CallNode* op) final {
+ // presertve reinterpret<bf16>() behavior.
+ if (op->op.same_as(builtin::reinterpret())) {
+ return StmtExprMutator::VisitExpr_(op);
+ }
+ // update normal computations to return f32 instead.
+ auto fmutate = [this](const PrimExpr& e) { return
PromoteBF16ToF32(this->VisitExpr(e)); };
+ Array<PrimExpr> args = op->args.Map(fmutate);
+ if (op->dtype.is_bfloat16()) {
+ return Call(DataType::Float(32, op->dtype.lanes()), op->op, args);
+ }
+ if (args.same_as(op->args)) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ return Call(op->dtype, op->op, args);
+ }
+ }
+
+ PrimExpr VisitExpr_(const FloatImmNode* op) final {
+ if (op->dtype.is_bfloat16()) {
+ return FloatImm(DataType::Float(32), op->value);
+ }
+ return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const VarNode* op) final {
@@ -205,26 +230,73 @@ class BF16LowerRewriter : public StmtExprMutator {
}
}
- Stmt VisitStmt_(const AllocateNode* op) final {
- if (op->dtype.is_bfloat16()) {
- DataType dtype = DataType::UInt(16, op->dtype.lanes());
- Var buffer_var = Var(op->buffer_var->name_hint,
PointerType(PrimType(dtype)));
- var_remap_[op->buffer_var] = buffer_var;
- return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition,
op->body));
+ PrimExpr VisitExpr_(const LetNode* op) final {
+ PrimExpr value = PromoteBF16ToF32(op->value);
+ Var var = op->var;
+ if (value.dtype() != op->value.dtype()) {
+ var = op->var.copy_with_dtype(op->value.dtype());
+ var_remap_[op->var] = var;
+ }
+
+ PrimExpr body = VisitExpr(op->body);
+
+ if (value.same_as(op->value) && var.same_as(op->var) &&
body.same_as(op->body)) {
+ return GetRef<PrimExpr>(op);
} else {
- return StmtExprMutator::VisitStmt_(op);
+ return Let(var, value, body);
+ }
+ }
+
+ DEFINE_BIOP_EXPR_LEGALIZE(AddNode, operator+);
+ DEFINE_BIOP_EXPR_LEGALIZE(SubNode, operator-);
+ DEFINE_BIOP_EXPR_LEGALIZE(MulNode, operator*);
+ DEFINE_BIOP_EXPR_LEGALIZE(DivNode, div);
+ DEFINE_BIOP_EXPR_LEGALIZE(MinNode, min);
+ DEFINE_BIOP_EXPR_LEGALIZE(MaxNode, max);
+ DEFINE_BIOP_EXPR_LEGALIZE(LTNode, operator<); // NOLINT(*)
+ DEFINE_BIOP_EXPR_LEGALIZE(LENode, operator<=);
+ DEFINE_BIOP_EXPR_LEGALIZE(GTNode, operator>); // NOLINT(*)
+ DEFINE_BIOP_EXPR_LEGALIZE(GENode, operator>=);
+ DEFINE_BIOP_EXPR_LEGALIZE(EQNode, operator==);
+ DEFINE_BIOP_EXPR_LEGALIZE(NENode, operator!=);
+
+ Stmt VisitStmt_(const LetStmtNode* op) final {
+ PrimExpr value = PromoteBF16ToF32(op->value);
+ Var var = op->var;
+ if (value.dtype() != op->value.dtype()) {
+ var = op->var.copy_with_dtype(op->value.dtype());
+ var_remap_[op->var] = var;
+ }
+ Stmt body = VisitStmt(op->body);
+
+ if (value.same_as(op->value) && var.same_as(op->var) &&
body.same_as(op->body)) {
+ return GetRef<Stmt>(op);
+ } else {
+ return LetStmt(var, value, body);
}
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
- Stmt ret = StmtExprMutator::VisitStmt_(op);
- op = ret.as<BufferStoreNode>();
+ PrimExpr value = this->VisitExpr(op->value);
+ auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
+
+ Array<PrimExpr> indices = op->indices.Map(fmutate);
Buffer new_buf = GetRemappedBuffer(op->buffer);
- if (new_buf.same_as(op->buffer)) {
- return ret;
+
+ if (value.same_as(op->value) && indices.same_as(op->indices) &&
new_buf.same_as(op->buffer)) {
+ return GetRef<Stmt>(op);
} else {
- return BufferStore(new_buf, op->value, op->indices);
+ if (new_buf->dtype.is_bfloat16()) {
+ value = CastF32ToBF16(value);
+ }
+ if (value.dtype() != new_buf->dtype) {
+ // this happens when buffer get rewritten to f32
+ // but values remain as bf16
+ ICHECK(value.dtype().is_bfloat16());
+ value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value);
+ }
+ return BufferStore(new_buf, value, indices);
}
}
@@ -258,6 +330,35 @@ class BF16LowerRewriter : public StmtExprMutator {
}
}
+ Stmt VisitStmt_(const DeclBufferNode* op) final {
+ Stmt ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<DeclBufferNode>();
+
+ Buffer new_buf = GetRemappedBuffer(op->buffer);
+ if (new_buf.same_as(op->buffer)) {
+ return ret;
+ } else {
+ return DeclBuffer(new_buf, op->body);
+ }
+ }
+
+ Stmt VisitStmt_(const AllocateNode* op) final {
+ Stmt ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<AllocateNode>();
+
+ auto it = var_remap_.find(op->buffer_var);
+ if (it != var_remap_.end()) {
+ Var remapped_var = it->second;
+ auto* ptr = remapped_var->type_annotation.as<PointerTypeNode>();
+ ICHECK(ptr);
+ auto* prim_type = ptr->element_type.as<PrimTypeNode>();
+ ICHECK(prim_type);
+ return Allocate(remapped_var, prim_type->dtype, op->extents,
op->condition, op->body);
+ } else {
+ return ret;
+ }
+ }
+
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<BufferLoadNode>();
@@ -270,48 +371,244 @@ class BF16LowerRewriter : public StmtExprMutator {
}
}
- PrimExpr VisitExpr_(const FloatImmNode* op) final {
+ private:
+ /*!
+ * \brief promote BF16 to F32 and keep other values unchanged.
+ * \param value The input value.
+ * \return The converted value.
+ */
+ PrimExpr PromoteBF16ToF32(PrimExpr value) {
+ if (!value.dtype().is_bfloat16()) return value;
+ if (const CastNode* cast = value.as<CastNode>()) {
+ if (cast->value.dtype() == DataType::Float(32)) return cast->value;
+ }
+ DataType f32 = DataType::Float(32, value.dtype().lanes());
+ DataType u16 = DataType::UInt(16, value.dtype().lanes());
+ DataType u32 = DataType::UInt(32, value.dtype().lanes());
+ // reinterpret<f32>((cast<u32>(reinterpret<u16>(bf16_value)) << 16))
+ return reinterpret(f32, cast(u32, reinterpret(u16, value)) << 16);
+ }
+
+ /*!
+ * \brief Cast value to F32 to BF16 and keep other values unchanged.
+ * \param value The input value
+ * \return The converted value.
+ */
+ PrimExpr CastF32ToBF16(PrimExpr value) {
+ if (!value.dtype().is_float()) return value;
+ ICHECK_EQ(value.dtype().bits(), 32);
+ DataType bf16 = DataType::BFloat(16, value.dtype().lanes());
+ DataType u16 = DataType::UInt(16, value.dtype().lanes());
+ DataType u32 = DataType::UInt(32, value.dtype().lanes());
+ PrimExpr u32_val = reinterpret(u32, value);
+
+ if (round_to_even_) {
+ PrimExpr rounding_bias = ((u32_val >> 16) & 1) + make_const(u32, 0x7FFF);
+ u32_val = u32_val + rounding_bias;
+ }
+ // reinterpret<bf16>((cast<u16>(reinterpret<u32>(f32_value)) >> 16))
+ return reinterpret(bf16, cast(u16, u32_val >> 16));
+ }
+
+ Buffer GetRemappedBuffer(Buffer buf) {
+ auto buf_it = buffer_remap_.find(buf);
+ if (buf_it != buffer_remap_.end()) {
+ return buf_it->second;
+ }
+ return buf;
+ }
+
+ bool round_to_even_{true};
+
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>
buffer_remap_;
+ std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
+};
+
+/*!
+ * \brief This Pass legalizes remaining BF16 storages to u16
+ *
+ * This pass needs to happens after BF16ComputeLegalizer and serves
+ * as a way to support BF16 on platforms that do not have native support.
+ */
+class BF16StorageLegalizer : public StmtExprMutator {
+ public:
+ PrimFunc Legalize(PrimFunc func) {
+ ICHECK_EQ(func->buffer_map.size(), 0) << "This pass must be called after
MakePackedAPI";
+ auto* n = func.CopyOnWrite();
+ n->params = n->params.Map([this](Var var) { return this->RemapVarDef(var);
});
+ n->body = this->VisitStmt(std::move(n->body));
+ return func;
+ }
+
+ private:
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ Var var = GetRef<Var>(op);
+ auto itr = var_remap_.find(var);
+ if (itr != var_remap_.end()) {
+ return itr->second;
+ } else {
+ return std::move(var);
+ }
+ }
+
+ Stmt VisitStmt_(const AllocateNode* op) final {
if (op->dtype.is_bfloat16()) {
- return IntImm(DataType::UInt(16, op->dtype.lanes()),
- RoundToNearestEven(static_cast<float>(op->value)));
+ DataType dtype = DataType::UInt(16, op->dtype.lanes());
+ Var buffer_var = Var(op->buffer_var->name_hint,
PointerType(PrimType(dtype)));
+ var_remap_[op->buffer_var] = buffer_var;
+ return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition,
op->body));
+ } else {
+ return StmtExprMutator::VisitStmt_(op);
}
- return StmtExprMutator::VisitExpr_(op);
}
- void AlterBuffers(PrimFuncNode* op) {
- Map<Var, Buffer> new_buffer_map;
-
- for (auto& itr : op->buffer_map) {
- auto param_var = itr.first;
- auto oldbuf = itr.second;
- if (oldbuf->dtype.is_bfloat16()) {
- DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
- Var buffer_var = Var(oldbuf->data->name_hint,
PointerType(PrimType(dtype)));
- auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape,
oldbuf->strides, oldbuf->elem_offset,
- oldbuf->name, oldbuf->data_alignment,
oldbuf->offset_factor,
- oldbuf->buffer_type);
- buffer_remap_[oldbuf] = newbuf;
- var_remap_[oldbuf->data] = buffer_var;
- new_buffer_map.Set(param_var, newbuf);
- } else {
- new_buffer_map.Set(param_var, oldbuf);
+ Stmt VisitStmt_(const DeclBufferNode* op) final {
+ Buffer buf = GetRemappedBuffer(op->buffer);
+ // in a rare case the buffer didn't get remapped
+ // because the original var is not bfloat*
+ // force remap here
+ if (buf->dtype.is_bfloat16()) {
+ buf = Buffer(buf->data, DataType::UInt(16, buf->dtype.lanes()),
buf->shape, buf->strides,
+ buf->elem_offset, buf->name, buf->data_alignment,
buf->offset_factor,
+ buf->buffer_type, buf->axis_separators, buf->span);
+ buffer_remap_[op->buffer] = buf;
+ }
+ Stmt body = VisitStmt(op->body);
+ if (buf.same_as(op->buffer) && body.same_as(op->body)) {
+ return GetRef<Stmt>(op);
+ } else {
+ return DeclBuffer(buf, body, op->span);
+ }
+ }
+
+ PrimExpr VisitExpr_(const LetNode* op) final {
+ PrimExpr value = VisitExpr(op->value);
+ Var var = RemapVarDef(op->var);
+ PrimExpr body = VisitExpr(op->body);
+
+ if (value.same_as(op->value) && var.same_as(op->var) &&
body.same_as(op->body)) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ return Let(var, value, body);
+ }
+ }
+
+ Stmt VisitStmt_(const LetStmtNode* op) final {
+ PrimExpr value = VisitExpr(op->value);
+ Var var = RemapVarDef(op->var);
+ Stmt body = VisitStmt(op->body);
+
+ if (value.same_as(op->value) && var.same_as(op->var) &&
body.same_as(op->body)) {
+ return GetRef<Stmt>(op);
+ } else {
+ return LetStmt(var, value, body);
+ }
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ PrimExpr value = this->ChangeBF16ToU16(VisitExpr(op->value));
+ Buffer new_buf = GetRemappedBuffer(op->buffer);
+ auto indices = op->indices.Map([this](PrimExpr expr) { return
this->VisitExpr(expr); });
+ if (new_buf.same_as(op->buffer) && indices.same_as(op->indices) &&
value.same_as(op->value)) {
+ return GetRef<Stmt>(op);
+ } else {
+ if (op->value.dtype().is_bfloat16()) {
+ ICHECK(new_buf->dtype.is_uint());
+ }
+ return BufferStore(new_buf, value, indices);
+ }
+ }
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ Stmt ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<AttrStmtNode>();
+
+ if (auto* buffer = op->node.as<BufferNode>()) {
+ auto it = buffer_remap_.find(GetRef<Buffer>(buffer));
+ if (it != buffer_remap_.end()) {
+ return AttrStmt(it->second, op->attr_key, op->value, op->body);
+ }
+ } else if (auto* var = op->node.as<VarNode>()) {
+ auto it = var_remap_.find(GetRef<Var>(var));
+ if (it != var_remap_.end()) {
+ return AttrStmt(it->second, op->attr_key, op->value, op->body);
}
}
+ return ret;
+ }
- if (buffer_remap_.size() != 0) {
- op->buffer_map = new_buffer_map;
+ Stmt VisitStmt_(const BufferRealizeNode* op) final {
+ LOG(FATAL) << "Do not expect buffer realize";
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ PrimExpr ret = StmtExprMutator::VisitExpr_(op);
+ op = ret.as<BufferLoadNode>();
+ Buffer new_buf = GetRemappedBuffer(op->buffer);
+ if (new_buf.same_as(op->buffer)) {
+ return ret;
+ } else {
+ return BufferLoad(new_buf, op->indices);
}
}
+ PrimExpr VisitExpr_(const CallNode* op) final {
+ // remap re-interpret so un-necessary reinterpret can be skipped.
+ if (op->op.same_as(builtin::reinterpret())) {
+ PrimExpr value = VisitExpr(op->args[0]);
+ // sometimes the input dtype can change and we can skip.
+ if (value.dtype() == op->dtype) return value;
+ if (op->dtype.is_bfloat16()) {
+ return reinterpret(DataType::UInt(16, op->dtype.lanes()), value);
+ }
+ if (op->args[0].same_as(value)) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ return reinterpret(op->dtype, value);
+ }
+ }
+ return StmtExprMutator::VisitExpr_(op);
+ }
+
private:
+ /*!
+ * \brief Change BF16 value to U16 value.
+ * \param value The input value.
+ * \return The converted value.
+ */
+ PrimExpr ChangeBF16ToU16(PrimExpr value) {
+ if (!value.dtype().is_bfloat16()) return value;
+ auto* call = value.as<CallNode>();
+ if (call && call->op.same_as(builtin::reinterpret())) {
+ return reinterpret(DataType::UInt(16, value.dtype().lanes()),
call->args[0]);
+ } else {
+ return value;
+ }
+ }
+
+ Var RemapVarDef(Var var) {
+ // remap the var
+ if (var.dtype().is_handle()) {
+ if (auto* ptr_type = var->type_annotation.as<PointerTypeNode>()) {
+ if (auto* elem_type = ptr_type->element_type.as<PrimTypeNode>()) {
+ if (elem_type->dtype.is_bfloat16()) {
+ Var new_var = Var(var->name_hint,
+ PointerType(PrimType(DataType::UInt(16,
elem_type->dtype.lanes()))));
+ var_remap_[var] = new_var;
+ return new_var;
+ }
+ }
+ }
+ }
+ return var;
+ }
+
Buffer GetRemappedBuffer(Buffer buf) {
auto buf_it = buffer_remap_.find(buf);
if (buf_it != buffer_remap_.end()) {
return buf_it->second;
}
-
Buffer new_buf = buf;
-
auto var_it = var_remap_.find(buf->data);
if (var_it != var_remap_.end()) {
DataType dtype =
@@ -319,6 +616,8 @@ class BF16LowerRewriter : public StmtExprMutator {
new_buf = Buffer(var_it->second, dtype, buf->shape, buf->strides,
buf->elem_offset, buf->name,
buf->data_alignment, buf->offset_factor,
buf->buffer_type,
buf->axis_separators, buf->span);
+ } else {
+ ICHECK(!buf->dtype.is_bfloat16()) << "Cannot find var remap for " << buf;
}
buffer_remap_[buf] = new_buf;
@@ -332,46 +631,25 @@ class BF16LowerRewriter : public StmtExprMutator {
namespace transform {
-Pass BF16Promote() {
+Pass BF16ComputeLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
- auto* n = f.CopyOnWrite();
- n->body = BF16PromoteRewriter()(std::move(n->body));
- return f;
+ // TODO(tvm-team): skip if the target supports bf16
+ return BF16ComputeLegalizer().Legalize(f);
};
- return CreatePrimFuncPass(pass_func, 0, "tir.BF16Promote", {});
+ return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.BF16Promote").set_body_typed(BF16Promote);
+TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16ComputeLegalize);
-Pass BF16CastElimination() {
+Pass BF16StorageLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
- auto* n = f.CopyOnWrite();
- n->body = BF16CastEliminationRewriter()(std::move(n->body));
- return f;
+ // TODO(tvm-team): skip if the target supports bf16
+ return BF16StorageLegalizer().Legalize(f);
};
- return CreatePrimFuncPass(pass_func, 0, "tir.BF16CastElimination", {});
-}
-
-TVM_REGISTER_GLOBAL("tir.transform.BF16CastElimination").set_body_typed(BF16CastElimination);
-
-Pass BF16TypeLowering() {
- auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
- auto* n = f.CopyOnWrite();
- BF16LowerRewriter lowerer;
- lowerer.AlterBuffers(n);
- n->body = lowerer(std::move(n->body));
- return f;
- };
- return CreatePrimFuncPass(pass_func, 0, "tir.BF16TypeLowering", {});
-}
-
-TVM_REGISTER_GLOBAL("tir.transform.BF16TypeLowering").set_body_typed(BF16TypeLowering);
-
-Pass BF16Legalize() {
- return Sequential({BF16Promote(), BF16CastElimination(),
BF16TypeLowering()}, "tir.BF16Legalize");
+ return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.BF16Legalize").set_body_typed(BF16Legalize);
+TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize);
} // namespace transform
} // namespace tir
diff --git a/src/tir/transforms/storage_access.h
b/src/tir/transforms/storage_access.h
index 8fac0c302a..119e595f59 100644
--- a/src/tir/transforms/storage_access.h
+++ b/src/tir/transforms/storage_access.h
@@ -139,7 +139,6 @@ class StorageAccessVisitor : public StmtExprVisitor {
// The involving threads
Array<IterVar> env_threads_;
};
-
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index 116c023caa..34dda6cf6f 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -99,6 +99,14 @@ def get_tvm_output_with_vm(
freeze_params=freeze_params,
convert_config=convert_config,
)
+ # handle the bfloat16 so we explicitly allocate
+ # bfloat16 arrays as input
+ for i, param in enumerate(mod["main"].params):
+ if param.type_annotation.dtype == "bfloat16":
+ input_data[i] = tvm.nd.empty(input_data[i].shape,
"bfloat16").copyfrom(
+ input_data[i]
+ )
+
if validate_structural_equal:
with tvm.testing.enable_span_filling():
mod_with_span, _ = relay.frontend.from_onnx(
diff --git a/tests/python/unittest/test_target_codegen_llvm.py
b/tests/python/unittest/test_target_codegen_llvm.py
index 44f950c82a..3190115aa6 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -707,7 +707,7 @@ def np_float2tvm_bf16(arr):
"""Convert a numpy array of float to a TVM array
of bf16"""
nparr = np_float2np_bf16(arr)
- return tvm.nd.empty(nparr.shape, "uint16").copyfrom(nparr)
+ return tvm.nd.empty(nparr.shape, "bfloat16").copyfrom(nparr)
def np_bf162np_float(arr):
@@ -730,9 +730,9 @@ def test_llvm_bf16():
B = te.placeholder((32,), dtype="bfloat16")
d = te.compute((32,), lambda x: A[x] + B[x])
sch = te.create_schedule(d.op)
- print(tvm.lower(sch, [A, B, d]))
if do_vectorize:
sch[d].vectorize(d.op.axis[0])
+
module = tvm.build(sch, [A, B, d])
npa = np.random.rand(32).astype("float32")
npb = np.random.rand(32).astype("float32")
@@ -741,7 +741,7 @@ def test_llvm_bf16():
res = np_bf16_cast_and_cast_back(va + vb)
a_ = np_float2tvm_bf16(npa)
b_ = np_float2tvm_bf16(npb)
- c_ = tvm.nd.empty((32,), "uint16")
+ c_ = tvm.nd.empty((32,), "bfloat16")
module(a_, b_, c_)
tvm.testing.assert_allclose(np_bf162np_float(c_.numpy()), res)
diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py
b/tests/python/unittest/test_tir_transform_bf16_legalize.py
index 1e3c8061e0..ababfd489a 100644
--- a/tests/python/unittest/test_tir_transform_bf16_legalize.py
+++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py
@@ -15,164 +15,105 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import topi
-from tvm import te
-
-
-def lower_stmt(sche, params, passfunc):
- func = tvm.driver.build_module.schedule_to_module(sche, params, "main",
None)["main"]
- func = passfunc()(tvm.IRModule.from_expr(func))["main"]
- stmt = func.body
- return stmt
-
-
-def test_promote():
- def runpass(op, passfunc):
- a = te.placeholder((100,), dtype="bfloat16")
- b = te.placeholder((100,), dtype="bfloat16")
- c = te.compute((100,), lambda i: op(a[i], b[i]))
- s = te.create_schedule(c.op)
- return lower_stmt(s, [a, b, c], passfunc)
-
- def get_promoted(op):
- a = te.placeholder((100,), dtype="bfloat16")
- b = te.placeholder((100,), dtype="bfloat16")
- c = te.compute(
- (100,),
- lambda i: topi.cast(op(topi.cast(a[i], "float"), topi.cast(b[i],
"float")), "bfloat16"),
- )
- s = te.create_schedule(c.op)
- func = tvm.driver.build_module.schedule_to_module(s, [a, b, c],
"main", None)["main"]
- return func.body
-
- def test_promoted(op):
- stmt = runpass(op, tvm.tir.transform.BF16Promote)
- tvm.ir.assert_structural_equal(stmt, get_promoted(op))
-
- test_promoted(topi.add)
- test_promoted(topi.subtract)
- test_promoted(topi.multiply)
- test_promoted(topi.divide)
-
-
-def test_eliminate():
- def to32(v):
- return topi.cast(v, "float")
-
- def to16(v):
- return topi.cast(v, "bfloat16")
-
- def get_eliminated():
- a = te.placeholder((100,), dtype="bfloat16")
- b = te.placeholder((100,), dtype="bfloat16")
- c = te.compute(
- (100,),
- lambda i: to16(
- topi.add(
- to32(
- to16(
- topi.add(
- to32(a[i]),
- to32(b[i]),
- )
- )
- ),
- to32(
- to16(
- topi.add(
- to32(a[i]),
- to32(b[i]),
- )
- )
- ),
- )
- ),
- )
- s = te.create_schedule(c.op)
- stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16CastElimination)
- return stmt
-
- def get_target():
- a = te.placeholder((100,), dtype="bfloat16")
- b = te.placeholder((100,), dtype="bfloat16")
- c = te.compute(
- (100,),
- lambda i: to16(
- topi.add(
- topi.add(
- to32(a[i]),
- to32(b[i]),
- ),
- topi.add(
- to32(a[i]),
- to32(b[i]),
- ),
- )
- ),
- )
- s = te.create_schedule(c.op)
- func = tvm.driver.build_module.schedule_to_module(s, [a, b, c],
"main", None)["main"]
- return func.body
-
- tvm.ir.assert_structural_equal(get_eliminated(), get_target())
-
-
-def test_legalize():
- def to32(v):
- uint32_v = topi.cast(v, "uint32")
- uint32_v = tvm.tir.call_intrin(
- "uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32")
- )
- return tvm.tir.call_intrin("float32", "tir.reinterpret", uint32_v)
-
- def to16(v):
- uint32_v = tvm.tir.call_intrin("uint32", "tir.reinterpret", v)
- rounding_bias = tvm.tir.call_intrin(
- "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")
- )
- rounding_bias = tvm.tir.call_intrin(
- "uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1,
"uint32")
- )
- rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16")
- uint32_v = uint32_v + rounding_bias
- uint32_v = tvm.tir.call_intrin(
- "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")
- )
- return topi.cast(uint32_v, "uint16")
-
- def check(fcompute_before, fcompute_after):
- a = te.placeholder((100,), dtype="bfloat16", name="A")
- b = te.placeholder((100,), dtype="bfloat16", name="B")
- c = te.compute((100,), fcompute_before(a, b), name="C")
- s = te.create_schedule(c.op)
- stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16Legalize)
-
- a = te.placeholder((100,), dtype="uint16", name="A")
- b = te.placeholder((100,), dtype="uint16", name="B")
- c = te.compute((100,), fcompute_after(a, b), name="C")
- s = te.create_schedule(c.op)
- func = tvm.driver.build_module.schedule_to_module(s, [a, b, c],
"main", None)["main"]
- tvm.ir.assert_structural_equal(stmt, func.body)
-
- def orig1(a, b):
- return lambda i: a[i] + b[i] + a[99 - i] + b[99 - i]
-
- def after1(a, b):
- return lambda i: to16(to32(a[i]) + to32(b[i]) + to32(a[99 - i]) +
to32(b[99 - i]))
-
- def orig2(a, b):
- return lambda i: a[i] * b[i] + a[99 - i] * b[99 - i] + a[i]
-
- def after2(a, b):
- return lambda i: to16(
- to32(a[i]) * to32(b[i]) + to32(a[99 - i]) * to32(b[99 - i]) +
to32(a[i])
- )
-
- check(orig1, after1)
- check(orig2, after2)
+import tvm.script
+from tvm.script import tir as T
+
+
+def get_before():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(
+ Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr:
T.handle("bfloat16")
+ ):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+ B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+ D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+ C = T.decl_buffer((100,), "bfloat16")
+ for i in T.grid(100):
+ C[i] = A[i] + B[i]
+ D[i] = T.exp(C[i])
+
+ return Before
+
+
+def u16tof32(v):
+ uint32_v = v.astype("uint32")
+ uint32_v = uint32_v << tvm.tir.const(16, "uint32")
+ return T.reinterpret("float32", uint32_v)
+
+
+def bf16tof32(v):
+ return u16tof32(T.reinterpret("uint16", v))
+
+
+def f32tou16(v):
+ uint32_v = T.reinterpret("uint32", v)
+ rounding_bias = (uint32_v >> tvm.tir.const(16, "uint32")) &
tvm.tir.const(1, "uint32")
+ rounding_bias += tvm.tir.const(0x7FFF, "uint32")
+ uint32_v = uint32_v + rounding_bias
+ return uint32_v >> tvm.tir.const(16, "uint32")
+
+
+def f32tobf16(v):
+ uint32_v = f32tou16(v)
+ return T.reinterpret("bfloat16", uint32_v.astype("uint16"))
+
+
+def get_after_compute_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func
+ def main(
+ Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr:
T.handle("bfloat16")
+ ):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+ B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+ D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+ C = T.decl_buffer((100,), "float32")
+ for i in T.grid(100):
+ C[i] = bf16tof32(A[i]) + bf16tof32(B[i])
+ D[i] = f32tobf16(T.exp(C[i]))
+
+ return After
+
+
+def get_after_storage_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func
+ def main(Aptr: T.handle("uint16"), Bptr: T.handle("uint16"), Dptr:
T.handle("uint16")):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "uint16", data=Aptr)
+ B = T.decl_buffer((100,), "uint16", data=Bptr)
+ D = T.decl_buffer((100,), "uint16", data=Dptr)
+ C = T.decl_buffer((100,), "float32")
+ for i in T.grid(100):
+ C[i] = u16tof32(A[i]) + u16tof32(B[i])
+ D[i] = f32tou16(T.exp(C[i]))
+
+ return After
+
+
+def test_bf16_compute_legalize():
+ before = get_before()
+ expected = get_after_compute_legalize()
+ # run the transform twice to ensure we can afford to deal
+ # with this repeative optimizations
+ after = tvm.tir.transform.BF16ComputeLegalize()(before)
+ after = tvm.tir.transform.BF16ComputeLegalize()(after)
+
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_bf16_storage_legalize():
+ before = get_after_compute_legalize()
+ after = tvm.tir.transform.BF16StorageLegalize()(before)
+ expected = get_after_storage_legalize()
+ tvm.ir.assert_structural_equal(after, expected)
if __name__ == "__main__":
- test_promote()
- test_eliminate()
- test_legalize()
+ test_bf16_storage_legalize()