This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 2829b59e1c [TVMScript] Add parser and printer support for e4m3/e5m2 fp8 (#16864) 2829b59e1c is described below commit 2829b59e1c78796da273b650f006628bca64cfcc Author: Wuwei Lin <wu...@apache.org> AuthorDate: Wed Apr 10 05:22:41 2024 -0700 [TVMScript] Add parser and printer support for e4m3/e5m2 fp8 (#16864) * [TVMScript] Add parser and printer support for e4m3/e5m2 fp8 * remove unrelated --- include/tvm/script/ir_builder/tir/ir.h | 12 +++++++ python/tvm/script/ir_builder/tir/ir.py | 39 +++++++++++++++------- src/script/ir_builder/tir/ir.cc | 5 +++ .../python/tvmscript/test_tvmscript_printer_tir.py | 31 +++++++++++++++++ 4 files changed, 75 insertions(+), 12 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 735d5ba6c0..c4ba44f673 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -489,6 +489,18 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); + +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64)); + +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, DataType::NVFloat8E4M3); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, DataType::NVFloat8E5M2); + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index a5c09cf1a3..127d2a4356 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1408,30 +1408,39 @@ uint16x64 = func_gen(("UInt16x64")) uint32x64 = func_gen(("UInt32x64")) uint64x64 = func_gen(("UInt64x64")) -float8 = func_gen(("Float8")) float16 = func_gen(("Float16")) float32 = func_gen(("Float32")) float64 = func_gen(("Float64")) -float8x4 = func_gen(("Float8x4")) float16x4 = func_gen(("Float16x4")) float32x4 = func_gen(("Float32x4")) float64x4 = func_gen(("Float64x4")) -float8x8 = func_gen(("Float8x8")) float16x8 = func_gen(("Float16x8")) float32x8 = func_gen(("Float32x8")) float64x8 = func_gen(("Float64x8")) -float8x16 = func_gen(("Float8x16")) float16x16 = func_gen(("Float16x16")) float32x16 = func_gen(("Float32x16")) float64x16 = func_gen(("Float64x16")) -float8x32 = func_gen(("Float8x32")) float16x32 = func_gen(("Float16x32")) float32x32 = func_gen(("Float32x32")) float64x32 = func_gen(("Float64x32")) -float8x64 = func_gen(("Float8x64")) float16x64 = func_gen(("Float16x64")) float32x64 = func_gen(("Float32x64")) float64x64 = func_gen(("Float64x64")) + +e4m3_float8 = func_gen(("E4M3Float8")) +e4m3_float8x4 = func_gen(("E4M3Float8x4")) +e4m3_float8x8 = func_gen(("E4M3Float8x8")) +e4m3_float8x16 = func_gen(("E4M3Float8x16")) +e4m3_float8x32 = func_gen(("E4M3Float8x32")) +e4m3_float8x64 = func_gen(("E4M3Float8x64")) + +e5m2_float8 = func_gen(("E5M2Float8")) +e5m2_float8x4 = func_gen(("E5M2Float8x4")) +e5m2_float8x8 = func_gen(("E5M2Float8x8")) +e5m2_float8x16 = func_gen(("E5M2Float8x16")) +e5m2_float8x32 = func_gen(("E5M2Float8x32")) +e5m2_float8x64 = func_gen(("E5M2Float8x64")) + # pylint: enable=invalid-name @@ -1954,27 +1963,33 @@ __all__ = [ "uint16x64", "uint32x64", "uint64x64", - "float8", + "e4m3_float8", + "e5m2_float8", "float16", "float32", "float64", - "float8x4", + "e4m3_float8x4", + "e5m2_float8x4", "float16x4", "float32x4", "float64x4", - "float8x8", + "e4m3_float8x8", + "e5m2_float8x8", "float16x8", "float32x8", "float64x8", - "float8x16", + "e4m3_float8x16", + "e5m2_float8x16", "float16x16", "float32x16", "float64x16", - "float8x32", + "e4m3_float8x32", + "e5m2_float8x32", "float16x32", "float32x32", "float64x32", - "float8x64", + "e4m3_float8x64", + "e5m2_float8x64", "float16x64", "float32x64", "float64x64", diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 1ae1051d25..ccb5a8b57b 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -751,6 +751,11 @@ TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.E4M3Float8").set_body_typed(E4M3Float8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float8); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8); + TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 97a6b889c0..edc6da3163 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -917,5 +917,36 @@ def func(): _assert_print(func, expected_output) +@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"]) +def test_float8(dtype): + from tvm.script import tir as T + + def get_func(dtype): + if dtype == "e4m3_float8": + + @T.prim_func + def func(): + T.evaluate(T.e4m3_float8(0.0)) + + return func + elif dtype == "e5m2_float8": + + @T.prim_func + def func(): + T.evaluate(T.e5m2_float8(0.0)) + + return func + + expected_output = f""" +# from tvm.script import tir as T + +@T.prim_func +def func(): + T.evaluate(T.{dtype}(0)) + """ + func = get_func(dtype) + _assert_print(func, expected_output) + + if __name__ == "__main__": tvm.testing.main()