[llvm] b9bba6c - [BasicAA] Track nuw through decomposed expressions (#106512)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 2 03:11:06 PDT 2024
Author: Nikita Popov
Date: 2024-09-02T12:11:03+02:00
New Revision: b9bba6ca9fc62c5ae3ee402196b11a523a500fdc
URL: https://github.com/llvm/llvm-project/commit/b9bba6ca9fc62c5ae3ee402196b11a523a500fdc
DIFF: https://github.com/llvm/llvm-project/commit/b9bba6ca9fc62c5ae3ee402196b11a523a500fdc.diff
LOG: [BasicAA] Track nuw through decomposed expressions (#106512)
When we decompose the GEP offset expression, and the arithmetic is not
performed using nuw operations, we cannot retain the nuw flag on the
decomposed GEP.
For example, if we have `gep nuw p, (a-1)`, this is not at all the same
as `gep nuw (gep nuw p, a), -1`.
Fix this by tracking NUW through linear expression decomposition,
similarly to what we already do for the NSW flag.
This fixes the miscompilation reported in
https://github.com/llvm/llvm-project/pull/105496#issuecomment-2315322220.
Added:
Modified:
llvm/lib/Analysis/BasicAliasAnalysis.cpp
llvm/test/Analysis/BasicAA/gep-nuw-alias.ll
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
index 72db28929c0c37..a00ed7530ebc4c 100644
--- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp
+++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
@@ -375,24 +375,28 @@ struct LinearExpression {
APInt Scale;
APInt Offset;
+ /// True if all operations in this expression are NUW.
+ bool IsNUW;
/// True if all operations in this expression are NSW.
bool IsNSW;
LinearExpression(const CastedValue &Val, const APInt &Scale,
- const APInt &Offset, bool IsNSW)
- : Val(Val), Scale(Scale), Offset(Offset), IsNSW(IsNSW) {}
+ const APInt &Offset, bool IsNUW, bool IsNSW)
+ : Val(Val), Scale(Scale), Offset(Offset), IsNUW(IsNUW), IsNSW(IsNSW) {}
- LinearExpression(const CastedValue &Val) : Val(Val), IsNSW(true) {
+ LinearExpression(const CastedValue &Val)
+ : Val(Val), IsNUW(true), IsNSW(true) {
unsigned BitWidth = Val.getBitWidth();
Scale = APInt(BitWidth, 1);
Offset = APInt(BitWidth, 0);
}
- LinearExpression mul(const APInt &Other, bool MulIsNSW) const {
+ LinearExpression mul(const APInt &Other, bool MulIsNUW, bool MulIsNSW) const {
// The check for zero offset is necessary, because generally
// (X +nsw Y) *nsw Z does not imply (X *nsw Z) +nsw (Y *nsw Z).
bool NSW = IsNSW && (Other.isOne() || (MulIsNSW && Offset.isZero()));
- return LinearExpression(Val, Scale * Other, Offset * Other, NSW);
+ bool NUW = IsNUW && (Other.isOne() || MulIsNUW);
+ return LinearExpression(Val, Scale * Other, Offset * Other, NUW, NSW);
}
};
}
@@ -408,7 +412,7 @@ static LinearExpression GetLinearExpression(
if (const ConstantInt *Const = dyn_cast<ConstantInt>(Val.V))
return LinearExpression(Val, APInt(Val.getBitWidth(), 0),
- Val.evaluateWith(Const->getValue()), true);
+ Val.evaluateWith(Const->getValue()), true, true);
if (const BinaryOperator *BOp = dyn_cast<BinaryOperator>(Val.V)) {
if (ConstantInt *RHSC = dyn_cast<ConstantInt>(BOp->getOperand(1))) {
@@ -444,6 +448,7 @@ static LinearExpression GetLinearExpression(
E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT);
E.Offset += RHS;
+ E.IsNUW &= NUW;
E.IsNSW &= NSW;
break;
}
@@ -451,13 +456,14 @@ static LinearExpression GetLinearExpression(
E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT);
E.Offset -= RHS;
+ E.IsNUW = false; // sub nuw x, y is not add nuw x, -y.
E.IsNSW &= NSW;
break;
}
case Instruction::Mul:
E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT)
- .mul(RHS, NSW);
+ .mul(RHS, NUW, NSW);
break;
case Instruction::Shl:
// We're trying to linearize an expression of the kind:
@@ -472,6 +478,7 @@ static LinearExpression GetLinearExpression(
Depth + 1, AC, DT);
E.Offset <<= RHS.getLimitedValue();
E.Scale <<= RHS.getLimitedValue();
+ E.IsNUW &= NUW;
E.IsNSW &= NSW;
break;
}
@@ -697,7 +704,8 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
// If the integer type is smaller than the index size, it is implicitly
// sign extended or truncated to index size.
bool NUSW = GEPOp->hasNoUnsignedSignedWrap();
- bool NonNeg = NUSW && GEPOp->hasNoUnsignedWrap();
+ bool NUW = GEPOp->hasNoUnsignedWrap();
+ bool NonNeg = NUSW && NUW;
unsigned Width = Index->getType()->getIntegerBitWidth();
unsigned SExtBits = IndexSize > Width ? IndexSize - Width : 0;
unsigned TruncBits = IndexSize < Width ? Width - IndexSize : 0;
@@ -706,9 +714,11 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
// Scale by the type size.
unsigned TypeSize = AllocTypeSize.getFixedValue();
- LE = LE.mul(APInt(IndexSize, TypeSize), NUSW);
+ LE = LE.mul(APInt(IndexSize, TypeSize), NUW, NUSW);
Decomposed.Offset += LE.Offset.sext(MaxIndexSize);
APInt Scale = LE.Scale.sext(MaxIndexSize);
+ if (!LE.IsNUW)
+ Decomposed.NWFlags = Decomposed.NWFlags.withoutNoUnsignedWrap();
// If we already had an occurrence of this index variable, merge this
// scale into it. For example, we want to handle:
@@ -719,7 +729,8 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
areBothVScale(Decomposed.VarIndices[i].Val.V, LE.Val.V)) &&
Decomposed.VarIndices[i].Val.hasSameCastsAs(LE.Val)) {
Scale += Decomposed.VarIndices[i].Scale;
- LE.IsNSW = false; // We cannot guarantee nsw for the merge.
+ // We cannot guarantee no-wrap for the merge.
+ LE.IsNSW = LE.IsNUW = false;
Decomposed.VarIndices.erase(Decomposed.VarIndices.begin() + i);
break;
}
diff --git a/llvm/test/Analysis/BasicAA/gep-nuw-alias.ll b/llvm/test/Analysis/BasicAA/gep-nuw-alias.ll
index b80a457f85176c..a5f1c1c747cc3f 100644
--- a/llvm/test/Analysis/BasicAA/gep-nuw-alias.ll
+++ b/llvm/test/Analysis/BasicAA/gep-nuw-alias.ll
@@ -212,3 +212,106 @@ define void @both_var_idx(ptr %p, i64 %i, i64 %j) {
ret void
}
+
+; CHECK-LABEL: add_no_nuw
+; CHECK: MayAlias: i8* %gep, i8* %p
+define i8 @add_no_nuw(ptr %p, i64 %n) {
+ store i8 3, ptr %p
+
+ %add = add i64 %n, 1
+ %gep = getelementptr nuw i8, ptr %p, i64 %add
+ %val = load i8, ptr %gep
+ ret i8 %val
+}
+
+; CHECK-LABEL: add_nuw
+; CHECK: NoAlias: i8* %gep, i8* %p
+define i8 @add_nuw(ptr %p, i64 %n) {
+ store i8 3, ptr %p
+
+ %add = add nuw i64 %n, 1
+ %gep = getelementptr nuw i8, ptr %p, i64 %add
+ %val = load i8, ptr %gep
+ ret i8 %val
+}
+
+; CHECK-LABEL: add_no_nuw
+; CHECK: MayAlias: i8* %gep, i16* %p
+define i8 @add_no_nuw_scale(ptr %p, i64 %n) {
+ store i16 3, ptr %p
+
+ %add = add i64 %n, 1
+ %gep = getelementptr nuw i16, ptr %p, i64 %add
+ %val = load i8, ptr %gep
+ ret i8 %val
+}
+
+; CHECK-LABEL: add_nuw
+; CHECK: NoAlias: i8* %gep, i16* %p
+define i8 @add_nuw_scale(ptr %p, i64 %n) {
+ store i16 3, ptr %p
+
+ %add = add nuw i64 %n, 1
+ %gep = getelementptr nuw i16, ptr %p, i64 %add
+ %val = load i8, ptr %gep
+ ret i8 %val
+}
+
+; CHECK-LABEL: sub_nuw
+; CHECK: MayAlias: i8* %gep, i8* %p
+define i8 @sub_nuw(ptr %p, i64 %n) {
+ store i8 3, ptr %p
+
+ %add = sub nuw i64 %n, 1
+ %gep = getelementptr nuw i8, ptr %p, i64 %add
+ %val = load i8, ptr %gep
+ ret i8 %val
+}
+
+; CHECK-LABEL: mul_no_nuw
+; CHECK: MayAlias: i8* %gep, i16* %p
+define i8 @mul_no_nuw(ptr %p, i64 %n) {
+ store i16 3, ptr %p
+
+ %add = add nuw i64 %n, 1
+ %mul = mul i64 %add, 2
+ %gep = getelementptr nuw i8, ptr %p, i64 %mul
+ %val = load i8, ptr %gep
+ ret i8 %val
+}
+
+; CHECK-LABEL: mul_nuw
+; CHECK: NoAlias: i8* %gep, i16* %p
+define i8 @mul_nuw(ptr %p, i64 %n) {
+ store i16 3, ptr %p
+
+ %add = add nuw i64 %n, 1
+ %mul = mul nuw i64 %add, 2
+ %gep = getelementptr nuw i8, ptr %p, i64 %mul
+ %val = load i8, ptr %gep
+ ret i8 %val
+}
+
+; CHECK-LABEL: shl_no_nuw
+; CHECK: MayAlias: i8* %gep, i16* %p
+define i8 @shl_no_nuw(ptr %p, i64 %n) {
+ store i16 3, ptr %p
+
+ %add = add nuw i64 %n, 1
+ %shl = shl i64 %add, 1
+ %gep = getelementptr nuw i8, ptr %p, i64 %shl
+ %val = load i8, ptr %gep
+ ret i8 %val
+}
+
+; CHECK-LABEL: shl_nuw
+; CHECK: NoAlias: i8* %gep, i16* %p
+define i8 @shl_nuw(ptr %p, i64 %n) {
+ store i16 3, ptr %p
+
+ %add = add nuw i64 %n, 1
+ %shl = shl nuw i64 %add, 1
+ %gep = getelementptr nuw i8, ptr %p, i64 %shl
+ %val = load i8, ptr %gep
+ ret i8 %val
+}
More information about the llvm-commits
mailing list