kmclaughlin created this revision.
kmclaughlin added reviewers: huntergr, rovka, greened.
Herald added subscribers: psnobl, rkruppe, hiraditya, kristof.beyls, tschuett.
Herald added a project: LLVM.
kmclaughlin added a parent revision: D47775: [AArch64][SVE] Add SPLAT_VECTOR 
ISD Node.

Adds support for codegen of masked loads, with non-extending,
zero-extending and sign-extending variants.

Depends on the changes in D47775 <https://reviews.llvm.org/D47775> for 
isConstantSplatVectorMaskForType


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D68877

Files:
  llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
  llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
  llvm/lib/CodeGen/TargetLoweringBase.cpp
  llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
  llvm/lib/Target/AArch64/AArch64InstrInfo.td
  llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
  llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
  llvm/lib/Target/AArch64/SVEInstrFormats.td
  llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
  llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
  llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll

Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
@@ -0,0 +1,72 @@
+; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
+
+;
+; Masked Loads
+;
+
+define <vscale x 2 x i64> @masked_zload_nxv2i8(<vscale x 2 x i8>* %src, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv2i8:
+; CHECK-NOT: ld1sb
+; CHECK: ld1b { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>* %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> undef)
+  %ext = zext <vscale x 2 x i8> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 2 x i64> @masked_zload_nxv2i16(<vscale x 2 x i16>* %src, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv2i16:
+; CHECK-NOT: ld1sh
+; CHECK: ld1h { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>* %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i16> undef)
+  %ext = zext <vscale x 2 x i16> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 2 x i64> @masked_zload_nxv2i32(<vscale x 2 x i32>* %src, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv2i32:
+; CHECK-NOT: ld1sw
+; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>* %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i32> undef)
+  %ext = zext <vscale x 2 x i32> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 4 x i32> @masked_zload_nxv4i8(<vscale x 4 x i8>* %src, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv4i8:
+; CHECK-NOT: ld1sb
+; CHECK: ld1b { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>* %src, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i8> undef)
+  %ext = zext <vscale x 4 x i8> %load to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %ext
+}
+
+define <vscale x 4 x i32> @masked_zload_nxv4i16(<vscale x 4 x i16>* %src, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv4i16:
+; CHECK-NOT: ld1sh
+; CHECK: ld1h { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>* %src, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i16> undef)
+  %ext = zext <vscale x 4 x i16> %load to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %ext
+}
+
+define <vscale x 8 x i16> @masked_zload_nxv8i8(<vscale x 8 x i8>* %src, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv8i8:
+; CHECK-NOT: ld1sb
+; CHECK: ld1b { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>* %src, i32 1, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef)
+  %ext = zext <vscale x 8 x i8> %load to <vscale x 8 x i16>
+  ret <vscale x 8 x i16> %ext
+}
+
+declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
+declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
+declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)
+declare <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>*, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
+declare <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>*, i32, <vscale x 4 x i1>, <vscale x 4 x i16>)
+declare <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>*, i32, <vscale x 8 x i1>, <vscale x 8 x i8>)
Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
@@ -0,0 +1,66 @@
+; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
+
+;
+; Masked Loads
+;
+
+define <vscale x 2 x i64> @masked_sload_nxv2i8(<vscale x 2 x i8> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv2i8:
+; CHECK: ld1sb { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8> *%a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> undef)
+  %ext = sext <vscale x 2 x i8> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 2 x i64> @masked_sload_nxv2i16(<vscale x 2 x i16> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv2i16:
+; CHECK: ld1sh { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16> *%a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i16> undef)
+  %ext = sext <vscale x 2 x i16> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 2 x i64> @masked_sload_nxv2i32(<vscale x 2 x i32> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv2i32:
+; CHECK: ld1sw { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32> *%a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i32> undef)
+  %ext = sext <vscale x 2 x i32> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 4 x i32> @masked_sload_nxv4i8(<vscale x 4 x i8> *%a, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv4i8:
+; CHECK: ld1sb { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8> *%a, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i8> undef)
+  %ext = sext <vscale x 4 x i8> %load to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %ext
+}
+
+define <vscale x 4 x i32> @masked_sload_nxv4i16(<vscale x 4 x i16> *%a, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv4i16:
+; CHECK: ld1sh { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16> *%a, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i16> undef)
+  %ext = sext <vscale x 4 x i16> %load to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %ext
+}
+
+define <vscale x 8 x i16> @masked_sload_nxv8i8(<vscale x 8 x i8> *%a, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv8i8:
+; CHECK: ld1sb { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8> *%a, i32 1, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef)
+  %ext = sext <vscale x 8 x i8> %load to <vscale x 8 x i16>
+  ret <vscale x 8 x i16> %ext
+}
+
+declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
+declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
+declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)
+declare <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>*, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
+declare <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>*, i32, <vscale x 4 x i1>, <vscale x 4 x i16>)
+declare <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>*, i32, <vscale x 8 x i1>, <vscale x 8 x i8>)
Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
@@ -0,0 +1,87 @@
+; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
+
+;
+; Masked Loads
+;
+
+define <vscale x 2 x i64> @masked_load_nxv2i64(<vscale x 2 x i64> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv2i64:
+; CHECK: ld1d { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 2 x i64> @llvm.masked.load.nxv2i64(<vscale x 2 x i64> *%a, i32 8, <vscale x 2 x i1> %mask, <vscale x 2 x i64> undef)
+  ret <vscale x 2 x i64> %load
+}
+
+define <vscale x 4 x i32> @masked_load_nxv4i32(<vscale x 4 x i32> *%a, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv4i32:
+; CHECK: ld1w { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 4 x i32> @llvm.masked.load.nxv4i32(<vscale x 4 x i32> *%a, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef)
+  ret <vscale x 4 x i32> %load
+}
+
+define <vscale x 8 x i16> @masked_load_nxv8i16(<vscale x 8 x i16> *%a, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv8i16:
+; CHECK: ld1h { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16(<vscale x 8 x i16> *%a, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef)
+  ret <vscale x 8 x i16> %load
+}
+
+define <vscale x 16 x i8> @masked_load_nxv16i8(<vscale x 16 x i8> *%a, <vscale x 16 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv16i8:
+; CHECK: ld1b { [[IN:z[0-9]+]].b }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8(<vscale x 16 x i8> *%a, i32 1, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
+  ret <vscale x 16 x i8> %load
+}
+
+define <vscale x 2 x double> @masked_load_nxv2f64(<vscale x 2 x double> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv2f64:
+; CHECK: ld1d { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 2 x double> @llvm.masked.load.nxv2f64(<vscale x 2 x double> *%a, i32 8, <vscale x 2 x i1> %mask, <vscale x 2 x double> undef)
+  ret <vscale x 2 x double> %load
+}
+
+define <vscale x 2 x float> @masked_load_nxv2f32(<vscale x 2 x float> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv2f32:
+; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 2 x float> @llvm.masked.load.nxv2f32(<vscale x 2 x float> *%a, i32 4, <vscale x 2 x i1> %mask, <vscale x 2 x float> undef)
+  ret <vscale x 2 x float> %load
+}
+
+define <vscale x 2 x half> @masked_load_nxv2f16(<vscale x 2 x half> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv2f16:
+; CHECK: ld1h { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 2 x half> @llvm.masked.load.nxv2f16(<vscale x 2 x half> *%a, i32 2, <vscale x 2 x i1> %mask, <vscale x 2 x half> undef)
+  ret <vscale x 2 x half> %load
+}
+
+define <vscale x 4 x float> @masked_load_nxv4f32(<vscale x 4 x float> *%a, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv4f32:
+; CHECK: ld1w { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32(<vscale x 4 x float> *%a, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> undef)
+  ret <vscale x 4 x float> %load
+}
+
+define <vscale x 4 x half> @masked_load_nxv4f16(<vscale x 4 x half> *%a, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv4f16:
+; CHECK: ld1h { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 4 x half> @llvm.masked.load.nxv4f16(<vscale x 4 x half> *%a, i32 2, <vscale x 4 x i1> %mask, <vscale x 4 x half> undef)
+  ret <vscale x 4 x half> %load
+}
+
+define <vscale x 8 x half> @masked_load_nxv8f16(<vscale x 8 x half> *%a, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv8f16:
+; CHECK: ld1h { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 8 x half> @llvm.masked.load.nxv8f16(<vscale x 8 x half> *%a, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x half> undef)
+  ret <vscale x 8 x half> %load
+}
+
+declare <vscale x 2 x i64> @llvm.masked.load.nxv2i64(<vscale x 2 x i64>*, i32, <vscale x 2 x i1>, <vscale x 2 x i64>)
+declare <vscale x 4 x i32> @llvm.masked.load.nxv4i32(<vscale x 4 x i32>*, i32, <vscale x 4 x i1>, <vscale x 4 x i32>)
+declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16(<vscale x 8 x i16>*, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)
+declare <vscale x 16 x i8> @llvm.masked.load.nxv16i8(<vscale x 16 x i8>*, i32, <vscale x 16 x i1>, <vscale x 16 x i8>)
+
+declare <vscale x 2 x double> @llvm.masked.load.nxv2f64(<vscale x 2 x double>*, i32, <vscale x 2 x i1>, <vscale x 2 x double>)
+declare <vscale x 2 x float> @llvm.masked.load.nxv2f32(<vscale x 2 x float>*, i32, <vscale x 2 x i1>, <vscale x 2 x float>)
+declare <vscale x 2 x half> @llvm.masked.load.nxv2f16(<vscale x 2 x half>*, i32, <vscale x 2 x i1>, <vscale x 2 x half>)
+declare <vscale x 4 x float> @llvm.masked.load.nxv4f32(<vscale x 4 x float>*, i32, <vscale x 4 x i1>, <vscale x 4 x float>)
+declare <vscale x 4 x half> @llvm.masked.load.nxv4f16(<vscale x 4 x half>*, i32, <vscale x 4 x i1>, <vscale x 4 x half>)
+declare <vscale x 8 x half> @llvm.masked.load.nxv8f16(<vscale x 8 x half>*, i32, <vscale x 8 x i1>, <vscale x 8 x half>)
Index: llvm/lib/Target/AArch64/SVEInstrFormats.td
===================================================================
--- llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -293,6 +293,8 @@
 : Pat<(vtd (op vt1:$Op1, vt2:$Op2, vt3:$Op3)),
       (inst $Op1, $Op2, $Op3)>;
 
+def SVEUndef : ComplexPattern<i64, 0, "SelectUndef", []>;
+
 //===----------------------------------------------------------------------===//
 // SVE Predicate Misc Group
 //===----------------------------------------------------------------------===//
@@ -4732,6 +4734,13 @@
                   (!cast<Instruction>(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), 0>;
   def : InstAlias<asm # "\t$Zt, $Pg/z, [$Rn]",
                   (!cast<Instruction>(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 1>;
+
+  // We need a layer of indirection because early machine code passes balk at
+  // physical register (i.e. FFR) uses that have no previous definition.
+  let hasSideEffects = 1, hasNoSchedulingInfo = 1, mayLoad = 1 in {
+  def "" : Pseudo<(outs listty:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), []>,
+           PseudoInstExpansion<(!cast<Instruction>(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4)>;
+  }
 }
 
 multiclass sve_mem_cld_si<bits<4> dtype, string asm, RegisterOperand listty,
Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
===================================================================
--- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -147,6 +147,13 @@
 
   bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info);
 
+  bool isLegalMaskedLoad(Type *DataType) {
+    return ST->hasSVE();
+  }
+  bool isLegalMaskedStore(Type *DataType) {
+    return ST->hasSVE();
+  }
+
   int getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor,
                                  ArrayRef<unsigned> Indices, unsigned Alignment,
                                  unsigned AddressSpace,
Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
===================================================================
--- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -1070,6 +1070,44 @@
   def : Pat<(nxv2f64 (bitconvert (nxv8f16 ZPR:$src))), (nxv2f64 ZPR:$src)>;
   def : Pat<(nxv2f64 (bitconvert (nxv4f32 ZPR:$src))), (nxv2f64 ZPR:$src)>;
 
+  // Add more complex addressing modes here as required
+  multiclass pred_load<ValueType Ty, ValueType PredTy, SDPatternOperator Load,
+                        Instruction RegImmInst> {
+
+    def _default_z : Pat<(Ty (Load  GPR64:$base, (PredTy PPR:$gp), (SVEUndef))),
+                         (RegImmInst PPR:$gp, GPR64:$base, (i64 0))>;
+  }
+
+  // 2-element contiguous loads
+  defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i8,   LD1B_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i8,  LD1SB_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i16,  LD1H_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i16, LD1SH_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i32,  LD1W_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i32, LD1SW_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, nonext_masked_load,    LD1D_IMM>;
+  defm : pred_load<nxv2f16, nxv2i1, nonext_masked_load,    LD1H_D_IMM>;
+  defm : pred_load<nxv2f32, nxv2i1, nonext_masked_load,    LD1W_D_IMM>;
+  defm : pred_load<nxv2f64, nxv2i1, nonext_masked_load,    LD1D_IMM>;
+
+  // 4-element contiguous loads
+  defm : pred_load<nxv4i32, nxv4i1, zext_masked_load_i8,   LD1B_S_IMM>;
+  defm : pred_load<nxv4i32, nxv4i1, asext_masked_load_i8,  LD1SB_S_IMM>;
+  defm : pred_load<nxv4i32, nxv4i1, zext_masked_load_i16,  LD1H_S_IMM>;
+  defm : pred_load<nxv4i32, nxv4i1, asext_masked_load_i16, LD1SH_S_IMM>;
+  defm : pred_load<nxv4i32, nxv4i1, nonext_masked_load,    LD1W_IMM>;
+  defm : pred_load<nxv4f16, nxv4i1, nonext_masked_load,    LD1H_S_IMM>;
+  defm : pred_load<nxv4f32, nxv4i1, nonext_masked_load,    LD1W_IMM>;
+
+  // 8-element contiguous loads
+  defm : pred_load<nxv8i16, nxv8i1, zext_masked_load_i8,  LD1B_H_IMM>;
+  defm : pred_load<nxv8i16, nxv8i1, asext_masked_load_i8, LD1SB_H_IMM>;
+  defm : pred_load<nxv8i16, nxv8i1, nonext_masked_load,   LD1H_IMM>;
+  defm : pred_load<nxv8f16, nxv8i1, nonext_masked_load,   LD1H_IMM>;
+
+  // 16-element contiguous loads
+  defm : pred_load<nxv16i8, nxv16i1, nonext_masked_load, LD1B_IMM>;
+
 }
 
 let Predicates = [HasSVE2] in {
Index: llvm/lib/Target/AArch64/AArch64InstrInfo.td
===================================================================
--- llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -256,6 +256,55 @@
                                          SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>,
                                          SDTCisSameAs<1, 4>]>;
 
+// non-extending masked load fragment.
+def nonext_masked_load :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (masked_ld node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD;
+}]>;
+// sign extending masked load fragments.
+def asext_masked_load :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (masked_ld node:$ptr, node:$pred, node:$def),[{
+  return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::EXTLOAD ||
+         cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::SEXTLOAD;
+}]>;
+def asext_masked_load_i8 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (asext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def asext_masked_load_i16 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (asext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def asext_masked_load_i32 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (asext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
+// zero extending masked load fragments.
+def zext_masked_load :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (masked_ld node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::ZEXTLOAD;
+}]>;
+def zext_masked_load_i8 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (zext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def zext_masked_load_i16 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (zext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def zext_masked_load_i32 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (zext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
 
 // Node definitions.
 def AArch64adrp          : SDNode<"AArch64ISD::ADRP", SDTIntUnaryOp, []>;
Index: llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
===================================================================
--- llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -140,6 +140,11 @@
     return SelectAddrModeXRO(N, Width / 8, Base, Offset, SignExtend, DoShift);
   }
 
+  bool SelectUndef(SDValue N) {
+    if (N->getOpcode() == ISD::UNDEF)
+      return true;
+    return false;
+  }
 
   /// Form sequences of consecutive 64/128-bit registers for use in NEON
   /// instructions making use of a vector-list (e.g. ldN, tbl). Vecs must have
@@ -204,7 +209,7 @@
   bool SelectAddrModeXRO(SDValue N, unsigned Size, SDValue &Base,
                          SDValue &Offset, SDValue &SignExtend,
                          SDValue &DoShift);
-  bool isWorthFolding(SDValue V) const;
+  bool isWorthFolding(SDValue V, unsigned MaxUses = 1) const;
   bool SelectExtendedSHL(SDValue N, unsigned Size, bool WantExtend,
                          SDValue &Offset, SDValue &SignExtend);
 
@@ -375,7 +380,7 @@
 }
 
 /// Determine whether it is worth to fold V into an extended register.
-bool AArch64DAGToDAGISel::isWorthFolding(SDValue V) const {
+bool AArch64DAGToDAGISel::isWorthFolding(SDValue V, unsigned MaxUses) const {
   // Trivial if we are optimizing for code size or if there is only
   // one use of the value.
   if (ForCodeSize || V.hasOneUse())
@@ -394,6 +399,18 @@
       return true;
   }
 
+  // If it has more than one use, check they're all loads/stores
+  // from/to the same memory type (e.g. if you can fold for one
+  // addressing mode, you can fold for the others as well).
+  EVT VT;
+  for (auto *Use : V.getNode()->uses())
+    if (auto *MemNode = dyn_cast<MemSDNode>(Use))
+      if (MemNode->getMemoryVT() != VT && VT != EVT())
+        return false;
+
+  if (V.getNode()->use_size() <= MaxUses)
+    return true;
+
   // It hurts otherwise, since the value will be reused.
   return false;
 }
Index: llvm/lib/CodeGen/TargetLoweringBase.cpp
===================================================================
--- llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -1265,18 +1265,23 @@
     MVT EltVT = VT.getVectorElementType();
     unsigned NElts = VT.getVectorNumElements();
     bool IsLegalWiderType = false;
+    bool IsScalable = VT.isScalableVector();
     LegalizeTypeAction PreferredAction = getPreferredVectorAction(VT);
     switch (PreferredAction) {
-    case TypePromoteInteger:
+    case TypePromoteInteger: {
+      MVT::SimpleValueType EndVT = IsScalable ?
+                                   MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE :
+                                   MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE;
       // Try to promote the elements of integer vectors. If no legal
       // promotion was found, fall through to the widen-vector method.
       for (unsigned nVT = i + 1;
-           nVT <= MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE; ++nVT) {
+           (MVT::SimpleValueType) nVT <= EndVT; ++nVT) {
         MVT SVT = (MVT::SimpleValueType) nVT;
         // Promote vectors of integers to vectors with the same number
         // of elements, with a wider element type.
         if (SVT.getScalarSizeInBits() > EltVT.getSizeInBits() &&
-            SVT.getVectorNumElements() == NElts && isTypeLegal(SVT)) {
+            SVT.getVectorNumElements() == NElts && isTypeLegal(SVT) &&
+            SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) {
           TransformToType[i] = SVT;
           RegisterTypeForVT[i] = SVT;
           NumRegistersForVT[i] = 1;
@@ -1288,6 +1293,7 @@
       if (IsLegalWiderType)
         break;
       LLVM_FALLTHROUGH;
+    }
 
     case TypeWidenVector:
       if (isPowerOf2_32(NElts)) {
@@ -1295,6 +1301,7 @@
         for (unsigned nVT = i + 1; nVT <= MVT::LAST_VECTOR_VALUETYPE; ++nVT) {
           MVT SVT = (MVT::SimpleValueType) nVT;
           if (SVT.getVectorElementType() == EltVT
+              && SVT.isScalableVector() == IsScalable
               && SVT.getVectorNumElements() > NElts && isTypeLegal(SVT)) {
             TransformToType[i] = SVT;
             RegisterTypeForVT[i] = SVT;
Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
===================================================================
--- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4447,12 +4447,15 @@
   const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range);
 
   // Do not serialize masked loads of constant memory with anything.
-  bool AddToChain =
-      !AA || !AA->pointsToConstantMemory(MemoryLocation(
-                 PtrOperand,
-                 LocationSize::precise(
-                     DAG.getDataLayout().getTypeStoreSize(I.getType())),
-                 AAInfo));
+  bool AddToChain = false;
+  if (!VT.isScalableVector())
+    AddToChain =
+        !AA || !AA->pointsToConstantMemory(MemoryLocation(
+                   PtrOperand,
+                   LocationSize::precise(
+                       DAG.getDataLayout().getTypeStoreSize(I.getType())),
+                   AAInfo));
+
   SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode();
 
   MachineMemOperand *MMO =
Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
===================================================================
--- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -838,6 +838,34 @@
   return false;
 }
 
+static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
+  if (!ScalarTy.isSimple())
+    return false;
+
+  uint64_t MaskForTy = 0ull;
+  switch(ScalarTy.getSimpleVT().SimpleTy) {
+  case MVT::i8:
+    MaskForTy = 0xffull;
+    break;
+  case MVT::i16:
+    MaskForTy = 0xffffull;
+    break;
+  case MVT::i32:
+    MaskForTy = 0xffffffffull;
+    break;
+  default:
+    return false;
+    break;
+  }
+
+  APInt Val;
+  if (ISD::isConstantSplatVector(N, Val)) {
+    return Val.getLimitedValue() == MaskForTy;
+  }
+
+  return false;
+}
+
 // Returns the SDNode if it is a constant float BuildVector
 // or constant float.
 static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) {
@@ -5221,6 +5249,24 @@
     }
   }
 
+  if (auto *LN0 = dyn_cast<MaskedLoadSDNode>(N0)) {
+    EVT MemVT = LN0->getMemoryVT();
+    EVT ScalarVT = MemVT.getScalarType();
+    if (SDValue(LN0, 0).hasOneUse()
+        && isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT)
+        && TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT)) {
+      SDValue ZExtLoad = DAG.getMaskedLoad(VT, SDLoc(N), LN0->getChain(),
+                                           LN0->getBasePtr(), LN0->getMask(),
+                                           LN0->getPassThru(), MemVT,
+                                           LN0->getMemOperand(), ISD::ZEXTLOAD);
+      CombineTo(N, ZExtLoad);
+      CombineTo(N0.getNode(), ZExtLoad, ZExtLoad.getValue(1));
+      AddToWorklist(ZExtLoad.getNode());
+      // Avoid recheck of N.
+      return SDValue(N, 0);
+    }
+  }
+
   // fold (and (load x), 255) -> (zextload x, i8)
   // fold (and (extload x, i16), 255) -> (zextload x, i8)
   // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8)
@@ -9043,6 +9089,9 @@
   if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
     return SDValue();
 
+  if (DstVT.isScalableVector())
+    return SDValue();
+
   SDLoc DL(N);
   const unsigned NumSplits =
       DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
@@ -10337,6 +10386,21 @@
     AddToWorklist(ExtLoad.getNode());
     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
   }
+  // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
+  if (isa<MaskedLoadSDNode>(N0) &&
+      EVT == cast<MaskedLoadSDNode>(N0)->getMemoryVT() &&
+      ((!LegalOperations && !cast<MaskedLoadSDNode>(N0)->isVolatile()) ||
+       TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) {
+    MaskedLoadSDNode *LN0 = cast<MaskedLoadSDNode>(N0);
+    SDValue ExtLoad = DAG.getMaskedLoad(VT, SDLoc(N), LN0->getChain(),
+                                        LN0->getBasePtr(), LN0->getMask(),
+                                        LN0->getPassThru(), LN0->getMemoryVT(),
+                                        LN0->getMemOperand(), ISD::SEXTLOAD);
+    CombineTo(N, ExtLoad);
+    CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
+    AddToWorklist(ExtLoad.getNode());
+    return SDValue(N, 0);   // Return N so it doesn't get rechecked!
+  }
   // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
   if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
       N0.hasOneUse() &&
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to