[llvm] [LoopUnroll] Add CSE to remove redundant loads after unrolling. (PR #83860)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Wed May 1 02:30:05 PDT 2024


================
@@ -209,13 +212,142 @@ static bool isEpilogProfitable(Loop *L) {
   return false;
 }
 
+struct LoadValue {
+  Instruction *DefI = nullptr;
+  unsigned Generation = 0;
+  LoadValue() = default;
+  LoadValue(Instruction *Inst, unsigned Generation)
+      : DefI(Inst), Generation(Generation) {}
+};
+
+class StackNode {
+  ScopedHashTable<const SCEV *, LoadValue>::ScopeTy LoadScope;
+  unsigned CurrentGeneration;
+  unsigned ChildGeneration;
+  DomTreeNode *Node;
+  DomTreeNode::const_iterator ChildIter;
+  DomTreeNode::const_iterator EndIter;
+  bool Processed = false;
+
+public:
+  StackNode(ScopedHashTable<const SCEV *, LoadValue> &AvailableLoads,
+            unsigned cg, DomTreeNode *N, DomTreeNode::const_iterator Child,
+            DomTreeNode::const_iterator End)
+      : LoadScope(AvailableLoads), CurrentGeneration(cg), ChildGeneration(cg),
+        Node(N), ChildIter(Child), EndIter(End) {}
+  // Accessors.
+  unsigned currentGeneration() const { return CurrentGeneration; }
+  unsigned childGeneration() const { return ChildGeneration; }
+  void childGeneration(unsigned generation) { ChildGeneration = generation; }
+  DomTreeNode *node() { return Node; }
+  DomTreeNode::const_iterator childIter() const { return ChildIter; }
+
+  DomTreeNode *nextChild() {
+    DomTreeNode *child = *ChildIter;
+    ++ChildIter;
+    return child;
+  }
+
+  DomTreeNode::const_iterator end() const { return EndIter; }
+  bool isProcessed() const { return Processed; }
+  void process() { Processed = true; }
+};
+
+Value *getMatchingValue(LoadValue LV, LoadInst *LI, unsigned CurrentGeneration,
+                        function_ref<MemorySSA *()> GetMSSA) {
+  if (!LV.DefI)
+    return nullptr;
+  if (LV.Generation != CurrentGeneration) {
+    MemorySSA *MSSA = GetMSSA();
+    if (!MSSA)
+      return nullptr;
+    auto *EarlierMA = MSSA->getMemoryAccess(LV.DefI);
+    MemoryAccess *LaterDef = MSSA->getWalker()->getClobberingMemoryAccess(LI);
+    if (!MSSA->dominates(LaterDef, EarlierMA))
+      return nullptr;
+  }
+
+  if (LV.DefI->getType() != LI->getType())
+    return nullptr;
+  return LV.DefI;
+}
+
+void loadCSE(Loop *L, DominatorTree &DT, ScalarEvolution &SE, LoopInfo &LI,
+             function_ref<MemorySSA *()> GetMSSA) {
+  ScopedHashTable<const SCEV *, LoadValue> AvailableLoads;
+  SmallVector<std::unique_ptr<StackNode>> NodesToProcess;
+  DomTreeNode *HeaderD = DT.getNode(L->getHeader());
+  NodesToProcess.emplace_back(new StackNode(AvailableLoads, 0, HeaderD,
+                                            HeaderD->begin(), HeaderD->end()));
+
+  unsigned CurrentGeneration = 0;
+  while (!NodesToProcess.empty()) {
+    // Grab the first item off the stack. Set the current generation, remove
+    // the node from the stack, and process it.
+    StackNode *NodeToProcess = &*NodesToProcess.back();
+
+    // Initialize class members.
+    CurrentGeneration = NodeToProcess->currentGeneration();
+
+    if (!NodeToProcess->isProcessed()) {
+      // Process the node.
+
+      // If this block has a single predecessor, then the predecessor is the
+      // parent
+      // of the domtree node and all of the live out memory values are still
+      // current in this block.  If this block has multiple predecessors, then
+      // they could have invalidated the live-out memory values of our parent
+      // value.  For now, just be conservative and invalidate memory if this
+      // block has multiple predecessors.
+      if (!NodeToProcess->node()->getBlock()->getSinglePredecessor())
+        ++CurrentGeneration;
+      for (auto &I : make_early_inc_range(*NodeToProcess->node()->getBlock())) {
+
+        auto *Load = dyn_cast<LoadInst>(&I);
+        if (!Load || !Load->isSimple()) {
+          if (I.mayWriteToMemory())
+            CurrentGeneration++;
+          continue;
+        }
+
+        const SCEV *PtrSCEV = SE.getSCEV(Load->getPointerOperand());
+        LoadValue LV = AvailableLoads.lookup(PtrSCEV);
+        if (Value *M = getMatchingValue(LV, Load, CurrentGeneration, GetMSSA)) {
+
----------------
fhahn wrote:

Dropped, thanks!

https://github.com/llvm/llvm-project/pull/83860


More information about the llvm-commits mailing list