================
@@ -5387,6 +5387,192 @@ bool 
AMDGPULegalizerInfo::legalizeDSAtomicFPIntrinsic(LegalizerHelper &Helper,
   return true;
 }
 
+bool AMDGPULegalizerInfo::legalizeLaneOp(LegalizerHelper &Helper,
+                                         MachineInstr &MI,
+                                         Intrinsic::ID IID) const {
+
+  MachineIRBuilder &B = Helper.MIRBuilder;
+  MachineRegisterInfo &MRI = *B.getMRI();
+
+  Register DstReg = MI.getOperand(0).getReg();
+  Register Src0 = MI.getOperand(2).getReg();
+
+  auto createLaneOp = [&](Register Src0, Register Src1,
+                          Register Src2) -> Register {
+    auto LaneOp = B.buildIntrinsic(IID, {S32}).addUse(Src0);
+    switch (IID) {
+    case Intrinsic::amdgcn_readfirstlane:
+      return LaneOp.getReg(0);
+    case Intrinsic::amdgcn_readlane:
+      return LaneOp.addUse(Src1).getReg(0);
+    case Intrinsic::amdgcn_writelane:
+      return LaneOp.addUse(Src1).addUse(Src2).getReg(0);
+    default:
+      llvm_unreachable("unhandled lane op");
+    }
+  };
+
+  Register Src1, Src2;
+  if (IID == Intrinsic::amdgcn_readlane || IID == Intrinsic::amdgcn_writelane) 
{
+    Src1 = MI.getOperand(3).getReg();
+    if (IID == Intrinsic::amdgcn_writelane) {
+      Src2 = MI.getOperand(4).getReg();
+    }
+  }
+
+  LLT Ty = MRI.getType(DstReg);
+  unsigned Size = Ty.getSizeInBits();
+
+  if (Size == 32) {
+    // Already legal
+    return true;
+  }
+
+  if (Size < 32) {
+    Register Src0Cast = MRI.getType(Src0).isScalar()
+                            ? Src0
+                            : B.buildBitcast(LLT::scalar(Size), 
Src0).getReg(0);
+    Src0 = B.buildAnyExt(S32, Src0Cast).getReg(0);
+    if (Src2.isValid()) {
+      Register Src2Cast =
+          MRI.getType(Src2).isScalar()
+              ? Src2
+              : B.buildBitcast(LLT::scalar(Size), Src2).getReg(0);
+      Src2 = B.buildAnyExt(LLT::scalar(32), Src2Cast).getReg(0);
+    }
+
+    Register LaneOpDst = createLaneOp(Src0, Src1, Src2);
+    if (Ty.isScalar())
+      B.buildTrunc(DstReg, LaneOpDst);
+    else {
+      auto Trunc = B.buildTrunc(LLT::scalar(Size), LaneOpDst);
+      B.buildBitcast(DstReg, Trunc);
+    }
+
+    MI.eraseFromParent();
+    return true;
+  }
+
+  if ((Size % 32) == 0) {
+    SmallVector<Register, 2> PartialRes;
+    unsigned NumParts = Size / 32;
+    auto IsS16Vec = Ty.isVector() && Ty.getElementType() == S16;
+    MachineInstrBuilder Src0Parts;
+
+    if (Ty.isPointer()) {
+      auto PtrToInt = B.buildPtrToInt(LLT::scalar(Size), Src0);
+      Src0Parts = B.buildUnmerge(S32, PtrToInt);
+    } else if (Ty.isPointerVector()) {
+      LLT IntVecTy = Ty.changeElementType(
+          LLT::scalar(Ty.getElementType().getSizeInBits()));
+      auto PtrToInt = B.buildPtrToInt(IntVecTy, Src0);
+      Src0Parts = B.buildUnmerge(S32, PtrToInt);
+    } else
+      Src0Parts =
+          IsS16Vec ? B.buildUnmerge(V2S16, Src0) : B.buildUnmerge(S32, Src0);
+
+    switch (IID) {
+    case Intrinsic::amdgcn_readlane: {
+      Register Src1 = MI.getOperand(3).getReg();
+      for (unsigned i = 0; i < NumParts; ++i) {
+        Src0 = IsS16Vec ? B.buildBitcast(S32, Src0Parts.getReg(i)).getReg(0)
+                        : Src0Parts.getReg(i);
+        PartialRes.push_back(
+            (B.buildIntrinsic(Intrinsic::amdgcn_readlane, {S32})
+                 .addUse(Src0)
+                 .addUse(Src1))
+                .getReg(0));
+      }
+      break;
+    }
+    case Intrinsic::amdgcn_readfirstlane: {
+      for (unsigned i = 0; i < NumParts; ++i) {
+        Src0 = IsS16Vec ? B.buildBitcast(S32, Src0Parts.getReg(i)).getReg(0)
+                        : Src0Parts.getReg(i);
+        PartialRes.push_back(
+            (B.buildIntrinsic(Intrinsic::amdgcn_readfirstlane, {S32})
+                 .addUse(Src0)
+                 .getReg(0)));
+      }
+
+      break;
+    }
+    case Intrinsic::amdgcn_writelane: {
+      Register Src1 = MI.getOperand(3).getReg();
+      Register Src2 = MI.getOperand(4).getReg();
+      MachineInstrBuilder Src2Parts;
+
+      if (Ty.isPointer()) {
+        auto PtrToInt = B.buildPtrToInt(S64, Src2);
+        Src2Parts = B.buildUnmerge(S32, PtrToInt);
+      } else if (Ty.isPointerVector()) {
+        LLT IntVecTy = Ty.changeElementType(
+            LLT::scalar(Ty.getElementType().getSizeInBits()));
+        auto PtrToInt = B.buildPtrToInt(IntVecTy, Src2);
+        Src2Parts = B.buildUnmerge(S32, PtrToInt);
+      } else
+        Src2Parts =
+            IsS16Vec ? B.buildUnmerge(V2S16, Src2) : B.buildUnmerge(S32, Src2);
+
+      for (unsigned i = 0; i < NumParts; ++i) {
+        Src0 = IsS16Vec ? B.buildBitcast(S32, Src0Parts.getReg(i)).getReg(0)
+                        : Src0Parts.getReg(i);
+        Src2 = IsS16Vec ? B.buildBitcast(S32, Src2Parts.getReg(i)).getReg(0)
+                        : Src2Parts.getReg(i);
+        PartialRes.push_back(
+            (B.buildIntrinsic(Intrinsic::amdgcn_writelane, {S32})
+                 .addUse(Src0)
+                 .addUse(Src1)
+                 .addUse(Src2))
+                .getReg(0));
+      }
+
+      break;
+    }
+    }
+
+    if (Ty.isPointerVector()) {
+      unsigned PtrSize = Ty.getElementType().getSizeInBits();
+      SmallVector<Register, 2> PtrElements;
+      if (PtrSize == 32) {
+        // Handle 32 bit pointers
+        for (unsigned i = 0; i < NumParts; i++)
+          PtrElements.push_back(
+              B.buildIntToPtr(Ty.getElementType(), PartialRes[i]).getReg(0));
+      } else {
+        // Handle legalization of <? x [pointer type bigger than 32 bits]>
+        SmallVector<Register, 2> PtrParts;
+        unsigned NumS32Parts = PtrSize / 32;
+        unsigned PartIdx = 0;
+        for (unsigned i = 0, j = 1; i < NumParts; i += NumS32Parts, j++) {
+          // Merge S32 components of a pointer element first.
+          for (; PartIdx < (j * NumS32Parts); PartIdx++)
+            PtrParts.push_back(PartialRes[PartIdx]);
+
+          auto MergedPtr =
+              B.buildMergeLikeInstr(LLT::scalar(PtrSize), PtrParts);
+          PtrElements.push_back(
+              B.buildIntToPtr(Ty.getElementType(), MergedPtr).getReg(0));
+          PtrParts.clear();
+        }
+      }
+
+      B.buildMergeLikeInstr(DstReg, PtrElements);
+    } else {
+      if (IsS16Vec) {
+        for (unsigned i = 0; i < NumParts; i++)
+          PartialRes[i] = B.buildBitcast(V2S16, PartialRes[i]).getReg(0);
----------------
arsenm wrote:

You shouldn't need to create a bitcast here. You're directly consuming the `<2 
x s16>` pieces in the selection pattern 

https://github.com/llvm/llvm-project/pull/89217
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to