This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 8bdd54b2fd [TOPI] Fix SME conv2d schedule import and intrin argument (#17040) 8bdd54b2fd is described below commit 8bdd54b2fd652f064dc7b0f56a89688fb555bf1e Author: Luke Hutton <luke.hut...@arm.com> AuthorDate: Wed May 29 16:44:46 2024 +0100 [TOPI] Fix SME conv2d schedule import and intrin argument (#17040) Fixes a merge conflict between #16981 and #17003. Change-Id: Ifcc983ef0b8c00250568a048fd682933adfdcde4 --- python/tvm/topi/arm_cpu/conv2d.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index 58c909301e..d0fe251e7e 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -729,7 +729,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): # pylint: disable=import-outside-toplevel from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes from tvm.tir.tensor_intrin.arm_cpu import ( - ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE, + ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE, ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA, ARM_SME_INIT, get_sme_gemm_interleaved_mopa_2svlx2svl_intrin, @@ -743,7 +743,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True) sch.parallel(b) sch.reorder(b, ko, mo, ki, mi) - sch.tensorize(ki, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE) + sch.tensorize(ki, ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE) # Split and reorder the loops of the GeMM for tensorization b, m, n, k = sch.get_loops(gemm_block) @@ -760,7 +760,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}" tvm.tir.TensorIntrin.register( sme_gemm_interleaved_intrin_name, - *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded), + *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, dtype), override=True, ) sch.tensorize(mi, sme_gemm_interleaved_intrin_name)