[llvm] Deduplication of cyclic PHI nodes (PR #86662)
Marek Sedláček via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 2 04:07:00 PDT 2024
https://github.com/mark-sed updated https://github.com/llvm/llvm-project/pull/86662
>From 17e9cf8940f9f249882e2fcfeec7ddffc0ac17c3 Mon Sep 17 00:00:00 2001
From: Marek Sedlacek <msedlacek at azul.com>
Date: Mon, 11 Mar 2024 13:46:41 +0100
Subject: [PATCH] Added cyclic phis deduplication
---
llvm/lib/Transforms/Utils/Local.cpp | 78 ++++++++++++++-----
llvm/unittests/Transforms/Utils/LocalTest.cpp | 72 +++++++++++++++++
2 files changed, 132 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index c4a8843f2840b3..bd3ae5b39da710 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -111,6 +111,9 @@ static cl::opt<unsigned> PHICSENumPHISmallSize(
"When the basic block contains not more than this number of PHI nodes, "
"perform a (faster!) exhaustive search instead of set-driven one."));
+static cl::opt<unsigned> DeduplicatePhisMaxDepth("deduplicate-phi-max-depth",
+ cl::Hidden, cl::init(8));
+
// Max recursion depth for collectBitParts used when detecting bswap and
// bitreverse idioms.
static const unsigned BitPartRecursionMaxDepth = 48;
@@ -1344,6 +1347,41 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB,
return true;
}
+static bool
+matchPhiStructures(PHINode *P1, PHINode *P2,
+ SmallSet<std::pair<PHINode *, PHINode *>, 8> &MatchingPhis) {
+ assert(P1->getParent() == P2->getParent() && "Must have the same parent!");
+ if (P1->getType() != P2->getType())
+ return false;
+ // Don't analyze too complex phi structures
+ if (MatchingPhis.size() > DeduplicatePhisMaxDepth)
+ return false;
+
+ if (P2 > P1)
+ std::swap(P1, P2);
+ if (!MatchingPhis.insert(std::make_pair(P1, P2)).second)
+ return true;
+
+ SmallDenseMap<BasicBlock *, Value *, 8> IncomingValues;
+ for (int i = 0, e = P1->getNumIncomingValues(); i != e; ++i)
+ IncomingValues[P1->getIncomingBlock(i)] = P1->getIncomingValue(i);
+
+ for (int i = 0, e = P2->getNumIncomingValues(); i != e; ++i) {
+ Value *I1 = IncomingValues[P2->getIncomingBlock(i)];
+ Value *I2 = P2->getIncomingValue(i);
+ if (I1 == I2)
+ continue;
+ if (auto *I1Phi = dyn_cast<PHINode>(I1))
+ if (auto *I2Phi = dyn_cast<PHINode>(I2))
+ if (I1Phi->getParent() == I2Phi->getParent())
+ if (matchPhiStructures(I1Phi, I2Phi, MatchingPhis))
+ continue;
+ return false;
+ }
+
+ return true;
+}
+
static bool
EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB,
SmallPtrSetImpl<PHINode *> &ToRemove) {
@@ -1364,7 +1402,9 @@ EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB,
for (auto J = I; PHINode *DuplicatePN = dyn_cast<PHINode>(J); ++J) {
if (ToRemove.contains(DuplicatePN))
continue;
- if (!DuplicatePN->isIdenticalToWhenDefined(PN))
+ SmallSet<std::pair<PHINode *, PHINode *>, 8> MatchingPhis;
+ if (!DuplicatePN->isIdenticalToWhenDefined(PN) &&
+ !matchPhiStructures(PN, DuplicatePN, MatchingPhis))
continue;
// A duplicate. Replace this PHI with the base PHI.
++NumPHICSEs;
@@ -1400,15 +1440,16 @@ EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB,
return PN == getEmptyKey() || PN == getTombstoneKey();
}
- // WARNING: this logic must be kept in sync with
- // Instruction::isIdenticalToWhenDefined()!
static unsigned getHashValueImpl(PHINode *PN) {
- // Compute a hash value on the operands. Instcombine will likely have
- // sorted them, which helps expose duplicates, but we have to check all
- // the operands to be safe in case instcombine hasn't run.
- return static_cast<unsigned>(hash_combine(
- hash_combine_range(PN->value_op_begin(), PN->value_op_end()),
- hash_combine_range(PN->block_begin(), PN->block_end())));
+ unsigned Result = 0;
+ for (int i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+ auto *Incoming = PN->getIncomingValue(i);
+ if (!isa<PHINode>(Incoming))
+ Result +=
+ 1 +
+ ((intptr_t(PN->getIncomingBlock(i)) ^ intptr_t(Incoming)) >> 3);
+ }
+ return Result;
}
static unsigned getHashValue(PHINode *PN) {
@@ -1423,18 +1464,19 @@ EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB,
return getHashValueImpl(PN);
}
- static bool isEqualImpl(PHINode *LHS, PHINode *RHS) {
- if (isSentinel(LHS) || isSentinel(RHS))
- return LHS == RHS;
- return LHS->isIdenticalTo(RHS);
- }
-
static bool isEqual(PHINode *LHS, PHINode *RHS) {
// These comparisons are nontrivial, so assert that equality implies
// hash equality (DenseMap demands this as an invariant).
- bool Result = isEqualImpl(LHS, RHS);
- assert(!Result || (isSentinel(LHS) && LHS == RHS) ||
- getHashValueImpl(LHS) == getHashValueImpl(RHS));
+ if (LHS == getEmptyKey() || LHS == getTombstoneKey() ||
+ RHS == getEmptyKey() || RHS == getTombstoneKey())
+ return LHS == RHS;
+ SmallSet<std::pair<PHINode *, PHINode *>, 8> MatchingPhis;
+ bool Result = matchPhiStructures(LHS, RHS, MatchingPhis);
+#ifndef NDEBUG
+ SmallSet<std::pair<PHINode *, PHINode *>, 8> MatchingPhis2;
+ bool Result2 = matchPhiStructures(RHS, LHS, MatchingPhis2);
+ assert(Result2 == Result && "Must be symmetric");
+#endif
return Result;
}
};
diff --git a/llvm/unittests/Transforms/Utils/LocalTest.cpp b/llvm/unittests/Transforms/Utils/LocalTest.cpp
index 82257741045754..6ab6742a6d6417 100644
--- a/llvm/unittests/Transforms/Utils/LocalTest.cpp
+++ b/llvm/unittests/Transforms/Utils/LocalTest.cpp
@@ -114,6 +114,78 @@ static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
return Mod;
}
+TEST(Local, RemoveDuplicateCyclicPHINodes) {
+ LLVMContext C;
+
+ std::unique_ptr<Module> M = parseIR(C,
+ R"(
+ declare i64 @f() #0
+ declare void @f.2()
+
+ define void @f.3(i1 %flag1, i1 %flag2, i1 %flag3, i1 %flag4) {
+ bb:
+ %0 = call i64 @f()
+ br label %bb1
+
+ bb1:
+ %1 = phi i64 [ %7, %bb7 ], [ 5, %bb ]
+ %2 = phi i64 [ %8, %bb7 ], [ 5, %bb ]
+ br i1 %flag1, label %bb2, label %bb3
+
+ bb2:
+ br label %bb3
+
+ bb3:
+ %3 = phi i64 [ 1, %bb2 ], [ %1, %bb1 ]
+ %4 = phi i64 [ 1, %bb2 ], [ %2, %bb1 ]
+ br i1 %flag2, label %bb4, label %bb5
+
+ bb4:
+ br label %bb5
+
+ bb5:
+ %5 = phi i64 [ 2, %bb4 ], [ %3, %bb3 ]
+ %6 = phi i64 [ 2, %bb4 ], [ %4, %bb3 ]
+ br i1 %flag3, label %bb6, label %bb7
+
+ bb6:
+ br label %bb7
+
+ bb7:
+ %7 = phi i64 [ 3, %bb6 ], [ %5, %bb5 ]
+ %8 = phi i64 [ 3, %bb6 ], [ %6, %bb5 ]
+ br i1 %flag4, label %bb1, label %bb8
+
+ bb8:
+ call void @f.2()
+ ret void
+ }
+ )");
+
+ auto *GV = M->getNamedValue("f.3");
+ ASSERT_TRUE(GV);
+ auto *F = dyn_cast<Function>(GV);
+ ASSERT_TRUE(F);
+
+ for (Function::iterator I = F->begin(), E = F->end(); I != E;) {
+ BasicBlock *BB = &*I++;
+ EliminateDuplicatePHINodes(BB);
+ }
+
+ // No block should have more than 2 PHIs
+ for (Function::iterator I = F->begin(), E = F->end(); I != E;) {
+ BasicBlock *BB = &*I++;
+ int PHICount = 0;
+ for (BasicBlock::iterator J = BB->begin(), JE = BB->end(); J != JE; ++J) {
+ Instruction *Inst = &*J;
+ if (isa<PHINode>(Inst)) {
+ PHICount++;
+ EXPECT_TRUE(PHICount < 2);
+ }
+ }
+ }
+}
+
TEST(Local, ReplaceDbgDeclare) {
LLVMContext C;
More information about the llvm-commits
mailing list