[llvm] [GVN] Refactor the LeaderTable structure into a properly encapsulated data structure (PR #88347)

Owen Anderson via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 25 14:38:41 PDT 2024


https://github.com/resistor updated https://github.com/llvm/llvm-project/pull/88347

>From b85502acaa72444bba311a089327c2eb265cb64d Mon Sep 17 00:00:00 2001
From: Owen Anderson <resistor at mac.com>
Date: Wed, 10 Apr 2024 22:00:49 -0600
Subject: [PATCH 1/2] Refactor the LeaderTable structure in GVN into a properly
 encapsulated data structure.

Hide the details of the one-off linked list used to implement the leader lists by
wrapping them in iterators, and then use that to reimplement a number of traversals
using standard algorithms and range-based for-loops.

No functional change intended.
---
 llvm/include/llvm/Transforms/Scalar/GVN.h | 111 ++++++++++--------
 llvm/lib/Transforms/Scalar/GVN.cpp        | 135 ++++++++++++++--------
 2 files changed, 149 insertions(+), 97 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Scalar/GVN.h b/llvm/include/llvm/Transforms/Scalar/GVN.h
index 4ba9b74ccb005d..10396141f11204 100644
--- a/llvm/include/llvm/Transforms/Scalar/GVN.h
+++ b/llvm/include/llvm/Transforms/Scalar/GVN.h
@@ -232,13 +232,67 @@ class GVNPass : public PassInfoMixin<GVNPass> {
 
   /// A mapping from value numbers to lists of Value*'s that
   /// have that value number.  Use findLeader to query it.
-  struct LeaderTableEntry {
-    Value *Val;
-    const BasicBlock *BB;
-    LeaderTableEntry *Next;
+  class LeaderMap {
+  public:
+    struct LeaderTableEntry {
+      Value *Val;
+      const BasicBlock *BB;
+    };
+
+  private:
+    struct LeaderListNode {
+      LeaderTableEntry Entry;
+      LeaderListNode *Next;
+    };
+    DenseMap<uint32_t, LeaderListNode> NumToLeaders;
+    BumpPtrAllocator TableAllocator;
+
+  public:
+    class leader_iterator {
+      const LeaderListNode *current;
+
+    public:
+      using iterator_category = std::forward_iterator_tag;
+      using value_type = const LeaderTableEntry;
+      using difference_type = std::ptrdiff_t;
+      using pointer = value_type *;
+      using reference = value_type &;
+
+      leader_iterator(const LeaderListNode *ptr) : current(ptr) {}
+      leader_iterator &operator++() {
+        assert(current && "Dereferenced end of leader list!");
+        current = current->Next;
+        return *this;
+      }
+      bool operator==(const leader_iterator &other) const {
+        return current == other.current;
+      }
+      bool operator!=(const leader_iterator &other) const {
+        return current != other.current;
+      }
+      reference operator*() const { return current->Entry; }
+    };
+
+    iterator_range<leader_iterator> getLeaders(uint32_t N) {
+      auto I = NumToLeaders.find(N);
+      if (I == NumToLeaders.end()) {
+        return iterator_range(leader_iterator(nullptr),
+                              leader_iterator(nullptr));
+      }
+
+      return iterator_range(leader_iterator(&I->second),
+                            leader_iterator(nullptr));
+    }
+
+    void insert(uint32_t N, Value *V, const BasicBlock *BB);
+    void erase(uint32_t N, Instruction *I, BasicBlock *BB);
+    void verifyRemoved(const Value *Inst) const;
+    void clear() {
+      NumToLeaders.clear();
+      TableAllocator.Reset();
+    }
   };
-  DenseMap<uint32_t, LeaderTableEntry> LeaderTable;
-  BumpPtrAllocator TableAllocator;
+  LeaderMap LeaderTable;
 
   // Block-local map of equivalent values to their leader, does not
   // propagate to any successors. Entries added mid-block are applied
@@ -264,51 +318,6 @@ class GVNPass : public PassInfoMixin<GVNPass> {
                MemoryDependenceResults *RunMD, LoopInfo &LI,
                OptimizationRemarkEmitter *ORE, MemorySSA *MSSA = nullptr);
 
-  /// Push a new Value to the LeaderTable onto the list for its value number.
-  void addToLeaderTable(uint32_t N, Value *V, const BasicBlock *BB) {
-    LeaderTableEntry &Curr = LeaderTable[N];
-    if (!Curr.Val) {
-      Curr.Val = V;
-      Curr.BB = BB;
-      return;
-    }
-
-    LeaderTableEntry *Node = TableAllocator.Allocate<LeaderTableEntry>();
-    Node->Val = V;
-    Node->BB = BB;
-    Node->Next = Curr.Next;
-    Curr.Next = Node;
-  }
-
-  /// Scan the list of values corresponding to a given
-  /// value number, and remove the given instruction if encountered.
-  void removeFromLeaderTable(uint32_t N, Instruction *I, BasicBlock *BB) {
-    LeaderTableEntry *Prev = nullptr;
-    LeaderTableEntry *Curr = &LeaderTable[N];
-
-    while (Curr && (Curr->Val != I || Curr->BB != BB)) {
-      Prev = Curr;
-      Curr = Curr->Next;
-    }
-
-    if (!Curr)
-      return;
-
-    if (Prev) {
-      Prev->Next = Curr->Next;
-    } else {
-      if (!Curr->Next) {
-        Curr->Val = nullptr;
-        Curr->BB = nullptr;
-      } else {
-        LeaderTableEntry *Next = Curr->Next;
-        Curr->Val = Next->Val;
-        Curr->BB = Next->BB;
-        Curr->Next = Next->Next;
-      }
-    }
-  }
-
   // List of critical edges to be split between iterations.
   SmallVector<std::pair<Instruction *, unsigned>, 4> toSplit;
 
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index 67fb2a5da3bb71..586dbdfe977f5a 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -724,6 +724,68 @@ void GVNPass::ValueTable::verifyRemoved(const Value *V) const {
          "Inst still occurs in value numbering map!");
 }
 
+//===----------------------------------------------------------------------===//
+//                     LeaderMap External Functions
+//===----------------------------------------------------------------------===//
+
+/// Push a new Value to the LeaderTable onto the list for its value number.
+void GVNPass::LeaderMap::insert(uint32_t N, Value *V, const BasicBlock *BB) {
+  LeaderListNode &Curr = NumToLeaders[N];
+  if (!Curr.Entry.Val) {
+    Curr.Entry.Val = V;
+    Curr.Entry.BB = BB;
+    return;
+  }
+
+  LeaderListNode *Node = TableAllocator.Allocate<LeaderListNode>();
+  Node->Entry.Val = V;
+  Node->Entry.BB = BB;
+  Node->Next = Curr.Next;
+  Curr.Next = Node;
+}
+
+/// Scan the list of values corresponding to a given
+/// value number, and remove the given instruction if encountered.
+void GVNPass::LeaderMap::erase(uint32_t N, Instruction *I, BasicBlock *BB) {
+  LeaderListNode *Prev = nullptr;
+  LeaderListNode *Curr = &NumToLeaders[N];
+
+  while (Curr && (Curr->Entry.Val != I || Curr->Entry.BB != BB)) {
+    Prev = Curr;
+    Curr = Curr->Next;
+  }
+
+  if (!Curr)
+    return;
+
+  if (Prev) {
+    Prev->Next = Curr->Next;
+  } else {
+    if (!Curr->Next) {
+      Curr->Entry.Val = nullptr;
+      Curr->Entry.BB = nullptr;
+    } else {
+      LeaderListNode *Next = Curr->Next;
+      Curr->Entry.Val = Next->Entry.Val;
+      Curr->Entry.BB = Next->Entry.BB;
+      Curr->Next = Next->Next;
+    }
+  }
+}
+
+void GVNPass::LeaderMap::verifyRemoved(const Value *V) const {
+  // Walk through the value number scope to make sure the instruction isn't
+  // ferreted away in it.
+  for (const auto &I : NumToLeaders) {
+    (void)I;
+    assert(I.second.Entry.Val != V && "Inst still in value numbering scope!");
+    assert(
+        std::none_of(leader_iterator(&I.second), leader_iterator(nullptr),
+                     [=](const LeaderTableEntry &E) { return E.Val == V; }) &&
+        "Inst still in value numbering scope!");
+  }
+}
+
 //===----------------------------------------------------------------------===//
 //                                GVN Pass
 //===----------------------------------------------------------------------===//
@@ -1466,7 +1528,7 @@ void GVNPass::eliminatePartiallyRedundantLoad(
         OldLoad->replaceAllUsesWith(NewLoad);
         replaceValuesPerBlockEntry(ValuesPerBlock, OldLoad, NewLoad);
         if (uint32_t ValNo = VN.lookup(OldLoad, false))
-          removeFromLeaderTable(ValNo, OldLoad, OldLoad->getParent());
+          LeaderTable.erase(ValNo, OldLoad, OldLoad->getParent());
         VN.erase(OldLoad);
         removeInstruction(OldLoad);
       }
@@ -2203,10 +2265,10 @@ GVNPass::ValueTable::assignExpNewValueNum(Expression &Exp) {
 /// defined in \p BB.
 bool GVNPass::ValueTable::areAllValsInBB(uint32_t Num, const BasicBlock *BB,
                                          GVNPass &Gvn) {
-  LeaderTableEntry *Vals = &Gvn.LeaderTable[Num];
-  while (Vals && Vals->BB == BB)
-    Vals = Vals->Next;
-  return !Vals;
+  auto I = Gvn.LeaderTable.getLeaders(Num);
+  return std::all_of(
+      I.begin(), I.end(),
+      [=](const LeaderMap::LeaderTableEntry &L) { return L.BB == BB; });
 }
 
 /// Wrap phiTranslateImpl to provide caching functionality.
@@ -2228,12 +2290,11 @@ bool GVNPass::ValueTable::areCallValsEqual(uint32_t Num, uint32_t NewNum,
                                            const BasicBlock *PhiBlock,
                                            GVNPass &Gvn) {
   CallInst *Call = nullptr;
-  LeaderTableEntry *Vals = &Gvn.LeaderTable[Num];
-  while (Vals) {
-    Call = dyn_cast<CallInst>(Vals->Val);
+  auto Leaders = Gvn.LeaderTable.getLeaders(Num);
+  for (auto Entry : Leaders) {
+    Call = dyn_cast<CallInst>(Entry.Val);
     if (Call && Call->getParent() == PhiBlock)
       break;
-    Vals = Vals->Next;
   }
 
   if (AA->doesNotAccessMemory(Call))
@@ -2326,23 +2387,17 @@ void GVNPass::ValueTable::eraseTranslateCacheEntry(
 // question.  This is fast because dominator tree queries consist of only
 // a few comparisons of DFS numbers.
 Value *GVNPass::findLeader(const BasicBlock *BB, uint32_t num) {
-  LeaderTableEntry Vals = LeaderTable[num];
-  if (!Vals.Val) return nullptr;
+  auto Leaders = LeaderTable.getLeaders(num);
+  if (Leaders.empty())
+    return nullptr;
 
   Value *Val = nullptr;
-  if (DT->dominates(Vals.BB, BB)) {
-    Val = Vals.Val;
-    if (isa<Constant>(Val)) return Val;
-  }
-
-  LeaderTableEntry* Next = Vals.Next;
-  while (Next) {
-    if (DT->dominates(Next->BB, BB)) {
-      if (isa<Constant>(Next->Val)) return Next->Val;
-      if (!Val) Val = Next->Val;
+  for (auto Entry : Leaders) {
+    if (DT->dominates(Entry.BB, BB)) {
+      Val = Entry.Val;
+      if (isa<Constant>(Val))
+        return Val;
     }
-
-    Next = Next->Next;
   }
 
   return Val;
@@ -2446,7 +2501,7 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
     // The leader table only tracks basic blocks, not edges. Only add to if we
     // have the simple case where the edge dominates the end.
     if (RootDominatesEnd && !isa<Instruction>(RHS))
-      addToLeaderTable(LVN, RHS, Root.getEnd());
+      LeaderTable.insert(LVN, RHS, Root.getEnd());
 
     // Replace all occurrences of 'LHS' with 'RHS' everywhere in the scope.  As
     // LHS always has at least one use that is not dominated by Root, this will
@@ -2532,7 +2587,7 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
       // The leader table only tracks basic blocks, not edges. Only add to if we
       // have the simple case where the edge dominates the end.
       if (RootDominatesEnd)
-        addToLeaderTable(Num, NotVal, Root.getEnd());
+        LeaderTable.insert(Num, NotVal, Root.getEnd());
 
       continue;
     }
@@ -2582,7 +2637,7 @@ bool GVNPass::processInstruction(Instruction *I) {
       return true;
 
     unsigned Num = VN.lookupOrAdd(Load);
-    addToLeaderTable(Num, Load, Load->getParent());
+    LeaderTable.insert(Num, Load, Load->getParent());
     return false;
   }
 
@@ -2650,7 +2705,7 @@ bool GVNPass::processInstruction(Instruction *I) {
   // Allocations are always uniquely numbered, so we can save time and memory
   // by fast failing them.
   if (isa<AllocaInst>(I) || I->isTerminator() || isa<PHINode>(I)) {
-    addToLeaderTable(Num, I, I->getParent());
+    LeaderTable.insert(Num, I, I->getParent());
     return false;
   }
 
@@ -2658,7 +2713,7 @@ bool GVNPass::processInstruction(Instruction *I) {
   // need to do a lookup to see if the number already exists
   // somewhere in the domtree: it can't!
   if (Num >= NextNum) {
-    addToLeaderTable(Num, I, I->getParent());
+    LeaderTable.insert(Num, I, I->getParent());
     return false;
   }
 
@@ -2667,7 +2722,7 @@ bool GVNPass::processInstruction(Instruction *I) {
   Value *Repl = findLeader(I->getParent(), Num);
   if (!Repl) {
     // Failure, just remember this instance for future use.
-    addToLeaderTable(Num, I, I->getParent());
+    LeaderTable.insert(Num, I, I->getParent());
     return false;
   }
 
@@ -2861,7 +2916,7 @@ bool GVNPass::performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred,
   VN.add(Instr, Num);
 
   // Update the availability map to include the new instruction.
-  addToLeaderTable(Num, Instr, Pred);
+  LeaderTable.insert(Num, Instr, Pred);
   return true;
 }
 
@@ -3012,13 +3067,13 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) {
   // After creating a new PHI for ValNo, the phi translate result for ValNo will
   // be changed, so erase the related stale entries in phi translate cache.
   VN.eraseTranslateCacheEntry(ValNo, *CurrentBlock);
-  addToLeaderTable(ValNo, Phi, CurrentBlock);
+  LeaderTable.insert(ValNo, Phi, CurrentBlock);
   Phi->setDebugLoc(CurInst->getDebugLoc());
   CurInst->replaceAllUsesWith(Phi);
   if (MD && Phi->getType()->isPtrOrPtrVectorTy())
     MD->invalidateCachedPointerInfo(Phi);
   VN.erase(CurInst);
-  removeFromLeaderTable(ValNo, CurInst, CurrentBlock);
+  LeaderTable.erase(ValNo, CurInst, CurrentBlock);
 
   LLVM_DEBUG(dbgs() << "GVN PRE removed: " << *CurInst << '\n');
   removeInstruction(CurInst);
@@ -3112,7 +3167,6 @@ void GVNPass::cleanupGlobalSets() {
   VN.clear();
   LeaderTable.clear();
   BlockRPONumber.clear();
-  TableAllocator.Reset();
   ICF->clear();
   InvalidBlockRPONumbers = true;
 }
@@ -3132,18 +3186,7 @@ void GVNPass::removeInstruction(Instruction *I) {
 /// internal data structures.
 void GVNPass::verifyRemoved(const Instruction *Inst) const {
   VN.verifyRemoved(Inst);
-
-  // Walk through the value number scope to make sure the instruction isn't
-  // ferreted away in it.
-  for (const auto &I : LeaderTable) {
-    const LeaderTableEntry *Node = &I.second;
-    assert(Node->Val != Inst && "Inst still in value numbering scope!");
-
-    while (Node->Next) {
-      Node = Node->Next;
-      assert(Node->Val != Inst && "Inst still in value numbering scope!");
-    }
-  }
+  LeaderTable.verifyRemoved(Inst);
 }
 
 /// BB is declared dead, which implied other blocks become dead as well. This
@@ -3270,7 +3313,7 @@ void GVNPass::assignValNumForDeadCode() {
   for (BasicBlock *BB : DeadBlocks) {
     for (Instruction &Inst : *BB) {
       unsigned ValNum = VN.lookupOrAdd(&Inst);
-      addToLeaderTable(ValNum, &Inst, BB);
+      LeaderTable.insert(ValNum, &Inst, BB);
     }
   }
 }

>From 3c2eb407bb8ac257263a52c26648b1bb2670dfd3 Mon Sep 17 00:00:00 2001
From: Owen Anderson <resistor at mac.com>
Date: Thu, 25 Apr 2024 01:07:53 -0600
Subject: [PATCH 2/2] Update for review feedback.

---
 llvm/include/llvm/Transforms/Scalar/GVN.h | 16 ++++++++--------
 llvm/lib/Transforms/Scalar/GVN.cpp        | 12 ++++++------
 2 files changed, 14 insertions(+), 14 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Scalar/GVN.h b/llvm/include/llvm/Transforms/Scalar/GVN.h
index 10396141f11204..2d87a2f6c85154 100644
--- a/llvm/include/llvm/Transforms/Scalar/GVN.h
+++ b/llvm/include/llvm/Transforms/Scalar/GVN.h
@@ -249,7 +249,7 @@ class GVNPass : public PassInfoMixin<GVNPass> {
 
   public:
     class leader_iterator {
-      const LeaderListNode *current;
+      const LeaderListNode *Current;
 
     public:
       using iterator_category = std::forward_iterator_tag;
@@ -258,19 +258,19 @@ class GVNPass : public PassInfoMixin<GVNPass> {
       using pointer = value_type *;
       using reference = value_type &;
 
-      leader_iterator(const LeaderListNode *ptr) : current(ptr) {}
+      leader_iterator(const LeaderListNode *C) : Current(C) {}
       leader_iterator &operator++() {
-        assert(current && "Dereferenced end of leader list!");
-        current = current->Next;
+        assert(Current && "Dereferenced end of leader list!");
+        Current = Current->Next;
         return *this;
       }
       bool operator==(const leader_iterator &other) const {
-        return current == other.current;
+        return Current == other.Current;
       }
       bool operator!=(const leader_iterator &other) const {
-        return current != other.current;
+        return Current != other.Current;
       }
-      reference operator*() const { return current->Entry; }
+      reference operator*() const { return Current->Entry; }
     };
 
     iterator_range<leader_iterator> getLeaders(uint32_t N) {
@@ -285,7 +285,7 @@ class GVNPass : public PassInfoMixin<GVNPass> {
     }
 
     void insert(uint32_t N, Value *V, const BasicBlock *BB);
-    void erase(uint32_t N, Instruction *I, BasicBlock *BB);
+    void erase(uint32_t N, Instruction *I, const BasicBlock *BB);
     void verifyRemoved(const Value *Inst) const;
     void clear() {
       NumToLeaders.clear();
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index 586dbdfe977f5a..d73cd2f003b494 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -746,7 +746,8 @@ void GVNPass::LeaderMap::insert(uint32_t N, Value *V, const BasicBlock *BB) {
 
 /// Scan the list of values corresponding to a given
 /// value number, and remove the given instruction if encountered.
-void GVNPass::LeaderMap::erase(uint32_t N, Instruction *I, BasicBlock *BB) {
+void GVNPass::LeaderMap::erase(uint32_t N, Instruction *I,
+                               const BasicBlock *BB) {
   LeaderListNode *Prev = nullptr;
   LeaderListNode *Curr = &NumToLeaders[N];
 
@@ -2266,9 +2267,8 @@ GVNPass::ValueTable::assignExpNewValueNum(Expression &Exp) {
 bool GVNPass::ValueTable::areAllValsInBB(uint32_t Num, const BasicBlock *BB,
                                          GVNPass &Gvn) {
   auto I = Gvn.LeaderTable.getLeaders(Num);
-  return std::all_of(
-      I.begin(), I.end(),
-      [=](const LeaderMap::LeaderTableEntry &L) { return L.BB == BB; });
+  return all_of(
+      I, [=](const LeaderMap::LeaderTableEntry &L) { return L.BB == BB; });
 }
 
 /// Wrap phiTranslateImpl to provide caching functionality.
@@ -2291,7 +2291,7 @@ bool GVNPass::ValueTable::areCallValsEqual(uint32_t Num, uint32_t NewNum,
                                            GVNPass &Gvn) {
   CallInst *Call = nullptr;
   auto Leaders = Gvn.LeaderTable.getLeaders(Num);
-  for (auto Entry : Leaders) {
+  for (const auto &Entry : Leaders) {
     Call = dyn_cast<CallInst>(Entry.Val);
     if (Call && Call->getParent() == PhiBlock)
       break;
@@ -2392,7 +2392,7 @@ Value *GVNPass::findLeader(const BasicBlock *BB, uint32_t num) {
     return nullptr;
 
   Value *Val = nullptr;
-  for (auto Entry : Leaders) {
+  for (const auto &Entry : Leaders) {
     if (DT->dominates(Entry.BB, BB)) {
       Val = Entry.Val;
       if (isa<Constant>(Val))



More information about the llvm-commits mailing list