This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f38dc146e4 [TOPI][Relay] Add conv2d NHWC hybrid schedule for `arm_cpu` 
(#16106)
f38dc146e4 is described below

commit f38dc146e489cb06db726950542807817d97b490
Author: Andrei Hutu <andrei.h...@arm.com>
AuthorDate: Fri Nov 24 09:53:47 2023 +0000

    [TOPI][Relay] Add conv2d NHWC hybrid schedule for `arm_cpu` (#16106)
    
    Implemented an `arm_cpu` conv2d NHWC schedule for fp32 using a hybrid GeMM 
approach, effectively breaking down the matrix multiplication into a 
macro-kernel (partitioning into fixed-sized, tile-level subproblems) and a 
micro-kernel (independently dealing with each subproblem). After the im2col 
transformation, the input matrix is handled natively (not interleaved), while 
the weights matrix is tiled and interleaved at compile time.
    The micro-kernel uses 16 registers to accumulate the results of each 4x16 
output tile, cycling through the operands needed to compute them (from the 
input and weight matrices) in the remaining registers.
    
    There are now two ways to transform the weights matrix for conv2d, which 
are detailed in `convolution.cc`:
    
    * for fp32: tile, interleave
    * for int8: tile, interleave, transpose
    
    To maintain naming consistency across both of these implementations 
(transposed vs not transposed), all mentions of `tile_rows_B` or `tile_cols_B` 
have been changed to `tile_N` and `tile_K` respectively to denote the tiling 
size along each axis of the flattened B matrix. As usual, `N = out_channels` 
and `K = kernel_width * kernel_height * in_channels`.
    
    I have also added a new conv2d NHWC fp32 test for both the 
`conv2d_nhwc_spatial_pack` and `conv2d_NHWC_fp32_hybrid` schedules.
---
 include/tvm/relay/attrs/nn.h                       |  10 +-
 python/tvm/relay/op/nn/_nn.py                      |   2 +-
 python/tvm/relay/op/nn/nn.py                       |  12 +-
 python/tvm/relay/op/strategy/arm_cpu.py            | 123 +++++---
 python/tvm/topi/arm_cpu/arm_utils.py               | 105 ++++---
 python/tvm/topi/arm_cpu/conv2d.py                  | 111 +++++++
 python/tvm/topi/arm_cpu/conv2d_alter_op.py         |  57 ++--
 python/tvm/topi/arm_cpu/conv2d_gemm.py             | 345 +++++++++++++--------
 python/tvm/topi/arm_cpu/conv2d_int8.py             |  96 +-----
 python/tvm/topi/nn/conv2d.py                       |  30 +-
 src/relay/op/nn/convolution.cc                     |  70 +++--
 tests/python/integration/test_arm_aprofile.py      |   1 +
 .../relay/strategy/test_select_implementation.py   | 128 ++++++--
 tests/python/topi/test_topi_conv2d_nhwc.py         |  39 +++
 14 files changed, 721 insertions(+), 408 deletions(-)

diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index e58c73dc73..58edb9df8b 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -197,12 +197,14 @@ struct ConvWinogradWeightTransformAttrs : public 
tvm::AttrsNode<ConvWinogradWeig
 
 /*! \brief Attributes used in gemm weight transformation operators */
 struct ConvGemmWeightTransformAttrs : public 
tvm::AttrsNode<ConvGemmWeightTransformAttrs> {
-  int tile_rows;
-  int tile_cols;
+  int tile_N;
+  int tile_K;
 
   TVM_DECLARE_ATTRS(ConvGemmWeightTransformAttrs, 
"relay.attrs.ConvGemmWeightTransformAttrs") {
-    TVM_ATTR_FIELD(tile_rows).describe("Tile rows of the weight transformation 
for ConvGemm.");
-    TVM_ATTR_FIELD(tile_cols).describe("Tile columns of the weight 
transformation for ConvGemm.");
+    TVM_ATTR_FIELD(tile_N).describe(
+        "Tile size across N axis of the weight transformation for ConvGemm. (N 
= OC)");
+    TVM_ATTR_FIELD(tile_K).describe(
+        "Tile size across K axis of the weight transformation for ConvGemm. (K 
= KW * KH * IC)");
   }
 };
 
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 6acaf43fe7..a03907f071 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -798,7 +798,7 @@ reg.register_strategy(
 @reg.register_compute("nn.contrib_conv2d_gemm_weight_transform")
 def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype):
     """Compute definition of contrib_conv2d_gemm_weight_transform"""
-    out = topi.nn.conv2d_gemm_weight_transform(inputs[0], attrs.tile_rows, 
attrs.tile_cols)
+    out = topi.nn.conv2d_gemm_weight_transform(inputs[0], attrs.tile_N, 
attrs.tile_K)
     return [out]
 
 
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 89953eb1df..8cb66ecaa9 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -2741,7 +2741,7 @@ def contrib_conv2d_winograd_weight_transform(weight, 
tile_size):
     return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)
 
 
-def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols):
+def contrib_conv2d_gemm_weight_transform(weights, tile_N, tile_K):
     r"""Weight Transformation part for 2D convolution with gemm algorithm.
 
     We separate this as a single op to enable pre-compute for inference.
@@ -2751,17 +2751,17 @@ def contrib_conv2d_gemm_weight_transform(weights, 
tile_rows, tile_cols):
     ----------
     weights : tvm.relay.Expr
         The weight expressions.
-    tile_rows: int
-        Tile rows of the weight transformation for ConvGemm.
-    tile_cols: int
-       Tile columns of the weight transformation for ConvGemm.
+    tile_N: int
+        Tile size across N axis of the weight transformation for ConvGemm. (N 
= OC)
+    tile_K: int
+       Tile size across K axis of the weight transformation for ConvGemm. (K = 
KW * KH * IC)
 
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
-    return _make.contrib_conv2d_gemm_weight_transform(weights, tile_rows, 
tile_cols)
+    return _make.contrib_conv2d_gemm_weight_transform(weights, tile_N, tile_K)
 
 
 def contrib_conv3d_winograd_weight_transform(weight, tile_size):
diff --git a/python/tvm/relay/op/strategy/arm_cpu.py 
b/python/tvm/relay/op/strategy/arm_cpu.py
index a23ccf8f69..1f9a6fc41e 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -211,37 +211,50 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, 
target):
                 )
             elif kernel_layout == "HWIO":
                 is_aarch64 = target.features.is_aarch64
-                has_asimd = target.features.has_asimd
                 has_dot_prod = target.features.has_dotprod
                 has_matmul_i8 = target.features.has_matmul_i8
-
-                if data.dtype in ["int8", "uint8"]:
-                    if has_matmul_i8:
+                interleaved_compute = 
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
+                interleaved_schedule = 
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
+                native_compute = 
topi.arm_cpu.compute_conv2d_NHWC_quantized_native
+                native_schedule = 
topi.arm_cpu.schedule_conv2d_NHWC_quantized_native
+                # Quantized cases
+                if is_aarch64 and data.dtype in ["int8", "uint8"]:
+                    if has_matmul_i8 and has_dot_prod:
+                        strategy.add_implementation(
+                            wrap_compute_conv2d(interleaved_compute),
+                            wrap_topi_schedule(interleaved_schedule),
+                            name="conv2d_NHWC_quantized_interleaved.arm_cpu",
+                        )
                         strategy.add_implementation(
-                            wrap_compute_conv2d(
-                                
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
-                            ),
-                            wrap_topi_schedule(
-                                
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
-                            ),
+                            wrap_compute_conv2d(native_compute),
+                            wrap_topi_schedule(native_schedule),
+                            name="conv2d_NHWC_quantized_native.arm_cpu",
+                        )
+                    elif has_matmul_i8:
+                        strategy.add_implementation(
+                            wrap_compute_conv2d(interleaved_compute),
+                            wrap_topi_schedule(interleaved_schedule),
                             name="conv2d_NHWC_quantized_interleaved.arm_cpu",
                         )
-                    if has_dot_prod:
+                    elif has_dot_prod:
                         strategy.add_implementation(
-                            
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
-                            
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
+                            wrap_compute_conv2d(native_compute),
+                            wrap_topi_schedule(native_schedule),
                             name="conv2d_NHWC_quantized_native.arm_cpu",
                         )
-                    if is_aarch64 and has_asimd:
+                    else:
                         strategy.add_implementation(
-                            wrap_compute_conv2d(
-                                
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved
-                            ),
-                            wrap_topi_schedule(
-                                
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved
-                            ),
+                            wrap_compute_conv2d(interleaved_compute),
+                            wrap_topi_schedule(interleaved_schedule),
                             name="conv2d_NHWC_quantized_interleaved.arm_cpu",
                         )
+                # Non-quantized cases
+                if is_aarch64 and data.dtype in ["float32", "float16"]:
+                    strategy.add_implementation(
+                        
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid),
+                        
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid),
+                        name="conv2d_NHWC_hybrid.arm_cpu",
+                    )
                 if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
                     # TODO(@giuseros)
                     # This strategy errors out for quantized data types when 
tuning.
@@ -250,6 +263,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, 
target):
                         
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
                         
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
                         name="conv2d_nhwc_spatial_pack.arm_cpu",
+                        plevel=5,
                     )
             else:
                 raise RuntimeError(f"Unsupported kernel layout {kernel_layout} 
for conv2d NHWC")
@@ -485,40 +499,59 @@ def 
conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
     data = inputs[0]
     strategy = _op.OpStrategy()
     is_aarch64 = target.features.is_aarch64
-    has_asimd = target.features.has_asimd
     has_dot_prod = target.features.has_dotprod
     has_matmul_i8 = target.features.has_matmul_i8
 
     interleaved_compute = 
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform
+    interleaved_schedule = 
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
     native_compute = 
topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform
-    if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
-        if has_matmul_i8:
-            strategy.add_implementation(
-                wrap_compute_conv2d_gemm(interleaved_compute),
-                wrap_topi_schedule(
-                    
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
-                ),
-                
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
-            )
-        if has_dot_prod:
-            strategy.add_implementation(
-                wrap_compute_conv2d_gemm(native_compute),
-                wrap_topi_schedule(
-                    
topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform
-                ),
-                name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
-            )
-        if is_aarch64 and has_asimd:
+    native_schedule = 
topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform
+    if layout == "NHWC" and data.dtype in ["int8", "uint8", "float32", 
"float16"]:
+        # Non-AArch64 cases
+        if not is_aarch64:
+            raise RuntimeError("Unsupported non-AArch64 
conv2d_NHWC_without_transform")
+        # AArch64 cases
+        if data.dtype in ["int8", "uint8"]:
+            # Quantized cases
+            if has_matmul_i8 and has_dot_prod:
+                strategy.add_implementation(
+                    wrap_compute_conv2d_gemm(interleaved_compute),
+                    wrap_topi_schedule(interleaved_schedule),
+                    
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
+                )
+                strategy.add_implementation(
+                    wrap_compute_conv2d_gemm(native_compute),
+                    wrap_topi_schedule(native_schedule),
+                    
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
+                )
+            elif has_matmul_i8:
+                strategy.add_implementation(
+                    wrap_compute_conv2d_gemm(interleaved_compute),
+                    wrap_topi_schedule(interleaved_schedule),
+                    
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
+                )
+            elif has_dot_prod:
+                strategy.add_implementation(
+                    wrap_compute_conv2d_gemm(native_compute),
+                    wrap_topi_schedule(native_schedule),
+                    
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
+                )
+            else:
+                strategy.add_implementation(
+                    wrap_compute_conv2d_gemm(interleaved_compute),
+                    wrap_topi_schedule(interleaved_schedule),
+                    
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
+                )
+        elif data.dtype in ["float32", "float16"]:
+            # Non-quantized cases
             strategy.add_implementation(
-                wrap_compute_conv2d_gemm(interleaved_compute),
-                wrap_topi_schedule(
-                    
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
-                ),
-                
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
+                
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform),
+                
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform),
+                name="conv2d_NHWC_hybrid_without_transform.arm_cpu",
             )
     else:
         raise RuntimeError(
-            f"Unsupported conv2d_NHWC_quantized_without_transform layout 
{layout}"
+            f"Unsupported conv2d_NHWC_without_transform layout {layout}"
             f"with datatype {data.dtype}"
         )
     return strategy
diff --git a/python/tvm/topi/arm_cpu/arm_utils.py 
b/python/tvm/topi/arm_cpu/arm_utils.py
index 9c519cbb93..0a67aa1c6b 100644
--- a/python/tvm/topi/arm_cpu/arm_utils.py
+++ b/python/tvm/topi/arm_cpu/arm_utils.py
@@ -20,9 +20,9 @@
 from tvm.target import Target
 
 
-def get_tiling_B_interleaved_t(interleave_A):
+def get_tiling_B_transformed(interleave_A, in_dtype):
     """Compute the tiling information for matrix B', where B'
-    is the transposed and interleaved version of matrix B in C=A*B.
+    is the tiled, interleaved (and transposed) version of matrix B in C=A*B.
 
     The tiling information is chosen to maximize register usage during the
     tile computation.
@@ -36,59 +36,68 @@ def get_tiling_B_interleaved_t(interleave_A):
 
     Parameters
     ----------
-    interleave_A: bool
-                  determines if A is expected to be interleaved
+    interleave_A : bool
+        determines if A is expected to be interleaved
+    in_dtype : str
+        input datatype
+
 
     Returns
     ----------
-    tile_rows_B: the output tile rows of B'
-    tile_cols_B: the output tile columns of B'
+    tile_N: the output tile size of B' on N axis (N = OC)
+    tile_K: the output tile size of B' on K axis (K = KW * KH * IC)
     """
     target = Target.current(allow_none=False)
-
-    if target.features.has_matmul_i8:
-        # If smmla/ummla is available,  A must be interleaved.
-        # Each load from B' will contain 8 elements
-        # and we are loading 12 rows of B' (i.e., 12 columns of B)
-        tile_rows_B = 12
-        tile_cols_B = 8
-    elif target.features.has_dotprod:
-        # The number of tile rows of B' vary depending on the
-        # strategy:
-        # * If we are interleaving A, then we select 12 columns from B'(i.e.,
-        #   12 rows from B).
-        # * If we are not interleaving A, then we select 16 columns from 
B'(i.e.,
-        #   16 rows from B).
-        tile_rows_B = 12 if interleave_A else 16
-
-        # Dot product instruction groups 2 (u)int16x8 vectors in
-        # groups of 4 and compute the dot product among those groups
-        # This means that the number of columns in a tile of B' (i.e.,  the
-        # rows of the original matrix B)  need to be 4.
-        tile_cols_B = 4
+    if in_dtype in ["int8", "uint8"]:
+        if target.features.has_matmul_i8:
+            # If smmla/ummla is available,  A must be interleaved.
+            # Each load from B' will contain 8 elements
+            # and we are loading 12 rows of B' (i.e., 12 columns of B)
+            tile_N = 12
+            tile_K = 8
+        elif target.features.has_dotprod:
+            # The number of tile rows of B' vary depending on the
+            # strategy:
+            # * If we are interleaving A, then we select 12 columns from 
B'(i.e.,
+            #   12 rows from B).
+            # * If we are not interleaving A, then we select 16 columns from 
B'(i.e.,
+            #   16 rows from B).
+            tile_N = 12 if interleave_A else 16
+
+            # Dot product instruction groups 2 (u)int16x8 vectors in
+            # groups of 4 and compute the dot product among those groups
+            # This means that the number of columns in a tile of B' (i.e.,  the
+            # rows of the original matrix B)  need to be 4.
+            tile_K = 4
+        else:
+            # If no acceleration is available, A must be interleaved. In this 
case
+            # we load 4 rows of B' (i.e., 4 columns of B). Each of them will 
contain 16 elements
+            tile_N = 4
+            tile_K = 16
     else:
-        # If no acceleration is available, A must be interleaved. In this case
-        # we load 4 rows of B' (i.e., 4 columns of B). Each of them will 
contain 16 elements
-        tile_rows_B = 4
-        tile_cols_B = 16
+        # In non-quantized cases, A is not interleaved.
+        # Each load from B' contains 16 elements (i.e. 16 columns from B)
+        # We are loading 4 rows from B', in the dimension of reduction (i.e. 4 
rows from B)
+        tile_N = 16
+        tile_K = 4
 
-    return tile_rows_B, tile_cols_B
+    return tile_N, tile_K
 
 
-def get_conv2d_weights_padding(N, K, tile_rows, tile_cols):
+def get_conv2d_weights_padding(N, K, tile_N, tile_K):
     """Compute the necessary padding for matrix B', where B'
-    is the transposed and interleaved version of matrix B in C=A*B.
+    is the transformed version of matrix B in C=A*B.
 
     Parameters
     ----------
     N : int
-        Number of rows in B' = OC
+        Number of columns in B = OC
     K : int
-        Number of columns in B' = KW * KH * IC
-    tile_rows : int
-                tile rows of B'
-    tile_cols : int
-                tile columns of B'
+        Number of rows in B = KW * KH * IC
+    tile_N : int
+             tile size of B' on N axis
+    tile_K : int
+             tile size of B' on K axis
 
     Returns
     ----------
@@ -98,16 +107,16 @@ def get_conv2d_weights_padding(N, K, tile_rows, tile_cols):
     pad_N = 0
     pad_K = 0
 
-    if N % tile_rows != 0:
-        pad_N = tile_rows - (N % tile_rows)
+    if N % tile_N != 0:
+        pad_N = tile_N - (N % tile_N)
 
-    # Tensorize will later make use of 4 tiles at once across the columns so 
make sure we pad such
-    # that the columns is multiple of 4
-    column_multiplier = 4
-    tile_cols_multiplied = tile_cols * column_multiplier
-    K_misalignment = K % tile_cols_multiplied
+    # Tensorize will later make use of 4 tiles at once across the K axis so 
make sure we pad such
+    # that K is multiple of 4
+    K_multiplier = 4
+    tile_K_multiplied = tile_K * K_multiplier
+    K_misalignment = K % tile_K_multiplied
 
     if K_misalignment != 0:
-        pad_K = tile_cols_multiplied - K_misalignment
+        pad_K = tile_K_multiplied - K_misalignment
 
     return pad_N, pad_K
diff --git a/python/tvm/topi/arm_cpu/conv2d.py 
b/python/tvm/topi/arm_cpu/conv2d.py
index a478818084..90e199f36a 100644
--- a/python/tvm/topi/arm_cpu/conv2d.py
+++ b/python/tvm/topi/arm_cpu/conv2d.py
@@ -27,12 +27,18 @@ from ..utils import traverse_inline, get_const_tuple
 from .. import nn
 from ..nn.utils import get_const_int, get_pad_tuple
 from ..nn.winograd_util import winograd_transform_matrices
+from .arm_utils import get_tiling_B_transformed
 from .conv2d_spatial_pack import (
     conv2d_spatial_pack_nchw,
     conv2d_spatial_pack_nhwc,
     schedule_conv2d_spatial_pack_nchw,
     schedule_conv2d_spatial_pack_nhwc,
 )
+from .conv2d_gemm import (
+    compute_conv2d_gemm_without_weight_transform,
+    schedule_conv2d_gemm_interleaved,
+    schedule_conv2d_gemm_native,
+)
 from .mprofile.dsp.conv2d import conv2d_nhwc_dsp_compute, 
conv2d_nhwc_dsp_schedule
 
 
@@ -509,3 +515,108 @@ def conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, 
dilation, out_dtype):
 def schedule_conv2d_nhwc_dsp(cfg, outs):
     """Create schedule for conv2d_nhwc_dsp"""
     return conv2d_nhwc_dsp_schedule(cfg, outs)
+
+
+def compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, 
out_dtype, interleave_A):
+    N, IH, IW, IC = get_const_tuple(data.shape)
+    KH, KW, _, OC = get_const_tuple(kernel.shape)
+    tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype)
+
+    kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K)
+    return compute_conv2d_gemm_without_weight_transform(
+        cfg, data, kernel, strides, padding, dilation, out_dtype, (KH, KW), 
OC, interleave_A
+    )
+
+
+def compute_conv2d_NHWC_without_transform(
+    cfg,
+    data,
+    B,
+    strides,
+    padding,
+    dilation,
+    out_dtype,
+    kernel_size=None,
+    output_channels=None,
+    interleave_A=False,
+):
+    """Compute conv2d NHWC without weight transform"""
+    return compute_conv2d_gemm_without_weight_transform(
+        cfg,
+        data,
+        B,
+        strides,
+        padding,
+        dilation,
+        out_dtype,
+        kernel_size,
+        output_channels,
+        interleave_A,
+    )
+
+
+def schedule_conv2d_NHWC(cfg, outs, interleave_A):
+    """Create schedule for tensors"""
+    s = te.create_schedule([x.op for x in outs])
+    # Vectorize the output and then inline all the rest
+    out = outs[0]
+    n, h, w, c = out.op.axis
+    n_h_fused = s[out].fuse(n, h)
+    _, inner = s[out].split(c, 4)
+    s[out].vectorize(inner)
+    s[out].parallel(n_h_fused)
+
+    def _callback(op):
+        """Traverse operators from computation graph"""
+        if op.name == "conv2d_gemm_output":
+            conv_out = op.output(0)
+            if interleave_A:
+                schedule_conv2d_gemm_interleaved(cfg, s, conv_out, out)
+            else:
+                schedule_conv2d_gemm_native(cfg, s, conv_out, out)
+            if out != conv_out:
+                s[conv_out].compute_at(s[out], inner)
+            else:
+                C = conv_out.op.input_tensors[0]
+                if interleave_A:
+                    s[C].compute_at(s[out], inner)
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+@autotvm.register_topi_compute("conv2d_NHWC_hybrid.arm_cpu")
+def compute_conv2d_NHWC_hybrid(cfg, data, kernel, strides, padding, dilation, 
out_dtype):
+    """Interface for hybrid compute_conv2d_NHWC_hybrid"""
+    return compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, 
out_dtype, False)
+
+
+@autotvm.register_topi_compute("conv2d_NHWC_hybrid_without_transform.arm_cpu")
+def compute_conv2d_NHWC_hybrid_without_transform(
+    cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, 
output_channels
+):
+    """Interface for hybrid compute_conv2d_NHWC_hybrid_without_transform"""
+    return compute_conv2d_NHWC_without_transform(
+        cfg,
+        data,
+        kernel,
+        strides,
+        padding,
+        dilation,
+        out_dtype,
+        kernel_size,
+        output_channels,
+        False,
+    )
+
+
+@autotvm.register_topi_schedule("conv2d_NHWC_hybrid.arm_cpu")
+def schedule_conv2d_NHWC_hybrid(cfg, outs):
+    """Interface for hybrid schedule_conv2d_NHWC_hybrid"""
+    return schedule_conv2d_NHWC(cfg, outs, False)
+
+
+@autotvm.register_topi_schedule("conv2d_NHWC_hybrid_without_transform.arm_cpu")
+def schedule_conv2d_NHWC_hybrid_without_transform(cfg, outs):
+    """Interface for hybrid schedule_conv2d_NHWC_hybrid"""
+    return schedule_conv2d_NHWC(cfg, outs, False)
diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py 
b/python/tvm/topi/arm_cpu/conv2d_alter_op.py
index 1c30e1f3b6..fe4569ceb1 100644
--- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py
+++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py
@@ -32,15 +32,15 @@ from ..utils import get_const_tuple
 from ..x86.conv2d import _get_default_config as _get_x86_default_config
 from ..x86.conv2d_int8 import _get_default_config_int8
 from .conv2d_int8 import is_int8_hw_support
-from .arm_utils import get_tiling_B_interleaved_t, get_conv2d_weights_padding
+from .arm_utils import get_tiling_B_transformed, get_conv2d_weights_padding
 from ..generic.conv2d import conv2d_alter_int8_common
 from .mprofile.dsp.micro_kernel.common import num_simd_lanes_per_word
 
 logger = logging.getLogger("topi")
 
 
-def interleave_transpose_weights(inputs, data, kernel, interleave_A):
-    """Transform the weight matrix by reshaping, interleaving and transposing 
it
+def transform_weights(inputs, data, kernel, interleave_A):
+    """Transform the weight matrix by tiling, interleaving (and transposing it)
 
     Parameters
     ----------
@@ -59,29 +59,28 @@ def interleave_transpose_weights(inputs, data, kernel, 
interleave_A):
     new_kernel_expr : tvm.relay.Expr
                 The relay expression of the weights
     """
-    assert (
-        data.dtype == "int8"
-        and kernel.dtype == "int8"
-        or data.dtype == "uint8"
-        and kernel.dtype == "uint8"
-    )
 
     KH, KW, IC, OC = get_const_tuple(kernel.shape)
     K = KH * KW * IC
     N = OC
 
-    # Get tiling information for the interleaved transposed version of B
-    tile_rows_B, tile_cols_B = get_tiling_B_interleaved_t(interleave_A)
-    pad_N, pad_K = get_conv2d_weights_padding(N, K, tile_rows_B, tile_cols_B)
+    # Get tiling information for the transformed version of B
+    tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype)
+    pad_N, pad_K = get_conv2d_weights_padding(N, K, tile_N, tile_K)
 
     N_padded = N + pad_N
     K_padded = K + pad_K
-    new_kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(
-        inputs[1], tile_rows_B, tile_cols_B
-    )
-    new_kernel = te.placeholder(
-        (N_padded // tile_rows_B, K_padded // tile_cols_B, tile_rows_B, 
tile_cols_B), kernel.dtype
-    )
+    new_kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(inputs[1], 
tile_N, tile_K)
+    if data.dtype in ["int8", "uint8"]:
+        new_kernel = te.placeholder(
+            (N_padded // tile_N, K_padded // tile_K, tile_N, tile_K),
+            kernel.dtype,
+        )
+    else:
+        new_kernel = te.placeholder(
+            (N_padded // tile_N, K_padded // tile_K, tile_K, tile_N),
+            kernel.dtype,
+        )
     return new_kernel, new_kernel_expr
 
 
@@ -149,6 +148,20 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
             inputs[0], relay.Constant(tvm.nd.array(reshaped_new_kernel)), 
**new_attrs
         )
 
+    if topi_tmpl == "conv2d_NHWC_hybrid.arm_cpu":
+        assert data_layout == "NHWC" and kernel_layout == "HWIO"
+        KH, KW, _, OC = get_const_tuple(kernel.shape)
+        new_workload_name = "conv2d_NHWC_hybrid_without_transform.arm_cpu"
+        new_kernel, new_kernel_expr = transform_weights(inputs, data, kernel, 
interleave_A=False)
+        new_workload = autotvm.task.args_to_workload(
+            [data, new_kernel, strides, padding, dilation, out_dtype, (KH, 
KW), OC],
+            new_workload_name,
+        )
+        dispatch_ctx.update(target, new_workload, cfg)
+        return relay.nn.contrib_conv2d_gemm_without_weight_transform(
+            inputs[0], new_kernel_expr, **new_attrs
+        )
+
     # Only microTVM does layout alteration for NHWC layout with real data types
     if data_layout == "NHWC" and data_dtype not in ["uint8", "int8"]:
         return None
@@ -431,9 +444,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         assert data_layout == "NHWC" and kernel_layout == "HWIO"
         KH, KW, _, OC = get_const_tuple(kernel.shape)
         new_workload_name = 
"conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu"
-        new_kernel, new_kernel_expr = interleave_transpose_weights(
-            inputs, data, kernel, interleave_A=True
-        )
+        new_kernel, new_kernel_expr = transform_weights(inputs, data, kernel, 
interleave_A=True)
         new_workload = autotvm.task.args_to_workload(
             [data, new_kernel, strides, padding, dilation, out_dtype, (KH, 
KW), OC],
             new_workload_name,
@@ -447,9 +458,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         assert data_layout == "NHWC" and kernel_layout == "HWIO"
         KH, KW, _, OC = get_const_tuple(kernel.shape)
         new_workload_name = 
"conv2d_NHWC_quantized_native_without_transform.arm_cpu"
-        new_kernel, new_kernel_expr = interleave_transpose_weights(
-            inputs, data, kernel, interleave_A=False
-        )
+        new_kernel, new_kernel_expr = transform_weights(inputs, data, kernel, 
interleave_A=False)
         new_workload = autotvm.task.args_to_workload(
             [data, new_kernel, strides, padding, dilation, out_dtype, (KH, 
KW), OC],
             new_workload_name,
diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py 
b/python/tvm/topi/arm_cpu/conv2d_gemm.py
index 90e02c5ab0..e08775dcf3 100644
--- a/python/tvm/topi/arm_cpu/conv2d_gemm.py
+++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py
@@ -70,6 +70,7 @@ def compute_conv2d_gemm_without_weight_transform(
     """Compute conv2d by transforming the input,
     executing GEMM and transforming the output back"""
     batches, IH, IW, IC = get_const_tuple(data.shape)
+    in_dtype = data.dtype
 
     KH, KW = get_const_tuple(kernel_size)
     OC = get_const_int(output_channels)
@@ -90,7 +91,7 @@ def compute_conv2d_gemm_without_weight_transform(
 
     OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
     OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
-    if pad_top or pad_left:
+    if pad_top or pad_left or pad_down or pad_right:
         data_pad = nn.pad(
             data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], 
name="data_pad"
         )
@@ -119,8 +120,12 @@ def compute_conv2d_gemm_without_weight_transform(
 
     #  Pad if necessary
     N_transformed = B_interleaved_t.shape[0]
-    tile_rows_B = B_interleaved_t.shape[2]
-    tile_cols_B = B_interleaved_t.shape[3]
+    if in_dtype in ["int8", "uint8"]:
+        tile_N = B_interleaved_t.shape[2]
+        tile_K_B = B_interleaved_t.shape[3]
+    else:
+        tile_N = B_interleaved_t.shape[3]
+        tile_K_B = B_interleaved_t.shape[2]
 
     # Select the tiling strategy for A.
     # The tiling information is chosen to maximize register usage during
@@ -134,34 +139,41 @@ def compute_conv2d_gemm_without_weight_transform(
     # In order to have more information
     #
     target = Target.current(allow_none=False)
-    if target.features.has_matmul_i8:
-        # If smmla/ummla is enabled, we are loading 8 rows from A. Each row
-        # will contain 8 elements
-        tile_rows_A = 8
-        tile_cols_A = 8
-    elif target.features.has_dotprod and interleave_A:
-        # If dot product has been enabled, and we are interleaving A
-        # tile size should be 8x4
-        tile_rows_A = 8
-        tile_cols_A = 4
+    if in_dtype in ["int8", "uint8"]:
+        if target.features.has_matmul_i8:
+            # If smmla/ummla is enabled, we are loading 8 rows from A. Each row
+            # will contain 8 elements
+            tile_M = 8
+            tile_K_A = 8
+        elif target.features.has_dotprod and interleave_A:
+            # If dot product has been enabled, and we are interleaving A
+            # tile size should be 8x4
+            tile_M = 8
+            tile_K_A = 4
+        else:
+            # If either there is no dot product or if we are using a native 
strategy
+            # tile size should be 4x16
+            tile_M = 4
+            tile_K_A = 16
     else:
-        # If either there is no dot product or if we are using a native 
strategy
-        # tile size should be 4x16
-        tile_rows_A = 4
-        tile_cols_A = 16
+        # In non-quantized cases, A is not interleaved.
+        # We are loading 4 rows from A.
+        # Each row will contain 4 elements, along the dimension of reduction
+        tile_M = 4
+        tile_K_A = 4
 
     pad_M = 0
     pad_K = 0
 
-    if M % tile_rows_A != 0:
-        pad_M = tile_rows_A - (M % tile_rows_A)
+    if M % tile_M != 0:
+        pad_M = tile_M - (M % tile_M)
 
-    if K % tile_cols_A != 0:
-        pad_K = tile_cols_A - (K % tile_cols_A)
+    if K % tile_K_A != 0:
+        pad_K = tile_K_A - (K % tile_K_A)
 
     M_padded = M + pad_M
     K_padded = K + pad_K
-    N_padded = N_transformed * tile_rows_B
+    N_padded = N_transformed * tile_N
 
     pad_before = (0, 0, 0)
     pad_after = (0, pad_M, pad_K)
@@ -174,131 +186,160 @@ def compute_conv2d_gemm_without_weight_transform(
     idxm = tvm.tir.indexmod
     k = te.reduce_axis((0, K_padded), "k")
 
-    if interleave_A:
-        # Configuration space
-        configure_knobs(cfg, M_padded, K_padded, target)
+    if in_dtype in ["int8", "uint8"]:
+        if interleave_A:
+            # Configuration space
+            configure_knobs(cfg, M_padded, K_padded, target)
 
-        # Pack the input data
-        A_interleaved = te.compute(
-            (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, 
tile_rows_A, tile_cols_A),
-            lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * 
y],
-            name="A_interleaved",
-        )
-        target = Target.current(allow_none=False)
-        if target.features.has_matmul_i8:
-            # Execute GEMM. In the case of mmla, we need to enforce the tiling
-            # from the compute. This is because mmla is doing a tiled 
computation
-            # as well. So we have a big 8x12 tile, with small 2x2 sub-tiles
-            # generated by mmla. In theory we could make the tile 2x2 and
-            # fuse and split during scheduling, but this would not work
-            # because of possible padding
-            C_interleaved = te.compute(
+            # Pack the input data
+            A_interleaved = te.compute(
                 (
                     batches,
-                    M_padded // tile_rows_A,
-                    N_transformed,
-                    tile_rows_A // 2,
-                    tile_rows_B // 2,
-                    2,
-                    2,
-                ),
-                lambda b, x, y, w, z, s, t: te.sum(
-                    A_interleaved[b, x, k // tile_cols_A, 2 * w + s, idxm(k, 
tile_cols_A)].astype(
-                        "int32"
-                    )
-                    * B_interleaved_t[y, k // tile_cols_B, 2 * z + t, idxm(k, 
tile_cols_B)].astype(
-                        "int32"
-                    ),
-                    axis=k,
+                    M_padded // tile_M,
+                    K_padded // tile_K_A,
+                    tile_M,
+                    tile_K_A,
                 ),
-                name="C_interleaved",
+                lambda b, x, y, z, w: A[b, z + tile_M * x, w + tile_K_A * y],
+                name="A_interleaved",
             )
-            # Ensure the padding needed for tensorize does not get removed 
during tir passes
-            # by adding a dummy reference to the specific padded area of the 
result
-            zero = (
-                tvm.tir.const(1, C_interleaved.dtype)
-                * C_interleaved[
-                    batches - 1,
-                    M // tile_rows_A,
-                    N_transformed - 1,
-                    idxm(M, tile_rows_A) // 2,
-                    tile_rows_B // 2 - 1,
-                    1,
-                    1,
-                ]
-                - tvm.tir.const(1, C_interleaved.dtype)
-                * C_interleaved[
-                    batches - 1,
-                    M // tile_rows_A,
-                    N_transformed - 1,
-                    idxm(M, tile_rows_A) // 2,
-                    tile_rows_B // 2 - 1,
-                    1,
-                    1,
-                ]
-            )
-            # Unpack the result
-            C = te.compute(
-                (batches, M, N),
-                lambda b, x, y: (
-                    C_interleaved[
-                        b,
-                        x // tile_rows_A,
-                        y // tile_rows_B,
-                        idxm(x, tile_rows_A) // 2,
-                        idxm(y, tile_rows_B) // 2,
-                        idxm(idxm(x, tile_rows_A), 2),
-                        idxm(idxm(y, tile_rows_B), 2),
+            target = Target.current(allow_none=False)
+            if target.features.has_matmul_i8:
+                # Execute GEMM. In the case of mmla, we need to enforce the 
tiling
+                # from the compute. This is because mmla is doing a tiled 
computation
+                # as well. So we have a big 8x12 tile, with small 2x2 sub-tiles
+                # generated by mmla. In theory we could make the tile 2x2 and
+                # fuse and split during scheduling, but this would not work
+                # because of possible padding
+                C_interleaved = te.compute(
+                    (
+                        batches,
+                        M_padded // tile_M,
+                        N_transformed,
+                        tile_M // 2,
+                        tile_N // 2,
+                        2,
+                        2,
+                    ),
+                    lambda b, x, y, w, z, s, t: te.sum(
+                        A_interleaved[b, x, k // tile_K_A, 2 * w + s, idxm(k, 
tile_K_A)].astype(
+                            "int32"
+                        )
+                        * B_interleaved_t[y, k // tile_K_B, 2 * z + t, idxm(k, 
tile_K_B)].astype(
+                            "int32"
+                        ),
+                        axis=k,
+                    ),
+                    name="C_interleaved",
+                )
+                # Ensure the padding needed for tensorize does not get removed 
during tir passes
+                # by adding a dummy reference to the specific padded area of 
the result
+                zero = (
+                    tvm.tir.const(1, C_interleaved.dtype)
+                    * C_interleaved[
+                        batches - 1,
+                        M // tile_M,
+                        N_transformed - 1,
+                        idxm(M, tile_M) // 2,
+                        tile_N // 2 - 1,
+                        1,
+                        1,
                     ]
-                    + zero
-                ).astype(out_dtype),
-                name="C",
-            )
+                    - tvm.tir.const(1, C_interleaved.dtype)
+                    * C_interleaved[
+                        batches - 1,
+                        M // tile_M,
+                        N_transformed - 1,
+                        idxm(M, tile_M) // 2,
+                        tile_N // 2 - 1,
+                        1,
+                        1,
+                    ]
+                )
+                # Unpack the result
+                C = te.compute(
+                    (batches, M, N),
+                    lambda b, x, y: (
+                        C_interleaved[
+                            b,
+                            x // tile_M,
+                            y // tile_N,
+                            idxm(x, tile_M) // 2,
+                            idxm(y, tile_N) // 2,
+                            idxm(idxm(x, tile_M), 2),
+                            idxm(idxm(y, tile_N), 2),
+                        ]
+                        + zero
+                    ).astype(out_dtype),
+                    name="C",
+                )
+            else:
+                # Execute GEMM
+                C_interleaved = te.compute(
+                    (batches, M_padded // tile_M, N_transformed, tile_M, 
tile_N),
+                    lambda b, x, y, w, z: te.sum(
+                        A_interleaved[b, x, k // tile_K_A, w, idxm(k, 
tile_K_A)].astype("int32")
+                        * B_interleaved_t[y, k // tile_K_B, z, idxm(k, 
tile_K_B)].astype("int32"),
+                        axis=k,
+                    ),
+                    name="C_interleaved",
+                )
+                # Unpack the result
+                C = te.compute(
+                    (batches, M, N),
+                    lambda b, x, y: C_interleaved[
+                        b,
+                        x // tile_M,
+                        y // tile_N,
+                        idxm(x, tile_M),
+                        idxm(y, tile_N),
+                    ].astype(out_dtype),
+                    name="C",
+                )
+            zero = tvm.tir.const(0)
         else:
-            # Execute GEMM
-            C_interleaved = te.compute(
-                (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, 
tile_rows_B),
-                lambda b, x, y, w, z: te.sum(
-                    A_interleaved[b, x, k // tile_cols_A, w, idxm(k, 
tile_cols_A)].astype("int32")
-                    * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, 
tile_cols_B)].astype("int32"),
+            # No need to pack/unpack, execute GEMM directly
+            C = te.compute(
+                (batches, M_padded, N_padded),
+                lambda b, x, y: te.sum(
+                    A[b, x, k].astype("int32")
+                    * B_interleaved_t[
+                        y // tile_N,
+                        k // tile_K_B,
+                        idxm(y, tile_N),
+                        idxm(k, tile_K_B),
+                    ].astype("int32"),
                     axis=k,
                 ),
-                name="C_interleaved",
-            )
-            # Unpack the result
-            C = te.compute(
-                (batches, M, N),
-                lambda b, x, y: C_interleaved[
-                    b,
-                    x // tile_rows_A,
-                    y // tile_rows_B,
-                    idxm(x, tile_rows_A),
-                    idxm(y, tile_rows_B),
-                ].astype(out_dtype),
                 name="C",
             )
-        zero = tvm.tir.const(0)
+
+            # We need to ensure that infer bound pass does not remove the 
padding
+            # which is necessary for the tensorizations to work. So we need to
+            # add a dummy reference to the padding area of the result
+            zero = (
+                tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
+                - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
+            )
     else:
-        # No need to pack/unpack, execute GEMM directly
+        # Configuration space
+        configure_knobs(cfg, M_padded, K_padded, target)
+
         C = te.compute(
             (batches, M_padded, N_padded),
             lambda b, x, y: te.sum(
-                A[b, x, k].astype("int32")
+                A[b, x, k].astype(in_dtype)
                 * B_interleaved_t[
-                    y // tile_rows_B, k // tile_cols_B, idxm(y, tile_rows_B), 
idxm(k, tile_cols_B)
-                ].astype("int32"),
+                    y // tile_N,
+                    k // tile_K_B,
+                    idxm(k, tile_K_B),
+                    idxm(y, tile_N),
+                ].astype(in_dtype),
                 axis=k,
             ),
             name="C",
         )
-
-        # We need to ensure that infer bound pass does not remove the padding
-        # which is necessary for the tensorizations to work. So we need to
-        # add a dummy reference to the padding area of the result
-        zero = (
-            tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
-            - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
-        )
+        zero = tvm.tir.const(0)
 
     # Reshape the result into a convolution output
     out_shape = (batches, OH, OW, OC)
@@ -417,14 +458,35 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
     # Computation
     b, x, y = C.op.axis
     (k,) = C.op.reduce_axis
-    k_outer, k_inner = s[C].split(k, 16)
-    y_tile_size = 16
-    x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, 
y_factor=y_tile_size)
-    s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner)
-    gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1)
-    s[C].unroll(x_inner)
-    s[C].tensorize(y_inner, gemm_acc)
-    s[C].parallel(x_outer)
+
+    if in_type in ["int8", "uint8"]:
+        k_outer, k_inner = s[C].split(k, 16)
+        y_tile_size = 16
+        x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, 
y_factor=y_tile_size)
+        s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner)
+        gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1)
+        s[C].unroll(x_inner)
+        s[C].tensorize(y_inner, gemm_acc)
+        s[C].parallel(x_outer)
+    else:
+        k_outer, k_inner = s[C].split(k, 4)
+        y_tile_size = 16
+        x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, 
y_factor=y_tile_size)
+        y_inner_outer, y_inner_inner = s[C].split(y_inner, 4)
+        b_x_outer_fused = s[C].fuse(b, x_outer)
+        s[C].parallel(b_x_outer_fused)
+        s[C].reorder(
+            b_x_outer_fused,
+            y_outer,
+            k_outer,
+            k_inner,
+            y_inner_outer,
+            x_inner,
+            y_inner_inner,
+        )
+        s[C].unroll(y_inner_outer)
+        s[C].unroll(x_inner)
+        s[C].vectorize(y_inner_inner)
 
     # Input transform
     if A.op.name == "A_padded_K" or A.op.name == "A_padded_M":
@@ -450,7 +512,11 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
 
         split_factor = 16
         n_size = data_im2col.shape[2]
-        if n_size % split_factor != 0:
+        if n_size % 16 == 0:
+            split_factor = 16
+        elif n_size % 8 == 0:
+            split_factor = 8
+        else:
             # Split by kernel area (KH * KW) to ensure proper vectorization
             ic = data_im2col.op.input_tensors[0].shape[3]
             split_factor = n_size // ic
@@ -466,6 +532,13 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
     else:
         s[data_im2col].compute_at(s[C], x_inner)
 
+    A_pad = data_im2col.op.input_tensors[0]
+    if A_pad.op.name == "data_pad":
+        n, h, w, c = A_pad.op.axis
+        n_h_fused = s[A_pad].fuse(n, h)
+        s[A_pad].parallel(n_h_fused)
+        s[A_pad].vectorize(c)
+
     # Output transform
     if out != final_out:
         n, h, w, c = out.op.axis
diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py 
b/python/tvm/topi/arm_cpu/conv2d_int8.py
index 6b2c9527a4..721385c189 100644
--- a/python/tvm/topi/arm_cpu/conv2d_int8.py
+++ b/python/tvm/topi/arm_cpu/conv2d_int8.py
@@ -25,12 +25,7 @@ from ..nn.conv2d import _get_workload as 
_get_conv2d_workload, unpack_NCHWc_to_n
 from ..x86.conv2d_int8 import _pack_data
 from ..nn.utils import get_pad_tuple
 from .tensor_intrin import dot_int8_int8_int32_neon_82, 
dot_int8_int8_int32_neon
-from .conv2d_gemm import (
-    compute_conv2d_gemm_without_weight_transform,
-    schedule_conv2d_gemm_interleaved,
-    schedule_conv2d_gemm_native,
-)
-from .arm_utils import get_tiling_B_interleaved_t
+from .conv2d import compute_conv2d_NHWC, 
compute_conv2d_NHWC_without_transform, schedule_conv2d_NHWC
 
 
 def _get_default_config(cfg, data, kernel, strides, padding, dilation, 
out_dtype):
@@ -208,75 +203,6 @@ def schedule_conv2d_nchw_int8(outs):
     return schedule_conv2d_NCHWc_int8(outs)
 
 
-def _compute_conv2d_NHWC_quantized(
-    cfg, data, kernel, strides, padding, dilation, out_dtype, interleave_A
-):
-    N, IH, IW, IC = get_const_tuple(data.shape)
-    KH, KW, _, OC = get_const_tuple(kernel.shape)
-    tile_rows_B, tile_cols_B = get_tiling_B_interleaved_t(interleave_A)
-
-    kernel = nn.conv2d_gemm_weight_transform(kernel, tile_rows_B, tile_cols_B)
-    return compute_conv2d_gemm_without_weight_transform(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, (KH, KW), 
OC, interleave_A
-    )
-
-
-def _compute_conv2d_NHWC_quantized_without_transform(
-    cfg,
-    data,
-    B,
-    strides,
-    padding,
-    dilation,
-    out_dtype,
-    kernel_size=None,
-    output_channels=None,
-    interleave_A=False,
-):
-    return compute_conv2d_gemm_without_weight_transform(
-        cfg,
-        data,
-        B,
-        strides,
-        padding,
-        dilation,
-        out_dtype,
-        kernel_size,
-        output_channels,
-        interleave_A,
-    )
-
-
-def _schedule_conv2d_NHWC_quantized(cfg, outs, interleave_A):
-    """Create schedule for tensors"""
-    s = te.create_schedule([x.op for x in outs])
-    # Vectorize the output and then inline all the rest
-    out = outs[0]
-    n, h, w, c = out.op.axis
-    n_h_fused = s[out].fuse(n, h)
-    outer, inner = s[out].split(c, 4)
-    s[out].vectorize(inner)
-    s[out].parallel(n_h_fused)
-
-    def _callback(op):
-        """Traverse operators from computation graph"""
-        if op.name == "conv2d_gemm_output":
-            conv_out = op.output(0)
-            if interleave_A:
-                schedule_conv2d_gemm_interleaved(cfg, s, conv_out, out)
-            else:
-                schedule_conv2d_gemm_native(cfg, s, conv_out, out)
-            if out != conv_out:
-                s[conv_out].compute_at(s[out], inner)
-            else:
-                C = conv_out.op.input_tensors[0]
-                if interleave_A:
-                    s[C].compute_at(s[out], inner)
-
-    traverse_inline(s, outs[0].op, _callback)
-    return s
-
-
 # Interleaved schedules: those schedule will interleave the input data. The
 # weights are interleaved and transposed
 @autotvm.register_topi_compute("conv2d_NHWC_quantized_interleaved.arm_cpu")
@@ -284,9 +210,7 @@ def compute_conv2d_NHWC_quantized_interleaved(
     cfg, data, kernel, strides, padding, dilation, out_dtype
 ):
     """Interface for interleaved compute_conv2d_NHWC_quantized_interleaved"""
-    return _compute_conv2d_NHWC_quantized(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, True
-    )
+    return compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, 
out_dtype, True)
 
 
 
@autotvm.register_topi_compute("conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu")
@@ -294,7 +218,7 @@ def 
compute_conv2d_NHWC_quantized_interleaved_without_transform(
     cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, 
output_channels
 ):
     """Interface for interleaved 
compute_conv2d_NHWC_quantized_interleaved_without_transform"""
-    return _compute_conv2d_NHWC_quantized_without_transform(
+    return compute_conv2d_NHWC_without_transform(
         cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, 
output_channels, True
     )
 
@@ -302,13 +226,13 @@ def 
compute_conv2d_NHWC_quantized_interleaved_without_transform(
 @autotvm.register_topi_schedule("conv2d_NHWC_quantized_interleaved.arm_cpu")
 def schedule_conv2d_NHWC_quantized_interleaved(cfg, outs):
     """Interface for interleaved schedule_conv2d_NHWC_quantized_interleaved"""
-    return _schedule_conv2d_NHWC_quantized(cfg, outs, True)
+    return schedule_conv2d_NHWC(cfg, outs, True)
 
 
 
@autotvm.register_topi_schedule("conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu")
 def schedule_conv2d_NHWC_quantized_interleaved_without_transform(cfg, outs):
     """Interface for interleaved schedule_conv2d_NHWC_quantized_interleaved"""
-    return _schedule_conv2d_NHWC_quantized(cfg, outs, True)
+    return schedule_conv2d_NHWC(cfg, outs, True)
 
 
 # Native schedules: those schedule won't interleave A (which is left in its 
native form).
@@ -316,9 +240,7 @@ def 
schedule_conv2d_NHWC_quantized_interleaved_without_transform(cfg, outs):
 @autotvm.register_topi_compute("conv2d_NHWC_quantized_native.arm_cpu")
 def compute_conv2d_NHWC_quantized_native(cfg, data, kernel, strides, padding, 
dilation, out_dtype):
     """Interface for native compute_conv2d_NHWC_quantized"""
-    return _compute_conv2d_NHWC_quantized(
-        cfg, data, kernel, strides, padding, dilation, out_dtype, False
-    )
+    return compute_conv2d_NHWC(cfg, data, kernel, strides, padding, dilation, 
out_dtype, False)
 
 
 
@autotvm.register_topi_compute("conv2d_NHWC_quantized_native_without_transform.arm_cpu")
@@ -326,7 +248,7 @@ def compute_conv2d_NHWC_quantized_native_without_transform(
     cfg, data, kernel, strides, padding, dilation, out_dtype, kernel_size, 
output_channels
 ):
     """Interface for compute_conv2d_NHWC_quantized_native_without_transform"""
-    return _compute_conv2d_NHWC_quantized_without_transform(
+    return compute_conv2d_NHWC_without_transform(
         cfg,
         data,
         kernel,
@@ -343,10 +265,10 @@ def 
compute_conv2d_NHWC_quantized_native_without_transform(
 @autotvm.register_topi_schedule("conv2d_NHWC_quantized_native.arm_cpu")
 def schedule_conv2d_NHWC_quantized_native(cfg, outs):
     """Interface for native schedule_conv2d_NHWC_quantized"""
-    return _schedule_conv2d_NHWC_quantized(cfg, outs, False)
+    return schedule_conv2d_NHWC(cfg, outs, False)
 
 
 
@autotvm.register_topi_schedule("conv2d_NHWC_quantized_native_without_transform.arm_cpu")
 def schedule_conv2d_NHWC_quantized_native_without_transform(cfg, outs):
     """Interface for native schedule_conv2d_NHWC_quantized"""
-    return _schedule_conv2d_NHWC_quantized(cfg, outs, False)
+    return schedule_conv2d_NHWC(cfg, outs, False)
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index 75f72ee93d..7516bff702 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -615,17 +615,17 @@ def conv2d_NCHWc_int8(
     )
 
 
-def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
+def conv2d_gemm_weight_transform(kernel, tile_N, tile_K):
     """Weight transformation for winograd
 
     Parameters
     ----------
     kernel: Tensor
         The raw kernel tensor with layout "NHWC".
-    tile_rows: int
-        Tile rows of the weight transformation for ConvGemm.
-    tile_cols: int
-        Tile columns of the weight transformation for ConvGemm.
+    tile_N: int
+        Tile size across N axis of the weight transformation for ConvGemm. (N 
= OC)
+    tile_K: int
+        Tile size across K axis of the weight transformation for ConvGemm. (K 
= KW * KH * IC)
 
     Returns
     -------
@@ -640,7 +640,7 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, 
tile_cols):
         (K, N), lambda x, y: kernel[(x // IC) // KW, (x // IC) % KW, x % IC, 
y], "weight_flatten"
     )
 
-    pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, 
tile_rows, tile_cols)
+    pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, 
tile_N, tile_K)
 
     N_padded = N + pad_N
     K_padded = K + pad_K
@@ -650,11 +650,19 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, 
tile_cols):
             kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), 
name="weight_padding"
         )
 
-    return te.compute(
-        (N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols),
-        lambda x, y, z, w: kernel_flat[w + tile_cols * y, z + tile_rows * x],
-        name="weight_block_reshape",
-    )
+    if kernel.dtype in ["int8", "uint8"]:
+        B_inter_t = te.compute(
+            (N_padded // tile_N, K_padded // tile_K, tile_N, tile_K),
+            lambda x, y, z, w: kernel_flat[w + tile_K * y, z + tile_N * x],
+            name="weight_block_reshape",
+        )
+    else:
+        B_inter_t = te.compute(
+            (N_padded // tile_N, K_padded // tile_K, tile_K, tile_N),
+            lambda x, y, z, w: kernel_flat[z + tile_K * y, w + tile_N * x],
+            name="weight_block_reshape",
+        )
+    return B_inter_t
 
 
 def conv2d_winograd_weight_transform(kernel, tile_size):
diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc
index 75e28c8d07..547b533ccc 100644
--- a/src/relay/op/nn/convolution.cc
+++ b/src/relay/op/nn/convolution.cc
@@ -43,10 +43,10 @@ Expr MakeConvWinogradWeightTransform(Expr weight, int 
tile_size, std::string op_
   return Call(op, {weight}, Attrs(attrs), {});
 }
 
-Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, 
std::string op_name) {
+Expr MakeConvGemmWeightTransform(Expr weight, int tile_N, int tile_K, 
std::string op_name) {
   auto attrs = make_object<ConvGemmWeightTransformAttrs>();
-  attrs->tile_rows = tile_rows;
-  attrs->tile_cols = tile_cols;
+  attrs->tile_N = tile_N;
+  attrs->tile_K = tile_K;
   const Op& op = Op::Get(op_name);
   return Call(op, {weight}, Attrs(attrs), {});
 }
@@ -1497,13 +1497,14 @@ 
RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform")
 TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs);
 
 // Gemm convolution shape relations
-// In order to run GEMM we need to block-transpose and interleave the K x N 
weights matrix W.
-// The high level idea is to subdivide W in tiles of tile_cols x tile_rows, 
and transpose and
-// interleave them. The final output is a [N//tile_rows, K//tile_cols, 
tile_rows, tile_cols]
+// In order to run GEMM we need to transform the K x N weights matrix W.
+//
+// For integer datatypes, the high level idea is to subdivide W in tiles of 
tile_K x tile_N, and
+// transpose and interleave them. The final output is a [N//tile_N, K//tile_K, 
tile_N, tile_K]
 // matrix that we call W_interleaved_t.
 //
-// In the following picture, we show how the first [tile_cols,tile_rows] block 
of W is transformed
-// for tile_rows = 4 and tile_cols = 16
+// In the following picture, we show how the first [tile_K,tile_N] block of W 
is transformed
+// for tile_N = 4 and tile_K = 16
 //
 //              W[0,0,:,:]                        W_interleaved_t[0,0,:,:]
 //  +-------------------------------+     +----------------------------------- 
+
@@ -1515,9 +1516,31 @@ TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs);
 //  |W[15,0] W[15,1] W[15,2] W[15,3]|
 //  +-------------------------------+
 //
-// Tile columns is usually the direction of the reduction. So, if our target 
can reduce k elements
-// at the time, we should set tile_cols = k.
-// Tile rows is connected with the number of registers available for the given 
target.
+// Alternatively, for floating point datatypes, we subdivide W in tiles of 
tile_K x tile_N size,
+// then interleave these tiles, without transposing. The final output is a 
[N//tile_N, K//tile_K,
+// tile_K, tile_N] matrix called W_interleaved.
+//
+// In the following illustration, we show how the tiles are interleaved.
+// Note that the inside of each tile is kept unchanged during this 
tranformation.
+//
+//           W[:,:,:,:]               W_interleaved[:,:,:,:]
+//  +--------+--------+--------+       +--------+--------+
+//  |        |        |        |       |        |        |
+//  | tile_1 | tile_2 | tile_3 |       | tile_1 | tile_4 |
+//  |        |        |        |  --\  |        |        |
+//  +--------+--------+--------+  --/  +--------+--------+
+//  |        |        |        |       |        |        |
+//  | tile_4 | tile_5 | tile_6 |       | tile_2 | tile_5 |
+//  |        |        |        |       |        |        |
+//  +--------+--------+--------+       +--------+--------+
+//                                     |        |        |
+//                                     | tile_3 | tile_6 |
+//                                     |        |        |
+//                                     +--------+--------+
+//
+// Tile K is the direction of the reduction in both cases. So, if our target 
can reduce k elements
+// at the time, we should set tile_K = k.
+// Tile N is connected with the number of registers available for the given 
target.
 //
 bool Conv2DGemmWeightTransformRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
                                   const TypeReporter& reporter) {
@@ -1527,8 +1550,8 @@ bool Conv2DGemmWeightTransformRel(const Array<Type>& 
types, int num_inputs, cons
 
   const ConvGemmWeightTransformAttrs* param = 
attrs.as<ConvGemmWeightTransformAttrs>();
   ICHECK(param != nullptr);
-  int n = param->tile_rows;
-  int k = param->tile_cols;
+  int n = param->tile_N;
+  int k = param->tile_K;
 
   ICHECK_EQ(weight->shape.size(), 4) << "Only support HWIO kernel layout";
 
@@ -1544,12 +1567,21 @@ bool Conv2DGemmWeightTransformRel(const Array<Type>& 
types, int num_inputs, cons
   const auto N_padded = N + pad_N;
   const auto K_padded = K + pad_K;
 
-  Array<IndexExpr> oshape{
-      indexdiv(N_padded, n),
-      indexdiv(K_padded, k),
-      n,
-      k,
-  };
+  Array<IndexExpr> oshape;
+  if (weight->dtype.bits() == 8 && (weight->dtype.is_int() || 
weight->dtype.is_uint()))
+    oshape = {
+        indexdiv(N_padded, n),
+        indexdiv(K_padded, k),
+        n,
+        k,
+    };
+  else
+    oshape = {
+        indexdiv(N_padded, n),
+        indexdiv(K_padded, k),
+        k,
+        n,
+    };
 
   reporter->Assign(types[1], TensorType(oshape, weight->dtype));
   return true;
diff --git a/tests/python/integration/test_arm_aprofile.py 
b/tests/python/integration/test_arm_aprofile.py
index c38217a1b1..006ad5f359 100644
--- a/tests/python/integration/test_arm_aprofile.py
+++ b/tests/python/integration/test_arm_aprofile.py
@@ -49,6 +49,7 @@ def test_conv2d(dtype):
         invar,
         weight,
         kernel_size=kernel_size,
+        channels=2,
         strides=(1, 1),
         padding=(0, 0),
         dilation=(1, 1),
diff --git a/tests/python/relay/strategy/test_select_implementation.py 
b/tests/python/relay/strategy/test_select_implementation.py
index d7dd0abbc4..f9b1a002a8 100644
--- a/tests/python/relay/strategy/test_select_implementation.py
+++ b/tests/python/relay/strategy/test_select_implementation.py
@@ -57,6 +57,39 @@ def test_concatenate(target, expected_implementation):
     assert impl.name == expected_implementation
 
 
+def _get_conv2d_impl(dtype, target):
+    """Returns selected conv2d implementation for a given datatype and 
target"""
+    data_shape = (1, 1, 1, 4)
+    weight_shape = (1, 1, 4, 4)
+    data_layout = "NHWC"
+    kernel_layout = "HWIO"
+    channels = 4
+    kernel_size = (1, 1)
+
+    out = relay.nn.conv2d(
+        relay.var("data", shape=data_shape, dtype=dtype),
+        relay.var("weight", shape=weight_shape, dtype=dtype),
+        kernel_size=kernel_size,
+        channels=channels,
+        data_layout=data_layout,
+        kernel_layout=kernel_layout,
+        out_dtype=dtype,
+    )
+
+    with target:
+        out = run_opt_pass(out, relay.transform.AlterOpLayout())
+        impl, _ = relay.backend.te_compiler.select_implementation(
+            out.op,
+            out.attrs,
+            [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, 
dtype)],
+            out.checked_type,
+            target,
+            use_autotvm=False,
+        )
+
+    return impl.name
+
+
 @pytest.mark.parametrize(
     "target,expected_impl",
     [
@@ -93,37 +126,78 @@ def test_concatenate(target, expected_implementation):
 )
 def test_int8_conv2d(target, expected_impl):
     target = tvm.target.Target(target)
-
     dtype = "int8"
-    data_shape = (1, 1, 1, 4)
-    weight_shape = (1, 1, 4, 4)
-    data_layout = "NHWC"
-    kernel_layout = "HWIO"
-    channels = 4
-    kernel_size = (1, 1)
 
-    out = relay.nn.conv2d(
-        relay.var("data", shape=data_shape, dtype=dtype),
-        relay.var("weight", shape=weight_shape, dtype=dtype),
-        kernel_size=kernel_size,
-        channels=channels,
-        data_layout=data_layout,
-        kernel_layout=kernel_layout,
-        out_dtype=dtype,
-    )
+    selected_impl = _get_conv2d_impl(dtype, target)
+    assert selected_impl == expected_impl
 
-    with target:
-        out = run_opt_pass(out, relay.transform.AlterOpLayout())
-        impl, _ = relay.backend.te_compiler.select_implementation(
-            out.op,
-            out.attrs,
-            [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, 
dtype)],
-            out.checked_type,
-            target,
-            use_autotvm=False,
-        )
 
-    assert impl.name == expected_impl
+@pytest.mark.parametrize(
+    "target,expected_impl",
+    [
+        ("llvm -device=arm_cpu", "conv2d_nhwc_spatial_pack.arm_cpu"),
+        (
+            "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon",
+            "conv2d_nhwc_spatial_pack.arm_cpu",
+        ),
+        (
+            "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+        (
+            "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+        (
+            "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu 
-mattr=+v8.2a,+neon",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+        (
+            "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+    ],
+)
+def test_fp32_conv2d(target, expected_impl):
+    target = tvm.target.Target(target)
+    dtype = "float32"
+
+    selected_impl = _get_conv2d_impl(dtype, target)
+    assert selected_impl == expected_impl
+
+
+@pytest.mark.parametrize(
+    "target,expected_impl",
+    [
+        ("llvm -device=arm_cpu", "conv2d_nhwc_spatial_pack.arm_cpu"),
+        (
+            "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon",
+            "conv2d_nhwc_spatial_pack.arm_cpu",
+        ),
+        (
+            "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+        (
+            "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+        (
+            "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu 
-mattr=+v8.2a,+neon",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+        (
+            "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a",
+            "conv2d_NHWC_hybrid_without_transform.arm_cpu",
+        ),
+    ],
+)
+def test_fp16_conv2d(target, expected_impl):
+    target = tvm.target.Target(target)
+    dtype = "float16"
+
+    selected_impl = _get_conv2d_impl(dtype, target)
+    assert selected_impl == expected_impl
 
 
 @pytest.mark.parametrize(
diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py 
b/tests/python/topi/test_topi_conv2d_nhwc.py
index e60cf12aa8..05f9cb9c05 100644
--- a/tests/python/topi/test_topi_conv2d_nhwc.py
+++ b/tests/python/topi/test_topi_conv2d_nhwc.py
@@ -16,6 +16,7 @@
 # under the License.
 """Example code to do convolution."""
 import os
+import platform
 import numpy as np
 import tvm
 from tvm import te
@@ -45,6 +46,19 @@ _conv2d_nhwc_implement = {
     "hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc),
 }
 
+device = tvm.testing.parameter(
+    (
+        "llvm --device arm_cpu --mtriple aarch64-linux-gnu",
+        topi.arm_cpu.conv2d_nhwc_spatial_pack,
+        topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack,
+    ),
+    (
+        "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a",
+        topi.arm_cpu.compute_conv2d_NHWC_hybrid,
+        topi.arm_cpu.schedule_conv2d_NHWC_hybrid,
+    ),
+)
+
 dtype = tvm.testing.parameter("float32")
 
 batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = 
tvm.testing.parameters(
@@ -77,6 +91,31 @@ def ref_data(dtype, batch, in_channel, in_size, num_filter, 
kernel, stride, padd
     return a_np, w_np, b_np
 
 
+def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, 
dilation):
+    a_np, w_np, b_np = ref_data
+
+    A = te.placeholder(a_np.shape, name="A", dtype=dtype)
+    W = te.placeholder(w_np.shape, name="W", dtype=dtype)
+
+    target, compute, schedule = device
+    dev = tvm.device(target, 0)
+
+    with tvm.target.Target(target):
+        B = compute(A, W, stride, padding, dilation, dtype)
+        s = schedule([B])
+    a = tvm.nd.array(a_np, dev)
+    w = tvm.nd.array(w_np, dev)
+    b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
+    func = tvm.build(s, [A, W, B], target)
+
+    build_only = platform.machine() != "aarch64"
+    if build_only:
+        return
+
+    func(a, w, b)
+    tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+
+
 def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, 
dilation):
     a_np, w_np, b_np = ref_data
 

Reply via email to