This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 4247433e33 [WebGPU] Add `tir.dp4a` (#17124) 4247433e33 is described below commit 4247433e33dfeff9bc82521ed4c7e85605d94893 Author: Jiawei Shao <jiawei.s...@intel.com> AuthorDate: Mon Jul 1 20:36:14 2024 +0800 [WebGPU] Add `tir.dp4a` (#17124) * [WebGPU] Add `tir.dp4a` This patch adds `tir.dp4a` as a new TIR built-in operator as a preparation of supporting int8 computation with `dot4I8Packed` in WebGPU backend. * Fix format issues * Fix format issue * Replace `accumulation` with `accumulator` --- include/tvm/tir/builtin.h | 5 +++++ python/tvm/script/ir_builder/tir/ir.py | 2 ++ python/tvm/tir/__init__.py | 1 + python/tvm/tir/op.py | 25 +++++++++++++++++++++++++ src/tir/op/builtin.cc | 5 +++++ tests/python/tir-base/test_tir_op_types.py | 8 ++++++++ 6 files changed, 46 insertions(+) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 120c1b71be..ea2d07903e 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -816,6 +816,11 @@ TVM_DLL const Op& vectorlow(); */ TVM_DLL const Op& vectorcombine(); +/*! + * \brief Dot product of two int8x4 vectors and add an optional accumulator + */ +TVM_DLL const Op& dp4a(); + /*! * \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA */ diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index caefc6a6bc..bdbd6e2cda 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1932,6 +1932,7 @@ vectorlow = _dtype_forward(_tir_op.vectorlow) vectorhigh = _dtype_forward(_tir_op.vectorhigh) vectorcombine = _dtype_forward(_tir_op.vectorcombine) get_active_lane_mask = _dtype_forward(_tir_op.get_active_lane_mask) +dp4a = _dtype_forward(_tir_op.dp4a) broadcast = Broadcast @@ -2191,6 +2192,7 @@ __all__ = [ "vectorlow", "vectorhigh", "vectorcombine", + "dp4a", "assume", "undef", "tvm_call_packed", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 5360ab2b96..bcfbe6575d 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -95,6 +95,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_r from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic from .op import vscale, get_active_lane_mask, get_vscale_expr +from .op import dp4a from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 81d6604259..0bc299e403 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1813,6 +1813,31 @@ def vectorcombine(dtype, vec1, vec2): return call_intrin(dtype, "tir.vectorcombine", vec1, vec2) +def dp4a(vec1, vec2, acc=0): + """Dot product of two int8x4 vectors and add an optional accumulator + + Parameters + ---------- + vec1 : int8x4 + The input vector. + + vec2 : int8x4 + The input vector. + + acc : int32 + The accumulator. + + Returns + ------- + call : PrimExpr + The call expression. + """ + vec1 = convert(vec1) + vec2 = convert(vec2) + acc = convert(acc) + return call_intrin("int32", "tir.dp4a", vec1, vec2, acc) + + def ret(val): """Create a tir return expression diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 0404fd2823..0d4a213a23 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -355,6 +355,11 @@ TIR_DEFINE_BUILTIN_FUNC(vectorcombine) .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation", Integer(ScriptDtypePrintLocation::kFirst)); +TIR_DEFINE_BUILTIN_FUNC(dp4a) + .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); + TIR_DEFINE_BUILTIN_FUNC(atomic_add) .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/tests/python/tir-base/test_tir_op_types.py b/tests/python/tir-base/test_tir_op_types.py index 7398ee781b..aefab62559 100644 --- a/tests/python/tir-base/test_tir_op_types.py +++ b/tests/python/tir-base/test_tir_op_types.py @@ -295,6 +295,14 @@ def test_tir_op_vectorhigh(): assert expr.op.name == "tir.vectorhigh" +def test_tir_op_dp4a(): + vec1 = tir.Var("vec1", dtype="int8x4") + vec2 = tir.Var("vec2", dtype="int8x4") + acc = tir.Var("acc", dtype="int32") + expr = tir.dp4a(vec1, vec2, acc) + assert expr.op.name == "tir.dp4a" + + def test_tir_op_vectorcombine(): buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1) vec = buffer.vload([0, 0], dtype="int8x16")