[llvm] d881bac - [BasicAA] Consider 'nneg' flag when comparing CastedValues (#94129)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 4 08:33:00 PDT 2024
Author: Alex MacLean
Date: 2024-06-04T08:32:57-07:00
New Revision: d881bac6fa3b1d8d622d4fb651060cf7d6223080
URL: https://github.com/llvm/llvm-project/commit/d881bac6fa3b1d8d622d4fb651060cf7d6223080
DIFF: https://github.com/llvm/llvm-project/commit/d881bac6fa3b1d8d622d4fb651060cf7d6223080.diff
LOG: [BasicAA] Consider 'nneg' flag when comparing CastedValues (#94129)
Any of the `zext` bits in a `zext nneg` can be converted to `sext` but
when checking if casts are compatible `BasicAA` fails to take into
account `nneg`. This change adds tracking of `nneg` to the `CastedValue`
struct and ensures that `sext` and `zext` bits are treated as
interchangeable when either `CastedValue` has a `nneg`. When
distributing casted values in `GetLinearExpression` we conservatively
discard the `nneg` from the `CastedValue`, except in the case of `shl
nsw`, where we know the sign has not changed to negative.
Added:
llvm/test/Analysis/BasicAA/zext-nneg.ll
Modified:
llvm/lib/Analysis/BasicAliasAnalysis.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
index 3f456db1c51ac..c110943ad0d58 100644
--- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp
+++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
@@ -268,31 +268,43 @@ struct CastedValue {
unsigned ZExtBits = 0;
unsigned SExtBits = 0;
unsigned TruncBits = 0;
+ /// Whether trunc(V) is non-negative.
+ bool IsNonNegative = false;
explicit CastedValue(const Value *V) : V(V) {}
explicit CastedValue(const Value *V, unsigned ZExtBits, unsigned SExtBits,
- unsigned TruncBits)
- : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits) {}
+ unsigned TruncBits, bool IsNonNegative)
+ : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits),
+ IsNonNegative(IsNonNegative) {}
unsigned getBitWidth() const {
return V->getType()->getPrimitiveSizeInBits() - TruncBits + ZExtBits +
SExtBits;
}
- CastedValue withValue(const Value *NewV) const {
- return CastedValue(NewV, ZExtBits, SExtBits, TruncBits);
+ CastedValue withValue(const Value *NewV, bool PreserveNonNeg) const {
+ return CastedValue(NewV, ZExtBits, SExtBits, TruncBits,
+ IsNonNegative && PreserveNonNeg);
}
/// Replace V with zext(NewV)
- CastedValue withZExtOfValue(const Value *NewV) const {
+ CastedValue withZExtOfValue(const Value *NewV, bool ZExtNonNegative) const {
unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
NewV->getType()->getPrimitiveSizeInBits();
if (ExtendBy <= TruncBits)
- return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+ // zext<nneg>(trunc(zext(NewV))) == zext<nneg>(trunc(NewV))
+ // The nneg can be preserved on the outer zext here.
+ return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
+ IsNonNegative);
// zext(sext(zext(NewV))) == zext(zext(zext(NewV)))
ExtendBy -= TruncBits;
- return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0);
+ // zext<nneg>(zext(NewV)) == zext(NewV)
+ // zext(zext<nneg>(NewV)) == zext<nneg>(NewV)
+ // The nneg can be preserved from the inner zext here but must be dropped
+ // from the outer.
+ return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0,
+ ZExtNonNegative);
}
/// Replace V with sext(NewV)
@@ -300,11 +312,16 @@ struct CastedValue {
unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
NewV->getType()->getPrimitiveSizeInBits();
if (ExtendBy <= TruncBits)
- return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+ // zext<nneg>(trunc(sext(NewV))) == zext<nneg>(trunc(NewV))
+ // The nneg can be preserved on the outer zext here
+ return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
+ IsNonNegative);
// zext(sext(sext(NewV)))
ExtendBy -= TruncBits;
- return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0);
+ // zext<nneg>(sext(sext(NewV))) = zext<nneg>(sext(NewV))
+ // The nneg can be preserved on the outer zext here
+ return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0, IsNonNegative);
}
APInt evaluateWith(APInt N) const {
@@ -333,8 +350,15 @@ struct CastedValue {
}
bool hasSameCastsAs(const CastedValue &Other) const {
- return ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
- TruncBits == Other.TruncBits;
+ if (ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
+ TruncBits == Other.TruncBits)
+ return true;
+ // If either CastedValue has a nneg zext then the sext/zext bits are
+ // interchangable for that value.
+ if (IsNonNegative || Other.IsNonNegative)
+ return (ZExtBits + SExtBits == Other.ZExtBits + Other.SExtBits &&
+ TruncBits == Other.TruncBits);
+ return false;
}
};
@@ -410,21 +434,21 @@ static LinearExpression GetLinearExpression(
[[fallthrough]];
case Instruction::Add: {
- E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+ E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT);
E.Offset += RHS;
E.IsNSW &= NSW;
break;
}
case Instruction::Sub: {
- E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+ E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT);
E.Offset -= RHS;
E.IsNSW &= NSW;
break;
}
case Instruction::Mul:
- E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+ E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT)
.mul(RHS, NSW);
break;
@@ -437,7 +461,7 @@ static LinearExpression GetLinearExpression(
if (RHS.getLimitedValue() > Val.getBitWidth())
return Val;
- E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+ E = GetLinearExpression(Val.withValue(BOp->getOperand(0), NSW), DL,
Depth + 1, AC, DT);
E.Offset <<= RHS.getLimitedValue();
E.Scale <<= RHS.getLimitedValue();
@@ -448,10 +472,10 @@ static LinearExpression GetLinearExpression(
}
}
- if (isa<ZExtInst>(Val.V))
+ if (const auto *ZExt = dyn_cast<ZExtInst>(Val.V))
return GetLinearExpression(
- Val.withZExtOfValue(cast<CastInst>(Val.V)->getOperand(0)),
- DL, Depth + 1, AC, DT);
+ Val.withZExtOfValue(ZExt->getOperand(0), ZExt->hasNonNeg()), DL,
+ Depth + 1, AC, DT);
if (isa<SExtInst>(Val.V))
return GetLinearExpression(
@@ -673,7 +697,7 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
unsigned SExtBits = IndexSize > Width ? IndexSize - Width : 0;
unsigned TruncBits = IndexSize < Width ? Width - IndexSize : 0;
LinearExpression LE = GetLinearExpression(
- CastedValue(Index, 0, SExtBits, TruncBits), DL, 0, AC, DT);
+ CastedValue(Index, 0, SExtBits, TruncBits, false), DL, 0, AC, DT);
// Scale by the type size.
unsigned TypeSize = AllocTypeSize.getFixedValue();
diff --git a/llvm/test/Analysis/BasicAA/zext-nneg.ll b/llvm/test/Analysis/BasicAA/zext-nneg.ll
new file mode 100644
index 0000000000000..808bb1a8c9d96
--- /dev/null
+++ b/llvm/test/Analysis/BasicAA/zext-nneg.ll
@@ -0,0 +1,181 @@
+; RUN: opt < %s -aa-pipeline=basic-aa -passes=aa-eval -print-all-alias-modref-info -disable-output 2>&1 | FileCheck %s
+
+;; Simple case: a zext nneg can be replaced with a sext. Make sure BasicAA
+;; understands that.
+define void @t1(i32 %a, i32 %b) {
+; CHECK-LABEL: Function: t1
+; CHECK: NoAlias: float* %gep1, float* %gep2
+
+ %1 = alloca [8 x float], align 4
+ %or1 = or i32 %a, 1
+ %2 = sext i32 %or1 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+ %shl1 = shl i32 %b, 1
+ %3 = zext nneg i32 %shl1 to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %3
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
+
+;; A (zext nneg (sext V)) is equivalent to a (zext (sext V)) as long as the
+;; total number of zext+sext bits is the same for both.
+define void @t2(i8 %a, i8 %b) {
+; CHECK-LABEL: Function: t2
+; CHECK: NoAlias: float* %gep1, float* %gep2
+ %1 = alloca [8 x float], align 4
+ %or1 = or i8 %a, 1
+ %2 = sext i8 %or1 to i32
+ %3 = zext i32 %2 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %3
+
+ %shl1 = shl i8 %b, 1
+ %4 = sext i8 %shl1 to i16
+ %5 = zext nneg i16 %4 to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
+
+;; Here the %a and %b are knowably non-equal. In this cases we can distribute
+;; the zext, preserving the nneg flag, through the shl because it has a nsw flag
+define void @t3(i8 %v) {
+; CHECK-LABEL: Function: t3
+; CHECK: NoAlias: <2 x float>* %gep1, <2 x float>* %gep2
+ %a = or i8 %v, 1
+ %b = and i8 %v, 2
+
+ %1 = alloca [8 x float], align 4
+ %or1 = shl nuw nsw i8 %a, 1
+ %2 = zext nneg i8 %or1 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+ %m = mul nsw nuw i8 %b, 2
+ %3 = sext i8 %m to i16
+ %4 = zext i16 %3 to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %4
+
+ load <2 x float>, ptr %gep1
+ load <2 x float>, ptr %gep2
+ ret void
+}
+
+;; This is the same as above, but this time the shl does not have the nsw flag.
+;; the nneg cannot be kept on the zext.
+define void @t4(i8 %v) {
+; CHECK-LABEL: Function: t4
+; CHECK: MayAlias: <2 x float>* %gep1, <2 x float>* %gep2
+ %a = or i8 %v, 1
+ %b = and i8 %v, 2
+
+ %1 = alloca [8 x float], align 4
+ %or1 = shl nuw i8 %a, 1
+ %2 = zext nneg i8 %or1 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+ %m = mul nsw nuw i8 %b, 2
+ %3 = sext i8 %m to i16
+ %4 = zext i16 %3 to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %4
+
+ load <2 x float>, ptr %gep1
+ load <2 x float>, ptr %gep2
+ ret void
+}
+
+;; Verify a zext nneg and a zext are understood as the same
+define void @t5(ptr %p, i16 %i) {
+; CHECK-LABEL: Function: t5
+; CHECK: NoAlias: i32* %pi, i32* %pi.next
+ %i1 = zext nneg i16 %i to i32
+ %pi = getelementptr i32, ptr %p, i32 %i1
+
+ %i.next = add i16 %i, 1
+ %i.next2 = zext i16 %i.next to i32
+ %pi.next = getelementptr i32, ptr %p, i32 %i.next2
+
+ load i32, ptr %pi
+ load i32, ptr %pi.next
+ ret void
+}
+
+;; This is not very idiomatic, but still possible, verify the nneg is propagated
+;; outward. and that no alias is correctly identified.
+define void @t6(i8 %a) {
+; CHECK-LABEL: Function: t6
+; CHECK: NoAlias: float* %gep1, float* %gep2
+ %1 = alloca [8 x float], align 4
+ %a.add = add i8 %a, 1
+ %2 = zext nneg i8 %a.add to i16
+ %3 = sext i16 %2 to i32
+ %4 = zext i32 %3 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+ %5 = sext i8 %a to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
+
+;; This is even less idiomatic, but still possible, verify the nneg is not
+;; propagated inward. and that may alias is correctly identified.
+define void @t7(i8 %a) {
+; CHECK-LABEL: Function: t7
+; CHECK: MayAlias: float* %gep1, float* %gep2
+ %1 = alloca [8 x float], align 4
+ %a.add = add i8 %a, 1
+ %2 = zext i8 %a.add to i16
+ %3 = sext i16 %2 to i32
+ %4 = zext nneg i32 %3 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+ %5 = sext i8 %a to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
+
+;; Verify the nneg survives an implicit trunc of fewer bits then the zext.
+define void @t8(i8 %a) {
+; CHECK-LABEL: Function: t8
+; CHECK: NoAlias: float* %gep1, float* %gep2
+ %1 = alloca [8 x float], align 4
+ %a.add = add i8 %a, 1
+ %2 = zext nneg i8 %a.add to i128
+ %gep1 = getelementptr inbounds float, ptr %1, i128 %2
+
+ %3 = sext i8 %a to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %3
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
+
+;; Ensure that the nneg is never propagated past this trunc and that these
+;; casted values are understood as non-equal.
+define void @t9(i8 %a) {
+; CHECK-LABEL: Function: t9
+; CHECK: MayAlias: float* %gep1, float* %gep2
+ %1 = alloca [8 x float], align 4
+ %a.add = add i8 %a, 1
+ %2 = zext i8 %a.add to i16
+ %3 = trunc i16 %2 to i1
+ %4 = zext nneg i1 %3 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+ %5 = sext i8 %a to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
More information about the llvm-commits
mailing list