This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2829b59e1c [TVMScript] Add parser and printer support for e4m3/e5m2 
fp8 (#16864)
2829b59e1c is described below

commit 2829b59e1c78796da273b650f006628bca64cfcc
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Apr 10 05:22:41 2024 -0700

    [TVMScript] Add parser and printer support for e4m3/e5m2 fp8 (#16864)
    
    * [TVMScript] Add parser and printer support for e4m3/e5m2 fp8
    
    * remove unrelated
---
 include/tvm/script/ir_builder/tir/ir.h             | 12 +++++++
 python/tvm/script/ir_builder/tir/ir.py             | 39 +++++++++++++++-------
 src/script/ir_builder/tir/ir.cc                    |  5 +++
 .../python/tvmscript/test_tvmscript_printer_tir.py | 31 +++++++++++++++++
 4 files changed, 75 insertions(+), 12 deletions(-)

diff --git a/include/tvm/script/ir_builder/tir/ir.h 
b/include/tvm/script/ir_builder/tir/ir.h
index 735d5ba6c0..c4ba44f673 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -489,6 +489,18 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, 
DataType::Int);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1));                    \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4));                \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8));                \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16));              \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32));              \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64));
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, 
DataType::NVFloat8E4M3);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, 
DataType::NVFloat8E5M2);
+
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
 
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index a5c09cf1a3..127d2a4356 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1408,30 +1408,39 @@ uint16x64 = func_gen(("UInt16x64"))
 uint32x64 = func_gen(("UInt32x64"))
 uint64x64 = func_gen(("UInt64x64"))
 
-float8 = func_gen(("Float8"))
 float16 = func_gen(("Float16"))
 float32 = func_gen(("Float32"))
 float64 = func_gen(("Float64"))
-float8x4 = func_gen(("Float8x4"))
 float16x4 = func_gen(("Float16x4"))
 float32x4 = func_gen(("Float32x4"))
 float64x4 = func_gen(("Float64x4"))
-float8x8 = func_gen(("Float8x8"))
 float16x8 = func_gen(("Float16x8"))
 float32x8 = func_gen(("Float32x8"))
 float64x8 = func_gen(("Float64x8"))
-float8x16 = func_gen(("Float8x16"))
 float16x16 = func_gen(("Float16x16"))
 float32x16 = func_gen(("Float32x16"))
 float64x16 = func_gen(("Float64x16"))
-float8x32 = func_gen(("Float8x32"))
 float16x32 = func_gen(("Float16x32"))
 float32x32 = func_gen(("Float32x32"))
 float64x32 = func_gen(("Float64x32"))
-float8x64 = func_gen(("Float8x64"))
 float16x64 = func_gen(("Float16x64"))
 float32x64 = func_gen(("Float32x64"))
 float64x64 = func_gen(("Float64x64"))
+
+e4m3_float8 = func_gen(("E4M3Float8"))
+e4m3_float8x4 = func_gen(("E4M3Float8x4"))
+e4m3_float8x8 = func_gen(("E4M3Float8x8"))
+e4m3_float8x16 = func_gen(("E4M3Float8x16"))
+e4m3_float8x32 = func_gen(("E4M3Float8x32"))
+e4m3_float8x64 = func_gen(("E4M3Float8x64"))
+
+e5m2_float8 = func_gen(("E5M2Float8"))
+e5m2_float8x4 = func_gen(("E5M2Float8x4"))
+e5m2_float8x8 = func_gen(("E5M2Float8x8"))
+e5m2_float8x16 = func_gen(("E5M2Float8x16"))
+e5m2_float8x32 = func_gen(("E5M2Float8x32"))
+e5m2_float8x64 = func_gen(("E5M2Float8x64"))
+
 # pylint: enable=invalid-name
 
 
@@ -1954,27 +1963,33 @@ __all__ = [
     "uint16x64",
     "uint32x64",
     "uint64x64",
-    "float8",
+    "e4m3_float8",
+    "e5m2_float8",
     "float16",
     "float32",
     "float64",
-    "float8x4",
+    "e4m3_float8x4",
+    "e5m2_float8x4",
     "float16x4",
     "float32x4",
     "float64x4",
-    "float8x8",
+    "e4m3_float8x8",
+    "e5m2_float8x8",
     "float16x8",
     "float32x8",
     "float64x8",
-    "float8x16",
+    "e4m3_float8x16",
+    "e5m2_float8x16",
     "float16x16",
     "float32x16",
     "float64x16",
-    "float8x32",
+    "e4m3_float8x32",
+    "e5m2_float8x32",
     "float16x32",
     "float32x32",
     "float64x32",
-    "float8x64",
+    "e4m3_float8x64",
+    "e5m2_float8x64",
     "float16x64",
     "float32x64",
     "float64x64",
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 1ae1051d25..ccb5a8b57b 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -751,6 +751,11 @@ 
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float);
 TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt);
 TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int);
 
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.E4M3Float8").set_body_typed(E4M3Float8);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float8);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8);
+
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py 
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index 97a6b889c0..edc6da3163 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -917,5 +917,36 @@ def func():
     _assert_print(func, expected_output)
 
 
+@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"])
+def test_float8(dtype):
+    from tvm.script import tir as T
+
+    def get_func(dtype):
+        if dtype == "e4m3_float8":
+
+            @T.prim_func
+            def func():
+                T.evaluate(T.e4m3_float8(0.0))
+
+            return func
+        elif dtype == "e5m2_float8":
+
+            @T.prim_func
+            def func():
+                T.evaluate(T.e5m2_float8(0.0))
+
+            return func
+
+    expected_output = f"""
+# from tvm.script import tir as T
+
+@T.prim_func
+def func():
+    T.evaluate(T.{dtype}(0))
+    """
+    func = get_func(dtype)
+    _assert_print(func, expected_output)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to