[llvm-branch-commits] [llvm] [CallPromotionUtils]Implement conditional indirect call promotion with vtable-based comparison (PR #81378)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Feb 12 10:07:52 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Mingming Liu (minglotus-6)
<details>
<summary>Changes</summary>
* Given the code sequence
```
bb:
%vtable = load ptr, ptr %d, !prof !8
%vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
%1 = load ptr, ptr %vfn
%call = tail call i32 %1(ptr %d), !prof !9
```
The transformation looks like
```
bb:
%vtable = load ptr, ptr %d, align 8
%vfn = getelementptr inbounds i8, ptr %vtable, i64 8 <-- Inst 1
%func-addr = load ptr, ptr %vfn, align 8 <-- Inst 2
# compare loaded pointers with address point of vtables
%1 = icmp eq ptr %vtable, getelementptr inbounds ({ [4 x ptr] }, ptr @<!-- -->_ZTV<VTable>, i64 0, i32 0, i64 2)
br i1 %1, label %if.true.direct_targ, label %if.false.orig_indirect, !prof !18
if.true.direct_targ: ; preds = %entry
%2 = tail call i32 @<direct-call>(ptr nonnull %d)
br label %if.end.icp
if.false.orig_indirect: ; preds = %entry
%call = tail call i32 %func-addr(ptr nonnull %d)
br label %if.end.icp
if.end.icp: ; preds = %if.false.orig_indirect, %if.true.direct_targ
%4 = phi i32 [ %call, %if.false.orig_indirect ], [ %2, %if.true.direct_targ ]
```
It's intentional that `Inst 1` and `Inst2` remains in `bb` (not in `if.false.orig_indirect`). Subsequent pass (notably `instcombine` would sink them (and handles debug intrinsics properly) if possible.
* The parent patch is https://github.com/llvm/llvm-project/pull/81181
---
Full diff: https://github.com/llvm/llvm-project/pull/81378.diff
3 Files Affected:
- (modified) llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h (+43-6)
- (modified) llvm/lib/Transforms/Utils/CallPromotionUtils.cpp (+62-2)
- (modified) llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp (+119)
``````````diff
diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
index fcb384ec361339..32b252d132c04c 100644
--- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
@@ -14,10 +14,16 @@
#ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
+#include <cstdint>
+
namespace llvm {
+template <typename T> class ArrayRef;
+class Constant;
class CallBase;
class CastInst;
class Function;
+class GlobalVariable;
+class Instruction;
class MDNode;
class Value;
@@ -41,7 +47,9 @@ bool isLegalToPromote(const CallBase &CB, Function *Callee,
CallBase &promoteCall(CallBase &CB, Function *Callee,
CastInst **RetBitCast = nullptr);
-/// Promote the given indirect call site to conditionally call \p Callee.
+/// Promote the given indirect call site to conditionally call \p Callee. The
+/// promoted direct call instruction is predicated on `CB.getCalledOperand() ==
+/// Callee`.
///
/// This function creates an if-then-else structure at the location of the call
/// site. The original call site is moved into the "else" block. A clone of the
@@ -51,6 +59,31 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
MDNode *BranchWeights = nullptr);
+/// This is similar to `promoteCallWithIfThenElse` except that the condition to
+/// promote a virtual call is that \p VPtr is the same as any of \p
+/// AddressPoints.
+///
+/// This function is expected to be used on virtual calls (a subset of indirect
+/// calls). \p VPtr is the virtual table address stored in the objects, and
+/// \p AddressPoints contains address points of vtables to be compared with.
+///
+/// It's the responsibility of caller to guarantee the transformation
+/// correctness (by specifying \p VPtr and \p AddressPoints properly).
+///
+/// This function doesn't sink the address-calculation instructions of indirect
+/// callee to the indirect call fallback. The subsequent passes (e.g.
+/// inst-combine) should sink them if possible and handle the sink of debug
+/// intrinsics together.
+CallBase &promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
+ Function *Callee,
+ ArrayRef<Constant *> AddressPoints,
+ MDNode *BranchWeights);
+
+/// Returns a constant representing the vtable's address point specified by the
+/// offset. Caller should ensure \p AddressPointOffset is valid.
+Constant *getVTableAddressPointOffset(GlobalVariable *VTable,
+ uint32_t AddressPointOffset);
+
/// Try to promote (devirtualize) a virtual call on an Alloca. Return true on
/// success.
///
@@ -74,13 +107,17 @@ CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
///
bool tryPromoteCall(CallBase &CB);
+/// Predicate and clone the given call site using the given condition.
+CallBase &versionCallSiteWithCond(CallBase &CB, Value *Cond,
+ MDNode *BranchWeights);
+
/// Predicate and clone the given call site.
///
-/// This function creates an if-then-else structure at the location of the call
-/// site. The "if" condition compares the call site's called value to the given
-/// callee. The original call site is moved into the "else" block, and a clone
-/// of the call site is placed in the "then" block. The cloned instruction is
-/// returned.
+/// This function creates an if-then-else structure at the location of the
+/// call site. The "if" condition compares the call site's called value to
+/// the given callee. The original call site is moved into the "else" block,
+/// and a clone of the call site is placed in the "then" block. The cloned
+/// instruction is returned.
CallBase &versionCallSite(CallBase &CB, Value *Callee, MDNode *BranchWeights);
} // end namespace llvm
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index d0cf0792eface0..ea855b9a4d8416 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -12,9 +12,11 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/AttributeMask.h"
+#include "llvm/IR/Constant.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -185,6 +187,24 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
U->replaceUsesOfWith(&CB, Cast);
}
+// Returns the or result of all icmp instructions.
+static Value *getOrResult(const SmallVector<Value *, 2> &ICmps,
+ IRBuilder<> &Builder) {
+ assert(!ICmps.empty() && "Must have at least one icmp instructions");
+ if (ICmps.size() == 1)
+ return ICmps[0];
+
+ SmallVector<Value *, 2> OrResults;
+ int i = 0, NumICmp = ICmps.size();
+ for (i = 0; i + 1 < NumICmp; i += 2)
+ OrResults.push_back(Builder.CreateOr(ICmps[i], ICmps[i + 1], "icmp-or"));
+
+ if (i < NumICmp)
+ OrResults.push_back(ICmps[i]);
+
+ return getOrResult(OrResults, Builder);
+}
+
/// Predicate and clone the given call site.
///
/// This function creates an if-then-else structure at the location of the call
@@ -276,8 +296,8 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
/// ; The original call instruction stays in its original block.
/// %t0 = musttail call i32 %ptr()
/// ret %t0
-static CallBase &versionCallSiteWithCond(CallBase &CB, Value *Cond,
- MDNode *BranchWeights) {
+CallBase &llvm::versionCallSiteWithCond(CallBase &CB, Value *Cond,
+ MDNode *BranchWeights) {
IRBuilder<> Builder(&CB);
CallBase *OrigInst = &CB;
@@ -565,6 +585,46 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
return promoteCall(NewInst, Callee);
}
+Constant *llvm::getVTableAddressPointOffset(GlobalVariable *VTable,
+ uint32_t AddressPointOffset) {
+ Module &M = *VTable->getParent();
+ const DataLayout &DL = M.getDataLayout();
+ LLVMContext &Context = M.getContext();
+ Type *VTableType = VTable->getValueType();
+ assert(AddressPointOffset < DL.getTypeAllocSize(VTableType) &&
+ "Out-of-bound access");
+ APInt AddressPointOffsetAPInt(32, AddressPointOffset, false);
+ SmallVector<APInt> Indices =
+ DL.getGEPIndicesForOffset(VTableType, AddressPointOffsetAPInt);
+ SmallVector<llvm::Constant *> GEPIndices;
+ for (const auto &Index : Indices)
+ GEPIndices.push_back(llvm::ConstantInt::get(Type::getInt32Ty(Context),
+ Index.getZExtValue()));
+
+ return ConstantExpr::getInBoundsGetElementPtr(VTable->getValueType(), VTable,
+ GEPIndices);
+}
+
+CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
+ Function *Callee,
+ ArrayRef<Constant *> AddressPoints,
+ MDNode *BranchWeights) {
+ assert(!AddressPoints.empty() && "Caller should guarantee");
+ IRBuilder<> Builder(&CB);
+ SmallVector<Value *, 2> ICmps;
+ for (auto &AddressPoint : AddressPoints)
+ ICmps.push_back(Builder.CreateICmpEQ(VPtr, AddressPoint));
+
+ Value *Cond = getOrResult(ICmps, Builder);
+
+ // Version the indirect call site. If Cond is true, 'NewInst' will be
+ // executed, otherwise the original call site will be executed.
+ CallBase &NewInst = versionCallSiteWithCond(CB, Cond, BranchWeights);
+
+ // Promote 'NewInst' so that it directly calls the desired function.
+ return promoteCall(NewInst, Callee);
+}
+
bool llvm::tryPromoteCall(CallBase &CB) {
assert(!CB.getCalledFunction());
Module *M = CB.getCaller()->getParent();
diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
index eff8e27d36d641..227156378369b5 100644
--- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
@@ -8,9 +8,12 @@
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/NoFolder.h"
#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"
@@ -368,3 +371,119 @@ declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this)
bool IsPromoted = tryPromoteCall(*CI);
EXPECT_FALSE(IsPromoted);
}
+
+TEST(CallPromotionUtilsTest, getVTableAddressPointOffset) {
+ LLVMContext C;
+ std::unique_ptr<Module> M = parseIR(C,
+ R"IR(
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+ at _ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }
+
+declare i32 @_ZN5Base15func1Ev(ptr)
+declare i32 @_ZN5Base25func2Ev(ptr)
+declare i32 @_ZN5Base15func0Ev(ptr)
+declare void @_ZN5Base35func3Ev(ptr)
+)IR");
+ GlobalVariable *GV = M->getGlobalVariable("_ZTV8Derived2");
+
+ for (auto [AddressPointOffset, Index] :
+ {std::pair{16, 0}, {40, 1}, {64, 2}}) {
+ Constant *AddressPoint =
+ getVTableAddressPointOffset(GV, AddressPointOffset);
+
+ ConstantExpr *GEP = dyn_cast<ConstantExpr>(AddressPoint);
+ ASSERT_TRUE(GEP);
+ SmallVector<Constant *> Indices = {
+ llvm::ConstantInt::get(Type::getInt32Ty(C), 0U),
+ llvm::ConstantInt::get(Type::getInt32Ty(C), Index),
+ llvm::ConstantInt::get(Type::getInt32Ty(C), 2U)};
+ EXPECT_EQ(GEP, ConstantExpr::getInBoundsGetElementPtr(GV->getValueType(),
+ GV, Indices));
+ }
+}
+
+TEST(CallPromotionUtilsTest, promoteCallWithVTableCmp) {
+ LLVMContext C;
+ std::unique_ptr<Module> M = parseIR(C,
+ R"IR(
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+ at _ZTV5Base1 = constant { [4 x ptr] } { [4 x ptr] [ptr null, ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !0
+ at _ZTV8Derived1 = constant { [4 x ptr], [3 x ptr] } { [4 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev], [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !0, !type !1, !type !2
+ at _ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !3, !type !4, !type !5, !type !6
+
+define i32 @testfunc(ptr %d) {
+entry:
+ %vtable = load ptr, ptr %d, !prof !7
+ %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
+ %0 = load ptr, ptr %vfn
+ %call = tail call i32 %0(ptr %d), !prof !8
+ ret i32 %call
+}
+
+define i32 @_ZN5Base15func1Ev(ptr %this) {
+entry:
+ ret i32 2
+}
+
+declare i32 @_ZN5Base25func2Ev(ptr)
+declare i32 @_ZN5Base15func0Ev(ptr)
+declare void @_ZN5Base35func3Ev(ptr)
+
+!0 = !{i64 16, !"_ZTS5Base1"}
+!1 = !{i64 48, !"_ZTS5Base2"}
+!2 = !{i64 16, !"_ZTS8Derived1"}
+!3 = !{i64 64, !"_ZTS5Base1"}
+!4 = !{i64 40, !"_ZTS5Base2"}
+!5 = !{i64 16, !"_ZTS5Base3"}
+!6 = !{i64 16, !"_ZTS8Derived2"}
+!7 = !{!"VP", i32 2, i64 1600, i64 -9064381665493407289, i64 800, i64 5035968517245772950, i64 500, i64 3215870116411581797, i64 300}
+!8 = !{!"VP", i32 0, i64 1600, i64 6804820478065511155, i64 1600})IR");
+
+ Function *F = M->getFunction("testfunc");
+ ASSERT_TRUE(F);
+ CallInst *CI = dyn_cast<CallInst>(&*std::next(F->front().rbegin()));
+ ASSERT_TRUE(CI && CI->isIndirectCall());
+
+ LoadInst *FuncPtr = dyn_cast<LoadInst>(CI->getCalledOperand());
+ ASSERT_TRUE(FuncPtr);
+
+ GetElementPtrInst *GEP =
+ dyn_cast<GetElementPtrInst>(FuncPtr->getPointerOperand());
+ ASSERT_TRUE(GEP);
+
+ // Create the constant and the branch weights
+ SmallVector<Constant *, 3> VTableAddressPoints;
+
+ for (auto &[VTableName, AddressPointOffset] : {std::pair{"_ZTV5Base1", 16},
+ {"_ZTV8Derived1", 16},
+ {"_ZTV8Derived2", 64}})
+ VTableAddressPoints.push_back(getVTableAddressPointOffset(
+ M->getGlobalVariable(VTableName), AddressPointOffset));
+
+ MDBuilder MDB(C);
+ MDNode *BranchWeights = MDB.createBranchWeights(1600, 0);
+
+ size_t OrigEntryBBSize = F->front().size();
+
+ LoadInst *VPtr = dyn_cast<LoadInst>(&*F->front().begin());
+
+ Function *Callee = M->getFunction("_ZN5Base15func1Ev");
+ // Tests that promoted direct call is returned.
+ CallBase &DirectCB = promoteCallWithVTableCmp(
+ *CI, VPtr, Callee, VTableAddressPoints, BranchWeights);
+ EXPECT_EQ(DirectCB.getCalledOperand(), Callee);
+
+ // GEP and FuncPtr remains in the original block. `promoteCallWithVTableCmp`
+ // doesn't sink them to the basic block of indirect fallback.
+ BasicBlock *EntryBB = &F->front();
+ EXPECT_EQ(EntryBB, GEP->getParent());
+ EXPECT_EQ(EntryBB, FuncPtr->getParent());
+
+ // Promotion inserts 3 icmp instructions and 2 or instructions, and removes
+ // 1 call instruction from the entry block.
+ EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/81378
More information about the llvm-branch-commits
mailing list