This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c309e4ea5f [TIR] Add cooperative_tensor builtins and
metal.cooperative_tensor storage scope (#19423)
c309e4ea5f is described below
commit c309e4ea5f6f5d7bae6ab9753995269d56548263
Author: Yichen Yan <[email protected]>
AuthorDate: Mon May 11 21:16:29 2026 +0800
[TIR] Add cooperative_tensor builtins and metal.cooperative_tensor storage
scope (#19423)
part of https://github.com/tile-ai/tilelang/pull/1869
## Summary
Add TIR builtins and storage scope for Metal cooperative_tensor
operations (MetalPerformancePrimitives / Metal 4).
## Motivation
Apple Metal 4 introduces MetalPerformancePrimitives (MPP) with
`matmul2d` using `cooperative_tensor` operands. On M5, this routes to
NAX tensor cores; on M1-M4, it falls back to simdgroup matrix
instructions. These TIR primitives enable backend codegen to emit MPP
calls.
## Changes
### New TIR builtins
- `cooperative_tensor_fill(d, index, value, rows, cols)`
- `cooperative_tensor_load(d, index, ptr, stride, rows, cols,
transpose)`
- `cooperative_tensor_store(d, index, ptr, stride, rows, cols,
transpose)`
- `cooperative_tensor_multiply_accumulate(d, di, a, ai, b, bi, c, ci, M,
N, K, trans_a, trans_b)`
### New storage scope
- `metal.cooperative_tensor` (`StorageRank::kMetalCooperativeTensor`)
### Files changed
- `include/tvm/tirx/builtin.h` — Op declarations
- `src/tirx/op/builtin.cc` — Op registrations
- `python/tvm/tirx/op.py` — Python wrappers
- `python/tvm/script/ir_builder/tirx/ir.py` — Script parser exports
- `src/runtime/thread_storage_scope.h` — StorageRank enum + scope
parsing
These builtins mirror the existing `simdgroup_*` builtins for the older
Metal simdgroup matrix API, extended with M/N/K dimension parameters for
the matmul2d descriptor.
---
include/tvm/tirx/builtin.h | 45 +++++++++++++++
python/tvm/tirx/op.py | 104 +++++++++++++++++++++++++++++++++++
python/tvm/tirx/script/builder/ir.py | 8 +++
src/runtime/thread_storage_scope.h | 7 +++
src/tirx/op/builtin.cc | 12 ++++
5 files changed, 176 insertions(+)
diff --git a/include/tvm/tirx/builtin.h b/include/tvm/tirx/builtin.h
index 3339b1aa49..2e69ce80d2 100644
--- a/include/tvm/tirx/builtin.h
+++ b/include/tvm/tirx/builtin.h
@@ -782,6 +782,51 @@ TVM_DLL const Op& simdgroup_store();
*/
TVM_DLL const Op& simdgroup_multiply_accumulate();
+// Metal cooperative_tensor intrinsics (MetalPerformancePrimitives / Metal 4)
+
+/*!
+ * \brief Fill a cooperative_tensor with a given value.
+ *
+ * void cooperative_tensor_fill(Var d, PrimExpr index, PrimExpr value,
+ * int rows, int cols);
+ */
+TVM_DLL const Op& cooperative_tensor_fill();
+
+/*!
+ * \brief Load data from device or threadgroup memory into a
cooperative_tensor.
+ *
+ * void cooperative_tensor_load(Var d, PrimExpr index, PrimExpr ptr,
+ * PrimExpr stride, int rows, int cols,
+ * bool transpose_matrix,
+ * int mma_M, int mma_N, int mma_K,
+ * int operand_role);
+ * operand_role: 0=left(A), 1=right(B), 2=destination(C)
+ */
+TVM_DLL const Op& cooperative_tensor_load();
+
+/*!
+ * \brief Store data from a cooperative_tensor to device or threadgroup memory.
+ *
+ * void cooperative_tensor_store(Var d, PrimExpr index, PrimExpr ptr,
+ * PrimExpr stride, int rows, int cols,
+ * bool transpose_matrix,
+ * int mma_M, int mma_N, int mma_K,
+ * int operand_role);
+ * operand_role: 0=left(A), 1=right(B), 2=destination(C)
+ */
+TVM_DLL const Op& cooperative_tensor_store();
+
+/*!
+ * \brief Multiply and accumulate two matrices using cooperative_tensor
+ * (MetalPerformancePrimitives matmul2d).
+ *
+ * void cooperative_tensor_multiply_accumulate(
+ * Var d, PrimExpr index_d, Var a, PrimExpr index_a,
+ * Var b, PrimExpr index_b, Var c, PrimExpr index_c,
+ * int M, int N, int K, bool transpose_a, bool transpose_b);
+ */
+TVM_DLL const Op& cooperative_tensor_multiply_accumulate();
+
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py
index 566f5d905b..2cdc6f0b36 100644
--- a/python/tvm/tirx/op.py
+++ b/python/tvm/tirx/op.py
@@ -1793,6 +1793,110 @@ def simdgroup_multiply_accumulate(
)
+def cooperative_tensor_fill(
+ d: Var,
+ index: PrimExpr,
+ value: PrimExpr,
+ rows: int,
+ cols: int,
+):
+ return call_intrin("handle", "tirx.cooperative_tensor_fill", d, index,
value, rows, cols)
+
+
+def cooperative_tensor_load(
+ d: Var,
+ index: PrimExpr,
+ ptr: PrimExpr,
+ stride: PrimExpr,
+ rows: int,
+ cols: int,
+ transpose_matrix: bool = False,
+ mma_M: int = 0,
+ mma_N: int = 0,
+ mma_K: int = 0,
+ operand_role: int = 0,
+):
+ return call_intrin(
+ "handle",
+ "tirx.cooperative_tensor_load",
+ d,
+ index,
+ ptr,
+ stride,
+ rows,
+ cols,
+ transpose_matrix,
+ mma_M,
+ mma_N,
+ mma_K,
+ operand_role,
+ )
+
+
+def cooperative_tensor_store(
+ d: PrimExpr,
+ index: PrimExpr,
+ ptr: PrimExpr,
+ stride: PrimExpr,
+ rows: int,
+ cols: int,
+ transpose_matrix: bool = False,
+ mma_M: int = 0,
+ mma_N: int = 0,
+ mma_K: int = 0,
+ operand_role: int = 0,
+):
+ return call_intrin(
+ "handle",
+ "tirx.cooperative_tensor_store",
+ d,
+ index,
+ ptr,
+ stride,
+ rows,
+ cols,
+ transpose_matrix,
+ mma_M,
+ mma_N,
+ mma_K,
+ operand_role,
+ )
+
+
+def cooperative_tensor_multiply_accumulate(
+ d: Var,
+ index_d: PrimExpr,
+ a: Var,
+ index_a: PrimExpr,
+ b: Var,
+ index_b: PrimExpr,
+ c: Var,
+ index_c: PrimExpr,
+ M: int,
+ N: int,
+ K: int,
+ transpose_a: bool = False,
+ transpose_b: bool = False,
+):
+ return call_intrin(
+ "handle",
+ "tirx.cooperative_tensor_multiply_accumulate",
+ d,
+ index_d,
+ a,
+ index_a,
+ b,
+ index_b,
+ c,
+ index_c,
+ M,
+ N,
+ K,
+ transpose_a,
+ transpose_b,
+ )
+
+
def vectorlow(dtype, vec):
"""Get the low level half of the vector
diff --git a/python/tvm/tirx/script/builder/ir.py
b/python/tvm/tirx/script/builder/ir.py
index 76f0397a8e..95f1fbea80 100644
--- a/python/tvm/tirx/script/builder/ir.py
+++ b/python/tvm/tirx/script/builder/ir.py
@@ -1965,6 +1965,10 @@ make_filled_simdgroup_matrix =
_op_wrapper(_tir_op.make_filled_simdgroup_matrix)
simdgroup_load = _op_wrapper(_tir_op.simdgroup_load)
simdgroup_store = _op_wrapper(_tir_op.simdgroup_store)
simdgroup_multiply_accumulate =
_op_wrapper(_tir_op.simdgroup_multiply_accumulate)
+cooperative_tensor_fill = _op_wrapper(_tir_op.cooperative_tensor_fill)
+cooperative_tensor_load = _op_wrapper(_tir_op.cooperative_tensor_load)
+cooperative_tensor_store = _op_wrapper(_tir_op.cooperative_tensor_store)
+cooperative_tensor_multiply_accumulate =
_op_wrapper(_tir_op.cooperative_tensor_multiply_accumulate)
create_barriers = _op_wrapper(_tir_op.create_barriers)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
@@ -2255,6 +2259,10 @@ __all__ = float_types + [
"simdgroup_load",
"simdgroup_store",
"simdgroup_multiply_accumulate",
+ "cooperative_tensor_fill",
+ "cooperative_tensor_load",
+ "cooperative_tensor_store",
+ "cooperative_tensor_multiply_accumulate",
"create_barriers",
"mma_store",
"mma_fill",
diff --git a/src/runtime/thread_storage_scope.h
b/src/runtime/thread_storage_scope.h
index 313e4cfe48..0155aa1ffd 100644
--- a/src/runtime/thread_storage_scope.h
+++ b/src/runtime/thread_storage_scope.h
@@ -71,6 +71,8 @@ enum class StorageRank {
kMMAMatrixC = 11,
/*! \brief Metal SIMD group memory */
kMetalSimdGroup = 12,
+ /*! \brief Metal cooperative_tensor memory (MetalPerformancePrimitives) */
+ kMetalCooperativeTensor = 13,
};
/*!
@@ -129,6 +131,8 @@ struct StorageScope {
return "m16n8k8.matrixC" + tag;
case StorageRank::kMetalSimdGroup:
return "metal.simdgroup" + tag;
+ case StorageRank::kMetalCooperativeTensor:
+ return "metal.cooperative_tensor" + tag;
default:
TVM_FFI_THROW(InternalError) << "unknown storage scope";
return "";
@@ -182,6 +186,9 @@ struct StorageScope {
} else if (s.compare(0, 15, "metal.simdgroup") == 0) {
r.rank = StorageRank::kMetalSimdGroup;
r.tag = s.substr(15, std::string::npos);
+ } else if (s.compare(0, 24, "metal.cooperative_tensor") == 0) {
+ r.rank = StorageRank::kMetalCooperativeTensor;
+ r.tag = s.substr(24, std::string::npos);
} else {
TVM_FFI_THROW(InternalError) << "unknown storage scope " << s;
}
diff --git a/src/tirx/op/builtin.cc b/src/tirx/op/builtin.cc
index 4355583d79..7ac487144f 100644
--- a/src/tirx/op/builtin.cc
+++ b/src/tirx/op/builtin.cc
@@ -345,6 +345,18 @@ TIR_DEFINE_BUILTIN_FUNC(simdgroup_store)
TIR_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_fill)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_load)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_store)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_multiply_accumulate)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",