llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: None (SpencerAbson) <details> <summary>Changes</summary> This patch implements the following intrinsics: 8-bit floating-point sum of outer products and accumulate. ``` c // Only if __ARM_FEATURE_SME_F8F16 != 0 void svmopa_za16[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) __arm_streaming __arm_inout("za"); // Only if __ARM_FEATURE_SME_F8F32 != 0 void svmopa_za32[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) __arm_streaming __arm_inout("za"); ``` In accordance with: https://github.com/ARM-software/acle/pull/323/ Co-authored-by: Momchil Velikov momchil.velikov@<!-- -->arm.com Co-authored-by: Marian Lukac marian.lukac@<!-- -->arm.com --- Patch is 20.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118115.diff 10 Files Affected: - (modified) clang/include/clang/Basic/arm_sme.td (+10) - (modified) clang/lib/CodeGen/CGBuiltin.cpp (+6) - (added) clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c (+55) - (added) clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c (+18) - (added) clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c (+13) - (modified) clang/utils/TableGen/SveEmitter.cpp (+14-1) - (modified) llvm/include/llvm/IR/IntrinsicsAArch64.td (+11) - (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+7-10) - (modified) llvm/lib/Target/AArch64/SMEInstrFormats.td (+24-2) - (added) llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll (+22) ``````````diff diff --git a/clang/include/clang/Basic/arm_sme.td b/clang/include/clang/Basic/arm_sme.td index 0f689e82bdb742..71b2c7cdd04f93 100644 --- a/clang/include/clang/Basic/arm_sme.td +++ b/clang/include/clang/Basic/arm_sme.td @@ -824,4 +824,14 @@ let SMETargetGuard = "sme-lutv2" in { def SVLUTI4_ZT_X4 : SInst<"svluti4_zt_{d}_x4", "4i2.u", "cUc", MergeNone, "aarch64_sme_luti4_zt_x4", [IsStreaming, IsInZT0], [ImmCheck<0, ImmCheck0_0>]>; } +let SMETargetGuard = "sme-f8f32" in { + def SVMOPA_FP8_ZA32 : Inst<"svmopa_za32[_mf8]_m_fpm", "viPPdd>", "m", MergeNone, "aarch64_sme_fp8_fmopa_za32", + [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], [ImmCheck<0, ImmCheck0_3>]>; +} + +let SMETargetGuard = "sme-f8f16" in { + def SVMOPA_FP8_ZA16 : Inst<"svmopa_za16[_mf8]_m_fpm", "viPPdd>", "m", MergeNone, "aarch64_sme_fp8_fmopa_za16", + [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], [ImmCheck<0, ImmCheck0_1>]>; +} + } // let SVETargetGuard = InvalidMode diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index cb9c23b8e0a0d0..56595bb4704e74 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -10183,6 +10183,8 @@ CodeGenFunction::getSVEType(const SVETypeFlags &TypeFlags) { case SVETypeFlags::EltTyInt64: return llvm::ScalableVectorType::get(Builder.getInt64Ty(), 2); + case SVETypeFlags::EltTyMFloat8: + return llvm::ScalableVectorType::get(Builder.getInt8Ty(), 16); case SVETypeFlags::EltTyFloat16: return llvm::ScalableVectorType::get(Builder.getHalfTy(), 8); case SVETypeFlags::EltTyBFloat16: @@ -11234,6 +11236,10 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID, BuiltinID == SME::BI__builtin_sme_svstr_za) return EmitSMELdrStr(TypeFlags, Ops, Builtin->LLVMIntrinsic); + // Emit set FPMR for intrinsics that require it + if (TypeFlags.setsFPMR()) + Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr), + Ops.pop_back_val()); // Handle builtins which require their multi-vector operands to be swapped swapCommutativeSMEOperands(BuiltinID, Ops); diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c new file mode 100644 index 00000000000000..95d6383ab30efe --- /dev/null +++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c @@ -0,0 +1,55 @@ +// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5 +// REQUIRES: aarch64-registered-target + +// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s +// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK +// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s +// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK +// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -S -disable-O0-optnone -Werror -Wall -o /dev/null %s + +#include <arm_sme.h> + +#ifdef SVE_OVERLOADED_FORMS +#define SVE_ACLE_FUNC(A1,A2_UNUSED,A3) A1##A3 +#else +#define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3 +#endif + + +// CHECK-LABEL: define dso_local void @test_svmopa_za16_mf8_m( +// CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]]) +// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]]) +// CHECK-NEXT: ret void +// +// CPP-CHECK-LABEL: define dso_local void @_Z22test_svmopa_za16_mf8_mu10__SVBool_tS_u13__SVMfloat8_tS0_m( +// CPP-CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] { +// CPP-CHECK-NEXT: [[ENTRY:.*:]] +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]]) +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]]) +// CPP-CHECK-NEXT: ret void +// +void test_svmopa_za16_mf8_m(svbool_t pn, svbool_t pm, svmfloat8_t zn, + svmfloat8_t zm, fpm_t fpmr) __arm_streaming __arm_inout("za") { + SVE_ACLE_FUNC(svmopa_za16,_mf8,_m_fpm)(1, pn, pm, zn, zm, fpmr); +} + +// CHECK-LABEL: define dso_local void @test_svmopa_za32_mf8_m( +// CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]]) +// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]]) +// CHECK-NEXT: ret void +// +// CPP-CHECK-LABEL: define dso_local void @_Z22test_svmopa_za32_mf8_mu10__SVBool_tS_u13__SVMfloat8_tS0_m( +// CPP-CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] { +// CPP-CHECK-NEXT: [[ENTRY:.*:]] +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]]) +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]]) +// CPP-CHECK-NEXT: ret void +// +void test_svmopa_za32_mf8_m(svbool_t pn, svbool_t pm, svmfloat8_t zn, + svmfloat8_t zm, fpm_t fpmr) __arm_streaming __arm_inout("za") { + SVE_ACLE_FUNC(svmopa_za32,_mf8,_m_fpm)(3, pn, pm, zn, zm, fpmr); +} diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c new file mode 100644 index 00000000000000..62cad9cfa4c8fd --- /dev/null +++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -triple aarch64 -target-feature +sme -target-feature +sme2 -target-feature +sme-f8f16 -target-feature +sme-f8f32 -fsyntax-only -verify %s + +// REQUIRES: aarch64-registered-target + +#include <arm_sme.h> + +void test_svmopa(svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm, + fpm_t fpmr) __arm_streaming __arm_inout("za") { + // expected-error@+1 {{argument value 18446744073709551615 is outside the valid range [0, 1]}} + svmopa_za16_mf8_m_fpm(-1, pn, pm, zn, zm, fpmr); + // expected-error@+1 {{argument value 2 is outside the valid range [0, 1]}} + svmopa_za16_mf8_m_fpm(2, pn, pm, zn, zm, fpmr); + + // expected-error@+1 {{argument value 18446744073709551615 is outside the valid range [0, 3]}} + svmopa_za32_mf8_m_fpm(-1, pn, pm, zn, zm, fpmr); + // expected-error@+1 {{argument value 4 is outside the valid range [0, 3]}} + svmopa_za32_mf8_m_fpm(4, pn, pm, zn, zm, fpmr); +} diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c new file mode 100644 index 00000000000000..86426abcd43291 --- /dev/null +++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c @@ -0,0 +1,13 @@ +// RUN: %clang_cc1 -triple aarch64 -target-feature +sme -verify -emit-llvm-only %s + +// REQUIRES: aarch64-registered-target + +#include <arm_sme.h> + +void test_features(svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm, + fpm_t fpmr) __arm_streaming __arm_inout("za") { + // expected-error@+1 {{'svmopa_za16_mf8_m_fpm' needs target feature sme,sme-f8f16}} + svmopa_za16_mf8_m_fpm(0, pn, pm, zn, zm, fpmr); + // expected-error@+1 {{'svmopa_za32_mf8_m_fpm' needs target feature sme,sme-f8f32}} + svmopa_za32_mf8_m_fpm(0, pn, pm, zn, zm, fpmr); +} diff --git a/clang/utils/TableGen/SveEmitter.cpp b/clang/utils/TableGen/SveEmitter.cpp index e9fa01ea98dced..e24e93e8f29d8f 100644 --- a/clang/utils/TableGen/SveEmitter.cpp +++ b/clang/utils/TableGen/SveEmitter.cpp @@ -587,7 +587,6 @@ void SVEType::applyTypespec(StringRef TS) { ElementBitwidth = 16; break; case 'm': - Signed = false; MFloat = true; Float = false; BFloat = false; @@ -702,6 +701,7 @@ void SVEType::applyModifier(char Mod) { Svcount = false; Float = false; BFloat = false; + MFloat = false; ElementBitwidth = Bitwidth = 64; NumVectors = 0; Signed = false; @@ -712,6 +712,7 @@ void SVEType::applyModifier(char Mod) { Svcount = false; Float = false; BFloat = false; + MFloat = false; ElementBitwidth = Bitwidth = 32; NumVectors = 0; Signed = true; @@ -723,6 +724,7 @@ void SVEType::applyModifier(char Mod) { Svcount = false; Float = false; BFloat = false; + MFloat = false; ElementBitwidth = Bitwidth = 32; NumVectors = 0; Signed = true; @@ -735,6 +737,7 @@ void SVEType::applyModifier(char Mod) { Signed = true; Float = false; BFloat = false; + MFloat = false; ElementBitwidth = Bitwidth = 32; NumVectors = 0; break; @@ -744,6 +747,7 @@ void SVEType::applyModifier(char Mod) { Signed = true; Float = false; BFloat = false; + MFloat = false; ElementBitwidth = Bitwidth = 64; NumVectors = 0; break; @@ -753,6 +757,7 @@ void SVEType::applyModifier(char Mod) { Signed = false; Float = false; BFloat = false; + MFloat = false; ElementBitwidth = Bitwidth = 32; NumVectors = 0; break; @@ -765,6 +770,7 @@ void SVEType::applyModifier(char Mod) { Signed = false; Float = false; BFloat = false; + MFloat = false; ElementBitwidth = Bitwidth = 64; NumVectors = 0; break; @@ -783,6 +789,7 @@ void SVEType::applyModifier(char Mod) { case 'g': Signed = false; Float = false; + MFloat = false; BFloat = false; ElementBitwidth = 64; break; @@ -790,18 +797,21 @@ void SVEType::applyModifier(char Mod) { Signed = false; Float = false; BFloat = false; + MFloat = false; ElementBitwidth = 8; break; case 't': Signed = true; Float = false; BFloat = false; + MFloat = false; ElementBitwidth = 32; break; case 'z': Signed = false; Float = false; BFloat = false; + MFloat = false; ElementBitwidth = 32; break; case 'O': @@ -815,6 +825,7 @@ void SVEType::applyModifier(char Mod) { Svcount = false; Float = true; BFloat = false; + MFloat = false; ElementBitwidth = 32; break; case 'N': @@ -922,6 +933,7 @@ void SVEType::applyModifier(char Mod) { Predicate = false; Svcount = false; Float = false; + MFloat = false; BFloat = true; ElementBitwidth = 16; break; @@ -932,6 +944,7 @@ void SVEType::applyModifier(char Mod) { NumVectors = 0; Float = false; BFloat = false; + MFloat = false; break; case '~': Float = false; diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td index a91616b9556828..0fde957ecbba6e 100644 --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -2983,6 +2983,13 @@ let TargetPrefix = "aarch64" in { LLVMMatchType<0>, llvm_anyvector_ty], [ImmArg<ArgIndex<0>>]>; + class SME_FP8_OuterProduct_Intrinsic + : DefaultAttrsIntrinsic<[], + [llvm_i32_ty, + llvm_nxv16i1_ty, llvm_nxv16i1_ty, + llvm_nxv16i8_ty, llvm_nxv16i8_ty], + [ImmArg<ArgIndex<0>>, IntrInaccessibleMemOnly, IntrHasSideEffects]>; + def int_aarch64_sme_mopa : SME_OuterProduct_Intrinsic; def int_aarch64_sme_mops : SME_OuterProduct_Intrinsic; @@ -2998,6 +3005,10 @@ let TargetPrefix = "aarch64" in { def int_aarch64_sme_usmopa_wide : SME_OuterProduct_Intrinsic; def int_aarch64_sme_usmops_wide : SME_OuterProduct_Intrinsic; + // FP8 outer product + def int_aarch64_sme_fp8_fmopa_za16 : SME_FP8_OuterProduct_Intrinsic; + def int_aarch64_sme_fp8_fmopa_za32 : SME_FP8_OuterProduct_Intrinsic; + class SME_AddVectorToTile_Intrinsic : DefaultAttrsIntrinsic<[], [llvm_i32_ty, diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index 37ac915d1d8808..9c657787d3492b 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -990,7 +990,7 @@ defm FDOT_VG2_M2ZZI_BtoH : sme2p1_multi_vec_array_vg2_index_f8f16<"fdot", 0b11 defm FDOT_VG4_M4ZZI_BtoH : sme2p1_multi_vec_array_vg4_index_f8f16<"fdot", 0b100, ZZZZ_b_mul_r, ZPR4b8>; defm FDOT_VG2_M2ZZ_BtoH : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010001, MatrixOp16, ZZ_b, ZPR4b8>; defm FDOT_VG4_M4ZZ_BtoH : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110001, MatrixOp16, ZZZZ_b, ZPR4b8>; -// TODO: Replace nxv16i8 by nxv16f8 + defm FDOT_VG2_M2Z2Z_BtoH : sme2_dot_mla_add_sub_array_vg2_multi<"fdot", 0b0100100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>; defm FDOT_VG4_M4Z4Z_BtoH : sme2_dot_mla_add_sub_array_vg4_multi<"fdot", 0b0100100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>; @@ -998,23 +998,22 @@ def FMLAL_MZZI_BtoH : sme2_mla_ll_array_index_16b<"fmlal", 0b11, 0b00>; defm FMLAL_VG2_M2ZZI_BtoH : sme2_multi_vec_array_vg2_index_16b<"fmlal", 0b10, 0b111>; defm FMLAL_VG4_M4ZZI_BtoH : sme2_multi_vec_array_vg4_index_16b<"fmlal", 0b10, 0b110>; def FMLAL_VG2_MZZ_BtoH : sme2_mla_long_array_single_16b<"fmlal">; -// TODO: Replace nxv16i8 by nxv16f8 + defm FMLAL_VG2_M2ZZ_BtoH : sme2_fp_mla_long_array_vg2_single<"fmlal", 0b001, MatrixOp16, ZZ_b, ZPR4b8, nxv16i8, null_frag>; defm FMLAL_VG4_M4ZZ_BtoH : sme2_fp_mla_long_array_vg4_single<"fmlal", 0b001, MatrixOp16, ZZZZ_b, ZPR4b8, nxv16i8, null_frag>; defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal", 0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>; defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal", 0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>; -defm FMOPA_MPPZZ_BtoH : sme2p1_fmop_tile_f8f16<"fmopa", 0b1, 0b0, 0b01>; - +defm FMOPA_MPPZZ_BtoH : sme2_fp8_fmopa_za16<"fmopa", int_aarch64_sme_fp8_fmopa_za16>; } //[HasSMEF8F16] let Predicates = [HasSMEF8F32] in { -// TODO : Replace nxv16i8 by nxv16f8 + defm FDOT_VG2_M2ZZI_BtoS : sme2_multi_vec_array_vg2_index_32b<"fdot", 0b01, 0b0111, ZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>; defm FDOT_VG4_M4ZZI_BtoS : sme2_multi_vec_array_vg4_index_32b<"fdot", 0b0001, ZZZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>; defm FDOT_VG2_M2ZZ_BtoS : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010011, MatrixOp32, ZZ_b, ZPR4b8>; defm FDOT_VG4_M4ZZ_BtoS : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110011, MatrixOp32, ZZZZ_b, ZPR4b8>; -// TODO : Replace nxv16i8 by nxv16f8 + defm FDOT_VG2_M2Z2Z_BtoS : sme2_dot_mla_add_sub_array_vg2_multi<"fdot", 0b0100110, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>; defm FDOT_VG4_M4Z4Z_BtoS : sme2_dot_mla_add_sub_array_vg4_multi<"fdot", 0b0100110, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>; @@ -1024,16 +1023,14 @@ def FVDOTT_VG4_M2ZZI_BtoS : sme2_fp8_multi_vec_array_vg4_index<"fvdott", 0b1>; defm FMLALL_MZZI_BtoS : sme2_mla_ll_array_index_32b<"fmlall", 0b01, 0b000, null_frag>; defm FMLALL_VG2_M2ZZI_BtoS : sme2_mla_ll_array_vg2_index_32b<"fmlall", 0b10, 0b100, null_frag>; defm FMLALL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"fmlall", 0b00, 0b1000, null_frag>; -// TODO: Replace nxv16i8 by nxv16f8 + defm FMLALL_MZZ_BtoS : sme2_mla_ll_array_single<"fmlall", 0b01000, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, null_frag>; defm FMLALL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg24_single<"fmlall", 0b000001, MatrixOp32, ZZ_b, ZPR4b8>; defm FMLALL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg24_single<"fmlall", 0b010001, MatrixOp32, ZZZZ_b, ZPR4b8>; defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall", 0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>; defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall", 0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>; - -defm FMOPA_MPPZZ_BtoS : sme_outer_product_fp32<0b0, 0b01, ZPR8, "fmopa", null_frag>; - +defm FMOPA_MPPZZ_BtoS : sme2_fp8_fmopa_za32<"fmopa", int_aarch64_sme_fp8_fmopa_za32>; } //[HasSMEF8F32] let Predicates = [HasSME2, HasSVEBFSCALE] in { diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td index 776472e72af05a..e6535f957e2024 100644 --- a/llvm/lib/Target/AArch64/SMEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td @@ -305,6 +305,21 @@ multiclass sme_outer_product_fp32<bit S, bits<2> sz, ZPRRegOp zpr_ty, string mne def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, op, timm32_0_3, nxv4i1, nxv4f32>; } +multiclass sme2_fp8_fmopa_za32<string mnemonic, SDPatternOperator intrinsic> { + def NAME : sme_fp_outer_product_inst<0, 0b01, 0b00, TileOp32, ZPR8, mnemonic>, SMEPseudo2Instr<NAME, 1> { + bits<2> ZAda; + let Inst{1-0} = ZAda; + let Inst{2} = 0b0; + + let Uses = [FPMR, FPCR]; + } + + let mayStore = 1, mayLoad = 1 in + def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8, SMEMatrixTileS>, SMEPseudo2Instr<NAME, 0>; + + def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, intrinsic, timm32_0_3, nxv16i1, nxv16i8>; +} + multiclass sme_outer_product_fp64<bit S, string mnemonic, SDPatternOperator op> { def NAME : sme_fp_outer_product_inst<S, 0b10, 0b00, TileOp64, ZPR64, mnemonic>, SMEPseudo2Instr<NAME, 1> { bits<3> ZAda; @@ -316,12 +331,19 @@ multiclass sme_outer_product_fp64<bit S, string mnemonic, SDPatternOperator op> def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, op, timm32_0_7, nxv2i1, nxv2f64>; } -multiclass sme2p1_fmop_tile_f8f16<string mnemonic, bit bf, bit s, bits<2> op> { - def NAME : sme_fp_outer_product_inst<s, {0,bf}, op, TileOp16, ZPR8, mnemonic> { +multiclass sme2_fp8_fmopa_za16<string mnemonic, SDPatternOperator intrinsic> { + def NAME : sme_fp_outer_product_inst<0, {0, 0b1}, 0b01, TileOp16, ZPR8, mnemonic>, SMEPseudo2Instr<NAME, 1> { bits<1> ZAda; let Inst{2-1} = 0b00; let Inst{0} = ZAda; + + let Uses = [FPMR, FPCR]; } + + let mayStore = 1, mayLoad = 1 in + def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8, SMEMatrixTileH>, SMEPseudo2Instr<NAME, 0>; + + def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, intrinsic, timm32_0_1, nxv16i1, nxv16i8>; } multiclass sme2p1_fmop_tile_fp16<string mnemonic, bit bf, bit s, ValueType vt, SDPatternOperator intrinsic = null_frag> { diff --git a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll new file mode 100644 index 00000000000000..6e88cdf4e7fec3 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll @@ -0,0 +1,22 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4 +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme-f8f16,+sme-f8f32 -force-streaming < %s | FileCheck %s + +define void @test_fmopa_16(<vscale x 16 x i1> %pn, <vscale x 16 x i1> %pm, <vscale x 16 x i8> %vn, <vscale x 16 x i8> %vm) { +; CHECK-LABEL: test_fmopa_16: +; CHEC... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/118115 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits