https://github.com/nickitat updated https://github.com/llvm/llvm-project/pull/185087
>From 915a7cb2288f585791aae3261258185198455f80 Mon Sep 17 00:00:00 2001 From: Nikita Taranov <[email protected]> Date: Sat, 24 Jan 2026 23:15:22 +0100 Subject: [PATCH 1/2] impl --- clang/lib/CodeGen/CGExprScalar.cpp | 23 +++ .../CodeGen/devirt-downcast-type-test.cpp | 52 ++++++ .../Transforms/Utils/CallPromotionUtils.cpp | 160 ++++++++++++++---- ...ual_interface_calls_through_static_cast.ll | 129 ++++++++++++++ 4 files changed, 334 insertions(+), 30 deletions(-) create mode 100644 clang/test/CodeGen/devirt-downcast-type-test.cpp create mode 100644 llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 06eadb6c07507..935c3a9e5c0c7 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -2827,6 +2827,29 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { CodeGenFunction::CFITCK_DerivedCast, CE->getBeginLoc()); + // Propagate static_cast<Derived*> type information to the middle-end via + // llvm.type.test + llvm.assume. The programmer's downcast asserts (with UB + // if violated) that the object is of the derived type, so we record this as + // a type assumption that devirtualization passes can exploit. + // + // We use the BASE/SOURCE object pointer (not the vtable pointer) as the + // type.test argument so that tryPromoteCall can find it immediately after + // inlining the callee: after inlining, the vtable is loaded from the same + // SSA value (the original object pointer), making the type.test findable by + // scanning uses of the object pointer. + if (DerivedClassDecl->isPolymorphic() && + DerivedClassDecl->isEffectivelyFinal()) { + llvm::Value *BasePtr = Base.emitRawPointer(CGF); + CanQualType Ty = CGF.CGM.getContext().getCanonicalTagType(DerivedClassDecl); + llvm::Metadata *MD = CGF.CGM.CreateMetadataIdentifierForType(Ty); + llvm::Value *TypeId = + llvm::MetadataAsValue::get(CGF.CGM.getLLVMContext(), MD); + llvm::Value *TypeTest = CGF.Builder.CreateCall( + CGF.CGM.getIntrinsic(llvm::Intrinsic::type_test), {BasePtr, TypeId}); + CGF.Builder.CreateCall(CGF.CGM.getIntrinsic(llvm::Intrinsic::assume), + TypeTest); + } + return CGF.getAsNaturalPointerTo(Derived, CE->getType()->getPointeeType()); } case CK_UncheckedDerivedToBase: diff --git a/clang/test/CodeGen/devirt-downcast-type-test.cpp b/clang/test/CodeGen/devirt-downcast-type-test.cpp new file mode 100644 index 0000000000000..877c1dc140f70 --- /dev/null +++ b/clang/test/CodeGen/devirt-downcast-type-test.cpp @@ -0,0 +1,52 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -emit-llvm -o - %s | FileCheck %s +// +// Test that Clang emits llvm.type.test+llvm.assume on the object pointer at +// CK_BaseToDerived (static_cast<Derived*>) cast sites when the derived class +// is polymorphic and effectively final. This annotation allows the LLVM inliner +// (tryPromoteCall) to devirtualize virtual calls through the downcast pointer +// without requiring a visible vtable store. + +struct Base { + virtual void doFoo(); + void foo() { doFoo(); } +}; + +struct Derived final : Base { + void doFoo() override; +}; + +// static_cast to a final polymorphic derived class: type.test must be emitted. +void f(Base *b) { + static_cast<Derived *>(b)->foo(); +} + +// CHECK-LABEL: define {{.*}} @_Z1fP4Base( +// CHECK: [[LOADED:%[0-9]+]] = load ptr, ptr %b.addr +// CHECK-NEXT: [[TT:%[0-9]+]] = call i1 @llvm.type.test(ptr [[LOADED]], metadata !"_ZTS7Derived") +// CHECK-NEXT: call void @llvm.assume(i1 [[TT]]) + +struct NonPolyBase {}; +struct NonPolyDerived : NonPolyBase {}; + +// static_cast to a non-polymorphic derived class: no type.test should be emitted. +NonPolyDerived *g(NonPolyBase *b) { + return static_cast<NonPolyDerived *>(b); +} + +// CHECK-LABEL: define {{.*}} @_Z1gP11NonPolyBase( +// CHECK-NOT: llvm.type.test +// CHECK: ret ptr + +struct NonFinalDerived : Base { + void doFoo() override; +}; + +// static_cast to a non-final polymorphic derived class: no type.test should be +// emitted (the object could be a further-derived subclass with a different vtable). +void h(Base *b) { + static_cast<NonFinalDerived *>(b)->foo(); +} + +// CHECK-LABEL: define {{.*}} @_Z1hP4Base( +// CHECK-NOT: llvm.type.test +// CHECK: ret void diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index f0f9add09bf82..f9f5bd3c95b44 100644 --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/AttributeMask.h" #include "llvm/IR/Constant.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -682,60 +683,159 @@ CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr, return promoteCall(NewInst, Callee); } +// Try to devirtualize an indirect virtual call using +// llvm.type.test + llvm.assume pairs that were emitted by Clang, +// e.g., at static_cast<Derived*> downcast sites. +// Such a cast is a programmer assertion (UB if wrong) that the object is of +// type Derived. Clang records this as: +// %tt = call i1 @llvm.type.test(ptr %base_obj, metadata !"_ZTS4Derived") +// call void @llvm.assume(i1 %tt) +// +// After the callee is inlined into the caller, the vtable is loaded from the +// same object pointer (%base_obj). By scanning uses of Object (the vtable-load +// source) for such type.test calls, we can determine the concrete vtable and +// resolve the virtual call to a direct call. +static bool tryDevirtualizeViaTypeTestAssume(CallBase &CB, Value *Object, + APInt VTableOffset, + const DataLayout &DL, Module &M) { + // Build a dominator tree so we can verify that all execution paths to the + // call (CB) must go through the assume that uses the type.test result. + DominatorTree DT(*CB.getFunction()); + + for (User *U : Object->users()) { + auto *TypeTestCI = dyn_cast<CallInst>(U); + if (!TypeTestCI || TypeTestCI->getIntrinsicID() != Intrinsic::type_test) + continue; + // The type.test must use Object as its pointer argument. + if (TypeTestCI->getArgOperand(0) != Object) + continue; + + // There must be a dominating llvm.assume consuming the type.test result. + bool HasDominatingAssume = false; + for (User *TU : TypeTestCI->users()) { + if (auto *Assume = dyn_cast<AssumeInst>(TU); + Assume && DT.dominates(Assume, &CB)) { + HasDominatingAssume = true; + break; + } + } + if (!HasDominatingAssume) + continue; + + // Extract the type metadata identifier, e.g. MDString "_ZTS4Impl". + Metadata *TypeId = + cast<MetadataAsValue>(TypeTestCI->getArgOperand(1))->getMetadata(); + + // Vtable lookup via !type metadata. + // We require exactly one matching vtable — if multiple vtables carry the + // same type ID the type is not effectively final and we cannot safely + // devirtualize (the object could be a further-derived subclass). + GlobalVariable *MatchedVTable = nullptr; + uint64_t MatchedAddrPointOffset = 0; + bool Ambiguous = false; + for (GlobalVariable &GV : M.globals()) { + if (!GV.isConstant() || !GV.hasDefinitiveInitializer()) + continue; + SmallVector<MDNode *, 2> Types; + GV.getMetadata(LLVMContext::MD_type, Types); + for (MDNode *TypeMD : Types) { + if (TypeMD->getNumOperands() < 2) + continue; + if (TypeMD->getOperand(1).get() != TypeId) + continue; + auto *OffsetCmd = + dyn_cast<ConstantAsMetadata>(TypeMD->getOperand(0)); + if (!OffsetCmd) + continue; + if (MatchedVTable) { + Ambiguous = true; + break; + } + MatchedVTable = &GV; + MatchedAddrPointOffset = + cast<ConstantInt>(OffsetCmd->getValue())->getZExtValue(); + } + if (Ambiguous) + break; + } + if (MatchedVTable && !Ambiguous) { + if (VTableOffset.getActiveBits() > 64) + continue; + uint64_t TotalOffset = + MatchedAddrPointOffset + VTableOffset.getZExtValue(); + auto [DirectCallee, _] = + getFunctionAtVTableOffset(MatchedVTable, TotalOffset, M); + if (DirectCallee && isLegalToPromote(CB, DirectCallee)) { + promoteCall(CB, DirectCallee); + return true; + } + } + } + return false; +} + bool llvm::tryPromoteCall(CallBase &CB) { assert(!CB.getCalledFunction()); Module *M = CB.getCaller()->getParent(); const DataLayout &DL = M->getDataLayout(); Value *Callee = CB.getCalledOperand(); + // We expect the indirect callee to be a function pointer loaded from a vtable + // slot, which is itself a getelementptr into the vtable, which is loaded from + // the object's vptr field. The chain is: + // %obj = ... (alloca or argument) + // %vtable = load ptr, ptr %obj (VTablePtrLoad) + // %vfn_slot = GEP ptr %vtable, i64 N (VTableEntryPtr, VTableOffset) + // %fn = load ptr, ptr %vfn_slot (VTableEntryLoad) LoadInst *VTableEntryLoad = dyn_cast<LoadInst>(Callee); if (!VTableEntryLoad) - return false; // Not a vtable entry load. + return false; Value *VTableEntryPtr = VTableEntryLoad->getPointerOperand(); APInt VTableOffset(DL.getIndexTypeSizeInBits(VTableEntryPtr->getType()), 0); Value *VTableBasePtr = VTableEntryPtr->stripAndAccumulateConstantOffsets( DL, VTableOffset, /* AllowNonInbounds */ true); LoadInst *VTablePtrLoad = dyn_cast<LoadInst>(VTableBasePtr); if (!VTablePtrLoad) - return false; // Not a vtable load. + return false; Value *Object = VTablePtrLoad->getPointerOperand(); APInt ObjectOffset(DL.getIndexTypeSizeInBits(Object->getType()), 0); Value *ObjectBase = Object->stripAndAccumulateConstantOffsets( DL, ObjectOffset, /* AllowNonInbounds */ true); - if (!(isa<AllocaInst>(ObjectBase) && ObjectOffset == 0)) - // Not an Alloca or the offset isn't zero. - return false; - // Look for the vtable pointer store into the object by the ctor. - BasicBlock::iterator BBI(VTablePtrLoad); - Value *VTablePtr = FindAvailableLoadedValue( - VTablePtrLoad, VTablePtrLoad->getParent(), BBI, 0, nullptr, nullptr); - if (!VTablePtr || !VTablePtr->getType()->isPointerTy()) - return false; // No vtable found. - APInt VTableOffsetGVBase(DL.getIndexTypeSizeInBits(VTablePtr->getType()), 0); - Value *VTableGVBase = VTablePtr->stripAndAccumulateConstantOffsets( - DL, VTableOffsetGVBase, /* AllowNonInbounds */ true); - GlobalVariable *GV = dyn_cast<GlobalVariable>(VTableGVBase); - if (!(GV && GV->isConstant() && GV->hasDefinitiveInitializer())) - // Not in the form of a global constant variable with an initializer. + if (ObjectOffset != 0) return false; - APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset; - if (!(VTableGVOffset.getActiveBits() <= 64)) - return false; // Out of range. + if (isa<AllocaInst>(ObjectBase)) { + // Look for a store of a concrete vtable pointer to the vptr field; + // this is set by the copy/move constructor when the object was materialised + // locally. + BasicBlock::iterator BBI(VTablePtrLoad); + Value *VTablePtr = FindAvailableLoadedValue( + VTablePtrLoad, VTablePtrLoad->getParent(), BBI, 0, nullptr, nullptr); + if (!VTablePtr || !VTablePtr->getType()->isPointerTy()) + return false; - Function *DirectCallee = nullptr; - std::tie(DirectCallee, std::ignore) = - getFunctionAtVTableOffset(GV, VTableGVOffset.getZExtValue(), *M); - if (!DirectCallee) - return false; // No function pointer found. + APInt VTableOffsetGVBase(DL.getIndexTypeSizeInBits(VTablePtr->getType()), + 0); + Value *VTableGVBase = VTablePtr->stripAndAccumulateConstantOffsets( + DL, VTableOffsetGVBase, /* AllowNonInbounds */ true); + GlobalVariable *GV = dyn_cast<GlobalVariable>(VTableGVBase); + if (!(GV && GV->isConstant() && GV->hasDefinitiveInitializer())) + return false; - if (!isLegalToPromote(CB, DirectCallee)) - return false; + APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset; + if (VTableGVOffset.getActiveBits() > 64) + return false; - // Success. - promoteCall(CB, DirectCallee); - return true; + auto [DirectCallee, _] = + getFunctionAtVTableOffset(GV, VTableGVOffset.getZExtValue(), *M); + if (!DirectCallee || !isLegalToPromote(CB, DirectCallee)) + return false; + + promoteCall(CB, DirectCallee); + return true; + } + return tryDevirtualizeViaTypeTestAssume(CB, Object, VTableOffset, DL, *M); } #undef DEBUG_TYPE diff --git a/llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll b/llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll new file mode 100644 index 0000000000000..c325b550ed811 --- /dev/null +++ b/llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll @@ -0,0 +1,129 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6 +; RUN: opt < %s -passes='cgscc(inline),function(sroa,gvn)' -S | FileCheck %s +; +; Test: devirtualization of virtual calls through static_cast downcast pointers +; via llvm.type.test+assume infrastructure. +; +; At the static_cast<Impl*>(intf) site, Clang emits: +; %tt = llvm.type.test(ptr %intf, !"_ZTS4Impl") +; llvm.assume(%tt) +; +; After the inliner inlines Intf::foo(), tryPromoteCall finds the type.test on +; the object pointer, matches !"_ZTS4Impl" against !type metadata on @_ZTV4Impl, +; resolves the vtable slot, and promotes the indirect call to @_ZN4Impl5doFooEv. +; +; Generated from the following C++ source with: +; clang++ -O0 -flto -fwhole-program-vtables -S -emit-llvm file.cc +; then hand-simplified. +; +; -flto is required for -fwhole-program-vtables, which causes Clang to emit +; !type metadata on vtable globals. The vtable definition must be in the same +; module for tryPromoteCall to resolve it; in a multi-TU build this happens at +; LTO link time when all modules are merged. +; +; C++ source: +; +; int glob = 0; +; int secretValue = 42; +; +; struct Intf { +; void foo() { this->doFoo(); } +; virtual void doFoo(); +; }; +; +; struct Impl final : Intf { +; void doFoo() override { glob = secretValue; } +; }; +; +; void f(Intf *intf) { static_cast<Impl *>(intf)->foo(); } + + +%struct.Impl = type { %struct.Intf } +%struct.Intf = type { ptr } + +@glob = dso_local global i32 0, align 4 +@secretValue = dso_local global i32 42, align 4 +@_ZTV4Impl = linkonce_odr unnamed_addr constant { [3 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN4Impl5doFooEv] }, align 8, !type !0 + +; f(Intf *intf) +define dso_local void @_Z1fP4Intf(ptr noundef %intf) { +; CHECK-LABEL: define dso_local void @_Z1fP4Intf( +; CHECK-SAME: ptr noundef [[INTF:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = call i1 @llvm.type.test(ptr [[INTF]], metadata !"_ZTS4Impl") +; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]]) +; CHECK-NEXT: [[TMP1:%.*]] = load i32, ptr @secretValue, align 4 +; CHECK-NEXT: store i32 [[TMP1]], ptr @glob, align 4 +; CHECK-NEXT: ret void +; +entry: + %0 = call i1 @llvm.type.test(ptr %intf, metadata !"_ZTS4Impl") + call void @llvm.assume(i1 %0) + call void @_ZN4Intf3fooEv(ptr noundef nonnull align 8 dereferenceable(8) %intf) + ret void +} + +; Negative test: the assume does NOT dominate the call to foo() because it is +; only on one side of a branch. tryPromoteCall must not devirtualize here. +define dso_local void @non_dominating_assume(ptr noundef %intf, i1 %cond) { +; CHECK-LABEL: define dso_local void @non_dominating_assume( +; CHECK-SAME: ptr noundef [[INTF:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = call i1 @llvm.type.test(ptr [[INTF]], metadata !"_ZTS4Impl") +; CHECK-NEXT: br i1 [[COND]], label %[[THEN:.*]], label %[[ELSE:.*]] +; CHECK: [[THEN]]: +; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]]) +; CHECK-NEXT: br label %[[MERGE:.*]] +; CHECK: [[ELSE]]: +; CHECK-NEXT: br label %[[MERGE]] +; CHECK: [[MERGE]]: +; CHECK-NEXT: [[VTABLE_I:%.*]] = load ptr, ptr [[INTF]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[VTABLE_I]], align 8 +; CHECK-NEXT: call void [[TMP1]](ptr noundef nonnull align 8 dereferenceable(8) [[INTF]]) +; CHECK-NEXT: ret void +; +entry: + %0 = call i1 @llvm.type.test(ptr %intf, metadata !"_ZTS4Impl") + br i1 %cond, label %then, label %else + +then: + call void @llvm.assume(i1 %0) + br label %merge + +else: + br label %merge + +merge: + call void @_ZN4Intf3fooEv(ptr noundef nonnull align 8 dereferenceable(8) %intf) + ret void +} + +; Intf::foo() - non-virtual wrapper that makes the virtual call +define linkonce_odr void @_ZN4Intf3fooEv(ptr noundef nonnull align 8 dereferenceable(8) %this) align 2 { +entry: + %vtable = load ptr, ptr %this, align 8 + %vfn = getelementptr inbounds ptr, ptr %vtable, i64 0 + %0 = load ptr, ptr %vfn, align 8 + call void %0(ptr noundef nonnull align 8 dereferenceable(8) %this) + ret void +} + +; Impl::doFoo() +define linkonce_odr void @_ZN4Impl5doFooEv(ptr noundef nonnull align 8 dereferenceable(8) %this) unnamed_addr align 2 { +; CHECK-LABEL: define linkonce_odr void @_ZN4Impl5doFooEv( +; CHECK-SAME: ptr noundef nonnull align 8 dereferenceable(8) [[THIS:%.*]]) unnamed_addr align 2 { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr @secretValue, align 4 +; CHECK-NEXT: store i32 [[TMP0]], ptr @glob, align 4 +; CHECK-NEXT: ret void +; +entry: + %0 = load i32, ptr @secretValue, align 4 + store i32 %0, ptr @glob, align 4 + ret void +} + +; !type metadata: maps type ID !"_ZTS4Impl" to the vtable address point at byte +; offset 16 (past offset-to-top and RTTI pointer). Emitted by Clang with +; -fwhole-program-vtables. +!0 = !{i64 16, !"_ZTS4Impl"} >From d2f65f300c9231c5ab509c997ea6b184f5c53a2a Mon Sep 17 00:00:00 2001 From: Nikita Taranov <[email protected]> Date: Fri, 6 Mar 2026 19:46:34 +0000 Subject: [PATCH 2/2] fix style --- clang/lib/CodeGen/CGExprScalar.cpp | 3 ++- llvm/lib/Transforms/Utils/CallPromotionUtils.cpp | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 935c3a9e5c0c7..dea027561119a 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -2840,7 +2840,8 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { if (DerivedClassDecl->isPolymorphic() && DerivedClassDecl->isEffectivelyFinal()) { llvm::Value *BasePtr = Base.emitRawPointer(CGF); - CanQualType Ty = CGF.CGM.getContext().getCanonicalTagType(DerivedClassDecl); + CanQualType Ty = + CGF.CGM.getContext().getCanonicalTagType(DerivedClassDecl); llvm::Metadata *MD = CGF.CGM.CreateMetadataIdentifierForType(Ty); llvm::Value *TypeId = llvm::MetadataAsValue::get(CGF.CGM.getLLVMContext(), MD); diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index f9f5bd3c95b44..d2358a0b11338 100644 --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -743,8 +743,7 @@ static bool tryDevirtualizeViaTypeTestAssume(CallBase &CB, Value *Object, continue; if (TypeMD->getOperand(1).get() != TypeId) continue; - auto *OffsetCmd = - dyn_cast<ConstantAsMetadata>(TypeMD->getOperand(0)); + auto *OffsetCmd = dyn_cast<ConstantAsMetadata>(TypeMD->getOperand(0)); if (!OffsetCmd) continue; if (MatchedVTable) { _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
