This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b49468ddf1 [SME] Introduce scalable fp32 dense schedule (#16921)
b49468ddf1 is described below

commit b49468ddf11a1103d82f11009a0b3253a49705aa
Author: Luke Hutton <luke.hut...@arm.com>
AuthorDate: Wed May 15 11:28:16 2024 +0100

    [SME] Introduce scalable fp32 dense schedule (#16921)
    
    This commit adds a new scalable fp32 dense schedule that calls SME 
intrinsics according to the SME RFC: 
https://github.com/apache/tvm-rfcs/pull/107.
    
    Currently the schedule does not make use of predication, meaning the output 
from the matmul compute must be copied in a subsequent compute stage. This will 
be removed once support for predication is added.
---
 python/tvm/micro/testing/aot_test_utils.py         |  10 +
 python/tvm/relay/op/strategy/arm_cpu.py            |  69 +++-
 python/tvm/testing/utils.py                        |  17 +
 python/tvm/tir/tensor_intrin/__init__.py           |   1 -
 python/tvm/tir/tensor_intrin/arm_cpu.py            | 362 ++++++++++++++++++++-
 python/tvm/topi/arm_cpu/__init__.py                |   5 +-
 python/tvm/topi/arm_cpu/arm_utils.py               |  26 ++
 python/tvm/topi/arm_cpu/dense.py                   |  10 +-
 python/tvm/topi/arm_cpu/dense_alter_op.py          |  75 +++++
 python/tvm/topi/arm_cpu/matmul.py                  | 124 +++++++
 python/tvm/topi/x86/dense_alter_op.py              |   2 +-
 src/arith/const_int_bound.cc                       |   2 +-
 src/relay/backend/te_compiler_cache.cc             |   4 +-
 src/relay/op/nn/nn.cc                              |   1 +
 src/tir/schedule/ir_comparator.cc                  |   6 +-
 .../python/codegen/test_target_codegen_aarch64.py  |  46 ++-
 tests/python/integration/test_arm_aprofile.py      |  94 ------
 ...est_meta_schedule_postproc_rewrite_tensorize.py |   2 +-
 .../relay/strategy/arm_cpu/scalable_utils.py       |  53 +++
 .../arm_cpu/{test_dense_dsp.py => test_dense.py}   |  91 +++++-
 tests/python/relay/strategy/arm_cpu/test_matmul.py | 118 +++++++
 .../relay/strategy/test_select_implementation.py   |  55 +++-
 tests/python/relay/test_pass_alter_op_layout.py    |  56 ++++
 tests/python/topi/test_topi_matmul.py              |  20 +-
 24 files changed, 1127 insertions(+), 122 deletions(-)

diff --git a/python/tvm/micro/testing/aot_test_utils.py 
b/python/tvm/micro/testing/aot_test_utils.py
index 06cd0f1c9e..991a3f0ddb 100644
--- a/python/tvm/micro/testing/aot_test_utils.py
+++ b/python/tvm/micro/testing/aot_test_utils.py
@@ -65,6 +65,16 @@ AOT_USMP_CORSTONE300_RUNNER = AOTTestRunner(
     },
 )
 
+AOT_APROFILE_AEM_RUNNER = AOTTestRunner(
+    makefile="aprofile_aem",
+    includes=[],
+    pass_config={
+        "tir.usmp.enable": False,
+        # AOT test infra generates 'fake' tensor inputs which fails asserts
+        "tir.disable_assert": True,
+    },
+)
+
 
 def parametrize_aot_options(test):
     """Parametrize over valid option combinations"""
diff --git a/python/tvm/relay/op/strategy/arm_cpu.py 
b/python/tvm/relay/op/strategy/arm_cpu.py
index 2fc148c3ef..9974d2691d 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -21,7 +21,9 @@ import logging
 # pylint: 
disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
 import re
 
+import tvm
 from tvm import relay, topi, tir
+from tvm.tir.schedule.analysis import has_block
 
 from ....auto_scheduler import is_auto_scheduler_enabled
 from ....meta_schedule import is_meta_schedule_enabled
@@ -639,7 +641,7 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, 
out_type, target):
 def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
     """dense arm cpu strategy"""
     strategy = _op.OpStrategy()
-    data, _ = inputs
+    data, weight = inputs
 
     if target.features.has_dsp and data.dtype in ["int8", "int16"]:
         strategy.add_implementation(
@@ -680,6 +682,23 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, 
target):
             plevel=11,
         )
 
+    if (
+        target.features.has_sme
+        and data.dtype in ["float32"]
+        and weight.dtype in ["float32"]
+        and out_type.dtype in ["float32"]
+        # The schedule uses tensorization which does not work when the
+        # reduction axis has unit iters. See
+        # https://github.com/apache/tvm/issues/16566
+        and data.shape[1] > 1
+    ):
+        strategy.add_implementation(
+            wrap_compute_dense(topi.arm_cpu.compute_matmul_sme),
+            lambda: None,
+            name="matmul.arm_cpu.sme",
+            plevel=12,
+        )
+
     # Fallback to x86 schedules as there is currently no arm_cpu schedule for 
dense
     strategy.add_implementation(
         wrap_compute_dense(topi.x86.dense_nopack),
@@ -697,6 +716,40 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, 
target):
     return strategy
 
 
+@matmul_strategy.register("arm_cpu")
+def matmul_strategy_arm_cpu(attrs, inputs, out_type, target):
+    """matmul arm cpu strategy"""
+    strategy = _op.OpStrategy()
+    data, weight = inputs
+
+    if (
+        target.features.has_sme
+        and data.dtype in ["float32"]
+        and weight.dtype in ["float32"]
+        and out_type.dtype in ["float32"]
+        and not (attrs.transpose_a or attrs.transpose_b)
+        and len(data.shape) == 2
+        # The schedule uses tensorization which does not work when the
+        # reduction axis has unit iters. See
+        # https://github.com/apache/tvm/issues/16566
+        and data.shape[1] > 1
+    ):
+        # Ideally we should check that weight is a Relay constant, but 
strategy functions
+        # don't have access to the data needed to check this.
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.arm_cpu.compute_matmul_sme),
+            lambda: None,
+            name="matmul.arm_cpu.sme",
+        )
+        return strategy
+
+    logger.warning("matmul is not optimized for arm cpu.")
+    strategy.add_implementation(
+        wrap_compute_matmul(topi.nn.matmul), naive_schedule, 
name="matmul.generic"
+    )
+    return strategy
+
+
 @conv1d_strategy.register("arm_cpu")
 def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target):
     """conv1d strategy"""
@@ -737,3 +790,17 @@ def conv1d_strategy_arm_cpu(attrs, inputs, out_type, 
target):
             f"Unsupported kernel layout {kernel_layout} for conv1d {layout} 
for arm cpu."
         )
     return strategy
+
+
+def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool:
+    """
+    Strategy for arm_cpu STIR schedules.
+    """
+    current_target = tvm.target.Target.current()
+
+    if current_target.features.has_sme and has_block(sch, "matmul_sme_gemm"):
+        topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
+        return True
+
+    # Fallback to TE schedule for operators we have not written a special TIR 
schedule for
+    return False
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index ac22af2823..38b39b5fc2 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -1023,6 +1023,19 @@ requires_corstone300 = Feature(
     parent_features="cmsisnn",
 )
 
+
+def _aprofile_aem_fvp_compile_time_check():
+    if shutil.which("FVP_Base_RevC-2xAEMvA") is None:
+        return "AProfile AEM is not available"
+    return True
+
+
+requires_aprofile_aem_fvp = Feature(
+    "aprofile-aem-fvp",
+    "AProfile AEM FVP",
+    compile_time_check=_aprofile_aem_fvp_compile_time_check,
+)
+
 # Mark a test as requiring Vitis AI to run
 requires_vitis_ai = Feature("vitis_ai", "Vitis AI", cmake_flag="USE_VITIS_AI")
 
@@ -1205,6 +1218,10 @@ def skip_if_32bit(reason):
     return decorator
 
 
+def skip_if_no_reference_system(func):
+    return skip_if_32bit(reason="Reference system unavailable in i386 
container")(func)
+
+
 def requires_package(*packages):
     """Mark a test as requiring python packages to run.
 
diff --git a/python/tvm/tir/tensor_intrin/__init__.py 
b/python/tvm/tir/tensor_intrin/__init__.py
index 7e5a26bdeb..d127335e82 100644
--- a/python/tvm/tir/tensor_intrin/__init__.py
+++ b/python/tvm/tir/tensor_intrin/__init__.py
@@ -16,4 +16,3 @@
 # under the License.
 # pylint: disable=unused-import
 """Intrinsics for tensorization."""
-from . import arm_cpu, cuda, rocm, x86, hexagon
diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py 
b/python/tvm/tir/tensor_intrin/arm_cpu.py
index a5003d41a8..90af1e05b1 100644
--- a/python/tvm/tir/tensor_intrin/arm_cpu.py
+++ b/python/tvm/tir/tensor_intrin/arm_cpu.py
@@ -17,6 +17,10 @@
 # pylint: disable=invalid-name,missing-function-docstring,unused-import
 """Intrinsics for ARM tensorization."""
 from tvm.script import tir as T
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder.tir import prim_func as build_prim_func
+from tvm.target.codegen import llvm_version_major
+
 from .. import TensorIntrin
 from .dot_product_common import (
     DP4A_S8S8S32_INTRIN,
@@ -163,15 +167,367 @@ def get_dotprod_intrin(in_dtype, out_dtype):
     return dot_prod_desc, dot_prod_impl
 
 
+def get_sme_transpose_interleave_2svlx2svl_intrin():
+    """
+    Transpose a matrix of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector 
Length) using
+    the Scalable Matrix Extension (SME).
+
+    This is completed by loading rows of the input matrix into the accumulator 
tile,
+    then storing the columns. The SME accumulator tile is divided into a 
series of sub-tiles
+    which must be loaded to / stored from independently.
+
+    Note: currently only supports the fp32 datatype.
+
+    Example
+    -------
+    An example case for float32. In this instance the accumulator tile is 
divided into 4
+    sub-tiles of size SVLxSVL numbered 0-3. We start by loading rows of A, 
each SVL in length,
+    into each of the sub-tiles. In the diagram below, each load for a sub-tile 
is sequenced by
+    a, b, ... till the tile is full.
+
+    The columns of each sub-tile are then stored into A_t. Note that to 
perform a transpose,
+    the contents of sub-tile 1 and 2 are stored in opposite locations - see 
the diagram
+    below.
+
+    A:                                  Accumulator tile:                     
A_t:
+                2SVL                                2SVL                       
        2SVL
+         +----------------+                 +-----------------+                
+-------------------+
+         | --0a--  --1a-- |                 |                 |                
| |  |     |  |     |
+         | --0b--  --1b-- |                 |    0       1    |                
| 0a 0b .. 2a 2b .. |
+         |   ...     ...  | ld1w.horiz      |                 | st1w.vert      
| |  |     |  |     |
+    2SVL | --2a--  --3a-- |   ====>    2SVL |                 |   ====>   2SVL 
| |  |     |  |     |
+         | --2a--  --3b-- |                 |    2       3    |                
| 1a 1b .. 3a 3b .. |
+         |   ...     ...  |                 |                 |                
| |  |     |  |     |
+         +----------------+                 +-----------------+                
+-------------------+
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The SME TensorIntrin that can be used in tensorizing a schedule.
+
+    """
+    SVF = 4 * T.vscale()
+    SVF2 = 2 * SVF
+
+    @T.prim_func
+    def desc(a: T.handle, a_t: T.handle) -> None:
+        A = T.match_buffer(a, (SVF2, SVF2), dtype="float32", offset_factor=1)
+        A_t = T.match_buffer(a_t, (SVF2, SVF2), dtype="float32", 
offset_factor=1)
+        with T.block("root"):
+            T.reads(A[0:SVF2, 0:SVF2])
+            T.writes(A_t[0:SVF2, 0:SVF2])
+            for k, m in T.grid(SVF2, SVF2):
+                with T.block("transpose"):
+                    v_m, v_k = T.axis.remap("SS", [m, k])
+                    A_t[v_k, v_m] = A[v_m, v_k]
+
+    def impl():
+        # Accumulation sub-tile count. For fp32 it is 4
+        sub_tile_count = 4
+
+        with IRBuilder() as ib:
+            with build_prim_func():
+                a = T.arg("a", T.handle())
+                a_t = T.arg("a_t", T.handle())
+
+                A = T.match_buffer(
+                    a, (SVF2, SVF2), "float32", offset_factor=1, 
strides=[T.int32(), 1]
+                )
+                A_t = T.match_buffer(
+                    a_t,
+                    (SVF2, SVF2),
+                    "float32",
+                    offset_factor=1,
+                    strides=[T.int32(), 1],
+                )
+
+                # Disable predication
+                ptrue = T.broadcast(T.IntImm("int1", 1), T.vscale() * 4)
+
+                with T.block("root"):
+                    T.reads(A[0:SVF2, 0:SVF2])
+                    T.writes(A_t[0:SVF2, 0:SVF2])
+
+                    # Load rows of the input matrix
+                    with T.serial(0, SVF) as slice_idx:
+                        for sub_tile_idx in range(0, sub_tile_count):
+                            row_offset = SVF if sub_tile_idx >= 
(sub_tile_count // 2) else 0
+                            col_offset = SVF if sub_tile_idx % 2 else 0
+                            offset = (slice_idx + row_offset) * A.strides[0] + 
col_offset
+
+                            input_ptr = A.access_ptr("r", offset=offset)
+                            sub_tile = T.int32(sub_tile_idx)
+                            T.evaluate(
+                                T.call_llvm_intrin(
+                                    "void",
+                                    "llvm.aarch64.sme.ld1w.horiz",
+                                    T.uint32(4),
+                                    ptrue,
+                                    input_ptr,
+                                    sub_tile,
+                                    slice_idx,
+                                )
+                            )
+
+                    # Store columns to the ouptut matrix
+                    with T.serial(0, SVF) as slice_idx:
+                        for sub_tile_idx in range(0, sub_tile_count):
+                            col_offset = SVF if sub_tile_idx >= 
(sub_tile_count // 2) else 0
+                            row_offset = SVF if sub_tile_idx % 2 else 0
+                            offset = (slice_idx + row_offset) * A_t.strides[0] 
+ col_offset
+
+                            output_ptr = A_t.access_ptr("w", offset=offset)
+                            sub_tile = T.int32(sub_tile_idx)
+                            T.evaluate(
+                                T.call_llvm_intrin(
+                                    "void",
+                                    "llvm.aarch64.sme.st1w.vert",
+                                    T.uint32(4),
+                                    ptrue,
+                                    output_ptr,
+                                    sub_tile,
+                                    slice_idx,
+                                )
+                            )
+
+        return ib.get()
+
+    return desc, impl()
+
+
+def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K):
+    """
+    Compute a GEMM of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector 
Length using
+    outer product operations from the Scalable Matrix Extension (SME).
+
+    The inputs A and B are expected to be of size K x 2SVL and produce a 
result C of
+    size 2SVL x 2SVL.
+
+    The SME accumulator tile is divided into sub-tiles, each of which is 
utilized to
+    calculate the outer-product using columns / rows of A and B respectively. 
For each
+    sub-tile, elements in the first column of input matrix A (accessed 
sequentially due
+    to being transpose-interleaved) and first row of input matrix B are used 
to calculate
+    an outer-product. This is then accumulated with the result of performing an
+    outer-product on the second column and row of A and B respectively. This 
process is
+    repeated K times. Finally, the results of the accumulation are stored.
+
+    Note: The input tensor 'A' must be transpose-interleaved.
+    Note: Currently only supports the fp32 datatype.
+
+    Example
+    -------
+
+    Diagram showing outer-product performed on each of the accumulator 
sub-tiles
+    for the fp32 datatype:
+
+                       SVL           SVL
+                +----------------------------+
+                |       l     |       h      | K
+            K   +----------------------------+
+         +---+  +----------------------------+
+         |   |  |  0:            1:          |-+
+         |   |  |  mopa(l, l)    mopa(l, h)  | |-+
+       l |   |  |                            | | |
+         |   |  |                            | | |
+         |---|  |                            | | |
+         |   |  |  2:            3:          | | |
+       h |   |  |  mopa(h, l)    mopa(h, h)  | | |
+         |   |  |                            | | |
+         |   |  |                            | | |
+         +---+  +----------------------------+ | |
+                  +----------------------------+ |
+                     +---------------------------+
+                                    (accumulate K times)
+
+    Pseudo code computing 2SVL x 2SVL GEMM for fp32 inputs:
+
+    .. code-block:: c
+
+        // Number of fp32 elements in a scalable vector
+        int SVF = SVL / 32;
+
+        // Reset the accumulator tile
+        sme.zero();
+
+        // Calculate outer products and accumulate
+        for (k = 0; k < K; k++) {
+            float32xSVF A_row_0 = A[k][0];
+            float32xSVF A_row_1 = A[k][SVF];
+            float32xSVF B_row_0 = B[k][0];
+            float32xSVF B_row_1 = B[k][SVF];
+
+            float32xSVFxSVF sub_tile_0 += sme.mopa(A_row_0, B_row_0);
+            float32xSVFxSVF sub_tile_1 += sme.mopa(A_row_0, B_row_1);
+            float32xSVFxSVF sub_tile_2 += sme.mopa(A_row_1, B_row_0);
+            float32xSVFxSVF sub_tile_3 += sme.mopa(A_row_1, B_row_1);
+        }
+
+        // Store the results of accumulation
+        for (i = 0; i < SVF; i++) {
+            C[i][0] = sme.horiz(sub_tile_0[i]);
+            C[i][0] = sme.horiz(sub_tile_0[i + SVF]);
+            C[i + SVF][0] = sme.horiz(sub_tile_0[i]);
+            C[i + SVF][0] = sme.horiz(sub_tile_0[i + SVF]);
+        }
+
+    Notes:
+    - Recall that A has been transposed beforehand such that each column is 
now accessed
+      by row.
+    - 'sme.zero' resets the accumulator tile to contain all zero's.
+    - 'sme.mopa' is the outer product and accumulate intrinsic.
+    - 'sme.horiz' stores rows of an accumulator sub-tile to memory.
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The SME TensorIntrin that can be used in tensorizing a schedule.
+
+    """
+    SVF = 4 * T.vscale()
+    SVF2 = 2 * SVF
+
+    @T.prim_func
+    def desc(a: T.handle, b: T.handle, c: T.handle):
+        A = T.match_buffer(a, (K, SVF2), dtype="float32", offset_factor=1)
+        B = T.match_buffer(b, (K, SVF2), dtype="float32", offset_factor=1)
+        C = T.match_buffer(c, (SVF2, SVF2), dtype="float32", offset_factor=1)
+
+        with T.block("root"):
+            T.reads(C[0:SVF2, 0:SVF2], A[0:K, 0:SVF2], B[0:K, 0:SVF2])
+            T.writes(C[0:SVF2, 0:SVF2])
+            for m, n, k in T.grid(SVF2, SVF2, K):
+                with T.block("gemm"):
+                    v_m, v_n, v_k = T.axis.remap("SSR", [m, n, k])
+                    C[v_m, v_n] += A[v_k, v_m] * B[v_k, v_n]
+
+    def impl():
+        # Accumulation sub-tile count. For fp32 it is 4
+        sub_tile_count = 4
+
+        with IRBuilder() as ib:
+            with build_prim_func():
+                a = T.arg("a", T.handle())
+                b = T.arg("b", T.handle())
+                c = T.arg("c", T.handle())
+
+                A = T.match_buffer(a, (K, SVF2), "float32", offset_factor=1, 
strides=[T.int32(), 1])
+                B = T.match_buffer(b, (K, SVF2), "float32", offset_factor=1, 
strides=[T.int32(), 1])
+                C = T.match_buffer(
+                    c, (SVF2, SVF2), "float32", offset_factor=1, 
strides=[T.int32(), 1]
+                )
+
+                ptrue = T.broadcast(T.IntImm("int1", 1), T.vscale() * 4)
+
+                with T.block("root"):
+                    T.reads(C[0:SVF2, 0:SVF2], A[0:K, 0:SVF2], B[0:K, 0:SVF2])
+                    T.writes(C[0:SVF2, 0:SVF2])
+
+                    # Iterate over the reduction axis applying outer product 
and accumulate
+                    with T.serial(K) as k:
+                        a_low = T.BufferLoad(A, [k, T.Ramp(0, 1, T.vscale() * 
4)])
+                        a_high = T.BufferLoad(A, [k, T.Ramp(SVF, 1, T.vscale() 
* 4)])
+                        b_low = T.BufferLoad(B, [k, T.Ramp(0, 1, T.vscale() * 
4)])
+                        b_high = T.BufferLoad(B, [k, T.Ramp(SVF, 1, T.vscale() 
* 4)])
+
+                        input_combinations = [
+                            (a_low, b_low),
+                            (a_low, b_high),
+                            (a_high, b_low),
+                            (a_high, b_high),
+                        ]
+                        for sub_tile_idx in range(0, sub_tile_count):
+                            sub_tile = T.int32(sub_tile_idx)
+                            input_1 = input_combinations[sub_tile_idx][0]
+                            input_2 = input_combinations[sub_tile_idx][1]
+
+                            T.evaluate(
+                                T.call_llvm_intrin(
+                                    "void",
+                                    "llvm.aarch64.sme.mopa.nxv4f32",
+                                    T.uint32(5),
+                                    sub_tile,
+                                    ptrue,
+                                    ptrue,
+                                    input_1,
+                                    input_2,
+                                )
+                            )
+
+                    # Store the accumulated tile results
+                    with T.serial(SVF) as slice_idx:
+                        for sub_tile_idx in range(sub_tile_count):
+                            vert_offset = SVF if sub_tile_idx >= 
(sub_tile_count // 2) else 0
+                            horiz_offset = SVF if sub_tile_idx % 2 else 0
+                            local_offset = (slice_idx + vert_offset) * 
C.strides[0] + horiz_offset
+                            output_ptr = C.access_ptr("w", 
offset=local_offset, extent=SVF)
+
+                            T.evaluate(
+                                T.call_llvm_intrin(
+                                    "void",
+                                    "llvm.aarch64.sme.st1w.horiz",
+                                    T.uint32(4),
+                                    ptrue,
+                                    output_ptr,
+                                    T.int32(sub_tile_idx),
+                                    T.int32(slice_idx),
+                                )
+                            )
+
+            return ib.get()
+
+    return desc, impl()
+
+
+def get_sme_init_intrin():
+    """
+    Reset the entire matrix tile storage to 0.
+    """
+    SVF2 = 2 * 4 * T.vscale()
+
+    @T.prim_func
+    def desc(c: T.handle) -> None:
+        C = T.match_buffer(c, (SVF2, SVF2), "float32", offset_factor=1)
+        with T.block("root"):
+            T.reads()
+            T.writes(C[0:SVF2, 0:SVF2])
+            for m, n in T.grid(SVF2, SVF2):
+                with T.block("init"):
+                    v_m, v_n = T.axis.remap("SS", [m, n])
+                    C[v_m, v_n] = T.float32(0)
+
+    @T.prim_func
+    def impl(c: T.handle) -> None:
+        C = T.match_buffer(c, (SVF2, SVF2), "float32", offset_factor=1)
+        with T.block("root"):
+            T.reads()
+            T.writes(C[0:SVF2, 0:SVF2])
+            clear_all_tiles = T.int32(255)
+            T.evaluate(
+                T.call_llvm_intrin("void", "llvm.aarch64.sme.zero", 
T.uint32(1), clear_all_tiles)
+            )
+
+    return desc, impl
+
+
 ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon"
 ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot"
 ARM_DOT_4x4_u8_UDOT_INTRIN = "dot_4x4_u8u8u32_udot"
 ARM_DOT_4x4_u8_HDOT_INTRIN = "dot_4x4_u8u8i32_hdot"
 
 TensorIntrin.register(ARM_DOT_4x4_i8_NEON_INTRIN, neon_4x4_i8i8i32_desc, 
neon_4x4_i8i8i32_impl)
-
 TensorIntrin.register(ARM_DOT_4x4_i8_SDOT_INTRIN, *get_dotprod_intrin("int8", 
"int32"))
-
 TensorIntrin.register(ARM_DOT_4x4_u8_UDOT_INTRIN, *get_dotprod_intrin("uint8", 
"uint32"))
-
 TensorIntrin.register(ARM_DOT_4x4_u8_HDOT_INTRIN, *get_dotprod_intrin("uint8", 
"int32"))
+
+ARM_SME_INIT = "sme_init"
+ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE = "sme_2svlx2svl_transpose_interleave"
+ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA = "sme_2svlx2svl_gemm_interleaved_mopa"
+
+# The following tensor intrinsics use LLVM intrinsics that are only available
+# in versions of LLVM >= 15. Installations with older versions of LLVM will
+# not be able to use them.
+if llvm_version_major() >= 15:
+    TensorIntrin.register(
+        ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE, 
*get_sme_transpose_interleave_2svlx2svl_intrin()
+    )
+    TensorIntrin.register(ARM_SME_INIT, *get_sme_init_intrin())
diff --git a/python/tvm/topi/arm_cpu/__init__.py 
b/python/tvm/topi/arm_cpu/__init__.py
index 054103f43b..5484adaa64 100644
--- a/python/tvm/topi/arm_cpu/__init__.py
+++ b/python/tvm/topi/arm_cpu/__init__.py
@@ -22,13 +22,16 @@ from .conv2d import *
 from .depthwise_conv2d import *
 from .conv2d_transpose import *
 from .conv2d_int8 import *
-from . import conv2d_alter_op
 from .bitserial_conv2d import *
 from .bitserial_dense import *
 from .injective import *
 from .group_conv2d import *
 from .pooling import *
 from .dense import *
+from .matmul import *
 from .qnn import *
+
+from . import conv2d_alter_op
+from . import dense_alter_op
 from . import qnn_alter_op
 from . import qnn_legalize
diff --git a/python/tvm/topi/arm_cpu/arm_utils.py 
b/python/tvm/topi/arm_cpu/arm_utils.py
index c350b87167..f2e01c5aef 100644
--- a/python/tvm/topi/arm_cpu/arm_utils.py
+++ b/python/tvm/topi/arm_cpu/arm_utils.py
@@ -19,6 +19,7 @@
 
 import tvm
 from tvm.target import Target
+from tvm.tir.expr import PrimExpr
 
 
 def get_tiling_A(interleave_A, in_dtype):
@@ -186,6 +187,31 @@ def get_conv2d_im2col_padding(M, K, tile_M, tile_K):
     return pad_M, pad_K
 
 
+def pad_dim_to_multiple(dim: PrimExpr, multiple: PrimExpr):
+    """
+    Compute the padding required to reach specified multiple.
+
+    Parameters
+    ----------
+    dim : PrimExpr
+        Current size of the dim.
+    multiple : PrimExpr
+        Multiple to pad up to.
+
+    Returns
+    -------
+    padded_dim : PrimExpr
+        The new dim size.
+    pad_value : PrimExpr
+        The padding required.
+    """
+    pad_value = 0
+    if dim % multiple != 0:
+        pad_value = multiple - (dim % multiple)
+    padded_dim = dim + pad_value
+    return padded_dim, pad_value
+
+
 def get_conv2d_weights_padding(N, K, tile_N, tile_K):
     """Compute the necessary padding for matrix B', where B'
     is the transformed version of matrix B in C=A*B.
diff --git a/python/tvm/topi/arm_cpu/dense.py b/python/tvm/topi/arm_cpu/dense.py
index dd66b0d531..6a44cc89b0 100644
--- a/python/tvm/topi/arm_cpu/dense.py
+++ b/python/tvm/topi/arm_cpu/dense.py
@@ -14,16 +14,18 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, unused-variable, no-else-return, 
unused-argument, import-outside-toplevel
 """Dense schedule for ARM CPU"""
-
 from tvm import autotvm
-from .mprofile.dsp.dense import dense_dsp_schedule, dense_dsp_compute
+
+from .mprofile.dsp.dense import (
+    dense_dsp_schedule,
+    dense_dsp_compute,
+)
 
 
 @autotvm.register_topi_compute("dense_dsp.arm_cpu")
 def dense_dsp(cfg, data, weight, bias, out_dtype):
-    """Compute conv2d_nhwc with v7e-m DSP instructions."""
+    """Compute dense_dsp with v7e-m DSP instructions."""
     return dense_dsp_compute(cfg, data, weight, bias=bias, out_dtype=out_dtype)
 
 
diff --git a/python/tvm/topi/arm_cpu/dense_alter_op.py 
b/python/tvm/topi/arm_cpu/dense_alter_op.py
new file mode 100644
index 0000000000..208b923e68
--- /dev/null
+++ b/python/tvm/topi/arm_cpu/dense_alter_op.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Dense alter op definitions for the `arm_cpu` device key."""
+
+import tvm
+from tvm import relay
+from tvm import autotvm
+from tvm import te
+
+from ..nn import dense_alter_layout
+
+
+@dense_alter_layout.register("arm_cpu")
+def _alter_dense(attrs, inputs, tinfos, out_type):
+    target = tvm.target.Target.current(allow_none=False)
+    dispatch_ctx = autotvm.task.DispatchContext.current
+
+    _, outs = relay.backend.te_compiler.select_implementation(
+        relay.op.get("nn.dense"),
+        attrs,
+        tinfos,
+        out_type,
+        target,
+    )
+    workload = autotvm.task.get_workload(outs)
+    if workload is None:
+        # The best implementation is not an AutoTVM template,
+        # we then assume it's not necessary to alter this op.
+        return None
+
+    cfg = dispatch_ctx.query(target, workload)
+    topi_impl = workload[0]
+    if topi_impl == "matmul.arm_cpu.sme":
+        # Pre-compute transposed weights and convert to a matmul
+        assert isinstance(
+            inputs[1], relay.Constant
+        ), "matmul_sme.arm_cpu requires weights be a Relay Constant"
+
+        weight_dtype = tinfos[1].dtype
+        weight_data = inputs[1].data.numpy()
+        interleaved = weight_data.transpose()
+        encoded_weight = relay.const(interleaved, weight_dtype)
+
+        new_weight = te.placeholder((weight_data.shape), dtype=weight_dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [tinfos[0], new_weight, None, out_type.dtype], topi_impl
+        )
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return relay.nn.matmul(
+            inputs[0],
+            encoded_weight,
+            units=attrs.units,
+            out_dtype=attrs.out_dtype,
+            transpose_a=False,
+            transpose_b=False,
+        )
+
+    # x86 schedules are used as a fallback
+    return tvm.topi.x86.dense_alter_op._alter_dense_layout(attrs, inputs, 
tinfos, out_type)
diff --git a/python/tvm/topi/arm_cpu/matmul.py 
b/python/tvm/topi/arm_cpu/matmul.py
new file mode 100644
index 0000000000..ea8b27cabc
--- /dev/null
+++ b/python/tvm/topi/arm_cpu/matmul.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-argument
+
+"""Matmul schedules for the `arm_cpu` device key."""
+
+import tvm
+from tvm import te
+from tvm import autotvm
+from tvm.script import tir as T
+from tvm.topi import nn
+from tvm.topi.utils import get_const_tuple
+from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
+from tvm.topi.arm_cpu.arm_utils import pad_dim_to_multiple
+
+
+@autotvm.register_topi_compute("matmul.arm_cpu.sme")
+def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype, transpose_a=False, 
transpose_b=False):
+    """
+    SME Matmul compute definition.
+    """
+    assert (
+        transpose_a == transpose_b == False
+    ), "Compute definition currently does not support transposed inputs."
+
+    M, K = get_const_tuple(data_a.shape)
+    N = get_const_tuple(data_b.shape)[1]
+
+    if not out_dtype:
+        out_dtype = data_a.dtype
+
+    tile_m = 2 * 4 * tvm.tir.vscale()
+    tile_n = 2 * 4 * tvm.tir.vscale()
+
+    M_padded, pad_M = pad_dim_to_multiple(M, tile_m)
+    N_padded, pad_N = pad_dim_to_multiple(N, tile_n)
+    if pad_M != 0:
+        data_a = nn.pad(data_a, pad_before=(0, 0), pad_after=(pad_M, 0))
+    if pad_N != 0:
+        data_b = nn.pad(data_b, pad_before=(0, 0), pad_after=(0, pad_N))
+
+    k = te.reduce_axis((0, K), name="k")
+    C = te.compute(
+        (M_padded, N_padded),
+        lambda m, n: te.sum(
+            data_a[m, k].astype(data_a.dtype) * data_b[k, 
n].astype(data_b.dtype),
+            axis=k,
+        ).astype(out_dtype),
+        name="matmul_sme_gemm",
+    )
+    C = te.compute((M, N), lambda m, n: C[m, n])
+    return C
+
+
+def tir_schedule_matmul_sme(sch):
+    """
+    SME STIR Matmul schedule.
+    """
+    # pylint: disable=import-outside-toplevel
+    from tvm.tir.tensor_intrin.arm_cpu import (
+        ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE,
+        ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
+        ARM_SME_INIT,
+        get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
+    )
+
+    gemm_block = sch.get_block("matmul_sme_gemm")
+    m, n, k = sch.get_loops(gemm_block)
+
+    extent_m = sch.get(m).extent
+    extent_k = sch.get(k).extent
+
+    tile_m = T.cast(2 * 4 * T.vscale(), extent_m.dtype)
+    tile_k = T.cast(2 * 4 * T.vscale(), extent_k.dtype)
+    tile_n = T.cast(2 * 4 * T.vscale(), sch.get(n).extent.dtype)
+
+    # Interleave the input utilizing the matrix tile
+    interleave_a_block = sch.cache_read(gemm_block, 0, "global")
+    sch.transform_layout(interleave_a_block, ("write", 0), lambda m, k: (k, m))
+    m, k = sch.get_loops(interleave_a_block)
+    outer_m, inner_m = sch.split(m, factors=(None, tile_m), 
disable_predication=True)
+    outer_k, inner_k = sch.split(k, factors=(None, tile_k), 
disable_predication=True)
+    sch.reorder(outer_k, outer_m, inner_k, inner_m)
+    sch.tensorize(inner_k, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE)
+
+    # Split and reorder the loops of the GeMM for tensorization
+    m, n, k = sch.get_loops(gemm_block)
+    outer_m, inner_m = sch.split(m, factors=(None, tile_m), 
disable_predication=True)
+    outer_n, inner_n = sch.split(n, factors=(None, tile_n), 
disable_predication=True)
+    sch.reorder(outer_m, outer_n, inner_m, inner_n, k)
+
+    # Tensorize the GeMM initialization
+    init_block = sch.decompose_reduction(gemm_block, inner_m)
+    sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT)
+
+    # Tensorize the GeMM update
+    sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA 
+ f"_{extent_k}"
+    tvm.tir.TensorIntrin.register(
+        sme_gemm_interleaved_intrin_name,
+        *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(extent_k),
+        override=True,
+    )
+    sch.tensorize(inner_m, sme_gemm_interleaved_intrin_name)
+
+    # Add pstate annotations
+    root_block = sch.get_block("root")
+    sch.annotate(
+        root_block, SMEAttributes.STREAMING_MODE, 
SMEAttributes.StreamingModeValues.ENABLED
+    )
+    sch.annotate(root_block, SMEAttributes.ZA_STORAGE, 
SMEAttributes.ZAStorageValues.NEW)
diff --git a/python/tvm/topi/x86/dense_alter_op.py 
b/python/tvm/topi/x86/dense_alter_op.py
index 0e9b1f7b65..10b1248c6a 100644
--- a/python/tvm/topi/x86/dense_alter_op.py
+++ b/python/tvm/topi/x86/dense_alter_op.py
@@ -39,7 +39,7 @@ def check_int8_applicable(x, y, allow_padding=False):
     )
 
 
-@dense_alter_layout.register(["cpu", "arm_cpu"])
+@dense_alter_layout.register(["cpu"])
 def _alter_dense_layout(attrs, inputs, tinfos, out_type):
     target = tvm.target.Target.current(allow_none=False)
     dispatch_ctx = autotvm.task.DispatchContext.current
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 57dd024a27..76c97c5ad5 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -371,7 +371,7 @@ class ConstIntBoundAnalyzer::Impl
     } else if (op->op.same_as(tir::builtin::bitwise_and())) {
       return VisitBitwiseAnd(op);
     } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE()) {
-      return MakeBound(1, 16);
+      return MakeBound(1, kAArch64VScaleValues.size());
     } else {
       return Everything(op->dtype);
     }
diff --git a/src/relay/backend/te_compiler_cache.cc 
b/src/relay/backend/te_compiler_cache.cc
index b747855bff..2655cf6671 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -476,12 +476,10 @@ class ScheduleBuilder : public ExprVisitor {
         
mod_eq_structural_(meta_schedule::ModuleEquality::Create("ignore-ndarray")) {
     // Whether to use auto_scheduler schedule.
     use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+    database_ = meta_schedule::Database::Current();
     if (backend::IsMetaScheduleEnabled()) {
-      database_ = meta_schedule::Database::Current();
       CHECK(database_.defined()) << "ValueError: `use_meta_schedule` is 
enabled in Relay "
                                     "build, but no `meta_schedule.Database` 
context is provided. ";
-    } else {
-      database_ = NullOpt;
     }
   }
 
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 9e2fe63b00..ccc9734855 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -193,6 +193,7 @@ RELAY_REGISTER_OP("nn.matmul")
     .add_argument("tensor_a", "nD Tensor", "The first input Tensor.")
     .add_argument("tensor_b", "2D Tensor", "The second input Tensor.")
     .set_support_level(1)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", 
DenseInferCorrectLayout)
     .add_type_rel("Matmul", MatmulRel<MatmulAttrs>)
     .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
 
diff --git a/src/tir/schedule/ir_comparator.cc 
b/src/tir/schedule/ir_comparator.cc
index 00e573eaf6..a97cda266f 100644
--- a/src/tir/schedule/ir_comparator.cc
+++ b/src/tir/schedule/ir_comparator.cc
@@ -18,6 +18,8 @@
  */
 #include "./ir_comparator.h"
 
+#include "../../arith/scalable_expression.h"
+
 namespace tvm {
 
 namespace tir {
@@ -74,7 +76,9 @@ bool TensorizeComparator::VisitStmt(const Stmt& n, const 
Stmt& other) {
 bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) {
   bool equal = n.same_as(other) ||
                ((n->type_index() == other->type_index()) &&
-                n.dtype().code() == other.dtype().code() && 
ExprComparator::VisitExpr(n, other));
+                n.dtype().code() == other.dtype().code() && 
ExprComparator::VisitExpr(n, other)) ||
+               (tvm::arith::ContainsVscaleCall(n) && 
analyzer_.CanProveEqual(n, other));
+
   if (!equal && assert_mode_) {
     std::ostringstream os;
     os << "Expression mismatch: " << n << " vs " << other;
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py 
b/tests/python/codegen/test_target_codegen_aarch64.py
index 9726f79d7a..f73d96e7c9 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -15,15 +15,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import re
+"""
+Codegen tests for AArch64
+"""
 
+import re
 import pytest
 
 import tvm
 from tvm import te
 from tvm.script import tir as T
 from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
-
 from tvm.target.codegen import llvm_version_major
 
 
@@ -496,6 +498,46 @@ def test_codegen_vscale():
     assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."
 
 
+@pytest.mark.skipif(
+    llvm_version_major() < 16, reason="SME is not supported in earlier 
versions of LLVM"
+)
+@pytest.mark.parametrize("dtype", ["float32"])
+def test_matmul_sme(dtype):
+    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+v9a,+sme"
+
+    def check_correct_assembly(dtype):
+        A = te.placeholder((32, 32), dtype=dtype, name="A")
+        B = te.placeholder((32, 32), dtype=dtype, name="B")
+
+        with tvm.target.Target(target):
+            C = tvm.topi.arm_cpu.matmul.compute_matmul_sme(A, B, None, dtype, 
False, False)
+            prim_func = te.create_prim_func([A, B, C])
+
+            sch = tvm.tir.Schedule(prim_func)
+            tvm.topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
+            prim_func = sch.mod
+
+            f = tvm.build(prim_func, target=target)
+
+        assembly = f.get_source("asm")
+        smstart = re.findall(r"smstart\t(sm|za)", assembly)
+        loads = re.findall(r"ld1[whdb]\t{\s?za", assembly)
+        mopa = re.findall(
+            r"fmopa\tza[0-9].[shdb],( p[0-9]/[zm],)?( p[0-9]/[zm],)? 
z[0-9].[shdb], z[0-9].[shdb]",
+            assembly,
+        )
+        stores = re.findall(r"st1[whdb]\t{\s?za", assembly)
+        smstop = re.findall(r"smstop\t(sm|za)", assembly)
+
+        assert len(smstart) > 0
+        assert len(loads) > 0
+        assert len(mopa) > 0
+        assert len(stores) > 0
+        assert len(smstop) > 0
+
+    check_correct_assembly(dtype=dtype)
+
+
 @pytest.mark.skipif(
     llvm_version_major() < 11, reason="Vscale is not supported in earlier 
versions of LLVM"
 )
diff --git a/tests/python/integration/test_arm_aprofile.py 
b/tests/python/integration/test_arm_aprofile.py
index af35a14297..d32fed00af 100644
--- a/tests/python/integration/test_arm_aprofile.py
+++ b/tests/python/integration/test_arm_aprofile.py
@@ -16,7 +16,6 @@
 # under the License.
 """Tests for Arm(R) A-Profile Architecture."""
 import os
-import subprocess
 
 import numpy as np
 import pytest
@@ -26,8 +25,6 @@ import tvm.testing
 from tvm import relay
 from tvm.relay.transform import ToMixedPrecision, FoldConstant
 from tvm.relay.build_module import bind_params_by_name
-from tvm.testing.aot import AOTTestModel, AOTTestRunner, generate_ref_data, 
compile_and_run
-from tvm.contrib import utils
 
 
 def get_mattr(dtype):
@@ -80,96 +77,5 @@ def test_conv2d(dtype):
         lib.export_library(lib_path, cc="aarch64-linux-gnu-gcc")
 
 
-# AOT Test Runner using the AArch64 Architecture Envelope Model (AEM)
-# Fixed Virtual Platform (FVP) reference system.
-# See: 
https://developer.arm.com/Tools%20and%20Software/Fixed%20Virtual%20Platforms
-AOT_APROFILE_AEM_RUNNER = AOTTestRunner(
-    makefile="aprofile_aem",
-    pass_config={
-        "tir.usmp.enable": False,
-        "tir.disable_assert": True,  # AOT test infra creates 'fake' inputs 
that fail asserts
-    },
-)
-
-
-@tvm.testing.requires_x86
-@tvm.testing.skip_if_32bit
-def test_aem_simple_addition():
-    """Tests a simple addition running on the AArch64 AEM."""
-    inp = relay.var("data", shape=(1, 2, 4, 4))
-    add = relay.add(inp, relay.const(np.ones((1, 2, 4, 4))))
-    func = relay.Function([inp], add)
-    ir_mod = tvm.IRModule.from_expr(func)
-    ir_mod = tvm.relay.transform.InferType()(ir_mod)
-
-    main_func = ir_mod["main"]
-    shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in 
main_func.params}
-    type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params}
-
-    input_data = 
np.random.uniform(size=shape_dict["data"]).astype(type_dict["data"])
-    params = {}
-    inputs = {"data": input_data}
-    ref_outputs = generate_ref_data(ir_mod, inputs, params)
-
-    compile_and_run(
-        AOTTestModel(module=ir_mod, inputs=inputs, outputs=ref_outputs, 
params=params),
-        target=tvm.target.Target("llvm -mtriple=aarch64-none-elf"),
-        runtime=tvm.relay.backend.Runtime("crt", {"system-lib": True}),
-        interface_api="packed",
-        use_unpacked_api=False,
-        runner=AOT_APROFILE_AEM_RUNNER,
-    )
-
-
-@tvm.testing.requires_x86
-@tvm.testing.skip_if_32bit
-def test_aem_asm_sme():
-    """
-    Tests SME assembly runs on the AArch64 AEM. This test is used as a simple
-    sanity check until the TVM schedules are able to produce SME.
-    """
-    c_code = """
-    #include <stdio.h>
-
-    int main(void) {
-        __asm volatile(
-            "smstart\\n"
-            "smstop\\n"
-        );
-        printf("EXITTHESIM\\n");
-        return 0;
-    }
-    """
-    runner = AOT_APROFILE_AEM_RUNNER
-
-    tmpdir = utils.tempdir()
-    build_path = os.path.join(tmpdir.path, "build")
-    os.makedirs(build_path, exist_ok=True)
-
-    with open(build_path + "/test.c", "w") as f:
-        f.write(c_code)
-
-    file_dir = os.path.dirname(os.path.abspath(__file__))
-    makefile_dir = os.path.join(file_dir, "../../../tests/python/relay/aot")
-    makefile = os.path.join(makefile_dir, f"{runner.makefile}.mk")
-
-    make_command = (
-        f"make -f {makefile} build_dir={build_path}"
-        + f" TVM_ROOT={file_dir}/../../.."
-        + f" AOT_TEST_ROOT={makefile_dir}"
-        + " FVP_DIR=/opt/arm/fvp/Base_RevC_AEMvA_pkg/models/Linux64_GCC-9.3/"
-    )
-
-    compile_command = f"{make_command} aot_test_runner"
-    popen = subprocess.Popen(compile_command, cwd=build_path, shell=True, 
stdout=subprocess.PIPE)
-    return_code = popen.wait()
-    assert not return_code, "Failed to compile"
-
-    run_command = f"{make_command} run"
-    popen = subprocess.Popen(run_command, cwd=build_path, shell=True, 
stdout=subprocess.PIPE)
-    return_code = popen.wait()
-    assert not return_code, "Failed to run"
-
-
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py 
b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py
index 8cc1c7c7aa..1272b35451 100644
--- 
a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py
+++ 
b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py
@@ -18,7 +18,7 @@
 import tvm
 from tvm import meta_schedule as ms
 from tvm.script import tir as T
-from tvm.tir.tensor_intrin import arm_cpu, cuda, rocm, x86
+from tvm.tir.tensor_intrin import cuda, rocm, x86
 
 
 @tvm.script.ir_module
diff --git a/tests/python/relay/strategy/arm_cpu/scalable_utils.py 
b/tests/python/relay/strategy/arm_cpu/scalable_utils.py
new file mode 100644
index 0000000000..ad16a47612
--- /dev/null
+++ b/tests/python/relay/strategy/arm_cpu/scalable_utils.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm.tir.stmt_functor import post_order_visit, ir_transform
+
+
+def calculate_extra_workspace_size_from_scalable_extents(func, 
known_vscale_value):
+    """
+    The AOT executor needs to know the size of the workspace ahead of time, 
but this
+    isn't possible when some allocations are scalable (vscale is not known at 
compile-time).
+    If we know the target hardware, we can reason about the value of vscale 
ahead of time.
+    This function will calculate an upper-bound for the extra workspace bytes 
required by the
+    AOT executor given TIR function and a known value for vscale.
+    """
+    extra_workspace_bytes = 0
+    is_scalable_extent = False
+    ana = tvm.arith.Analyzer()
+
+    def replace_vscale_with_known_value(stmt):
+        nonlocal is_scalable_extent
+        if isinstance(stmt, tvm.tir.expr.Call) and stmt.op.name == 
"tir.vscale":
+            is_scalable_extent = True
+            return tvm.tir.IntImm(stmt.dtype, known_vscale_value)
+
+    def calculate_workspace_bytes(stmt):
+        nonlocal extra_workspace_bytes, is_scalable_extent
+        if isinstance(stmt, tvm.tir.stmt.Allocate):
+            for extent in stmt.extents:
+                extent_stmt = tvm.tir.Evaluate(extent)
+                is_scalable_extent = False
+                mutated_extent = ir_transform(extent_stmt, 
replace_vscale_with_known_value, None)
+                # Non scalable extents are already included in the calculation 
by AOT
+                if is_scalable_extent:
+                    alloc_bytes = ana.simplify(mutated_extent.value) * 
tvm.DataType(stmt.dtype).bits
+                    extra_workspace_bytes += alloc_bytes
+
+    post_order_visit(func.body, calculate_workspace_bytes)
+    return extra_workspace_bytes
diff --git a/tests/python/relay/strategy/arm_cpu/test_dense_dsp.py 
b/tests/python/relay/strategy/arm_cpu/test_dense.py
similarity index 50%
rename from tests/python/relay/strategy/arm_cpu/test_dense_dsp.py
rename to tests/python/relay/strategy/arm_cpu/test_dense.py
index abd3ac4a3f..b9384e532e 100644
--- a/tests/python/relay/strategy/arm_cpu/test_dense_dsp.py
+++ b/tests/python/relay/strategy/arm_cpu/test_dense.py
@@ -14,14 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import pytest
 import numpy as np
+
 import tvm
 import tvm.testing
 from tvm import relay
-from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data
-from tvm.micro.testing.aot_test_utils import (
-    AOT_CORSTONE300_RUNNER,
+from tvm import meta_schedule
+from tvm.testing.aot import (
+    AOTTestModel,
+    AOTCompiledTestModel,
+    compile_and_run,
+    run_and_check,
+    generate_ref_data,
 )
+from tvm.micro.testing.aot_test_utils import AOT_CORSTONE300_RUNNER, 
AOT_APROFILE_AEM_RUNNER
+from tvm.target.codegen import llvm_version_major
+from tvm.relay.op.strategy.arm_cpu import arm_cpu_tir_strategy
+from scalable_utils import calculate_extra_workspace_size_from_scalable_extents
 
 
 class BasicDenseTests:
@@ -84,5 +94,80 @@ class TestDense(BasicDenseTests):
     enable_bias = tvm.testing.parameter(False, True)
 
 
+@pytest.mark.skipif(
+    llvm_version_major() < 17, reason="SME is not supported in earlier 
versions of LLVM"
+)
+@tvm.testing.requires_aprofile_aem_fvp
+@pytest.mark.parametrize(
+    "data_shape,weight_shape",
+    [
+        ((32, 32), (32, 32)),
+        ((2, 35), (6, 35)),
+        ((3, 3), (68, 3)),
+        ((79, 65), (152, 65)),
+    ],
+)
+@pytest.mark.parametrize("dtype", ["float32"])
+def test_sme_dense(data_shape, weight_shape, dtype):
+    np.random.seed(0)
+
+    input_data = np.random.uniform(size=data_shape).astype(dtype)
+    inp = relay.var("data", shape=data_shape, dtype=dtype)
+    weight_data = np.random.uniform(size=weight_shape).astype(dtype)
+    weight = relay.const(weight_data, dtype=dtype)
+
+    dense = relay.nn.dense(inp, weight)
+    func = relay.Function(relay.analysis.free_vars(dense), dense)
+
+    ir_mod = tvm.IRModule.from_expr(func)
+    ir_mod = tvm.relay.transform.InferType()(ir_mod)
+
+    inputs = {"data": input_data}
+    params = {}
+    ref_outputs = generate_ref_data(ir_mod, inputs, params)
+
+    target = tvm.target.Target("llvm -mtriple=aarch64-none-elf 
-mattr=+v9.2a,+sme")
+    runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True})
+    executor = tvm.relay.backend.Executor(
+        "aot",
+        {
+            "interface-api": "packed",
+            "unpacked-api": False,
+        },
+    )
+
+    with tvm.transform.PassContext(
+        opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config
+    ), meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy):
+        executor_factory = tvm.relay.build(
+            ir_mod,
+            target=target,
+            executor=executor,
+            runtime=runtime,
+            params=params,
+        )
+    generated_func = executor_factory.lowered_ir_mods.items()[0][1][
+        "tvmgen_default_fused_nn_matmul"
+    ]
+    extra_memory_in_bytes = 
calculate_extra_workspace_size_from_scalable_extents(generated_func, 4)
+
+    test_model = AOTTestModel(
+        ir_mod, inputs, ref_outputs, params=params, 
extra_memory_in_bytes=extra_memory_in_bytes
+    )
+    compiled = AOTCompiledTestModel(test_model, executor_factory)
+
+    assembly = (
+        
compiled.executor_factory.module.imported_modules[0].imported_modules[0].get_source("asm")
+    )
+    assert "fmopa" in assembly
+
+    assert run_and_check(
+        models=[compiled],
+        interface_api="packed",
+        runner=AOT_APROFILE_AEM_RUNNER,
+        print_output_on_mismatch=True,
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relay/strategy/arm_cpu/test_matmul.py 
b/tests/python/relay/strategy/arm_cpu/test_matmul.py
new file mode 100644
index 0000000000..3b46c8019a
--- /dev/null
+++ b/tests/python/relay/strategy/arm_cpu/test_matmul.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm import meta_schedule
+from tvm.testing.aot import (
+    AOTTestModel,
+    AOTCompiledTestModel,
+    run_and_check,
+    generate_ref_data,
+)
+from tvm.micro.testing.aot_test_utils import AOT_APROFILE_AEM_RUNNER
+from tvm.target.codegen import llvm_version_major
+from tvm.relay.op.strategy.arm_cpu import arm_cpu_tir_strategy
+from scalable_utils import calculate_extra_workspace_size_from_scalable_extents
+
+
+@pytest.mark.skipif(
+    llvm_version_major() < 17, reason="SME is not supported in earlier 
versions of LLVM"
+)
+@tvm.testing.requires_aprofile_aem_fvp
+@pytest.mark.parametrize(
+    "data_shape,weight_shape,transpose_a,transpose_b",
+    [
+        ((4, 63), (63, 10), False, False),
+        ((64, 32), (32, 32), False, True),
+        ((96, 64), (64, 32), False, False),
+        ((62, 3), (3, 3), False, False),
+        ((4, 5), (79, 5), False, True),
+        ((134, 36), (36, 111), False, False),
+        ((3, 10), (10, 72), False, False),
+        # Tensorization does not work when the reduction axis has unit iters.
+        # See https://github.com/apache/tvm/issues/16566
+        # ((5, 1), (1, 5), False, False),
+    ],
+)
+@pytest.mark.parametrize("dtype", ["float32"])
+def test_sme_matmul_with_const_b(data_shape, weight_shape, transpose_a, 
transpose_b, dtype):
+    """
+    Execution tests for matmul Scalable Matrix Extension (SME) schedule.
+    """
+    np.random.seed(0)
+
+    input_data = np.random.uniform(size=data_shape).astype(dtype)
+    inp = relay.var("data", shape=data_shape, dtype=dtype)
+    weight_data = np.random.uniform(size=weight_shape).astype(dtype)
+    weight = relay.const(weight_data, dtype=dtype)
+
+    matmul = relay.nn.matmul(inp, weight, transpose_a=transpose_a, 
transpose_b=transpose_b)
+    func = relay.Function(relay.analysis.free_vars(matmul), matmul)
+
+    ir_mod = tvm.IRModule.from_expr(func)
+    ir_mod = tvm.relay.transform.InferType()(ir_mod)
+
+    inputs = {"data": input_data}
+    params = {}
+    ref_outputs = generate_ref_data(ir_mod, inputs, params)
+
+    target = tvm.target.Target("llvm -mtriple=aarch64-none-elf 
-mattr=+v9.2a,+sme")
+    runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True})
+    executor = tvm.relay.backend.Executor(
+        "aot",
+        {
+            "interface-api": "packed",
+            "unpacked-api": False,
+        },
+    )
+    with tvm.transform.PassContext(
+        opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config
+    ), meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy):
+        executor_factory = tvm.relay.build(
+            ir_mod,
+            target=target,
+            executor=executor,
+            runtime=runtime,
+            params=params,
+        )
+    generated_func = executor_factory.lowered_ir_mods.items()[0][1][
+        "tvmgen_default_fused_nn_matmul"
+    ]
+    extra_memory_in_bytes = 
calculate_extra_workspace_size_from_scalable_extents(generated_func, 4)
+
+    test_model = AOTTestModel(
+        ir_mod, inputs, ref_outputs, params=params, 
extra_memory_in_bytes=extra_memory_in_bytes
+    )
+    compiled = AOTCompiledTestModel(test_model, executor_factory)
+
+    assembly = 
executor_factory.module.imported_modules[0].imported_modules[0].get_source("asm")
+    assert "fmopa" in assembly
+
+    assert run_and_check(
+        models=[compiled],
+        interface_api="packed",
+        runner=AOT_APROFILE_AEM_RUNNER,
+        print_output_on_mismatch=True,
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relay/strategy/test_select_implementation.py 
b/tests/python/relay/strategy/test_select_implementation.py
index d0767175d3..71dd688e29 100644
--- a/tests/python/relay/strategy/test_select_implementation.py
+++ b/tests/python/relay/strategy/test_select_implementation.py
@@ -258,18 +258,23 @@ def test_int8_depthwise_conv2d(target, expected_impl):
 
 @pytest.mark.parametrize(
     "target,expected_valid_impl,expected_impl",
-    [("llvm -device=arm_cpu", ["dense_pack.x86", "dense_nopack.x86"], 
"dense_pack.x86")],
+    [
+        (
+            "llvm -device=arm_cpu",
+            ["dense_pack.x86", "dense_nopack.x86"],
+            "dense_pack.x86",
+        ),
+    ],
 )
 def test_dense(target, expected_valid_impl, expected_impl):
     target = tvm.target.Target(target)
-
     data_shape = (30, 40)
     weight_shape = (30, 40)
     dtype = "float32"
 
     out = relay.nn.dense(
         relay.var("data", shape=data_shape, dtype=dtype),
-        relay.var("weight", shape=weight_shape, dtype=dtype),
+        relay.const(np.zeros((weight_shape)).astype(dtype)),
         out_dtype=dtype,
     )
     out = run_infer_type(out)
@@ -284,7 +289,51 @@ def test_dense(target, expected_valid_impl, expected_impl):
         ]
         valid_impl = relay.backend.te_compiler.get_valid_implementations(*args)
         selected_impl, _ = 
relay.backend.te_compiler.select_implementation(*args, use_autotvm=False)
+    assert len(valid_impl) == len(expected_valid_impl)
+    for impl in valid_impl:
+        assert impl.name in expected_valid_impl
+    assert selected_impl.name == expected_impl
 
+
+@pytest.mark.skipif(llvm_version_major() < 15, reason="Older versions of LLVM 
don't support SME.")
+@pytest.mark.parametrize(
+    "shape,expected_valid_impl,expected_impl",
+    [
+        (
+            (30, 40),
+            ["matmul.arm_cpu.sme", "dense_pack.x86", "dense_nopack.x86"],
+            "matmul.arm_cpu.sme",
+        ),
+        (
+            (5, 1),
+            ["dense_pack.x86", "dense_nopack.x86"],
+            "dense_pack.x86",
+        ),
+    ],
+)
+def test_dense_with_sme_target(shape, expected_valid_impl, expected_impl):
+    target = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu 
-mattr=+v9.2a,+sme")
+    data_shape = shape
+    weight_shape = shape
+    dtype = "float32"
+
+    out = relay.nn.dense(
+        relay.var("data", shape=data_shape, dtype=dtype),
+        relay.const(np.zeros((weight_shape)).astype(dtype)),
+        out_dtype=dtype,
+    )
+    out = run_infer_type(out)
+
+    with target:
+        args = [
+            out.op,
+            out.attrs,
+            [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, 
dtype)],
+            out.checked_type,
+            target,
+        ]
+        valid_impl = relay.backend.te_compiler.get_valid_implementations(*args)
+        selected_impl, _ = 
relay.backend.te_compiler.select_implementation(*args, use_autotvm=False)
     assert len(valid_impl) == len(expected_valid_impl)
     for impl in valid_impl:
         assert impl.name in expected_valid_impl
diff --git a/tests/python/relay/test_pass_alter_op_layout.py 
b/tests/python/relay/test_pass_alter_op_layout.py
index 831070299f..f74b31157a 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -23,6 +23,7 @@ from tvm import relay, topi
 from tvm.relay import transform, analysis
 from tvm.relay.testing.temp_op_attr import TempOpAttr
 from tvm.relay.testing import run_infer_type
+from tvm.target.codegen import llvm_version_major
 import numpy as np
 import tvm.testing
 from tvm.relay import testing
@@ -1451,6 +1452,61 @@ def test_alter_op_dense_packed_data():
             assert tvm.ir.structural_equal(a, b)
 
 
+@pytest.mark.skipif(
+    llvm_version_major() < 17, reason="SME is not supported in earlier 
versions of LLVM"
+)
+def test_alter_op_dense_arm_cpu_sme():
+    np.random.seed(0)
+    y_data = np.random.uniform(size=(64, 32)).astype("float32")
+
+    def before():
+        x = relay.var("x", shape=(32, 32), dtype="float32")
+        y = relay.const(y_data, dtype="float32")
+        dense = relay.nn.dense(x, y)
+        return relay.Function(analysis.free_vars(dense), dense)
+
+    def expected():
+        x = relay.var("x", shape=(32, 32), dtype="float32")
+        y = relay.const(y_data.transpose(), dtype="float32")
+        matmul = relay.nn.matmul(x, y)
+        return relay.Function(analysis.free_vars(matmul), matmul)
+
+    with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu 
-mattr=+v9.2a,+sme"):
+        with TempOpAttr("nn.dense", "FTVMAlterOpLayout", 
topi.arm_cpu.dense_alter_op._alter_dense):
+            a = run_opt_pass(before(), transform.AlterOpLayout())
+            b = run_opt_pass(expected(), transform.InferType())
+            assert tvm.ir.structural_equal(a, b)
+
+
+@pytest.mark.skipif(
+    llvm_version_major() < 17, reason="SME is not supported in earlier 
versions of LLVM"
+)
+@pytest.mark.parametrize(
+    "transpose_b,transform_b", [(False, lambda x: x), (True, lambda x: 
x.transpose())]
+)
+def test_alter_op_matmul_arm_cpu_sme(transpose_b, transform_b):
+    np.random.seed(0)
+    y_data = np.random.uniform(size=(64, 32)).astype("float32")
+
+    def before():
+        x = relay.var("x", shape=(96, 32), dtype="float32")
+        y = relay.const(y_data, dtype="float32")
+        dense = relay.nn.matmul(x, y, transpose_a=False, 
transpose_b=transpose_b)
+        return relay.Function(analysis.free_vars(dense), dense)
+
+    def expected():
+        x = relay.var("x", shape=(96, 32), dtype="float32")
+        y = relay.const(transform_b(y_data), dtype="float32")
+        matmul = relay.nn.matmul(x, y)
+        return relay.Function(analysis.free_vars(matmul), matmul)
+
+    with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu 
-mattr=+v9.2a,+sme"):
+        with TempOpAttr("nn.dense", "FTVMAlterOpLayout", 
topi.arm_cpu.dense_alter_op._alter_dense):
+            a = run_opt_pass(before(), transform.AlterOpLayout())
+            b = run_opt_pass(expected(), transform.InferType())
+            assert tvm.ir.structural_equal(a, b)
+
+
 def test_conv2d_strided_slice_packed_to_unpacked():
     """We do not support propagating through packed to unpacked layout"""
     x_shape = (1, 1, 1, 1, 4)
diff --git a/tests/python/topi/test_topi_matmul.py 
b/tests/python/topi/test_topi_matmul.py
index 4b05dd3813..a7b3965aee 100644
--- a/tests/python/topi/test_topi_matmul.py
+++ b/tests/python/topi/test_topi_matmul.py
@@ -14,12 +14,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+import pytest
 import numpy as np
+
 import tvm
 import tvm.testing
 from tvm import te
 from tvm import topi
 from tvm.topi.utils import get_const_tuple
+from tvm.topi.arm_cpu.matmul import compute_matmul_sme
 
 
 def with_tvm(lam, *args):
@@ -148,7 +152,17 @@ def test_tensordot():
     verify_tensordot((4, 3, 2, 2), (2, 4, 3, 5), ((1, 2, 0), (2, 0, 1)))
 
 
+@pytest.mark.parametrize("transpose_a,transpose_b", [(True, False), (False, 
True)])
+def test_unsupported_sme_matmul_compute_transpose(transpose_a, transpose_b):
+    """
+    SME matmul compute does not support transposed inputs for now.
+    """
+    err_msg = "Compute definition currently does not support transposed 
inputs."
+    with pytest.raises(AssertionError, match=err_msg) as e:
+        compute_matmul_sme(
+            te.placeholder((32, 32)), te.placeholder((32, 32)), None, None, 
transpose_a, transpose_b
+        )
+
+
 if __name__ == "__main__":
-    test_nn_matmul()
-    test_matmul()
-    test_tensordot()
+    tvm.testing.main()

Reply via email to