[llvm] a759745 - [InstCombine] Support multiple comparisons in foldAllocaCmp()

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 14 02:35:52 PDT 2023


Author: Nikita Popov
Date: 2023-04-14T11:32:58+02:00
New Revision: a75974516903a91897b042d065c968ba96d745d7

URL: https://github.com/llvm/llvm-project/commit/a75974516903a91897b042d065c968ba96d745d7
DIFF: https://github.com/llvm/llvm-project/commit/a75974516903a91897b042d065c968ba96d745d7.diff

LOG: [InstCombine] Support multiple comparisons in foldAllocaCmp()

foldAllocaCmp() needs to fold all comparisons of an alloca at the
same time, to ensure that there is a consistent view of the alloca
address. Currently, it folds "all" comparisons by limiting to the
case where there is only one. This patch switches the algorithm to
instead actually collect and fold all comparisons.

Something we need to be careful about here is that there may be
comparisons where both sides of the icmp are based on the alloca.
Such comparisons are comparing offsets of the alloca, and as such
can be ignored here, but shouldn't be folded to false.

Differential Revision: https://reviews.llvm.org/D144492

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/test/Transforms/InstCombine/compare-alloca.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 646a130d612fe..5c2f65eca28f6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -899,10 +899,7 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
   return transformToIndexedCompare(GEPLHS, RHS, Cond, DL);
 }
 
-Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI,
-                                             const AllocaInst *Alloca) {
-  assert(ICI.isEquality() && "Cannot fold non-equality comparison.");
-
+bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) {
   // It would be tempting to fold away comparisons between allocas and any
   // pointer not based on that alloca (e.g. an argument). However, even
   // though such pointers cannot alias, they can still compare equal.
@@ -911,21 +908,34 @@ Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI,
   // doesn't escape we can argue that it's impossible to guess its value, and we
   // can therefore act as if any such guesses are wrong.
   //
-  // The code below checks that the alloca doesn't escape, and that it's only
-  // used in a comparison once (the current instruction). The
-  // single-comparison-use condition ensures that we're trivially folding all
-  // comparisons against the alloca consistently, and avoids the risk of
-  // erroneously folding a comparison of the pointer with itself.
+  // However, we need to ensure that this folding is consistent: We can't fold
+  // one comparison to false, and then leave a 
diff erent comparison against the
+  // same value alone (as it might evaluate to true at runtime, leading to a
+  // contradiction). As such, this code ensures that all comparisons are folded
+  // at the same time, and there are no other escapes.
 
   struct CmpCaptureTracker : public CaptureTracker {
+    AllocaInst *Alloca;
     bool Captured = false;
-    unsigned NumCmps = 0;
+    /// The value of the map is a bit mask of which icmp operands the alloca is
+    /// used in.
+    SmallMapVector<ICmpInst *, unsigned, 4> ICmps;
+
+    CmpCaptureTracker(AllocaInst *Alloca) : Alloca(Alloca) {}
 
     void tooManyUses() override { Captured = true; }
 
     bool captured(const Use *U) override {
-      if (isa<ICmpInst>(U->getUser()) && ++NumCmps == 1) {
-        // Ignore one icmp capture.
+      auto *ICmp = dyn_cast<ICmpInst>(U->getUser());
+      // We need to check that U is based *only* on the alloca, and doesn't
+      // have other contributions from a select/phi operand.
+      // TODO: We could check whether getUnderlyingObjects() reduces to one
+      // object, which would allow looking through phi nodes.
+      if (ICmp && ICmp->isEquality() && getUnderlyingObject(*U) == Alloca) {
+        // Collect equality icmps of the alloca, and don't treat them as
+        // captures.
+        auto Res = ICmps.insert({ICmp, 0});
+        Res.first->second |= 1u << U->getOperandNo();
         return false;
       }
 
@@ -934,14 +944,36 @@ Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI,
     }
   };
 
-  CmpCaptureTracker Tracker;
+  CmpCaptureTracker Tracker(Alloca);
   PointerMayBeCaptured(Alloca, &Tracker);
   if (Tracker.Captured)
-    return nullptr;
+    return false;
+
+  bool Changed = false;
+  for (auto [ICmp, Operands] : Tracker.ICmps) {
+    switch (Operands) {
+    case 1:
+    case 2: {
+      // The alloca is only used in one icmp operand. Assume that the
+      // equality is false.
+      auto *Res = ConstantInt::get(
+          ICmp->getType(), ICmp->getPredicate() == ICmpInst::ICMP_NE);
+      replaceInstUsesWith(*ICmp, Res);
+      eraseInstFromFunction(*ICmp);
+      Changed = true;
+      break;
+    }
+    case 3:
+      // Both icmp operands are based on the alloca, so this is comparing
+      // pointer offsets, without leaking any information about the address
+      // of the alloca. Ignore such comparisons.
+      break;
+    default:
+      llvm_unreachable("Cannot happen");
+    }
+  }
 
-  auto *Res = ConstantInt::get(ICI.getType(),
-                               !CmpInst::isTrueWhenEqual(ICI.getPredicate()));
-  return replaceInstUsesWith(ICI, Res);
+  return Changed;
 }
 
 /// Fold "icmp pred (X+C), X".
@@ -6500,11 +6532,11 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
   if (Op0->getType()->isPointerTy() && I.isEquality()) {
     assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?");
     if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op0)))
-      if (Instruction *New = foldAllocaCmp(I, Alloca))
-        return New;
+      if (foldAllocaCmp(Alloca))
+        return nullptr;
     if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op1)))
-      if (Instruction *New = foldAllocaCmp(I, Alloca))
-        return New;
+      if (foldAllocaCmp(Alloca))
+        return nullptr;
   }
 
   if (Instruction *Res = foldICmpBitCast(I))

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 804b77ae78d9e..03362db5defaa 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -549,7 +549,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                            ICmpInst::Predicate Cond, Instruction &I);
   Instruction *foldSelectICmp(ICmpInst::Predicate Pred, SelectInst *SI,
                               Value *RHS, const ICmpInst &I);
-  Instruction *foldAllocaCmp(ICmpInst &ICI, const AllocaInst *Alloca);
+  bool foldAllocaCmp(AllocaInst *Alloca);
   Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI,
                                             GetElementPtrInst *GEP,
                                             GlobalVariable *GV, CmpInst &ICI,

diff  --git a/llvm/test/Transforms/InstCombine/compare-alloca.ll b/llvm/test/Transforms/InstCombine/compare-alloca.ll
index 164398382e9e0..9ca8dc866e565 100644
--- a/llvm/test/Transforms/InstCombine/compare-alloca.ll
+++ b/llvm/test/Transforms/InstCombine/compare-alloca.ll
@@ -58,12 +58,7 @@ define i1 @alloca_argument_compare_escaped_alloca(ptr %arg) {
 declare void @check_compares(i1, i1)
 define void @alloca_argument_compare_two_compares(ptr %p) {
 ; CHECK-LABEL: @alloca_argument_compare_two_compares(
-; CHECK-NEXT:    [[Q1:%.*]] = alloca [8 x i64], align 8
-; CHECK-NEXT:    [[R:%.*]] = getelementptr i64, ptr [[P:%.*]], i32 1
-; CHECK-NEXT:    [[S:%.*]] = getelementptr inbounds i64, ptr [[Q1]], i32 2
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[Q1]], [[P]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq ptr [[R]], [[S]]
-; CHECK-NEXT:    call void @check_compares(i1 [[CMP1]], i1 [[CMP2]])
+; CHECK-NEXT:    call void @check_compares(i1 false, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %q = alloca i64, i64 8
@@ -154,13 +149,10 @@ define i1 @offset_single_cmp() {
 
 declare void @witness(i1, i1)
 
-define void @neg_consistent_fold1() {
-; CHECK-LABEL: @neg_consistent_fold1(
-; CHECK-NEXT:    [[M1:%.*]] = alloca [4 x i8], align 1
+define void @consistent_fold1() {
+; CHECK-LABEL: @consistent_fold1(
 ; CHECK-NEXT:    [[RHS2:%.*]] = call ptr @hidden_inttoptr()
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[M1]], inttoptr (i64 2048 to ptr)
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq ptr [[M1]], [[RHS2]]
-; CHECK-NEXT:    call void @witness(i1 [[CMP1]], i1 [[CMP2]])
+; CHECK-NEXT:    call void @witness(i1 false, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %m = alloca i8, i32 4
@@ -172,15 +164,11 @@ define void @neg_consistent_fold1() {
   ret void
 }
 
-define void @neg_consistent_fold2() {
-; CHECK-LABEL: @neg_consistent_fold2(
-; CHECK-NEXT:    [[M1:%.*]] = alloca [4 x i8], align 1
+define void @consistent_fold2() {
+; CHECK-LABEL: @consistent_fold2(
 ; CHECK-NEXT:    [[N2:%.*]] = alloca [4 x i8], align 1
-; CHECK-NEXT:    [[RHS:%.*]] = getelementptr inbounds i8, ptr [[N2]], i32 4
 ; CHECK-NEXT:    [[RHS2:%.*]] = call ptr @hidden_offset(ptr nonnull [[N2]])
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[M1]], [[RHS]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq ptr [[M1]], [[RHS2]]
-; CHECK-NEXT:    call void @witness(i1 [[CMP1]], i1 [[CMP2]])
+; CHECK-NEXT:    call void @witness(i1 false, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %m = alloca i8, i32 4
@@ -193,14 +181,10 @@ define void @neg_consistent_fold2() {
   ret void
 }
 
-define void @neg_consistent_fold3() {
-; CHECK-LABEL: @neg_consistent_fold3(
-; CHECK-NEXT:    [[M1:%.*]] = alloca [4 x i8], align 1
-; CHECK-NEXT:    [[LGP:%.*]] = load ptr, ptr @gp, align 8
+define void @consistent_fold3() {
+; CHECK-LABEL: @consistent_fold3(
 ; CHECK-NEXT:    [[RHS2:%.*]] = call ptr @hidden_inttoptr()
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[M1]], [[LGP]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq ptr [[M1]], [[RHS2]]
-; CHECK-NEXT:    call void @witness(i1 [[CMP1]], i1 [[CMP2]])
+; CHECK-NEXT:    call void @witness(i1 false, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %m = alloca i8, i32 4


        


More information about the llvm-commits mailing list