[llvm] Deduplication of cyclic PHI nodes (PR #86662)

Marek Sedláček via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 26 06:56:09 PDT 2024


https://github.com/mark-sed created https://github.com/llvm/llvm-project/pull/86662

The EliminateDuplicatePHINodes function used in simplifycfg pass does not handle cyclic phis. This patch adds detection for such phis and can be constrained using the deduplicate-phi-max-depth option.

This is a continuation of a PR from https://reviews.llvm.org/D153014. There was a request by @nikic to benchmark the compile time for this change, which was done here: https://llvm-compile-time-tracker.com/index.php?config=Overview&stat=instructions%3Au&remote=mark-sed

Although as I have not worked with this tool and with compile time benchmarks before I am not sure how to quantify these results.

>From 17e9cf8940f9f249882e2fcfeec7ddffc0ac17c3 Mon Sep 17 00:00:00 2001
From: Marek Sedlacek <msedlacek at azul.com>
Date: Mon, 11 Mar 2024 13:46:41 +0100
Subject: [PATCH] Added cyclic phis deduplication

---
 llvm/lib/Transforms/Utils/Local.cpp           | 78 ++++++++++++++-----
 llvm/unittests/Transforms/Utils/LocalTest.cpp | 72 +++++++++++++++++
 2 files changed, 132 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index c4a8843f2840b3..bd3ae5b39da710 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -111,6 +111,9 @@ static cl::opt<unsigned> PHICSENumPHISmallSize(
         "When the basic block contains not more than this number of PHI nodes, "
         "perform a (faster!) exhaustive search instead of set-driven one."));
 
+static cl::opt<unsigned> DeduplicatePhisMaxDepth("deduplicate-phi-max-depth",
+                                                 cl::Hidden, cl::init(8));
+
 // Max recursion depth for collectBitParts used when detecting bswap and
 // bitreverse idioms.
 static const unsigned BitPartRecursionMaxDepth = 48;
@@ -1344,6 +1347,41 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB,
   return true;
 }
 
+static bool
+matchPhiStructures(PHINode *P1, PHINode *P2,
+                   SmallSet<std::pair<PHINode *, PHINode *>, 8> &MatchingPhis) {
+  assert(P1->getParent() == P2->getParent() && "Must have the same parent!");
+  if (P1->getType() != P2->getType())
+    return false;
+  // Don't analyze too complex phi structures
+  if (MatchingPhis.size() > DeduplicatePhisMaxDepth)
+    return false;
+
+  if (P2 > P1)
+    std::swap(P1, P2);
+  if (!MatchingPhis.insert(std::make_pair(P1, P2)).second)
+    return true;
+
+  SmallDenseMap<BasicBlock *, Value *, 8> IncomingValues;
+  for (int i = 0, e = P1->getNumIncomingValues(); i != e; ++i)
+    IncomingValues[P1->getIncomingBlock(i)] = P1->getIncomingValue(i);
+
+  for (int i = 0, e = P2->getNumIncomingValues(); i != e; ++i) {
+    Value *I1 = IncomingValues[P2->getIncomingBlock(i)];
+    Value *I2 = P2->getIncomingValue(i);
+    if (I1 == I2)
+      continue;
+    if (auto *I1Phi = dyn_cast<PHINode>(I1))
+      if (auto *I2Phi = dyn_cast<PHINode>(I2))
+        if (I1Phi->getParent() == I2Phi->getParent())
+          if (matchPhiStructures(I1Phi, I2Phi, MatchingPhis))
+            continue;
+    return false;
+  }
+
+  return true;
+}
+
 static bool
 EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB,
                                     SmallPtrSetImpl<PHINode *> &ToRemove) {
@@ -1364,7 +1402,9 @@ EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB,
     for (auto J = I; PHINode *DuplicatePN = dyn_cast<PHINode>(J); ++J) {
       if (ToRemove.contains(DuplicatePN))
         continue;
-      if (!DuplicatePN->isIdenticalToWhenDefined(PN))
+      SmallSet<std::pair<PHINode *, PHINode *>, 8> MatchingPhis;
+      if (!DuplicatePN->isIdenticalToWhenDefined(PN) &&
+          !matchPhiStructures(PN, DuplicatePN, MatchingPhis))
         continue;
       // A duplicate. Replace this PHI with the base PHI.
       ++NumPHICSEs;
@@ -1400,15 +1440,16 @@ EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB,
       return PN == getEmptyKey() || PN == getTombstoneKey();
     }
 
-    // WARNING: this logic must be kept in sync with
-    //          Instruction::isIdenticalToWhenDefined()!
     static unsigned getHashValueImpl(PHINode *PN) {
-      // Compute a hash value on the operands. Instcombine will likely have
-      // sorted them, which helps expose duplicates, but we have to check all
-      // the operands to be safe in case instcombine hasn't run.
-      return static_cast<unsigned>(hash_combine(
-          hash_combine_range(PN->value_op_begin(), PN->value_op_end()),
-          hash_combine_range(PN->block_begin(), PN->block_end())));
+      unsigned Result = 0;
+      for (int i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+        auto *Incoming = PN->getIncomingValue(i);
+        if (!isa<PHINode>(Incoming))
+          Result +=
+              1 +
+              ((intptr_t(PN->getIncomingBlock(i)) ^ intptr_t(Incoming)) >> 3);
+      }
+      return Result;
     }
 
     static unsigned getHashValue(PHINode *PN) {
@@ -1423,18 +1464,19 @@ EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB,
       return getHashValueImpl(PN);
     }
 
-    static bool isEqualImpl(PHINode *LHS, PHINode *RHS) {
-      if (isSentinel(LHS) || isSentinel(RHS))
-        return LHS == RHS;
-      return LHS->isIdenticalTo(RHS);
-    }
-
     static bool isEqual(PHINode *LHS, PHINode *RHS) {
       // These comparisons are nontrivial, so assert that equality implies
       // hash equality (DenseMap demands this as an invariant).
-      bool Result = isEqualImpl(LHS, RHS);
-      assert(!Result || (isSentinel(LHS) && LHS == RHS) ||
-             getHashValueImpl(LHS) == getHashValueImpl(RHS));
+      if (LHS == getEmptyKey() || LHS == getTombstoneKey() ||
+          RHS == getEmptyKey() || RHS == getTombstoneKey())
+        return LHS == RHS;
+      SmallSet<std::pair<PHINode *, PHINode *>, 8> MatchingPhis;
+      bool Result = matchPhiStructures(LHS, RHS, MatchingPhis);
+#ifndef NDEBUG
+      SmallSet<std::pair<PHINode *, PHINode *>, 8> MatchingPhis2;
+      bool Result2 = matchPhiStructures(RHS, LHS, MatchingPhis2);
+      assert(Result2 == Result && "Must be symmetric");
+#endif
       return Result;
     }
   };
diff --git a/llvm/unittests/Transforms/Utils/LocalTest.cpp b/llvm/unittests/Transforms/Utils/LocalTest.cpp
index 82257741045754..6ab6742a6d6417 100644
--- a/llvm/unittests/Transforms/Utils/LocalTest.cpp
+++ b/llvm/unittests/Transforms/Utils/LocalTest.cpp
@@ -114,6 +114,78 @@ static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
   return Mod;
 }
 
+TEST(Local, RemoveDuplicateCyclicPHINodes) {
+  LLVMContext C;
+
+  std::unique_ptr<Module> M = parseIR(C,
+                                      R"(
+      declare i64 @f() #0
+      declare void @f.2()
+
+      define void @f.3(i1 %flag1, i1 %flag2, i1 %flag3, i1 %flag4) {
+      bb:
+        %0 = call i64 @f()
+        br label %bb1
+
+      bb1:
+        %1 = phi i64 [ %7, %bb7 ], [ 5, %bb ]
+        %2 = phi i64 [ %8, %bb7 ], [ 5, %bb ]
+        br i1 %flag1, label %bb2, label %bb3
+
+      bb2:
+        br label %bb3
+
+      bb3:
+        %3 = phi i64 [ 1, %bb2 ], [ %1, %bb1 ]
+        %4 = phi i64 [ 1, %bb2 ], [ %2, %bb1 ]
+        br i1 %flag2, label %bb4, label %bb5
+
+      bb4:
+        br label %bb5
+
+      bb5:
+        %5 = phi i64 [ 2, %bb4 ], [ %3, %bb3 ]
+        %6 = phi i64 [ 2, %bb4 ], [ %4, %bb3 ]
+        br i1 %flag3, label %bb6, label %bb7
+
+      bb6:
+        br label %bb7
+
+      bb7:
+        %7 = phi i64 [ 3, %bb6 ], [ %5, %bb5 ]
+        %8 = phi i64 [ 3, %bb6 ], [ %6, %bb5 ]
+        br i1 %flag4, label %bb1, label %bb8
+
+      bb8:
+        call void @f.2()
+        ret void
+      }
+      )");
+
+  auto *GV = M->getNamedValue("f.3");
+  ASSERT_TRUE(GV);
+  auto *F = dyn_cast<Function>(GV);
+  ASSERT_TRUE(F);
+
+  for (Function::iterator I = F->begin(), E = F->end(); I != E;) {
+    BasicBlock *BB = &*I++;
+    EliminateDuplicatePHINodes(BB);
+  }
+
+  // No block should have more than 2 PHIs
+  for (Function::iterator I = F->begin(), E = F->end(); I != E;) {
+    BasicBlock *BB = &*I++;
+    int PHICount = 0;
+    for (BasicBlock::iterator J = BB->begin(), JE = BB->end(); J != JE; ++J) {
+      Instruction *Inst = &*J;
+      if (isa<PHINode>(Inst)) {
+        PHICount++;
+        EXPECT_TRUE(PHICount < 2);
+      }
+    }
+  }
+}
+
 TEST(Local, ReplaceDbgDeclare) {
   LLVMContext C;
 



More information about the llvm-commits mailing list