[llvm] 6c06d8e - [stack-safety] Check SCEV constraints at memory instructions.

Florian Mayer via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 23 15:29:33 PST 2021


Author: Florian Mayer
Date: 2021-11-23T15:29:23-08:00
New Revision: 6c06d8e310bd926f8c9ed63118c38b28075f4de3

URL: https://github.com/llvm/llvm-project/commit/6c06d8e310bd926f8c9ed63118c38b28075f4de3
DIFF: https://github.com/llvm/llvm-project/commit/6c06d8e310bd926f8c9ed63118c38b28075f4de3.diff

LOG: [stack-safety] Check SCEV constraints at memory instructions.

Reviewed By: vitalybuka

Differential Revision: https://reviews.llvm.org/D113160

Added: 
    

Modified: 
    llvm/lib/Analysis/StackSafetyAnalysis.cpp
    llvm/test/Analysis/StackSafetyAnalysis/local.ll
    llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/StackSafetyAnalysis.cpp b/llvm/lib/Analysis/StackSafetyAnalysis.cpp
index 74cc39b7f2c0d..54f3605ee0333 100644
--- a/llvm/lib/Analysis/StackSafetyAnalysis.cpp
+++ b/llvm/lib/Analysis/StackSafetyAnalysis.cpp
@@ -14,12 +14,14 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/ModuleSummaryAnalysis.h"
+#include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/StackLifetime.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/ModuleSummaryIndex.h"
@@ -117,7 +119,7 @@ template <typename CalleeTy> struct UseInfo {
   // Access range if the address (alloca or parameters).
   // It is allowed to be empty-set when there are no known accesses.
   ConstantRange Range;
-  std::map<const Instruction *, ConstantRange> Accesses;
+  std::set<const Instruction *> UnsafeAccesses;
 
   // List of calls which pass address as an argument.
   // Value is offset range of address from base address (alloca or calling
@@ -131,10 +133,9 @@ template <typename CalleeTy> struct UseInfo {
   UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
 
   void updateRange(const ConstantRange &R) { Range = unionNoWrap(Range, R); }
-  void addRange(const Instruction *I, const ConstantRange &R) {
-    auto Ins = Accesses.emplace(I, R);
-    if (!Ins.second)
-      Ins.first->second = unionNoWrap(Ins.first->second, R);
+  void addRange(const Instruction *I, const ConstantRange &R, bool IsSafe) {
+    if (!IsSafe)
+      UnsafeAccesses.insert(I);
     updateRange(R);
   }
 };
@@ -230,7 +231,7 @@ struct StackSafetyInfo::InfoTy {
 struct StackSafetyGlobalInfo::InfoTy {
   GVToSSI Info;
   SmallPtrSet<const AllocaInst *, 8> SafeAllocas;
-  std::map<const Instruction *, bool> AccessIsUnsafe;
+  std::set<const Instruction *> UnsafeAccesses;
 };
 
 namespace {
@@ -253,6 +254,11 @@ class StackSafetyLocalAnalysis {
   void analyzeAllUses(Value *Ptr, UseInfo<GlobalValue> &AS,
                       const StackLifetime &SL);
 
+
+  bool isSafeAccess(const Use &U, AllocaInst *AI, const SCEV *AccessSize);
+  bool isSafeAccess(const Use &U, AllocaInst *AI, Value *V);
+  bool isSafeAccess(const Use &U, AllocaInst *AI, TypeSize AccessSize);
+
 public:
   StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
       : F(F), DL(F.getParent()->getDataLayout()), SE(SE),
@@ -333,6 +339,56 @@ ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
   return getAccessRange(U, Base, SizeRange);
 }
 
+bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI,
+                                            Value *V) {
+  return isSafeAccess(U, AI, SE.getSCEV(V));
+}
+
+bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI,
+                                            TypeSize TS) {
+  if (TS.isScalable())
+    return false;
+  auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
+  const SCEV *SV = SE.getConstant(CalculationTy, TS.getFixedSize());
+  return isSafeAccess(U, AI, SV);
+}
+
+bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI,
+                                            const SCEV *AccessSize) {
+
+  if (!AI)
+    return true;
+  if (isa<SCEVCouldNotCompute>(AccessSize))
+    return false;
+
+  const auto *I = cast<Instruction>(U.getUser());
+
+  auto ToCharPtr = [&](const SCEV *V) {
+    auto *PtrTy = IntegerType::getInt8PtrTy(SE.getContext());
+    return SE.getTruncateOrZeroExtend(V, PtrTy);
+  };
+
+  const SCEV *AddrExp = ToCharPtr(SE.getSCEV(U.get()));
+  const SCEV *BaseExp = ToCharPtr(SE.getSCEV(AI));
+  const SCEV *Diff = SE.getMinusSCEV(AddrExp, BaseExp);
+  if (isa<SCEVCouldNotCompute>(Diff))
+    return false;
+
+  auto Size = getStaticAllocaSizeRange(*AI);
+
+  auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
+  auto ToDiffTy = [&](const SCEV *V) {
+    return SE.getTruncateOrZeroExtend(V, CalculationTy);
+  };
+  const SCEV *Min = ToDiffTy(SE.getConstant(Size.getLower()));
+  const SCEV *Max = SE.getMinusSCEV(ToDiffTy(SE.getConstant(Size.getUpper())),
+                                    ToDiffTy(AccessSize));
+  return SE.evaluatePredicateAt(ICmpInst::Predicate::ICMP_SGE, Diff, Min, I)
+             .getValueOr(false) &&
+         SE.evaluatePredicateAt(ICmpInst::Predicate::ICMP_SLE, Diff, Max, I)
+             .getValueOr(false);
+}
+
 /// The function analyzes all local uses of Ptr (alloca or argument) and
 /// calculates local access range and all function calls where it was used.
 void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
@@ -341,7 +397,7 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
   SmallPtrSet<const Value *, 16> Visited;
   SmallVector<const Value *, 8> WorkList;
   WorkList.push_back(Ptr);
-  const AllocaInst *AI = dyn_cast<AllocaInst>(Ptr);
+  AllocaInst *AI = dyn_cast<AllocaInst>(Ptr);
 
   // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
   while (!WorkList.empty()) {
@@ -356,11 +412,13 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
       switch (I->getOpcode()) {
       case Instruction::Load: {
         if (AI && !SL.isAliveAfter(AI, I)) {
-          US.addRange(I, UnknownRange);
+          US.addRange(I, UnknownRange, /*IsSafe=*/false);
           break;
         }
-        US.addRange(I,
-                    getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
+        auto TypeSize = DL.getTypeStoreSize(I->getType());
+        auto AccessRange = getAccessRange(UI, Ptr, TypeSize);
+        bool Safe = isSafeAccess(UI, AI, TypeSize);
+        US.addRange(I, AccessRange, Safe);
         break;
       }
 
@@ -370,16 +428,17 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
       case Instruction::Store: {
         if (V == I->getOperand(0)) {
           // Stored the pointer - conservatively assume it may be unsafe.
-          US.addRange(I, UnknownRange);
+          US.addRange(I, UnknownRange, /*IsSafe=*/false);
           break;
         }
         if (AI && !SL.isAliveAfter(AI, I)) {
-          US.addRange(I, UnknownRange);
+          US.addRange(I, UnknownRange, /*IsSafe=*/false);
           break;
         }
-        US.addRange(
-            I, getAccessRange(
-                   UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
+        auto TypeSize = DL.getTypeStoreSize(I->getOperand(0)->getType());
+        auto AccessRange = getAccessRange(UI, Ptr, TypeSize);
+        bool Safe = isSafeAccess(UI, AI, TypeSize);
+        US.addRange(I, AccessRange, Safe);
         break;
       }
 
@@ -387,7 +446,7 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
         // Information leak.
         // FIXME: Process parameters correctly. This is a leak only if we return
         // alloca.
-        US.addRange(I, UnknownRange);
+        US.addRange(I, UnknownRange, /*IsSafe=*/false);
         break;
 
       case Instruction::Call:
@@ -396,12 +455,20 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
           break;
 
         if (AI && !SL.isAliveAfter(AI, I)) {
-          US.addRange(I, UnknownRange);
+          US.addRange(I, UnknownRange, /*IsSafe=*/false);
           break;
         }
-
         if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
-          US.addRange(I, getMemIntrinsicAccessRange(MI, UI, Ptr));
+          auto AccessRange = getMemIntrinsicAccessRange(MI, UI, Ptr);
+          bool Safe = false;
+          if (const auto *MTI = dyn_cast<MemTransferInst>(MI)) {
+            if (MTI->getRawSource() != UI && MTI->getRawDest() != UI)
+              Safe = true;
+          } else if (MI->getRawDest() != UI) {
+            Safe = true;
+          }
+          Safe = Safe || isSafeAccess(UI, AI, MI->getLength());
+          US.addRange(I, AccessRange, Safe);
           break;
         }
 
@@ -412,15 +479,16 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
         }
 
         if (!CB.isArgOperand(&UI)) {
-          US.addRange(I, UnknownRange);
+          US.addRange(I, UnknownRange, /*IsSafe=*/false);
           break;
         }
 
         unsigned ArgNo = CB.getArgOperandNo(&UI);
         if (CB.isByValArgument(ArgNo)) {
-          US.addRange(I, getAccessRange(
-                             UI, Ptr,
-                             DL.getTypeStoreSize(CB.getParamByValType(ArgNo))));
+          auto TypeSize = DL.getTypeStoreSize(CB.getParamByValType(ArgNo));
+          auto AccessRange = getAccessRange(UI, Ptr, TypeSize);
+          bool Safe = isSafeAccess(UI, AI, TypeSize);
+          US.addRange(I, AccessRange, Safe);
           break;
         }
 
@@ -430,7 +498,7 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
         const GlobalValue *Callee =
             dyn_cast<GlobalValue>(CB.getCalledOperand()->stripPointerCasts());
         if (!Callee) {
-          US.addRange(I, UnknownRange);
+          US.addRange(I, UnknownRange, /*IsSafe=*/false);
           break;
         }
 
@@ -827,8 +895,8 @@ const StackSafetyGlobalInfo::InfoTy &StackSafetyGlobalInfo::getInfo() const {
           Info->SafeAllocas.insert(AI);
           ++NumAllocaStackSafe;
         }
-        for (const auto &A : KV.second.Accesses)
-          Info->AccessIsUnsafe[A.first] |= !AIRange.contains(A.second);
+        Info->UnsafeAccesses.insert(KV.second.UnsafeAccesses.begin(),
+                                    KV.second.UnsafeAccesses.end());
       }
     }
 
@@ -903,11 +971,7 @@ bool StackSafetyGlobalInfo::isSafe(const AllocaInst &AI) const {
 
 bool StackSafetyGlobalInfo::stackAccessIsSafe(const Instruction &I) const {
   const auto &Info = getInfo();
-  auto It = Info.AccessIsUnsafe.find(&I);
-  if (It == Info.AccessIsUnsafe.end()) {
-    return true;
-  }
-  return !It->second;
+  return Info.UnsafeAccesses.find(&I) == Info.UnsafeAccesses.end();
 }
 
 void StackSafetyGlobalInfo::print(raw_ostream &O) const {

diff  --git a/llvm/test/Analysis/StackSafetyAnalysis/local.ll b/llvm/test/Analysis/StackSafetyAnalysis/local.ll
index f764fe3e84098..24ffc1b69574c 100644
--- a/llvm/test/Analysis/StackSafetyAnalysis/local.ll
+++ b/llvm/test/Analysis/StackSafetyAnalysis/local.ll
@@ -44,6 +44,53 @@ entry:
   ret void
 }
 
+define void @StoreInBoundsCond(i64 %i) {
+; CHECK-LABEL: @StoreInBoundsCond dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[4]: full-set{{$}}
+; GLOBAL-NEXT: safe accesses:
+; GLOBAL-NEXT: store i8 0, i8* %x2, align 1
+; CHECK-EMPTY:
+entry:
+  %x = alloca i32, align 4
+  %x1 = bitcast i32* %x to i8*
+  %c1 = icmp sge i64 %i, 0
+  %c2 = icmp slt i64 %i, 4
+  br i1 %c1, label %c1.true, label %false
+
+c1.true:
+  br i1 %c2, label %c2.true, label %false
+
+c2.true:
+  %x2 = getelementptr i8, i8* %x1, i64 %i
+  store i8 0, i8* %x2, align 1
+  br label %false
+
+false:
+  ret void
+}
+
+define void @StoreInBoundsMinMax(i64 %i) {
+; CHECK-LABEL: @StoreInBoundsMinMax dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[4]: [0,4){{$}}
+; GLOBAL-NEXT: safe accesses:
+; GLOBAL-NEXT: store i8 0, i8* %x2, align 1
+; CHECK-EMPTY:
+entry:
+  %x = alloca i32, align 4
+  %x1 = bitcast i32* %x to i8*
+  %c1 = icmp sge i64 %i, 0
+  %i1 = select i1 %c1, i64 %i, i64 0
+  %c2 = icmp slt i64 %i1, 3
+  %i2 = select i1 %c2, i64 %i1, i64 3
+  %x2 = getelementptr i8, i8* %x1, i64 %i2
+  store i8 0, i8* %x2, align 1
+  ret void
+}
+
 define void @StoreInBounds2() {
 ; CHECK-LABEL: @StoreInBounds2 dso_preemptable{{$}}
 ; CHECK-NEXT: args uses:
@@ -157,6 +204,54 @@ entry:
   ret void
 }
 
+define void @StoreOutOfBoundsCond(i64 %i) {
+; CHECK-LABEL: @StoreOutOfBoundsCond dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[4]: full-set{{$}}
+; GLOBAL-NEXT: safe accesses:
+; CHECK-EMPTY:
+entry:
+  %x = alloca i32, align 4
+  %x1 = bitcast i32* %x to i8*
+  %c1 = icmp sge i64 %i, 0
+  %c2 = icmp slt i64 %i, 5
+  br i1 %c1, label %c1.true, label %false
+
+c1.true:
+  br i1 %c2, label %c2.true, label %false
+
+c2.true:
+  %x2 = getelementptr i8, i8* %x1, i64 %i
+  store i8 0, i8* %x2, align 1
+  br label %false
+
+false:
+  ret void
+}
+
+define void @StoreOutOfBoundsCond2(i64 %i) {
+; CHECK-LABEL: @StoreOutOfBoundsCond2 dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[4]: full-set{{$}}
+; GLOBAL-NEXT: safe accesses:
+; CHECK-EMPTY:
+entry:
+  %x = alloca i32, align 4
+  %x1 = bitcast i32* %x to i8*
+  %c2 = icmp slt i64 %i, 5
+  br i1 %c2, label %c2.true, label %false
+
+c2.true:
+  %x2 = getelementptr i8, i8* %x1, i64 %i
+  store i8 0, i8* %x2, align 1
+  br label %false
+
+false:
+  ret void
+}
+
 define void @StoreOutOfBounds2() {
 ; CHECK-LABEL: @StoreOutOfBounds2 dso_preemptable{{$}}
 ; CHECK-NEXT: args uses:

diff  --git a/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll b/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll
index a19e3ed46efb0..9e1c49cb90be1 100644
--- a/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll
+++ b/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll
@@ -233,3 +233,42 @@ entry:
   call void @llvm.memmove.p0i8.p0i8.i32(i8* %x1, i8* %x2, i32 9, i1 false)
   ret void
 }
+
+define void @MemsetInBoundsCast() {
+; CHECK-LABEL: MemsetInBoundsCast dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[4]: [0,4){{$}}
+; CHECK-NEXT: y[1]: empty-set{{$}}
+; GLOBAL-NEXT: safe accesses:
+; GLOBAL-NEXT: call void @llvm.memset.p0i8.i32(i8* %x1, i8 %yint, i32 4, i1 false)
+; CHECK-EMPTY:
+entry:
+  %x = alloca i32, align 4
+  %y = alloca i8, align 1
+  %x1 = bitcast i32* %x to i8*
+  %yint = ptrtoint i8* %y to i8
+  call void @llvm.memset.p0i8.i32(i8* %x1, i8 %yint, i32 4, i1 false)
+  ret void
+}
+
+define void @MemcpyInBoundsCast2(i8 %zint8) {
+; CHECK-LABEL: MemcpyInBoundsCast2 dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[256]: [0,255){{$}}
+; CHECK-NEXT: y[256]: [0,255){{$}}
+; CHECK-NEXT: z[1]: empty-set{{$}}
+; GLOBAL-NEXT: safe accesses:
+; GLOBAL-NEXT: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %x1, i8* %y1, i32 %zint32, i1 false)
+; CHECK-EMPTY:
+entry:
+  %x = alloca [256 x i8], align 4
+  %y = alloca [256 x i8], align 4
+  %z = alloca i8, align 1
+  %x1 = bitcast [256 x i8]* %x to i8*
+  %y1 = bitcast [256 x i8]* %y to i8*
+  %zint32 = zext i8 %zint8 to i32
+  call void @llvm.memcpy.p0i8.p0i8.i32(i8* %x1, i8* %y1, i32 %zint32, i1 false)
+  ret void
+}


        


More information about the llvm-commits mailing list