[llvm] b8a8ef3 - [SafeStack] Make sure SafeStack does not break musttail call contract

Xun Li via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 10 20:46:17 PST 2020


Author: Xun Li
Date: 2020-11-10T20:46:05-08:00
New Revision: b8a8ef32762bfe02b10495595e578002b29c8dc8

URL: https://github.com/llvm/llvm-project/commit/b8a8ef32762bfe02b10495595e578002b29c8dc8
DIFF: https://github.com/llvm/llvm-project/commit/b8a8ef32762bfe02b10495595e578002b29c8dc8.diff

LOG: [SafeStack] Make sure SafeStack does not break musttail call contract

SafeStack instrumentation should not insert anything inbetween musttail call and return instruction.
For every ReturnInst that needs to be instrumented, we adjust the insertion point to the musttail call if exists.

Differential Revision: https://reviews.llvm.org/D90702

Added: 
    llvm/test/Transforms/SafeStack/X86/musttail.ll

Modified: 
    llvm/lib/CodeGen/SafeStack.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SafeStack.cpp b/llvm/lib/CodeGen/SafeStack.cpp
index 55478c232dd7..52d3178c4d20 100644
--- a/llvm/lib/CodeGen/SafeStack.cpp
+++ b/llvm/lib/CodeGen/SafeStack.cpp
@@ -151,7 +151,7 @@ class SafeStack {
   Value *getStackGuard(IRBuilder<> &IRB, Function &F);
 
   /// Load stack guard from the frame and check if it has changed.
-  void checkStackGuard(IRBuilder<> &IRB, Function &F, ReturnInst &RI,
+  void checkStackGuard(IRBuilder<> &IRB, Function &F, Instruction &RI,
                        AllocaInst *StackGuardSlot, Value *StackGuard);
 
   /// Find all static allocas, dynamic allocas, return instructions and
@@ -160,7 +160,7 @@ class SafeStack {
   void findInsts(Function &F, SmallVectorImpl<AllocaInst *> &StaticAllocas,
                  SmallVectorImpl<AllocaInst *> &DynamicAllocas,
                  SmallVectorImpl<Argument *> &ByValArguments,
-                 SmallVectorImpl<ReturnInst *> &Returns,
+                 SmallVectorImpl<Instruction *> &Returns,
                  SmallVectorImpl<Instruction *> &StackRestorePoints);
 
   /// Calculate the allocation size of a given alloca. Returns 0 if the
@@ -168,15 +168,13 @@ class SafeStack {
   uint64_t getStaticAllocaAllocationSize(const AllocaInst* AI);
 
   /// Allocate space for all static allocas in \p StaticAllocas,
-  /// replace allocas with pointers into the unsafe stack and generate code to
-  /// restore the stack pointer before all return instructions in \p Returns.
+  /// replace allocas with pointers into the unsafe stack.
   ///
   /// \returns A pointer to the top of the unsafe stack after all unsafe static
   /// allocas are allocated.
   Value *moveStaticAllocasToUnsafeStack(IRBuilder<> &IRB, Function &F,
                                         ArrayRef<AllocaInst *> StaticAllocas,
                                         ArrayRef<Argument *> ByValArguments,
-                                        ArrayRef<ReturnInst *> Returns,
                                         Instruction *BasePointer,
                                         AllocaInst *StackGuardSlot);
 
@@ -383,7 +381,7 @@ void SafeStack::findInsts(Function &F,
                           SmallVectorImpl<AllocaInst *> &StaticAllocas,
                           SmallVectorImpl<AllocaInst *> &DynamicAllocas,
                           SmallVectorImpl<Argument *> &ByValArguments,
-                          SmallVectorImpl<ReturnInst *> &Returns,
+                          SmallVectorImpl<Instruction *> &Returns,
                           SmallVectorImpl<Instruction *> &StackRestorePoints) {
   for (Instruction &I : instructions(&F)) {
     if (auto AI = dyn_cast<AllocaInst>(&I)) {
@@ -401,7 +399,10 @@ void SafeStack::findInsts(Function &F,
         DynamicAllocas.push_back(AI);
       }
     } else if (auto RI = dyn_cast<ReturnInst>(&I)) {
-      Returns.push_back(RI);
+      if (CallInst *CI = I.getParent()->getTerminatingMustTailCall())
+        Returns.push_back(CI);
+      else
+        Returns.push_back(RI);
     } else if (auto CI = dyn_cast<CallInst>(&I)) {
       // setjmps require stack restore.
       if (CI->getCalledFunction() && CI->canReturnTwice())
@@ -465,7 +466,7 @@ SafeStack::createStackRestorePoints(IRBuilder<> &IRB, Function &F,
   return DynamicTop;
 }
 
-void SafeStack::checkStackGuard(IRBuilder<> &IRB, Function &F, ReturnInst &RI,
+void SafeStack::checkStackGuard(IRBuilder<> &IRB, Function &F, Instruction &RI,
                                 AllocaInst *StackGuardSlot, Value *StackGuard) {
   Value *V = IRB.CreateLoad(StackPtrTy, StackGuardSlot);
   Value *Cmp = IRB.CreateICmpNE(StackGuard, V);
@@ -490,8 +491,8 @@ void SafeStack::checkStackGuard(IRBuilder<> &IRB, Function &F, ReturnInst &RI,
 /// prologue into a local variable and restore it in the epilogue.
 Value *SafeStack::moveStaticAllocasToUnsafeStack(
     IRBuilder<> &IRB, Function &F, ArrayRef<AllocaInst *> StaticAllocas,
-    ArrayRef<Argument *> ByValArguments, ArrayRef<ReturnInst *> Returns,
-    Instruction *BasePointer, AllocaInst *StackGuardSlot) {
+    ArrayRef<Argument *> ByValArguments, Instruction *BasePointer,
+    AllocaInst *StackGuardSlot) {
   if (StaticAllocas.empty() && ByValArguments.empty())
     return BasePointer;
 
@@ -759,7 +760,7 @@ bool SafeStack::run() {
   SmallVector<AllocaInst *, 16> StaticAllocas;
   SmallVector<AllocaInst *, 4> DynamicAllocas;
   SmallVector<Argument *, 4> ByValArguments;
-  SmallVector<ReturnInst *, 4> Returns;
+  SmallVector<Instruction *, 4> Returns;
 
   // Collect all points where stack gets unwound and needs to be restored
   // This is only necessary because the runtime (setjmp and unwind code) is
@@ -812,7 +813,7 @@ bool SafeStack::run() {
     StackGuardSlot = IRB.CreateAlloca(StackPtrTy, nullptr);
     IRB.CreateStore(StackGuard, StackGuardSlot);
 
-    for (ReturnInst *RI : Returns) {
+    for (Instruction *RI : Returns) {
       IRBuilder<> IRBRet(RI);
       checkStackGuard(IRBRet, F, *RI, StackGuardSlot, StackGuard);
     }
@@ -820,9 +821,8 @@ bool SafeStack::run() {
 
   // The top of the unsafe stack after all unsafe static allocas are
   // allocated.
-  Value *StaticTop =
-      moveStaticAllocasToUnsafeStack(IRB, F, StaticAllocas, ByValArguments,
-                                     Returns, BasePointer, StackGuardSlot);
+  Value *StaticTop = moveStaticAllocasToUnsafeStack(
+      IRB, F, StaticAllocas, ByValArguments, BasePointer, StackGuardSlot);
 
   // Safe stack object that stores the current unsafe stack top. It is updated
   // as unsafe dynamic (non-constant-sized) allocas are allocated and freed.
@@ -838,7 +838,7 @@ bool SafeStack::run() {
                                   DynamicAllocas);
 
   // Restore the unsafe stack pointer before each return.
-  for (ReturnInst *RI : Returns) {
+  for (Instruction *RI : Returns) {
     IRB.SetInsertPoint(RI);
     IRB.CreateStore(BasePointer, UnsafeStackPtr);
   }

diff  --git a/llvm/test/Transforms/SafeStack/X86/musttail.ll b/llvm/test/Transforms/SafeStack/X86/musttail.ll
new file mode 100644
index 000000000000..d209b2ff2ab4
--- /dev/null
+++ b/llvm/test/Transforms/SafeStack/X86/musttail.ll
@@ -0,0 +1,46 @@
+; To test that safestack does not break the musttail call contract.
+;
+; RUN: opt < %s --safe-stack -S | FileCheck %s
+
+target triple = "x86_64-unknown-linux-gnu"
+
+declare i32 @foo(i32* %p)
+declare void @alloca_test_use([10 x i8]*)
+
+define i32 @call_foo(i32* %a) safestack {
+; CHECK-LABEL: @call_foo(
+; CHECK-NEXT:    [[UNSAFE_STACK_PTR:%.*]] = load i8*, i8** @__safestack_unsafe_stack_ptr, align 8
+; CHECK-NEXT:    [[UNSAFE_STACK_STATIC_TOP:%.*]] = getelementptr i8, i8* [[UNSAFE_STACK_PTR]], i32 -16
+; CHECK-NEXT:    store i8* [[UNSAFE_STACK_STATIC_TOP]], i8** @__safestack_unsafe_stack_ptr, align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i8, i8* [[UNSAFE_STACK_PTR]], i32 -10
+; CHECK-NEXT:    [[X_UNSAFE:%.*]] = bitcast i8* [[TMP1]] to [10 x i8]*
+; CHECK-NEXT:    call void @alloca_test_use([10 x i8]* [[X_UNSAFE]])
+; CHECK-NEXT:    store i8* [[UNSAFE_STACK_PTR]], i8** @__safestack_unsafe_stack_ptr, align 8
+; CHECK-NEXT:    [[R:%.*]] = musttail call i32 @foo(i32* [[A:%.*]])
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %x = alloca [10 x i8], align 1
+  call void @alloca_test_use([10 x i8]* %x)
+  %r = musttail call i32 @foo(i32* %a)
+  ret i32 %r
+}
+
+define i32 @call_foo_cast(i32* %a) safestack {
+; CHECK-LABEL: @call_foo_cast(
+; CHECK-NEXT:    [[UNSAFE_STACK_PTR:%.*]] = load i8*, i8** @__safestack_unsafe_stack_ptr, align 8
+; CHECK-NEXT:    [[UNSAFE_STACK_STATIC_TOP:%.*]] = getelementptr i8, i8* [[UNSAFE_STACK_PTR]], i32 -16
+; CHECK-NEXT:    store i8* [[UNSAFE_STACK_STATIC_TOP]], i8** @__safestack_unsafe_stack_ptr, align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i8, i8* [[UNSAFE_STACK_PTR]], i32 -10
+; CHECK-NEXT:    [[X_UNSAFE:%.*]] = bitcast i8* [[TMP1]] to [10 x i8]*
+; CHECK-NEXT:    call void @alloca_test_use([10 x i8]* [[X_UNSAFE]])
+; CHECK-NEXT:    store i8* [[UNSAFE_STACK_PTR]], i8** @__safestack_unsafe_stack_ptr, align 8
+; CHECK-NEXT:    [[R:%.*]] = musttail call i32 @foo(i32* [[A:%.*]])
+; CHECK-NEXT:    [[T:%.*]] = bitcast i32 [[R]] to i32
+; CHECK-NEXT:    ret i32 [[T]]
+;
+  %x = alloca [10 x i8], align 1
+  call void @alloca_test_use([10 x i8]* %x)
+  %r = musttail call i32 @foo(i32* %a)
+  %t = bitcast i32 %r to i32
+  ret i32 %t
+}


        


More information about the llvm-commits mailing list