================ @@ -0,0 +1,726 @@ + +//===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "AArch64LoopIdiomTransform.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/InitializePasses.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "aarch64-lit" + +static cl::opt<bool> + DisableAll("disable-aarch64-lit-all", cl::Hidden, cl::init(false), + cl::desc("Disable AArch64 Loop Idiom Transform Pass.")); + +static cl::opt<bool> DisableByteCmp( + "disable-aarch64-lit-bytecmp", cl::Hidden, cl::init(false), + cl::desc("Proceed with AArch64 Loop Idiom Transform Pass, but do " + "not convert byte-compare loop(s).")); + +namespace llvm { + +void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &); +Pass *createAArch64LoopIdiomTransformPass(); + +} // end namespace llvm + +namespace { + +class AArch64LoopIdiomTransform { + Loop *CurLoop = nullptr; + DominatorTree *DT; + LoopInfo *LI; + const TargetTransformInfo *TTI; + const DataLayout *DL; + +public: + explicit AArch64LoopIdiomTransform(DominatorTree *DT, LoopInfo *LI, + const TargetTransformInfo *TTI, + const DataLayout *DL) + : DT(DT), LI(LI), TTI(TTI), DL(DL) {} + + bool run(Loop *L); + +private: + /// \name Countable Loop Idiom Handling + /// @{ + + bool runOnCountableLoop(); + bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, + SmallVectorImpl<BasicBlock *> &ExitBlocks); + + bool recognizeByteCompare(); + Value *expandFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA, + GetElementPtrInst *GEPB, Value *Start, + Value *MaxLen); + void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, + Value *MaxLen, Value *Index, Value *Start, + bool IncIdx, BasicBlock *FoundBB, + BasicBlock *EndBB); + /// @} +}; + +class AArch64LoopIdiomTransformLegacyPass : public LoopPass { +public: + static char ID; + + explicit AArch64LoopIdiomTransformLegacyPass() : LoopPass(ID) { + initializeAArch64LoopIdiomTransformLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { + return "Recognize AArch64-specific loop idioms"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override; +}; + +bool AArch64LoopIdiomTransformLegacyPass::runOnLoop(Loop *L, + LPPassManager &LPM) { + + if (skipLoop(L)) + return false; + + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI( + *L->getHeader()->getParent()); + return AArch64LoopIdiomTransform( + DT, LI, &TTI, &L->getHeader()->getModule()->getDataLayout()) + .run(L); +} + +} // end anonymous namespace + +char AArch64LoopIdiomTransformLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN( + AArch64LoopIdiomTransformLegacyPass, "aarch64-lit", + "Transform specific loop idioms into optimised vector forms", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END( + AArch64LoopIdiomTransformLegacyPass, "aarch64-lit", + "Transform specific loop idioms into optimised vector forms", false, false) + +Pass *llvm::createAArch64LoopIdiomTransformPass() { + return new AArch64LoopIdiomTransformLegacyPass(); +} + +PreservedAnalyses +AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + if (DisableAll) + return PreservedAnalyses::all(); + + const auto *DL = &L.getHeader()->getModule()->getDataLayout(); + + AArch64LoopIdiomTransform LIT(&AR.DT, &AR.LI, &AR.TTI, DL); + if (!LIT.run(&L)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +//===----------------------------------------------------------------------===// +// +// Implementation of AArch64LoopIdiomTransform +// +//===----------------------------------------------------------------------===// + +bool AArch64LoopIdiomTransform::run(Loop *L) { + CurLoop = L; + + if (DisableAll) + return false; + + // If the loop could not be converted to canonical form, it must have an + // indirectbr in it, just give up. + if (!L->getLoopPreheader()) + return false; + + LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" + << CurLoop->getHeader()->getParent()->getName() + << "] Loop %" << CurLoop->getHeader()->getName() << "\n"); + + return recognizeByteCompare(); +} + +/// Match loop-invariant value. +template <typename SubPattern_t> struct match_LoopInvariant { + SubPattern_t SubPattern; + const Loop *L; + + match_LoopInvariant(const SubPattern_t &SP, const Loop *L) + : SubPattern(SP), L(L) {} + + template <typename ITy> bool match(ITy *V) { + return L->isLoopInvariant(V) && SubPattern.match(V); + } +}; + +/// Matches if the value is loop-invariant. +template <typename Ty> +inline match_LoopInvariant<Ty> m_LoopInvariant(const Ty &M, const Loop *L) { + return match_LoopInvariant<Ty>(M, L); +} + +bool AArch64LoopIdiomTransform::recognizeByteCompare() { + if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || + DisableByteCmp) + return false; + + BasicBlock *Header = CurLoop->getHeader(); + BasicBlock *PH = CurLoop->getLoopPreheader(); + + // In AArch64LoopIdiomTransform::run we have already checked that the loop + // has a preheader so we can assume it's in a canonical form. + auto *EntryBI = cast<BranchInst>(PH->getTerminator()); + + if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2) + return false; + + PHINode *PN = dyn_cast<PHINode>(&Header->front()); + if (!PN || PN->getNumIncomingValues() != 2) + return false; + + auto LoopBlocks = CurLoop->getBlocks(); + // The first block in the loop should contain only 4 instructions, e.g. + // + // while.cond: + // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ] + // %inc = add i32 %res.phi, 1 + // %cmp.not = icmp eq i32 %inc, %n + // br i1 %cmp.not, label %while.end, label %while.body + // + auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug(); + if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4) ---------------- david-arm wrote:
This logic is designed to prevent recognising loops such as `@compare_bytes_simple2`, i.e. ``` while.cond: ... br i1 %cmp.not, label %while.end, label %while.body while.body: ... br i1 %cmp.not2, label %while.cond, label %while.end while.end: ... %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ] ``` where the found and end blocks are the same, and in that block there is a PHI node that has a unique incoming value for each block in the loop. However, each incoming value is not defined in the loop itself. Currently we don't handle this case in `transformByteCompare` because it would require extra work in the `byte.compare` block to select between these two incoming values. Rather than let `transformByteCompare` generate incorrect code and lead to bugs, I've explicitly disabled these cases for now. I haven't encountered any examples of such loops in the benchmarks I have looked at, so these are not critical to us. I only discovered this accidentally while writing more test cases. https://github.com/llvm/llvm-project/pull/72273 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits