[llvm] 5d3f296 - [CallPromotionUtils]Implement conditional indirect call promotion with vtable-based comparison (#81378)

via llvm-commits llvm-commits at lists.llvm.org
Sun May 19 16:33:21 PDT 2024


Author: Mingming Liu
Date: 2024-05-19T16:33:17-07:00
New Revision: 5d3f296733b66281a53dd451a983e69ae0bb482f

URL: https://github.com/llvm/llvm-project/commit/5d3f296733b66281a53dd451a983e69ae0bb482f
DIFF: https://github.com/llvm/llvm-project/commit/5d3f296733b66281a53dd451a983e69ae0bb482f.diff

LOG: [CallPromotionUtils]Implement conditional indirect call promotion with vtable-based comparison (#81378)

* Given the code sequence
   ```
   bb:
     %vtable = load ptr, ptr %d, !prof !8
     %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
     %1 = load ptr, ptr %vfn
     %call = tail call i32 %1(ptr %d), !prof !9
  ```
   The transformation looks like

   ```
   bb:
    %vtable = load ptr, ptr %d, align 8
    %vfn = getelementptr inbounds i8, ptr %vtable, i64 8  <-- Inst 1
    %func-addr = load ptr, ptr %vfn, align 8  <-- Inst 2
    # compare loaded pointers with address point of vtables
%1 = icmp eq ptr %vtable, getelementptr inbounds (i8, ptr @_ZTV<VTable>,
i32 16)
br i1 %1, label %if.true.direct_targ, label %if.false.orig_indirect,
!prof !18

  if.true.direct_targ:                              ; preds = %bb
    %2 = tail call i32 @<direct-call>(ptr nonnull %d)
    br label %if.end.icp

  if.false.orig_indirect:                           ; preds = %bb
    %call = tail call i32 %func-addr(ptr nonnull %d)
    br label %if.end.icp

if.end.icp: ; preds = %if.false.orig_indirect, %if.true.direct_targ
%4 = phi i32 [ %call, %if.false.orig_indirect ], [ %2,
%if.true.direct_targ ]

   ```
It's intentional that `Inst 1` and `Inst2` remains in `bb` (not in
`if.false.orig_indirect`). A follow up patch will implement code to sink
them (something like how `instcombine` would
[sink](https://github.com/llvm/llvm-project/blob/2fcfc9754a16805b81e541dc8222a8b5cf17a121/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp#L4293)
instructions along with [debug
intrinsics](https://github.com/llvm/llvm-project/blob/2fcfc9754a16805b81e541dc8222a8b5cf17a121/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp#L4356-L4368)
if possible)

* The parent patch is https://github.com/llvm/llvm-project/pull/81181

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
    llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
    llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
index fcb384ec36133..385831f457038 100644
--- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
@@ -15,9 +15,12 @@
 #define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
 
 namespace llvm {
+template <typename T> class ArrayRef;
+class Constant;
 class CallBase;
 class CastInst;
 class Function;
+class Instruction;
 class MDNode;
 class Value;
 
@@ -41,7 +44,9 @@ bool isLegalToPromote(const CallBase &CB, Function *Callee,
 CallBase &promoteCall(CallBase &CB, Function *Callee,
                       CastInst **RetBitCast = nullptr);
 
-/// Promote the given indirect call site to conditionally call \p Callee.
+/// Promote the given indirect call site to conditionally call \p Callee. The
+/// promoted direct call instruction is predicated on `CB.getCalledOperand() ==
+/// Callee`.
 ///
 /// This function creates an if-then-else structure at the location of the call
 /// site. The original call site is moved into the "else" block. A clone of the
@@ -51,6 +56,22 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
 CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
                                     MDNode *BranchWeights = nullptr);
 
+/// This is similar to `promoteCallWithIfThenElse` except that the condition to
+/// promote a virtual call is that \p VPtr is the same as any of \p
+/// AddressPoints.
+///
+/// This function is expected to be used on virtual calls (a subset of indirect
+/// calls). \p VPtr is the virtual table address stored in the objects, and
+/// \p AddressPoints contains vtable address points. A vtable address point is
+/// a location inside the vtable that's referenced by vpointer in C++ objects.
+///
+/// TODO: sink the address-calculation instructions of indirect callee to the
+/// indirect call fallback after transformation.
+CallBase &promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
+                                   Function *Callee,
+                                   ArrayRef<Constant *> AddressPoints,
+                                   MDNode *BranchWeights);
+
 /// Try to promote (devirtualize) a virtual call on an Alloca. Return true on
 /// success.
 ///
@@ -76,11 +97,11 @@ bool tryPromoteCall(CallBase &CB);
 
 /// Predicate and clone the given call site.
 ///
-/// This function creates an if-then-else structure at the location of the call
-/// site. The "if" condition compares the call site's called value to the given
-/// callee. The original call site is moved into the "else" block, and a clone
-/// of the call site is placed in the "then" block. The cloned instruction is
-/// returned.
+/// This function creates an if-then-else structure at the location of the
+/// call site. The "if" condition compares the call site's called value to
+/// the given callee. The original call site is moved into the "else" block,
+/// and a clone of the call site is placed in the "then" block. The cloned
+/// instruction is returned.
 CallBase &versionCallSite(CallBase &CB, Value *Callee, MDNode *BranchWeights);
 
 } // end namespace llvm

diff  --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index 9ca9aaf9ee9df..dda80d419999d 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -12,9 +12,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/TypeMetadataUtils.h"
 #include "llvm/IR/AttributeMask.h"
+#include "llvm/IR/Constant.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -188,9 +190,9 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
 /// Predicate and clone the given call site.
 ///
 /// This function creates an if-then-else structure at the location of the call
-/// site. The "if" condition is specified by `Cond`. The original call site is
-/// moved into the "else" block, and a clone of the call site is placed in the
-/// "then" block. The cloned instruction is returned.
+/// site. The "if" condition is specified by `Cond`.
+/// The original call site is moved into the "else" block, and a clone of the
+/// call site is placed in the "then" block. The cloned instruction is returned.
 ///
 /// For example, the call instruction below:
 ///
@@ -518,7 +520,8 @@ CallBase &llvm::promoteCall(CallBase &CB, Function *Callee,
     Type *FormalTy = CalleeType->getParamType(ArgNo);
     Type *ActualTy = Arg->getType();
     if (FormalTy != ActualTy) {
-      auto *Cast = CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator());
+      auto *Cast =
+          CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator());
       CB.setArgOperand(ArgNo, Cast);
 
       // Remove any incompatible attributes for the argument.
@@ -568,6 +571,27 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
   return promoteCall(NewInst, Callee);
 }
 
+CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
+                                         Function *Callee,
+                                         ArrayRef<Constant *> AddressPoints,
+                                         MDNode *BranchWeights) {
+  assert(!AddressPoints.empty() && "Caller should guarantee");
+  IRBuilder<> Builder(&CB);
+  SmallVector<Value *, 2> ICmps;
+  for (auto &AddressPoint : AddressPoints)
+    ICmps.push_back(Builder.CreateICmpEQ(VPtr, AddressPoint));
+
+  // TODO: Perform tree height reduction if the number of ICmps is high.
+  Value *Cond = Builder.CreateOr(ICmps);
+
+  // Version the indirect call site. If Cond is true, 'NewInst' will be
+  // executed, otherwise the original call site will be executed.
+  CallBase &NewInst = versionCallSiteWithCond(CB, Cond, BranchWeights);
+
+  // Promote 'NewInst' so that it directly calls the desired function.
+  return promoteCall(NewInst, Callee);
+}
+
 bool llvm::tryPromoteCall(CallBase &CB) {
   assert(!CB.getCalledFunction());
   Module *M = CB.getCaller()->getParent();

diff  --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
index 0e9641c5846f3..2d457eb3b678a 100644
--- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
@@ -8,9 +8,12 @@
 
 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
 #include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/NoFolder.h"
 #include "llvm/Support/SourceMgr.h"
 #include "gtest/gtest.h"
 
@@ -24,6 +27,21 @@ static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
   return Mod;
 }
 
+// Returns a constant representing the vtable's address point specified by the
+// offset.
+static Constant *getVTableAddressPointOffset(GlobalVariable *VTable,
+                                             uint32_t AddressPointOffset) {
+  Module &M = *VTable->getParent();
+  LLVMContext &Context = M.getContext();
+  assert(AddressPointOffset <
+             M.getDataLayout().getTypeAllocSize(VTable->getValueType()) &&
+         "Out-of-bound access");
+
+  return ConstantExpr::getInBoundsGetElementPtr(
+      Type::getInt8Ty(Context), VTable,
+      llvm::ConstantInt::get(Type::getInt32Ty(Context), AddressPointOffset));
+}
+
 TEST(CallPromotionUtilsTest, TryPromoteCall) {
   LLVMContext C;
   std::unique_ptr<Module> M = parseIR(C,
@@ -368,3 +386,73 @@ declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this)
   bool IsPromoted = tryPromoteCall(*CI);
   EXPECT_FALSE(IsPromoted);
 }
+
+TEST(CallPromotionUtilsTest, promoteCallWithVTableCmp) {
+  LLVMContext C;
+  std::unique_ptr<Module> M = parseIR(C,
+                                      R"IR(
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+ at _ZTV5Base1 = constant { [4 x ptr] } { [4 x ptr] [ptr null, ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !0
+ at _ZTV8Derived1 = constant { [4 x ptr], [3 x ptr] } { [4 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev], [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !0, !type !1, !type !2
+ at _ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !3, !type !4, !type !5, !type !6
+
+define i32 @testfunc(ptr %d) {
+entry:
+  %vtable = load ptr, ptr %d, !prof !7
+  %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
+  %0 = load ptr, ptr %vfn
+  %call = tail call i32 %0(ptr %d), !prof !8
+  ret i32 %call
+}
+
+define i32 @_ZN5Base15func1Ev(ptr %this) {
+entry:
+  ret i32 2
+}
+
+declare i32 @_ZN5Base25func2Ev(ptr)
+declare i32 @_ZN5Base15func0Ev(ptr)
+declare void @_ZN5Base35func3Ev(ptr)
+
+!0 = !{i64 16, !"_ZTS5Base1"}
+!1 = !{i64 48, !"_ZTS5Base2"}
+!2 = !{i64 16, !"_ZTS8Derived1"}
+!3 = !{i64 64, !"_ZTS5Base1"}
+!4 = !{i64 40, !"_ZTS5Base2"}
+!5 = !{i64 16, !"_ZTS5Base3"}
+!6 = !{i64 16, !"_ZTS8Derived2"}
+!7 = !{!"VP", i32 2, i64 1600, i64 -9064381665493407289, i64 800, i64 5035968517245772950, i64 500, i64 3215870116411581797, i64 300}
+!8 = !{!"VP", i32 0, i64 1600, i64 6804820478065511155, i64 1600})IR");
+
+  Function *F = M->getFunction("testfunc");
+  CallInst *CI = dyn_cast<CallInst>(&*std::next(F->front().rbegin()));
+  ASSERT_TRUE(CI && CI->isIndirectCall());
+
+  // Create the constant and the branch weights
+  SmallVector<Constant *, 3> VTableAddressPoints;
+
+  for (auto &[VTableName, AddressPointOffset] : {std::pair{"_ZTV5Base1", 16},
+                                                 {"_ZTV8Derived1", 16},
+                                                 {"_ZTV8Derived2", 64}})
+    VTableAddressPoints.push_back(getVTableAddressPointOffset(
+        M->getGlobalVariable(VTableName), AddressPointOffset));
+
+  MDBuilder MDB(C);
+  MDNode *BranchWeights = MDB.createBranchWeights(1600, 0);
+
+  size_t OrigEntryBBSize = F->front().size();
+
+  LoadInst *VPtr = dyn_cast<LoadInst>(&*F->front().begin());
+
+  Function *Callee = M->getFunction("_ZN5Base15func1Ev");
+  // Tests that promoted direct call is returned.
+  CallBase &DirectCB = promoteCallWithVTableCmp(
+      *CI, VPtr, Callee, VTableAddressPoints, BranchWeights);
+  EXPECT_EQ(DirectCB.getCalledOperand(), Callee);
+
+  // Promotion inserts 3 icmp instructions and 2 or instructions, and removes
+  // 1 call instruction from the entry block.
+  EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
+}


        


More information about the llvm-commits mailing list