This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch revert-17003-sme-conv2d-fp32 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit b71a9a3827d81ac17da5f5bc608583f1a02bd0d8 Author: Tianqi Chen <tqc...@users.noreply.github.com> AuthorDate: Tue May 28 19:56:33 2024 -0400 Revert "[SME][TOPI] Add conv2d NHWC SME fp32 schedule (#17003)" This reverts commit cab54e0dee82f84d94cd65f8fe0432ee1c2f2e22. --- python/tvm/relay/op/strategy/arm_cpu.py | 15 -- python/tvm/testing/utils.py | 7 - python/tvm/topi/arm_cpu/arm_utils.py | 18 +- python/tvm/topi/arm_cpu/conv2d.py | 238 +-------------------- python/tvm/topi/arm_cpu/conv2d_gemm.py | 12 +- python/tvm/topi/nn/conv2d.py | 6 +- src/arith/scalable_expression.cc | 7 + tests/python/arith/test_arith_simplify.py | 10 + .../python/codegen/test_target_codegen_aarch64.py | 69 +----- tests/python/relay/strategy/arm_cpu/test_conv2d.py | 138 +----------- .../relay/strategy/test_select_implementation.py | 8 - tests/python/topi/test_topi_conv2d_nhwc.py | 52 +---- 12 files changed, 45 insertions(+), 535 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 12f19462f7..5e94b38772 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -253,18 +253,6 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ) # Non-quantized cases if is_aarch64 and data.dtype in ["float32", "float16"]: - if ( - target.features.has_sme - and data.dtype in ["float32"] - and kernel.dtype in ["float32"] - and out_type.dtype in ["float32"] - ): - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME), - lambda: None, - name="conv2d_NHWC_hybrid_SME.arm_cpu", - plevel=12, - ) if target.features.has_sve: # This strategy is currently suboptimal because of LLVM's limited support # for scalable vector alias analysis, which causes redundant loads / stores @@ -818,9 +806,6 @@ def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool: if matmul_block and sch.get(matmul_block).annotations.get("schedule_type", "") == "sme": topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch) return True - elif has_block(sch, "conv2d_gemm_output"): - topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(sch) - return True # Fallback to TE schedule for operators we have not written a special TIR schedule for return False diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index a208459dd8..84b631cf38 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1071,13 +1071,6 @@ requires_aarch64_sve = Feature( ) -requires_aarch64_sme = Feature( - "arm_sme", - "AArch64 SME", - run_time_check=lambda: _has_cpu_feat("sme"), -) - - requires_x86_vnni = Feature( "x86_vnni", "x86 VNNI Extensions", diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index 5c4b3c0456..f2e01c5aef 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -22,7 +22,7 @@ from tvm.target import Target from tvm.tir.expr import PrimExpr -def get_tiling_A(interleave_A, in_dtype, use_sme=False): +def get_tiling_A(interleave_A, in_dtype): """Compute the tiling information for matrix A in C=A*B, which corresponds to the im2col-transformed input matrix. @@ -42,8 +42,6 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False): determines if A is expected to be interleaved in_dtype : str input datatype - use_sme : bool - determines if SME operations on scalable vectors are expected Returns ---------- @@ -67,11 +65,8 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False): # tile size should be 4x16 tile_M = 4 tile_K = 16 - elif use_sme: - tile_M = 2 * 4 * tvm.tir.vscale() - tile_K = 2 * 4 * tvm.tir.vscale() else: - # In non-SME, non-quantized cases, A is not interleaved. + # In non-quantized cases, A is not interleaved. # We are loading 4 rows from A. # Each row will contain 4 elements, along the dimension of reduction tile_M = 4 @@ -80,7 +75,7 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False): return tile_M, tile_K -def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False, use_sme=False): +def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False): """Compute the tiling information for matrix B', where B' is the tiled, interleaved (and transposed) version of matrix B in C=A*B. @@ -102,8 +97,6 @@ def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False, input datatype use_scalable_vectors : bool determines if operations on scalable vectors are expected - use_sme : bool - determines if SME operations on scalable vectors are expected Returns @@ -138,10 +131,7 @@ def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False, # we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements tile_N = 4 tile_K = 16 - elif use_sme: - tile_N = 2 * 4 * tvm.tir.vscale() - tile_K = 2 * 4 * tvm.tir.vscale() - # In non-SME, non-quantized cases, A is not interleaved. + # In non-quantized cases, A is not interleaved. elif use_scalable_vectors: if in_dtype == "float16": # Each load from B' contains 32 * vscale elements (i.e. 32 * vscale columns from B) diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index 58c909301e..44c4f7f76f 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -21,15 +21,13 @@ from __future__ import absolute_import as _abs import tvm from tvm import te from tvm import autotvm -from tvm.script import tir as T import tvm.contrib.nnpack -from tvm.tir.schedule.analysis import has_block from ..utils import traverse_inline, get_const_tuple from .. import nn from ..nn.utils import get_const_int, get_pad_tuple from ..nn.winograd_util import winograd_transform_matrices -from .arm_utils import get_tiling_A, get_tiling_B_transformed +from .arm_utils import get_tiling_B_transformed from .conv2d_spatial_pack import ( conv2d_spatial_pack_nchw, conv2d_spatial_pack_nhwc, @@ -529,16 +527,13 @@ def compute_conv2d_NHWC( out_dtype, interleave_A, use_scalable_vectors=False, - use_sme=False, ): """Compute definition for conv2d NHWC""" N, IH, IW, IC = get_const_tuple(data.shape) KH, KW, _, OC = get_const_tuple(kernel.shape) - tile_N, tile_K = get_tiling_B_transformed( - interleave_A, data.dtype, use_scalable_vectors, use_sme - ) + tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype, use_scalable_vectors) - kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors, use_sme) + kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors) return compute_conv2d_gemm_without_weight_transform( cfg, data, @@ -551,7 +546,6 @@ def compute_conv2d_NHWC( OC, interleave_A, use_scalable_vectors, - use_sme, ) @@ -661,229 +655,3 @@ def compute_conv2d_NHWC_hybrid_SVE(cfg, data, kernel, strides, padding, dilation def schedule_conv2d_NHWC_hybrid_SVE(cfg, outs): """Interface for hybrid schedule_conv2d_NHWC_hybrid_SVE""" return schedule_conv2d_NHWC(cfg, outs, False) - - -@autotvm.register_topi_compute("conv2d_NHWC_hybrid_SME.arm_cpu") -def compute_conv2d_NHWC_hybrid_SME(cfg, data, kernel, strides, padding, dilation, out_dtype): - """Interface for hybrid compute_conv2d_NHWC_hybrid_SME""" - return compute_conv2d_NHWC( - cfg, - data, - kernel, - strides, - padding, - dilation, - out_dtype, - False, - True, - True, - ) - - -def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): - """ - Perform TIR scheduling for conv2d NHWC. - """ - # Get ordered buffer list - primfunc = sch.mod["main"] - buffer_names = primfunc.params - buffer_list = [primfunc.buffer_map[buf] for buf in buffer_names] - dtype = buffer_list[0].dtype - - # Determine PrimFunc blocks - block_list = [ - "data_pad", - "data_im2col", - "T_reshape", - "A_padded_K", - "A_padded_M", - "weight_flatten", - "C", - "conv2d_gemm_output", - ] - func_blocks = {} - for block in block_list: - func_blocks[block] = sch.get_block(block) if has_block(sch, block) else None - - gemm_block = func_blocks["C"] - b, m, n, k = sch.get_loops(gemm_block) - - # Get tiling information - use_scalable_vectors = sch.get(func_blocks["conv2d_gemm_output"]).annotations[ - "use_scalable_vectors" - ] - use_sme = sch.get(func_blocks["conv2d_gemm_output"]).annotations["use_sme"] - M_padded = sch.get(m).extent - N_padded = sch.get(n).extent - K_padded = sch.get(k).extent - tile_M, tile_K = get_tiling_A(False, dtype, use_sme) - tile_N, _ = get_tiling_B_transformed(False, dtype, use_scalable_vectors, use_sme) - tile_M = T.cast(tile_M, M_padded.dtype) - tile_N = T.cast(tile_N, N_padded.dtype) - tile_K = T.cast(tile_K, K_padded.dtype) - - # GeMM - # Compute each tile_M x tile_N tile - # By summing up K outer products - if use_sme: - # pylint: disable=import-outside-toplevel - from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes - from tvm.tir.tensor_intrin.arm_cpu import ( - ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE, - ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA, - ARM_SME_INIT, - get_sme_gemm_interleaved_mopa_2svlx2svl_intrin, - ) - - # Interleave the padded im2col matrix utilizing the matrix tile - interleave_t_A_block = sch.cache_read(gemm_block, 0, "global") - sch.transform_layout(interleave_t_A_block, ("write", 0), lambda b, m, k: (b, k, m)) - b, m, k = sch.get_loops(interleave_t_A_block) - mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True) - ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True) - sch.parallel(b) - sch.reorder(b, ko, mo, ki, mi) - sch.tensorize(ki, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE) - - # Split and reorder the loops of the GeMM for tensorization - b, m, n, k = sch.get_loops(gemm_block) - mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True) - no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True) - sch.parallel(b) - sch.reorder(b, mo, no, mi, ni, k) - - # Tensorize the GeMM output matrix initialization to zero - init_block = sch.decompose_reduction(gemm_block, mi) - sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT) - - # Tensorize the GeMM update - sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}" - tvm.tir.TensorIntrin.register( - sme_gemm_interleaved_intrin_name, - *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded), - override=True, - ) - sch.tensorize(mi, sme_gemm_interleaved_intrin_name) - - # Add pstate annotations - root_block = sch.get_block("root") - sch.annotate( - root_block, SMEAttributes.STREAMING_MODE, SMEAttributes.StreamingModeValues.ENABLED - ) - sch.annotate(root_block, SMEAttributes.ZA_STORAGE, SMEAttributes.ZAStorageValues.NEW) - elif use_scalable_vectors: - mo, mi = sch.split(m, [None, tile_M]) - no, ni = sch.split(n, [None, tile_N], disable_predication=True) - ko, ki = sch.split(k, [None, tile_K]) - b_mo_fused = sch.fuse(b, mo) - sch.parallel(b_mo_fused) - sch.reorder( - b_mo_fused, - no, - ko, - ki, - mi, - ni, - ) - sch.vectorize(ni) - sch.unroll(mi) - - # GeMM - Init - # Initialise an entire GeMM tile at once - sch.decompose_reduction(gemm_block, ko) - else: - mo, mi = sch.split(m, [None, tile_M]) - no, ni = sch.split(n, [None, tile_N]) - ko, ki = sch.split(k, [None, tile_K]) - ni_outer, ni_inner = sch.split(ni, [4, None]) - b_mo_fused = sch.fuse(b, mo) - sch.parallel(b_mo_fused) - sch.reorder( - b_mo_fused, - no, - ko, - ki, - ni_outer, - mi, - ni_inner, - ) - sch.vectorize(ni_inner) - sch.unroll(mi) - sch.unroll(ni_outer) - - # GeMM - Init - # Initialise an entire GeMM tile at once - sch.decompose_reduction(gemm_block, ko) - - # Input padding - if func_blocks["data_pad"]: - input_padding_block = func_blocks["data_pad"] - b, h, w, ic = sch.get_loops(input_padding_block) - b_h_fused = sch.fuse(b, h) - sch.parallel(b_h_fused) - - # Im2col + padding to tile size - # Computed outside GeMM - if func_blocks["data_im2col"]: - im2col_block = func_blocks["data_im2col"] - b1, m1, k1 = sch.get_loops(im2col_block) - b_m_fused_1 = sch.fuse(b1, m1) - if func_blocks["A_padded_K"]: - im2col_pad_K_block = func_blocks["A_padded_K"] - b2, m2, k2 = sch.get_loops(im2col_pad_K_block) - b_m_fused_2 = sch.fuse(b2, m2) - sch.parallel(b_m_fused_2) - sch.compute_at(im2col_block, b_m_fused_2) - _, k1 = sch.get_loops(sch.get_block("data_im2col")) - elif func_blocks["A_padded_M"]: - im2col_pad_M_block = func_blocks["A_padded_M"] - b2, m2, k2 = sch.get_loops(im2col_pad_M_block) - b_m_fused_2 = sch.fuse(b2, m2) - sch.parallel(b_m_fused_1) - sch.parallel(b_m_fused_2) - else: - sch.parallel(b_m_fused_1) - - K = sch.get(k1).extent.value - if K % 16 == 0: - split_factor = 16 - elif K % 8 == 0: - split_factor = 8 - else: - IC = buffer_list[0].shape[3] - split_factor = IC - k_outer, k_inner = sch.split(k1, [None, split_factor]) - sch.vectorize(k_inner) - sch.unroll(k_outer) - - # Reshape + padding to tile size - # Computed inside GeMM - elif func_blocks["T_reshape"]: - reshape_block = func_blocks["T_reshape"] - A_pad_block = func_blocks["A_padded_K"] if func_blocks["A_padded_K"] else None - A_pad_block = func_blocks["A_padded_M"] if func_blocks["A_padded_M"] else A_pad_block - if use_sme: - sch.compute_inline(reshape_block) - elif A_pad_block: - sch.compute_inline(reshape_block) - b, m, k = sch.get_loops(A_pad_block) - _, k_inner = sch.split(k, [None, tile_N]) - sch.vectorize(k_inner) - sch.compute_at(A_pad_block, mi) - else: - sch.compute_at(reshape_block, mi) - - # Weight flattening - if func_blocks["weight_flatten"]: - weight_flatten_block = func_blocks["weight_flatten"] - sch.compute_inline(weight_flatten_block) - - # Conv2d output block - output_block = func_blocks["conv2d_gemm_output"] - n, h, w, c = sch.get_loops(output_block) - n_h_fused = sch.fuse(n, h) - _, inner = sch.split(c, [None, 4]) - sch.vectorize(inner) - sch.parallel(n_h_fused) - - return sch diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 0c3908bb70..5ff2ccb2c1 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -68,7 +68,6 @@ def compute_conv2d_gemm_without_weight_transform( output_channels, interleave_A, use_scalable_vectors=False, - use_sme=False, ): """Compute conv2d by transforming the input, executing GEMM and transforming the output back""" @@ -124,12 +123,9 @@ def compute_conv2d_gemm_without_weight_transform( ) # Select the tiling strategy for A and B - tile_M, tile_K_A = arm_utils.get_tiling_A(interleave_A, in_dtype, use_sme) + tile_M, tile_K_A = arm_utils.get_tiling_A(interleave_A, in_dtype) tile_N, tile_K_B = arm_utils.get_tiling_B_transformed( - interleave_A, - in_dtype, - use_scalable_vectors, - use_sme, + interleave_A, in_dtype, use_scalable_vectors ) # Pad to tiles (if necessary) @@ -289,7 +285,7 @@ def compute_conv2d_gemm_without_weight_transform( tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] ) - elif use_scalable_vectors or use_sme: + elif use_scalable_vectors: assert len(B_interleaved_t.shape) == 2 C = te.compute( (batches, M_padded, N_padded), @@ -337,7 +333,7 @@ def compute_conv2d_gemm_without_weight_transform( out_shape, lambda b, x, y, z: (C(b, y + OW * x, z) + zero).astype(out_dtype), name="conv2d_gemm_output", - attrs={"use_scalable_vectors": use_scalable_vectors, "use_sme": use_sme}, + attrs={"use_scalable_vectors": use_scalable_vectors}, ) return out diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 8d61c62250..e21c0bd4e1 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -615,7 +615,7 @@ def conv2d_NCHWc_int8( ) -def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=False, use_sme=False): +def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=False): """Weight transformation for winograd Parameters @@ -628,8 +628,6 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=Fa Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC) use_scalable_vectors : bool determines if operations on scalable vectors are expected - use_sme : bool - determines if SME operations on scalable vectors are expected Returns ------- @@ -654,7 +652,7 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=Fa kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding" ) - if use_sme or use_scalable_vectors: + if use_scalable_vectors: return kernel_flat if kernel.dtype in ["int8", "uint8"]: diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index 5e3a65438d..e5f3bc28ba 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -71,8 +71,15 @@ std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) { } } +bool IsComparison(const PrimExpr& expr) { + return expr->IsInstance<tir::LENode>() || expr->IsInstance<tir::LTNode>() || + expr->IsInstance<tir::GENode>() || expr->IsInstance<tir::GTNode>() || + expr->IsInstance<tir::EQNode>() || expr->IsInstance<tir::NENode>(); +} + bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr, const std::vector<unsigned int>& vscale_values) { + ICHECK(IsComparison(expr)) << "Expected comparison but got: " << expr; bool can_prove_expr = true; for (const unsigned int vscale_value : vscale_values) { PrimExpr result = SubstituteVScaleWithKnownValue(expr, vscale_value); diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index 1a876548af..fd8316d1e0 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -90,6 +90,16 @@ def test_simplify_vscale_comparison_without_sve_target(capfd): assert warning_msg in capture +def test_simplify_vscale_non_comparison(): + ana = tvm.arith.Analyzer() + vs = tvm.tir.vscale() + + err_msg = r".*Expected comparison but got: T.vscale\(\) \* 4" + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + ana.can_prove(vs * 4) + + def test_regression_simplify_inf_recursion(): ana = tvm.arith.Analyzer() cond = tir.Var("cond", "int32") diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 77c22761a9..d5446b0b1c 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -731,36 +731,20 @@ def test_unsupported_multiple_function_attributes(attr_key, attr_value): llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) @pytest.mark.parametrize("dtype", ["float16", "float32"]) -@pytest.mark.parametrize( - "conv2d_impl", - [ - ( - tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE, - tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE, - False, - ), - ( - tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE, - tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR, - True, - ), - ], -) -def test_conv2d_sve(dtype, conv2d_impl): +def test_conv2d_sve(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(dtype, compute, schedule, use_tir_schedule): + def check_correct_assembly(dtype): A = te.placeholder((1, 32, 32, 3), dtype=dtype, name="A") W = te.placeholder((3, 3, 3, 8), dtype=dtype, name="B") stride = padding = dilation = 1 + + compute = tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE + schedule = tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE B = compute(A, W, stride, padding, dilation, dtype) - if use_tir_schedule: - func = te.create_prim_func([A, W, B]) - sch = schedule(tvm.tir.Schedule(func)) - f = tvm.build(sch.mod["main"], target) - else: - s = schedule([B]) - f = tvm.build(s, [A, W, B], target) + s = schedule([B]) + + f = tvm.build(s, [A, W, B], target) assembly = f.get_source("asm") loads = re.findall(r"ld1[r]?[q]?[whdb]\t{\s?z", assembly) @@ -774,43 +758,6 @@ def test_conv2d_sve(dtype, conv2d_impl): assert len(compute_ops) > 0 assert len(stores) > 0 - with tvm.target.Target(target): - check_correct_assembly(dtype, *conv2d_impl) - - -@pytest.mark.skipif( - llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SME" -) -@pytest.mark.parametrize("dtype", ["float32"]) -def test_conv2d_sme(dtype): - target = "llvm -mtriple=aarch64-linux-gnu -mattr=+v9a,+sme" - - def check_correct_assembly(dtype): - A = te.placeholder((1, 32, 32, 3), dtype=dtype, name="A") - W = te.placeholder((3, 3, 3, 8), dtype=dtype, name="B") - stride = padding = dilation = 1 - - B = tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME(A, W, stride, padding, dilation, dtype) - func = te.create_prim_func([A, W, B]) - sch = tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(tvm.tir.Schedule(func)) - f = tvm.build(sch.mod["main"], target) - - assembly = f.get_source("asm") - smstart = re.findall(r"smstart\t(sm|za)", assembly) - loads = re.findall(r"ld1[whdb]\t{\s?za", assembly) - mopa = re.findall( - r"fmopa\tza[0-9].[shdb],( p[0-9]/[zm],)?( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", - assembly, - ) - stores = re.findall(r"st1[whdb]\t{\s?za", assembly) - smstop = re.findall(r"smstop\t(sm|za)", assembly) - - assert len(smstart) > 0 - assert len(loads) > 0 - assert len(mopa) > 0 - assert len(stores) > 0 - assert len(smstop) > 0 - with tvm.target.Target(target): check_correct_assembly(dtype=dtype) diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_conv2d.py index 2708094afb..1b9c1a5e2e 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d.py @@ -16,21 +16,8 @@ # under the License. """Tests for arm_cpu schedules for regular conv2d.""" -import pytest -import numpy as np - -import tvm -import tvm.topi.testing -from tvm import relay from test_generalized_conv2d import GeneralizedConv2dTests from tvm.testing import fixture, main, parameter, parameters -from tvm.topi.nn.utils import get_pad_tuple -from tvm.topi.utils import get_const_tuple -from tvm.target.codegen import llvm_version_major -from tvm.testing.aot import AOTTestModel, AOTCompiledTestModel, run_and_check, generate_ref_data -from tvm.micro.testing.aot_test_utils import AOT_APROFILE_AEM_RUNNER -from tvm.relay.op.strategy.arm_cpu import arm_cpu_tir_strategy -from scalable_utils import calculate_extra_workspace_size_from_scalable_extents class Conv2dTests(GeneralizedConv2dTests): @@ -120,128 +107,5 @@ class TestConv2d_NCHW_Spatial_Pack(Conv2dTests): schedule_name = parameter("conv2d_nchw_spatial_pack.arm_cpu") -dtype = tvm.testing.parameter("float32") - -batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( - # Pad M, N, K - (1, 1, 1, 1, 1, 1, "SAME", 1), - (1, 1, 3, 15, 1, 1, "SAME", 1), - # Pad M, K - (1, 3, 9, 16, 3, 1, "SAME", 1), - # Pad M, N - (1, 2, 9, 15, 4, 1, "SAME", 1), - # Pad K, N - (1, 7, 4, 15, 3, 1, "SAME", 1), - # Pad M - (1, 2, 9, 16, 4, 1, "SAME", 1), - # Pad K - (1, 7, 4, 16, 3, 1, "SAME", 1), - # Pad N - (1, 2, 4, 15, 4, 1, "SAME", 1), - (1, 2, 4, 20, 1, 1, "SAME", 1), - # Large workloads - (1, 128, 32, 128, 3, 1, "SAME", 1), - (4, 64, 16, 64, 5, 2, "SAME", 1), - (1, 128, 32, 128, 3, 1, "VALID", 1), - (4, 64, 16, 64, 5, 2, "VALID", 1), - (1, 64, 16, 64, 3, 2, (0, 0, 1, 1), 1), - (1, 64, 16, 64, 3, 2, (1, 1, 2, 2), 1), - (1, 64, 16, 64, 5, 2, (3, 3, 2, 2), 1), - (1, 64, 16, 64, 3, 2, (0, 1, 2, 3), 1), - (1, 64, 32, 64, 3, 1, "SAME", 2), - (1, 64, 32, 64, 3, 1, (1, 1, 2, 2), 2), -) - - -@tvm.testing.fixture() -def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation): - np.random.seed(0) - in_height = in_width = in_size - a_shape = (batch, in_height, in_width, in_channel) - w_shape = (kernel, kernel, in_channel, num_filter) - - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - return a_np, w_np - - -@pytest.mark.skipif( - llvm_version_major() < 16, reason="SME is not supported in earlier versions of LLVM" -) -@tvm.testing.requires_aprofile_aem_fvp -def test_conv2d_fp32(target, ref_data, dtype, stride, padding, dilation): - a_np, w_np = ref_data - dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) - - kernel_size = get_const_tuple(w_np.shape[:2]) - out_channels = w_np.shape[3] - - x = relay.var("data", shape=a_np.shape, dtype=dtype) - weight = relay.const(w_np, dtype=dtype) - conv2d = relay.nn.conv2d( - x, - weight, - channels=out_channels, - kernel_size=kernel_size, - strides=stride, - dilation=dilation, - padding=get_pad_tuple(padding, dw_np.shape[:2]), - data_layout="NHWC", - kernel_layout="HWIO", - out_dtype=dtype, - ) - - func = relay.Function(relay.analysis.free_vars(conv2d), conv2d) - - ir_mod = tvm.IRModule.from_expr(func) - ir_mod = tvm.relay.transform.InferType()(ir_mod) - - inputs = {"data": a_np} - params = {} - ref_outputs = generate_ref_data(ir_mod, inputs, params) - - target = tvm.target.Target("llvm -mtriple=aarch64-none-elf -mattr=+v9.2a,+sme") - runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True}) - executor = tvm.relay.backend.Executor( - "aot", - { - "interface-api": "packed", - "unpacked-api": False, - }, - ) - - with tvm.transform.PassContext( - opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config - ), target, tvm.meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy): - executor_factory = tvm.relay.build( - ir_mod, - target=target, - executor=executor, - runtime=runtime, - params=params, - ) - generated_func = executor_factory.lowered_ir_mods.items()[0][1][ - "tvmgen_default_fused_nn_conv2d" - ] - extra_memory_in_bytes = calculate_extra_workspace_size_from_scalable_extents(generated_func, 4) - - test_model = AOTTestModel( - ir_mod, inputs, ref_outputs, params=params, extra_memory_in_bytes=extra_memory_in_bytes - ) - compiled = AOTCompiledTestModel(test_model, executor_factory) - - assembly = ( - compiled.executor_factory.module.imported_modules[0].imported_modules[0].get_source("asm") - ) - assert "fmopa" in assembly - - assert run_and_check( - models=[compiled], - interface_api="packed", - runner=AOT_APROFILE_AEM_RUNNER, - print_output_on_mismatch=True, - ) - - if __name__ == "__main__": - tvm.testing.main() + main() diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index 01a914e793..71dd688e29 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -161,10 +161,6 @@ def test_int8_conv2d(target, expected_impl): "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a", "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), - ( - "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9.2a,+sme", - "conv2d_NHWC_hybrid_SME.arm_cpu", - ), ], ) def test_fp32_conv2d(target, expected_impl): @@ -201,10 +197,6 @@ def test_fp32_conv2d(target, expected_impl): "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v9a", "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), - ( - "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9.2a,+sme", - "conv2d_NHWC_hybrid_without_transform.arm_cpu", - ), ], ) def test_fp16_conv2d(target, expected_impl): diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index 02f16b59c0..b5c9518d34 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -17,12 +17,10 @@ """Example code to do convolution.""" import os import platform -import pytest import numpy as np import tvm from tvm import te from tvm import topi -from tvm.target.codegen import llvm_version_major import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.topi.utils import get_const_tuple @@ -53,37 +51,16 @@ device = tvm.testing.parameter( "llvm --device arm_cpu --mtriple aarch64-linux-gnu", topi.arm_cpu.conv2d_nhwc_spatial_pack, topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack, - False, ), ( "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+fullfp16", topi.arm_cpu.compute_conv2d_NHWC_hybrid, topi.arm_cpu.schedule_conv2d_NHWC_hybrid, - False, ), ( "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.6a,+sve", topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE, topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE, - False, - ), - ( - "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a", - topi.arm_cpu.compute_conv2d_NHWC_hybrid, - topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR, - True, - ), - ( - "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.6a,+sve", - topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE, - topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR, - True, - ), - ( - "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v9a,+sme", - topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME, - topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR, - True, ), ) @@ -91,7 +68,6 @@ dtype = tvm.testing.parameter("float16", "float32") batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( # Pad M, N, K - (1, 1, 1, 1, 1, 1, "SAME", 1), (1, 1, 3, 15, 1, 1, "SAME", 1), # Pad M, K (1, 3, 9, 16, 3, 1, "SAME", 1), @@ -163,31 +139,16 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): A = te.placeholder(a_np.shape, name="A", dtype=dtype) W = te.placeholder(w_np.shape, name="W", dtype=dtype) - target_string, compute, schedule, use_tir_schedule = device - dev = tvm.device(target_string, 0) - target = tvm.target.Target(target_string) - - if target.features.has_sve and llvm_version_major() < 15: - pytest.skip(f"LLVM {llvm_version_major()} does not support targetting SVE.") - - if target.features.has_sme and llvm_version_major() < 16: - pytest.skip(f"LLVM {llvm_version_major()} does not support targetting SME.") - - if target.features.has_sme and dtype == "float16": - pytest.skip(f"Conv2d fp16 targetting SME not implemented.") + target, compute, schedule = device + dev = tvm.device(target, 0) - with target: + with tvm.target.Target(target) as target: + B = compute(A, W, stride, padding, dilation, dtype) + s = schedule([B]) a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) - B = compute(A, W, stride, padding, dilation, dtype) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) - if use_tir_schedule: - primfunc = te.create_prim_func([A, W, B]) - sch = schedule(tvm.tir.Schedule(primfunc)) - func = tvm.build(sch.mod["main"], target) - else: - s = schedule([B]) - func = tvm.build(s, [A, W, B], target) + func = tvm.build(s, [A, W, B], target) # Run only on AArch64 devices # Do not run SVE schedules on non-SVE devices @@ -199,7 +160,6 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): and target.features.has_fp16_simd and not tvm.testing.requires_arm_fp16.run_time_check() ) - or (target.features.has_sme and not tvm.testing.requires_aarch64_sme.run_time_check()) ) if build_only: return