[llvm] r318028 - [PartialInliner] Inline vararg functions that forward varargs.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 13 02:35:52 PST 2017


Author: fhahn
Date: Mon Nov 13 02:35:52 2017
New Revision: 318028

URL: http://llvm.org/viewvc/llvm-project?rev=318028&view=rev
Log:
[PartialInliner] Inline vararg functions that forward varargs.

Summary:
This patch extends the partial inliner to support inlining parts of
vararg functions, if the vararg handling is done in the outlined part.

It adds a `ForwardVarArgsTo` argument to InlineFunction. If it is
non-null, all varargs passed to the inlined function will be added to
all calls to `ForwardVarArgsTo`.

The partial inliner takes care to only pass `ForwardVarArgsTo` if the
varargs handing is done in the outlined function. It checks that vastart
is not part of the function to be inlined.

`test/Transforms/CodeExtractor/PartialInlineNoInline.ll` (already part
of the repo) checks we do not do partial inlining if vastart is used in
a basic block that will be inlined.

Reviewers: davide, davidxl, grosser

Reviewed By: davide, davidxl, grosser

Subscribers: gyiu, grosser, eraman, llvm-commits

Differential Revision: https://reviews.llvm.org/D39607

Added:
    llvm/trunk/test/Transforms/CodeExtractor/PartialInlineVarArg.ll
Modified:
    llvm/trunk/include/llvm/Transforms/Utils/Cloning.h
    llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h
    llvm/trunk/lib/Transforms/IPO/PartialInlining.cpp
    llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp
    llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp

Modified: llvm/trunk/include/llvm/Transforms/Utils/Cloning.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Transforms/Utils/Cloning.h?rev=318028&r1=318027&r2=318028&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Transforms/Utils/Cloning.h (original)
+++ llvm/trunk/include/llvm/Transforms/Utils/Cloning.h Mon Nov 13 02:35:52 2017
@@ -227,12 +227,18 @@ public:
 /// *inlined* code to minimize the actual inserted code, it must not delete
 /// code in the caller as users of this routine may have pointers to
 /// instructions in the caller that need to remain stable.
+///
+/// If ForwardVarArgsTo is passed, inlining a function with varargs is allowed
+/// and all varargs at the callsite will be passed to any calls to
+/// ForwardVarArgsTo. The caller of InlineFunction has to make sure any varargs
+/// are only used by ForwardVarArgsTo.
 bool InlineFunction(CallInst *C, InlineFunctionInfo &IFI,
                     AAResults *CalleeAAR = nullptr, bool InsertLifetime = true);
 bool InlineFunction(InvokeInst *II, InlineFunctionInfo &IFI,
                     AAResults *CalleeAAR = nullptr, bool InsertLifetime = true);
 bool InlineFunction(CallSite CS, InlineFunctionInfo &IFI,
-                    AAResults *CalleeAAR = nullptr, bool InsertLifetime = true);
+                    AAResults *CalleeAAR = nullptr, bool InsertLifetime = true,
+                    Function *ForwardVarArgsTo = nullptr);
 
 /// \brief Clones a loop \p OrigLoop.  Returns the loop and the blocks in \p
 /// Blocks.

Modified: llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h?rev=318028&r1=318027&r2=318028&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h (original)
+++ llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h Mon Nov 13 02:35:52 2017
@@ -56,6 +56,9 @@ class Value;
     BlockFrequencyInfo *BFI;
     BranchProbabilityInfo *BPI;
 
+    // If true, varargs functions can be extracted.
+    bool AllowVarArgs;
+
     // Bits of intermediate state computed at various phases of extraction.
     SetVector<BasicBlock *> Blocks;
     unsigned NumExitBlocks = std::numeric_limits<unsigned>::max();
@@ -67,10 +70,13 @@ class Value;
     /// Given a sequence of basic blocks where the first block in the sequence
     /// dominates the rest, prepare a code extractor object for pulling this
     /// sequence out into its new function. When a DominatorTree is also given,
-    /// extra checking and transformations are enabled.
+    /// extra checking and transformations are enabled. If AllowVarArgs is true,
+    /// vararg functions can be extracted. This is safe, if all vararg handling
+    /// code is extracted, including vastart.
     CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
                   bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
-                  BranchProbabilityInfo *BPI = nullptr);
+                  BranchProbabilityInfo *BPI = nullptr,
+                  bool AllowVarArgs = false);
 
     /// \brief Create a code extractor for a loop body.
     ///
@@ -82,8 +88,11 @@ class Value;
 
     /// \brief Check to see if a block is valid for extraction.
     ///
-    /// Blocks containing EHPads, allocas, invokes, or vastarts are not valid.
-    static bool isBlockValidForExtraction(const BasicBlock &BB);
+    /// Blocks containing EHPads, allocas and invokes are not valid. If
+    /// AllowVarArgs is true, blocks with vastart can be extracted. This is
+    /// safe, if all vararg handling code is extracted, including vastart.
+    static bool isBlockValidForExtraction(const BasicBlock &BB,
+                                          bool AllowVarArgs);
 
     /// \brief Perform the extraction, returning the new function.
     ///

Modified: llvm/trunk/lib/Transforms/IPO/PartialInlining.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/IPO/PartialInlining.cpp?rev=318028&r1=318027&r2=318028&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/IPO/PartialInlining.cpp (original)
+++ llvm/trunk/lib/Transforms/IPO/PartialInlining.cpp Mon Nov 13 02:35:52 2017
@@ -149,7 +149,12 @@ struct PartialInlinerImpl {
     // the return block.
     void NormalizeReturnBlock();
 
-    // Do function outlining:
+    // Do function outlining.
+    // NOTE: For vararg functions that do the vararg handling in the outlined
+    //       function, we temporarily generate IR that does not properly
+    //       forward varargs to the outlined function. Calling InlineFunction
+    //       will update calls to the outlined functions to properly forward
+    //       the varargs.
     Function *doFunctionOutlining();
 
     Function *OrigFunc = nullptr;
@@ -813,7 +818,8 @@ Function *PartialInlinerImpl::FunctionCl
 
   // Extract the body of the if.
   OutlinedFunc = CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false,
-                               ClonedFuncBFI.get(), &BPI)
+                               ClonedFuncBFI.get(), &BPI,
+                               /* AllowVarargs */ true)
                      .extractCodeRegion();
 
   if (OutlinedFunc) {
@@ -938,7 +944,7 @@ bool PartialInlinerImpl::tryPartialInlin
        << ore::NV("Caller", CS.getCaller());
 
     InlineFunctionInfo IFI(nullptr, GetAssumptionCache, PSI);
-    if (!InlineFunction(CS, IFI))
+    if (!InlineFunction(CS, IFI, nullptr, true, Cloner.OutlinedFunc))
       continue;
 
     ORE.emit(OR);

Modified: llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp?rev=318028&r1=318027&r2=318028&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp Mon Nov 13 02:35:52 2017
@@ -78,7 +78,8 @@ AggregateArgsOpt("aggregate-extracted-ar
                  cl::desc("Aggregate arguments to code-extracted functions"));
 
 /// \brief Test whether a block is valid for extraction.
-bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) {
+bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB,
+                                              bool AllowVarArgs) {
   // Landing pads must be in the function where they were inserted for cleanup.
   if (BB.isEHPad())
     return false;
@@ -110,14 +111,19 @@ bool CodeExtractor::isBlockValidForExtra
     }
   }
 
-  // Don't hoist code containing allocas, invokes, or vastarts.
+  // Don't hoist code containing allocas or invokes. If explicitly requested,
+  // allow vastart.
   for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
     if (isa<AllocaInst>(I) || isa<InvokeInst>(I))
       return false;
     if (const CallInst *CI = dyn_cast<CallInst>(I))
       if (const Function *F = CI->getCalledFunction())
-        if (F->getIntrinsicID() == Intrinsic::vastart)
-          return false;
+        if (F->getIntrinsicID() == Intrinsic::vastart) {
+          if (AllowVarArgs)
+            continue;
+          else
+            return false;
+        }
   }
 
   return true;
@@ -125,7 +131,8 @@ bool CodeExtractor::isBlockValidForExtra
 
 /// \brief Build a set of blocks to extract if the input blocks are viable.
 static SetVector<BasicBlock *>
-buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT) {
+buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
+                        bool AllowVarArgs) {
   assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
   SetVector<BasicBlock *> Result;
 
@@ -138,7 +145,7 @@ buildExtractionBlockSet(ArrayRef<BasicBl
 
     if (!Result.insert(BB))
       llvm_unreachable("Repeated basic blocks in extraction input");
-    if (!CodeExtractor::isBlockValidForExtraction(*BB)) {
+    if (!CodeExtractor::isBlockValidForExtraction(*BB, AllowVarArgs)) {
       Result.clear();
       return Result;
     }
@@ -160,15 +167,17 @@ buildExtractionBlockSet(ArrayRef<BasicBl
 
 CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
                              bool AggregateArgs, BlockFrequencyInfo *BFI,
-                             BranchProbabilityInfo *BPI)
+                             BranchProbabilityInfo *BPI, bool AllowVarArgs)
     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
-      BPI(BPI), Blocks(buildExtractionBlockSet(BBs, DT)) {}
+      BPI(BPI), AllowVarArgs(AllowVarArgs),
+      Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs)) {}
 
 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
                              BlockFrequencyInfo *BFI,
                              BranchProbabilityInfo *BPI)
     : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
-      BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT)) {}
+      BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT,
+                                               /* AllowVarArgs */ false)) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -594,7 +603,8 @@ Function *CodeExtractor::constructFuncti
     paramTy.push_back(PointerType::getUnqual(StructTy));
   }
   FunctionType *funcType =
-                  FunctionType::get(RetTy, paramTy, false);
+                  FunctionType::get(RetTy, paramTy,
+                                    AllowVarArgs && oldFunction->isVarArg());
 
   // Create the new function
   Function *newFunction = Function::Create(funcType,
@@ -957,12 +967,31 @@ Function *CodeExtractor::extractCodeRegi
   if (!isEligible())
     return nullptr;
 
-  ValueSet inputs, outputs, SinkingCands, HoistingCands;
-  BasicBlock *CommonExit = nullptr;
-
   // Assumption: this is a single-entry code region, and the header is the first
   // block in the region.
   BasicBlock *header = *Blocks.begin();
+  Function *oldFunction = header->getParent();
+
+  // For functions with varargs, check that varargs handling is only done in the
+  // outlined function, i.e vastart and vaend are only used in outlined blocks.
+  if (AllowVarArgs && oldFunction->getFunctionType()->isVarArg()) {
+    auto containsVarArgIntrinsic = [](Instruction &I) {
+      if (const CallInst *CI = dyn_cast<CallInst>(&I))
+        if (const Function *F = CI->getCalledFunction())
+          return F->getIntrinsicID() == Intrinsic::vastart ||
+                 F->getIntrinsicID() == Intrinsic::vaend;
+      return false;
+    };
+
+    for (auto &BB : *oldFunction) {
+      if (Blocks.count(&BB))
+        continue;
+      if (llvm::any_of(BB, containsVarArgIntrinsic))
+        return nullptr;
+    }
+  }
+  ValueSet inputs, outputs, SinkingCands, HoistingCands;
+  BasicBlock *CommonExit = nullptr;
 
   // Calculate the entry frequency of the new function before we change the root
   //   block.
@@ -984,8 +1013,6 @@ Function *CodeExtractor::extractCodeRegi
   // that the return is not in the region.
   splitReturnBlocks();
 
-  Function *oldFunction = header->getParent();
-
   // This takes place of the original loop
   BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), 
                                                 "codeRepl", oldFunction,

Modified: llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp?rev=318028&r1=318027&r2=318028&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp Mon Nov 13 02:35:52 2017
@@ -1490,7 +1490,8 @@ static void updateCalleeCount(BlockFrequ
 /// exists in the instruction stream.  Similarly this will inline a recursive
 /// function by one level.
 bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,
-                          AAResults *CalleeAAR, bool InsertLifetime) {
+                          AAResults *CalleeAAR, bool InsertLifetime,
+                          Function *ForwardVarArgsTo) {
   Instruction *TheCall = CS.getInstruction();
   assert(TheCall->getParent() && TheCall->getFunction()
          && "Instruction not in function!");
@@ -1500,8 +1501,9 @@ bool llvm::InlineFunction(CallSite CS, I
 
   Function *CalledFunc = CS.getCalledFunction();
   if (!CalledFunc ||              // Can't inline external function or indirect
-      CalledFunc->isDeclaration() || // call, or call to a vararg function!
-      CalledFunc->getFunctionType()->isVarArg()) return false;
+      CalledFunc->isDeclaration() ||
+      (!ForwardVarArgsTo && CalledFunc->isVarArg())) // call, or call to a vararg function!
+      return false;
 
   // The inliner does not know how to inline through calls with operand bundles
   // in general ...
@@ -1628,8 +1630,8 @@ bool llvm::InlineFunction(CallSite CS, I
 
     auto &DL = Caller->getParent()->getDataLayout();
 
-    assert(CalledFunc->arg_size() == CS.arg_size() &&
-           "No varargs calls can be inlined!");
+    assert((CalledFunc->arg_size() == CS.arg_size() || ForwardVarArgsTo) &&
+           "Varargs calls can only be inlined if the Varargs are forwarded!");
 
     // Calculate the vector of arguments to pass into the function cloner, which
     // matches up the formal to the actual argument values.
@@ -1811,6 +1813,11 @@ bool llvm::InlineFunction(CallSite CS, I
       replaceDbgDeclareForAlloca(AI, AI, DIB, /*Deref=*/false);
   }
 
+  SmallVector<Value*,4> VarArgsToForward;
+  for (unsigned i = CalledFunc->getFunctionType()->getNumParams();
+       i < CS.getNumArgOperands(); i++)
+    VarArgsToForward.push_back(CS.getArgOperand(i));
+
   bool InlinedMustTailCalls = false, InlinedDeoptimizeCalls = false;
   if (InlinedFunctionInfo.ContainsCalls) {
     CallInst::TailCallKind CallSiteTailKind = CallInst::TCK_None;
@@ -1819,7 +1826,8 @@ bool llvm::InlineFunction(CallSite CS, I
 
     for (Function::iterator BB = FirstNewBlock, E = Caller->end(); BB != E;
          ++BB) {
-      for (Instruction &I : *BB) {
+      for (auto II = BB->begin(); II != BB->end();) {
+        Instruction &I = *II++;
         CallInst *CI = dyn_cast<CallInst>(&I);
         if (!CI)
           continue;
@@ -1850,6 +1858,14 @@ bool llvm::InlineFunction(CallSite CS, I
         // 'nounwind'.
         if (MarkNoUnwind)
           CI->setDoesNotThrow();
+
+        if (ForwardVarArgsTo && CI->getCalledFunction() == ForwardVarArgsTo) {
+          SmallVector<Value*, 6> Params(CI->arg_operands());
+          Params.append(VarArgsToForward.begin(), VarArgsToForward.end());
+          CallInst *Call = CallInst::Create(CI->getCalledFunction(), Params, "", CI);
+          CI->replaceAllUsesWith(Call);
+          CI->eraseFromParent();
+        }
       }
     }
   }

Added: llvm/trunk/test/Transforms/CodeExtractor/PartialInlineVarArg.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/CodeExtractor/PartialInlineVarArg.ll?rev=318028&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/CodeExtractor/PartialInlineVarArg.ll (added)
+++ llvm/trunk/test/Transforms/CodeExtractor/PartialInlineVarArg.ll Mon Nov 13 02:35:52 2017
@@ -0,0 +1,83 @@
+; RUN: opt < %s -partial-inliner -S -skip-partial-inlining-cost-analysis | FileCheck %s
+; RUN: opt < %s -passes=partial-inliner -S -skip-partial-inlining-cost-analysis | FileCheck %s
+
+ at stat = external global i32, align 4
+
+define i32 @vararg(i32 %count, ...) {
+entry:
+  %vargs = alloca i8*, align 8
+  %stat1 = load i32, i32* @stat, align 4
+  %cmp = icmp slt i32 %stat1, 0
+  br i1 %cmp, label %bb2, label %bb1
+
+bb1:                                              ; preds = %entry
+  %vg1 = add nsw i32 %stat1, 1
+  store i32 %vg1, i32* @stat, align 4
+  %vargs1 = bitcast i8** %vargs to i8*
+  call void @llvm.va_start(i8* %vargs1)
+  %va1 = va_arg i8** %vargs, i32
+  call void @foo(i32 %count, i32 %va1) #2
+  call void @llvm.va_end(i8* %vargs1)
+  br label %bb2
+
+bb2:                                              ; preds = %bb1, %entry
+  %res = phi i32 [ 1, %bb1 ], [ 0, %entry ]
+  ret i32 %res
+}
+
+declare void @foo(i32, i32)
+declare void @llvm.va_start(i8*)
+declare void @llvm.va_end(i8*)
+
+define i32 @caller1(i32 %arg) {
+bb:
+  %tmp = tail call i32 (i32, ...) @vararg(i32 %arg)
+  ret i32 %tmp
+}
+; CHECK-LABEL: @caller1
+; CHECK: codeRepl.i:
+; CHECK-NEXT:  call void (i32, i8**, i32, ...) @vararg.2_bb1(i32 %stat1.i, i8** %vargs.i, i32 %arg)
+
+define i32 @caller2(i32 %arg, float %arg2) {
+bb:
+  %tmp = tail call i32 (i32, ...) @vararg(i32 %arg, i32 10, float %arg2)
+  ret i32 %tmp
+}
+
+; CHECK-LABEL: @caller2
+; CHECK: codeRepl.i:
+; CHECK-NEXT:  call void (i32, i8**, i32, ...) @vararg.2_bb1(i32 %stat1.i, i8** %vargs.i, i32 %arg, i32 10, float %arg2)
+
+; Test case to check that we do not extract a vararg function, if va_end is in
+; a block that is not outlined.
+define i32 @vararg_not_legal(i32 %count, ...) {
+entry:
+  %vargs = alloca i8*, align 8
+  %vargs0 = bitcast i8** %vargs to i8*
+  %stat1 = load i32, i32* @stat, align 4
+  %cmp = icmp slt i32 %stat1, 0
+  br i1 %cmp, label %bb2, label %bb1
+
+bb1:                                              ; preds = %entry
+  %vg1 = add nsw i32 %stat1, 1
+  store i32 %vg1, i32* @stat, align 4
+  %vargs1 = bitcast i8** %vargs to i8*
+  call void @llvm.va_start(i8* %vargs1)
+  %va1 = va_arg i8** %vargs, i32
+  call void @foo(i32 %count, i32 %va1)
+  br label %bb2
+
+bb2:                                              ; preds = %bb1, %entry
+  %res = phi i32 [ 1, %bb1 ], [ 0, %entry ]
+  %ptr = phi i8* [ %vargs1, %bb1 ], [ %vargs0, %entry]
+  call void @llvm.va_end(i8* %ptr)
+  ret i32 %res
+}
+
+; CHECK-LABEL: @caller3
+; CHECK: tail call i32 (i32, ...) @vararg_not_legal(i32 %arg, i32 %arg)
+define i32 @caller3(i32 %arg) {
+bb:
+  %res = tail call i32 (i32, ...) @vararg_not_legal(i32 %arg, i32 %arg)
+  ret i32 %res
+}




More information about the llvm-commits mailing list