This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch revert-17003-sme-conv2d-fp32
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit b71a9a3827d81ac17da5f5bc608583f1a02bd0d8
Author: Tianqi Chen <tqc...@users.noreply.github.com>
AuthorDate: Tue May 28 19:56:33 2024 -0400

    Revert "[SME][TOPI] Add conv2d NHWC SME fp32 schedule (#17003)"
    
    This reverts commit cab54e0dee82f84d94cd65f8fe0432ee1c2f2e22.
---
 python/tvm/relay/op/strategy/arm_cpu.py            |  15 --
 python/tvm/testing/utils.py                        |   7 -
 python/tvm/topi/arm_cpu/arm_utils.py               |  18 +-
 python/tvm/topi/arm_cpu/conv2d.py                  | 238 +--------------------
 python/tvm/topi/arm_cpu/conv2d_gemm.py             |  12 +-
 python/tvm/topi/nn/conv2d.py                       |   6 +-
 src/arith/scalable_expression.cc                   |   7 +
 tests/python/arith/test_arith_simplify.py          |  10 +
 .../python/codegen/test_target_codegen_aarch64.py  |  69 +-----
 tests/python/relay/strategy/arm_cpu/test_conv2d.py | 138 +-----------
 .../relay/strategy/test_select_implementation.py   |   8 -
 tests/python/topi/test_topi_conv2d_nhwc.py         |  52 +----
 12 files changed, 45 insertions(+), 535 deletions(-)

diff --git a/python/tvm/relay/op/strategy/arm_cpu.py 
b/python/tvm/relay/op/strategy/arm_cpu.py
index 12f19462f7..5e94b38772 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -253,18 +253,6 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, 
target):
                         )
                 # Non-quantized cases
                 if is_aarch64 and data.dtype in ["float32", "float16"]:
-                    if (
-                        target.features.has_sme
-                        and data.dtype in ["float32"]
-                        and kernel.dtype in ["float32"]
-                        and out_type.dtype in ["float32"]
-                    ):
-                        strategy.add_implementation(
-                            
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME),
-                            lambda: None,
-                            name="conv2d_NHWC_hybrid_SME.arm_cpu",
-                            plevel=12,
-                        )
                     if target.features.has_sve:
                         # This strategy is currently suboptimal because of 
LLVM's limited support
                         # for scalable vector alias analysis, which causes 
redundant loads / stores
@@ -818,9 +806,6 @@ def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool:
     if matmul_block and sch.get(matmul_block).annotations.get("schedule_type", 
"") == "sme":
         topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
         return True
-    elif has_block(sch, "conv2d_gemm_output"):
-        topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(sch)
-        return True
 
     # Fallback to TE schedule for operators we have not written a special TIR 
schedule for
     return False
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index a208459dd8..84b631cf38 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -1071,13 +1071,6 @@ requires_aarch64_sve = Feature(
 )
 
 
-requires_aarch64_sme = Feature(
-    "arm_sme",
-    "AArch64 SME",
-    run_time_check=lambda: _has_cpu_feat("sme"),
-)
-
-
 requires_x86_vnni = Feature(
     "x86_vnni",
     "x86 VNNI Extensions",
diff --git a/python/tvm/topi/arm_cpu/arm_utils.py 
b/python/tvm/topi/arm_cpu/arm_utils.py
index 5c4b3c0456..f2e01c5aef 100644
--- a/python/tvm/topi/arm_cpu/arm_utils.py
+++ b/python/tvm/topi/arm_cpu/arm_utils.py
@@ -22,7 +22,7 @@ from tvm.target import Target
 from tvm.tir.expr import PrimExpr
 
 
-def get_tiling_A(interleave_A, in_dtype, use_sme=False):
+def get_tiling_A(interleave_A, in_dtype):
     """Compute the tiling information for matrix A in C=A*B,
     which corresponds to the im2col-transformed input matrix.
 
@@ -42,8 +42,6 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False):
         determines if A is expected to be interleaved
     in_dtype : str
         input datatype
-    use_sme : bool
-        determines if SME operations on scalable vectors are expected
 
     Returns
     ----------
@@ -67,11 +65,8 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False):
             # tile size should be 4x16
             tile_M = 4
             tile_K = 16
-    elif use_sme:
-        tile_M = 2 * 4 * tvm.tir.vscale()
-        tile_K = 2 * 4 * tvm.tir.vscale()
     else:
-        # In non-SME, non-quantized cases, A is not interleaved.
+        # In non-quantized cases, A is not interleaved.
         # We are loading 4 rows from A.
         # Each row will contain 4 elements, along the dimension of reduction
         tile_M = 4
@@ -80,7 +75,7 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False):
     return tile_M, tile_K
 
 
-def get_tiling_B_transformed(interleave_A, in_dtype, 
use_scalable_vectors=False, use_sme=False):
+def get_tiling_B_transformed(interleave_A, in_dtype, 
use_scalable_vectors=False):
     """Compute the tiling information for matrix B', where B'
     is the tiled, interleaved (and transposed) version of matrix B in C=A*B.
 
@@ -102,8 +97,6 @@ def get_tiling_B_transformed(interleave_A, in_dtype, 
use_scalable_vectors=False,
         input datatype
     use_scalable_vectors : bool
         determines if operations on scalable vectors are expected
-    use_sme : bool
-        determines if SME operations on scalable vectors are expected
 
 
     Returns
@@ -138,10 +131,7 @@ def get_tiling_B_transformed(interleave_A, in_dtype, 
use_scalable_vectors=False,
             # we load 4 rows of B' (i.e., 4 columns of B). Each of them will 
contain 16 elements
             tile_N = 4
             tile_K = 16
-    elif use_sme:
-        tile_N = 2 * 4 * tvm.tir.vscale()
-        tile_K = 2 * 4 * tvm.tir.vscale()
-    # In non-SME, non-quantized cases, A is not interleaved.
+    # In non-quantized cases, A is not interleaved.
     elif use_scalable_vectors:
         if in_dtype == "float16":
             # Each load from B' contains 32 * vscale elements (i.e. 32 * 
vscale columns from B)
diff --git a/python/tvm/topi/arm_cpu/conv2d.py 
b/python/tvm/topi/arm_cpu/conv2d.py
index 58c909301e..44c4f7f76f 100644
--- a/python/tvm/topi/arm_cpu/conv2d.py
+++ b/python/tvm/topi/arm_cpu/conv2d.py
@@ -21,15 +21,13 @@ from __future__ import absolute_import as _abs
 import tvm
 from tvm import te
 from tvm import autotvm
-from tvm.script import tir as T
 import tvm.contrib.nnpack
-from tvm.tir.schedule.analysis import has_block
 
 from ..utils import traverse_inline, get_const_tuple
 from .. import nn
 from ..nn.utils import get_const_int, get_pad_tuple
 from ..nn.winograd_util import winograd_transform_matrices
-from .arm_utils import get_tiling_A, get_tiling_B_transformed
+from .arm_utils import get_tiling_B_transformed
 from .conv2d_spatial_pack import (
     conv2d_spatial_pack_nchw,
     conv2d_spatial_pack_nhwc,
@@ -529,16 +527,13 @@ def compute_conv2d_NHWC(
     out_dtype,
     interleave_A,
     use_scalable_vectors=False,
-    use_sme=False,
 ):
     """Compute definition for conv2d NHWC"""
     N, IH, IW, IC = get_const_tuple(data.shape)
     KH, KW, _, OC = get_const_tuple(kernel.shape)
-    tile_N, tile_K = get_tiling_B_transformed(
-        interleave_A, data.dtype, use_scalable_vectors, use_sme
-    )
+    tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype, 
use_scalable_vectors)
 
-    kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K, 
use_scalable_vectors, use_sme)
+    kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K, 
use_scalable_vectors)
     return compute_conv2d_gemm_without_weight_transform(
         cfg,
         data,
@@ -551,7 +546,6 @@ def compute_conv2d_NHWC(
         OC,
         interleave_A,
         use_scalable_vectors,
-        use_sme,
     )
 
 
@@ -661,229 +655,3 @@ def compute_conv2d_NHWC_hybrid_SVE(cfg, data, kernel, 
strides, padding, dilation
 def schedule_conv2d_NHWC_hybrid_SVE(cfg, outs):
     """Interface for hybrid schedule_conv2d_NHWC_hybrid_SVE"""
     return schedule_conv2d_NHWC(cfg, outs, False)
-
-
-@autotvm.register_topi_compute("conv2d_NHWC_hybrid_SME.arm_cpu")
-def compute_conv2d_NHWC_hybrid_SME(cfg, data, kernel, strides, padding, 
dilation, out_dtype):
-    """Interface for hybrid compute_conv2d_NHWC_hybrid_SME"""
-    return compute_conv2d_NHWC(
-        cfg,
-        data,
-        kernel,
-        strides,
-        padding,
-        dilation,
-        out_dtype,
-        False,
-        True,
-        True,
-    )
-
-
-def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
-    """
-    Perform TIR scheduling for conv2d NHWC.
-    """
-    # Get ordered buffer list
-    primfunc = sch.mod["main"]
-    buffer_names = primfunc.params
-    buffer_list = [primfunc.buffer_map[buf] for buf in buffer_names]
-    dtype = buffer_list[0].dtype
-
-    # Determine PrimFunc blocks
-    block_list = [
-        "data_pad",
-        "data_im2col",
-        "T_reshape",
-        "A_padded_K",
-        "A_padded_M",
-        "weight_flatten",
-        "C",
-        "conv2d_gemm_output",
-    ]
-    func_blocks = {}
-    for block in block_list:
-        func_blocks[block] = sch.get_block(block) if has_block(sch, block) 
else None
-
-    gemm_block = func_blocks["C"]
-    b, m, n, k = sch.get_loops(gemm_block)
-
-    # Get tiling information
-    use_scalable_vectors = 
sch.get(func_blocks["conv2d_gemm_output"]).annotations[
-        "use_scalable_vectors"
-    ]
-    use_sme = sch.get(func_blocks["conv2d_gemm_output"]).annotations["use_sme"]
-    M_padded = sch.get(m).extent
-    N_padded = sch.get(n).extent
-    K_padded = sch.get(k).extent
-    tile_M, tile_K = get_tiling_A(False, dtype, use_sme)
-    tile_N, _ = get_tiling_B_transformed(False, dtype, use_scalable_vectors, 
use_sme)
-    tile_M = T.cast(tile_M, M_padded.dtype)
-    tile_N = T.cast(tile_N, N_padded.dtype)
-    tile_K = T.cast(tile_K, K_padded.dtype)
-
-    # GeMM
-    # Compute each tile_M x tile_N tile
-    # By summing up K outer products
-    if use_sme:
-        # pylint: disable=import-outside-toplevel
-        from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
-        from tvm.tir.tensor_intrin.arm_cpu import (
-            ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE,
-            ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
-            ARM_SME_INIT,
-            get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
-        )
-
-        # Interleave the padded im2col matrix utilizing the matrix tile
-        interleave_t_A_block = sch.cache_read(gemm_block, 0, "global")
-        sch.transform_layout(interleave_t_A_block, ("write", 0), lambda b, m, 
k: (b, k, m))
-        b, m, k = sch.get_loops(interleave_t_A_block)
-        mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True)
-        ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
-        sch.parallel(b)
-        sch.reorder(b, ko, mo, ki, mi)
-        sch.tensorize(ki, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE)
-
-        # Split and reorder the loops of the GeMM for tensorization
-        b, m, n, k = sch.get_loops(gemm_block)
-        mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True)
-        no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
-        sch.parallel(b)
-        sch.reorder(b, mo, no, mi, ni, k)
-
-        # Tensorize the GeMM output matrix initialization to zero
-        init_block = sch.decompose_reduction(gemm_block, mi)
-        sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT)
-
-        # Tensorize the GeMM update
-        sme_gemm_interleaved_intrin_name = 
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}"
-        tvm.tir.TensorIntrin.register(
-            sme_gemm_interleaved_intrin_name,
-            *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded),
-            override=True,
-        )
-        sch.tensorize(mi, sme_gemm_interleaved_intrin_name)
-
-        # Add pstate annotations
-        root_block = sch.get_block("root")
-        sch.annotate(
-            root_block, SMEAttributes.STREAMING_MODE, 
SMEAttributes.StreamingModeValues.ENABLED
-        )
-        sch.annotate(root_block, SMEAttributes.ZA_STORAGE, 
SMEAttributes.ZAStorageValues.NEW)
-    elif use_scalable_vectors:
-        mo, mi = sch.split(m, [None, tile_M])
-        no, ni = sch.split(n, [None, tile_N], disable_predication=True)
-        ko, ki = sch.split(k, [None, tile_K])
-        b_mo_fused = sch.fuse(b, mo)
-        sch.parallel(b_mo_fused)
-        sch.reorder(
-            b_mo_fused,
-            no,
-            ko,
-            ki,
-            mi,
-            ni,
-        )
-        sch.vectorize(ni)
-        sch.unroll(mi)
-
-        # GeMM - Init
-        # Initialise an entire GeMM tile at once
-        sch.decompose_reduction(gemm_block, ko)
-    else:
-        mo, mi = sch.split(m, [None, tile_M])
-        no, ni = sch.split(n, [None, tile_N])
-        ko, ki = sch.split(k, [None, tile_K])
-        ni_outer, ni_inner = sch.split(ni, [4, None])
-        b_mo_fused = sch.fuse(b, mo)
-        sch.parallel(b_mo_fused)
-        sch.reorder(
-            b_mo_fused,
-            no,
-            ko,
-            ki,
-            ni_outer,
-            mi,
-            ni_inner,
-        )
-        sch.vectorize(ni_inner)
-        sch.unroll(mi)
-        sch.unroll(ni_outer)
-
-        # GeMM - Init
-        # Initialise an entire GeMM tile at once
-        sch.decompose_reduction(gemm_block, ko)
-
-    # Input padding
-    if func_blocks["data_pad"]:
-        input_padding_block = func_blocks["data_pad"]
-        b, h, w, ic = sch.get_loops(input_padding_block)
-        b_h_fused = sch.fuse(b, h)
-        sch.parallel(b_h_fused)
-
-    # Im2col + padding to tile size
-    # Computed outside GeMM
-    if func_blocks["data_im2col"]:
-        im2col_block = func_blocks["data_im2col"]
-        b1, m1, k1 = sch.get_loops(im2col_block)
-        b_m_fused_1 = sch.fuse(b1, m1)
-        if func_blocks["A_padded_K"]:
-            im2col_pad_K_block = func_blocks["A_padded_K"]
-            b2, m2, k2 = sch.get_loops(im2col_pad_K_block)
-            b_m_fused_2 = sch.fuse(b2, m2)
-            sch.parallel(b_m_fused_2)
-            sch.compute_at(im2col_block, b_m_fused_2)
-            _, k1 = sch.get_loops(sch.get_block("data_im2col"))
-        elif func_blocks["A_padded_M"]:
-            im2col_pad_M_block = func_blocks["A_padded_M"]
-            b2, m2, k2 = sch.get_loops(im2col_pad_M_block)
-            b_m_fused_2 = sch.fuse(b2, m2)
-            sch.parallel(b_m_fused_1)
-            sch.parallel(b_m_fused_2)
-        else:
-            sch.parallel(b_m_fused_1)
-
-        K = sch.get(k1).extent.value
-        if K % 16 == 0:
-            split_factor = 16
-        elif K % 8 == 0:
-            split_factor = 8
-        else:
-            IC = buffer_list[0].shape[3]
-            split_factor = IC
-        k_outer, k_inner = sch.split(k1, [None, split_factor])
-        sch.vectorize(k_inner)
-        sch.unroll(k_outer)
-
-    # Reshape + padding to tile size
-    # Computed inside GeMM
-    elif func_blocks["T_reshape"]:
-        reshape_block = func_blocks["T_reshape"]
-        A_pad_block = func_blocks["A_padded_K"] if func_blocks["A_padded_K"] 
else None
-        A_pad_block = func_blocks["A_padded_M"] if func_blocks["A_padded_M"] 
else A_pad_block
-        if use_sme:
-            sch.compute_inline(reshape_block)
-        elif A_pad_block:
-            sch.compute_inline(reshape_block)
-            b, m, k = sch.get_loops(A_pad_block)
-            _, k_inner = sch.split(k, [None, tile_N])
-            sch.vectorize(k_inner)
-            sch.compute_at(A_pad_block, mi)
-        else:
-            sch.compute_at(reshape_block, mi)
-
-    # Weight flattening
-    if func_blocks["weight_flatten"]:
-        weight_flatten_block = func_blocks["weight_flatten"]
-        sch.compute_inline(weight_flatten_block)
-
-    # Conv2d output block
-    output_block = func_blocks["conv2d_gemm_output"]
-    n, h, w, c = sch.get_loops(output_block)
-    n_h_fused = sch.fuse(n, h)
-    _, inner = sch.split(c, [None, 4])
-    sch.vectorize(inner)
-    sch.parallel(n_h_fused)
-
-    return sch
diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py 
b/python/tvm/topi/arm_cpu/conv2d_gemm.py
index 0c3908bb70..5ff2ccb2c1 100644
--- a/python/tvm/topi/arm_cpu/conv2d_gemm.py
+++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py
@@ -68,7 +68,6 @@ def compute_conv2d_gemm_without_weight_transform(
     output_channels,
     interleave_A,
     use_scalable_vectors=False,
-    use_sme=False,
 ):
     """Compute conv2d by transforming the input,
     executing GEMM and transforming the output back"""
@@ -124,12 +123,9 @@ def compute_conv2d_gemm_without_weight_transform(
         )
 
     # Select the tiling strategy for A and B
-    tile_M, tile_K_A = arm_utils.get_tiling_A(interleave_A, in_dtype, use_sme)
+    tile_M, tile_K_A = arm_utils.get_tiling_A(interleave_A, in_dtype)
     tile_N, tile_K_B = arm_utils.get_tiling_B_transformed(
-        interleave_A,
-        in_dtype,
-        use_scalable_vectors,
-        use_sme,
+        interleave_A, in_dtype, use_scalable_vectors
     )
 
     # Pad to tiles (if necessary)
@@ -289,7 +285,7 @@ def compute_conv2d_gemm_without_weight_transform(
                 tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
                 - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
             )
-    elif use_scalable_vectors or use_sme:
+    elif use_scalable_vectors:
         assert len(B_interleaved_t.shape) == 2
         C = te.compute(
             (batches, M_padded, N_padded),
@@ -337,7 +333,7 @@ def compute_conv2d_gemm_without_weight_transform(
         out_shape,
         lambda b, x, y, z: (C(b, y + OW * x, z) + zero).astype(out_dtype),
         name="conv2d_gemm_output",
-        attrs={"use_scalable_vectors": use_scalable_vectors, "use_sme": 
use_sme},
+        attrs={"use_scalable_vectors": use_scalable_vectors},
     )
     return out
 
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index 8d61c62250..e21c0bd4e1 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -615,7 +615,7 @@ def conv2d_NCHWc_int8(
     )
 
 
-def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, 
use_scalable_vectors=False, use_sme=False):
+def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, 
use_scalable_vectors=False):
     """Weight transformation for winograd
 
     Parameters
@@ -628,8 +628,6 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, 
use_scalable_vectors=Fa
         Tile size across K axis of the weight transformation for ConvGemm. (K 
= KW * KH * IC)
     use_scalable_vectors : bool
         determines if operations on scalable vectors are expected
-    use_sme : bool
-        determines if SME operations on scalable vectors are expected
 
     Returns
     -------
@@ -654,7 +652,7 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, 
use_scalable_vectors=Fa
             kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), 
name="weight_padding"
         )
 
-    if use_sme or use_scalable_vectors:
+    if use_scalable_vectors:
         return kernel_flat
 
     if kernel.dtype in ["int8", "uint8"]:
diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc
index 5e3a65438d..e5f3bc28ba 100644
--- a/src/arith/scalable_expression.cc
+++ b/src/arith/scalable_expression.cc
@@ -71,8 +71,15 @@ std::optional<int> ExtractVscaleFactor(const PrimExpr& 
lanes) {
   }
 }
 
+bool IsComparison(const PrimExpr& expr) {
+  return expr->IsInstance<tir::LENode>() || expr->IsInstance<tir::LTNode>() ||
+         expr->IsInstance<tir::GENode>() || expr->IsInstance<tir::GTNode>() ||
+         expr->IsInstance<tir::EQNode>() || expr->IsInstance<tir::NENode>();
+}
+
 bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const 
PrimExpr& expr,
                                              const std::vector<unsigned int>& 
vscale_values) {
+  ICHECK(IsComparison(expr)) << "Expected comparison but got: " << expr;
   bool can_prove_expr = true;
   for (const unsigned int vscale_value : vscale_values) {
     PrimExpr result = SubstituteVScaleWithKnownValue(expr, vscale_value);
diff --git a/tests/python/arith/test_arith_simplify.py 
b/tests/python/arith/test_arith_simplify.py
index 1a876548af..fd8316d1e0 100644
--- a/tests/python/arith/test_arith_simplify.py
+++ b/tests/python/arith/test_arith_simplify.py
@@ -90,6 +90,16 @@ def 
test_simplify_vscale_comparison_without_sve_target(capfd):
     assert warning_msg in capture
 
 
+def test_simplify_vscale_non_comparison():
+    ana = tvm.arith.Analyzer()
+    vs = tvm.tir.vscale()
+
+    err_msg = r".*Expected comparison but got: T.vscale\(\) \* 4"
+    with pytest.raises(tvm.TVMError, match=err_msg):
+        with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
+            ana.can_prove(vs * 4)
+
+
 def test_regression_simplify_inf_recursion():
     ana = tvm.arith.Analyzer()
     cond = tir.Var("cond", "int32")
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py 
b/tests/python/codegen/test_target_codegen_aarch64.py
index 77c22761a9..d5446b0b1c 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -731,36 +731,20 @@ def 
test_unsupported_multiple_function_attributes(attr_key, attr_value):
     llvm_version_major() < 15, reason="Test requires an LLVM version of at 
least 15 to target SVE"
 )
 @pytest.mark.parametrize("dtype", ["float16", "float32"])
-@pytest.mark.parametrize(
-    "conv2d_impl",
-    [
-        (
-            tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE,
-            tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE,
-            False,
-        ),
-        (
-            tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE,
-            tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR,
-            True,
-        ),
-    ],
-)
-def test_conv2d_sve(dtype, conv2d_impl):
+def test_conv2d_sve(dtype):
     target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
 
-    def check_correct_assembly(dtype, compute, schedule, use_tir_schedule):
+    def check_correct_assembly(dtype):
         A = te.placeholder((1, 32, 32, 3), dtype=dtype, name="A")
         W = te.placeholder((3, 3, 3, 8), dtype=dtype, name="B")
         stride = padding = dilation = 1
+
+        compute = tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE
+        schedule = tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE
         B = compute(A, W, stride, padding, dilation, dtype)
-        if use_tir_schedule:
-            func = te.create_prim_func([A, W, B])
-            sch = schedule(tvm.tir.Schedule(func))
-            f = tvm.build(sch.mod["main"], target)
-        else:
-            s = schedule([B])
-            f = tvm.build(s, [A, W, B], target)
+        s = schedule([B])
+
+        f = tvm.build(s, [A, W, B], target)
         assembly = f.get_source("asm")
 
         loads = re.findall(r"ld1[r]?[q]?[whdb]\t{\s?z", assembly)
@@ -774,43 +758,6 @@ def test_conv2d_sve(dtype, conv2d_impl):
         assert len(compute_ops) > 0
         assert len(stores) > 0
 
-    with tvm.target.Target(target):
-        check_correct_assembly(dtype, *conv2d_impl)
-
-
-@pytest.mark.skipif(
-    llvm_version_major() < 16, reason="Test requires an LLVM version of at 
least 16 to target SME"
-)
-@pytest.mark.parametrize("dtype", ["float32"])
-def test_conv2d_sme(dtype):
-    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+v9a,+sme"
-
-    def check_correct_assembly(dtype):
-        A = te.placeholder((1, 32, 32, 3), dtype=dtype, name="A")
-        W = te.placeholder((3, 3, 3, 8), dtype=dtype, name="B")
-        stride = padding = dilation = 1
-
-        B = tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME(A, W, stride, 
padding, dilation, dtype)
-        func = te.create_prim_func([A, W, B])
-        sch = 
tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(tvm.tir.Schedule(func))
-        f = tvm.build(sch.mod["main"], target)
-
-        assembly = f.get_source("asm")
-        smstart = re.findall(r"smstart\t(sm|za)", assembly)
-        loads = re.findall(r"ld1[whdb]\t{\s?za", assembly)
-        mopa = re.findall(
-            r"fmopa\tza[0-9].[shdb],( p[0-9]/[zm],)?( p[0-9]/[zm],)? 
z[0-9].[shdb], z[0-9].[shdb]",
-            assembly,
-        )
-        stores = re.findall(r"st1[whdb]\t{\s?za", assembly)
-        smstop = re.findall(r"smstop\t(sm|za)", assembly)
-
-        assert len(smstart) > 0
-        assert len(loads) > 0
-        assert len(mopa) > 0
-        assert len(stores) > 0
-        assert len(smstop) > 0
-
     with tvm.target.Target(target):
         check_correct_assembly(dtype=dtype)
 
diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d.py 
b/tests/python/relay/strategy/arm_cpu/test_conv2d.py
index 2708094afb..1b9c1a5e2e 100644
--- a/tests/python/relay/strategy/arm_cpu/test_conv2d.py
+++ b/tests/python/relay/strategy/arm_cpu/test_conv2d.py
@@ -16,21 +16,8 @@
 # under the License.
 """Tests for arm_cpu schedules for regular conv2d."""
 
-import pytest
-import numpy as np
-
-import tvm
-import tvm.topi.testing
-from tvm import relay
 from test_generalized_conv2d import GeneralizedConv2dTests
 from tvm.testing import fixture, main, parameter, parameters
-from tvm.topi.nn.utils import get_pad_tuple
-from tvm.topi.utils import get_const_tuple
-from tvm.target.codegen import llvm_version_major
-from tvm.testing.aot import AOTTestModel, AOTCompiledTestModel, run_and_check, 
generate_ref_data
-from tvm.micro.testing.aot_test_utils import AOT_APROFILE_AEM_RUNNER
-from tvm.relay.op.strategy.arm_cpu import arm_cpu_tir_strategy
-from scalable_utils import calculate_extra_workspace_size_from_scalable_extents
 
 
 class Conv2dTests(GeneralizedConv2dTests):
@@ -120,128 +107,5 @@ class TestConv2d_NCHW_Spatial_Pack(Conv2dTests):
     schedule_name = parameter("conv2d_nchw_spatial_pack.arm_cpu")
 
 
-dtype = tvm.testing.parameter("float32")
-
-batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = 
tvm.testing.parameters(
-    # Pad M, N, K
-    (1, 1, 1, 1, 1, 1, "SAME", 1),
-    (1, 1, 3, 15, 1, 1, "SAME", 1),
-    # Pad M, K
-    (1, 3, 9, 16, 3, 1, "SAME", 1),
-    # Pad M, N
-    (1, 2, 9, 15, 4, 1, "SAME", 1),
-    # Pad K, N
-    (1, 7, 4, 15, 3, 1, "SAME", 1),
-    # Pad M
-    (1, 2, 9, 16, 4, 1, "SAME", 1),
-    # Pad K
-    (1, 7, 4, 16, 3, 1, "SAME", 1),
-    # Pad N
-    (1, 2, 4, 15, 4, 1, "SAME", 1),
-    (1, 2, 4, 20, 1, 1, "SAME", 1),
-    # Large workloads
-    (1, 128, 32, 128, 3, 1, "SAME", 1),
-    (4, 64, 16, 64, 5, 2, "SAME", 1),
-    (1, 128, 32, 128, 3, 1, "VALID", 1),
-    (4, 64, 16, 64, 5, 2, "VALID", 1),
-    (1, 64, 16, 64, 3, 2, (0, 0, 1, 1), 1),
-    (1, 64, 16, 64, 3, 2, (1, 1, 2, 2), 1),
-    (1, 64, 16, 64, 5, 2, (3, 3, 2, 2), 1),
-    (1, 64, 16, 64, 3, 2, (0, 1, 2, 3), 1),
-    (1, 64, 32, 64, 3, 1, "SAME", 2),
-    (1, 64, 32, 64, 3, 1, (1, 1, 2, 2), 2),
-)
-
-
-@tvm.testing.fixture()
-def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, 
padding, dilation):
-    np.random.seed(0)
-    in_height = in_width = in_size
-    a_shape = (batch, in_height, in_width, in_channel)
-    w_shape = (kernel, kernel, in_channel, num_filter)
-
-    a_np = np.random.uniform(size=a_shape).astype(dtype)
-    w_np = np.random.uniform(size=w_shape).astype(dtype)
-    return a_np, w_np
-
-
-@pytest.mark.skipif(
-    llvm_version_major() < 16, reason="SME is not supported in earlier 
versions of LLVM"
-)
-@tvm.testing.requires_aprofile_aem_fvp
-def test_conv2d_fp32(target, ref_data, dtype, stride, padding, dilation):
-    a_np, w_np = ref_data
-    dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
-
-    kernel_size = get_const_tuple(w_np.shape[:2])
-    out_channels = w_np.shape[3]
-
-    x = relay.var("data", shape=a_np.shape, dtype=dtype)
-    weight = relay.const(w_np, dtype=dtype)
-    conv2d = relay.nn.conv2d(
-        x,
-        weight,
-        channels=out_channels,
-        kernel_size=kernel_size,
-        strides=stride,
-        dilation=dilation,
-        padding=get_pad_tuple(padding, dw_np.shape[:2]),
-        data_layout="NHWC",
-        kernel_layout="HWIO",
-        out_dtype=dtype,
-    )
-
-    func = relay.Function(relay.analysis.free_vars(conv2d), conv2d)
-
-    ir_mod = tvm.IRModule.from_expr(func)
-    ir_mod = tvm.relay.transform.InferType()(ir_mod)
-
-    inputs = {"data": a_np}
-    params = {}
-    ref_outputs = generate_ref_data(ir_mod, inputs, params)
-
-    target = tvm.target.Target("llvm -mtriple=aarch64-none-elf 
-mattr=+v9.2a,+sme")
-    runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True})
-    executor = tvm.relay.backend.Executor(
-        "aot",
-        {
-            "interface-api": "packed",
-            "unpacked-api": False,
-        },
-    )
-
-    with tvm.transform.PassContext(
-        opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config
-    ), target, 
tvm.meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy):
-        executor_factory = tvm.relay.build(
-            ir_mod,
-            target=target,
-            executor=executor,
-            runtime=runtime,
-            params=params,
-        )
-    generated_func = executor_factory.lowered_ir_mods.items()[0][1][
-        "tvmgen_default_fused_nn_conv2d"
-    ]
-    extra_memory_in_bytes = 
calculate_extra_workspace_size_from_scalable_extents(generated_func, 4)
-
-    test_model = AOTTestModel(
-        ir_mod, inputs, ref_outputs, params=params, 
extra_memory_in_bytes=extra_memory_in_bytes
-    )
-    compiled = AOTCompiledTestModel(test_model, executor_factory)
-
-    assembly = (
-        
compiled.executor_factory.module.imported_modules[0].imported_modules[0].get_source("asm")
-    )
-    assert "fmopa" in assembly
-
-    assert run_and_check(
-        models=[compiled],
-        interface_api="packed",
-        runner=AOT_APROFILE_AEM_RUNNER,
-        print_output_on_mismatch=True,
-    )
-
-
 if __name__ == "__main__":
-    tvm.testing.main()
+    main()
diff --git a/tests/python/relay/strategy/test_select_implementation.py 
b/tests/python/relay/strategy/test_select_implementation.py
index 01a914e793..71dd688e29 100644
--- a/tests/python/relay/strategy/test_select_implementation.py
+++ b/tests/python/relay/strategy/test_select_implementation.py
@@ -161,10 +161,6 @@ def test_int8_conv2d(target, expected_impl):
             "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a",
             "conv2d_NHWC_hybrid_without_transform.arm_cpu",
         ),
-        (
-            "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu 
-mattr=+v9.2a,+sme",
-            "conv2d_NHWC_hybrid_SME.arm_cpu",
-        ),
     ],
 )
 def test_fp32_conv2d(target, expected_impl):
@@ -201,10 +197,6 @@ def test_fp32_conv2d(target, expected_impl):
             "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v9a",
             "conv2d_NHWC_hybrid_without_transform.arm_cpu",
         ),
-        (
-            "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu 
-mattr=+v9.2a,+sme",
-            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
-        ),
     ],
 )
 def test_fp16_conv2d(target, expected_impl):
diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py 
b/tests/python/topi/test_topi_conv2d_nhwc.py
index 02f16b59c0..b5c9518d34 100644
--- a/tests/python/topi/test_topi_conv2d_nhwc.py
+++ b/tests/python/topi/test_topi_conv2d_nhwc.py
@@ -17,12 +17,10 @@
 """Example code to do convolution."""
 import os
 import platform
-import pytest
 import numpy as np
 import tvm
 from tvm import te
 from tvm import topi
-from tvm.target.codegen import llvm_version_major
 import tvm.topi.testing
 from tvm.contrib.pickle_memoize import memoize
 from tvm.topi.utils import get_const_tuple
@@ -53,37 +51,16 @@ device = tvm.testing.parameter(
         "llvm --device arm_cpu --mtriple aarch64-linux-gnu",
         topi.arm_cpu.conv2d_nhwc_spatial_pack,
         topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack,
-        False,
     ),
     (
         "llvm --device arm_cpu --mtriple aarch64-linux-gnu 
-mattr=+v8.2a,+fullfp16",
         topi.arm_cpu.compute_conv2d_NHWC_hybrid,
         topi.arm_cpu.schedule_conv2d_NHWC_hybrid,
-        False,
     ),
     (
         "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.6a,+sve",
         topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE,
         topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE,
-        False,
-    ),
-    (
-        "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a",
-        topi.arm_cpu.compute_conv2d_NHWC_hybrid,
-        topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR,
-        True,
-    ),
-    (
-        "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.6a,+sve",
-        topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE,
-        topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR,
-        True,
-    ),
-    (
-        "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v9a,+sme",
-        topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME,
-        topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR,
-        True,
     ),
 )
 
@@ -91,7 +68,6 @@ dtype = tvm.testing.parameter("float16", "float32")
 
 batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = 
tvm.testing.parameters(
     # Pad M, N, K
-    (1, 1, 1, 1, 1, 1, "SAME", 1),
     (1, 1, 3, 15, 1, 1, "SAME", 1),
     # Pad M, K
     (1, 3, 9, 16, 3, 1, "SAME", 1),
@@ -163,31 +139,16 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, 
stride, padding, dilation):
     A = te.placeholder(a_np.shape, name="A", dtype=dtype)
     W = te.placeholder(w_np.shape, name="W", dtype=dtype)
 
-    target_string, compute, schedule, use_tir_schedule = device
-    dev = tvm.device(target_string, 0)
-    target = tvm.target.Target(target_string)
-
-    if target.features.has_sve and llvm_version_major() < 15:
-        pytest.skip(f"LLVM {llvm_version_major()} does not support targetting 
SVE.")
-
-    if target.features.has_sme and llvm_version_major() < 16:
-        pytest.skip(f"LLVM {llvm_version_major()} does not support targetting 
SME.")
-
-    if target.features.has_sme and dtype == "float16":
-        pytest.skip(f"Conv2d fp16 targetting SME not implemented.")
+    target, compute, schedule = device
+    dev = tvm.device(target, 0)
 
-    with target:
+    with tvm.target.Target(target) as target:
+        B = compute(A, W, stride, padding, dilation, dtype)
+        s = schedule([B])
         a = tvm.nd.array(a_np, dev)
         w = tvm.nd.array(w_np, dev)
-        B = compute(A, W, stride, padding, dilation, dtype)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), 
dev)
-        if use_tir_schedule:
-            primfunc = te.create_prim_func([A, W, B])
-            sch = schedule(tvm.tir.Schedule(primfunc))
-            func = tvm.build(sch.mod["main"], target)
-        else:
-            s = schedule([B])
-            func = tvm.build(s, [A, W, B], target)
+        func = tvm.build(s, [A, W, B], target)
 
         # Run only on AArch64 devices
         # Do not run SVE schedules on non-SVE devices
@@ -199,7 +160,6 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, 
padding, dilation):
                 and target.features.has_fp16_simd
                 and not tvm.testing.requires_arm_fp16.run_time_check()
             )
-            or (target.features.has_sme and not 
tvm.testing.requires_aarch64_sme.run_time_check())
         )
         if build_only:
             return


Reply via email to