================ @@ -5387,6 +5387,192 @@ bool AMDGPULegalizerInfo::legalizeDSAtomicFPIntrinsic(LegalizerHelper &Helper, return true; } +bool AMDGPULegalizerInfo::legalizeLaneOp(LegalizerHelper &Helper, + MachineInstr &MI, + Intrinsic::ID IID) const { + + MachineIRBuilder &B = Helper.MIRBuilder; + MachineRegisterInfo &MRI = *B.getMRI(); + + Register DstReg = MI.getOperand(0).getReg(); + Register Src0 = MI.getOperand(2).getReg(); + + auto createLaneOp = [&](Register Src0, Register Src1, + Register Src2) -> Register { + auto LaneOp = B.buildIntrinsic(IID, {S32}).addUse(Src0); + switch (IID) { + case Intrinsic::amdgcn_readfirstlane: + return LaneOp.getReg(0); + case Intrinsic::amdgcn_readlane: + return LaneOp.addUse(Src1).getReg(0); + case Intrinsic::amdgcn_writelane: + return LaneOp.addUse(Src1).addUse(Src2).getReg(0); + default: + llvm_unreachable("unhandled lane op"); + } + }; + + Register Src1, Src2; + if (IID == Intrinsic::amdgcn_readlane || IID == Intrinsic::amdgcn_writelane) { + Src1 = MI.getOperand(3).getReg(); + if (IID == Intrinsic::amdgcn_writelane) { + Src2 = MI.getOperand(4).getReg(); + } + } + + LLT Ty = MRI.getType(DstReg); + unsigned Size = Ty.getSizeInBits(); + + if (Size == 32) { + // Already legal + return true; + } + + if (Size < 32) { + Register Src0Cast = MRI.getType(Src0).isScalar() + ? Src0 + : B.buildBitcast(LLT::scalar(Size), Src0).getReg(0); + Src0 = B.buildAnyExt(S32, Src0Cast).getReg(0); + if (Src2.isValid()) { + Register Src2Cast = + MRI.getType(Src2).isScalar() + ? Src2 + : B.buildBitcast(LLT::scalar(Size), Src2).getReg(0); + Src2 = B.buildAnyExt(LLT::scalar(32), Src2Cast).getReg(0); + } + + Register LaneOpDst = createLaneOp(Src0, Src1, Src2); + if (Ty.isScalar()) + B.buildTrunc(DstReg, LaneOpDst); + else { + auto Trunc = B.buildTrunc(LLT::scalar(Size), LaneOpDst); + B.buildBitcast(DstReg, Trunc); + } + + MI.eraseFromParent(); + return true; + } + + if ((Size % 32) == 0) { + SmallVector<Register, 2> PartialRes; + unsigned NumParts = Size / 32; + auto IsS16Vec = Ty.isVector() && Ty.getElementType() == S16; + MachineInstrBuilder Src0Parts; + + if (Ty.isPointer()) { + auto PtrToInt = B.buildPtrToInt(LLT::scalar(Size), Src0); + Src0Parts = B.buildUnmerge(S32, PtrToInt); + } else if (Ty.isPointerVector()) { + LLT IntVecTy = Ty.changeElementType( + LLT::scalar(Ty.getElementType().getSizeInBits())); + auto PtrToInt = B.buildPtrToInt(IntVecTy, Src0); + Src0Parts = B.buildUnmerge(S32, PtrToInt); + } else + Src0Parts = + IsS16Vec ? B.buildUnmerge(V2S16, Src0) : B.buildUnmerge(S32, Src0); + + switch (IID) { + case Intrinsic::amdgcn_readlane: { + Register Src1 = MI.getOperand(3).getReg(); + for (unsigned i = 0; i < NumParts; ++i) { + Src0 = IsS16Vec ? B.buildBitcast(S32, Src0Parts.getReg(i)).getReg(0) + : Src0Parts.getReg(i); + PartialRes.push_back( + (B.buildIntrinsic(Intrinsic::amdgcn_readlane, {S32}) + .addUse(Src0) + .addUse(Src1)) + .getReg(0)); + } + break; + } + case Intrinsic::amdgcn_readfirstlane: { + for (unsigned i = 0; i < NumParts; ++i) { + Src0 = IsS16Vec ? B.buildBitcast(S32, Src0Parts.getReg(i)).getReg(0) + : Src0Parts.getReg(i); + PartialRes.push_back( + (B.buildIntrinsic(Intrinsic::amdgcn_readfirstlane, {S32}) + .addUse(Src0) + .getReg(0))); + } + + break; + } + case Intrinsic::amdgcn_writelane: { + Register Src1 = MI.getOperand(3).getReg(); + Register Src2 = MI.getOperand(4).getReg(); + MachineInstrBuilder Src2Parts; + + if (Ty.isPointer()) { + auto PtrToInt = B.buildPtrToInt(S64, Src2); + Src2Parts = B.buildUnmerge(S32, PtrToInt); + } else if (Ty.isPointerVector()) { + LLT IntVecTy = Ty.changeElementType( + LLT::scalar(Ty.getElementType().getSizeInBits())); + auto PtrToInt = B.buildPtrToInt(IntVecTy, Src2); + Src2Parts = B.buildUnmerge(S32, PtrToInt); + } else + Src2Parts = + IsS16Vec ? B.buildUnmerge(V2S16, Src2) : B.buildUnmerge(S32, Src2); + + for (unsigned i = 0; i < NumParts; ++i) { + Src0 = IsS16Vec ? B.buildBitcast(S32, Src0Parts.getReg(i)).getReg(0) + : Src0Parts.getReg(i); + Src2 = IsS16Vec ? B.buildBitcast(S32, Src2Parts.getReg(i)).getReg(0) + : Src2Parts.getReg(i); + PartialRes.push_back( + (B.buildIntrinsic(Intrinsic::amdgcn_writelane, {S32}) + .addUse(Src0) + .addUse(Src1) + .addUse(Src2)) + .getReg(0)); + } + + break; + } + } + + if (Ty.isPointerVector()) { + unsigned PtrSize = Ty.getElementType().getSizeInBits(); + SmallVector<Register, 2> PtrElements; + if (PtrSize == 32) { + // Handle 32 bit pointers + for (unsigned i = 0; i < NumParts; i++) + PtrElements.push_back( + B.buildIntToPtr(Ty.getElementType(), PartialRes[i]).getReg(0)); + } else { + // Handle legalization of <? x [pointer type bigger than 32 bits]> + SmallVector<Register, 2> PtrParts; + unsigned NumS32Parts = PtrSize / 32; + unsigned PartIdx = 0; + for (unsigned i = 0, j = 1; i < NumParts; i += NumS32Parts, j++) { + // Merge S32 components of a pointer element first. + for (; PartIdx < (j * NumS32Parts); PartIdx++) + PtrParts.push_back(PartialRes[PartIdx]); + + auto MergedPtr = + B.buildMergeLikeInstr(LLT::scalar(PtrSize), PtrParts); + PtrElements.push_back( + B.buildIntToPtr(Ty.getElementType(), MergedPtr).getReg(0)); + PtrParts.clear(); + } + } + + B.buildMergeLikeInstr(DstReg, PtrElements); + } else { + if (IsS16Vec) { + for (unsigned i = 0; i < NumParts; i++) + PartialRes[i] = B.buildBitcast(V2S16, PartialRes[i]).getReg(0); ---------------- arsenm wrote:
You shouldn't need to create a bitcast here. You're directly consuming the `<2 x s16>` pieces in the selection pattern https://github.com/llvm/llvm-project/pull/89217 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits