[llvm] [BasicAA] Consider 'nneg' flag when comparing CastedValues (PR #94129)

via llvm-commits llvm-commits at lists.llvm.org
Sat Jun 1 17:39:28 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Alex MacLean (AlexMaclean)

<details>
<summary>Changes</summary>

Any of the `zext` bits in a `zext nneg` can be converted to `sext` but when checking if casts are compatible `BasicAA` fails to take into account `nneg`. This change adds tracking of `nneg` to the `CastedValue` struct and ensures that `sext` and `zext` bits are treated as interchangeable when either `CastedValue` has a `nneg`.  When distributing casted values in `GetLinearExpression` we conservatively discard the `nneg` from the `CastedValue`, except in the case of `shl nsw`, where we know the sign has not changed to negative. 

---
Full diff: https://github.com/llvm/llvm-project/pull/94129.diff


2 Files Affected:

- (modified) llvm/lib/Analysis/BasicAliasAnalysis.cpp (+41-17) 
- (added) llvm/test/Analysis/BasicAA/zext-nneg.ll (+181) 


``````````diff
diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
index 3f456db1c51ac..826706d1306a9 100644
--- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp
+++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
@@ -268,31 +268,42 @@ struct CastedValue {
   unsigned ZExtBits = 0;
   unsigned SExtBits = 0;
   unsigned TruncBits = 0;
+  bool IsNonNegative = false;
 
   explicit CastedValue(const Value *V) : V(V) {}
   explicit CastedValue(const Value *V, unsigned ZExtBits, unsigned SExtBits,
-                       unsigned TruncBits)
-      : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits) {}
+                       unsigned TruncBits, bool IsNonNegative)
+      : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits),
+        IsNonNegative(IsNonNegative) {}
 
   unsigned getBitWidth() const {
     return V->getType()->getPrimitiveSizeInBits() - TruncBits + ZExtBits +
            SExtBits;
   }
 
-  CastedValue withValue(const Value *NewV) const {
-    return CastedValue(NewV, ZExtBits, SExtBits, TruncBits);
+  CastedValue withValue(const Value *NewV, bool PreserveNonNeg) const {
+    return CastedValue(NewV, ZExtBits, SExtBits, TruncBits,
+                       IsNonNegative && PreserveNonNeg);
   }
 
   /// Replace V with zext(NewV)
-  CastedValue withZExtOfValue(const Value *NewV) const {
+  CastedValue withZExtOfValue(const Value *NewV, bool ZExtNonNegative) const {
     unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
                         NewV->getType()->getPrimitiveSizeInBits();
     if (ExtendBy <= TruncBits)
-      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+      // zext<nneg>(trunc(zext(NewV))) == zext<nneg>(trunc(NewV))
+      // The nneg can be preserved on the outer zext here
+      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
+                         IsNonNegative);
 
     // zext(sext(zext(NewV))) == zext(zext(zext(NewV)))
     ExtendBy -= TruncBits;
-    return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0);
+    // zext<nneg>(zext(NewV)) == zext(NewV)
+    // zext(zext<nneg>(NewV)) == zext<nneg>(NewV)
+    // The nneg can be preserved from the inner zext here but must be dropped
+    // from the outer.
+    return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0,
+                       ZExtNonNegative);
   }
 
   /// Replace V with sext(NewV)
@@ -300,11 +311,16 @@ struct CastedValue {
     unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
                         NewV->getType()->getPrimitiveSizeInBits();
     if (ExtendBy <= TruncBits)
-      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+      // zext<nneg>(trunc(sext(NewV))) == zext<nneg>(trunc(NewV))
+      // The nneg can be preserved on the outer zext here
+      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
+                         IsNonNegative);
 
     // zext(sext(sext(NewV)))
     ExtendBy -= TruncBits;
-    return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0);
+    // zext<nneg>(sext(sext(NewV))) = zext<nneg>(sext(NewV))
+    // The nneg can be preserved on the outer zext here
+    return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0, IsNonNegative);
   }
 
   APInt evaluateWith(APInt N) const {
@@ -333,8 +349,15 @@ struct CastedValue {
   }
 
   bool hasSameCastsAs(const CastedValue &Other) const {
-    return ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
-           TruncBits == Other.TruncBits;
+    if (ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
+        TruncBits == Other.TruncBits)
+      return true;
+    // If either CastedValue has a nneg zext then the sext/zext bits are
+    // interchangable for that value.
+    if (IsNonNegative || Other.IsNonNegative)
+      return (ZExtBits + SExtBits == Other.ZExtBits + Other.SExtBits &&
+              TruncBits == Other.TruncBits);
+    return false;
   }
 };
 
@@ -410,21 +433,21 @@ static LinearExpression GetLinearExpression(
 
         [[fallthrough]];
       case Instruction::Add: {
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
                                 Depth + 1, AC, DT);
         E.Offset += RHS;
         E.IsNSW &= NSW;
         break;
       }
       case Instruction::Sub: {
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
                                 Depth + 1, AC, DT);
         E.Offset -= RHS;
         E.IsNSW &= NSW;
         break;
       }
       case Instruction::Mul:
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
                                 Depth + 1, AC, DT)
                 .mul(RHS, NSW);
         break;
@@ -437,7 +460,7 @@ static LinearExpression GetLinearExpression(
         if (RHS.getLimitedValue() > Val.getBitWidth())
           return Val;
 
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), NSW), DL,
                                 Depth + 1, AC, DT);
         E.Offset <<= RHS.getLimitedValue();
         E.Scale <<= RHS.getLimitedValue();
@@ -450,7 +473,8 @@ static LinearExpression GetLinearExpression(
 
   if (isa<ZExtInst>(Val.V))
     return GetLinearExpression(
-        Val.withZExtOfValue(cast<CastInst>(Val.V)->getOperand(0)),
+        Val.withZExtOfValue(cast<CastInst>(Val.V)->getOperand(0),
+                            cast<ZExtInst>(Val.V)->hasNonNeg()),
         DL, Depth + 1, AC, DT);
 
   if (isa<SExtInst>(Val.V))
@@ -673,7 +697,7 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
       unsigned SExtBits = IndexSize > Width ? IndexSize - Width : 0;
       unsigned TruncBits = IndexSize < Width ? Width - IndexSize : 0;
       LinearExpression LE = GetLinearExpression(
-          CastedValue(Index, 0, SExtBits, TruncBits), DL, 0, AC, DT);
+          CastedValue(Index, 0, SExtBits, TruncBits, false), DL, 0, AC, DT);
 
       // Scale by the type size.
       unsigned TypeSize = AllocTypeSize.getFixedValue();
diff --git a/llvm/test/Analysis/BasicAA/zext-nneg.ll b/llvm/test/Analysis/BasicAA/zext-nneg.ll
new file mode 100644
index 0000000000000..808bb1a8c9d96
--- /dev/null
+++ b/llvm/test/Analysis/BasicAA/zext-nneg.ll
@@ -0,0 +1,181 @@
+; RUN: opt < %s -aa-pipeline=basic-aa -passes=aa-eval -print-all-alias-modref-info -disable-output 2>&1 | FileCheck %s
+
+;; Simple case: a zext nneg can be replaced with a sext. Make sure BasicAA
+;; understands that.
+define void @t1(i32 %a, i32 %b) {
+; CHECK-LABEL: Function: t1
+; CHECK: NoAlias: float* %gep1, float* %gep2
+
+  %1 = alloca [8 x float], align 4
+  %or1 = or i32 %a, 1
+  %2 = sext i32 %or1 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+  %shl1 = shl i32 %b, 1
+  %3 = zext nneg i32 %shl1 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %3
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; A (zext nneg (sext V)) is equivalent to a (zext (sext V)) as long as the
+;; total number of zext+sext bits is the same for both.
+define void @t2(i8 %a, i8 %b) {
+; CHECK-LABEL: Function: t2
+; CHECK: NoAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %or1 = or i8 %a, 1
+  %2 = sext i8 %or1 to i32
+  %3 = zext i32 %2 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %3
+
+  %shl1 = shl i8 %b, 1
+  %4 = sext i8 %shl1 to i16
+  %5 = zext nneg i16 %4 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; Here the %a and %b are knowably non-equal. In this cases we can distribute
+;; the zext, preserving the nneg flag, through the shl because it has a nsw flag
+define void @t3(i8 %v) {
+; CHECK-LABEL: Function: t3
+; CHECK: NoAlias: <2 x float>* %gep1, <2 x float>* %gep2
+  %a = or i8 %v, 1
+  %b = and i8 %v, 2
+
+  %1 = alloca [8 x float], align 4
+  %or1 = shl nuw nsw i8 %a, 1
+  %2 = zext nneg i8 %or1 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+  %m = mul nsw nuw i8 %b, 2
+  %3 = sext i8 %m to i16
+  %4 = zext i16 %3 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %4
+
+  load <2 x float>, ptr %gep1
+  load <2 x float>, ptr %gep2
+  ret void
+}
+
+;; This is the same as above, but this time the shl does not have the nsw flag.
+;; the nneg cannot be kept on the zext.
+define void @t4(i8 %v) {
+; CHECK-LABEL: Function: t4
+; CHECK: MayAlias: <2 x float>* %gep1, <2 x float>* %gep2
+  %a = or i8 %v, 1
+  %b = and i8 %v, 2
+
+  %1 = alloca [8 x float], align 4
+  %or1 = shl nuw i8 %a, 1
+  %2 = zext nneg i8 %or1 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+  %m = mul nsw nuw i8 %b, 2
+  %3 = sext i8 %m to i16
+  %4 = zext i16 %3 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %4
+
+  load <2 x float>, ptr %gep1
+  load <2 x float>, ptr %gep2
+  ret void
+}
+
+;; Verify a zext nneg and a zext are understood as the same
+define void @t5(ptr %p, i16 %i) {
+; CHECK-LABEL: Function: t5
+; CHECK: NoAlias: i32* %pi, i32* %pi.next
+  %i1 = zext nneg i16 %i to i32
+  %pi = getelementptr i32, ptr %p, i32 %i1
+
+  %i.next = add i16 %i, 1
+  %i.next2 = zext i16 %i.next to i32
+  %pi.next = getelementptr i32, ptr %p, i32 %i.next2
+
+  load i32, ptr %pi
+  load i32, ptr %pi.next
+  ret void
+}
+
+;; This is not very idiomatic, but still possible, verify the nneg is propagated
+;; outward. and that no alias is correctly identified.
+define void @t6(i8 %a) {
+; CHECK-LABEL: Function: t6
+; CHECK: NoAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext nneg i8 %a.add to i16
+  %3 = sext i16 %2 to i32
+  %4 = zext i32 %3 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+  %5 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; This is even less idiomatic, but still possible, verify the nneg is not
+;; propagated inward. and that may alias is correctly identified.
+define void @t7(i8 %a) {
+; CHECK-LABEL: Function: t7
+; CHECK: MayAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext i8 %a.add to i16
+  %3 = sext i16 %2 to i32
+  %4 = zext nneg i32 %3 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+  %5 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; Verify the nneg survives an implicit trunc of fewer bits then the zext.
+define void @t8(i8 %a) {
+; CHECK-LABEL: Function: t8
+; CHECK: NoAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext nneg i8 %a.add to i128
+  %gep1 = getelementptr inbounds float, ptr %1, i128 %2
+
+  %3 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %3
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; Ensure that the nneg is never propagated past this trunc and that these
+;; casted values are understood as non-equal.
+define void @t9(i8 %a) {
+; CHECK-LABEL: Function: t9
+; CHECK: MayAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext i8 %a.add to i16
+  %3 = trunc i16 %2 to i1
+  %4 = zext nneg i1 %3 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+  %5 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/94129


More information about the llvm-commits mailing list