[clang] [llvm] [RFC] Better devirtualization for non-virtual interface (PR #185087)
Nikita Taranov via cfe-commits
cfe-commits at lists.llvm.org
Sat Mar 7 09:07:17 PST 2026
https://github.com/nickitat updated https://github.com/llvm/llvm-project/pull/185087
>From 915a7cb2288f585791aae3261258185198455f80 Mon Sep 17 00:00:00 2001
From: Nikita Taranov <nikita.taranov at clickhouse.com>
Date: Sat, 24 Jan 2026 23:15:22 +0100
Subject: [PATCH 1/3] impl
---
clang/lib/CodeGen/CGExprScalar.cpp | 23 +++
.../CodeGen/devirt-downcast-type-test.cpp | 52 ++++++
.../Transforms/Utils/CallPromotionUtils.cpp | 160 ++++++++++++++----
...ual_interface_calls_through_static_cast.ll | 129 ++++++++++++++
4 files changed, 334 insertions(+), 30 deletions(-)
create mode 100644 clang/test/CodeGen/devirt-downcast-type-test.cpp
create mode 100644 llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 06eadb6c07507..935c3a9e5c0c7 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2827,6 +2827,29 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
CodeGenFunction::CFITCK_DerivedCast,
CE->getBeginLoc());
+ // Propagate static_cast<Derived*> type information to the middle-end via
+ // llvm.type.test + llvm.assume. The programmer's downcast asserts (with UB
+ // if violated) that the object is of the derived type, so we record this as
+ // a type assumption that devirtualization passes can exploit.
+ //
+ // We use the BASE/SOURCE object pointer (not the vtable pointer) as the
+ // type.test argument so that tryPromoteCall can find it immediately after
+ // inlining the callee: after inlining, the vtable is loaded from the same
+ // SSA value (the original object pointer), making the type.test findable by
+ // scanning uses of the object pointer.
+ if (DerivedClassDecl->isPolymorphic() &&
+ DerivedClassDecl->isEffectivelyFinal()) {
+ llvm::Value *BasePtr = Base.emitRawPointer(CGF);
+ CanQualType Ty = CGF.CGM.getContext().getCanonicalTagType(DerivedClassDecl);
+ llvm::Metadata *MD = CGF.CGM.CreateMetadataIdentifierForType(Ty);
+ llvm::Value *TypeId =
+ llvm::MetadataAsValue::get(CGF.CGM.getLLVMContext(), MD);
+ llvm::Value *TypeTest = CGF.Builder.CreateCall(
+ CGF.CGM.getIntrinsic(llvm::Intrinsic::type_test), {BasePtr, TypeId});
+ CGF.Builder.CreateCall(CGF.CGM.getIntrinsic(llvm::Intrinsic::assume),
+ TypeTest);
+ }
+
return CGF.getAsNaturalPointerTo(Derived, CE->getType()->getPointeeType());
}
case CK_UncheckedDerivedToBase:
diff --git a/clang/test/CodeGen/devirt-downcast-type-test.cpp b/clang/test/CodeGen/devirt-downcast-type-test.cpp
new file mode 100644
index 0000000000000..877c1dc140f70
--- /dev/null
+++ b/clang/test/CodeGen/devirt-downcast-type-test.cpp
@@ -0,0 +1,52 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -emit-llvm -o - %s | FileCheck %s
+//
+// Test that Clang emits llvm.type.test+llvm.assume on the object pointer at
+// CK_BaseToDerived (static_cast<Derived*>) cast sites when the derived class
+// is polymorphic and effectively final. This annotation allows the LLVM inliner
+// (tryPromoteCall) to devirtualize virtual calls through the downcast pointer
+// without requiring a visible vtable store.
+
+struct Base {
+ virtual void doFoo();
+ void foo() { doFoo(); }
+};
+
+struct Derived final : Base {
+ void doFoo() override;
+};
+
+// static_cast to a final polymorphic derived class: type.test must be emitted.
+void f(Base *b) {
+ static_cast<Derived *>(b)->foo();
+}
+
+// CHECK-LABEL: define {{.*}} @_Z1fP4Base(
+// CHECK: [[LOADED:%[0-9]+]] = load ptr, ptr %b.addr
+// CHECK-NEXT: [[TT:%[0-9]+]] = call i1 @llvm.type.test(ptr [[LOADED]], metadata !"_ZTS7Derived")
+// CHECK-NEXT: call void @llvm.assume(i1 [[TT]])
+
+struct NonPolyBase {};
+struct NonPolyDerived : NonPolyBase {};
+
+// static_cast to a non-polymorphic derived class: no type.test should be emitted.
+NonPolyDerived *g(NonPolyBase *b) {
+ return static_cast<NonPolyDerived *>(b);
+}
+
+// CHECK-LABEL: define {{.*}} @_Z1gP11NonPolyBase(
+// CHECK-NOT: llvm.type.test
+// CHECK: ret ptr
+
+struct NonFinalDerived : Base {
+ void doFoo() override;
+};
+
+// static_cast to a non-final polymorphic derived class: no type.test should be
+// emitted (the object could be a further-derived subclass with a different vtable).
+void h(Base *b) {
+ static_cast<NonFinalDerived *>(b)->foo();
+}
+
+// CHECK-LABEL: define {{.*}} @_Z1hP4Base(
+// CHECK-NOT: llvm.type.test
+// CHECK: ret void
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index f0f9add09bf82..f9f5bd3c95b44 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -17,6 +17,7 @@
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Constant.h"
+#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
@@ -682,60 +683,159 @@ CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
return promoteCall(NewInst, Callee);
}
+// Try to devirtualize an indirect virtual call using
+// llvm.type.test + llvm.assume pairs that were emitted by Clang,
+// e.g., at static_cast<Derived*> downcast sites.
+// Such a cast is a programmer assertion (UB if wrong) that the object is of
+// type Derived. Clang records this as:
+// %tt = call i1 @llvm.type.test(ptr %base_obj, metadata !"_ZTS4Derived")
+// call void @llvm.assume(i1 %tt)
+//
+// After the callee is inlined into the caller, the vtable is loaded from the
+// same object pointer (%base_obj). By scanning uses of Object (the vtable-load
+// source) for such type.test calls, we can determine the concrete vtable and
+// resolve the virtual call to a direct call.
+static bool tryDevirtualizeViaTypeTestAssume(CallBase &CB, Value *Object,
+ APInt VTableOffset,
+ const DataLayout &DL, Module &M) {
+ // Build a dominator tree so we can verify that all execution paths to the
+ // call (CB) must go through the assume that uses the type.test result.
+ DominatorTree DT(*CB.getFunction());
+
+ for (User *U : Object->users()) {
+ auto *TypeTestCI = dyn_cast<CallInst>(U);
+ if (!TypeTestCI || TypeTestCI->getIntrinsicID() != Intrinsic::type_test)
+ continue;
+ // The type.test must use Object as its pointer argument.
+ if (TypeTestCI->getArgOperand(0) != Object)
+ continue;
+
+ // There must be a dominating llvm.assume consuming the type.test result.
+ bool HasDominatingAssume = false;
+ for (User *TU : TypeTestCI->users()) {
+ if (auto *Assume = dyn_cast<AssumeInst>(TU);
+ Assume && DT.dominates(Assume, &CB)) {
+ HasDominatingAssume = true;
+ break;
+ }
+ }
+ if (!HasDominatingAssume)
+ continue;
+
+ // Extract the type metadata identifier, e.g. MDString "_ZTS4Impl".
+ Metadata *TypeId =
+ cast<MetadataAsValue>(TypeTestCI->getArgOperand(1))->getMetadata();
+
+ // Vtable lookup via !type metadata.
+ // We require exactly one matching vtable — if multiple vtables carry the
+ // same type ID the type is not effectively final and we cannot safely
+ // devirtualize (the object could be a further-derived subclass).
+ GlobalVariable *MatchedVTable = nullptr;
+ uint64_t MatchedAddrPointOffset = 0;
+ bool Ambiguous = false;
+ for (GlobalVariable &GV : M.globals()) {
+ if (!GV.isConstant() || !GV.hasDefinitiveInitializer())
+ continue;
+ SmallVector<MDNode *, 2> Types;
+ GV.getMetadata(LLVMContext::MD_type, Types);
+ for (MDNode *TypeMD : Types) {
+ if (TypeMD->getNumOperands() < 2)
+ continue;
+ if (TypeMD->getOperand(1).get() != TypeId)
+ continue;
+ auto *OffsetCmd =
+ dyn_cast<ConstantAsMetadata>(TypeMD->getOperand(0));
+ if (!OffsetCmd)
+ continue;
+ if (MatchedVTable) {
+ Ambiguous = true;
+ break;
+ }
+ MatchedVTable = &GV;
+ MatchedAddrPointOffset =
+ cast<ConstantInt>(OffsetCmd->getValue())->getZExtValue();
+ }
+ if (Ambiguous)
+ break;
+ }
+ if (MatchedVTable && !Ambiguous) {
+ if (VTableOffset.getActiveBits() > 64)
+ continue;
+ uint64_t TotalOffset =
+ MatchedAddrPointOffset + VTableOffset.getZExtValue();
+ auto [DirectCallee, _] =
+ getFunctionAtVTableOffset(MatchedVTable, TotalOffset, M);
+ if (DirectCallee && isLegalToPromote(CB, DirectCallee)) {
+ promoteCall(CB, DirectCallee);
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
bool llvm::tryPromoteCall(CallBase &CB) {
assert(!CB.getCalledFunction());
Module *M = CB.getCaller()->getParent();
const DataLayout &DL = M->getDataLayout();
Value *Callee = CB.getCalledOperand();
+ // We expect the indirect callee to be a function pointer loaded from a vtable
+ // slot, which is itself a getelementptr into the vtable, which is loaded from
+ // the object's vptr field. The chain is:
+ // %obj = ... (alloca or argument)
+ // %vtable = load ptr, ptr %obj (VTablePtrLoad)
+ // %vfn_slot = GEP ptr %vtable, i64 N (VTableEntryPtr, VTableOffset)
+ // %fn = load ptr, ptr %vfn_slot (VTableEntryLoad)
LoadInst *VTableEntryLoad = dyn_cast<LoadInst>(Callee);
if (!VTableEntryLoad)
- return false; // Not a vtable entry load.
+ return false;
Value *VTableEntryPtr = VTableEntryLoad->getPointerOperand();
APInt VTableOffset(DL.getIndexTypeSizeInBits(VTableEntryPtr->getType()), 0);
Value *VTableBasePtr = VTableEntryPtr->stripAndAccumulateConstantOffsets(
DL, VTableOffset, /* AllowNonInbounds */ true);
LoadInst *VTablePtrLoad = dyn_cast<LoadInst>(VTableBasePtr);
if (!VTablePtrLoad)
- return false; // Not a vtable load.
+ return false;
Value *Object = VTablePtrLoad->getPointerOperand();
APInt ObjectOffset(DL.getIndexTypeSizeInBits(Object->getType()), 0);
Value *ObjectBase = Object->stripAndAccumulateConstantOffsets(
DL, ObjectOffset, /* AllowNonInbounds */ true);
- if (!(isa<AllocaInst>(ObjectBase) && ObjectOffset == 0))
- // Not an Alloca or the offset isn't zero.
- return false;
- // Look for the vtable pointer store into the object by the ctor.
- BasicBlock::iterator BBI(VTablePtrLoad);
- Value *VTablePtr = FindAvailableLoadedValue(
- VTablePtrLoad, VTablePtrLoad->getParent(), BBI, 0, nullptr, nullptr);
- if (!VTablePtr || !VTablePtr->getType()->isPointerTy())
- return false; // No vtable found.
- APInt VTableOffsetGVBase(DL.getIndexTypeSizeInBits(VTablePtr->getType()), 0);
- Value *VTableGVBase = VTablePtr->stripAndAccumulateConstantOffsets(
- DL, VTableOffsetGVBase, /* AllowNonInbounds */ true);
- GlobalVariable *GV = dyn_cast<GlobalVariable>(VTableGVBase);
- if (!(GV && GV->isConstant() && GV->hasDefinitiveInitializer()))
- // Not in the form of a global constant variable with an initializer.
+ if (ObjectOffset != 0)
return false;
- APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset;
- if (!(VTableGVOffset.getActiveBits() <= 64))
- return false; // Out of range.
+ if (isa<AllocaInst>(ObjectBase)) {
+ // Look for a store of a concrete vtable pointer to the vptr field;
+ // this is set by the copy/move constructor when the object was materialised
+ // locally.
+ BasicBlock::iterator BBI(VTablePtrLoad);
+ Value *VTablePtr = FindAvailableLoadedValue(
+ VTablePtrLoad, VTablePtrLoad->getParent(), BBI, 0, nullptr, nullptr);
+ if (!VTablePtr || !VTablePtr->getType()->isPointerTy())
+ return false;
- Function *DirectCallee = nullptr;
- std::tie(DirectCallee, std::ignore) =
- getFunctionAtVTableOffset(GV, VTableGVOffset.getZExtValue(), *M);
- if (!DirectCallee)
- return false; // No function pointer found.
+ APInt VTableOffsetGVBase(DL.getIndexTypeSizeInBits(VTablePtr->getType()),
+ 0);
+ Value *VTableGVBase = VTablePtr->stripAndAccumulateConstantOffsets(
+ DL, VTableOffsetGVBase, /* AllowNonInbounds */ true);
+ GlobalVariable *GV = dyn_cast<GlobalVariable>(VTableGVBase);
+ if (!(GV && GV->isConstant() && GV->hasDefinitiveInitializer()))
+ return false;
- if (!isLegalToPromote(CB, DirectCallee))
- return false;
+ APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset;
+ if (VTableGVOffset.getActiveBits() > 64)
+ return false;
- // Success.
- promoteCall(CB, DirectCallee);
- return true;
+ auto [DirectCallee, _] =
+ getFunctionAtVTableOffset(GV, VTableGVOffset.getZExtValue(), *M);
+ if (!DirectCallee || !isLegalToPromote(CB, DirectCallee))
+ return false;
+
+ promoteCall(CB, DirectCallee);
+ return true;
+ }
+ return tryDevirtualizeViaTypeTestAssume(CB, Object, VTableOffset, DL, *M);
}
#undef DEBUG_TYPE
diff --git a/llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll b/llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll
new file mode 100644
index 0000000000000..c325b550ed811
--- /dev/null
+++ b/llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll
@@ -0,0 +1,129 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
+; RUN: opt < %s -passes='cgscc(inline),function(sroa,gvn)' -S | FileCheck %s
+;
+; Test: devirtualization of virtual calls through static_cast downcast pointers
+; via llvm.type.test+assume infrastructure.
+;
+; At the static_cast<Impl*>(intf) site, Clang emits:
+; %tt = llvm.type.test(ptr %intf, !"_ZTS4Impl")
+; llvm.assume(%tt)
+;
+; After the inliner inlines Intf::foo(), tryPromoteCall finds the type.test on
+; the object pointer, matches !"_ZTS4Impl" against !type metadata on @_ZTV4Impl,
+; resolves the vtable slot, and promotes the indirect call to @_ZN4Impl5doFooEv.
+;
+; Generated from the following C++ source with:
+; clang++ -O0 -flto -fwhole-program-vtables -S -emit-llvm file.cc
+; then hand-simplified.
+;
+; -flto is required for -fwhole-program-vtables, which causes Clang to emit
+; !type metadata on vtable globals. The vtable definition must be in the same
+; module for tryPromoteCall to resolve it; in a multi-TU build this happens at
+; LTO link time when all modules are merged.
+;
+; C++ source:
+;
+; int glob = 0;
+; int secretValue = 42;
+;
+; struct Intf {
+; void foo() { this->doFoo(); }
+; virtual void doFoo();
+; };
+;
+; struct Impl final : Intf {
+; void doFoo() override { glob = secretValue; }
+; };
+;
+; void f(Intf *intf) { static_cast<Impl *>(intf)->foo(); }
+
+
+%struct.Impl = type { %struct.Intf }
+%struct.Intf = type { ptr }
+
+ at glob = dso_local global i32 0, align 4
+ at secretValue = dso_local global i32 42, align 4
+ at _ZTV4Impl = linkonce_odr unnamed_addr constant { [3 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN4Impl5doFooEv] }, align 8, !type !0
+
+; f(Intf *intf)
+define dso_local void @_Z1fP4Intf(ptr noundef %intf) {
+; CHECK-LABEL: define dso_local void @_Z1fP4Intf(
+; CHECK-SAME: ptr noundef [[INTF:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = call i1 @llvm.type.test(ptr [[INTF]], metadata !"_ZTS4Impl")
+; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT: [[TMP1:%.*]] = load i32, ptr @secretValue, align 4
+; CHECK-NEXT: store i32 [[TMP1]], ptr @glob, align 4
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = call i1 @llvm.type.test(ptr %intf, metadata !"_ZTS4Impl")
+ call void @llvm.assume(i1 %0)
+ call void @_ZN4Intf3fooEv(ptr noundef nonnull align 8 dereferenceable(8) %intf)
+ ret void
+}
+
+; Negative test: the assume does NOT dominate the call to foo() because it is
+; only on one side of a branch. tryPromoteCall must not devirtualize here.
+define dso_local void @non_dominating_assume(ptr noundef %intf, i1 %cond) {
+; CHECK-LABEL: define dso_local void @non_dominating_assume(
+; CHECK-SAME: ptr noundef [[INTF:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = call i1 @llvm.type.test(ptr [[INTF]], metadata !"_ZTS4Impl")
+; CHECK-NEXT: br i1 [[COND]], label %[[THEN:.*]], label %[[ELSE:.*]]
+; CHECK: [[THEN]]:
+; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT: br label %[[MERGE:.*]]
+; CHECK: [[ELSE]]:
+; CHECK-NEXT: br label %[[MERGE]]
+; CHECK: [[MERGE]]:
+; CHECK-NEXT: [[VTABLE_I:%.*]] = load ptr, ptr [[INTF]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[VTABLE_I]], align 8
+; CHECK-NEXT: call void [[TMP1]](ptr noundef nonnull align 8 dereferenceable(8) [[INTF]])
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = call i1 @llvm.type.test(ptr %intf, metadata !"_ZTS4Impl")
+ br i1 %cond, label %then, label %else
+
+then:
+ call void @llvm.assume(i1 %0)
+ br label %merge
+
+else:
+ br label %merge
+
+merge:
+ call void @_ZN4Intf3fooEv(ptr noundef nonnull align 8 dereferenceable(8) %intf)
+ ret void
+}
+
+; Intf::foo() - non-virtual wrapper that makes the virtual call
+define linkonce_odr void @_ZN4Intf3fooEv(ptr noundef nonnull align 8 dereferenceable(8) %this) align 2 {
+entry:
+ %vtable = load ptr, ptr %this, align 8
+ %vfn = getelementptr inbounds ptr, ptr %vtable, i64 0
+ %0 = load ptr, ptr %vfn, align 8
+ call void %0(ptr noundef nonnull align 8 dereferenceable(8) %this)
+ ret void
+}
+
+; Impl::doFoo()
+define linkonce_odr void @_ZN4Impl5doFooEv(ptr noundef nonnull align 8 dereferenceable(8) %this) unnamed_addr align 2 {
+; CHECK-LABEL: define linkonce_odr void @_ZN4Impl5doFooEv(
+; CHECK-SAME: ptr noundef nonnull align 8 dereferenceable(8) [[THIS:%.*]]) unnamed_addr align 2 {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr @secretValue, align 4
+; CHECK-NEXT: store i32 [[TMP0]], ptr @glob, align 4
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = load i32, ptr @secretValue, align 4
+ store i32 %0, ptr @glob, align 4
+ ret void
+}
+
+; !type metadata: maps type ID !"_ZTS4Impl" to the vtable address point at byte
+; offset 16 (past offset-to-top and RTTI pointer). Emitted by Clang with
+; -fwhole-program-vtables.
+!0 = !{i64 16, !"_ZTS4Impl"}
>From d2f65f300c9231c5ab509c997ea6b184f5c53a2a Mon Sep 17 00:00:00 2001
From: Nikita Taranov <nikita.taranov at clickhouse.com>
Date: Fri, 6 Mar 2026 19:46:34 +0000
Subject: [PATCH 2/3] fix style
---
clang/lib/CodeGen/CGExprScalar.cpp | 3 ++-
llvm/lib/Transforms/Utils/CallPromotionUtils.cpp | 3 +--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 935c3a9e5c0c7..dea027561119a 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2840,7 +2840,8 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
if (DerivedClassDecl->isPolymorphic() &&
DerivedClassDecl->isEffectivelyFinal()) {
llvm::Value *BasePtr = Base.emitRawPointer(CGF);
- CanQualType Ty = CGF.CGM.getContext().getCanonicalTagType(DerivedClassDecl);
+ CanQualType Ty =
+ CGF.CGM.getContext().getCanonicalTagType(DerivedClassDecl);
llvm::Metadata *MD = CGF.CGM.CreateMetadataIdentifierForType(Ty);
llvm::Value *TypeId =
llvm::MetadataAsValue::get(CGF.CGM.getLLVMContext(), MD);
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index f9f5bd3c95b44..d2358a0b11338 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -743,8 +743,7 @@ static bool tryDevirtualizeViaTypeTestAssume(CallBase &CB, Value *Object,
continue;
if (TypeMD->getOperand(1).get() != TypeId)
continue;
- auto *OffsetCmd =
- dyn_cast<ConstantAsMetadata>(TypeMD->getOperand(0));
+ auto *OffsetCmd = dyn_cast<ConstantAsMetadata>(TypeMD->getOperand(0));
if (!OffsetCmd)
continue;
if (MatchedVTable) {
>From 610f41093c9886a30a915e06408a99ba531eb5b5 Mon Sep 17 00:00:00 2001
From: Nikita Taranov <nikita.taranov at clickhouse.com>
Date: Sat, 7 Mar 2026 15:54:19 +0000
Subject: [PATCH 3/3] account for possible vptr clobber
---
.../Transforms/Utils/CallPromotionUtils.h | 3 +-
llvm/lib/Transforms/IPO/Inliner.cpp | 3 +-
llvm/lib/Transforms/IPO/ModuleInliner.cpp | 3 +-
.../Transforms/Utils/CallPromotionUtils.cpp | 40 +++++--
...ual_interface_calls_through_static_cast.ll | 100 +++++++++++++++++-
.../Utils/CallPromotionUtilsTest.cpp | 52 +++++++--
6 files changed, 178 insertions(+), 23 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
index e9660ac25bc89..82d37b4f05c09 100644
--- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
@@ -17,6 +17,7 @@
#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Support/Compiler.h"
namespace llvm {
+class AAResults;
template <typename T> class ArrayRef;
class Constant;
class CallBase;
@@ -98,7 +99,7 @@ LLVM_ABI CallBase &promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
/// [i8* null, i8* bitcast ({ i8*, i8*, i8* }* @_ZTI4Impl to i8*),
/// i8* bitcast (void (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] }
///
-LLVM_ABI bool tryPromoteCall(CallBase &CB);
+LLVM_ABI bool tryPromoteCall(CallBase &CB, AAResults &AA);
/// Predicate and clone the given call site.
///
diff --git a/llvm/lib/Transforms/IPO/Inliner.cpp b/llvm/lib/Transforms/IPO/Inliner.cpp
index fb376562f6781..aa020e34a5df5 100644
--- a/llvm/lib/Transforms/IPO/Inliner.cpp
+++ b/llvm/lib/Transforms/IPO/Inliner.cpp
@@ -415,7 +415,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
// the post-inline cleanup and the next DevirtSCCRepeatedPass
// iteration because the next iteration may not happen and we may
// miss inlining it.
- if (tryPromoteCall(*ICB))
+ if (tryPromoteCall(
+ *ICB, FAM.getResult<AAManager>(*ICB->getCaller())))
NewCallee = ICB->getCalledFunction();
}
if (NewCallee) {
diff --git a/llvm/lib/Transforms/IPO/ModuleInliner.cpp b/llvm/lib/Transforms/IPO/ModuleInliner.cpp
index 3e0bb6d1432b2..db72f6f831d58 100644
--- a/llvm/lib/Transforms/IPO/ModuleInliner.cpp
+++ b/llvm/lib/Transforms/IPO/ModuleInliner.cpp
@@ -261,7 +261,8 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
// miss inlining it.
// FIXME: enable for ctxprof.
if (CtxProf.isInSpecializedModule())
- if (tryPromoteCall(*ICB))
+ if (tryPromoteCall(
+ *ICB, FAM.getResult<AAManager>(*ICB->getCaller())))
NewCallee = ICB->getCalledFunction();
}
if (NewCallee)
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index d2358a0b11338..93d71ebe2def6 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
+#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
@@ -697,7 +698,8 @@ CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
// resolve the virtual call to a direct call.
static bool tryDevirtualizeViaTypeTestAssume(CallBase &CB, Value *Object,
APInt VTableOffset,
- const DataLayout &DL, Module &M) {
+ const DataLayout &DL, Module &M,
+ AAResults &AA) {
// Build a dominator tree so we can verify that all execution paths to the
// call (CB) must go through the assume that uses the type.test result.
DominatorTree DT(*CB.getFunction());
@@ -710,16 +712,36 @@ static bool tryDevirtualizeViaTypeTestAssume(CallBase &CB, Value *Object,
if (TypeTestCI->getArgOperand(0) != Object)
continue;
- // There must be a dominating llvm.assume consuming the type.test result.
- bool HasDominatingAssume = false;
+ // There must be a dominating llvm.assume consuming the type.test result,
+ // and no instruction between the assume and the indirect call may clobber
+ // the vptr (e.g. via placement new or destructor). We use alias analysis to
+ // check specifically whether an intervening write could alias the vptr slot.
+ // We require the assume and the call to be in the same basic block.
+ MemoryLocation VptrLoc =
+ MemoryLocation(Object, LocationSize::precise(DL.getPointerSize()));
+ bool HasValidAssume = false;
for (User *TU : TypeTestCI->users()) {
- if (auto *Assume = dyn_cast<AssumeInst>(TU);
- Assume && DT.dominates(Assume, &CB)) {
- HasDominatingAssume = true;
+ auto *Assume = dyn_cast<AssumeInst>(TU);
+ if (!Assume || !DT.dominates(Assume, &CB))
+ continue;
+ // Require assume and CB in the same basic block so we can walk between.
+ if (Assume->getParent() != CB.getParent())
+ continue;
+ // Walk from the assume to CB and check for instructions that could
+ // clobber the vptr (e.g. destructor + placement new).
+ bool Clobbered = false;
+ for (auto It = std::next(Assume->getIterator()); &*It != &CB; ++It) {
+ if (isModSet(AA.getModRefInfo(&*It, VptrLoc))) {
+ Clobbered = true;
+ break;
+ }
+ }
+ if (!Clobbered) {
+ HasValidAssume = true;
break;
}
}
- if (!HasDominatingAssume)
+ if (!HasValidAssume)
continue;
// Extract the type metadata identifier, e.g. MDString "_ZTS4Impl".
@@ -773,7 +795,7 @@ static bool tryDevirtualizeViaTypeTestAssume(CallBase &CB, Value *Object,
return false;
}
-bool llvm::tryPromoteCall(CallBase &CB) {
+bool llvm::tryPromoteCall(CallBase &CB, AAResults &AA) {
assert(!CB.getCalledFunction());
Module *M = CB.getCaller()->getParent();
const DataLayout &DL = M->getDataLayout();
@@ -834,7 +856,7 @@ bool llvm::tryPromoteCall(CallBase &CB) {
promoteCall(CB, DirectCallee);
return true;
}
- return tryDevirtualizeViaTypeTestAssume(CB, Object, VTableOffset, DL, *M);
+ return tryDevirtualizeViaTypeTestAssume(CB, Object, VTableOffset, DL, *M, AA);
}
#undef DEBUG_TYPE
diff --git a/llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll b/llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll
index c325b550ed811..01cf9f7f8712c 100644
--- a/llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll
+++ b/llvm/test/Transforms/Inline/devirtualize_non_virtual_interface_calls_through_static_cast.ll
@@ -65,8 +65,8 @@ entry:
; Negative test: the assume does NOT dominate the call to foo() because it is
; only on one side of a branch. tryPromoteCall must not devirtualize here.
-define dso_local void @non_dominating_assume(ptr noundef %intf, i1 %cond) {
-; CHECK-LABEL: define dso_local void @non_dominating_assume(
+define dso_local void @NonDominatingAssume(ptr noundef %intf, i1 %cond) {
+; CHECK-LABEL: define dso_local void @NonDominatingAssume(
; CHECK-SAME: ptr noundef [[INTF:%.*]], i1 [[COND:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = call i1 @llvm.type.test(ptr [[INTF]], metadata !"_ZTS4Impl")
@@ -98,6 +98,102 @@ merge:
ret void
}
+; Negative test: a call between the assume and foo() could clobber the vptr
+; (e.g. destructor + placement new). tryPromoteCall must not devirtualize.
+declare void @MayClobberVptr(ptr)
+
+define dso_local void @VptrClobbered(ptr noundef %intf) {
+; CHECK-LABEL: define dso_local void @VptrClobbered(
+; CHECK-SAME: ptr noundef [[INTF:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = call i1 @llvm.type.test(ptr [[INTF]], metadata !"_ZTS4Impl")
+; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT: call void @MayClobberVptr(ptr [[INTF]])
+; CHECK-NEXT: [[VTABLE_I:%.*]] = load ptr, ptr [[INTF]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[VTABLE_I]], align 8
+; CHECK-NEXT: call void [[TMP1]](ptr noundef nonnull align 8 dereferenceable(8) [[INTF]])
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = call i1 @llvm.type.test(ptr %intf, metadata !"_ZTS4Impl")
+ call void @llvm.assume(i1 %0)
+ call void @MayClobberVptr(ptr %intf)
+ call void @_ZN4Intf3fooEv(ptr noundef nonnull align 8 dereferenceable(8) %intf)
+ ret void
+}
+
+; Positive test: a store to an unrelated global between the assume and foo()
+; cannot alias the vptr slot. With alias analysis, tryPromoteCall can still
+; devirtualize.
+ at unrelated_global = dso_local global i32 0, align 4
+
+define dso_local void @UnrelatedWriteNoAlias(ptr noundef %intf) {
+; CHECK-LABEL: define dso_local void @UnrelatedWriteNoAlias(
+; CHECK-SAME: ptr noundef [[INTF:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = call i1 @llvm.type.test(ptr [[INTF]], metadata !"_ZTS4Impl")
+; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT: store i32 99, ptr @unrelated_global, align 4
+; CHECK-NEXT: [[TMP1:%.*]] = load i32, ptr @secretValue, align 4
+; CHECK-NEXT: store i32 [[TMP1]], ptr @glob, align 4
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = call i1 @llvm.type.test(ptr %intf, metadata !"_ZTS4Impl")
+ call void @llvm.assume(i1 %0)
+ store i32 99, ptr @unrelated_global, align 4
+ call void @_ZN4Intf3fooEv(ptr noundef nonnull align 8 dereferenceable(8) %intf)
+ ret void
+}
+
+; Negative test: an unknown call on an alias of %intf could clobber the vptr.
+; AA detects that %alias may-alias %intf, so devirtualization is blocked.
+define dso_local void @VptrClobberedViaAlias(ptr noundef %intf) {
+; CHECK-LABEL: define dso_local void @VptrClobberedViaAlias(
+; CHECK-SAME: ptr noundef [[INTF:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = call i1 @llvm.type.test(ptr [[INTF]], metadata !"_ZTS4Impl")
+; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT: call void @MayClobberVptr(ptr [[INTF]])
+; CHECK-NEXT: [[VTABLE_I:%.*]] = load ptr, ptr [[INTF]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[VTABLE_I]], align 8
+; CHECK-NEXT: call void [[TMP1]](ptr noundef nonnull align 8 dereferenceable(8) [[INTF]])
+; CHECK-NEXT: ret void
+;
+entry:
+ %alias = getelementptr i8, ptr %intf, i64 0
+ %0 = call i1 @llvm.type.test(ptr %intf, metadata !"_ZTS4Impl")
+ call void @llvm.assume(i1 %0)
+ call void @MayClobberVptr(ptr %alias)
+ call void @_ZN4Intf3fooEv(ptr noundef nonnull align 8 dereferenceable(8) %intf)
+ ret void
+}
+
+; Negative test: pointer laundered through ptrtoint/inttoptr. AA must still
+; conservatively block devirtualization because the result may alias %intf.
+define dso_local void @VptrClobberedViaIntToPtr(ptr noundef %intf) {
+; CHECK-LABEL: define dso_local void @VptrClobberedViaIntToPtr(
+; CHECK-SAME: ptr noundef [[INTF:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[INT:%.*]] = ptrtoint ptr [[INTF]] to i64
+; CHECK-NEXT: [[TMP0:%.*]] = call i1 @llvm.type.test(ptr [[INTF]], metadata !"_ZTS4Impl")
+; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]])
+; CHECK-NEXT: call void @MayClobberVptr(ptr [[INTF]])
+; CHECK-NEXT: [[VTABLE_I:%.*]] = load ptr, ptr [[INTF]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[VTABLE_I]], align 8
+; CHECK-NEXT: call void [[TMP1]](ptr noundef nonnull align 8 dereferenceable(8) [[INTF]])
+; CHECK-NEXT: ret void
+;
+entry:
+ %int = ptrtoint ptr %intf to i64
+ %laundered = inttoptr i64 %int to ptr
+ %0 = call i1 @llvm.type.test(ptr %intf, metadata !"_ZTS4Impl")
+ call void @llvm.assume(i1 %0)
+ call void @MayClobberVptr(ptr %laundered)
+ call void @_ZN4Intf3fooEv(ptr noundef nonnull align 8 dereferenceable(8) %intf)
+ ret void
+}
+
; Intf::foo() - non-virtual wrapper that makes the virtual call
define linkonce_odr void @_ZN4Intf3fooEv(ptr noundef nonnull align 8 dereferenceable(8) %this) align 2 {
entry:
diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
index c859b0c799f08..a37a1f4950886 100644
--- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
@@ -7,7 +7,9 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
+#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/CtxProfAnalysis.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
@@ -74,6 +76,10 @@ define void @f() {
declare void @_ZN4Impl3RunEv(ptr %this)
)IR");
+ TargetLibraryInfoImpl TLII(M->getTargetTriple());
+ TargetLibraryInfo TLI(TLII);
+ AAResults AA(TLI);
+
auto *GV = M->getNamedValue("f");
ASSERT_TRUE(GV);
auto *F = dyn_cast<Function>(GV);
@@ -85,7 +91,7 @@ declare void @_ZN4Impl3RunEv(ptr %this)
auto *CI = dyn_cast<CallInst>(Inst);
ASSERT_TRUE(CI);
ASSERT_FALSE(CI->getCalledFunction());
- bool IsPromoted = tryPromoteCall(*CI);
+ bool IsPromoted = tryPromoteCall(*CI, AA);
EXPECT_TRUE(IsPromoted);
GV = M->getNamedValue("_ZN4Impl3RunEv");
ASSERT_TRUE(GV);
@@ -107,6 +113,10 @@ define void @f(ptr %fp, ptr nonnull %base.i) {
}
)IR");
+ TargetLibraryInfoImpl TLII(M->getTargetTriple());
+ TargetLibraryInfo TLI(TLII);
+ AAResults AA(TLI);
+
auto *GV = M->getNamedValue("f");
ASSERT_TRUE(GV);
auto *F = dyn_cast<Function>(GV);
@@ -115,7 +125,7 @@ define void @f(ptr %fp, ptr nonnull %base.i) {
auto *CI = dyn_cast<CallInst>(Inst);
ASSERT_TRUE(CI);
ASSERT_FALSE(CI->getCalledFunction());
- bool IsPromoted = tryPromoteCall(*CI);
+ bool IsPromoted = tryPromoteCall(*CI, AA);
EXPECT_FALSE(IsPromoted);
}
@@ -134,6 +144,10 @@ define void @f(ptr %vtable.i, ptr nonnull %base.i) {
}
)IR");
+ TargetLibraryInfoImpl TLII(M->getTargetTriple());
+ TargetLibraryInfo TLI(TLII);
+ AAResults AA(TLI);
+
auto *GV = M->getNamedValue("f");
ASSERT_TRUE(GV);
auto *F = dyn_cast<Function>(GV);
@@ -142,7 +156,7 @@ define void @f(ptr %vtable.i, ptr nonnull %base.i) {
auto *CI = dyn_cast<CallInst>(Inst);
ASSERT_TRUE(CI);
ASSERT_FALSE(CI->getCalledFunction());
- bool IsPromoted = tryPromoteCall(*CI);
+ bool IsPromoted = tryPromoteCall(*CI, AA);
EXPECT_FALSE(IsPromoted);
}
@@ -168,6 +182,10 @@ define void @f() {
declare void @_ZN4Impl3RunEv(ptr %this)
)IR");
+ TargetLibraryInfoImpl TLII(M->getTargetTriple());
+ TargetLibraryInfo TLI(TLII);
+ AAResults AA(TLI);
+
auto *GV = M->getNamedValue("f");
ASSERT_TRUE(GV);
auto *F = dyn_cast<Function>(GV);
@@ -176,7 +194,7 @@ declare void @_ZN4Impl3RunEv(ptr %this)
auto *CI = dyn_cast<CallInst>(Inst);
ASSERT_TRUE(CI);
ASSERT_FALSE(CI->getCalledFunction());
- bool IsPromoted = tryPromoteCall(*CI);
+ bool IsPromoted = tryPromoteCall(*CI, AA);
EXPECT_FALSE(IsPromoted);
}
@@ -206,6 +224,10 @@ define void @f() {
declare void @_ZN4Impl3RunEv(ptr %this)
)IR");
+ TargetLibraryInfoImpl TLII(M->getTargetTriple());
+ TargetLibraryInfo TLI(TLII);
+ AAResults AA(TLI);
+
auto *GV = M->getNamedValue("f");
ASSERT_TRUE(GV);
auto *F = dyn_cast<Function>(GV);
@@ -217,7 +239,7 @@ declare void @_ZN4Impl3RunEv(ptr %this)
auto *CI = dyn_cast<CallInst>(Inst);
ASSERT_TRUE(CI);
ASSERT_FALSE(CI->getCalledFunction());
- bool IsPromoted = tryPromoteCall(*CI);
+ bool IsPromoted = tryPromoteCall(*CI, AA);
EXPECT_FALSE(IsPromoted);
}
@@ -247,6 +269,10 @@ define void @f() {
declare void @_ZN4Impl3RunEv(ptr %this)
)IR");
+ TargetLibraryInfoImpl TLII(M->getTargetTriple());
+ TargetLibraryInfo TLI(TLII);
+ AAResults AA(TLI);
+
auto *GV = M->getNamedValue("f");
ASSERT_TRUE(GV);
auto *F = dyn_cast<Function>(GV);
@@ -258,7 +284,7 @@ declare void @_ZN4Impl3RunEv(ptr %this)
auto *CI = dyn_cast<CallInst>(Inst);
ASSERT_TRUE(CI);
ASSERT_FALSE(CI->getCalledFunction());
- bool IsPromoted = tryPromoteCall(*CI);
+ bool IsPromoted = tryPromoteCall(*CI, AA);
EXPECT_FALSE(IsPromoted);
}
@@ -298,6 +324,10 @@ declare i32 @_ZN1A3vf1Ev(ptr %this)
declare i32 @_ZN1A3vf2Ev(ptr %this)
)IR");
+ TargetLibraryInfoImpl TLII(M->getTargetTriple());
+ TargetLibraryInfo TLI(TLII);
+ AAResults AA(TLI);
+
auto *GV = M->getNamedValue("_Z2g1v");
ASSERT_TRUE(GV);
auto *F = dyn_cast<Function>(GV);
@@ -309,7 +339,7 @@ declare i32 @_ZN1A3vf2Ev(ptr %this)
auto *CI = dyn_cast<CallInst>(Inst);
ASSERT_TRUE(CI);
ASSERT_FALSE(CI->getCalledFunction());
- bool IsPromoted1 = tryPromoteCall(*CI);
+ bool IsPromoted1 = tryPromoteCall(*CI, AA);
EXPECT_TRUE(IsPromoted1);
GV = M->getNamedValue("_ZN1A3vf1Ev");
ASSERT_TRUE(GV);
@@ -327,7 +357,7 @@ declare i32 @_ZN1A3vf2Ev(ptr %this)
CI = dyn_cast<CallInst>(Inst);
ASSERT_TRUE(CI);
ASSERT_FALSE(CI->getCalledFunction());
- bool IsPromoted2 = tryPromoteCall(*CI);
+ bool IsPromoted2 = tryPromoteCall(*CI, AA);
EXPECT_TRUE(IsPromoted2);
GV = M->getNamedValue("_ZN1A3vf2Ev");
ASSERT_TRUE(GV);
@@ -365,6 +395,10 @@ define %struct1 @f() {
declare %struct2 @_ZN4Impl3RunEv(ptr %this)
)IR");
+ TargetLibraryInfoImpl TLII(M->getTargetTriple());
+ TargetLibraryInfo TLI(TLII);
+ AAResults AA(TLI);
+
auto *GV = M->getNamedValue("f");
ASSERT_TRUE(GV);
auto *F = dyn_cast<Function>(GV);
@@ -376,7 +410,7 @@ declare %struct2 @_ZN4Impl3RunEv(ptr %this)
auto *CI = dyn_cast<CallInst>(Inst);
ASSERT_TRUE(CI);
ASSERT_FALSE(CI->getCalledFunction());
- bool IsPromoted = tryPromoteCall(*CI);
+ bool IsPromoted = tryPromoteCall(*CI, AA);
EXPECT_FALSE(IsPromoted);
}
More information about the cfe-commits
mailing list