[llvm] r208910 - Teach the inliner how to preserve musttail invariants

Reid Kleckner reid at kleckner.net
Thu May 15 13:11:29 PDT 2014


Author: rnk
Date: Thu May 15 15:11:28 2014
New Revision: 208910

URL: http://llvm.org/viewvc/llvm-project?rev=208910&view=rev
Log:
Teach the inliner how to preserve musttail invariants

The interesting case is what happens when you inline a musttail call
through a musttail call site.  In this case, we can't break perfect
forwarding or allow any stack growth.

Instead of merging control flow from the inlined return instruction
after a musttail call into the body of the caller, leave the inlined
return instruction in the caller so that the musttail call stays in the
tail position.

More work is required in http://reviews.llvm.org/D3630 to handle the
case where the inlined function has dynamic allocas or byval arguments.

Reviewers: chandlerc

Differential Revision: http://reviews.llvm.org/D3491

Modified:
    llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp
    llvm/trunk/test/Transforms/Inline/inline-tail.ll

Modified: llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp?rev=208910&r1=208909&r2=208910&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp Thu May 15 15:11:28 2014
@@ -19,6 +19,7 @@
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/CallSite.h"
+#include "llvm/IR/CFG.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfo.h"
@@ -478,6 +479,33 @@ static void fixupLineNumbers(Function *F
   }
 }
 
+/// Returns a musttail call instruction if one immediately precedes the given
+/// return instruction with an optional bitcast instruction between them.
+static CallInst *getPrecedingMustTailCall(ReturnInst *RI) {
+  Instruction *Prev = RI->getPrevNode();
+  if (!Prev)
+    return nullptr;
+
+  if (Value *RV = RI->getReturnValue()) {
+    if (RV != Prev)
+      return nullptr;
+
+    // Look through the optional bitcast.
+    if (auto *BI = dyn_cast<BitCastInst>(Prev)) {
+      RV = BI->getOperand(0);
+      Prev = BI->getPrevNode();
+      if (!Prev || RV != Prev)
+        return nullptr;
+    }
+  }
+
+  if (auto *CI = dyn_cast<CallInst>(Prev)) {
+    if (CI->isMustTailCall())
+      return CI;
+  }
+  return nullptr;
+}
+
 /// InlineFunction - This function inlines the called function into the basic
 /// block of the caller.  This returns false if it is not possible to inline
 /// this call.  The program is still in a well defined state if this occurs
@@ -503,8 +531,10 @@ bool llvm::InlineFunction(CallSite CS, I
 
   // If the call to the callee is not a tail call, we must clear the 'tail'
   // flags on any calls that we inline.
-  bool MustClearTailCallFlags =
-    !(isa<CallInst>(TheCall) && cast<CallInst>(TheCall)->isTailCall());
+  CallInst::TailCallKind CallSiteTailKind = CallInst::TCK_None;
+  if (CallInst *CI = dyn_cast<CallInst>(TheCall))
+    CallSiteTailKind = CI->getTailCallKind();
+  bool MustClearTailCallFlags = false;
 
   // If the call to the callee cannot throw, set the 'nounwind' flag on any
   // calls that we inline.
@@ -661,6 +691,41 @@ bool llvm::InlineFunction(CallSite CS, I
     }
   }
 
+  bool InlinedMustTailCalls = false;
+  if (InlinedFunctionInfo.ContainsCalls) {
+    for (Function::iterator BB = FirstNewBlock, E = Caller->end(); BB != E;
+         ++BB) {
+      for (Instruction &I : *BB) {
+        CallInst *CI = dyn_cast<CallInst>(&I);
+        if (!CI)
+          continue;
+
+        // We need to reduce the strength of any inlined tail calls.  For
+        // musttail, we have to avoid introducing potential unbounded stack
+        // growth.  For example, if functions 'f' and 'g' are mutually recursive
+        // with musttail, we can inline 'g' into 'f' so long as we preserve
+        // musttail on the cloned call to 'f'.  If either the inlined call site
+        // or the cloned call site is *not* musttail, the program already has
+        // one frame of stack growth, so it's safe to remove musttail.  Here is
+        // a table of example transformations:
+        //
+        //    f -> musttail g -> musttail f  ==>  f -> musttail f
+        //    f -> musttail g ->     tail f  ==>  f ->     tail f
+        //    f ->          g -> musttail f  ==>  f ->          f
+        //    f ->          g ->     tail f  ==>  f ->          f
+        CallInst::TailCallKind ChildTCK = CI->getTailCallKind();
+        ChildTCK = std::min(CallSiteTailKind, ChildTCK);
+        CI->setTailCallKind(ChildTCK);
+        InlinedMustTailCalls |= CI->isMustTailCall();
+
+        // Calls inlined through a 'nounwind' call site should be marked
+        // 'nounwind'.
+        if (MarkNoUnwind)
+          CI->setDoesNotThrow();
+      }
+    }
+  }
+
   // Leave lifetime markers for the static alloca's, scoping them to the
   // function we just inlined.
   if (InsertLifetime && !IFI.StaticAllocas.empty()) {
@@ -693,10 +758,8 @@ bool llvm::InlineFunction(CallSite CS, I
       }
 
       builder.CreateLifetimeStart(AI, AllocaSize);
-      for (unsigned ri = 0, re = Returns.size(); ri != re; ++ri) {
-        IRBuilder<> builder(Returns[ri]);
-        builder.CreateLifetimeEnd(AI, AllocaSize);
-      }
+      for (ReturnInst *RI : Returns)
+        IRBuilder<>(RI).CreateLifetimeEnd(AI, AllocaSize);
     }
   }
 
@@ -714,26 +777,8 @@ bool llvm::InlineFunction(CallSite CS, I
 
     // Insert a call to llvm.stackrestore before any return instructions in the
     // inlined function.
-    for (unsigned i = 0, e = Returns.size(); i != e; ++i) {
-      IRBuilder<>(Returns[i]).CreateCall(StackRestore, SavedPtr);
-    }
-  }
-
-  // If we are inlining tail call instruction through a call site that isn't
-  // marked 'tail', we must remove the tail marker for any calls in the inlined
-  // code.  Also, calls inlined through a 'nounwind' call site should be marked
-  // 'nounwind'.
-  if (InlinedFunctionInfo.ContainsCalls &&
-      (MustClearTailCallFlags || MarkNoUnwind)) {
-    for (Function::iterator BB = FirstNewBlock, E = Caller->end();
-         BB != E; ++BB)
-      for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
-        if (CallInst *CI = dyn_cast<CallInst>(I)) {
-          if (MustClearTailCallFlags)
-            CI->setTailCall(false);
-          if (MarkNoUnwind)
-            CI->setDoesNotThrow();
-        }
+    for (ReturnInst *RI : Returns)
+      IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr);
   }
 
   // If we are inlining for an invoke instruction, we must make sure to rewrite
@@ -741,6 +786,42 @@ bool llvm::InlineFunction(CallSite CS, I
   if (InvokeInst *II = dyn_cast<InvokeInst>(TheCall))
     HandleInlinedInvoke(II, FirstNewBlock, InlinedFunctionInfo);
 
+  // Handle any inlined musttail call sites.  In order for a new call site to be
+  // musttail, the source of the clone and the inlined call site must have been
+  // musttail.  Therefore it's safe to return without merging control into the
+  // phi below.
+  if (InlinedMustTailCalls) {
+    // Check if we need to bitcast the result of any musttail calls.
+    Type *NewRetTy = Caller->getReturnType();
+    bool NeedBitCast = !TheCall->use_empty() && TheCall->getType() != NewRetTy;
+
+    // Handle the returns preceded by musttail calls separately.
+    SmallVector<ReturnInst *, 8> NormalReturns;
+    for (ReturnInst *RI : Returns) {
+      CallInst *ReturnedMustTail = getPrecedingMustTailCall(RI);
+      if (!ReturnedMustTail) {
+        NormalReturns.push_back(RI);
+        continue;
+      }
+      if (!NeedBitCast)
+        continue;
+
+      // Delete the old return and any preceding bitcast.
+      BasicBlock *CurBB = RI->getParent();
+      auto *OldCast = dyn_cast_or_null<BitCastInst>(RI->getReturnValue());
+      RI->eraseFromParent();
+      if (OldCast)
+        OldCast->eraseFromParent();
+
+      // Insert a new bitcast and return with the right type.
+      IRBuilder<> Builder(CurBB);
+      Builder.CreateRet(Builder.CreateBitCast(ReturnedMustTail, NewRetTy));
+    }
+
+    // Leave behind the normal returns so we can merge control flow.
+    std::swap(Returns, NormalReturns);
+  }
+
   // If we cloned in _exactly one_ basic block, and if that block ends in a
   // return instruction, we splice the body of the inlined callee directly into
   // the calling basic block.
@@ -896,6 +977,11 @@ bool llvm::InlineFunction(CallSite CS, I
   // Since we are now done with the Call/Invoke, we can delete it.
   TheCall->eraseFromParent();
 
+  // If we inlined any musttail calls and the original return is now
+  // unreachable, delete it.  It can only contain a bitcast and ret.
+  if (InlinedMustTailCalls && pred_begin(AfterCallBB) == pred_end(AfterCallBB))
+    AfterCallBB->eraseFromParent();
+
   // We should always be able to fold the entry block of the function into the
   // single predecessor of the block...
   assert(cast<BranchInst>(Br)->isUnconditional() && "splitBasicBlock broken!");

Modified: llvm/trunk/test/Transforms/Inline/inline-tail.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/Inline/inline-tail.ll?rev=208910&r1=208909&r2=208910&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/Inline/inline-tail.ll (original)
+++ llvm/trunk/test/Transforms/Inline/inline-tail.ll Thu May 15 15:11:28 2014
@@ -1,15 +1,146 @@
-; RUN: opt < %s -inline -S | not grep tail
+; RUN: opt < %s -inline -S | FileCheck %s
 
-declare void @bar(i32*)
+; We have to apply the less restrictive TailCallKind of the call site being
+; inlined and any call sites cloned into the caller.
 
-define internal void @foo(i32* %P) {
-        tail call void @bar( i32* %P )
-        ret void
+; No tail marker after inlining, since test_capture_c captures an alloca.
+; CHECK: define void @test_capture_a(
+; CHECK-NOT: tail
+; CHECK: call void @test_capture_c(
+
+declare void @test_capture_c(i32*)
+define internal void @test_capture_b(i32* %P) {
+  tail call void @test_capture_c(i32* %P)
+  ret void
+}
+define void @test_capture_a() {
+  %A = alloca i32  		; captured by test_capture_b
+  call void @test_capture_b(i32* %A)
+  ret void
+}
+
+; No musttail marker after inlining, since the prototypes don't match.
+; CHECK: define void @test_proto_mismatch_a(
+; CHECK-NOT: musttail
+; CHECK: call void @test_proto_mismatch_c(
+
+declare void @test_proto_mismatch_c(i32*)
+define internal void @test_proto_mismatch_b(i32* %p) {
+  musttail call void @test_proto_mismatch_c(i32* %p)
+  ret void
+}
+define void @test_proto_mismatch_a() {
+  call void @test_proto_mismatch_b(i32* null)
+  ret void
 }
 
-define void @caller() {
-        %A = alloca i32         ; <i32*> [#uses=1]
-        call void @foo( i32* %A )
-        ret void
+; After inlining through a musttail call site, we need to keep musttail markers
+; to prevent unbounded stack growth.
+; CHECK: define void @test_musttail_basic_a(
+; CHECK: musttail call void @test_musttail_basic_c(
+
+declare void @test_musttail_basic_c(i32* %p)
+define internal void @test_musttail_basic_b(i32* %p) {
+  musttail call void @test_musttail_basic_c(i32* %p)
+  ret void
+}
+define void @test_musttail_basic_a(i32* %p) {
+  musttail call void @test_musttail_basic_b(i32* %p)
+  ret void
+}
+
+; We can't merge the returns.
+; CHECK: define void @test_multiret_a(
+; CHECK: musttail call void @test_multiret_c(
+; CHECK-NEXT: ret void
+; CHECK: musttail call void @test_multiret_d(
+; CHECK-NEXT: ret void
+
+declare void @test_multiret_c(i1 zeroext %b)
+declare void @test_multiret_d(i1 zeroext %b)
+define internal void @test_multiret_b(i1 zeroext %b) {
+  br i1 %b, label %c, label %d
+c:
+  musttail call void @test_multiret_c(i1 zeroext %b)
+  ret void
+d:
+  musttail call void @test_multiret_d(i1 zeroext %b)
+  ret void
+}
+define void @test_multiret_a(i1 zeroext %b) {
+  musttail call void @test_multiret_b(i1 zeroext %b)
+  ret void
 }
 
+; We have to avoid bitcast chains.
+; CHECK: define i32* @test_retptr_a(
+; CHECK: musttail call i8* @test_retptr_c(
+; CHECK-NEXT: bitcast i8* {{.*}} to i32*
+; CHECK-NEXT: ret i32*
+
+declare i8* @test_retptr_c()
+define internal i16* @test_retptr_b() {
+  %rv = musttail call i8* @test_retptr_c()
+  %v = bitcast i8* %rv to i16*
+  ret i16* %v
+}
+define i32* @test_retptr_a() {
+  %rv = musttail call i16* @test_retptr_b()
+  %v = bitcast i16* %rv to i32*
+  ret i32* %v
+}
+
+; Combine the last two cases: multiple returns with pointer bitcasts.
+; CHECK: define i32* @test_multiptrret_a(
+; CHECK: musttail call i8* @test_multiptrret_c(
+; CHECK-NEXT: bitcast i8* {{.*}} to i32*
+; CHECK-NEXT: ret i32*
+; CHECK: musttail call i8* @test_multiptrret_d(
+; CHECK-NEXT: bitcast i8* {{.*}} to i32*
+; CHECK-NEXT: ret i32*
+
+declare i8* @test_multiptrret_c(i1 zeroext %b)
+declare i8* @test_multiptrret_d(i1 zeroext %b)
+define internal i16* @test_multiptrret_b(i1 zeroext %b) {
+  br i1 %b, label %c, label %d
+c:
+  %c_rv = musttail call i8* @test_multiptrret_c(i1 zeroext %b)
+  %c_v = bitcast i8* %c_rv to i16*
+  ret i16* %c_v
+d:
+  %d_rv = musttail call i8* @test_multiptrret_d(i1 zeroext %b)
+  %d_v = bitcast i8* %d_rv to i16*
+  ret i16* %d_v
+}
+define i32* @test_multiptrret_a(i1 zeroext %b) {
+  %rv = musttail call i16* @test_multiptrret_b(i1 zeroext %b)
+  %v = bitcast i16* %rv to i32*
+  ret i32* %v
+}
+
+; Inline a musttail call site which contains a normal return and a musttail call.
+; CHECK: define i32 @test_mixedret_a(
+; CHECK: br i1 %b
+; CHECK: musttail call i32 @test_mixedret_c(
+; CHECK-NEXT: ret i32
+; CHECK: call i32 @test_mixedret_d(i1 zeroext %b)
+; CHECK: add i32 1,
+; CHECK-NOT: br
+; CHECK: ret i32
+
+declare i32 @test_mixedret_c(i1 zeroext %b)
+declare i32 @test_mixedret_d(i1 zeroext %b)
+define internal i32 @test_mixedret_b(i1 zeroext %b) {
+  br i1 %b, label %c, label %d
+c:
+  %c_rv = musttail call i32 @test_mixedret_c(i1 zeroext %b)
+  ret i32 %c_rv
+d:
+  %d_rv = call i32 @test_mixedret_d(i1 zeroext %b)
+  %d_rv1 = add i32 1, %d_rv
+  ret i32 %d_rv1
+}
+define i32 @test_mixedret_a(i1 zeroext %b) {
+  %rv = musttail call i32 @test_mixedret_b(i1 zeroext %b)
+  ret i32 %rv
+}





More information about the llvm-commits mailing list