kmclaughlin updated this revision to Diff 225900.
kmclaughlin edited the summary of this revision.
kmclaughlin added a comment.

- Rebased patch, removed extra sext & zext combine from DAGCombine which are no 
longer necessary
- Added isVectorLoadExtDesirable to AArch64ISelLowering
- Added more checks to isLegalMaskedLoad
- Changed //SVEUndef// to //SVEDup0Undef//, handling undef or all zeros
- Changed SelectionDAG::getConstant to return SPLAT_VECTOR instead of 
BUILD_VECTOR for scalable types


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D68877/new/

https://reviews.llvm.org/D68877

Files:
  llvm/include/llvm/CodeGen/SelectionDAG.h
  llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
  llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
  llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
  llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
  llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
  llvm/lib/Target/AArch64/AArch64ISelLowering.h
  llvm/lib/Target/AArch64/AArch64InstrInfo.td
  llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
  llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
  llvm/lib/Target/AArch64/SVEInstrFormats.td
  llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
  llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
  llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll

Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll
@@ -0,0 +1,72 @@
+; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
+
+;
+; Masked Loads
+;
+
+define <vscale x 2 x i64> @masked_zload_nxv2i8(<vscale x 2 x i8>* %src, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv2i8:
+; CHECK-NOT: ld1sb
+; CHECK: ld1b { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>* %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> undef)
+  %ext = zext <vscale x 2 x i8> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 2 x i64> @masked_zload_nxv2i16(<vscale x 2 x i16>* %src, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv2i16:
+; CHECK-NOT: ld1sh
+; CHECK: ld1h { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>* %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i16> undef)
+  %ext = zext <vscale x 2 x i16> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 2 x i64> @masked_zload_nxv2i32(<vscale x 2 x i32>* %src, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv2i32:
+; CHECK-NOT: ld1sw
+; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>* %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i32> undef)
+  %ext = zext <vscale x 2 x i32> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 4 x i32> @masked_zload_nxv4i8(<vscale x 4 x i8>* %src, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv4i8:
+; CHECK-NOT: ld1sb
+; CHECK: ld1b { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>* %src, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i8> undef)
+  %ext = zext <vscale x 4 x i8> %load to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %ext
+}
+
+define <vscale x 4 x i32> @masked_zload_nxv4i16(<vscale x 4 x i16>* %src, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv4i16:
+; CHECK-NOT: ld1sh
+; CHECK: ld1h { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>* %src, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i16> undef)
+  %ext = zext <vscale x 4 x i16> %load to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %ext
+}
+
+define <vscale x 8 x i16> @masked_zload_nxv8i8(<vscale x 8 x i8>* %src, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv8i8:
+; CHECK-NOT: ld1sb
+; CHECK: ld1b { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>* %src, i32 1, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef)
+  %ext = zext <vscale x 8 x i8> %load to <vscale x 8 x i16>
+  ret <vscale x 8 x i16> %ext
+}
+
+declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
+declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
+declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)
+declare <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>*, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
+declare <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>*, i32, <vscale x 4 x i1>, <vscale x 4 x i16>)
+declare <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>*, i32, <vscale x 8 x i1>, <vscale x 8 x i8>)
Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
@@ -0,0 +1,66 @@
+; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
+
+;
+; Masked Loads
+;
+
+define <vscale x 2 x i64> @masked_sload_nxv2i8(<vscale x 2 x i8> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv2i8:
+; CHECK: ld1sb { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8> *%a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i8> undef)
+  %ext = sext <vscale x 2 x i8> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 2 x i64> @masked_sload_nxv2i16(<vscale x 2 x i16> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv2i16:
+; CHECK: ld1sh { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16> *%a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i16> undef)
+  %ext = sext <vscale x 2 x i16> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 2 x i64> @masked_sload_nxv2i32(<vscale x 2 x i32> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv2i32:
+; CHECK: ld1sw { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32> *%a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i32> undef)
+  %ext = sext <vscale x 2 x i32> %load to <vscale x 2 x i64>
+  ret <vscale x 2 x i64> %ext
+}
+
+define <vscale x 4 x i32> @masked_sload_nxv4i8(<vscale x 4 x i8> *%a, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv4i8:
+; CHECK: ld1sb { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8> *%a, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i8> undef)
+  %ext = sext <vscale x 4 x i8> %load to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %ext
+}
+
+define <vscale x 4 x i32> @masked_sload_nxv4i16(<vscale x 4 x i16> *%a, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv4i16:
+; CHECK: ld1sh { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16> *%a, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i16> undef)
+  %ext = sext <vscale x 4 x i16> %load to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %ext
+}
+
+define <vscale x 8 x i16> @masked_sload_nxv8i8(<vscale x 8 x i8> *%a, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv8i8:
+; CHECK: ld1sb { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8> *%a, i32 1, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef)
+  %ext = sext <vscale x 8 x i8> %load to <vscale x 8 x i16>
+  ret <vscale x 8 x i16> %ext
+}
+
+declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
+declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
+declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)
+declare <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>*, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
+declare <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>*, i32, <vscale x 4 x i1>, <vscale x 4 x i16>)
+declare <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>*, i32, <vscale x 8 x i1>, <vscale x 8 x i8>)
Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
@@ -0,0 +1,87 @@
+; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
+
+;
+; Masked Loads
+;
+
+define <vscale x 2 x i64> @masked_load_nxv2i64(<vscale x 2 x i64> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv2i64:
+; CHECK: ld1d { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 2 x i64> @llvm.masked.load.nxv2i64(<vscale x 2 x i64> *%a, i32 8, <vscale x 2 x i1> %mask, <vscale x 2 x i64> undef)
+  ret <vscale x 2 x i64> %load
+}
+
+define <vscale x 4 x i32> @masked_load_nxv4i32(<vscale x 4 x i32> *%a, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv4i32:
+; CHECK: ld1w { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 4 x i32> @llvm.masked.load.nxv4i32(<vscale x 4 x i32> *%a, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef)
+  ret <vscale x 4 x i32> %load
+}
+
+define <vscale x 8 x i16> @masked_load_nxv8i16(<vscale x 8 x i16> *%a, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv8i16:
+; CHECK: ld1h { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16(<vscale x 8 x i16> *%a, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef)
+  ret <vscale x 8 x i16> %load
+}
+
+define <vscale x 16 x i8> @masked_load_nxv16i8(<vscale x 16 x i8> *%a, <vscale x 16 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv16i8:
+; CHECK: ld1b { [[IN:z[0-9]+]].b }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8(<vscale x 16 x i8> *%a, i32 1, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
+  ret <vscale x 16 x i8> %load
+}
+
+define <vscale x 2 x double> @masked_load_nxv2f64(<vscale x 2 x double> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv2f64:
+; CHECK: ld1d { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 2 x double> @llvm.masked.load.nxv2f64(<vscale x 2 x double> *%a, i32 8, <vscale x 2 x i1> %mask, <vscale x 2 x double> undef)
+  ret <vscale x 2 x double> %load
+}
+
+define <vscale x 2 x float> @masked_load_nxv2f32(<vscale x 2 x float> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv2f32:
+; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 2 x float> @llvm.masked.load.nxv2f32(<vscale x 2 x float> *%a, i32 4, <vscale x 2 x i1> %mask, <vscale x 2 x float> undef)
+  ret <vscale x 2 x float> %load
+}
+
+define <vscale x 2 x half> @masked_load_nxv2f16(<vscale x 2 x half> *%a, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv2f16:
+; CHECK: ld1h { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 2 x half> @llvm.masked.load.nxv2f16(<vscale x 2 x half> *%a, i32 2, <vscale x 2 x i1> %mask, <vscale x 2 x half> undef)
+  ret <vscale x 2 x half> %load
+}
+
+define <vscale x 4 x float> @masked_load_nxv4f32(<vscale x 4 x float> *%a, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv4f32:
+; CHECK: ld1w { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32(<vscale x 4 x float> *%a, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> undef)
+  ret <vscale x 4 x float> %load
+}
+
+define <vscale x 4 x half> @masked_load_nxv4f16(<vscale x 4 x half> *%a, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv4f16:
+; CHECK: ld1h { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 4 x half> @llvm.masked.load.nxv4f16(<vscale x 4 x half> *%a, i32 2, <vscale x 4 x i1> %mask, <vscale x 4 x half> undef)
+  ret <vscale x 4 x half> %load
+}
+
+define <vscale x 8 x half> @masked_load_nxv8f16(<vscale x 8 x half> *%a, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: masked_load_nxv8f16:
+; CHECK: ld1h { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0]
+  %load = call <vscale x 8 x half> @llvm.masked.load.nxv8f16(<vscale x 8 x half> *%a, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x half> undef)
+  ret <vscale x 8 x half> %load
+}
+
+declare <vscale x 2 x i64> @llvm.masked.load.nxv2i64(<vscale x 2 x i64>*, i32, <vscale x 2 x i1>, <vscale x 2 x i64>)
+declare <vscale x 4 x i32> @llvm.masked.load.nxv4i32(<vscale x 4 x i32>*, i32, <vscale x 4 x i1>, <vscale x 4 x i32>)
+declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16(<vscale x 8 x i16>*, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)
+declare <vscale x 16 x i8> @llvm.masked.load.nxv16i8(<vscale x 16 x i8>*, i32, <vscale x 16 x i1>, <vscale x 16 x i8>)
+
+declare <vscale x 2 x double> @llvm.masked.load.nxv2f64(<vscale x 2 x double>*, i32, <vscale x 2 x i1>, <vscale x 2 x double>)
+declare <vscale x 2 x float> @llvm.masked.load.nxv2f32(<vscale x 2 x float>*, i32, <vscale x 2 x i1>, <vscale x 2 x float>)
+declare <vscale x 2 x half> @llvm.masked.load.nxv2f16(<vscale x 2 x half>*, i32, <vscale x 2 x i1>, <vscale x 2 x half>)
+declare <vscale x 4 x float> @llvm.masked.load.nxv4f32(<vscale x 4 x float>*, i32, <vscale x 4 x i1>, <vscale x 4 x float>)
+declare <vscale x 4 x half> @llvm.masked.load.nxv4f16(<vscale x 4 x half>*, i32, <vscale x 4 x i1>, <vscale x 4 x half>)
+declare <vscale x 8 x half> @llvm.masked.load.nxv8f16(<vscale x 8 x half>*, i32, <vscale x 8 x i1>, <vscale x 8 x half>)
Index: llvm/lib/Target/AArch64/SVEInstrFormats.td
===================================================================
--- llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -293,6 +293,8 @@
 : Pat<(vtd (op vt1:$Op1, vt2:$Op2, vt3:$Op3)),
       (inst $Op1, $Op2, $Op3)>;
 
+def SVEDup0Undef : ComplexPattern<i64, 0, "SelectDupZeroOrUndef", []>;
+
 //===----------------------------------------------------------------------===//
 // SVE Predicate Misc Group
 //===----------------------------------------------------------------------===//
@@ -4744,6 +4746,13 @@
                   (!cast<Instruction>(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), 0>;
   def : InstAlias<asm # "\t$Zt, $Pg/z, [$Rn]",
                   (!cast<Instruction>(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 1>;
+
+  // We need a layer of indirection because early machine code passes balk at
+  // physical register (i.e. FFR) uses that have no previous definition.
+  let hasSideEffects = 1, hasNoSchedulingInfo = 1, mayLoad = 1 in {
+  def "" : Pseudo<(outs listty:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), []>,
+           PseudoInstExpansion<(!cast<Instruction>(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4)>;
+  }
 }
 
 multiclass sve_mem_cld_si<bits<4> dtype, string asm, RegisterOperand listty,
Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
===================================================================
--- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -147,6 +147,21 @@
 
   bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info);
 
+  bool isLegalMaskedLoad(Type *DataType, MaybeAlign Alignment) {
+    if (!isa<VectorType>(DataType) || !ST->hasSVE())
+      return false;
+
+    Type *Ty = DataType->getVectorElementType();
+    if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy())
+      return true;
+
+    if (Ty->isIntegerTy(8) || Ty->isIntegerTy(16) ||
+        Ty->isIntegerTy(32) || Ty->isIntegerTy(64))
+      return true;
+
+    return false;
+  }
+
   int getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor,
                                  ArrayRef<unsigned> Indices, unsigned Alignment,
                                  unsigned AddressSpace,
Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
===================================================================
--- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -1070,6 +1070,44 @@
   def : Pat<(nxv2f64 (bitconvert (nxv8f16 ZPR:$src))), (nxv2f64 ZPR:$src)>;
   def : Pat<(nxv2f64 (bitconvert (nxv4f32 ZPR:$src))), (nxv2f64 ZPR:$src)>;
 
+  // Add more complex addressing modes here as required
+  multiclass pred_load<ValueType Ty, ValueType PredTy, SDPatternOperator Load,
+                        Instruction RegImmInst> {
+
+    def _default_z : Pat<(Ty (Load  GPR64:$base, (PredTy PPR:$gp), (SVEDup0Undef))),
+                         (RegImmInst PPR:$gp, GPR64:$base, (i64 0))>;
+  }
+
+  // 2-element contiguous loads
+  defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i8,   LD1B_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i8,  LD1SB_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i16,  LD1H_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i16, LD1SH_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i32,  LD1W_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i32, LD1SW_D_IMM>;
+  defm : pred_load<nxv2i64, nxv2i1, nonext_masked_load,    LD1D_IMM>;
+  defm : pred_load<nxv2f16, nxv2i1, nonext_masked_load,    LD1H_D_IMM>;
+  defm : pred_load<nxv2f32, nxv2i1, nonext_masked_load,    LD1W_D_IMM>;
+  defm : pred_load<nxv2f64, nxv2i1, nonext_masked_load,    LD1D_IMM>;
+
+  // 4-element contiguous loads
+  defm : pred_load<nxv4i32, nxv4i1, zext_masked_load_i8,   LD1B_S_IMM>;
+  defm : pred_load<nxv4i32, nxv4i1, asext_masked_load_i8,  LD1SB_S_IMM>;
+  defm : pred_load<nxv4i32, nxv4i1, zext_masked_load_i16,  LD1H_S_IMM>;
+  defm : pred_load<nxv4i32, nxv4i1, asext_masked_load_i16, LD1SH_S_IMM>;
+  defm : pred_load<nxv4i32, nxv4i1, nonext_masked_load,    LD1W_IMM>;
+  defm : pred_load<nxv4f16, nxv4i1, nonext_masked_load,    LD1H_S_IMM>;
+  defm : pred_load<nxv4f32, nxv4i1, nonext_masked_load,    LD1W_IMM>;
+
+  // 8-element contiguous loads
+  defm : pred_load<nxv8i16, nxv8i1, zext_masked_load_i8,  LD1B_H_IMM>;
+  defm : pred_load<nxv8i16, nxv8i1, asext_masked_load_i8, LD1SB_H_IMM>;
+  defm : pred_load<nxv8i16, nxv8i1, nonext_masked_load,   LD1H_IMM>;
+  defm : pred_load<nxv8f16, nxv8i1, nonext_masked_load,   LD1H_IMM>;
+
+  // 16-element contiguous loads
+  defm : pred_load<nxv16i8, nxv16i1, nonext_masked_load, LD1B_IMM>;
+
 }
 
 let Predicates = [HasSVE2] in {
Index: llvm/lib/Target/AArch64/AArch64InstrInfo.td
===================================================================
--- llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -259,6 +259,55 @@
                                          SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>,
                                          SDTCisSameAs<1, 4>]>;
 
+// non-extending masked load fragment.
+def nonext_masked_load :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (masked_ld node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD;
+}]>;
+// sign extending masked load fragments.
+def asext_masked_load :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (masked_ld node:$ptr, node:$pred, node:$def),[{
+  return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::EXTLOAD ||
+         cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::SEXTLOAD;
+}]>;
+def asext_masked_load_i8 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (asext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def asext_masked_load_i16 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (asext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def asext_masked_load_i32 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (asext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
+// zero extending masked load fragments.
+def zext_masked_load :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (masked_ld node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::ZEXTLOAD;
+}]>;
+def zext_masked_load_i8 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (zext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8;
+}]>;
+def zext_masked_load_i16 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (zext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16;
+}]>;
+def zext_masked_load_i32 :
+  PatFrag<(ops node:$ptr, node:$pred, node:$def),
+          (zext_masked_load node:$ptr, node:$pred, node:$def), [{
+  return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32;
+}]>;
 
 // Node definitions.
 def AArch64adrp          : SDNode<"AArch64ISD::ADRP", SDTIntUnaryOp, []>;
Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h
===================================================================
--- llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -741,6 +741,7 @@
     return TargetLowering::getInlineAsmMemConstraint(ConstraintCode);
   }
 
+  bool isVectorLoadExtDesirable(SDValue ExtVal) const override;
   bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override;
   bool mayBeEmittedAsTailCall(const CallInst *CI) const override;
   bool getIndexedAddressParts(SDNode *Op, SDValue &Base, SDValue &Offset,
Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
===================================================================
--- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2886,6 +2886,10 @@
   }
 }
 
+bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
+  return ExtVal.getValueType().isScalableVector();
+}
+
 // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16.
 static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST,
                                         EVT VT, EVT MemVT,
Index: llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
===================================================================
--- llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -140,6 +140,26 @@
     return SelectAddrModeXRO(N, Width / 8, Base, Offset, SignExtend, DoShift);
   }
 
+  bool SelectDupZeroOrUndef(SDValue N) {
+    switch(N->getOpcode()) {
+    case ISD::UNDEF:
+      return true;
+    case AArch64ISD::DUP:
+    case ISD::SPLAT_VECTOR: {
+      auto Opnd0 = N->getOperand(0);
+      if (auto CN = dyn_cast<ConstantSDNode>(Opnd0))
+        if (CN->isNullValue())
+          return true;
+      if (auto CN = dyn_cast<ConstantFPSDNode>(Opnd0))
+        if (CN->isZero())
+          return true;
+    }
+    default:
+      break;
+    }
+
+    return false;
+  }
 
   /// Form sequences of consecutive 64/128-bit registers for use in NEON
   /// instructions making use of a vector-list (e.g. ldN, tbl). Vecs must have
Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
===================================================================
--- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4462,12 +4462,15 @@
   const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range);
 
   // Do not serialize masked loads of constant memory with anything.
-  bool AddToChain =
-      !AA || !AA->pointsToConstantMemory(MemoryLocation(
-                 PtrOperand,
-                 LocationSize::precise(
-                     DAG.getDataLayout().getTypeStoreSize(I.getType())),
-                 AAInfo));
+  MemoryLocation ML;
+  if (VT.isScalableVector())
+    ML = MemoryLocation(PtrOperand);
+  else
+    ML = MemoryLocation(PtrOperand, LocationSize::precise(
+                           DAG.getDataLayout().getTypeStoreSize(I.getType())),
+                           AAInfo);
+  bool AddToChain = !AA || !AA->pointsToConstantMemory(ML);
+
   SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode();
 
   MachineMemOperand *MMO =
Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
===================================================================
--- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1279,7 +1279,9 @@
   }
 
   SDValue Result(N, 0);
-  if (VT.isVector())
+  if (VT.isScalableVector())
+    Result = getSplatVector(VT, DL, Result);
+  else if (VT.isVector())
     Result = getSplatBuildVector(VT, DL, Result);
 
   return Result;
Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
===================================================================
--- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -9108,6 +9108,8 @@
   if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
     return SDValue();
 
+  assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
+
   SDLoc DL(N);
   const unsigned NumSplits =
       DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
Index: llvm/include/llvm/CodeGen/SelectionDAG.h
===================================================================
--- llvm/include/llvm/CodeGen/SelectionDAG.h
+++ llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -779,6 +779,20 @@
     return getNode(ISD::BUILD_VECTOR, DL, VT, Ops);
   }
 
+  // Return a splat ISD::SPLAT_VECTOR node, consisting of Op splatted to all
+  // elements.
+  SDValue getSplatVector(EVT VT, const SDLoc &DL, SDValue Op) {
+    if (Op.getOpcode() == ISD::UNDEF) {
+      assert((VT.getVectorElementType() == Op.getValueType() ||
+              (VT.isInteger() &&
+               VT.getVectorElementType().bitsLE(Op.getValueType()))) &&
+             "A splatted value must have a width equal or (for integers) "
+             "greater than the vector element type!");
+      return getNode(ISD::UNDEF, SDLoc(), VT);
+    }
+    return getNode(ISD::SPLAT_VECTOR, DL, VT, Op);
+  }
+
   /// Returns an ISD::VECTOR_SHUFFLE node semantically equivalent to
   /// the shuffle node in input but with swapped operands.
   ///
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to