This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8bdd54b2fd [TOPI] Fix SME conv2d schedule import and intrin argument 
(#17040)
8bdd54b2fd is described below

commit 8bdd54b2fd652f064dc7b0f56a89688fb555bf1e
Author: Luke Hutton <luke.hut...@arm.com>
AuthorDate: Wed May 29 16:44:46 2024 +0100

    [TOPI] Fix SME conv2d schedule import and intrin argument (#17040)
    
    Fixes a merge conflict between #16981 and #17003.
    
    Change-Id: Ifcc983ef0b8c00250568a048fd682933adfdcde4
---
 python/tvm/topi/arm_cpu/conv2d.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/python/tvm/topi/arm_cpu/conv2d.py 
b/python/tvm/topi/arm_cpu/conv2d.py
index 58c909301e..d0fe251e7e 100644
--- a/python/tvm/topi/arm_cpu/conv2d.py
+++ b/python/tvm/topi/arm_cpu/conv2d.py
@@ -729,7 +729,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
         # pylint: disable=import-outside-toplevel
         from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
         from tvm.tir.tensor_intrin.arm_cpu import (
-            ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE,
+            ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
             ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
             ARM_SME_INIT,
             get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
@@ -743,7 +743,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
         ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
         sch.parallel(b)
         sch.reorder(b, ko, mo, ki, mi)
-        sch.tensorize(ki, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE)
+        sch.tensorize(ki, ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE)
 
         # Split and reorder the loops of the GeMM for tensorization
         b, m, n, k = sch.get_loops(gemm_block)
@@ -760,7 +760,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
         sme_gemm_interleaved_intrin_name = 
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}"
         tvm.tir.TensorIntrin.register(
             sme_gemm_interleaved_intrin_name,
-            *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded),
+            *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, dtype),
             override=True,
         )
         sch.tensorize(mi, sme_gemm_interleaved_intrin_name)

Reply via email to