https://github.com/llvmbot created https://github.com/llvm/llvm-project/pull/149778
Backport 8a307ae61963a3f967052f7ea3c89aafa56934cf Requested by: @heiher >From e781da54605c515b1d061a1368aba3e73c8a6bd5 Mon Sep 17 00:00:00 2001 From: hev <wang...@loongson.cn> Date: Mon, 21 Jul 2025 16:36:49 +0800 Subject: [PATCH] [LoongArch] Fix failure to widen operand for `[X]VMSK{LT,GE,NE}Z` (#149442) Reported-by: tangyan <tangya...@loongson.cn> (cherry picked from commit 8a307ae61963a3f967052f7ea3c89aafa56934cf) --- .../LoongArch/LoongArchISelLowering.cpp | 221 ++++++++++-------- llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll | 15 ++ 2 files changed, 139 insertions(+), 97 deletions(-) diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index c47987fbf683b..12cf04bbbab56 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -4563,6 +4563,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT, llvm_unreachable("Unexpected node type for vXi1 sign extension"); } +static SDValue +performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const LoongArchSubtarget &Subtarget) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue Src = N->getOperand(0); + EVT SrcVT = Src.getValueType(); + + if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse()) + return SDValue(); + + bool UseLASX; + unsigned Opc = ISD::DELETED_NODE; + EVT CmpVT = Src.getOperand(0).getValueType(); + EVT EltVT = CmpVT.getVectorElementType(); + + if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128) + UseLASX = false; + else if (Subtarget.has32S() && Subtarget.hasExtLASX() && + CmpVT.getSizeInBits() == 256) + UseLASX = true; + else + return SDValue(); + + SDValue SrcN1 = Src.getOperand(1); + switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) { + default: + break; + case ISD::SETEQ: + // x == 0 => not (vmsknez.b x) + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ; + break; + case ISD::SETGT: + // x > -1 => vmskgez.b x + if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; + break; + case ISD::SETGE: + // x >= 0 => vmskgez.b x + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; + break; + case ISD::SETLT: + // x < 0 => vmskltz.{b,h,w,d} x + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && + (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || + EltVT == MVT::i64)) + Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; + break; + case ISD::SETLE: + // x <= -1 => vmskltz.{b,h,w,d} x + if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && + (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || + EltVT == MVT::i64)) + Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; + break; + case ISD::SETNE: + // x != 0 => vmsknez.b x + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ; + break; + } + + if (Opc == ISD::DELETED_NODE) + return SDValue(); + + SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0)); + EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements()); + V = DAG.getZExtOrTrunc(V, DL, T); + return DAG.getBitcast(VT, V); +} + static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const LoongArchSubtarget &Subtarget) { @@ -4577,110 +4651,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG, if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1) return SDValue(); - unsigned Opc = ISD::DELETED_NODE; // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible + SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget); + if (Res) + return Res; + + // Generate vXi1 using [X]VMSKLTZ + MVT SExtVT; + unsigned Opc; + bool UseLASX = false; + bool PropagateSExt = false; + if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) { - bool UseLASX; EVT CmpVT = Src.getOperand(0).getValueType(); - EVT EltVT = CmpVT.getVectorElementType(); - - if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128) - UseLASX = false; - else if (Subtarget.has32S() && Subtarget.hasExtLASX() && - CmpVT.getSizeInBits() <= 256) - UseLASX = true; - else + if (CmpVT.getSizeInBits() > 256) return SDValue(); - - SDValue SrcN1 = Src.getOperand(1); - switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) { - default: - break; - case ISD::SETEQ: - // x == 0 => not (vmsknez.b x) - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ; - break; - case ISD::SETGT: - // x > -1 => vmskgez.b x - if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; - break; - case ISD::SETGE: - // x >= 0 => vmskgez.b x - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; - break; - case ISD::SETLT: - // x < 0 => vmskltz.{b,h,w,d} x - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && - (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || - EltVT == MVT::i64)) - Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; - break; - case ISD::SETLE: - // x <= -1 => vmskltz.{b,h,w,d} x - if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && - (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || - EltVT == MVT::i64)) - Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; - break; - case ISD::SETNE: - // x != 0 => vmsknez.b x - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ; - break; - } } - // Generate vXi1 using [X]VMSKLTZ - if (Opc == ISD::DELETED_NODE) { - MVT SExtVT; - bool UseLASX = false; - bool PropagateSExt = false; - switch (SrcVT.getSimpleVT().SimpleTy) { - default: - return SDValue(); - case MVT::v2i1: - SExtVT = MVT::v2i64; - break; - case MVT::v4i1: - SExtVT = MVT::v4i32; - if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { - SExtVT = MVT::v4i64; - UseLASX = true; - PropagateSExt = true; - } - break; - case MVT::v8i1: - SExtVT = MVT::v8i16; - if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { - SExtVT = MVT::v8i32; - UseLASX = true; - PropagateSExt = true; - } - break; - case MVT::v16i1: - SExtVT = MVT::v16i8; - if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { - SExtVT = MVT::v16i16; - UseLASX = true; - PropagateSExt = true; - } - break; - case MVT::v32i1: - SExtVT = MVT::v32i8; + switch (SrcVT.getSimpleVT().SimpleTy) { + default: + return SDValue(); + case MVT::v2i1: + SExtVT = MVT::v2i64; + break; + case MVT::v4i1: + SExtVT = MVT::v4i32; + if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { + SExtVT = MVT::v4i64; UseLASX = true; - break; - }; - if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX()) - return SDValue(); - Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL) - : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src); - Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; - } else { - Src = Src.getOperand(0); - } + PropagateSExt = true; + } + break; + case MVT::v8i1: + SExtVT = MVT::v8i16; + if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { + SExtVT = MVT::v8i32; + UseLASX = true; + PropagateSExt = true; + } + break; + case MVT::v16i1: + SExtVT = MVT::v16i8; + if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { + SExtVT = MVT::v16i16; + UseLASX = true; + PropagateSExt = true; + } + break; + case MVT::v32i1: + SExtVT = MVT::v32i8; + UseLASX = true; + break; + }; + if (UseLASX && !(Subtarget.has32S() && Subtarget.hasExtLASX())) + return SDValue(); + Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL) + : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src); + Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src); EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements()); diff --git a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll index 0ee30120f77a6..ad57bbf9ee5c0 100644 --- a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll +++ b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll @@ -588,3 +588,18 @@ define i2 @vmsk_trunc_i64(<2 x i64> %a) { %res = bitcast <2 x i1> %y to i2 ret i2 %res } + +define i4 @vmsk_eq_allzeros_v4i8(<4 x i8> %a) { +; CHECK-LABEL: vmsk_eq_allzeros_v4i8: +; CHECK: # %bb.0: +; CHECK-NEXT: vseqi.b $vr0, $vr0, 0 +; CHECK-NEXT: vilvl.b $vr0, $vr0, $vr0 +; CHECK-NEXT: vilvl.h $vr0, $vr0, $vr0 +; CHECK-NEXT: vslli.w $vr0, $vr0, 24 +; CHECK-NEXT: vmskltz.w $vr0, $vr0 +; CHECK-NEXT: vpickve2gr.hu $a0, $vr0, 0 +; CHECK-NEXT: ret + %1 = icmp eq <4 x i8> %a, zeroinitializer + %2 = bitcast <4 x i1> %1 to i4 + ret i4 %2 +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits