[llvm] 9075864 - [BasicAA] Refactor linear expression decomposition
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 27 15:32:09 PDT 2021
Author: Nikita Popov
Date: 2021-03-27T23:31:58+01:00
New Revision: 9075864b7375b4eb8f2c0663caa575c0c667de7c
URL: https://github.com/llvm/llvm-project/commit/9075864b7375b4eb8f2c0663caa575c0c667de7c
DIFF: https://github.com/llvm/llvm-project/commit/9075864b7375b4eb8f2c0663caa575c0c667de7c.diff
LOG: [BasicAA] Refactor linear expression decomposition
The current linear expression decomposition handles zext/sext by
decomposing the casted operand, and then checking NUW/NSW flags
to determine whether the extension can be distributed. This has
some disadvantages:
First, it is not possible to perform a partial decomposition. If
we have zext((x + C1) +<nuw> C2) then we will fail to decompose
the expression entirely, even though it would be safe and
profitable to decompose it to zext(x + C1) +<nuw> zext(C2)
Second, we may end up performing unnecessary decompositions,
which will later be discarded because they lack nowrap flags
necessary for extensions.
Third, correctness of the code is not entirely obvious: At a high
level, we encounter zext(x -<nuw> C) in the form of a zext on the
linear expression x + (-C) with nuw flag set. Notably, this case
must be treated as zext(x) + -zext(C) rather than zext(x) + zext(-C).
The code handles this correctly by speculatively zexting constants
to the final bitwidth, and performing additional fixup if the
actual extension turns out to be an sext. This was not immediately
obvious to me.
This patch inverts the approach: An ExtendedValue represents a
zext(sext(V)), and linear expression decomposition will try to
decompose V further, either by absorbing another sext/zext into the
ExtendedValue, or by distributing zext(sext(x op C)) over a binary
operator with appropriate nsw/nuw flags. At each step we can
determine whether distribution is legal and abort with a partial
decomposition if not. We also know which extensions we need to
apply to constants, and don't need to speculate or fixup.
Added:
Modified:
llvm/include/llvm/Analysis/BasicAliasAnalysis.h
llvm/lib/Analysis/BasicAliasAnalysis.cpp
llvm/test/Analysis/BasicAA/zext.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/BasicAliasAnalysis.h b/llvm/include/llvm/Analysis/BasicAliasAnalysis.h
index 7ed2074badef..f13cf0c313c1 100644
--- a/llvm/include/llvm/Analysis/BasicAliasAnalysis.h
+++ b/llvm/include/llvm/Analysis/BasicAliasAnalysis.h
@@ -180,12 +180,6 @@ class BasicAAResult : public AAResultBase<BasicAAResult> {
/// Tracks instructions visited by pointsToConstantMemory.
SmallPtrSet<const Value *, 16> Visited;
- static const Value *
- GetLinearExpression(const Value *V, APInt &Scale, APInt &Offset,
- unsigned &ZExtBits, unsigned &SExtBits,
- const DataLayout &DL, unsigned Depth, AssumptionCache *AC,
- DominatorTree *DT, bool &NSW, bool &NUW);
-
static DecomposedGEP
DecomposeGEPExpression(const Value *V, const DataLayout &DL,
AssumptionCache *AC, DominatorTree *DT);
diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
index 02ccd7769695..9594f7b43f24 100644
--- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp
+++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
@@ -222,172 +222,159 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL,
// GetElementPtr Instruction Decomposition and Analysis
//===----------------------------------------------------------------------===//
-static const Value *extendLinearExpression(
- bool SignExt, unsigned NewWidth, const Value *CastOp, const Value *Result,
- APInt &Scale, APInt &Offset, unsigned &ZExtBits, unsigned &SExtBits,
- bool &NSW, bool &NUW) {
- unsigned SmallWidth = CastOp->getType()->getPrimitiveSizeInBits();
-
- // zext(zext(%x)) == zext(%x), and similarly for sext; we'll handle this
- // by just incrementing the number of bits we've extended by.
- unsigned ExtendedBy = NewWidth - SmallWidth;
-
- if (SignExt && ZExtBits == 0) {
- // sext(sext(%x, a), b) == sext(%x, a + b)
-
- if (NSW) {
- // We haven't sign-wrapped, so it's valid to decompose sext(%x + c)
- // into sext(%x) + sext(c). We'll sext the Offset ourselves:
- unsigned OldWidth = Offset.getBitWidth();
- Offset = Offset.truncOrSelf(SmallWidth).sext(NewWidth).zextOrSelf(OldWidth);
- } else {
- // We may have signed-wrapped, so don't decompose sext(%x + c) into
- // sext(%x) + sext(c)
- Scale = 1;
- Offset = 0;
- Result = CastOp;
- ZExtBits = 0;
- SExtBits = 0;
- }
- SExtBits += ExtendedBy;
- } else {
- // sext(zext(%x, a), b) = zext(zext(%x, a), b) = zext(%x, a + b)
-
- if (!NUW) {
- // We may have unsigned-wrapped, so don't decompose zext(%x + c) into
- // zext(%x) + zext(c)
- Scale = 1;
- Offset = 0;
- Result = CastOp;
- ZExtBits = 0;
- SExtBits = 0;
- }
- ZExtBits += ExtendedBy;
+namespace {
+/// Represents zext(sext(V)).
+struct ExtendedValue {
+ const Value *V;
+ unsigned ZExtBits;
+ unsigned SExtBits;
+
+ explicit ExtendedValue(const Value *V, unsigned ZExtBits = 0,
+ unsigned SExtBits = 0)
+ : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits) {}
+
+ unsigned getBitWidth() const {
+ return V->getType()->getPrimitiveSizeInBits() + ZExtBits + SExtBits;
}
- return Result;
+ ExtendedValue withValue(const Value *NewV) const {
+ return ExtendedValue(NewV, ZExtBits, SExtBits);
+ }
+
+ ExtendedValue withZExtOfValue(const Value *NewV) const {
+ unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
+ NewV->getType()->getPrimitiveSizeInBits();
+ // zext(sext(zext(NewV))) == zext(zext(zext(NewV)))
+ return ExtendedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0);
+ }
+
+ ExtendedValue withSExtOfValue(const Value *NewV) const {
+ unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
+ NewV->getType()->getPrimitiveSizeInBits();
+ // zext(sext(sext(NewV)))
+ return ExtendedValue(NewV, ZExtBits, SExtBits + ExtendBy);
+ }
+
+ APInt evaluateWith(APInt N) const {
+ assert(N.getBitWidth() == V->getType()->getPrimitiveSizeInBits() &&
+ "Incompatible bit width");
+ if (SExtBits) N = N.sext(N.getBitWidth() + SExtBits);
+ if (ZExtBits) N = N.zext(N.getBitWidth() + ZExtBits);
+ return N;
+ }
+
+ bool canDistributeOver(bool NUW, bool NSW) const {
+ // zext(x op<nuw> y) == zext(x) op<nuw> zext(y)
+ // sext(x op<nsw> y) == sext(x) op<nsw> sext(y)
+ return (!ZExtBits || NUW) && (!SExtBits || NSW);
+ }
+};
+
+/// Represents zext(sext(V)) * Scale + Offset.
+struct LinearExpression {
+ ExtendedValue Val;
+ APInt Scale;
+ APInt Offset;
+
+ LinearExpression(const ExtendedValue &Val, const APInt &Scale,
+ const APInt &Offset)
+ : Val(Val), Scale(Scale), Offset(Offset) {}
+
+ LinearExpression(const ExtendedValue &Val) : Val(Val) {
+ unsigned BitWidth = Val.getBitWidth();
+ Scale = APInt(BitWidth, 1);
+ Offset = APInt(BitWidth, 0);
+ }
+};
}
/// Analyzes the specified value as a linear expression: "A*V + B", where A and
/// B are constant integers.
-///
-/// Returns the scale and offset values as APInts and return V as a Value*, and
-/// return whether we looked through any sign or zero extends. The incoming
-/// Value is known to have IntegerType, and it may already be sign or zero
-/// extended.
-///
-/// Note that this looks through extends, so the high bits may not be
-/// represented in the result.
-/*static*/ const Value *BasicAAResult::GetLinearExpression(
- const Value *V, APInt &Scale, APInt &Offset, unsigned &ZExtBits,
- unsigned &SExtBits, const DataLayout &DL, unsigned Depth,
- AssumptionCache *AC, DominatorTree *DT, bool &NSW, bool &NUW) {
- assert(V->getType()->isIntegerTy() && "Not an integer value");
- assert(Scale == 0 && Offset == 0 && ZExtBits == 0 && SExtBits == 0 &&
- NSW == true && NUW == true && "Incorrect default values");
-
+static LinearExpression GetLinearExpression(
+ const ExtendedValue &Val, const DataLayout &DL, unsigned Depth,
+ AssumptionCache *AC, DominatorTree *DT) {
// Limit our recursion depth.
- if (Depth == 6) {
- Scale = 1;
- Offset = 0;
- return V;
- }
+ if (Depth == 6)
+ return Val;
- if (const ConstantInt *Const = dyn_cast<ConstantInt>(V)) {
- // If it's a constant, just convert it to an offset and remove the variable.
- // If we've been called recursively, the Offset bit width will be greater
- // than the constant's (the Offset's always as wide as the outermost call),
- // so we'll zext here and process any extension in the isa<SExtInst> &
- // isa<ZExtInst> cases below.
- Offset = Const->getValue().zextOrSelf(Offset.getBitWidth());
- assert(Scale == 0 && "Constant values don't have a scale");
- return V;
- }
+ if (const ConstantInt *Const = dyn_cast<ConstantInt>(Val.V))
+ return LinearExpression(Val, APInt(Val.getBitWidth(), 0),
+ Val.evaluateWith(Const->getValue()));
- if (const BinaryOperator *BOp = dyn_cast<BinaryOperator>(V)) {
+ if (const BinaryOperator *BOp = dyn_cast<BinaryOperator>(Val.V)) {
if (ConstantInt *RHSC = dyn_cast<ConstantInt>(BOp->getOperand(1))) {
- // If we've been called recursively, then Offset and Scale will be wider
- // than the BOp operands. We'll always zext it here as we'll process sign
- // extensions below (see the isa<SExtInst> / isa<ZExtInst> cases).
- APInt RHS = RHSC->getValue().zextOrSelf(Offset.getBitWidth());
+ APInt RHS = Val.evaluateWith(RHSC->getValue());
+ // The only non-OBO case we deal with is or, and only limited to the
+ // case where it is both nuw and nsw.
+ bool NUW = true, NSW = true;
+ if (isa<OverflowingBinaryOperator>(BOp)) {
+ NUW &= BOp->hasNoUnsignedWrap();
+ NSW &= BOp->hasNoSignedWrap();
+ }
+ if (!Val.canDistributeOver(NUW, NSW))
+ return Val;
switch (BOp->getOpcode()) {
default:
// We don't understand this instruction, so we can't decompose it any
// further.
- Scale = 1;
- Offset = 0;
- return V;
+ return Val;
case Instruction::Or:
// X|C == X+C if all the bits in C are unset in X. Otherwise we can't
// analyze it.
if (!MaskedValueIsZero(BOp->getOperand(0), RHSC->getValue(), DL, 0, AC,
- BOp, DT)) {
- Scale = 1;
- Offset = 0;
- return V;
- }
+ BOp, DT))
+ return Val;
+
LLVM_FALLTHROUGH;
- case Instruction::Add:
- V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits,
- SExtBits, DL, Depth + 1, AC, DT, NSW, NUW);
- Offset += RHS;
- break;
- case Instruction::Sub:
- V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits,
- SExtBits, DL, Depth + 1, AC, DT, NSW, NUW);
- Offset -= RHS;
- break;
- case Instruction::Mul:
- V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits,
- SExtBits, DL, Depth + 1, AC, DT, NSW, NUW);
- Offset *= RHS;
- Scale *= RHS;
- break;
+ case Instruction::Add: {
+ LinearExpression E = GetLinearExpression(
+ Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
+ E.Offset += RHS;
+ return E;
+ }
+ case Instruction::Sub: {
+ LinearExpression E = GetLinearExpression(
+ Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
+ E.Offset -= RHS;
+ return E;
+ }
+ case Instruction::Mul: {
+ LinearExpression E = GetLinearExpression(
+ Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
+ E.Offset *= RHS;
+ E.Scale *= RHS;
+ return E;
+ }
case Instruction::Shl:
// We're trying to linearize an expression of the kind:
// shl i8 -128, 36
// where the shift count exceeds the bitwidth of the type.
// We can't decompose this further (the expression would return
// a poison value).
- if (Offset.getBitWidth() < RHS.getLimitedValue() ||
- Scale.getBitWidth() < RHS.getLimitedValue()) {
- Scale = 1;
- Offset = 0;
- return V;
- }
-
- V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits,
- SExtBits, DL, Depth + 1, AC, DT, NSW, NUW);
- Offset <<= RHS.getLimitedValue();
- Scale <<= RHS.getLimitedValue();
- break;
+ if (RHS.getLimitedValue() > Val.getBitWidth())
+ return Val;
+
+ LinearExpression E = GetLinearExpression(
+ Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
+ E.Offset <<= RHS.getLimitedValue();
+ E.Scale <<= RHS.getLimitedValue();
+ return E;
}
-
- if (isa<OverflowingBinaryOperator>(BOp)) {
- NUW &= BOp->hasNoUnsignedWrap();
- NSW &= BOp->hasNoSignedWrap();
- }
- return V;
}
}
- // Since GEP indices are sign extended anyway, we don't care about the high
- // bits of a sign or zero extended value - just scales and offsets. The
- // extensions have to be consistent though.
- if (isa<SExtInst>(V) || isa<ZExtInst>(V)) {
- const Value *CastOp = cast<CastInst>(V)->getOperand(0);
- const Value *Result =
- GetLinearExpression(CastOp, Scale, Offset, ZExtBits, SExtBits, DL,
- Depth + 1, AC, DT, NSW, NUW);
- return extendLinearExpression(
- isa<SExtInst>(V), V->getType()->getPrimitiveSizeInBits(),
- CastOp, Result, Scale, Offset, ZExtBits, SExtBits, NSW, NUW);
- }
+ if (isa<ZExtInst>(Val.V))
+ return GetLinearExpression(
+ Val.withZExtOfValue(cast<CastInst>(Val.V)->getOperand(0)),
+ DL, Depth + 1, AC, DT);
+
+ if (isa<SExtInst>(Val.V))
+ return GetLinearExpression(
+ Val.withSExtOfValue(cast<CastInst>(Val.V)->getOperand(0)),
+ DL, Depth + 1, AC, DT);
- Scale = 1;
- Offset = 0;
- return V;
+ return Val;
}
/// To ensure a pointer offset fits in an integer of size PointerSize
@@ -537,21 +524,12 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
APInt Scale(MaxPointerSize,
DL.getTypeAllocSize(GTI.getIndexedType()).getFixedSize());
- // Use GetLinearExpression to decompose the index into a C1*V+C2 form.
- unsigned Width = Index->getType()->getIntegerBitWidth();
- APInt IndexScale(Width, 0), IndexOffset(Width, 0);
- unsigned ZExtBits = 0, SExtBits = 0;
- bool NSW = true, NUW = true;
- const Value *OrigIndex = Index;
- Index = GetLinearExpression(Index, IndexScale, IndexOffset, ZExtBits,
- SExtBits, DL, 0, AC, DT, NSW, NUW);
-
// If the integer type is smaller than the pointer size, it is implicitly
// sign extended to pointer size.
- if (PointerSize > Width)
- Index = extendLinearExpression(
- /* SignExt */ true, PointerSize, OrigIndex, Index, IndexScale,
- IndexOffset, ZExtBits, SExtBits, NSW, NUW);
+ unsigned Width = Index->getType()->getIntegerBitWidth();
+ unsigned SExtBits = PointerSize > Width ? PointerSize - Width : 0;
+ LinearExpression LE = GetLinearExpression(
+ ExtendedValue(Index, 0, SExtBits), DL, 0, AC, DT);
// The GEP index scale ("Scale") scales C1*V+C2, yielding (C1*V+C2)*Scale.
// This gives us an aggregate computation of (C1*Scale)*V + C2*Scale.
@@ -564,19 +542,13 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
// (C1*Scale)*V+C2*Scale can also overflow. We should check for this
// possibility.
bool Overflow;
- APInt ScaledOffset = IndexOffset.sextOrTrunc(MaxPointerSize)
+ APInt ScaledOffset = LE.Offset.sextOrTrunc(MaxPointerSize)
.smul_ov(Scale, Overflow);
if (Overflow) {
- Index = OrigIndex;
- IndexScale = 1;
- IndexOffset = 0;
-
- ZExtBits = SExtBits = 0;
- if (PointerSize > Width)
- SExtBits += PointerSize - Width;
+ LE = LinearExpression(ExtendedValue(Index, 0, SExtBits));
} else {
Decomposed.Offset += ScaledOffset;
- Scale *= IndexScale.sextOrTrunc(MaxPointerSize);
+ Scale *= LE.Scale.sextOrTrunc(MaxPointerSize);
}
// If we already had an occurrence of this index variable, merge this
@@ -584,9 +556,9 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
// A[x][x] -> x*16 + x*4 -> x*20
// This also ensures that 'x' only appears in the index list once.
for (unsigned i = 0, e = Decomposed.VarIndices.size(); i != e; ++i) {
- if (Decomposed.VarIndices[i].V == Index &&
- Decomposed.VarIndices[i].ZExtBits == ZExtBits &&
- Decomposed.VarIndices[i].SExtBits == SExtBits) {
+ if (Decomposed.VarIndices[i].V == LE.Val.V &&
+ Decomposed.VarIndices[i].ZExtBits == LE.Val.ZExtBits &&
+ Decomposed.VarIndices[i].SExtBits == LE.Val.SExtBits) {
Scale += Decomposed.VarIndices[i].Scale;
Decomposed.VarIndices.erase(Decomposed.VarIndices.begin() + i);
break;
@@ -598,7 +570,8 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
Scale = adjustToPointerSize(Scale, PointerSize);
if (!!Scale) {
- VariableGEPIndex Entry = {Index, ZExtBits, SExtBits, Scale, CxtI};
+ VariableGEPIndex Entry = {LE.Val.V, LE.Val.ZExtBits, LE.Val.SExtBits,
+ Scale, CxtI};
Decomposed.VarIndices.push_back(Entry);
}
}
@@ -1746,25 +1719,17 @@ bool BasicAAResult::constantOffsetHeuristic(
Var0.Scale != -Var1.Scale)
return false;
- unsigned Width = Var1.V->getType()->getIntegerBitWidth();
-
// We'll strip off the Extensions of Var0 and Var1 and do another round
// of GetLinearExpression decomposition. In the example above, if Var0
// is zext(%x + 1) we should get V1 == %x and V1Offset == 1.
- APInt V0Scale(Width, 0), V0Offset(Width, 0), V1Scale(Width, 0),
- V1Offset(Width, 0);
- bool NSW = true, NUW = true;
- unsigned V0ZExtBits = 0, V0SExtBits = 0, V1ZExtBits = 0, V1SExtBits = 0;
- const Value *V0 = GetLinearExpression(Var0.V, V0Scale, V0Offset, V0ZExtBits,
- V0SExtBits, DL, 0, AC, DT, NSW, NUW);
- NSW = true;
- NUW = true;
- const Value *V1 = GetLinearExpression(Var1.V, V1Scale, V1Offset, V1ZExtBits,
- V1SExtBits, DL, 0, AC, DT, NSW, NUW);
-
- if (V0Scale != V1Scale || V0ZExtBits != V1ZExtBits ||
- V0SExtBits != V1SExtBits || !isValueEqualInPotentialCycles(V0, V1))
+ LinearExpression E0 =
+ GetLinearExpression(ExtendedValue(Var0.V), DL, 0, AC, DT);
+ LinearExpression E1 =
+ GetLinearExpression(ExtendedValue(Var1.V), DL, 0, AC, DT);
+ if (E0.Scale != E1.Scale || E0.Val.ZExtBits != E1.Val.ZExtBits ||
+ E0.Val.SExtBits != E1.Val.SExtBits ||
+ !isValueEqualInPotentialCycles(E0.Val.V, E1.Val.V))
return false;
// We have a hit - Var0 and Var1 only
diff er by a constant offset!
@@ -1774,7 +1739,7 @@ bool BasicAAResult::constantOffsetHeuristic(
// minimum
diff erence between the two. The minimum distance may occur due to
// wrapping; consider "add i3 %i, 5": if %i == 7 then 7 + 5 mod 8 == 4, and so
// the minimum distance between %i and %i + 5 is 3.
- APInt MinDiff = V0Offset - V1Offset, Wrapped = -MinDiff;
+ APInt MinDiff = E0.Offset - E1.Offset, Wrapped = -MinDiff;
MinDiff = APIntOps::umin(MinDiff, Wrapped);
APInt MinDiffBytes =
MinDiff.zextOrTrunc(Var0.Scale.getBitWidth()) * Var0.Scale.abs();
diff --git a/llvm/test/Analysis/BasicAA/zext.ll b/llvm/test/Analysis/BasicAA/zext.ll
index a1fc10a48a3b..8e5d5fe0cf2c 100644
--- a/llvm/test/Analysis/BasicAA/zext.ll
+++ b/llvm/test/Analysis/BasicAA/zext.ll
@@ -275,5 +275,16 @@ define void @test_implicit_sext(i8* %p, i32 %x) {
ret void
}
+; CHECK-LABEL: Function: test_partial_decomposition
+; CHECK: MustAlias: i8* %p.1, i8* %p.2
+define void @test_partial_decomposition(i8* %p, i32 %x) {
+ %add = add i32 %x, 1
+ %add.1 = add nsw i32 %add, 1
+ %add.2 = add nsw i32 %add, 1
+ %p.1 = getelementptr i8, i8* %p, i32 %add.1
+ %p.2 = getelementptr i8, i8* %p, i32 %add.2
+ ret void
+}
+
; Function Attrs: nounwind
declare noalias i8* @malloc(i64)
More information about the llvm-commits
mailing list