[llvm] 3ac2b3a - [MergeICmps] Adapt to non-eq comparisons
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 11 17:47:55 PST 2023
Author: zhongyunde
Date: 2023-01-12T09:47:02+08:00
New Revision: 3ac2b3a4f9effc9f79822e770f209fd70ff66362
URL: https://github.com/llvm/llvm-project/commit/3ac2b3a4f9effc9f79822e770f209fd70ff66362
DIFF: https://github.com/llvm/llvm-project/commit/3ac2b3a4f9effc9f79822e770f209fd70ff66362.diff
LOG: [MergeICmps] Adapt to non-eq comparisons
Fix https://github.com/llvm/llvm-project/issues/59740.
Reviewed By: courbet, nikic
Differential Revision: https://reviews.llvm.org/D141188
Added:
llvm/test/Transforms/MergeICmps/X86/pr59740.ll
Modified:
llvm/lib/Transforms/Scalar/MergeICmps.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
index bcedb05890af3..bcb95f80325d5 100644
--- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp
+++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
@@ -330,10 +330,10 @@ std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI,
// Visit the given comparison block. If this is a comparison between two valid
// BCE atoms, returns the comparison.
-std::optional<BCECmpBlock> visitCmpBlock(Value *const Val,
- BasicBlock *const Block,
- const BasicBlock *const PhiBlock,
- BaseIdentifier &BaseId) {
+std::optional<BCECmpBlock>
+visitCmpBlock(Value *const Baseline, ICmpInst::Predicate &Predicate,
+ Value *const Val, BasicBlock *const Block,
+ const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) {
if (Block->empty())
return std::nullopt;
auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator());
@@ -348,15 +348,27 @@ std::optional<BCECmpBlock> visitCmpBlock(Value *const Val,
// that this does not mean that this is the last incoming value, blocks
// can be reordered).
Cond = Val;
- ExpectedPredicate = ICmpInst::ICMP_EQ;
+ const auto *const ConstBase = cast<ConstantInt>(Baseline);
+ assert(ConstBase->getType()->isIntegerTy(1) &&
+ "Select condition is not an i1?");
+ ExpectedPredicate =
+ ConstBase->isOne() ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ;
+
+ // Remember the correct predicate.
+ Predicate = ExpectedPredicate;
} else {
+ // All the incoming values must be consistent.
+ if (Baseline != Val)
+ return std::nullopt;
// In this case, we expect a constant incoming value (the comparison is
// chained).
const auto *const Const = cast<ConstantInt>(Val);
+ assert(Const->getType()->isIntegerTy(1) &&
+ "Incoming value is not an i1?");
LLVM_DEBUG(dbgs() << "const\n");
- if (!Const->isZero())
+ if (!Const->isZero() && !Const->isOne())
return std::nullopt;
- LLVM_DEBUG(dbgs() << "false\n");
+ LLVM_DEBUG(dbgs() << *Const << "\n");
assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch");
BasicBlock *const FalseBlock = BranchI->getSuccessor(1);
Cond = BranchI->getCondition();
@@ -417,6 +429,8 @@ class BCECmpChain {
std::vector<ContiguousBlocks> MergedBlocks_;
// The original entry block (before sorting);
BasicBlock *EntryBlock_;
+ // Remember the predicate type of the chain.
+ ICmpInst::Predicate Predicate_;
};
static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) {
@@ -475,10 +489,13 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
// Now look inside blocks to check for BCE comparisons.
std::vector<BCECmpBlock> Comparisons;
BaseIdentifier BaseId;
+ Value *const Baseline = Phi.getIncomingValueForBlock(Blocks[0]);
+ Predicate_ = CmpInst::BAD_ICMP_PREDICATE;
for (BasicBlock *const Block : Blocks) {
assert(Block && "invalid block");
- std::optional<BCECmpBlock> Comparison = visitCmpBlock(
- Phi.getIncomingValueForBlock(Block), Block, Phi.getParent(), BaseId);
+ std::optional<BCECmpBlock> Comparison =
+ visitCmpBlock(Baseline, Predicate_, Phi.getIncomingValueForBlock(Block),
+ Block, Phi.getParent(), BaseId);
if (!Comparison) {
LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n");
return;
@@ -602,7 +619,8 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
BasicBlock *const InsertBefore,
BasicBlock *const NextCmpBlock,
PHINode &Phi, const TargetLibraryInfo &TLI,
- AliasAnalysis &AA, DomTreeUpdater &DTU) {
+ AliasAnalysis &AA, DomTreeUpdater &DTU,
+ ICmpInst::Predicate Predicate) {
assert(!Comparisons.empty() && "merging zero comparisons");
LLVMContext &Context = NextCmpBlock->getContext();
const BCECmpBlock &FirstCmp = Comparisons[0];
@@ -623,7 +641,7 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
else
Rhs = FirstCmp.Rhs().LoadI->getPointerOperand();
- Value *IsEqual = nullptr;
+ Value *ICmpValue = nullptr;
LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> "
<< BB->getName() << "\n");
@@ -644,7 +662,7 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
Value *const RhsLoad =
Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs);
// There are no blocks to merge, just do the comparison.
- IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
+ ICmpValue = Builder.CreateICmp(Predicate, LhsLoad, RhsLoad);
} else {
const unsigned TotalSizeBits = std::accumulate(
Comparisons.begin(), Comparisons.end(), 0u,
@@ -660,8 +678,8 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
Lhs, Rhs,
ConstantInt::get(Builder.getIntNTy(SizeTBits), TotalSizeBits / 8),
Builder, DL, &TLI);
- IsEqual = Builder.CreateICmpEQ(
- MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0));
+ ICmpValue = Builder.CreateICmp(
+ Predicate, MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0));
}
BasicBlock *const PhiBB = Phi.getParent();
@@ -669,11 +687,11 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
if (NextCmpBlock == PhiBB) {
// Continue to phi, passing it the comparison result.
Builder.CreateBr(PhiBB);
- Phi.addIncoming(IsEqual, BB);
+ Phi.addIncoming(ICmpValue, BB);
DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}});
} else {
// Continue to next block if equal, exit to phi else.
- Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB);
+ Builder.CreateCondBr(ICmpValue, NextCmpBlock, PhiBB);
Phi.addIncoming(ConstantInt::getFalse(Context), BB);
DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock},
{DominatorTree::Insert, BB, PhiBB}});
@@ -691,9 +709,11 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
// so that the next block is always available to branch to.
BasicBlock *InsertBefore = EntryBlock_;
BasicBlock *NextCmpBlock = Phi_.getParent();
+ assert(Predicate_ != CmpInst::BAD_ICMP_PREDICATE &&
+ "Got the chain of comparisons");
for (const auto &Blocks : reverse(MergedBlocks_)) {
InsertBefore = NextCmpBlock = mergeComparisons(
- Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU);
+ Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU, Predicate_);
}
// Replace the original cmp chain with the new cmp chain by pointing all
diff --git a/llvm/test/Transforms/MergeICmps/X86/pr59740.ll b/llvm/test/Transforms/MergeICmps/X86/pr59740.ll
new file mode 100644
index 0000000000000..6b46325447a50
--- /dev/null
+++ b/llvm/test/Transforms/MergeICmps/X86/pr59740.ll
@@ -0,0 +1,86 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=mergeicmps -verify-dom-info -S -mtriple=x86_64-unknown-unknown | FileCheck %s
+
+%struct.S = type { i8, i8, i8, i8 }
+
+define noundef i1 @_Z2neR1SS0_(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) {
+; CHECK-LABEL: @_Z2neR1SS0_(
+; CHECK-NEXT: "bb0+bb1+bb2+bb3":
+; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[S0:%.*]], ptr [[S1:%.*]], i64 4)
+; CHECK-NEXT: [[TMP0:%.*]] = icmp ne i32 [[MEMCMP]], 0
+; CHECK-NEXT: br label [[BB4:%.*]]
+; CHECK: bb4:
+; CHECK-NEXT: ret i1 [[TMP0]]
+;
+bb0:
+ %v0 = load i8, ptr %s0, align 1
+ %v1 = load i8, ptr %s1, align 1
+ %cmp0 = icmp eq i8 %v0, %v1
+ br i1 %cmp0, label %bb1, label %bb4
+
+bb1: ; preds = %bb0
+ %s2 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1
+ %v2 = load i8, ptr %s2, align 1
+ %s3 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1
+ %v3 = load i8, ptr %s3, align 1
+ %cmp1 = icmp eq i8 %v2, %v3
+ br i1 %cmp1, label %bb2, label %bb4
+
+bb2: ; preds = %bb1
+ %s4 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 2
+ %v4 = load i8, ptr %s4, align 1
+ %s5 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 2
+ %v5 = load i8, ptr %s5, align 1
+ %cmp2 = icmp eq i8 %v4, %v5
+ br i1 %cmp2, label %bb3, label %bb4
+
+bb3: ; preds = %bb2
+ %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 3
+ %v6 = load i8, ptr %s6, align 1
+ %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 3
+ %v7 = load i8, ptr %s7, align 1
+ %cmp3 = icmp ne i8 %v6, %v7
+ br label %bb4
+
+bb4: ; preds = %bb0, %bb1, %bb2, %bb3
+ %cmp = phi i1 [ true, %bb0 ], [ true, %bb1 ], [ true, %bb2 ], [ %cmp3, %bb3 ]
+ ret i1 %cmp
+}
+
+; Negative test: Incorrect const value in PHI node
+define noundef i1 @cmp_ne_incorrect_const(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) {
+; CHECK-LABEL: @cmp_ne_incorrect_const(
+; CHECK-NEXT: bb0:
+; CHECK-NEXT: [[V0:%.*]] = load i8, ptr [[S0:%.*]], align 1
+; CHECK-NEXT: [[V1:%.*]] = load i8, ptr [[S1:%.*]], align 1
+; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i8 [[V0]], [[V1]]
+; CHECK-NEXT: br i1 [[CMP0]], label [[BB1:%.*]], label [[BB2:%.*]]
+; CHECK: bb1:
+; CHECK-NEXT: [[S6:%.*]] = getelementptr inbounds [[STRUCT_S:%.*]], ptr [[S0]], i64 0, i32 1
+; CHECK-NEXT: [[V6:%.*]] = load i8, ptr [[S6]], align 1
+; CHECK-NEXT: [[S7:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[S1]], i64 0, i32 1
+; CHECK-NEXT: [[V7:%.*]] = load i8, ptr [[S7]], align 1
+; CHECK-NEXT: [[CMP3:%.*]] = icmp ne i8 [[V6]], [[V7]]
+; CHECK-NEXT: br label [[BB2]]
+; CHECK: bb2:
+; CHECK-NEXT: [[CMP:%.*]] = phi i1 [ false, [[BB0:%.*]] ], [ [[CMP3]], [[BB1]] ]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+bb0:
+ %v0 = load i8, ptr %s0, align 1
+ %v1 = load i8, ptr %s1, align 1
+ %cmp0 = icmp eq i8 %v0, %v1
+ br i1 %cmp0, label %bb1, label %bb2
+
+bb1: ; preds = %bb0
+ %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1
+ %v6 = load i8, ptr %s6, align 1
+ %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1
+ %v7 = load i8, ptr %s7, align 1
+ %cmp3 = icmp ne i8 %v6, %v7
+ br label %bb2
+
+bb2: ; preds = %bb0, %bb1
+ %cmp = phi i1 [ false, %bb0 ], [ %cmp3, %bb1 ]
+ ret i1 %cmp
+}
More information about the llvm-commits
mailing list