https://github.com/minglotus-6 created https://github.com/llvm/llvm-project/pull/81378
None >From ac5dc1bf77b67cbf0aa5e2c8fb6a7ce0080fb501 Mon Sep 17 00:00:00 2001 From: mingmingl <mingmi...@google.com> Date: Sat, 10 Feb 2024 12:03:25 -0800 Subject: [PATCH] [CallPromotionUtils]Implement conditional indirect call promotion with vtable-based comparison --- .../Transforms/Utils/CallPromotionUtils.h | 50 ++++++- .../Transforms/Utils/CallPromotionUtils.cpp | 64 ++++++++- .../Utils/CallPromotionUtilsTest.cpp | 127 ++++++++++++++++++ 3 files changed, 233 insertions(+), 8 deletions(-) diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h index fcb384ec361339..5f3a71206876c6 100644 --- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h +++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h @@ -14,10 +14,17 @@ #ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H #define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H +#include <cstdint> + +#include "llvm/ADT/ArrayRef.h" + namespace llvm { +class Constant; class CallBase; class CastInst; class Function; +class GlobalVariable; +class Instruction; class MDNode; class Value; @@ -41,7 +48,9 @@ bool isLegalToPromote(const CallBase &CB, Function *Callee, CallBase &promoteCall(CallBase &CB, Function *Callee, CastInst **RetBitCast = nullptr); -/// Promote the given indirect call site to conditionally call \p Callee. +/// Promote the given indirect call site to conditionally call \p Callee. The +/// promoted direct call instruction is predicated on `CB.getCalledOperand() == +/// Callee`. /// /// This function creates an if-then-else structure at the location of the call /// site. The original call site is moved into the "else" block. A clone of the @@ -51,6 +60,31 @@ CallBase &promoteCall(CallBase &CB, Function *Callee, CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee, MDNode *BranchWeights = nullptr); +/// This is similar to `promoteCallWithIfThenElse` except that the condition to +/// promote a virtual call is that \p VPtr is the same as any of \p +/// AddressPoints. +/// +/// This function is expected to be used on virtual calls (a subset of indirect +/// calls). \p VPtr is the virtual table address stored in the objects, and +/// \p AddressPoints contains address points of vtables to be compared with. +/// +/// It's the responsibility of caller to guarantee the transformation +/// correctness (by specifying \p VPtr and \p AddressPoints properly). +/// +/// This function doesn't sink the address-calculation instructions of indirect +/// callee to the indirect call fallback. The subsequent passes (e.g. +/// inst-combine) should sink them if possible and handle the sink of debug +/// intrinsics together. +CallBase &promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr, + Function *Callee, + ArrayRef<Constant *> AddressPoints, + MDNode *BranchWeights); + +/// Returns a constant representing the vtable's address point specified by the +/// offset. Caller should ensure \p AddressPointOffset is valid. +Constant *getVTableAddressPointOffset(GlobalVariable *VTable, + uint32_t AddressPointOffset); + /// Try to promote (devirtualize) a virtual call on an Alloca. Return true on /// success. /// @@ -74,13 +108,17 @@ CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee, /// bool tryPromoteCall(CallBase &CB); +/// Predicate and clone the given call site using the given condition. +CallBase &versionCallSiteWithCond(CallBase &CB, Value *Cond, + MDNode *BranchWeights); + /// Predicate and clone the given call site. /// -/// This function creates an if-then-else structure at the location of the call -/// site. The "if" condition compares the call site's called value to the given -/// callee. The original call site is moved into the "else" block, and a clone -/// of the call site is placed in the "then" block. The cloned instruction is -/// returned. +/// This function creates an if-then-else structure at the location of the +/// call site. The "if" condition compares the call site's called value to +/// the given callee. The original call site is moved into the "else" block, +/// and a clone of the call site is placed in the "then" block. The cloned +/// instruction is returned. CallBase &versionCallSite(CallBase &CB, Value *Callee, MDNode *BranchWeights); } // end namespace llvm diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index d0cf0792eface0..ea855b9a4d8416 100644 --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -12,9 +12,11 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/AttributeMask.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -185,6 +187,24 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { U->replaceUsesOfWith(&CB, Cast); } +// Returns the or result of all icmp instructions. +static Value *getOrResult(const SmallVector<Value *, 2> &ICmps, + IRBuilder<> &Builder) { + assert(!ICmps.empty() && "Must have at least one icmp instructions"); + if (ICmps.size() == 1) + return ICmps[0]; + + SmallVector<Value *, 2> OrResults; + int i = 0, NumICmp = ICmps.size(); + for (i = 0; i + 1 < NumICmp; i += 2) + OrResults.push_back(Builder.CreateOr(ICmps[i], ICmps[i + 1], "icmp-or")); + + if (i < NumICmp) + OrResults.push_back(ICmps[i]); + + return getOrResult(OrResults, Builder); +} + /// Predicate and clone the given call site. /// /// This function creates an if-then-else structure at the location of the call @@ -276,8 +296,8 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { /// ; The original call instruction stays in its original block. /// %t0 = musttail call i32 %ptr() /// ret %t0 -static CallBase &versionCallSiteWithCond(CallBase &CB, Value *Cond, - MDNode *BranchWeights) { +CallBase &llvm::versionCallSiteWithCond(CallBase &CB, Value *Cond, + MDNode *BranchWeights) { IRBuilder<> Builder(&CB); CallBase *OrigInst = &CB; @@ -565,6 +585,46 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee, return promoteCall(NewInst, Callee); } +Constant *llvm::getVTableAddressPointOffset(GlobalVariable *VTable, + uint32_t AddressPointOffset) { + Module &M = *VTable->getParent(); + const DataLayout &DL = M.getDataLayout(); + LLVMContext &Context = M.getContext(); + Type *VTableType = VTable->getValueType(); + assert(AddressPointOffset < DL.getTypeAllocSize(VTableType) && + "Out-of-bound access"); + APInt AddressPointOffsetAPInt(32, AddressPointOffset, false); + SmallVector<APInt> Indices = + DL.getGEPIndicesForOffset(VTableType, AddressPointOffsetAPInt); + SmallVector<llvm::Constant *> GEPIndices; + for (const auto &Index : Indices) + GEPIndices.push_back(llvm::ConstantInt::get(Type::getInt32Ty(Context), + Index.getZExtValue())); + + return ConstantExpr::getInBoundsGetElementPtr(VTable->getValueType(), VTable, + GEPIndices); +} + +CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr, + Function *Callee, + ArrayRef<Constant *> AddressPoints, + MDNode *BranchWeights) { + assert(!AddressPoints.empty() && "Caller should guarantee"); + IRBuilder<> Builder(&CB); + SmallVector<Value *, 2> ICmps; + for (auto &AddressPoint : AddressPoints) + ICmps.push_back(Builder.CreateICmpEQ(VPtr, AddressPoint)); + + Value *Cond = getOrResult(ICmps, Builder); + + // Version the indirect call site. If Cond is true, 'NewInst' will be + // executed, otherwise the original call site will be executed. + CallBase &NewInst = versionCallSiteWithCond(CB, Cond, BranchWeights); + + // Promote 'NewInst' so that it directly calls the desired function. + return promoteCall(NewInst, Callee); +} + bool llvm::tryPromoteCall(CallBase &CB) { assert(!CB.getCalledFunction()); Module *M = CB.getCaller()->getParent(); diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp index eff8e27d36d641..c57abb54e46849 100644 --- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp +++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp @@ -8,9 +8,12 @@ #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/AsmParser/Parser.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/NoFolder.h" #include "llvm/Support/SourceMgr.h" #include "gtest/gtest.h" @@ -368,3 +371,127 @@ declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this) bool IsPromoted = tryPromoteCall(*CI); EXPECT_FALSE(IsPromoted); } + +TEST(CallPromotionUtilsTest, getVTableAddressPointOffset) { + LLVMContext C; + std::unique_ptr<Module> M = parseIR(C, + R"IR( +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@_ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] } + +declare i32 @_ZN5Base15func1Ev(ptr) +declare i32 @_ZN5Base25func2Ev(ptr) +declare i32 @_ZN5Base15func0Ev(ptr) +declare void @_ZN5Base35func3Ev(ptr) +)IR"); + GlobalVariable *GV = M->getGlobalVariable("_ZTV8Derived2"); + + for (auto [AddressPointOffset, Index] : + {std::pair{16, 0}, {40, 1}, {64, 2}}) { + Constant *AddressPoint = + getVTableAddressPointOffset(GV, AddressPointOffset); + + ConstantExpr *GEP = dyn_cast<ConstantExpr>(AddressPoint); + ASSERT_TRUE(GEP); + SmallVector<Constant *> Indices = { + llvm::ConstantInt::get(Type::getInt32Ty(C), 0U), + llvm::ConstantInt::get(Type::getInt32Ty(C), Index), + llvm::ConstantInt::get(Type::getInt32Ty(C), 2U)}; + EXPECT_EQ(GEP, ConstantExpr::getInBoundsGetElementPtr(GV->getValueType(), + GV, Indices)); + } +} + +TEST(CallPromotionUtilsTest, promoteCallWithVTableCmp) { + LLVMContext C; + std::unique_ptr<Module> M = parseIR(C, + R"IR( +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@_ZTV5Base1 = constant { [4 x ptr] } { [4 x ptr] [ptr null, ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !0 +@_ZTV8Derived1 = constant { [4 x ptr], [3 x ptr] } { [4 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev], [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !1, !type !2, !type !3 +@_ZTV5Base2 = constant { [3 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !2 +@_ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !4, !type !5, !type !6, !type !7 +@_ZTV5Base3 = constant { [3 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev] }, !type !6 + +define i32 @testfunc(ptr %d) { +entry: + %vtable = load ptr, ptr %d, !prof !8 + %0 = tail call i1 @llvm.type.test(ptr %vtable, metadata !"_ZTS5Base1") + tail call void @llvm.assume(i1 %0) + %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1 + %1 = load ptr, ptr %vfn + %call = tail call i32 %1(ptr %d), !prof !9 + ret i32 %call +} + +define i32 @_ZN5Base15func1Ev(ptr %this) { +entry: + ret i32 2 +} + + +declare i1 @llvm.type.test(ptr, metadata) +declare void @llvm.assume(i1) +declare i32 @_ZN5Base25func2Ev(ptr) +declare i32 @_ZN5Base15func0Ev(ptr) +declare void @_ZN5Base35func3Ev(ptr) + +!0 = !{i64 16, !"_ZTS5Base1"} +!1 = !{i64 16, !"_ZTS5Base1"} +!2 = !{i64 48, !"_ZTS5Base2"} +!3 = !{i64 16, !"_ZTS8Derived1"} +!4 = !{i64 64, !"_ZTS5Base1"} +!5 = !{i64 40, !"_ZTS5Base2"} +!6 = !{i64 16, !"_ZTS5Base3"} +!7 = !{i64 16, !"_ZTS8Derived2"} +!8 = !{!"VP", i32 2, i64 1600, i64 -9064381665493407289, i64 800, i64 5035968517245772950, i64 500, i64 3215870116411581797, i64 300} +!9 = !{!"VP", i32 0, i64 1600, i64 6804820478065511155, i64 1600})IR"); + + Function *F = M->getFunction("testfunc"); + ASSERT_TRUE(F); + CallInst *CI = dyn_cast<CallInst>(&*std::next(F->front().rbegin())); + ASSERT_TRUE(CI && CI->isIndirectCall()); + + LoadInst *FuncPtr = dyn_cast<LoadInst>(CI->getCalledOperand()); + ASSERT_TRUE(FuncPtr); + + GetElementPtrInst *GEP = + dyn_cast<GetElementPtrInst>(FuncPtr->getPointerOperand()); + ASSERT_TRUE(GEP); + + LoadInst *VPtr = dyn_cast<LoadInst>(&*F->front().begin()); + + Function *Callee = M->getFunction("_ZN5Base15func1Ev"); + + // Create the constant and the branch weights + SmallVector<Constant *, 3> VTableAddressPoints; + + for (auto &[VTableName, AddressPointOffset] : {std::pair{"_ZTV5Base1", 16}, + {"_ZTV8Derived1", 16}, + {"_ZTV8Derived2", 64}}) + VTableAddressPoints.push_back(getVTableAddressPointOffset( + M->getGlobalVariable(VTableName), AddressPointOffset)); + + MDBuilder MDB(C); + MDNode *BranchWeights = MDB.createBranchWeights(1600, 0); + + size_t OrigEntryBBSize = F->front().size(); + + // Tests that promoted direct call is returned. + CallBase &DirectCB = promoteCallWithVTableCmp( + *CI, VPtr, Callee, VTableAddressPoints, BranchWeights); + EXPECT_EQ(DirectCB.getCalledOperand(), Callee); + + // Tests that GEP and FuncPtr sink to the basic block of indirect call. + BasicBlock *EntryBB = &F->front(); + EXPECT_EQ(EntryBB, GEP->getParent()); + EXPECT_EQ(EntryBB, FuncPtr->getParent()); + + // Promotion inserts 3 icmp instructions and 2 or instructions, and removes + // 1 call instruction from the entry block. + EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4); +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits