llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-backend-risc-v Author: Pengcheng Wang (wangpc-pp) <details> <summary>Changes</summary> Note that we only support SEW=8/16 for `vwabdacc(u)`. --- Full diff: https://github.com/llvm/llvm-project/pull/180162.diff 4 Files Affected: - (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+44) - (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+14-2) - (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td (+21-1) - (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll (+14-10) ``````````diff diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index d46cb575c54c5..171fc391a7aa8 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -18770,6 +18770,48 @@ static SDValue combineVWADDSUBWSelect(SDNode *N, SelectionDAG &DAG) { N->getFlags()); } +// vwaddu C (vabd A B) -> vwabda(A B C) +// vwaddu C (vabdu A B) -> vwabdau(A B C) +static SDValue performVWABDACombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (!Subtarget.hasStdExtZvabd()) + return SDValue(); + + MVT VT = N->getSimpleValueType(0); + if (VT.getVectorElementType() != MVT::i8 && + VT.getVectorElementType() != MVT::i16) + return SDValue(); + + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + SDValue Passthru = N->getOperand(2); + if (!Passthru->isUndef()) + return SDValue(); + + SDValue Mask = N->getOperand(3); + SDValue VL = N->getOperand(4); + auto IsABD = [](SDValue Op) { + if (Op->getOpcode() != RISCVISD::ABDS_VL && + Op->getOpcode() != RISCVISD::ABDU_VL) + return SDValue(); + return Op; + }; + + SDValue Diff = IsABD(Op0); + Diff = Diff ? Diff : IsABD(Op1); + if (!Diff) + return SDValue(); + SDValue Acc = Diff == Op0 ? Op1 : Op0; + + SDLoc DL(N); + Acc = DAG.getNode(RISCVISD::VZEXT_VL, DL, VT, Acc, Mask, VL); + SDValue Result = DAG.getNode( + Diff.getOpcode() == RISCVISD::ABDS_VL ? RISCVISD::VWABDA_VL + : RISCVISD::VWABDAU_VL, + DL, VT, Diff.getOperand(0), Diff.getOperand(1), Acc, Mask, VL); + return Result; +} + static SDValue performVWADDSUBW_VLCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const RISCVSubtarget &Subtarget) { @@ -21681,6 +21723,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, if (SDValue V = combineVqdotAccum(N, DAG, Subtarget)) return V; return combineToVWMACC(N, DAG, Subtarget); + case RISCVISD::VWADDU_VL: + return performVWABDACombine(N, DAG, Subtarget); case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: case RISCVISD::VWSUB_W_VL: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index 46dd45876a384..d1bcaffdeac5b 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -1750,8 +1750,9 @@ multiclass VPatMultiplyAddVL_VV_VX<SDNode op, string instruction_name> { } } -multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name> { - foreach vtiTowti = AllWidenableIntVectors in { +multiclass VPatWidenMultiplyAddVL_VV<SDNode vwmacc_op, string instr_name, + list<VTypeInfoToWide> vtilist = AllWidenableIntVectors> { + foreach vtiTowti = vtilist in { defvar vti = vtiTowti.Vti; defvar wti = vtiTowti.Wti; let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates, @@ -1763,6 +1764,17 @@ multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name> { (!cast<Instruction>(instr_name#"_VV_"#vti.LMul.MX#"_MASK") wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; + } + } +} + +multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name> + : VPatWidenMultiplyAddVL_VV<vwmacc_op, instr_name> { + foreach vtiTowti = AllWidenableIntVectors in { + defvar vti = vtiTowti.Vti; + defvar wti = vtiTowti.Wti; + let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates, + GetVTypePredicates<wti>.Predicates) in { def : Pat<(vwmacc_op (SplatPat XLenVT:$rs1), (vti.Vector vti.RegClass:$rs2), (wti.Vector wti.RegClass:$rd), diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td index 139372b70e590..46261d83711cc 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td @@ -29,7 +29,6 @@ let Predicates = [HasStdExtZvabd] in { //===----------------------------------------------------------------------===// // Pseudos //===----------------------------------------------------------------------===// - multiclass PseudoVABS { foreach m = MxList in { defvar mx = m.MX; @@ -44,10 +43,23 @@ multiclass PseudoVABS { } } +multiclass VPseudoVWABD_VV { + foreach m = MxListW in { + defvar mx = m.MX; + defm "" : VPseudoTernaryW_VV<m, Commutable = 1>, + SchedTernary<"WriteVIWMulAddV", "ReadVIWMulAddV", + "ReadVIWMulAddV", "ReadVIWMulAddV", mx>; + } +} + let Predicates = [HasStdExtZvabd] in { defm PseudoVABS : PseudoVABS; defm PseudoVABD : VPseudoVALU_VV<Commutable = 1>; defm PseudoVABDU : VPseudoVALU_VV<Commutable = 1>; + let IsRVVWideningReduction = 1 in { + defm PseudoVWABDA : VPseudoVWABD_VV; + defm PseudoVWABDAU : VPseudoVWABD_VV; + } // IsRVVWideningReduction = 1 } // Predicates = [HasStdExtZvabd] //===----------------------------------------------------------------------===// @@ -57,12 +69,17 @@ let HasPassthruOp = true, HasMaskOp = true in { def riscv_abs_vl : RVSDNode<"ABS_VL", SDT_RISCVIntUnOp_VL>; def riscv_abds_vl : RVSDNode<"ABDS_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; def riscv_abdu_vl : RVSDNode<"ABDU_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; +def rvv_vwabda_vl : RVSDNode<"VWABDA_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>; +def rvv_vwabdau_vl : RVSDNode<"VWABDAU_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>; } // let HasPassthruOp = true, HasMaskOp = true // These instructions are defined for SEW=8 and SEW=16, otherwise the instruction // encoding is reserved. defvar ABDIntVectors = !filter(vti, AllIntegerVectors, !or(!eq(vti.SEW, 8), !eq(vti.SEW, 16))); +defvar ABDAIntVectors = !filter(vtiTowti, AllWidenableIntVectors, + !or(!eq(vtiTowti.Vti.SEW, 8), + !eq(vtiTowti.Vti.SEW, 16))); let Predicates = [HasStdExtZvabd] in { defm : VPatBinarySDNode_VV<abds, "PseudoVABD", ABDIntVectors>; @@ -79,4 +96,7 @@ foreach vti = AllIntegerVectors in { } defm : VPatUnaryVL_V<riscv_abs_vl, "PseudoVABS", HasStdExtZvabd>; + +defm : VPatWidenMultiplyAddVL_VV<rvv_vwabda_vl, "PseudoVWABDA", ABDAIntVectors>; +defm : VPatWidenMultiplyAddVL_VV<rvv_vwabdau_vl, "PseudoVWABDAU", ABDAIntVectors>; } // Predicates = [HasStdExtZvabd] diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll index 9f6c34cb052ff..dcb8b31c682b3 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll @@ -199,16 +199,18 @@ define signext i32 @sad_2block_16xi8_as_i32(ptr %a, ptr %b, i32 signext %stridea ; ZVABD-NEXT: vle8.v v15, (a1) ; ZVABD-NEXT: add a0, a0, a2 ; ZVABD-NEXT: add a1, a1, a3 +; ZVABD-NEXT: vle8.v v16, (a0) +; ZVABD-NEXT: vle8.v v17, (a1) ; ZVABD-NEXT: vabdu.vv v8, v8, v9 -; ZVABD-NEXT: vle8.v v9, (a0) -; ZVABD-NEXT: vabdu.vv v10, v10, v11 -; ZVABD-NEXT: vle8.v v11, (a1) -; ZVABD-NEXT: vwaddu.vv v12, v10, v8 +; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma +; ZVABD-NEXT: vzext.vf2 v12, v8 +; ZVABD-NEXT: vsetvli zero, zero, e8, m1, ta, ma +; ZVABD-NEXT: vwabdau.vv v12, v10, v11 ; ZVABD-NEXT: vabdu.vv v8, v14, v15 ; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma ; ZVABD-NEXT: vzext.vf2 v14, v8 ; ZVABD-NEXT: vsetvli zero, zero, e8, m1, ta, ma -; ZVABD-NEXT: vabdu.vv v16, v9, v11 +; ZVABD-NEXT: vabdu.vv v16, v16, v17 ; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma ; ZVABD-NEXT: vwaddu.vv v8, v14, v12 ; ZVABD-NEXT: vzext.vf2 v12, v16 @@ -320,16 +322,18 @@ define signext i32 @sadu_2block_16xi8_as_i32(ptr %a, ptr %b, i32 signext %stride ; ZVABD-NEXT: vle8.v v15, (a1) ; ZVABD-NEXT: add a0, a0, a2 ; ZVABD-NEXT: add a1, a1, a3 +; ZVABD-NEXT: vle8.v v16, (a0) +; ZVABD-NEXT: vle8.v v17, (a1) ; ZVABD-NEXT: vabd.vv v8, v8, v9 -; ZVABD-NEXT: vle8.v v9, (a0) -; ZVABD-NEXT: vabd.vv v10, v10, v11 -; ZVABD-NEXT: vle8.v v11, (a1) -; ZVABD-NEXT: vwaddu.vv v12, v10, v8 +; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma +; ZVABD-NEXT: vzext.vf2 v12, v8 +; ZVABD-NEXT: vsetvli zero, zero, e8, m1, ta, ma +; ZVABD-NEXT: vwabda.vv v12, v10, v11 ; ZVABD-NEXT: vabd.vv v8, v14, v15 ; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma ; ZVABD-NEXT: vzext.vf2 v14, v8 ; ZVABD-NEXT: vsetvli zero, zero, e8, m1, ta, ma -; ZVABD-NEXT: vabd.vv v16, v9, v11 +; ZVABD-NEXT: vabd.vv v16, v16, v17 ; ZVABD-NEXT: vsetvli zero, zero, e16, m2, ta, ma ; ZVABD-NEXT: vwaddu.vv v8, v14, v12 ; ZVABD-NEXT: vzext.vf2 v12, v16 `````````` </details> https://github.com/llvm/llvm-project/pull/180162 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
