kmclaughlin created this revision. kmclaughlin added reviewers: huntergr, rovka, greened. Herald added subscribers: psnobl, rkruppe, hiraditya, kristof.beyls, tschuett. Herald added a project: LLVM. kmclaughlin added a parent revision: D47775: [AArch64][SVE] Add SPLAT_VECTOR ISD Node.
Adds support for codegen of masked loads, with non-extending, zero-extending and sign-extending variants. Depends on the changes in D47775 <https://reviews.llvm.org/D47775> for isConstantSplatVectorMaskForType Repository: rG LLVM Github Monorepo https://reviews.llvm.org/D68877 Files: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp llvm/lib/CodeGen/TargetLoweringBase.cpp llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp llvm/lib/Target/AArch64/AArch64InstrInfo.td llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h llvm/lib/Target/AArch64/SVEInstrFormats.td llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll @@ -0,0 +1,72 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; +; Masked Loads +; + +define <vscale x 2 x i64> @masked_zload_nxv2i8(<vscale x 2 x i8>* %src, <vscale x 2 x i1> %mask) { +; CHECK-LABEL: masked_zload_nxv2i8: +; CHECK-NOT: ld1sb +; CHECK: ld1b { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>* %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> undef) + %ext = zext <vscale x 2 x i8> %load to <vscale x 2 x i64> + ret <vscale x 2 x i64> %ext +} + +define <vscale x 2 x i64> @masked_zload_nxv2i16(<vscale x 2 x i16>* %src, <vscale x 2 x i1> %mask) { +; CHECK-LABEL: masked_zload_nxv2i16: +; CHECK-NOT: ld1sh +; CHECK: ld1h { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>* %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i16> undef) + %ext = zext <vscale x 2 x i16> %load to <vscale x 2 x i64> + ret <vscale x 2 x i64> %ext +} + +define <vscale x 2 x i64> @masked_zload_nxv2i32(<vscale x 2 x i32>* %src, <vscale x 2 x i1> %mask) { +; CHECK-LABEL: masked_zload_nxv2i32: +; CHECK-NOT: ld1sw +; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>* %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i32> undef) + %ext = zext <vscale x 2 x i32> %load to <vscale x 2 x i64> + ret <vscale x 2 x i64> %ext +} + +define <vscale x 4 x i32> @masked_zload_nxv4i8(<vscale x 4 x i8>* %src, <vscale x 4 x i1> %mask) { +; CHECK-LABEL: masked_zload_nxv4i8: +; CHECK-NOT: ld1sb +; CHECK: ld1b { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>* %src, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i8> undef) + %ext = zext <vscale x 4 x i8> %load to <vscale x 4 x i32> + ret <vscale x 4 x i32> %ext +} + +define <vscale x 4 x i32> @masked_zload_nxv4i16(<vscale x 4 x i16>* %src, <vscale x 4 x i1> %mask) { +; CHECK-LABEL: masked_zload_nxv4i16: +; CHECK-NOT: ld1sh +; CHECK: ld1h { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>* %src, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i16> undef) + %ext = zext <vscale x 4 x i16> %load to <vscale x 4 x i32> + ret <vscale x 4 x i32> %ext +} + +define <vscale x 8 x i16> @masked_zload_nxv8i8(<vscale x 8 x i8>* %src, <vscale x 8 x i1> %mask) { +; CHECK-LABEL: masked_zload_nxv8i8: +; CHECK-NOT: ld1sb +; CHECK: ld1b { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>* %src, i32 1, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef) + %ext = zext <vscale x 8 x i8> %load to <vscale x 8 x i16> + ret <vscale x 8 x i16> %ext +} + +declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>) +declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>) +declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>) +declare <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>*, i32, <vscale x 4 x i1>, <vscale x 4 x i8>) +declare <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>*, i32, <vscale x 4 x i1>, <vscale x 4 x i16>) +declare <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>*, i32, <vscale x 8 x i1>, <vscale x 8 x i8>) Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll @@ -0,0 +1,66 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; +; Masked Loads +; + +define <vscale x 2 x i64> @masked_sload_nxv2i8(<vscale x 2 x i8> *%a, <vscale x 2 x i1> %mask) { +; CHECK-LABEL: masked_sload_nxv2i8: +; CHECK: ld1sb { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8> *%a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> undef) + %ext = sext <vscale x 2 x i8> %load to <vscale x 2 x i64> + ret <vscale x 2 x i64> %ext +} + +define <vscale x 2 x i64> @masked_sload_nxv2i16(<vscale x 2 x i16> *%a, <vscale x 2 x i1> %mask) { +; CHECK-LABEL: masked_sload_nxv2i16: +; CHECK: ld1sh { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16> *%a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i16> undef) + %ext = sext <vscale x 2 x i16> %load to <vscale x 2 x i64> + ret <vscale x 2 x i64> %ext +} + +define <vscale x 2 x i64> @masked_sload_nxv2i32(<vscale x 2 x i32> *%a, <vscale x 2 x i1> %mask) { +; CHECK-LABEL: masked_sload_nxv2i32: +; CHECK: ld1sw { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32> *%a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i32> undef) + %ext = sext <vscale x 2 x i32> %load to <vscale x 2 x i64> + ret <vscale x 2 x i64> %ext +} + +define <vscale x 4 x i32> @masked_sload_nxv4i8(<vscale x 4 x i8> *%a, <vscale x 4 x i1> %mask) { +; CHECK-LABEL: masked_sload_nxv4i8: +; CHECK: ld1sb { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8> *%a, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i8> undef) + %ext = sext <vscale x 4 x i8> %load to <vscale x 4 x i32> + ret <vscale x 4 x i32> %ext +} + +define <vscale x 4 x i32> @masked_sload_nxv4i16(<vscale x 4 x i16> *%a, <vscale x 4 x i1> %mask) { +; CHECK-LABEL: masked_sload_nxv4i16: +; CHECK: ld1sh { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16> *%a, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i16> undef) + %ext = sext <vscale x 4 x i16> %load to <vscale x 4 x i32> + ret <vscale x 4 x i32> %ext +} + +define <vscale x 8 x i16> @masked_sload_nxv8i8(<vscale x 8 x i8> *%a, <vscale x 8 x i1> %mask) { +; CHECK-LABEL: masked_sload_nxv8i8: +; CHECK: ld1sb { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8> *%a, i32 1, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef) + %ext = sext <vscale x 8 x i8> %load to <vscale x 8 x i16> + ret <vscale x 8 x i16> %ext +} + +declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>) +declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>) +declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>) +declare <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>*, i32, <vscale x 4 x i1>, <vscale x 4 x i8>) +declare <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>*, i32, <vscale x 4 x i1>, <vscale x 4 x i16>) +declare <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>*, i32, <vscale x 8 x i1>, <vscale x 8 x i8>) Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll @@ -0,0 +1,87 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; +; Masked Loads +; + +define <vscale x 2 x i64> @masked_load_nxv2i64(<vscale x 2 x i64> *%a, <vscale x 2 x i1> %mask) { +; CHECK-LABEL: masked_load_nxv2i64: +; CHECK: ld1d { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call <vscale x 2 x i64> @llvm.masked.load.nxv2i64(<vscale x 2 x i64> *%a, i32 8, <vscale x 2 x i1> %mask, <vscale x 2 x i64> undef) + ret <vscale x 2 x i64> %load +} + +define <vscale x 4 x i32> @masked_load_nxv4i32(<vscale x 4 x i32> *%a, <vscale x 4 x i1> %mask) { +; CHECK-LABEL: masked_load_nxv4i32: +; CHECK: ld1w { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] + %load = call <vscale x 4 x i32> @llvm.masked.load.nxv4i32(<vscale x 4 x i32> *%a, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef) + ret <vscale x 4 x i32> %load +} + +define <vscale x 8 x i16> @masked_load_nxv8i16(<vscale x 8 x i16> *%a, <vscale x 8 x i1> %mask) { +; CHECK-LABEL: masked_load_nxv8i16: +; CHECK: ld1h { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] + %load = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16(<vscale x 8 x i16> *%a, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef) + ret <vscale x 8 x i16> %load +} + +define <vscale x 16 x i8> @masked_load_nxv16i8(<vscale x 16 x i8> *%a, <vscale x 16 x i1> %mask) { +; CHECK-LABEL: masked_load_nxv16i8: +; CHECK: ld1b { [[IN:z[0-9]+]].b }, [[PG:p[0-9]+]]/z, [x0] + %load = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8(<vscale x 16 x i8> *%a, i32 1, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef) + ret <vscale x 16 x i8> %load +} + +define <vscale x 2 x double> @masked_load_nxv2f64(<vscale x 2 x double> *%a, <vscale x 2 x i1> %mask) { +; CHECK-LABEL: masked_load_nxv2f64: +; CHECK: ld1d { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call <vscale x 2 x double> @llvm.masked.load.nxv2f64(<vscale x 2 x double> *%a, i32 8, <vscale x 2 x i1> %mask, <vscale x 2 x double> undef) + ret <vscale x 2 x double> %load +} + +define <vscale x 2 x float> @masked_load_nxv2f32(<vscale x 2 x float> *%a, <vscale x 2 x i1> %mask) { +; CHECK-LABEL: masked_load_nxv2f32: +; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call <vscale x 2 x float> @llvm.masked.load.nxv2f32(<vscale x 2 x float> *%a, i32 4, <vscale x 2 x i1> %mask, <vscale x 2 x float> undef) + ret <vscale x 2 x float> %load +} + +define <vscale x 2 x half> @masked_load_nxv2f16(<vscale x 2 x half> *%a, <vscale x 2 x i1> %mask) { +; CHECK-LABEL: masked_load_nxv2f16: +; CHECK: ld1h { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call <vscale x 2 x half> @llvm.masked.load.nxv2f16(<vscale x 2 x half> *%a, i32 2, <vscale x 2 x i1> %mask, <vscale x 2 x half> undef) + ret <vscale x 2 x half> %load +} + +define <vscale x 4 x float> @masked_load_nxv4f32(<vscale x 4 x float> *%a, <vscale x 4 x i1> %mask) { +; CHECK-LABEL: masked_load_nxv4f32: +; CHECK: ld1w { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] + %load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32(<vscale x 4 x float> *%a, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> undef) + ret <vscale x 4 x float> %load +} + +define <vscale x 4 x half> @masked_load_nxv4f16(<vscale x 4 x half> *%a, <vscale x 4 x i1> %mask) { +; CHECK-LABEL: masked_load_nxv4f16: +; CHECK: ld1h { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] + %load = call <vscale x 4 x half> @llvm.masked.load.nxv4f16(<vscale x 4 x half> *%a, i32 2, <vscale x 4 x i1> %mask, <vscale x 4 x half> undef) + ret <vscale x 4 x half> %load +} + +define <vscale x 8 x half> @masked_load_nxv8f16(<vscale x 8 x half> *%a, <vscale x 8 x i1> %mask) { +; CHECK-LABEL: masked_load_nxv8f16: +; CHECK: ld1h { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] + %load = call <vscale x 8 x half> @llvm.masked.load.nxv8f16(<vscale x 8 x half> *%a, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x half> undef) + ret <vscale x 8 x half> %load +} + +declare <vscale x 2 x i64> @llvm.masked.load.nxv2i64(<vscale x 2 x i64>*, i32, <vscale x 2 x i1>, <vscale x 2 x i64>) +declare <vscale x 4 x i32> @llvm.masked.load.nxv4i32(<vscale x 4 x i32>*, i32, <vscale x 4 x i1>, <vscale x 4 x i32>) +declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16(<vscale x 8 x i16>*, i32, <vscale x 8 x i1>, <vscale x 8 x i16>) +declare <vscale x 16 x i8> @llvm.masked.load.nxv16i8(<vscale x 16 x i8>*, i32, <vscale x 16 x i1>, <vscale x 16 x i8>) + +declare <vscale x 2 x double> @llvm.masked.load.nxv2f64(<vscale x 2 x double>*, i32, <vscale x 2 x i1>, <vscale x 2 x double>) +declare <vscale x 2 x float> @llvm.masked.load.nxv2f32(<vscale x 2 x float>*, i32, <vscale x 2 x i1>, <vscale x 2 x float>) +declare <vscale x 2 x half> @llvm.masked.load.nxv2f16(<vscale x 2 x half>*, i32, <vscale x 2 x i1>, <vscale x 2 x half>) +declare <vscale x 4 x float> @llvm.masked.load.nxv4f32(<vscale x 4 x float>*, i32, <vscale x 4 x i1>, <vscale x 4 x float>) +declare <vscale x 4 x half> @llvm.masked.load.nxv4f16(<vscale x 4 x half>*, i32, <vscale x 4 x i1>, <vscale x 4 x half>) +declare <vscale x 8 x half> @llvm.masked.load.nxv8f16(<vscale x 8 x half>*, i32, <vscale x 8 x i1>, <vscale x 8 x half>) Index: llvm/lib/Target/AArch64/SVEInstrFormats.td =================================================================== --- llvm/lib/Target/AArch64/SVEInstrFormats.td +++ llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -293,6 +293,8 @@ : Pat<(vtd (op vt1:$Op1, vt2:$Op2, vt3:$Op3)), (inst $Op1, $Op2, $Op3)>; +def SVEUndef : ComplexPattern<i64, 0, "SelectUndef", []>; + //===----------------------------------------------------------------------===// // SVE Predicate Misc Group //===----------------------------------------------------------------------===// @@ -4732,6 +4734,13 @@ (!cast<Instruction>(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), 0>; def : InstAlias<asm # "\t$Zt, $Pg/z, [$Rn]", (!cast<Instruction>(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 1>; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1, mayLoad = 1 in { + def "" : Pseudo<(outs listty:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), []>, + PseudoInstExpansion<(!cast<Instruction>(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4)>; + } } multiclass sve_mem_cld_si<bits<4> dtype, string asm, RegisterOperand listty, Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -147,6 +147,13 @@ bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info); + bool isLegalMaskedLoad(Type *DataType) { + return ST->hasSVE(); + } + bool isLegalMaskedStore(Type *DataType) { + return ST->hasSVE(); + } + int getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices, unsigned Alignment, unsigned AddressSpace, Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1070,6 +1070,44 @@ def : Pat<(nxv2f64 (bitconvert (nxv8f16 ZPR:$src))), (nxv2f64 ZPR:$src)>; def : Pat<(nxv2f64 (bitconvert (nxv4f32 ZPR:$src))), (nxv2f64 ZPR:$src)>; + // Add more complex addressing modes here as required + multiclass pred_load<ValueType Ty, ValueType PredTy, SDPatternOperator Load, + Instruction RegImmInst> { + + def _default_z : Pat<(Ty (Load GPR64:$base, (PredTy PPR:$gp), (SVEUndef))), + (RegImmInst PPR:$gp, GPR64:$base, (i64 0))>; + } + + // 2-element contiguous loads + defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i8, LD1B_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i8, LD1SB_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i16, LD1H_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i16, LD1SH_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i32, LD1W_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i32, LD1SW_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, nonext_masked_load, LD1D_IMM>; + defm : pred_load<nxv2f16, nxv2i1, nonext_masked_load, LD1H_D_IMM>; + defm : pred_load<nxv2f32, nxv2i1, nonext_masked_load, LD1W_D_IMM>; + defm : pred_load<nxv2f64, nxv2i1, nonext_masked_load, LD1D_IMM>; + + // 4-element contiguous loads + defm : pred_load<nxv4i32, nxv4i1, zext_masked_load_i8, LD1B_S_IMM>; + defm : pred_load<nxv4i32, nxv4i1, asext_masked_load_i8, LD1SB_S_IMM>; + defm : pred_load<nxv4i32, nxv4i1, zext_masked_load_i16, LD1H_S_IMM>; + defm : pred_load<nxv4i32, nxv4i1, asext_masked_load_i16, LD1SH_S_IMM>; + defm : pred_load<nxv4i32, nxv4i1, nonext_masked_load, LD1W_IMM>; + defm : pred_load<nxv4f16, nxv4i1, nonext_masked_load, LD1H_S_IMM>; + defm : pred_load<nxv4f32, nxv4i1, nonext_masked_load, LD1W_IMM>; + + // 8-element contiguous loads + defm : pred_load<nxv8i16, nxv8i1, zext_masked_load_i8, LD1B_H_IMM>; + defm : pred_load<nxv8i16, nxv8i1, asext_masked_load_i8, LD1SB_H_IMM>; + defm : pred_load<nxv8i16, nxv8i1, nonext_masked_load, LD1H_IMM>; + defm : pred_load<nxv8f16, nxv8i1, nonext_masked_load, LD1H_IMM>; + + // 16-element contiguous loads + defm : pred_load<nxv16i8, nxv16i1, nonext_masked_load, LD1B_IMM>; + } let Predicates = [HasSVE2] in { Index: llvm/lib/Target/AArch64/AArch64InstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -256,6 +256,55 @@ SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>, SDTCisSameAs<1, 4>]>; +// non-extending masked load fragment. +def nonext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD; +}]>; +// sign extending masked load fragments. +def asext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def),[{ + return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::EXTLOAD || + cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::SEXTLOAD; +}]>; +def asext_masked_load_i8 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def asext_masked_load_i16 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def asext_masked_load_i32 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; +// zero extending masked load fragments. +def zext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::ZEXTLOAD; +}]>; +def zext_masked_load_i8 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def zext_masked_load_i16 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def zext_masked_load_i32 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; // Node definitions. def AArch64adrp : SDNode<"AArch64ISD::ADRP", SDTIntUnaryOp, []>; Index: llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -140,6 +140,11 @@ return SelectAddrModeXRO(N, Width / 8, Base, Offset, SignExtend, DoShift); } + bool SelectUndef(SDValue N) { + if (N->getOpcode() == ISD::UNDEF) + return true; + return false; + } /// Form sequences of consecutive 64/128-bit registers for use in NEON /// instructions making use of a vector-list (e.g. ldN, tbl). Vecs must have @@ -204,7 +209,7 @@ bool SelectAddrModeXRO(SDValue N, unsigned Size, SDValue &Base, SDValue &Offset, SDValue &SignExtend, SDValue &DoShift); - bool isWorthFolding(SDValue V) const; + bool isWorthFolding(SDValue V, unsigned MaxUses = 1) const; bool SelectExtendedSHL(SDValue N, unsigned Size, bool WantExtend, SDValue &Offset, SDValue &SignExtend); @@ -375,7 +380,7 @@ } /// Determine whether it is worth to fold V into an extended register. -bool AArch64DAGToDAGISel::isWorthFolding(SDValue V) const { +bool AArch64DAGToDAGISel::isWorthFolding(SDValue V, unsigned MaxUses) const { // Trivial if we are optimizing for code size or if there is only // one use of the value. if (ForCodeSize || V.hasOneUse()) @@ -394,6 +399,18 @@ return true; } + // If it has more than one use, check they're all loads/stores + // from/to the same memory type (e.g. if you can fold for one + // addressing mode, you can fold for the others as well). + EVT VT; + for (auto *Use : V.getNode()->uses()) + if (auto *MemNode = dyn_cast<MemSDNode>(Use)) + if (MemNode->getMemoryVT() != VT && VT != EVT()) + return false; + + if (V.getNode()->use_size() <= MaxUses) + return true; + // It hurts otherwise, since the value will be reused. return false; } Index: llvm/lib/CodeGen/TargetLoweringBase.cpp =================================================================== --- llvm/lib/CodeGen/TargetLoweringBase.cpp +++ llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -1265,18 +1265,23 @@ MVT EltVT = VT.getVectorElementType(); unsigned NElts = VT.getVectorNumElements(); bool IsLegalWiderType = false; + bool IsScalable = VT.isScalableVector(); LegalizeTypeAction PreferredAction = getPreferredVectorAction(VT); switch (PreferredAction) { - case TypePromoteInteger: + case TypePromoteInteger: { + MVT::SimpleValueType EndVT = IsScalable ? + MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE : + MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE; // Try to promote the elements of integer vectors. If no legal // promotion was found, fall through to the widen-vector method. for (unsigned nVT = i + 1; - nVT <= MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE; ++nVT) { + (MVT::SimpleValueType) nVT <= EndVT; ++nVT) { MVT SVT = (MVT::SimpleValueType) nVT; // Promote vectors of integers to vectors with the same number // of elements, with a wider element type. if (SVT.getScalarSizeInBits() > EltVT.getSizeInBits() && - SVT.getVectorNumElements() == NElts && isTypeLegal(SVT)) { + SVT.getVectorNumElements() == NElts && isTypeLegal(SVT) && + SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) { TransformToType[i] = SVT; RegisterTypeForVT[i] = SVT; NumRegistersForVT[i] = 1; @@ -1288,6 +1293,7 @@ if (IsLegalWiderType) break; LLVM_FALLTHROUGH; + } case TypeWidenVector: if (isPowerOf2_32(NElts)) { @@ -1295,6 +1301,7 @@ for (unsigned nVT = i + 1; nVT <= MVT::LAST_VECTOR_VALUETYPE; ++nVT) { MVT SVT = (MVT::SimpleValueType) nVT; if (SVT.getVectorElementType() == EltVT + && SVT.isScalableVector() == IsScalable && SVT.getVectorNumElements() > NElts && isTypeLegal(SVT)) { TransformToType[i] = SVT; RegisterTypeForVT[i] = SVT; Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4447,12 +4447,15 @@ const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); // Do not serialize masked loads of constant memory with anything. - bool AddToChain = - !AA || !AA->pointsToConstantMemory(MemoryLocation( - PtrOperand, - LocationSize::precise( - DAG.getDataLayout().getTypeStoreSize(I.getType())), - AAInfo)); + bool AddToChain = false; + if (!VT.isScalableVector()) + AddToChain = + !AA || !AA->pointsToConstantMemory(MemoryLocation( + PtrOperand, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo)); + SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode(); MachineMemOperand *MMO = Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -838,6 +838,34 @@ return false; } +static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) { + if (!ScalarTy.isSimple()) + return false; + + uint64_t MaskForTy = 0ull; + switch(ScalarTy.getSimpleVT().SimpleTy) { + case MVT::i8: + MaskForTy = 0xffull; + break; + case MVT::i16: + MaskForTy = 0xffffull; + break; + case MVT::i32: + MaskForTy = 0xffffffffull; + break; + default: + return false; + break; + } + + APInt Val; + if (ISD::isConstantSplatVector(N, Val)) { + return Val.getLimitedValue() == MaskForTy; + } + + return false; +} + // Returns the SDNode if it is a constant float BuildVector // or constant float. static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) { @@ -5221,6 +5249,24 @@ } } + if (auto *LN0 = dyn_cast<MaskedLoadSDNode>(N0)) { + EVT MemVT = LN0->getMemoryVT(); + EVT ScalarVT = MemVT.getScalarType(); + if (SDValue(LN0, 0).hasOneUse() + && isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) + && TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT)) { + SDValue ZExtLoad = DAG.getMaskedLoad(VT, SDLoc(N), LN0->getChain(), + LN0->getBasePtr(), LN0->getMask(), + LN0->getPassThru(), MemVT, + LN0->getMemOperand(), ISD::ZEXTLOAD); + CombineTo(N, ZExtLoad); + CombineTo(N0.getNode(), ZExtLoad, ZExtLoad.getValue(1)); + AddToWorklist(ZExtLoad.getNode()); + // Avoid recheck of N. + return SDValue(N, 0); + } + } + // fold (and (load x), 255) -> (zextload x, i8) // fold (and (extload x, i16), 255) -> (zextload x, i8) // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8) @@ -9043,6 +9089,9 @@ if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT)) return SDValue(); + if (DstVT.isScalableVector()) + return SDValue(); + SDLoc DL(N); const unsigned NumSplits = DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements(); @@ -10337,6 +10386,21 @@ AddToWorklist(ExtLoad.getNode()); return SDValue(N, 0); // Return N so it doesn't get rechecked! } + // fold (sext_inreg (masked_load x)) -> (sext_masked_load x) + if (isa<MaskedLoadSDNode>(N0) && + EVT == cast<MaskedLoadSDNode>(N0)->getMemoryVT() && + ((!LegalOperations && !cast<MaskedLoadSDNode>(N0)->isVolatile()) || + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) { + MaskedLoadSDNode *LN0 = cast<MaskedLoadSDNode>(N0); + SDValue ExtLoad = DAG.getMaskedLoad(VT, SDLoc(N), LN0->getChain(), + LN0->getBasePtr(), LN0->getMask(), + LN0->getPassThru(), LN0->getMemoryVT(), + LN0->getMemOperand(), ISD::SEXTLOAD); + CombineTo(N, ExtLoad); + CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); + AddToWorklist(ExtLoad.getNode()); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse() &&
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits