[llvm] 818e554 - [MergeICmps] Adapt to non-eq comparisons, fix bug for cases need be spilted

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 9 07:54:25 PST 2023


Author: zhongyunde
Date: 2023-03-09T23:49:09+08:00
New Revision: 818e554e251c1e07f133aeed9fe0473502ebfdae

URL: https://github.com/llvm/llvm-project/commit/818e554e251c1e07f133aeed9fe0473502ebfdae
DIFF: https://github.com/llvm/llvm-project/commit/818e554e251c1e07f133aeed9fe0473502ebfdae.diff

LOG: [MergeICmps] Adapt to non-eq comparisons, fix bug for cases need be spilted

Fix the last runtime issue as some sequent comparisons need be spilted.
For the origin equal comparisons chain, the new spilted Icmp chain will
still be end with equal, while for the new not-equal comparisons chain,
the new spilted Icmp chain will still be end with equal, so should address
this carefully, see detail wih case partial_sequent_ne.

Thanks for @glandium and @ayzhao report the runtime issue and carefully
examine.
Fix https://github.com/llvm/llvm-project/issues/59740.

Reviewed By: vitalybuka
Differential Revision: https://reviews.llvm.org/D141188

Added: 
    llvm/test/Transforms/MergeICmps/X86/pr59740.ll

Modified: 
    llvm/lib/Transforms/Scalar/MergeICmps.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
index bcedb05890af3..b85d4926e5472 100644
--- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp
+++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
@@ -330,10 +330,10 @@ std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI,
 
 // Visit the given comparison block. If this is a comparison between two valid
 // BCE atoms, returns the comparison.
-std::optional<BCECmpBlock> visitCmpBlock(Value *const Val,
-                                         BasicBlock *const Block,
-                                         const BasicBlock *const PhiBlock,
-                                         BaseIdentifier &BaseId) {
+std::optional<BCECmpBlock>
+visitCmpBlock(Value *const Baseline, ICmpInst::Predicate &Predicate,
+              Value *const Val, BasicBlock *const Block,
+              const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) {
   if (Block->empty())
     return std::nullopt;
   auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator());
@@ -348,15 +348,27 @@ std::optional<BCECmpBlock> visitCmpBlock(Value *const Val,
     // that this does not mean that this is the last incoming value, blocks
     // can be reordered).
     Cond = Val;
-    ExpectedPredicate = ICmpInst::ICMP_EQ;
+    const auto *const ConstBase = cast<ConstantInt>(Baseline);
+    assert(ConstBase->getType()->isIntegerTy(1) &&
+           "Select condition is not an i1?");
+    ExpectedPredicate =
+        ConstBase->isOne() ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ;
+
+    // Remember the correct predicate.
+    Predicate = ExpectedPredicate;
   } else {
+    // All the incoming values must be consistent.
+    if (Baseline != Val)
+      return std::nullopt;
     // In this case, we expect a constant incoming value (the comparison is
     // chained).
     const auto *const Const = cast<ConstantInt>(Val);
-    LLVM_DEBUG(dbgs() << "const\n");
-    if (!Const->isZero())
+    assert(Const->getType()->isIntegerTy(1) &&
+           "Incoming value is not an i1?");
+    LLVM_DEBUG(dbgs() << "const i1 value\n");
+    if (!Const->isZero() && !Const->isOne())
       return std::nullopt;
-    LLVM_DEBUG(dbgs() << "false\n");
+    LLVM_DEBUG(dbgs() << *Const << "\n");
     assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch");
     BasicBlock *const FalseBlock = BranchI->getSuccessor(1);
     Cond = BranchI->getCondition();
@@ -417,6 +429,8 @@ class BCECmpChain {
   std::vector<ContiguousBlocks> MergedBlocks_;
   // The original entry block (before sorting);
   BasicBlock *EntryBlock_;
+  // Remember the predicate type of the chain.
+  ICmpInst::Predicate Predicate_;
 };
 
 static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) {
@@ -475,10 +489,13 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
   // Now look inside blocks to check for BCE comparisons.
   std::vector<BCECmpBlock> Comparisons;
   BaseIdentifier BaseId;
+  Value *const Baseline = Phi.getIncomingValueForBlock(Blocks[0]);
+  Predicate_ = CmpInst::BAD_ICMP_PREDICATE;
   for (BasicBlock *const Block : Blocks) {
     assert(Block && "invalid block");
-    std::optional<BCECmpBlock> Comparison = visitCmpBlock(
-        Phi.getIncomingValueForBlock(Block), Block, Phi.getParent(), BaseId);
+    std::optional<BCECmpBlock> Comparison =
+        visitCmpBlock(Baseline, Predicate_, Phi.getIncomingValueForBlock(Block),
+                      Block, Phi.getParent(), BaseId);
     if (!Comparison) {
       LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n");
       return;
@@ -602,7 +619,8 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
                                     BasicBlock *const InsertBefore,
                                     BasicBlock *const NextCmpBlock,
                                     PHINode &Phi, const TargetLibraryInfo &TLI,
-                                    AliasAnalysis &AA, DomTreeUpdater &DTU) {
+                                    AliasAnalysis &AA, DomTreeUpdater &DTU,
+                                    ICmpInst::Predicate Predicate) {
   assert(!Comparisons.empty() && "merging zero comparisons");
   LLVMContext &Context = NextCmpBlock->getContext();
   const BCECmpBlock &FirstCmp = Comparisons[0];
@@ -623,7 +641,7 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
   else
     Rhs = FirstCmp.Rhs().LoadI->getPointerOperand();
 
-  Value *IsEqual = nullptr;
+  Value *ICmpValue = nullptr;
   LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> "
                     << BB->getName() << "\n");
 
@@ -637,6 +655,14 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
     ToSplit->split(BB, AA);
   }
 
+  // For a Icmp chain, the Predicate is record the last link in the chain of
+  // comparisons. When we spilt the chain The new spilted chain of comparisons
+  // is end with ICMP_EQ.
+  // Only the last link in the chain is a unconditionla jmp.
+  BasicBlock *const TailBB = Comparisons[Comparisons.size() - 1].BB;
+  auto *const BranchI = dyn_cast<BranchInst>(TailBB->getTerminator());
+  ICmpInst::Predicate Pred =
+      BranchI->isUnconditional() ? Predicate : ICmpInst::ICMP_EQ;
   if (Comparisons.size() == 1) {
     LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n");
     Value *const LhsLoad =
@@ -644,7 +670,7 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
     Value *const RhsLoad =
         Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs);
     // There are no blocks to merge, just do the comparison.
-    IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
+    ICmpValue = Builder.CreateICmp(Pred, LhsLoad, RhsLoad);
   } else {
     const unsigned TotalSizeBits = std::accumulate(
         Comparisons.begin(), Comparisons.end(), 0u,
@@ -660,8 +686,8 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
         Lhs, Rhs,
         ConstantInt::get(Builder.getIntNTy(SizeTBits), TotalSizeBits / 8),
         Builder, DL, &TLI);
-    IsEqual = Builder.CreateICmpEQ(
-        MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0));
+    ICmpValue = Builder.CreateICmp(
+        Pred, MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0));
   }
 
   BasicBlock *const PhiBB = Phi.getParent();
@@ -669,11 +695,11 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
   if (NextCmpBlock == PhiBB) {
     // Continue to phi, passing it the comparison result.
     Builder.CreateBr(PhiBB);
-    Phi.addIncoming(IsEqual, BB);
+    Phi.addIncoming(ICmpValue, BB);
     DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}});
   } else {
     // Continue to next block if equal, exit to phi else.
-    Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB);
+    Builder.CreateCondBr(ICmpValue, NextCmpBlock, PhiBB);
     Phi.addIncoming(ConstantInt::getFalse(Context), BB);
     DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock},
                       {DominatorTree::Insert, BB, PhiBB}});
@@ -691,9 +717,11 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
   // so that the next block is always available to branch to.
   BasicBlock *InsertBefore = EntryBlock_;
   BasicBlock *NextCmpBlock = Phi_.getParent();
+  assert(Predicate_ != CmpInst::BAD_ICMP_PREDICATE &&
+         "Got the chain of comparisons");
   for (const auto &Blocks : reverse(MergedBlocks_)) {
     InsertBefore = NextCmpBlock = mergeComparisons(
-        Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU);
+        Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU, Predicate_);
   }
 
   // Replace the original cmp chain with the new cmp chain by pointing all

diff  --git a/llvm/test/Transforms/MergeICmps/X86/pr59740.ll b/llvm/test/Transforms/MergeICmps/X86/pr59740.ll
new file mode 100644
index 0000000000000..6930f63d9289e
--- /dev/null
+++ b/llvm/test/Transforms/MergeICmps/X86/pr59740.ll
@@ -0,0 +1,183 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=mergeicmps -verify-dom-info -S -mtriple=x86_64-unknown-unknown | FileCheck %s
+
+%struct.S = type { i8, i8, i8, i8 }
+%struct1.S = type { ptr, ptr, ptr, i8 }
+
+define noundef i1 @full_sequent_ne(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) {
+; CHECK-LABEL: @full_sequent_ne(
+; CHECK-NEXT:  "bb0+bb1+bb2+bb3":
+; CHECK-NEXT:    [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[S0:%.*]], ptr [[S1:%.*]], i64 4)
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp ne i32 [[MEMCMP]], 0
+; CHECK-NEXT:    br label [[BB4:%.*]]
+; CHECK:       bb4:
+; CHECK-NEXT:    ret i1 [[TMP0]]
+;
+bb0:
+  %v0 = load i8, ptr %s0, align 1
+  %v1 = load i8, ptr %s1, align 1
+  %cmp0 = icmp eq i8 %v0, %v1
+  br i1 %cmp0, label %bb1, label %bb4
+
+bb1:                                              ; preds = %bb0
+  %s2 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1
+  %v2 = load i8, ptr %s2, align 1
+  %s3 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1
+  %v3 = load i8, ptr %s3, align 1
+  %cmp1 = icmp eq i8 %v2, %v3
+  br i1 %cmp1, label %bb2, label %bb4
+
+bb2:                                             ; preds = %bb1
+  %s4 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 2
+  %v4 = load i8, ptr %s4, align 1
+  %s5 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 2
+  %v5 = load i8, ptr %s5, align 1
+  %cmp2 = icmp eq i8 %v4, %v5
+  br i1 %cmp2, label %bb3, label %bb4
+
+bb3:                                               ; preds = %bb2
+  %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 3
+  %v6 = load i8, ptr %s6, align 1
+  %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 3
+  %v7 = load i8, ptr %s7, align 1
+  %cmp3 = icmp ne i8 %v6, %v7
+  br label %bb4
+
+bb4:                                               ; preds = %bb0, %bb1, %bb2, %bb3
+  %cmp = phi i1 [ true, %bb0 ], [ true, %bb1 ], [ true, %bb2 ], [ %cmp3, %bb3 ]
+  ret i1 %cmp
+}
+
+; Negative test: Incorrect const value in PHI node
+define noundef i1 @cmp_ne_incorrect_const(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) {
+; CHECK-LABEL: @cmp_ne_incorrect_const(
+; CHECK-NEXT:  bb0:
+; CHECK-NEXT:    [[V0:%.*]] = load i8, ptr [[S0:%.*]], align 1
+; CHECK-NEXT:    [[V1:%.*]] = load i8, ptr [[S1:%.*]], align 1
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq i8 [[V0]], [[V1]]
+; CHECK-NEXT:    br i1 [[CMP0]], label [[BB1:%.*]], label [[BB2:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    [[S6:%.*]] = getelementptr inbounds [[STRUCT_S:%.*]], ptr [[S0]], i64 0, i32 1
+; CHECK-NEXT:    [[V6:%.*]] = load i8, ptr [[S6]], align 1
+; CHECK-NEXT:    [[S7:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[S1]], i64 0, i32 1
+; CHECK-NEXT:    [[V7:%.*]] = load i8, ptr [[S7]], align 1
+; CHECK-NEXT:    [[CMP3:%.*]] = icmp ne i8 [[V6]], [[V7]]
+; CHECK-NEXT:    br label [[BB2]]
+; CHECK:       bb2:
+; CHECK-NEXT:    [[CMP:%.*]] = phi i1 [ false, [[BB0:%.*]] ], [ [[CMP3]], [[BB1]] ]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+bb0:
+  %v0 = load i8, ptr %s0, align 1
+  %v1 = load i8, ptr %s1, align 1
+  %cmp0 = icmp eq i8 %v0, %v1
+  br i1 %cmp0, label %bb1, label %bb2
+
+bb1:                                               ; preds = %bb0
+  %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1
+  %v6 = load i8, ptr %s6, align 1
+  %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1
+  %v7 = load i8, ptr %s7, align 1
+  %cmp3 = icmp ne i8 %v6, %v7
+  br label %bb2
+
+bb2:                                               ; preds = %bb0, %bb1
+  %cmp = phi i1 [ false, %bb0 ], [ %cmp3, %bb1 ]
+  ret i1 %cmp
+}
+
+define noundef i1 @partial_sequent_eq() {
+; CHECK-LABEL: @partial_sequent_eq(
+; CHECK-NEXT:  bb01:
+; CHECK-NEXT:    [[VARS0:%.*]] = alloca [[STRUCT1_S:%.*]], align 8
+; CHECK-NEXT:    [[VARS1:%.*]] = alloca [[STRUCT1_S]], align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = load ptr, ptr [[VARS0]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load ptr, ptr [[VARS0]], align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq ptr [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    br i1 [[TMP2]], label %"bb1+bb2", label [[BB3:%.*]]
+; CHECK:       "bb1+bb2":
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds [[STRUCT1_S]], ptr [[VARS0]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT1_S]], ptr [[VARS1]], i64 0, i32 2
+; CHECK-NEXT:    [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[TMP3]], ptr [[TMP4]], i64 9)
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[MEMCMP]], 0
+; CHECK-NEXT:    br label [[BB3]]
+; CHECK:       bb3:
+; CHECK-NEXT:    [[CMP:%.*]] = phi i1 [ [[TMP5]], %"bb1+bb2" ], [ false, [[BB01:%.*]] ]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+bb0:
+  %VarS0 = alloca %struct1.S, align 8
+  %VarS1 = alloca %struct1.S, align 8
+  %v0 = load ptr, ptr %VarS0, align 8
+  %v1 = load ptr, ptr %VarS0, align 8
+  %cmp0 = icmp eq ptr %v0, %v1
+  br i1 %cmp0, label %bb1, label %bb3
+
+bb1:                                              ; preds = %bb0
+  %s2 = getelementptr inbounds %struct1.S, ptr %VarS0, i64 0, i32 2
+  %v2 = load ptr, ptr %s2, align 8
+  %s3 = getelementptr inbounds %struct1.S, ptr %VarS1, i64 0, i32 2
+  %v3 = load ptr, ptr %s3, align 8
+  %cmp1 = icmp eq ptr %v2, %v3
+  br i1 %cmp1, label %bb2, label %bb3
+
+bb2:                                               ; preds = %bb2
+  %s6 = getelementptr inbounds %struct1.S, ptr %VarS0, i64 0, i32 3
+  %v6 = load i8, ptr %s6, align 1
+  %s7 = getelementptr inbounds %struct1.S, ptr %VarS1, i64 0, i32 3
+  %v7 = load i8, ptr %s7, align 1
+  %cmp3 = icmp eq i8 %v6, %v7
+  br label %bb3
+
+bb3:                                               ; preds = %bb0, %bb1, %bb2
+  %cmp = phi i1 [ false, %bb0 ], [ false, %bb1 ], [ %cmp3, %bb2 ]
+  ret i1 %cmp
+}
+
+define noundef i1 @partial_sequent_ne() {
+; CHECK-LABEL: @partial_sequent_ne(
+; CHECK-NEXT:  bb01:
+; CHECK-NEXT:    [[VARS0:%.*]] = alloca [[STRUCT1_S:%.*]], align 8
+; CHECK-NEXT:    [[VARS1:%.*]] = alloca [[STRUCT1_S]], align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = load ptr, ptr [[VARS0]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load ptr, ptr [[VARS0]], align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq ptr [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    br i1 [[TMP2]], label %"bb1+bb2", label [[BB3:%.*]]
+; CHECK:       "bb1+bb2":
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds [[STRUCT1_S]], ptr [[VARS0]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT1_S]], ptr [[VARS1]], i64 0, i32 2
+; CHECK-NEXT:    [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[TMP3]], ptr [[TMP4]], i64 9)
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne i32 [[MEMCMP]], 0
+; CHECK-NEXT:    br label [[BB3]]
+; CHECK:       bb3:
+; CHECK-NEXT:    [[CMP:%.*]] = phi i1 [ [[TMP5]], %"bb1+bb2" ], [ false, [[BB01:%.*]] ]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+bb0:
+  %VarS0 = alloca %struct1.S, align 8
+  %VarS1 = alloca %struct1.S, align 8
+  %v0 = load ptr, ptr %VarS0, align 8
+  %v1 = load ptr, ptr %VarS0, align 8
+  %cmp0 = icmp eq ptr %v0, %v1
+  br i1 %cmp0, label %bb1, label %bb3
+
+bb1:                                              ; preds = %bb0
+  %s2 = getelementptr inbounds %struct1.S, ptr %VarS0, i64 0, i32 2
+  %v2 = load ptr, ptr %s2, align 8
+  %s3 = getelementptr inbounds %struct1.S, ptr %VarS1, i64 0, i32 2
+  %v3 = load ptr, ptr %s3, align 8
+  %cmp1 = icmp eq ptr %v2, %v3
+  br i1 %cmp1, label %bb2, label %bb3
+
+bb2:                                               ; preds = %bb2
+  %s6 = getelementptr inbounds %struct1.S, ptr %VarS0, i64 0, i32 3
+  %v6 = load i8, ptr %s6, align 1
+  %s7 = getelementptr inbounds %struct1.S, ptr %VarS1, i64 0, i32 3
+  %v7 = load i8, ptr %s7, align 1
+  %cmp3 = icmp ne i8 %v6, %v7
+  br label %bb3
+
+bb3:                                               ; preds = %bb0, %bb1, %bb2
+  %cmp = phi i1 [ true, %bb0 ], [ true, %bb1 ], [ %cmp3, %bb2 ]
+  ret i1 %cmp
+}


        


More information about the llvm-commits mailing list