[llvm] 6c06d8e - [stack-safety] Check SCEV constraints at memory instructions.
Florian Mayer via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 23 15:29:33 PST 2021
Author: Florian Mayer
Date: 2021-11-23T15:29:23-08:00
New Revision: 6c06d8e310bd926f8c9ed63118c38b28075f4de3
URL: https://github.com/llvm/llvm-project/commit/6c06d8e310bd926f8c9ed63118c38b28075f4de3
DIFF: https://github.com/llvm/llvm-project/commit/6c06d8e310bd926f8c9ed63118c38b28075f4de3.diff
LOG: [stack-safety] Check SCEV constraints at memory instructions.
Reviewed By: vitalybuka
Differential Revision: https://reviews.llvm.org/D113160
Added:
Modified:
llvm/lib/Analysis/StackSafetyAnalysis.cpp
llvm/test/Analysis/StackSafetyAnalysis/local.ll
llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/StackSafetyAnalysis.cpp b/llvm/lib/Analysis/StackSafetyAnalysis.cpp
index 74cc39b7f2c0d..54f3605ee0333 100644
--- a/llvm/lib/Analysis/StackSafetyAnalysis.cpp
+++ b/llvm/lib/Analysis/StackSafetyAnalysis.cpp
@@ -14,12 +14,14 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/ModuleSummaryAnalysis.h"
+#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/StackLifetime.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/ModuleSummaryIndex.h"
@@ -117,7 +119,7 @@ template <typename CalleeTy> struct UseInfo {
// Access range if the address (alloca or parameters).
// It is allowed to be empty-set when there are no known accesses.
ConstantRange Range;
- std::map<const Instruction *, ConstantRange> Accesses;
+ std::set<const Instruction *> UnsafeAccesses;
// List of calls which pass address as an argument.
// Value is offset range of address from base address (alloca or calling
@@ -131,10 +133,9 @@ template <typename CalleeTy> struct UseInfo {
UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
void updateRange(const ConstantRange &R) { Range = unionNoWrap(Range, R); }
- void addRange(const Instruction *I, const ConstantRange &R) {
- auto Ins = Accesses.emplace(I, R);
- if (!Ins.second)
- Ins.first->second = unionNoWrap(Ins.first->second, R);
+ void addRange(const Instruction *I, const ConstantRange &R, bool IsSafe) {
+ if (!IsSafe)
+ UnsafeAccesses.insert(I);
updateRange(R);
}
};
@@ -230,7 +231,7 @@ struct StackSafetyInfo::InfoTy {
struct StackSafetyGlobalInfo::InfoTy {
GVToSSI Info;
SmallPtrSet<const AllocaInst *, 8> SafeAllocas;
- std::map<const Instruction *, bool> AccessIsUnsafe;
+ std::set<const Instruction *> UnsafeAccesses;
};
namespace {
@@ -253,6 +254,11 @@ class StackSafetyLocalAnalysis {
void analyzeAllUses(Value *Ptr, UseInfo<GlobalValue> &AS,
const StackLifetime &SL);
+
+ bool isSafeAccess(const Use &U, AllocaInst *AI, const SCEV *AccessSize);
+ bool isSafeAccess(const Use &U, AllocaInst *AI, Value *V);
+ bool isSafeAccess(const Use &U, AllocaInst *AI, TypeSize AccessSize);
+
public:
StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
: F(F), DL(F.getParent()->getDataLayout()), SE(SE),
@@ -333,6 +339,56 @@ ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
return getAccessRange(U, Base, SizeRange);
}
+bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI,
+ Value *V) {
+ return isSafeAccess(U, AI, SE.getSCEV(V));
+}
+
+bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI,
+ TypeSize TS) {
+ if (TS.isScalable())
+ return false;
+ auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
+ const SCEV *SV = SE.getConstant(CalculationTy, TS.getFixedSize());
+ return isSafeAccess(U, AI, SV);
+}
+
+bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI,
+ const SCEV *AccessSize) {
+
+ if (!AI)
+ return true;
+ if (isa<SCEVCouldNotCompute>(AccessSize))
+ return false;
+
+ const auto *I = cast<Instruction>(U.getUser());
+
+ auto ToCharPtr = [&](const SCEV *V) {
+ auto *PtrTy = IntegerType::getInt8PtrTy(SE.getContext());
+ return SE.getTruncateOrZeroExtend(V, PtrTy);
+ };
+
+ const SCEV *AddrExp = ToCharPtr(SE.getSCEV(U.get()));
+ const SCEV *BaseExp = ToCharPtr(SE.getSCEV(AI));
+ const SCEV *Diff = SE.getMinusSCEV(AddrExp, BaseExp);
+ if (isa<SCEVCouldNotCompute>(Diff))
+ return false;
+
+ auto Size = getStaticAllocaSizeRange(*AI);
+
+ auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
+ auto ToDiffTy = [&](const SCEV *V) {
+ return SE.getTruncateOrZeroExtend(V, CalculationTy);
+ };
+ const SCEV *Min = ToDiffTy(SE.getConstant(Size.getLower()));
+ const SCEV *Max = SE.getMinusSCEV(ToDiffTy(SE.getConstant(Size.getUpper())),
+ ToDiffTy(AccessSize));
+ return SE.evaluatePredicateAt(ICmpInst::Predicate::ICMP_SGE, Diff, Min, I)
+ .getValueOr(false) &&
+ SE.evaluatePredicateAt(ICmpInst::Predicate::ICMP_SLE, Diff, Max, I)
+ .getValueOr(false);
+}
+
/// The function analyzes all local uses of Ptr (alloca or argument) and
/// calculates local access range and all function calls where it was used.
void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
@@ -341,7 +397,7 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
SmallPtrSet<const Value *, 16> Visited;
SmallVector<const Value *, 8> WorkList;
WorkList.push_back(Ptr);
- const AllocaInst *AI = dyn_cast<AllocaInst>(Ptr);
+ AllocaInst *AI = dyn_cast<AllocaInst>(Ptr);
// A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
while (!WorkList.empty()) {
@@ -356,11 +412,13 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
switch (I->getOpcode()) {
case Instruction::Load: {
if (AI && !SL.isAliveAfter(AI, I)) {
- US.addRange(I, UnknownRange);
+ US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
- US.addRange(I,
- getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
+ auto TypeSize = DL.getTypeStoreSize(I->getType());
+ auto AccessRange = getAccessRange(UI, Ptr, TypeSize);
+ bool Safe = isSafeAccess(UI, AI, TypeSize);
+ US.addRange(I, AccessRange, Safe);
break;
}
@@ -370,16 +428,17 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
case Instruction::Store: {
if (V == I->getOperand(0)) {
// Stored the pointer - conservatively assume it may be unsafe.
- US.addRange(I, UnknownRange);
+ US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
if (AI && !SL.isAliveAfter(AI, I)) {
- US.addRange(I, UnknownRange);
+ US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
- US.addRange(
- I, getAccessRange(
- UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
+ auto TypeSize = DL.getTypeStoreSize(I->getOperand(0)->getType());
+ auto AccessRange = getAccessRange(UI, Ptr, TypeSize);
+ bool Safe = isSafeAccess(UI, AI, TypeSize);
+ US.addRange(I, AccessRange, Safe);
break;
}
@@ -387,7 +446,7 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
// Information leak.
// FIXME: Process parameters correctly. This is a leak only if we return
// alloca.
- US.addRange(I, UnknownRange);
+ US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
case Instruction::Call:
@@ -396,12 +455,20 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
break;
if (AI && !SL.isAliveAfter(AI, I)) {
- US.addRange(I, UnknownRange);
+ US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
-
if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
- US.addRange(I, getMemIntrinsicAccessRange(MI, UI, Ptr));
+ auto AccessRange = getMemIntrinsicAccessRange(MI, UI, Ptr);
+ bool Safe = false;
+ if (const auto *MTI = dyn_cast<MemTransferInst>(MI)) {
+ if (MTI->getRawSource() != UI && MTI->getRawDest() != UI)
+ Safe = true;
+ } else if (MI->getRawDest() != UI) {
+ Safe = true;
+ }
+ Safe = Safe || isSafeAccess(UI, AI, MI->getLength());
+ US.addRange(I, AccessRange, Safe);
break;
}
@@ -412,15 +479,16 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
}
if (!CB.isArgOperand(&UI)) {
- US.addRange(I, UnknownRange);
+ US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
unsigned ArgNo = CB.getArgOperandNo(&UI);
if (CB.isByValArgument(ArgNo)) {
- US.addRange(I, getAccessRange(
- UI, Ptr,
- DL.getTypeStoreSize(CB.getParamByValType(ArgNo))));
+ auto TypeSize = DL.getTypeStoreSize(CB.getParamByValType(ArgNo));
+ auto AccessRange = getAccessRange(UI, Ptr, TypeSize);
+ bool Safe = isSafeAccess(UI, AI, TypeSize);
+ US.addRange(I, AccessRange, Safe);
break;
}
@@ -430,7 +498,7 @@ void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
const GlobalValue *Callee =
dyn_cast<GlobalValue>(CB.getCalledOperand()->stripPointerCasts());
if (!Callee) {
- US.addRange(I, UnknownRange);
+ US.addRange(I, UnknownRange, /*IsSafe=*/false);
break;
}
@@ -827,8 +895,8 @@ const StackSafetyGlobalInfo::InfoTy &StackSafetyGlobalInfo::getInfo() const {
Info->SafeAllocas.insert(AI);
++NumAllocaStackSafe;
}
- for (const auto &A : KV.second.Accesses)
- Info->AccessIsUnsafe[A.first] |= !AIRange.contains(A.second);
+ Info->UnsafeAccesses.insert(KV.second.UnsafeAccesses.begin(),
+ KV.second.UnsafeAccesses.end());
}
}
@@ -903,11 +971,7 @@ bool StackSafetyGlobalInfo::isSafe(const AllocaInst &AI) const {
bool StackSafetyGlobalInfo::stackAccessIsSafe(const Instruction &I) const {
const auto &Info = getInfo();
- auto It = Info.AccessIsUnsafe.find(&I);
- if (It == Info.AccessIsUnsafe.end()) {
- return true;
- }
- return !It->second;
+ return Info.UnsafeAccesses.find(&I) == Info.UnsafeAccesses.end();
}
void StackSafetyGlobalInfo::print(raw_ostream &O) const {
diff --git a/llvm/test/Analysis/StackSafetyAnalysis/local.ll b/llvm/test/Analysis/StackSafetyAnalysis/local.ll
index f764fe3e84098..24ffc1b69574c 100644
--- a/llvm/test/Analysis/StackSafetyAnalysis/local.ll
+++ b/llvm/test/Analysis/StackSafetyAnalysis/local.ll
@@ -44,6 +44,53 @@ entry:
ret void
}
+define void @StoreInBoundsCond(i64 %i) {
+; CHECK-LABEL: @StoreInBoundsCond dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[4]: full-set{{$}}
+; GLOBAL-NEXT: safe accesses:
+; GLOBAL-NEXT: store i8 0, i8* %x2, align 1
+; CHECK-EMPTY:
+entry:
+ %x = alloca i32, align 4
+ %x1 = bitcast i32* %x to i8*
+ %c1 = icmp sge i64 %i, 0
+ %c2 = icmp slt i64 %i, 4
+ br i1 %c1, label %c1.true, label %false
+
+c1.true:
+ br i1 %c2, label %c2.true, label %false
+
+c2.true:
+ %x2 = getelementptr i8, i8* %x1, i64 %i
+ store i8 0, i8* %x2, align 1
+ br label %false
+
+false:
+ ret void
+}
+
+define void @StoreInBoundsMinMax(i64 %i) {
+; CHECK-LABEL: @StoreInBoundsMinMax dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[4]: [0,4){{$}}
+; GLOBAL-NEXT: safe accesses:
+; GLOBAL-NEXT: store i8 0, i8* %x2, align 1
+; CHECK-EMPTY:
+entry:
+ %x = alloca i32, align 4
+ %x1 = bitcast i32* %x to i8*
+ %c1 = icmp sge i64 %i, 0
+ %i1 = select i1 %c1, i64 %i, i64 0
+ %c2 = icmp slt i64 %i1, 3
+ %i2 = select i1 %c2, i64 %i1, i64 3
+ %x2 = getelementptr i8, i8* %x1, i64 %i2
+ store i8 0, i8* %x2, align 1
+ ret void
+}
+
define void @StoreInBounds2() {
; CHECK-LABEL: @StoreInBounds2 dso_preemptable{{$}}
; CHECK-NEXT: args uses:
@@ -157,6 +204,54 @@ entry:
ret void
}
+define void @StoreOutOfBoundsCond(i64 %i) {
+; CHECK-LABEL: @StoreOutOfBoundsCond dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[4]: full-set{{$}}
+; GLOBAL-NEXT: safe accesses:
+; CHECK-EMPTY:
+entry:
+ %x = alloca i32, align 4
+ %x1 = bitcast i32* %x to i8*
+ %c1 = icmp sge i64 %i, 0
+ %c2 = icmp slt i64 %i, 5
+ br i1 %c1, label %c1.true, label %false
+
+c1.true:
+ br i1 %c2, label %c2.true, label %false
+
+c2.true:
+ %x2 = getelementptr i8, i8* %x1, i64 %i
+ store i8 0, i8* %x2, align 1
+ br label %false
+
+false:
+ ret void
+}
+
+define void @StoreOutOfBoundsCond2(i64 %i) {
+; CHECK-LABEL: @StoreOutOfBoundsCond2 dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[4]: full-set{{$}}
+; GLOBAL-NEXT: safe accesses:
+; CHECK-EMPTY:
+entry:
+ %x = alloca i32, align 4
+ %x1 = bitcast i32* %x to i8*
+ %c2 = icmp slt i64 %i, 5
+ br i1 %c2, label %c2.true, label %false
+
+c2.true:
+ %x2 = getelementptr i8, i8* %x1, i64 %i
+ store i8 0, i8* %x2, align 1
+ br label %false
+
+false:
+ ret void
+}
+
define void @StoreOutOfBounds2() {
; CHECK-LABEL: @StoreOutOfBounds2 dso_preemptable{{$}}
; CHECK-NEXT: args uses:
diff --git a/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll b/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll
index a19e3ed46efb0..9e1c49cb90be1 100644
--- a/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll
+++ b/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll
@@ -233,3 +233,42 @@ entry:
call void @llvm.memmove.p0i8.p0i8.i32(i8* %x1, i8* %x2, i32 9, i1 false)
ret void
}
+
+define void @MemsetInBoundsCast() {
+; CHECK-LABEL: MemsetInBoundsCast dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[4]: [0,4){{$}}
+; CHECK-NEXT: y[1]: empty-set{{$}}
+; GLOBAL-NEXT: safe accesses:
+; GLOBAL-NEXT: call void @llvm.memset.p0i8.i32(i8* %x1, i8 %yint, i32 4, i1 false)
+; CHECK-EMPTY:
+entry:
+ %x = alloca i32, align 4
+ %y = alloca i8, align 1
+ %x1 = bitcast i32* %x to i8*
+ %yint = ptrtoint i8* %y to i8
+ call void @llvm.memset.p0i8.i32(i8* %x1, i8 %yint, i32 4, i1 false)
+ ret void
+}
+
+define void @MemcpyInBoundsCast2(i8 %zint8) {
+; CHECK-LABEL: MemcpyInBoundsCast2 dso_preemptable{{$}}
+; CHECK-NEXT: args uses:
+; CHECK-NEXT: allocas uses:
+; CHECK-NEXT: x[256]: [0,255){{$}}
+; CHECK-NEXT: y[256]: [0,255){{$}}
+; CHECK-NEXT: z[1]: empty-set{{$}}
+; GLOBAL-NEXT: safe accesses:
+; GLOBAL-NEXT: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %x1, i8* %y1, i32 %zint32, i1 false)
+; CHECK-EMPTY:
+entry:
+ %x = alloca [256 x i8], align 4
+ %y = alloca [256 x i8], align 4
+ %z = alloca i8, align 1
+ %x1 = bitcast [256 x i8]* %x to i8*
+ %y1 = bitcast [256 x i8]* %y to i8*
+ %zint32 = zext i8 %zint8 to i32
+ call void @llvm.memcpy.p0i8.p0i8.i32(i8* %x1, i8* %y1, i32 %zint32, i1 false)
+ ret void
+}
More information about the llvm-commits
mailing list