[llvm] 91fa356 - [BasicAA] Be more careful with modulo ops on VariableGEPIndex.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 29 01:24:19 PDT 2021


Author: Florian Hahn
Date: 2021-06-29T09:22:36+01:00
New Revision: 91fa3565da16f77e07270e5323874abc22661cb0

URL: https://github.com/llvm/llvm-project/commit/91fa3565da16f77e07270e5323874abc22661cb0
DIFF: https://github.com/llvm/llvm-project/commit/91fa3565da16f77e07270e5323874abc22661cb0.diff

LOG: [BasicAA] Be more careful with modulo ops on VariableGEPIndex.

(V * Scale) % X may not produce the same result for any possible value
of V, e.g. if the multiplication overflows. This means we currently
incorrectly determine NoAlias in some cases.

This patch updates LinearExpression to track whether the expression
has NSW and uses that to adjust the scale used for alias checks.

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D99424

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/BasicAliasAnalysis.h
    llvm/lib/Analysis/BasicAliasAnalysis.cpp
    llvm/test/Analysis/BasicAA/gep-modulo.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/BasicAliasAnalysis.h b/llvm/include/llvm/Analysis/BasicAliasAnalysis.h
index 1468ad89c333f..991c0cbb642aa 100644
--- a/llvm/include/llvm/Analysis/BasicAliasAnalysis.h
+++ b/llvm/include/llvm/Analysis/BasicAliasAnalysis.h
@@ -116,6 +116,9 @@ class BasicAAResult : public AAResultBase<BasicAAResult> {
     // Context instruction to use when querying information about this index.
     const Instruction *CxtI;
 
+    /// True if all operations in this expression are NSW.
+    bool IsNSW;
+
     void dump() const {
       print(dbgs());
       dbgs() << "\n";

diff  --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
index 356259fe5a7a8..da489b8d457fb 100644
--- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp
+++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
@@ -284,11 +284,14 @@ struct LinearExpression {
   APInt Scale;
   APInt Offset;
 
+  /// True if all operations in this expression are NSW.
+  bool IsNSW;
+
   LinearExpression(const ExtendedValue &Val, const APInt &Scale,
-                   const APInt &Offset)
-      : Val(Val), Scale(Scale), Offset(Offset) {}
+                   const APInt &Offset, bool IsNSW)
+      : Val(Val), Scale(Scale), Offset(Offset), IsNSW(IsNSW) {}
 
-  LinearExpression(const ExtendedValue &Val) : Val(Val) {
+  LinearExpression(const ExtendedValue &Val) : Val(Val), IsNSW(true) {
     unsigned BitWidth = Val.getBitWidth();
     Scale = APInt(BitWidth, 1);
     Offset = APInt(BitWidth, 0);
@@ -307,7 +310,7 @@ static LinearExpression GetLinearExpression(
 
   if (const ConstantInt *Const = dyn_cast<ConstantInt>(Val.V))
     return LinearExpression(Val, APInt(Val.getBitWidth(), 0),
-                            Val.evaluateWith(Const->getValue()));
+                            Val.evaluateWith(Const->getValue()), true);
 
   if (const BinaryOperator *BOp = dyn_cast<BinaryOperator>(Val.V)) {
     if (ConstantInt *RHSC = dyn_cast<ConstantInt>(BOp->getOperand(1))) {
@@ -322,6 +325,7 @@ static LinearExpression GetLinearExpression(
       if (!Val.canDistributeOver(NUW, NSW))
         return Val;
 
+      LinearExpression E(Val);
       switch (BOp->getOpcode()) {
       default:
         // We don't understand this instruction, so we can't decompose it any
@@ -336,23 +340,26 @@ static LinearExpression GetLinearExpression(
 
         LLVM_FALLTHROUGH;
       case Instruction::Add: {
-        LinearExpression E = GetLinearExpression(
-            Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+                                Depth + 1, AC, DT);
         E.Offset += RHS;
-        return E;
+        E.IsNSW &= NSW;
+        break;
       }
       case Instruction::Sub: {
-        LinearExpression E = GetLinearExpression(
-            Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+                                Depth + 1, AC, DT);
         E.Offset -= RHS;
-        return E;
+        E.IsNSW &= NSW;
+        break;
       }
       case Instruction::Mul: {
-        LinearExpression E = GetLinearExpression(
-            Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+                                Depth + 1, AC, DT);
         E.Offset *= RHS;
         E.Scale *= RHS;
-        return E;
+        E.IsNSW &= NSW;
+        break;
       }
       case Instruction::Shl:
         // We're trying to linearize an expression of the kind:
@@ -363,12 +370,14 @@ static LinearExpression GetLinearExpression(
         if (RHS.getLimitedValue() > Val.getBitWidth())
           return Val;
 
-        LinearExpression E = GetLinearExpression(
-            Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT);
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+                                Depth + 1, AC, DT);
         E.Offset <<= RHS.getLimitedValue();
         E.Scale <<= RHS.getLimitedValue();
-        return E;
+        E.IsNSW &= NSW;
+        break;
       }
+      return E;
     }
   }
 
@@ -578,8 +587,8 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
       Scale = adjustToPointerSize(Scale, PointerSize);
 
       if (!!Scale) {
-        VariableGEPIndex Entry = {LE.Val.V, LE.Val.ZExtBits, LE.Val.SExtBits,
-                                  Scale, CxtI};
+        VariableGEPIndex Entry = {
+            LE.Val.V, LE.Val.ZExtBits, LE.Val.SExtBits, Scale, CxtI, LE.IsNSW};
         Decomposed.VarIndices.push_back(Entry);
       }
     }
@@ -1138,7 +1147,11 @@ AliasResult BasicAAResult::aliasGEP(
     bool AllNonNegative = DecompGEP1.Offset.isNonNegative();
     bool AllNonPositive = DecompGEP1.Offset.isNonPositive();
     for (unsigned i = 0, e = DecompGEP1.VarIndices.size(); i != e; ++i) {
-      const APInt &Scale = DecompGEP1.VarIndices[i].Scale;
+      APInt Scale = DecompGEP1.VarIndices[i].Scale;
+      if (!DecompGEP1.VarIndices[i].IsNSW)
+        Scale = APInt::getOneBitSet(Scale.getBitWidth(),
+                                    Scale.countTrailingZeros());
+
       if (i == 0)
         GCD = Scale.abs();
       else
@@ -1701,9 +1714,10 @@ void BasicAAResult::GetIndexDifference(
 
       // If we found it, subtract off Scale V's from the entry in Dest.  If it
       // goes to zero, remove the entry.
-      if (Dest[j].Scale != Scale)
+      if (Dest[j].Scale != Scale) {
         Dest[j].Scale -= Scale;
-      else
+        Dest[j].IsNSW = false;
+      } else
         Dest.erase(Dest.begin() + j);
       Scale = 0;
       break;
@@ -1711,7 +1725,8 @@ void BasicAAResult::GetIndexDifference(
 
     // If we didn't consume this entry, add it to the end of the Dest list.
     if (!!Scale) {
-      VariableGEPIndex Entry = {V, ZExtBits, SExtBits, -Scale, Src[i].CxtI};
+      VariableGEPIndex Entry = {V,      ZExtBits,    SExtBits,
+                                -Scale, Src[i].CxtI, Src[i].IsNSW};
       Dest.push_back(Entry);
     }
   }

diff  --git a/llvm/test/Analysis/BasicAA/gep-modulo.ll b/llvm/test/Analysis/BasicAA/gep-modulo.ll
index 79782fad44872..e009ce498b06b 100644
--- a/llvm/test/Analysis/BasicAA/gep-modulo.ll
+++ b/llvm/test/Analysis/BasicAA/gep-modulo.ll
@@ -70,7 +70,7 @@ define void @may_overflow_mul_sub_i64([16 x i8]* %ptr, i64 %idx) {
 ; CHECK-LABEL: Function: may_overflow_mul_sub_i64: 3 pointers, 0 call sites
 ; CHECK-NEXT:    MayAlias:  [16 x i8]* %ptr, i8* %gep.idx
 ; CHECK-NEXT:    PartialAlias (off 3): [16 x i8]* %ptr, i8* %gep.3
-; CHECK-NEXT:    NoAlias:  i8* %gep.3, i8* %gep.idx
+; CHECK-NEXT:    MayAlias:  i8* %gep.3, i8* %gep.idx
 ;
   %mul = mul i64 %idx, 5
   %sub = sub i64 %mul, 1
@@ -115,7 +115,7 @@ define void @only_nuw_mul_sub_i64([16 x i8]* %ptr, i64 %idx) {
 ; CHECK-LABEL: Function: only_nuw_mul_sub_i64: 3 pointers, 0 call sites
 ; CHECK-NEXT:    MayAlias:  [16 x i8]* %ptr, i8* %gep.idx
 ; CHECK-NEXT:    PartialAlias (off 3): [16 x i8]* %ptr, i8* %gep.3
-; CHECK-NEXT:    NoAlias:  i8* %gep.3, i8* %gep.idx
+; CHECK-NEXT:    MayAlias:  i8* %gep.3, i8* %gep.idx
 ;
   %mul = mul nuw i64 %idx, 5
   %sub = sub nuw i64 %mul, 1
@@ -126,6 +126,8 @@ define void @only_nuw_mul_sub_i64([16 x i8]* %ptr, i64 %idx) {
   ret void
 }
 
+; Even though the mul and sub may overflow %gep.idx and %gep.3 cannot alias
+; because we multiply by a power-of-2.
 define void @may_overflow_mul_pow2_sub_i64([16 x i8]* %ptr, i64 %idx) {
 ; CHECK-LABEL: Function: may_overflow_mul_pow2_sub_i64: 3 pointers, 0 call sites
 ; CHECK-NEXT:    MayAlias:  [16 x i8]* %ptr, i8* %gep.idx
@@ -259,7 +261,7 @@ define void @may_overflow_pointer_
diff ([16 x i8]* %ptr, i64 %idx) {
 ; CHECK-LABEL: Function: may_overflow_pointer_
diff : 3 pointers, 0 call sites
 ; CHECK-NEXT:  MayAlias: [16 x i8]* %ptr, i8* %gep.mul.1
 ; CHECK-NEXT:  MayAlias: [16 x i8]* %ptr, i8* %gep.sub.2
-; CHECK-NEXT:  NoAlias:  i8* %gep.mul.1, i8* %gep.sub.2
+; CHECK-NEXT:  MayAlias:  i8* %gep.mul.1, i8* %gep.sub.2
 ;
   %mul.1 = mul i64 %idx, 6148914691236517207
   %gep.mul.1  = getelementptr [16 x i8], [16 x i8]* %ptr, i32 0, i64 %mul.1


        


More information about the llvm-commits mailing list