[llvm] d881bac - [BasicAA] Consider 'nneg' flag when comparing CastedValues (#94129)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 4 08:33:00 PDT 2024


Author: Alex MacLean
Date: 2024-06-04T08:32:57-07:00
New Revision: d881bac6fa3b1d8d622d4fb651060cf7d6223080

URL: https://github.com/llvm/llvm-project/commit/d881bac6fa3b1d8d622d4fb651060cf7d6223080
DIFF: https://github.com/llvm/llvm-project/commit/d881bac6fa3b1d8d622d4fb651060cf7d6223080.diff

LOG: [BasicAA] Consider 'nneg' flag when comparing CastedValues (#94129)

Any of the `zext` bits in a `zext nneg` can be converted to `sext` but
when checking if casts are compatible `BasicAA` fails to take into
account `nneg`. This change adds tracking of `nneg` to the `CastedValue`
struct and ensures that `sext` and `zext` bits are treated as
interchangeable when either `CastedValue` has a `nneg`. When
distributing casted values in `GetLinearExpression` we conservatively
discard the `nneg` from the `CastedValue`, except in the case of `shl
nsw`, where we know the sign has not changed to negative.

Added: 
    llvm/test/Analysis/BasicAA/zext-nneg.ll

Modified: 
    llvm/lib/Analysis/BasicAliasAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
index 3f456db1c51ac..c110943ad0d58 100644
--- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp
+++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
@@ -268,31 +268,43 @@ struct CastedValue {
   unsigned ZExtBits = 0;
   unsigned SExtBits = 0;
   unsigned TruncBits = 0;
+  /// Whether trunc(V) is non-negative.
+  bool IsNonNegative = false;
 
   explicit CastedValue(const Value *V) : V(V) {}
   explicit CastedValue(const Value *V, unsigned ZExtBits, unsigned SExtBits,
-                       unsigned TruncBits)
-      : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits) {}
+                       unsigned TruncBits, bool IsNonNegative)
+      : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits),
+        IsNonNegative(IsNonNegative) {}
 
   unsigned getBitWidth() const {
     return V->getType()->getPrimitiveSizeInBits() - TruncBits + ZExtBits +
            SExtBits;
   }
 
-  CastedValue withValue(const Value *NewV) const {
-    return CastedValue(NewV, ZExtBits, SExtBits, TruncBits);
+  CastedValue withValue(const Value *NewV, bool PreserveNonNeg) const {
+    return CastedValue(NewV, ZExtBits, SExtBits, TruncBits,
+                       IsNonNegative && PreserveNonNeg);
   }
 
   /// Replace V with zext(NewV)
-  CastedValue withZExtOfValue(const Value *NewV) const {
+  CastedValue withZExtOfValue(const Value *NewV, bool ZExtNonNegative) const {
     unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
                         NewV->getType()->getPrimitiveSizeInBits();
     if (ExtendBy <= TruncBits)
-      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+      // zext<nneg>(trunc(zext(NewV))) == zext<nneg>(trunc(NewV))
+      // The nneg can be preserved on the outer zext here.
+      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
+                         IsNonNegative);
 
     // zext(sext(zext(NewV))) == zext(zext(zext(NewV)))
     ExtendBy -= TruncBits;
-    return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0);
+    // zext<nneg>(zext(NewV)) == zext(NewV)
+    // zext(zext<nneg>(NewV)) == zext<nneg>(NewV)
+    // The nneg can be preserved from the inner zext here but must be dropped
+    // from the outer.
+    return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0,
+                       ZExtNonNegative);
   }
 
   /// Replace V with sext(NewV)
@@ -300,11 +312,16 @@ struct CastedValue {
     unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
                         NewV->getType()->getPrimitiveSizeInBits();
     if (ExtendBy <= TruncBits)
-      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+      // zext<nneg>(trunc(sext(NewV))) == zext<nneg>(trunc(NewV))
+      // The nneg can be preserved on the outer zext here
+      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
+                         IsNonNegative);
 
     // zext(sext(sext(NewV)))
     ExtendBy -= TruncBits;
-    return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0);
+    // zext<nneg>(sext(sext(NewV))) = zext<nneg>(sext(NewV))
+    // The nneg can be preserved on the outer zext here
+    return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0, IsNonNegative);
   }
 
   APInt evaluateWith(APInt N) const {
@@ -333,8 +350,15 @@ struct CastedValue {
   }
 
   bool hasSameCastsAs(const CastedValue &Other) const {
-    return ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
-           TruncBits == Other.TruncBits;
+    if (ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
+        TruncBits == Other.TruncBits)
+      return true;
+    // If either CastedValue has a nneg zext then the sext/zext bits are
+    // interchangable for that value.
+    if (IsNonNegative || Other.IsNonNegative)
+      return (ZExtBits + SExtBits == Other.ZExtBits + Other.SExtBits &&
+              TruncBits == Other.TruncBits);
+    return false;
   }
 };
 
@@ -410,21 +434,21 @@ static LinearExpression GetLinearExpression(
 
         [[fallthrough]];
       case Instruction::Add: {
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
                                 Depth + 1, AC, DT);
         E.Offset += RHS;
         E.IsNSW &= NSW;
         break;
       }
       case Instruction::Sub: {
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
                                 Depth + 1, AC, DT);
         E.Offset -= RHS;
         E.IsNSW &= NSW;
         break;
       }
       case Instruction::Mul:
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
                                 Depth + 1, AC, DT)
                 .mul(RHS, NSW);
         break;
@@ -437,7 +461,7 @@ static LinearExpression GetLinearExpression(
         if (RHS.getLimitedValue() > Val.getBitWidth())
           return Val;
 
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), NSW), DL,
                                 Depth + 1, AC, DT);
         E.Offset <<= RHS.getLimitedValue();
         E.Scale <<= RHS.getLimitedValue();
@@ -448,10 +472,10 @@ static LinearExpression GetLinearExpression(
     }
   }
 
-  if (isa<ZExtInst>(Val.V))
+  if (const auto *ZExt = dyn_cast<ZExtInst>(Val.V))
     return GetLinearExpression(
-        Val.withZExtOfValue(cast<CastInst>(Val.V)->getOperand(0)),
-        DL, Depth + 1, AC, DT);
+        Val.withZExtOfValue(ZExt->getOperand(0), ZExt->hasNonNeg()), DL,
+        Depth + 1, AC, DT);
 
   if (isa<SExtInst>(Val.V))
     return GetLinearExpression(
@@ -673,7 +697,7 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
       unsigned SExtBits = IndexSize > Width ? IndexSize - Width : 0;
       unsigned TruncBits = IndexSize < Width ? Width - IndexSize : 0;
       LinearExpression LE = GetLinearExpression(
-          CastedValue(Index, 0, SExtBits, TruncBits), DL, 0, AC, DT);
+          CastedValue(Index, 0, SExtBits, TruncBits, false), DL, 0, AC, DT);
 
       // Scale by the type size.
       unsigned TypeSize = AllocTypeSize.getFixedValue();

diff  --git a/llvm/test/Analysis/BasicAA/zext-nneg.ll b/llvm/test/Analysis/BasicAA/zext-nneg.ll
new file mode 100644
index 0000000000000..808bb1a8c9d96
--- /dev/null
+++ b/llvm/test/Analysis/BasicAA/zext-nneg.ll
@@ -0,0 +1,181 @@
+; RUN: opt < %s -aa-pipeline=basic-aa -passes=aa-eval -print-all-alias-modref-info -disable-output 2>&1 | FileCheck %s
+
+;; Simple case: a zext nneg can be replaced with a sext. Make sure BasicAA
+;; understands that.
+define void @t1(i32 %a, i32 %b) {
+; CHECK-LABEL: Function: t1
+; CHECK: NoAlias: float* %gep1, float* %gep2
+
+  %1 = alloca [8 x float], align 4
+  %or1 = or i32 %a, 1
+  %2 = sext i32 %or1 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+  %shl1 = shl i32 %b, 1
+  %3 = zext nneg i32 %shl1 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %3
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; A (zext nneg (sext V)) is equivalent to a (zext (sext V)) as long as the
+;; total number of zext+sext bits is the same for both.
+define void @t2(i8 %a, i8 %b) {
+; CHECK-LABEL: Function: t2
+; CHECK: NoAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %or1 = or i8 %a, 1
+  %2 = sext i8 %or1 to i32
+  %3 = zext i32 %2 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %3
+
+  %shl1 = shl i8 %b, 1
+  %4 = sext i8 %shl1 to i16
+  %5 = zext nneg i16 %4 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; Here the %a and %b are knowably non-equal. In this cases we can distribute
+;; the zext, preserving the nneg flag, through the shl because it has a nsw flag
+define void @t3(i8 %v) {
+; CHECK-LABEL: Function: t3
+; CHECK: NoAlias: <2 x float>* %gep1, <2 x float>* %gep2
+  %a = or i8 %v, 1
+  %b = and i8 %v, 2
+
+  %1 = alloca [8 x float], align 4
+  %or1 = shl nuw nsw i8 %a, 1
+  %2 = zext nneg i8 %or1 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+  %m = mul nsw nuw i8 %b, 2
+  %3 = sext i8 %m to i16
+  %4 = zext i16 %3 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %4
+
+  load <2 x float>, ptr %gep1
+  load <2 x float>, ptr %gep2
+  ret void
+}
+
+;; This is the same as above, but this time the shl does not have the nsw flag.
+;; the nneg cannot be kept on the zext.
+define void @t4(i8 %v) {
+; CHECK-LABEL: Function: t4
+; CHECK: MayAlias: <2 x float>* %gep1, <2 x float>* %gep2
+  %a = or i8 %v, 1
+  %b = and i8 %v, 2
+
+  %1 = alloca [8 x float], align 4
+  %or1 = shl nuw i8 %a, 1
+  %2 = zext nneg i8 %or1 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+  %m = mul nsw nuw i8 %b, 2
+  %3 = sext i8 %m to i16
+  %4 = zext i16 %3 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %4
+
+  load <2 x float>, ptr %gep1
+  load <2 x float>, ptr %gep2
+  ret void
+}
+
+;; Verify a zext nneg and a zext are understood as the same
+define void @t5(ptr %p, i16 %i) {
+; CHECK-LABEL: Function: t5
+; CHECK: NoAlias: i32* %pi, i32* %pi.next
+  %i1 = zext nneg i16 %i to i32
+  %pi = getelementptr i32, ptr %p, i32 %i1
+
+  %i.next = add i16 %i, 1
+  %i.next2 = zext i16 %i.next to i32
+  %pi.next = getelementptr i32, ptr %p, i32 %i.next2
+
+  load i32, ptr %pi
+  load i32, ptr %pi.next
+  ret void
+}
+
+;; This is not very idiomatic, but still possible, verify the nneg is propagated
+;; outward. and that no alias is correctly identified.
+define void @t6(i8 %a) {
+; CHECK-LABEL: Function: t6
+; CHECK: NoAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext nneg i8 %a.add to i16
+  %3 = sext i16 %2 to i32
+  %4 = zext i32 %3 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+  %5 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; This is even less idiomatic, but still possible, verify the nneg is not
+;; propagated inward. and that may alias is correctly identified.
+define void @t7(i8 %a) {
+; CHECK-LABEL: Function: t7
+; CHECK: MayAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext i8 %a.add to i16
+  %3 = sext i16 %2 to i32
+  %4 = zext nneg i32 %3 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+  %5 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; Verify the nneg survives an implicit trunc of fewer bits then the zext.
+define void @t8(i8 %a) {
+; CHECK-LABEL: Function: t8
+; CHECK: NoAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext nneg i8 %a.add to i128
+  %gep1 = getelementptr inbounds float, ptr %1, i128 %2
+
+  %3 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %3
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; Ensure that the nneg is never propagated past this trunc and that these
+;; casted values are understood as non-equal.
+define void @t9(i8 %a) {
+; CHECK-LABEL: Function: t9
+; CHECK: MayAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext i8 %a.add to i16
+  %3 = trunc i16 %2 to i1
+  %4 = zext nneg i1 %3 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+  %5 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}


        


More information about the llvm-commits mailing list