This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4247433e33 [WebGPU] Add `tir.dp4a` (#17124)
4247433e33 is described below

commit 4247433e33dfeff9bc82521ed4c7e85605d94893
Author: Jiawei Shao <jiawei.s...@intel.com>
AuthorDate: Mon Jul 1 20:36:14 2024 +0800

    [WebGPU] Add `tir.dp4a` (#17124)
    
    * [WebGPU] Add `tir.dp4a`
    
    This patch adds `tir.dp4a` as a new TIR built-in operator as a
    preparation of supporting int8 computation with `dot4I8Packed`
    in WebGPU backend.
    
    * Fix format issues
    
    * Fix format issue
    
    * Replace `accumulation` with `accumulator`
---
 include/tvm/tir/builtin.h                  |  5 +++++
 python/tvm/script/ir_builder/tir/ir.py     |  2 ++
 python/tvm/tir/__init__.py                 |  1 +
 python/tvm/tir/op.py                       | 25 +++++++++++++++++++++++++
 src/tir/op/builtin.cc                      |  5 +++++
 tests/python/tir-base/test_tir_op_types.py |  8 ++++++++
 6 files changed, 46 insertions(+)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 120c1b71be..ea2d07903e 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -816,6 +816,11 @@ TVM_DLL const Op& vectorlow();
  */
 TVM_DLL const Op& vectorcombine();
 
+/*!
+ * \brief Dot product of two int8x4 vectors and add an optional accumulator
+ */
+TVM_DLL const Op& dp4a();
+
 /*!
  * \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA
  */
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index caefc6a6bc..bdbd6e2cda 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1932,6 +1932,7 @@ vectorlow = _dtype_forward(_tir_op.vectorlow)
 vectorhigh = _dtype_forward(_tir_op.vectorhigh)
 vectorcombine = _dtype_forward(_tir_op.vectorcombine)
 get_active_lane_mask = _dtype_forward(_tir_op.get_active_lane_mask)
+dp4a = _dtype_forward(_tir_op.dp4a)
 
 
 broadcast = Broadcast
@@ -2191,6 +2192,7 @@ __all__ = [
     "vectorlow",
     "vectorhigh",
     "vectorcombine",
+    "dp4a",
     "assume",
     "undef",
     "tvm_call_packed",
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 5360ab2b96..bcfbe6575d 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -95,6 +95,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, 
shift_left, shift_r
 from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
 from .op import start_profile_intrinsic, end_profile_intrinsic
 from .op import vscale, get_active_lane_mask, get_vscale_expr
+from .op import dp4a
 from .generic import add, subtract, multiply
 
 from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, 
ScheduleError
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 81d6604259..0bc299e403 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1813,6 +1813,31 @@ def vectorcombine(dtype, vec1, vec2):
     return call_intrin(dtype, "tir.vectorcombine", vec1, vec2)
 
 
+def dp4a(vec1, vec2, acc=0):
+    """Dot product of two int8x4 vectors and add an optional accumulator
+
+    Parameters
+    ----------
+    vec1 : int8x4
+       The input vector.
+
+    vec2 : int8x4
+       The input vector.
+
+    acc : int32
+       The accumulator.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    vec1 = convert(vec1)
+    vec2 = convert(vec2)
+    acc = convert(acc)
+    return call_intrin("int32", "tir.dp4a", vec1, vec2, acc)
+
+
 def ret(val):
     """Create a tir return expression
 
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 0404fd2823..0d4a213a23 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -355,6 +355,11 @@ TIR_DEFINE_BUILTIN_FUNC(vectorcombine)
     .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
                                          
Integer(ScriptDtypePrintLocation::kFirst));
 
+TIR_DEFINE_BUILTIN_FUNC(dp4a)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
+
 TIR_DEFINE_BUILTIN_FUNC(atomic_add)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
 
diff --git a/tests/python/tir-base/test_tir_op_types.py 
b/tests/python/tir-base/test_tir_op_types.py
index 7398ee781b..aefab62559 100644
--- a/tests/python/tir-base/test_tir_op_types.py
+++ b/tests/python/tir-base/test_tir_op_types.py
@@ -295,6 +295,14 @@ def test_tir_op_vectorhigh():
     assert expr.op.name == "tir.vectorhigh"
 
 
+def test_tir_op_dp4a():
+    vec1 = tir.Var("vec1", dtype="int8x4")
+    vec2 = tir.Var("vec2", dtype="int8x4")
+    acc = tir.Var("acc", dtype="int32")
+    expr = tir.dp4a(vec1, vec2, acc)
+    assert expr.op.name == "tir.dp4a"
+
+
 def test_tir_op_vectorcombine():
     buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1)
     vec = buffer.vload([0, 0], dtype="int8x16")

Reply via email to