This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c0abab769f [TIR][DLight] Enable SimdGroup op for Metal (#17112)
c0abab769f is described below

commit c0abab769ff152d87f84963f18a98d2f7c9bdf31
Author: Siyuan Feng <hzfen...@sjtu.edu.cn>
AuthorDate: Mon Jun 24 21:24:32 2024 +0800

    [TIR][DLight] Enable SimdGroup op for Metal (#17112)
---
 include/tvm/tir/builtin.h                        |  44 ++-
 python/tvm/dlight/gpu/matmul.py                  | 145 ++++++++++
 python/tvm/script/ir_builder/tir/ir.py           |   8 +
 python/tvm/tir/__init__.py                       |   6 +
 python/tvm/tir/op.py                             | 191 ++++++++++++-
 python/tvm/tir/tensor_intrin/metal.py            | 350 +++++++++++++++++++++++
 src/runtime/thread_storage_scope.h               |   7 +
 src/target/source/codegen_metal.cc               |  82 +++++-
 src/target/source/codegen_metal.h                |   3 +
 src/tir/op/builtin.cc                            |  12 +
 tests/python/dlight/test_gpu_matmul_tensorize.py | 283 +++++++++++++++++-
 11 files changed, 1124 insertions(+), 7 deletions(-)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 5836eb8ea9..120c1b71be 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -746,7 +746,7 @@ TVM_DLL const Op& create_barriers();
 TVM_DLL const Op& mma_store();
 
 /*!
- * \brief tvm intrinsic for zero-initalizing an MMA accumulation registor.
+ * \brief tvm intrinsic for zero-initializing an MMA accumulation register.
  *        For example, if each thread in a warp of size 32 has 8 elements from 
the A matrix in
  *        m16xn8xk16 MMA in its registers, this intrinsic can be used to 
zero-initialize its
  *        4 accumulation registers.
@@ -758,6 +758,48 @@ TVM_DLL const Op& mma_store();
  */
 TVM_DLL const Op& mma_fill();
 
+// Metal SimdGroup matrix intrinsics
+
+/*!
+ * \brief tvm intrinsic for initializing and simdgroup with given value.
+ * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep 
shape as params,
+ *       keeping the similar interface with Metal Spec.
+ *
+ * void make_filled_simdgroup_matrix(Var d, PrimExpr index, PrimExpr value,
+ *                                   int col = 8, int row = 8);
+ */
+TVM_DLL const Op& make_filled_simdgroup_matrix();
+
+/*!
+ * \brief tvm intrinsic for loading data from device memory or threadgroup 
memory to simdgroup.
+ * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep 
shape as params,
+ *       keeping the similar interface with Metal Spec.
+ *
+ * void simdgroup_load(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride,
+                       int col = 8, int row = 8, bool transpose_matrix = 
false);
+ */
+TVM_DLL const Op& simdgroup_load();
+
+/*!
+ * \brief tvm intrinsic for storing data from simdgroup to device memory or 
threadgroup memory.
+ * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep 
shape as params,
+ *       keeping the similar interface with Metal Spec.
+ *
+ * void simdgroup_store(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride,
+ *                      int col = 8, int row = 8, bool transpose_matrix = 
false);
+ */
+TVM_DLL const Op& simdgroup_store();
+
+/*!
+ * \brief tvm intrinsic for multiply and accumulate two matrices in simdgroup
+ * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep 
shape as params,
+ *       keeping the similar interface with Metal Spec.
+ *
+ * void simdgroup_mma(Var d, PrimExpr index_d, Var a, PrimExpr index_a,
+ *                    Var b, PrimExpr index_b, Var c, PrimExpr index_c);
+ */
+TVM_DLL const Op& simdgroup_multiply_accumulate();
+
 // TODO(tvm-team) replace the usage of the vector operations by Shuffle.
 /*!
  * \brief Get the high level half of the vector
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index f4ef1f5044..a5759941ca 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -313,6 +313,146 @@ def check_sm_version(arch: str) -> int:
     return int(sm_version) if sm_version.isdigit() else -1
 
 
+class MetalMatmul(GPUScheduleRule):
+    """
+    The schedule rule for Metal matmul computation.
+    """
+
+    def apply(  # pylint: disable=too-many-locals,missing-docstring
+        self,
+        func: tir.PrimFunc,
+        target: Target,
+        _: bool,
+    ) -> Optional[tir.Schedule]:
+        from tvm.tir.tensor_intrin.metal import (  # pylint: 
disable=import-outside-toplevel
+            get_simdgroup_intrin_group,
+        )
+
+        if not isinstance(func, tir.PrimFunc) or not 
self.is_target_available(target):
+            return None
+        sch = tir.Schedule(func)
+        root_block = analysis.get_root_block(sch)
+        blocks = sch.get_child_blocks(root_block)
+
+        reduction_blocks = get_reduction_blocks(sch, blocks)
+        if reduction_blocks is None:
+            return None
+
+        main_block = reduction_blocks[0]
+        block_stmt = sch.get(main_block)
+        index_maps = get_index_map(block_stmt)
+        if index_maps is None:
+            return None
+        matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+        # Step 0. Configs
+        block_size_x: int = 16
+        block_size_y: int = 16
+        block_size_k: int = 32
+        micro_size: int = 8
+        warp_size: int = 32
+        ty_len: int = 1
+        tz_len: int = 4
+        vector_size: int = 4
+
+        # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, 
J, K]
+        block = sch.reindex(main_block, ("read", 0))
+        sch.transform_layout(block, ("write", 0), a_index_map)
+        block = sch.reindex(main_block, ("read", 1))
+        sch.transform_layout(block, ("write", 0), b_index_map)
+        block = sch.reindex(main_block, ("write", 0))
+        sch.transform_layout(block, ("read", 0), c_index_map)
+        sch.transform_block_layout(main_block, matmul_index_map)
+
+        # Step 2. Padding for dynamic shape kernels
+        sch.pad_einsum(
+            main_block,
+            [
+                1,
+                ty_len * block_size_x,
+                tz_len * block_size_y,
+                block_size_k,
+            ],
+        )
+
+        # Step 3. Schedule matmul to use simdgroup intrinsics
+        batch, i, j, k = sch.get_loops(main_block)
+        bx, ty, i0, i1 = sch.split(i, [None, ty_len, block_size_x // 
micro_size, micro_size])
+        by, tz, j0, j1 = sch.split(j, [None, tz_len, block_size_y // 
micro_size, micro_size])
+        k0, k1, k2 = sch.split(k, [None, block_size_k // micro_size, 
micro_size])
+        sch.reorder(bx, by, ty, tz, k0, k1, i0, j0, i1, j1, k2)
+        sch.bind(bx, "blockIdx.x")
+        sch.bind(by, "blockIdx.y")
+        sch.bind(batch, "blockIdx.z")
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tz, "threadIdx.z")
+
+        def fetch_to_shared(block, idx):
+            block_read = sch.cache_read(block, idx, "shared")
+            sch.compute_at(block_read, k0, preserve_unit_loops=True)
+            fused = sch.fuse(*sch.get_loops(block_read)[-2:])
+            _, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len, 
warp_size, vector_size])
+
+            sch.bind(_tz, "threadIdx.z")
+            sch.bind(_ty, "threadIdx.y")
+            sch.bind(_tx, "threadIdx.x")
+            sch.vectorize(vec)
+
+            return block_read
+
+        a_g2s = fetch_to_shared(main_block, 0)
+        b_g2s = fetch_to_shared(main_block, 1)
+
+        auto_inline_producers(sch, a_g2s)
+        auto_inline_producers(sch, b_g2s)
+
+        # create read cache to load matrix from shared memory to wmma fragments
+        A_simdgroup = sch.cache_read(main_block, 0, "metal.simdgroup")
+        B_simdgroup = sch.cache_read(main_block, 1, "metal.simdgroup")
+        sch.compute_at(A_simdgroup, k1)
+        sch.compute_at(B_simdgroup, k1)
+
+        C_simd2s = sch.cache_write(main_block, 0, "metal.simdgroup")
+        C_s2g = sch.cache_write(C_simd2s, 0, "shared")
+        sch.reverse_compute_at(C_simd2s, tz, preserve_unit_loops=True)
+        sch.reverse_compute_at(C_s2g, by, preserve_unit_loops=True)
+
+        intrin_group = get_simdgroup_intrin_group(
+            load_scope="shared",
+            store_scope="shared",
+            dtype="float16",
+            trans_a=False,
+            trans_b=True,
+        )
+        sch.transform_layout(B_simdgroup, ("write", 0), lambda s, i, j: (s, j, 
i))
+
+        def tensorize_block(block: tir.schedule.BlockRV, intrin: str):
+            *_, i, j = sch.get_loops(block)
+            io, ii = sch.split(i, [None, micro_size])
+            jo, ji = sch.split(j, [None, micro_size])
+            sch.reorder(io, jo, ii, ji)
+            sch.tensorize(ii, intrin)
+
+        C_init = sch.decompose_reduction(main_block, k0)
+        tensorize_block(A_simdgroup, intrin_group["load_a"])
+        tensorize_block(B_simdgroup, intrin_group["load_b"])
+        tensorize_block(C_simd2s, intrin_group["store"])
+        tensorize_block(C_init, intrin_group["init"])
+
+        *_, i, j, k = sch.get_loops(main_block)
+        sch.tensorize(i, intrin_group["compute"])
+
+        auto_inline_consumer_chain(sch, C_s2g)
+        fused = sch.fuse(*sch.get_loops(C_s2g)[-2:])
+        _, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len, 
warp_size, vector_size])
+        sch.bind(_tz, "threadIdx.z")
+        sch.bind(_ty, "threadIdx.y")
+        sch.bind(_tx, "threadIdx.x")
+        sch.vectorize(vec)
+
+        return sch
+
+
 class MatmulTensorization(GPUScheduleRule):
     """
     The schedule rule for float16 tensor core matmul computation.
@@ -848,6 +988,11 @@ class Matmul(GPUScheduleRule):
                     tensorize_sch = MatmulTensorization().apply(func, target, 
_)
                 if tensorize_sch is not None:
                     return tensorize_sch
+        elif target.kind.name == "metal":
+            try:
+                return MetalMatmul().apply(func, target, _)
+            except:  # pylint: disable=bare-except
+                pass
 
         # Step 2. Get schedule config.
         config = self.get_configs(target)
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index 18abc0ca5d..caefc6a6bc 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1887,6 +1887,10 @@ ptx_init_barrier_thread_count = 
_op_wrapper(_tir_op.ptx_init_barrier_thread_coun
 ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
 ptx_arrive_barrier_expect_tx = 
_op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
 ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
+make_filled_simdgroup_matrix = 
_op_wrapper(_tir_op.make_filled_simdgroup_matrix)
+simdgroup_load = _op_wrapper(_tir_op.simdgroup_load)
+simdgroup_store = _op_wrapper(_tir_op.simdgroup_store)
+simdgroup_multiply_accumulate = 
_op_wrapper(_tir_op.simdgroup_multiply_accumulate)
 create_barriers = _op_wrapper(_tir_op.create_barriers)
 assume = _op_wrapper(_tir_op.assume)
 undef = _op_wrapper(_tir_op.undef)
@@ -2177,6 +2181,10 @@ __all__ = [
     "ptx_arrive_barrier",
     "ptx_arrive_barrier_expect_tx",
     "ptx_wait_barrier",
+    "make_filled_simdgroup_matrix",
+    "simdgroup_load",
+    "simdgroup_store",
+    "simdgroup_multiply_accumulate",
     "create_barriers",
     "mma_store",
     "mma_fill",
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 0fee976eb1..5360ab2b96 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -73,6 +73,12 @@ from .op import (
     ptx_wait_barrier,
     create_barriers,
 )
+from .op import (
+    make_filled_simdgroup_matrix,
+    simdgroup_load,
+    simdgroup_multiply_accumulate,
+    simdgroup_store,
+)
 from .op import vectorlow, vectorhigh, vectorcombine
 from .op import infinity, reinterpret
 from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 95a85ab77d..81d6604259 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=redefined-builtin, invalid-name
+# pylint: disable=redefined-builtin, invalid-name, too-many-arguments
 """Operators used in TIR expression."""
 from typing import Any, Optional, Union
 
@@ -1567,6 +1567,195 @@ def create_barriers(barrier_count):
     return call_intrin("", "tir.create_barriers", barrier_count)
 
 
+def make_filled_simdgroup_matrix(
+    d: Var,
+    index: PrimExpr,
+    value: PrimExpr,
+    col: int = 8,
+    row: int = 8,
+):
+    """Create a filled SIMDGroup matrix
+
+    Parameters
+    ----------
+    d : var
+        The simdgroup var
+
+    index : PrimExpr
+        The index of the matrix.
+
+    value : PrimExpr
+        The value to fill.
+
+    col : int
+        The number of columns.
+
+    row : int
+        The number of rows.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.make_filled_simdgroup_matrix", d, index, 
value, col, row)
+
+
+def simdgroup_load(
+    d: Var,
+    index: PrimExpr,
+    ptr: PrimExpr,
+    stride: PrimExpr,
+    col: int = 8,
+    row: int = 8,
+    transpose_matrix: bool = False,
+):
+    """Load data from device memory or threadgroup memory to simdgroup
+
+    Parameters
+    ----------
+    d : var
+        The simdgroup var
+
+    index : PrimExpr
+        The index of the matrix.
+
+    ptr : PrimExpr
+        The pointer.
+
+    stride : PrimExpr
+        The stride.
+
+    col : int
+        The number of columns.
+
+    row : int
+        The number of rows.
+
+    transpose_matrix : bool
+        Whether to transpose the matrix.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        "handle",
+        "tir.simdgroup_load",
+        d,
+        index,
+        ptr,
+        stride,
+        col,
+        row,
+        transpose_matrix,
+    )
+
+
+def simdgroup_store(
+    d: PrimExpr,
+    index: PrimExpr,
+    ptr: PrimExpr,
+    stride: PrimExpr,
+    col: int = 8,
+    row: int = 8,
+    transpose_matrix: bool = False,
+):
+    """Store data from simdgroup to device memory or threadgroup memory
+
+    Parameters
+    ----------
+    d : PrimExpr
+        The SIMDGroup.
+
+    index : PrimExpr
+        The index of the matrix.
+
+    ptr : PrimExpr
+        The pointer.
+
+    stride : PrimExpr
+        The stride.
+
+    col : int
+        The number of columns.
+
+    row : int
+        The number of rows.
+
+
+    transpose_matrix : bool
+        Whether to transpose the matrix.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        "handle", "tir.simdgroup_store", d, index, ptr, stride, col, row, 
transpose_matrix
+    )
+
+
+def simdgroup_multiply_accumulate(
+    d: Var,
+    index_d: PrimExpr,
+    a: Var,
+    index_a: PrimExpr,
+    b: Var,
+    index_b: PrimExpr,
+    c: Var,
+    index_c: PrimExpr,
+):
+    """Multiply and accumulate two matrices in simdgroup
+    i.e. d = a * b + c
+
+    Parameters
+    ----------
+    d : Var
+        The destination matrix.
+
+    index_d : PrimExpr
+        The index of the destination matrix.
+
+    a : Var
+        The first matrix.
+
+    index_a : PrimExpr
+        The index of the first matrix.
+
+    b : Var
+        The second matrix.
+
+    index_b : PrimExpr
+        The index of the second matrix.
+
+    c : Var
+        The third matrix.
+
+    index_c : PrimExpr
+        The index of the third matrix.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        "handle",
+        "tir.simdgroup_multiply_accumulate",
+        d,
+        index_d,
+        a,
+        index_a,
+        b,
+        index_b,
+        c,
+        index_c,
+    )
+
+
 def vectorlow(dtype, vec):
     """Get the low level half of the vector
 
diff --git a/python/tvm/tir/tensor_intrin/metal.py 
b/python/tvm/tir/tensor_intrin/metal.py
new file mode 100644
index 0000000000..be34a9e266
--- /dev/null
+++ b/python/tvm/tir/tensor_intrin/metal.py
@@ -0,0 +1,350 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,missing-function-docstring,unused-variable
+"""Intrinsics for tensorization on Apple GPU."""
+from typing import Dict, Literal, Tuple
+
+from tvm.script import tir as T
+from tvm.tir import Buffer, PrimExpr, PrimFunc, TensorIntrin
+
+######## simdgroup matrix intrinsics ########
+
+
+def get_simdgroup_index(buffer: Buffer, stride: PrimExpr, col: int, row: int):
+    """Compute simdgroup index using elem_offset of the buffer"""
+
+    # NOTE: Need further check the usage between `col`` and `row`
+    # Currently, Metal only supports 8x8, which means the values of `col` and 
`row` are the same
+    frag_index_m = buffer.elem_offset // stride // col
+    frag_index_n = buffer.elem_offset % stride // row
+
+    num_fragments_per_row = stride // row
+    return frag_index_m * num_fragments_per_row + frag_index_n
+
+
+def get_make_filled_simdgroup_matrix_intrin(
+    dtype: str, col: int = 8, row: int = 8
+) -> Tuple[PrimFunc, PrimFunc]:
+    @T.prim_func
+    def desc(a: T.handle) -> None:
+        A = T.match_buffer(a, (col, row), dtype, scope="metal.simdgroup", 
offset_factor=1)
+        with T.block("root"):
+            T.reads()
+            T.writes(A[0:col, 0:row])
+            for i, j in T.grid(col, row):
+                with T.block("init"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    A[vi, vj] = T.float32(0)
+
+    @T.prim_func
+    def impl(a: T.handle) -> None:
+        d0, d1 = T.int32(), T.int32()
+        A = T.match_buffer(
+            a, (col, row), dtype, scope="metal.simdgroup", strides=[d1, d0], 
offset_factor=1
+        )
+        with T.block("root"):
+            T.reads()
+            T.writes(A[0:col, 0:row])
+            T.make_filled_simdgroup_matrix(
+                A.data,
+                index=get_simdgroup_index(A, d1, col, row),
+                value=T.float32(0),
+                col=col,
+                row=row,
+            )
+
+    return desc, impl
+
+
+def get_simdgroup_load_intrin(
+    dtype: str,
+    scope: Literal["global", "shared"],
+    col: int = 8,
+    row: int = 8,
+    transpose_matrix: bool = False,
+) -> Tuple[PrimFunc, PrimFunc]:
+    align = col * row
+
+    @T.prim_func
+    def desc(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (col, row), dtype, align=align, scope=scope, 
offset_factor=1)
+        C = T.match_buffer(
+            c, (col, row), dtype, align=align, scope="metal.simdgroup", 
offset_factor=1
+        )
+        with T.block("root"):
+            T.reads(A[0:col, 0:row])
+            T.writes(C[0:col, 0:row])
+            for i, j in T.grid(col, row):
+                with T.block("load"):
+                    vii, vjj = T.axis.remap("SS", [i, j])
+                    if transpose_matrix:
+                        # C[vii, vjj] = A[vjj, vii]
+                        C[vjj, vii] = A[vii, vjj]
+                    else:
+                        C[vii, vjj] = A[vii, vjj]
+
+    @T.prim_func
+    def impl(a: T.handle, c: T.handle) -> None:
+        s0, s1, d0, d1 = T.int32(), T.int32(), T.int32(), T.int32()
+        A = T.match_buffer(
+            a,
+            (col, row),
+            dtype,
+            align=align,
+            scope=scope,
+            strides=[s1, s0],
+            offset_factor=1,
+        )
+        C = T.match_buffer(
+            c,
+            (col, row),
+            dtype,
+            align=align,
+            scope="metal.simdgroup",
+            strides=[d1, d0],
+            offset_factor=1,
+        )
+        with T.block("root"):
+            T.reads(A[0:col, 0:row])
+            T.writes(C[0:col, 0:row])
+            T.simdgroup_load(
+                C.data,
+                index=get_simdgroup_index(C, d1, col, row),
+                ptr=A.access_ptr("r"),
+                stride=s1,
+                col=col,
+                row=row,
+                transpose_matrix=transpose_matrix,
+            )
+
+    return desc, impl
+
+
+def get_simdgroup_store_intrin(
+    dtype: str,
+    scope: Literal["global", "shared"],
+    col: int = 8,
+    row: int = 8,
+    transpose_matrix: bool = False,
+) -> Tuple[PrimFunc, PrimFunc]:
+    align = col * row
+
+    @T.prim_func
+    def desc(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(
+            a, (col, row), dtype, align=align, scope="metal.simdgroup", 
offset_factor=1
+        )
+        C = T.match_buffer(c, (col, row), dtype, align=align, scope=scope, 
offset_factor=1)
+        with T.block("root"):
+            T.reads(A[0:col, 0:row])
+            T.writes(C[0:col, 0:row])
+            for i, j in T.grid(col, row):
+                with T.block("store"):
+                    vii, vjj = T.axis.remap("SS", [i, j])
+                    if transpose_matrix:
+                        C[vjj, vii] = A[vii, vjj]
+                    else:
+                        C[vii, vjj] = A[vii, vjj]
+
+    @T.prim_func
+    def impl(a: T.handle, c: T.handle) -> None:
+        s0, s1, d0, d1 = T.int32(), T.int32(), T.int32(), T.int32()
+        A = T.match_buffer(
+            a,
+            (col, row),
+            dtype,
+            align=align,
+            scope="metal.simdgroup",
+            strides=[s1, s0],
+            offset_factor=1,
+        )
+        C = T.match_buffer(
+            c, (col, row), dtype, align=align, scope=scope, strides=[d1, d0], 
offset_factor=1
+        )
+        with T.block("root"):
+            T.reads(A[0:col, 0:row])
+            T.writes(C[0:col, 0:row])
+            T.simdgroup_store(
+                A.data,
+                index=get_simdgroup_index(A, s1, col, row),
+                ptr=C.access_ptr("w"),
+                stride=d1,
+                col=col,
+                row=row,
+                transpose_matrix=transpose_matrix,
+            )
+
+    return desc, impl
+
+
+def get_simdgroup_multiply_accumulate_intrin(
+    m_dim: int, n_dim: int, k_dim: int, dtype: str
+) -> Tuple[PrimFunc, PrimFunc]:
+    @T.prim_func
+    def desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (m_dim, k_dim), dtype, scope="metal.simdgroup", 
offset_factor=1)
+        B = T.match_buffer(b, (k_dim, n_dim), dtype, scope="metal.simdgroup", 
offset_factor=1)
+        C = T.match_buffer(c, (m_dim, n_dim), dtype, scope="metal.simdgroup", 
offset_factor=1)
+        with T.block("root"):
+            T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:k_dim, 
0:n_dim])
+            T.writes(C[0:m_dim, 0:n_dim])
+            for i, j, k in T.grid(m_dim, n_dim, k_dim):
+                with T.block(""):
+                    vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
+                    C[vii, vjj] += A[vii, vkk] * B[vkk, vjj]
+
+    @T.prim_func
+    def impl(a: T.handle, b: T.handle, c: T.handle) -> None:
+        a0, a1, b0, b1, c0, c1 = T.int32(), T.int32(), T.int32(), T.int32(), 
T.int32(), T.int32()
+        A = T.match_buffer(
+            a, (m_dim, k_dim), dtype, scope="metal.simdgroup", strides=[a1, 
a0], offset_factor=1
+        )
+        B = T.match_buffer(
+            b, (k_dim, n_dim), dtype, scope="metal.simdgroup", strides=[b1, 
b0], offset_factor=1
+        )
+        C = T.match_buffer(
+            c, (m_dim, n_dim), dtype, scope="metal.simdgroup", strides=[c1, 
c0], offset_factor=1
+        )
+        with T.block("root"):
+            T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:k_dim, 
0:n_dim])
+            T.writes(C[0:m_dim, 0:n_dim])
+            T.simdgroup_multiply_accumulate(
+                C.data,
+                get_simdgroup_index(C, c1, m_dim, n_dim),
+                A.data,
+                get_simdgroup_index(A, a1, m_dim, k_dim),
+                B.data,
+                get_simdgroup_index(B, b1, k_dim, n_dim),
+                C.data,
+                get_simdgroup_index(C, c1, m_dim, n_dim),
+            )
+
+    return desc, impl
+
+
+# Make filled simdgroup matrix intrinsics
+
+SIMDGROUP_MAKE_FILLED_8x8x8_f16_INTRIN = "simdgroup_make_filled_8x8x8_f16"
+TensorIntrin.register(
+    SIMDGROUP_MAKE_FILLED_8x8x8_f16_INTRIN,
+    *get_make_filled_simdgroup_matrix_intrin("float16", 8, 8),
+)
+
+SIMDGROUP_FILLED_8x8x8_f32_INTRIN = "simdgroup_fill_8x8x8_f32"
+TensorIntrin.register(
+    SIMDGROUP_FILLED_8x8x8_f32_INTRIN, 
*get_make_filled_simdgroup_matrix_intrin("float32", 8, 8)
+)
+
+SIMDGROUP_FILLED_8x8x8_bf16_INTRIN = "simdgroup_fill_8x8x8_bf16"
+TensorIntrin.register(
+    SIMDGROUP_FILLED_8x8x8_bf16_INTRIN, 
*get_make_filled_simdgroup_matrix_intrin("bfloat16", 8, 8)
+)
+
+# Load intrinsics
+
+SIMDGROUP_LOAD_8x8x8_f16_SHARED_INTRIN = "simdgroup_load_8x8x8_f16_shared"
+TensorIntrin.register(
+    SIMDGROUP_LOAD_8x8x8_f16_SHARED_INTRIN,
+    *get_simdgroup_load_intrin("float16", "shared", 8, 8, False),
+)
+
+SIMDGROUP_LOAD_8x8x8_f16_SHARED_TRANS_INTRIN = 
"simdgroup_load_8x8x8_f16_shared_trans"
+TensorIntrin.register(
+    SIMDGROUP_LOAD_8x8x8_f16_SHARED_TRANS_INTRIN,
+    *get_simdgroup_load_intrin("float16", "shared", 8, 8, True),
+)
+
+# Store intrinsics
+
+SIMDGROUP_STORE_8x8x8_f16_GLOBAL_INTRIN = "simdgroup_store_8x8x8_f16_global"
+TensorIntrin.register(
+    SIMDGROUP_STORE_8x8x8_f16_GLOBAL_INTRIN,
+    *get_simdgroup_store_intrin("float16", "global", 8, 8, False),
+)
+
+SIMDGROUP_STORE_8x8x8_f16_SHARED_INTRIN = "simdgroup_store_8x8x8_f16_shared"
+TensorIntrin.register(
+    SIMDGROUP_STORE_8x8x8_f16_SHARED_INTRIN,
+    *get_simdgroup_store_intrin("float16", "shared", 8, 8, False),
+)
+# Multiply accumulate intrinsics
+
+SIMDGROUP_MULTI_ACC_8x8x8_f16_INTRIN = 
"simdgroup_multiply_accumulate_8x8x8_f16"
+TensorIntrin.register(
+    SIMDGROUP_MULTI_ACC_8x8x8_f16_INTRIN,
+    *get_simdgroup_multiply_accumulate_intrin(8, 8, 8, "float16"),
+)
+
+
+def get_simdgroup_intrin_group(
+    load_scope: Literal["shared"],
+    store_scope: Literal["global", "shared"],
+    dtype: str,
+    trans_a: bool = False,
+    trans_b: bool = False,
+) -> Dict[str, str]:
+    """Get a group of intrinsics for tensorization on Apple GPU.
+
+    Parameters
+    ----------
+    load_scope : Literal["shared"]
+        The memory scope of the input buffer.
+
+    store_scope : Literal["global", "shared"]
+        The memory scope of the result buffer.
+
+    dtype : str
+        The data type of the input and output buffers.
+
+    trans_a : bool
+        Whether the input matrix A is transposed.
+
+    trans_b : bool
+        Whether the input matrix B is transposed.
+
+    Returns
+    -------
+    ret : Dict[str, str]
+        A group of tensor intrinsics.
+    """
+    assert load_scope in ["shared"]
+    assert store_scope in ["global", "shared"]
+    assert dtype in ["float16", "bfloat16", "float32"]
+
+    shape = "8x8x8"
+    dtype = "f16" if dtype == "float16" else "bf16" if dtype == "bfloat16" 
else "f32"
+    trans_a = "_trans" if trans_a else ""
+    trans_b = "_trans" if trans_b else ""
+
+    # e.g. simdgroup_load_8x8x8_f16_shared
+    load_a_intrin = f"simdgroup_load_{shape}_{dtype}_{load_scope}{trans_a}"
+    # e.g. simdgroup_load_8x8x8_f16_shared_trans
+    load_b_intrin = f"simdgroup_load_{shape}_{dtype}_{load_scope}{trans_b}"
+    # e.g. simdgroup_multiply_accumulate_8x8x8_f16
+    compute_intrin = f"simdgroup_multiply_accumulate_{shape}_{dtype}"
+    # e.g. simdgroup_make_filled_8x8x8_f16
+    init_intrin = f"simdgroup_make_filled_{shape}_{dtype}"
+    # e.g. simdgroup_store_8x8x8_f16_global
+    store_intrin = f"simdgroup_store_{shape}_{dtype}_{store_scope}"
+
+    return {
+        "init": init_intrin,
+        "load_a": load_a_intrin,
+        "load_b": load_b_intrin,
+        "compute": compute_intrin,
+        "store": store_intrin,
+    }
diff --git a/src/runtime/thread_storage_scope.h 
b/src/runtime/thread_storage_scope.h
index 747b905812..d1af2cb701 100644
--- a/src/runtime/thread_storage_scope.h
+++ b/src/runtime/thread_storage_scope.h
@@ -70,6 +70,8 @@ enum class StorageRank {
   kMMAMatrixB = 10,
   /*! \brief mma scope memory of accumulator */
   kMMAMatrixC = 11,
+  /*! \brief Metal SIMD group memory */
+  kMetalSimdGroup = 12,
 };
 
 /*!
@@ -126,6 +128,8 @@ struct StorageScope {
         return "m16n8k8.matrixB" + tag;
       case StorageRank::kMMAMatrixC:
         return "m16n8k8.matrixC" + tag;
+      case StorageRank::kMetalSimdGroup:
+        return "metal.simdgroup" + tag;
       default:
         LOG(FATAL) << "unknown storage scope";
     }
@@ -175,6 +179,9 @@ struct StorageScope {
     } else if (s.compare(0, 15, "m16n8k8.matrixC") == 0) {
       r.rank = StorageRank::kMMAMatrixC;
       r.tag = s.substr(15, std::string::npos);
+    } else if (s.compare(0, 15, "metal.simdgroup") == 0) {
+      r.rank = StorageRank::kMetalSimdGroup;
+      r.tag = s.substr(15, std::string::npos);
     } else {
       LOG(FATAL) << "unknown storage scope " << s;
     }
diff --git a/src/target/source/codegen_metal.cc 
b/src/target/source/codegen_metal.cc
index e729af417c..2908514988 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -25,10 +25,10 @@
 #include <tvm/tir/transform.h>
 
 #include <algorithm>
+#include <sstream>
 #include <string>
 #include <unordered_map>
 #include <utility>
-#include <vector>
 
 #include "../../runtime/metal/metal_module.h"
 #include "../../runtime/thread_storage_scope.h"
@@ -262,6 +262,9 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) 
{  // NOLINT(*)
       os << lanes;
       return;
     }
+  } else if (t.is_bfloat16()) {
+    os << "bfloat";
+    return;
   }
   LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
 }
@@ -296,9 +299,43 @@ void CodeGenMetal::PrintStorageScope(const std::string& 
scope, std::ostream& os)
     os << "device ";
   } else if (scope == "shared") {
     os << "threadgroup ";
-  } else {
+  } else if (scope == "local") {
     os << "thread ";
+  } else {
+    LOG(FATAL) << "Unknown storage scope `" << scope << "`";
+  }
+}
+
+void CodeGenMetal::VisitStmt_(const AllocateNode* op) {
+  ICHECK(!is_zero(op->condition));
+  std::string vid = AllocVarID(op->buffer_var.get());
+
+  this->PrintIndent();
+  size_t constant_size = op->ConstantAllocationSize();
+  ICHECK_GT(constant_size, 0) << "Can only handle constant size stack 
allocation for now";
+
+  auto scope = GetPtrStorageScope(op->buffer_var);
+  alloc_storage_scope_[op->buffer_var.get()] = scope;
+  if (scope == "metal.simdgroup") {
+    ICHECK(op->dtype == DataType::Float(16) || op->dtype == 
DataType::Float(32) ||
+           op->dtype == DataType::BFloat(16))
+        << "Only float16, float32, and bfloat16 are supported, but got " << 
op->dtype;
+    ICHECK(constant_size % 64 == 0)
+        << "Only 8x8 matrix is supported, but got " << constant_size << " 
bytes\n";
+
+    std::ostringstream dtype_os;
+    PrintType(op->dtype, dtype_os);
+    std::string dtype_str = dtype_os.str();
+    simdgroup_dtype_[op->buffer_var.get()] = dtype_str;
+    stream << "simdgroup_" << dtype_str << "8x8 " << vid << '[' << 
constant_size / 64 << "];\n";
+  } else {
+    PrintStorageScope(scope, stream);
+    PrintType(op->dtype, stream);
+    stream << ' ' << vid << '[' << constant_size << "];\n";
   }
+
+  RegisterHandleType(op->buffer_var.get(), op->dtype);
+  this->PrintStmt(op->body);
 }
 
 void CodeGenMetal::VisitExpr_(const SelectNode* op, std::ostream& os) {  // 
NOLINT(*)
@@ -322,7 +359,46 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, 
std::ostream& os) {  // NOLINT
   CHECK(!op->op.as<GlobalVarNode>())
       << "CodegenMetal does not support inter-function calls, "
       << "but expression " << GetRef<Call>(op) << " calls PrimFunc " << op->op;
-  if (op->op.same_as(builtin::reinterpret())) {
+  auto f_check_simdgroup_shape = [](PrimExpr col, PrimExpr row) {
+    ICHECK(col->IsInstance<IntImmNode>() && row->IsInstance<IntImmNode>())
+        << "Only constant shape is supported for simdgroup matrix, but got " 
<< col << "x" << row;
+    int col_val = col.as<IntImmNode>()->value;
+    int row_val = row.as<IntImmNode>()->value;
+    ICHECK(col_val == 8 && row_val == 8)
+        << "Only 8x8 matrix is supported, but got " << col_val << "x" << 
row_val;
+  };
+  if (op->op.same_as(builtin::make_filled_simdgroup_matrix())) {
+    ICHECK_EQ(op->args.size(), 5);
+    Var var = runtime::Downcast<Var>(op->args[0]);
+    // Get the data type of the simdgroup matrix
+    auto it = simdgroup_dtype_.find(var.get());
+    ICHECK(it != simdgroup_dtype_.end())
+        << "Cannot find variable allocation for simdgroup: " << var;
+    const std::string& dtype_str = it->second;
+    f_check_simdgroup_shape(op->args[3], op->args[4]);
+    os << PrintExpr(var) << "[" << PrintExpr(op->args[1]) << "] = 
make_filled_simdgroup_matrix<"
+       << dtype_str << ", " << PrintExpr(op->args[3]) << ", " << 
PrintExpr(op->args[4]) << ">("
+       << PrintExpr(op->args[2]) << ")";
+  } else if (op->op.same_as(builtin::simdgroup_load())) {
+    ICHECK_EQ(op->args.size(), 7);
+    f_check_simdgroup_shape(op->args[4], op->args[5]);
+    os << "simdgroup_load(" << PrintExpr(op->args[0]) << "[" << 
PrintExpr(op->args[1]) << "], "
+       << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, "
+       << PrintExpr(op->args[6]) << ")";
+  } else if (op->op.same_as(builtin::simdgroup_store())) {
+    ICHECK_EQ(op->args.size(), 7);
+    f_check_simdgroup_shape(op->args[4], op->args[5]);
+    os << "simdgroup_store(" << PrintExpr(op->args[0]) << "[" << 
PrintExpr(op->args[1]) << "], "
+       << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, "
+       << PrintExpr(op->args[6]) << ")";
+  } else if (op->op.same_as(builtin::simdgroup_multiply_accumulate())) {
+    ICHECK_EQ(op->args.size(), 8);
+    os << "simdgroup_multiply_accumulate("                                  //
+       << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], "  //
+       << PrintExpr(op->args[2]) << "[" << PrintExpr(op->args[3]) << "], "  //
+       << PrintExpr(op->args[4]) << "[" << PrintExpr(op->args[5]) << "], "  //
+       << PrintExpr(op->args[6]) << "[" << PrintExpr(op->args[7]) << "])";
+  } else if (op->op.same_as(builtin::reinterpret())) {
     // generate as_type<TYPE>(ARG)
     os << "(as_type<";
     this->PrintType(op->dtype, os);
diff --git a/src/target/source/codegen_metal.h 
b/src/target/source/codegen_metal.h
index 9cff3211ce..9bc0e15d15 100644
--- a/src/target/source/codegen_metal.h
+++ b/src/target/source/codegen_metal.h
@@ -27,6 +27,7 @@
 #include <tvm/target/codegen.h>
 
 #include <string>
+#include <unordered_map>
 
 #include "codegen_c.h"
 
@@ -50,6 +51,7 @@ class CodeGenMetal final : public CodeGenC {
   // print store of single element.
   void PrintVecElemStore(const std::string& vec, DataType t, int i, const 
std::string& value) final;
   // overload visitor
+  void VisitStmt_(const AllocateNode* op) final;                     // 
NOLINT(*)
   void VisitExpr_(const SelectNode* op, std::ostream& os) final;     // 
NOLINT(*)
   void VisitExpr_(const BroadcastNode* op, std::ostream& os) final;  // 
NOLINT(*)
   void VisitExpr_(const CallNode* op, std::ostream& os) final;       // 
NOLINT(*)
@@ -59,6 +61,7 @@ class CodeGenMetal final : public CodeGenC {
   using CodeGenC::PrintType;
 
  private:
+  std::unordered_map<const VarNode*, std::string> simdgroup_dtype_;
   int thread_index_bits_{32};
   int thread_work_dim_{0};
   Target target_;
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 67d01aa923..0404fd2823 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -328,6 +328,18 @@ TIR_DEFINE_BUILTIN_FUNC(mma_fill)
     .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
                                          
Integer(ScriptDtypePrintLocation::kFirst));
 
+TIR_DEFINE_BUILTIN_FUNC(make_filled_simdgroup_matrix)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(simdgroup_load)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(simdgroup_store)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure))
     .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py 
b/tests/python/dlight/test_gpu_matmul_tensorize.py
index 095447766e..59ccfec55c 100644
--- a/tests/python/dlight/test_gpu_matmul_tensorize.py
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -14,12 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=missing-docstring
+# pylint: disable=missing-docstring, unused-variable, invalid-name
+# flake8: noqa: E501
 import pytest
 
 import tvm.testing
 from tvm import dlight as dl
-from tvm.script import ir as I
 from tvm.script import tir as T
 from tvm.target import Target
 
@@ -698,5 +698,284 @@ class TestMatmulInt8Tensorize3d2dDyn(BaseBeforeAfter):
     # fmt: on
 
 
+class MetalBeforeAfter(tvm.testing.CompareBeforeAfter):
+    @pytest.fixture
+    def transform(self):
+        def transform(mod):
+            with Target("metal"):
+                return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
+
+        return transform
+
+
+class TestMatmulMetal(MetalBeforeAfter):
+    # fmt: off
+    @T.prim_func(private=True)
+    def before(
+        var_A: T.handle,
+        B: T.Buffer((28672, 4096), "float16"),
+        var_C: T.handle,
+    ):
+        batch_size = T.int32()
+        A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16")
+        C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16")
+        for i0, i1, i2, k in T.grid(batch_size, 1, 28672, 4096):
+            with T.block("C"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.writes(C[v_i0, v_i1, v_i2])
+                with T.init():
+                    C[v_i0, v_i1, v_i2] = T.float16(0)
+                C[v_i0, v_i1, v_i2] += A[v_i0, v_i1, v_k] * B[v_i2, v_k]
+
+    @T.prim_func
+    def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), 
var_C: T.handle):
+        T.func_attr({"tir.is_scheduled": 1})
+        batch_size = T.int32()
+        A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16")
+        C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16")
+        # with T.block("root"):
+        A_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 
16, 4096), "float16", scope="shared")
+        B_reindex_shared = T.alloc_buffer((1, 28672, 4096), "float16", 
scope="shared")
+        A_reindex_pad_shared_metal_simdgroup = T.alloc_buffer((1, (batch_size 
+ 15) // 16 * 16, 4096), "float16", scope="metal.simdgroup")
+        B_reindex_shared_metal_simdgroup = T.alloc_buffer((1, 4096, 28672), 
"float16", scope="metal.simdgroup")
+        C_reindex_pad_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) 
// 16 * 16, 28672), "float16", scope="metal.simdgroup")
+        C_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 
16, 28672), "float16", scope="shared")
+        for ax0 in T.thread_binding(1, thread="blockIdx.z"):
+            for ax1_0 in T.thread_binding((batch_size + 15) // 16, 
thread="blockIdx.x"):
+                for ax2_0 in T.thread_binding(448, thread="blockIdx.y"):
+                    for ax1_1 in T.thread_binding(1, thread="threadIdx.y"):
+                        for ax2_1 in T.thread_binding(4, thread="threadIdx.z"):
+                            for ax1_2_init, ax2_2_init, ax1_3_init_0, 
ax2_3_init_0 in T.grid(2, 2, 1, 1):
+                                with T.block("C_init_o"):
+                                    v0_o = T.axis.spatial(1, ax0)
+                                    v1_o = T.axis.spatial(2 * ((batch_size + 
15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2_init + ax1_3_init_0)
+                                    v2_o = T.axis.spatial(3584, ax2_0 * 8 + 
ax2_1 * 2 + ax2_2_init + ax2_3_init_0)
+                                    T.reads()
+                                    T.writes(C_reindex_pad_metal_simdgroup[0, 
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+                                    A_1 = 
T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 
8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                    T.make_filled_simdgroup_matrix(A_1.data, 
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + 
A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8)
+                            for ax3_0 in range(128):
+                                for ax0_1, ax1_ax2_fused_0 in T.grid(1, 1):
+                                    for ax1_ax2_fused_1 in T.thread_binding(4, 
thread="threadIdx.z"):
+                                        for ax1_ax2_fused_2 in 
T.thread_binding(1, thread="threadIdx.y"):
+                                            for ax1_ax2_fused_3 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                                for ax1_ax2_fused_4 in 
T.vectorized(4):
+                                                    with 
T.block("A_reindex_pad_shared"):
+                                                        v0 = T.axis.spatial(1, 
ax0_1)
+                                                        v1 = 
T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 
512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + 
ax1_ax2_fused_4) // 32)
+                                                        v2 = 
T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32)
+                                                        T.reads(A[v1, 0, v2])
+                                                        
T.writes(A_reindex_pad_shared[v0, v1, v2])
+                                                        
A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, A[v1, 0, 
v2], T.float16(0))
+                                for ax0_1, ax1_ax2_fused_0 in T.grid(1, 4):
+                                    for ax1_ax2_fused_1 in T.thread_binding(4, 
thread="threadIdx.z"):
+                                        for ax1_ax2_fused_2 in 
T.thread_binding(1, thread="threadIdx.y"):
+                                            for ax1_ax2_fused_3 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                                for ax1_ax2_fused_4 in 
T.vectorized(4):
+                                                    with 
T.block("B_reindex_shared"):
+                                                        v0 = T.axis.spatial(1, 
ax0_1)
+                                                        v1 = 
T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32)
+                                                        v2 = 
T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32)
+                                                        T.reads(B[v1, v2])
+                                                        
T.writes(B_reindex_shared[v0, v1, v2])
+                                                        B_reindex_shared[v0, 
v1, v2] = B[v1, v2]
+                                for ax3_1 in range(4):
+                                    for ax0_0, ax1_0_1 in T.grid(2, 1):
+                                        with 
T.block("A_reindex_pad_shared_metal.simdgroup_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(2 * 
((batch_size + 15) // 16), ax1_0 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(512, ax3_0 * 
4 + ax3_1 + ax1_0_1)
+                                            T.reads(A_reindex_pad_shared[v0_o, 
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+                                            
T.writes(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o 
* 8:v2_o * 8 + 8])
+                                            A_1 = 
T.match_buffer(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o 
* 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", 
offset_factor=1)
+                                            C_1 = 
T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 
8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                            T.simdgroup_load(C_1.data, 
C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + 
C_1.elem_offset % C_1.strides[0] // 8, 
T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, 
A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False))
+                                    for ax0_0, ax1_0_1 in T.grid(2, 1):
+                                        with 
T.block("B_reindex_shared_metal.simdgroup_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(3584, ax2_0 
* 8 + ax2_1 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(512, ax3_0 * 
4 + ax3_1 + ax1_0_1)
+                                            T.reads(B_reindex_shared[v0_o, 
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+                                            
T.writes(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 
8:v1_o * 8 + 8])
+                                            A_1 = 
T.match_buffer(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 
+ 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", 
offset_factor=1)
+                                            C_1 = 
T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, 
v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                            T.simdgroup_load(C_1.data, 
C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + 
C_1.elem_offset % C_1.strides[0] // 8, 
T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, 
A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True))
+                                    for ax1_2, ax2_2 in T.grid(2, 2):
+                                        with T.block("C_update_o"):
+                                            v0_o = T.axis.spatial(1, ax0)
+                                            v1_o = T.axis.spatial(2 * 
((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2)
+                                            v2_o = T.axis.spatial(3584, ax2_0 
* 8 + ax2_1 * 2 + ax2_2)
+                                            v3_o = T.axis.reduce(512, ax3_0 * 
4 + ax3_1)
+                                            
T.reads(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 
8 + 8], A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 
8:v3_o * 8 + 8], B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, 
v2_o * 8:v2_o * 8 + 8])
+                                            
T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o 
* 8 + 8])
+                                            A_1 = 
T.match_buffer(A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, 
v3_o * 8:v3_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                            B_1 = 
T.match_buffer(B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o 
* 8:v2_o * 8 + 8], (8, 8), "float16", strides=("B_s0", "B_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                            C_1 = 
T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 
8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                            
T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 
8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, 
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + 
A_1.elem_offset % A_1.strides[0] // 8, B_1.data, B_1.elem_offset // 
B_1.strides[0] // 8 * (B_1.strides[0] // 8) + B_1.elem_offset % B_1.strides[0] 
// 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) 
+ [...]
+                            for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2):
+                                with 
T.block("C_reindex_pad_metal.simdgroup_o"):
+                                    v0_o = T.axis.spatial(1, ax0_1)
+                                    v1_o = T.axis.spatial(2 * ((batch_size + 
15) // 16), ax1_0 * 2 + ax1_0_1)
+                                    v2_o = T.axis.spatial(3584, ax2_0 * 8 + 
ax2_1 * 2 + ax2_0_1)
+                                    
T.reads(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 
8:v2_o * 8 + 8])
+                                    T.writes(C_reindex_pad_shared[v0_o, v1_o * 
8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+                                    A_1 = 
T.match_buffer(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o 
* 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                    C_1 = 
T.match_buffer(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o 
* 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="shared", 
offset_factor=1)
+                                    T.simdgroup_store(A_1.data, 
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + 
A_1.elem_offset % A_1.strides[0] // 8, 
T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset, 
C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False))
+                    for ax0_1, ax1_ax2_fused_0 in T.grid(1, 2):
+                        for ax1_ax2_fused_1 in T.thread_binding(4, 
thread="threadIdx.z"):
+                            for ax1_ax2_fused_2 in T.thread_binding(1, 
thread="threadIdx.y"):
+                                for ax1_ax2_fused_3 in T.thread_binding(32, 
thread="threadIdx.x"):
+                                    for ax1_ax2_fused_4 in T.vectorized(4):
+                                        with T.block("C_reindex_pad_shared"):
+                                            v0 = T.axis.spatial(1, ax0_1)
+                                            v1 = T.axis.spatial((batch_size + 
15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + 
ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
+                                            v2 = T.axis.spatial(28672, ax2_0 * 
64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + 
ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
+                                            T.reads(C_reindex_pad_shared[v0, 
v1, v2])
+                                            T.writes(C[v1, 0, v2])
+                                            if v1 < batch_size:
+                                                C[v1, 0, v2] = 
C_reindex_pad_shared[v0, v1, v2]
+    # fmt: on
+
+
+class TestMatmulMetalInt4Quant(MetalBeforeAfter):
+    # fmt: off
+    @T.prim_func(private=True)
+    def before(
+        B0: T.Buffer((28672, 512), "uint32"),
+        B1: T.Buffer((28672, 128), "float16"),
+        var_A: T.handle,
+        var_C: T.handle
+    ):
+        batch_size = T.int32()
+        A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16")
+        C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16")
+        compute = T.alloc_buffer((28672, 4096), "float16")
+        B = T.alloc_buffer((28672, 4096), "float16")
+        for i0, i1 in T.grid(28672, 4096):
+            with T.block("compute"):
+                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                compute[v_i0, v_i1] = T.Cast("float16", 
T.bitwise_and(T.shift_right(B0[v_i0, v_i1 // 8], T.Cast("uint32", v_i1 % 8 * 
4)), T.uint32(15)))
+        for i0, i1 in T.grid(28672, 4096):
+            with T.block("dequantize"):
+                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                B[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * 
B1[v_i0, v_i1 // 32]
+        for i0, i1, i2, k in T.grid(batch_size, 1, 28672, 4096):
+            with T.block("NT_matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                with T.init():
+                    C[v_i0, v_i1, v_i2] = T.float16(0)
+                C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] 
* B[v_i2, v_k]
+
+    @T.prim_func(private=True)
+    def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 
128), "float16"), var_A: T.handle, var_C: T.handle):
+        T.func_attr({"tir.is_scheduled": 1})
+        batch_size = T.int32()
+        A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16")
+        C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16")
+        # with T.block("root"):
+        A_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 
16, 4096), "float16", scope="shared")
+        B_reindex_shared = T.alloc_buffer((1, 28672, 4096), "float16", 
scope="shared")
+        A_reindex_pad_shared_metal_simdgroup = T.alloc_buffer((1, (batch_size 
+ 15) // 16 * 16, 4096), "float16", scope="metal.simdgroup")
+        B_reindex_shared_metal_simdgroup = T.alloc_buffer((1, 4096, 28672), 
"float16", scope="metal.simdgroup")
+        C_reindex_pad_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) 
// 16 * 16, 28672), "float16", scope="metal.simdgroup")
+        C_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 
16, 28672), "float16", scope="shared")
+        for ax0 in T.thread_binding(1, thread="blockIdx.z"):
+            for ax1_0 in T.thread_binding((batch_size + 15) // 16, 
thread="blockIdx.x"):
+                for ax2_0 in T.thread_binding(448, thread="blockIdx.y"):
+                    for ax1_1 in T.thread_binding(1, thread="threadIdx.y"):
+                        for ax2_1 in T.thread_binding(4, thread="threadIdx.z"):
+                            for ax1_2_init, ax2_2_init, ax1_3_init_0, 
ax2_3_init_0 in T.grid(2, 2, 1, 1):
+                                with T.block("NT_matmul_init_o"):
+                                    v0_o = T.axis.spatial(1, ax0)
+                                    v1_o = T.axis.spatial(2 * ((batch_size + 
15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2_init + ax1_3_init_0)
+                                    v2_o = T.axis.spatial(3584, ax2_0 * 8 + 
ax2_1 * 2 + ax2_2_init + ax2_3_init_0)
+                                    T.reads()
+                                    T.writes(C_reindex_pad_metal_simdgroup[0, 
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+                                    A_1 = 
T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 
8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                    T.make_filled_simdgroup_matrix(A_1.data, 
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + 
A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8)
+                            for ax3_0 in range(128):
+                                for ax0_1, ax1_ax2_fused_0 in T.grid(1, 1):
+                                    for ax1_ax2_fused_1 in T.thread_binding(4, 
thread="threadIdx.z"):
+                                        for ax1_ax2_fused_2 in 
T.thread_binding(1, thread="threadIdx.y"):
+                                            for ax1_ax2_fused_3 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                                for ax1_ax2_fused_4 in 
T.vectorized(4):
+                                                    with 
T.block("A_reindex_pad_shared"):
+                                                        v0 = T.axis.spatial(1, 
ax0_1)
+                                                        v1 = 
T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 
512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + 
ax1_ax2_fused_4) // 32)
+                                                        v2 = 
T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32)
+                                                        T.reads(A[v1, 0, v2])
+                                                        
T.writes(A_reindex_pad_shared[v0, v1, v2])
+                                                        
A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, A[v1, 0, 
v2], T.float16(0))
+                                for ax0_1, ax1_ax2_fused_0 in T.grid(1, 4):
+                                    for ax1_ax2_fused_1 in T.thread_binding(4, 
thread="threadIdx.z"):
+                                        for ax1_ax2_fused_2 in 
T.thread_binding(1, thread="threadIdx.y"):
+                                            for ax1_ax2_fused_3 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                                for ax1_ax2_fused_4 in 
T.vectorized(4):
+                                                    with 
T.block("B_reindex_shared"):
+                                                        v0 = T.axis.spatial(1, 
ax0_1)
+                                                        v1 = 
T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32)
+                                                        v2 = 
T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32)
+                                                        T.reads(B0[v1, v2 // 
8], B1[v1, v2 // 32])
+                                                        
T.writes(B_reindex_shared[v0, v1, v2])
+                                                        B_reindex_shared[v0, 
v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(B0[v1, v2 // 8], 
T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7)) * B1[v1, v2 // 32]
+                                for ax3_1 in range(4):
+                                    for ax0_0, ax1_0_1 in T.grid(2, 1):
+                                        with 
T.block("A_reindex_pad_shared_metal.simdgroup_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(2 * 
((batch_size + 15) // 16), ax1_0 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(512, ax3_0 * 
4 + ax3_1 + ax1_0_1)
+                                            T.reads(A_reindex_pad_shared[v0_o, 
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+                                            
T.writes(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o 
* 8:v2_o * 8 + 8])
+                                            A_1 = 
T.match_buffer(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o 
* 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", 
offset_factor=1)
+                                            C_1 = 
T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 
8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                            T.simdgroup_load(C_1.data, 
C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + 
C_1.elem_offset % C_1.strides[0] // 8, 
T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, 
A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False))
+                                    for ax0_0, ax1_0_1 in T.grid(2, 1):
+                                        with 
T.block("B_reindex_shared_metal.simdgroup_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(3584, ax2_0 
* 8 + ax2_1 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(512, ax3_0 * 
4 + ax3_1 + ax1_0_1)
+                                            T.reads(B_reindex_shared[v0_o, 
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+                                            
T.writes(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 
8:v1_o * 8 + 8])
+                                            A_1 = 
T.match_buffer(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 
+ 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", 
offset_factor=1)
+                                            C_1 = 
T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, 
v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                            T.simdgroup_load(C_1.data, 
C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + 
C_1.elem_offset % C_1.strides[0] // 8, 
T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, 
A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True))
+                                    for ax1_2, ax2_2 in T.grid(2, 2):
+                                        with T.block("NT_matmul_update_o"):
+                                            v0_o = T.axis.spatial(1, ax0)
+                                            v1_o = T.axis.spatial(2 * 
((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2)
+                                            v2_o = T.axis.spatial(3584, ax2_0 
* 8 + ax2_1 * 2 + ax2_2)
+                                            v3_o = T.axis.reduce(512, ax3_0 * 
4 + ax3_1)
+                                            
T.reads(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 
8 + 8], A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 
8:v3_o * 8 + 8], B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, 
v2_o * 8:v2_o * 8 + 8])
+                                            
T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o 
* 8 + 8])
+                                            A_1 = 
T.match_buffer(A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, 
v3_o * 8:v3_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                            B = 
T.match_buffer(B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o 
* 8:v2_o * 8 + 8], (8, 8), "float16", strides=("B_s0", "B_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                            C_1 = 
T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 
8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                            
T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 
8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, 
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + 
A_1.elem_offset % A_1.strides[0] // 8, B.data, B.elem_offset // B.strides[0] // 
8 * (B.strides[0] // 8) + B.elem_offset % B.strides[0] // 8, C_1.data, 
C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_of 
[...]
+                            for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2):
+                                with 
T.block("C_reindex_pad_metal.simdgroup_o"):
+                                    v0_o = T.axis.spatial(1, ax0_1)
+                                    v1_o = T.axis.spatial(2 * ((batch_size + 
15) // 16), ax1_0 * 2 + ax1_0_1)
+                                    v2_o = T.axis.spatial(3584, ax2_0 * 8 + 
ax2_1 * 2 + ax2_0_1)
+                                    
T.reads(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 
8:v2_o * 8 + 8])
+                                    T.writes(C_reindex_pad_shared[v0_o, v1_o * 
8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+                                    A_1 = 
T.match_buffer(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o 
* 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), 
scope="metal.simdgroup", offset_factor=1)
+                                    C_1 = 
T.match_buffer(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o 
* 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="shared", 
offset_factor=1)
+                                    T.simdgroup_store(A_1.data, 
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + 
A_1.elem_offset % A_1.strides[0] // 8, 
T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset, 
C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False))
+                    for ax0_1, ax1_ax2_fused_0 in T.grid(1, 2):
+                        for ax1_ax2_fused_1 in T.thread_binding(4, 
thread="threadIdx.z"):
+                            for ax1_ax2_fused_2 in T.thread_binding(1, 
thread="threadIdx.y"):
+                                for ax1_ax2_fused_3 in T.thread_binding(32, 
thread="threadIdx.x"):
+                                    for ax1_ax2_fused_4 in T.vectorized(4):
+                                        with T.block("C_reindex_pad_shared"):
+                                            v0 = T.axis.spatial(1, ax0_1)
+                                            v1 = T.axis.spatial((batch_size + 
15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + 
ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
+                                            v2 = T.axis.spatial(28672, ax2_0 * 
64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + 
ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
+                                            T.reads(C_reindex_pad_shared[v0, 
v1, v2])
+                                            T.writes(C[v1, 0, v2])
+                                            if v1 < batch_size:
+                                                C[v1, 0, v2] = 
C_reindex_pad_shared[v0, v1, v2]
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to