[llvm-branch-commits] [mlir] [MLIR][Mem2Reg] Change API to always retry promotion after changes (PR #91464)

Christian Ulmann via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed May 8 05:10:04 PDT 2024


https://github.com/Dinistro created https://github.com/llvm/llvm-project/pull/91464

This commit modifies Mem2Reg's API to always attempt a full promotion on all the passed in "allocators". This ensures that the pass does not require unnecessary walks over the regions and improves caching benefits.

>From c5a6fd716c09d3445db41337c1bfbc9d6626e4da Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Wed, 8 May 2024 12:03:56 +0000
Subject: [PATCH] [MLIR][Mem2Reg] Change API to always retry promotion after
 changes

This commit modifies the Mem2Reg's API to always attempt a full
promotion on all the passed in "allocators". This ensures that the pass
does not require unnecessary walks over the regions and improves caching
benefits.
---
 mlir/include/mlir/Transforms/Mem2Reg.h |  6 +--
 mlir/lib/Transforms/Mem2Reg.cpp        | 62 +++++++++++++++-----------
 2 files changed, 39 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index fee7fb312750..6986cad9ae12 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -9,7 +9,6 @@
 #ifndef MLIR_TRANSFORMS_MEM2REG_H
 #define MLIR_TRANSFORMS_MEM2REG_H
 
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "llvm/ADT/Statistic.h"
 
@@ -23,8 +22,9 @@ struct Mem2RegStatistics {
   llvm::Statistic *newBlockArgumentAmount = nullptr;
 };
 
-/// Attempts to promote the memory slots of the provided allocators. Succeeds if
-/// at least one memory slot was promoted.
+/// Attempts to promote the memory slots of the provided allocators. Iteratively
+/// retries the promotion of all slots as promoting one slot might enable
+/// subsequent promotions. Succeeds if at least one memory slot was promoted.
 LogicalResult
 tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
                         OpBuilder &builder, const DataLayout &dataLayout,
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 8adbbcd01cb4..390d2a3f54b6 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -636,20 +636,36 @@ LogicalResult mlir::tryToPromoteMemorySlots(
   // lazily and cached to avoid expensive recomputation.
   BlockIndexCache blockIndexCache;
 
-  for (PromotableAllocationOpInterface allocator : allocators) {
-    for (MemorySlot slot : allocator.getPromotableSlots()) {
-      if (slot.ptr.use_empty())
-        continue;
-
-      MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
-      std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
-      if (info) {
-        MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
-                           std::move(*info), statistics, blockIndexCache)
-            .promoteSlot();
-        promotedAny = true;
+  SmallVector<PromotableAllocationOpInterface> workList(allocators.begin(),
+                                                        allocators.end());
+
+  SmallVector<PromotableAllocationOpInterface> newWorkList;
+  newWorkList.reserve(workList.size());
+  while (true) {
+    for (PromotableAllocationOpInterface allocator : workList) {
+      for (MemorySlot slot : allocator.getPromotableSlots()) {
+        if (slot.ptr.use_empty())
+          continue;
+
+        MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
+        std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
+        if (info) {
+          MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
+                             std::move(*info), statistics, blockIndexCache)
+              .promoteSlot();
+          promotedAny = true;
+          continue;
+        }
+        newWorkList.push_back(allocator);
       }
     }
+    if (workList.size() == newWorkList.size())
+      break;
+
+    // Swap the vector's backing memory and clear the entries in newWorkList
+    // afterwards. This ensures that additional heap allocations can be avoided.
+    workList.swap(newWorkList);
+    newWorkList.clear();
   }
 
   return success(promotedAny);
@@ -677,22 +693,16 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
 
       OpBuilder builder(&region.front(), region.front().begin());
 
-      // Promoting a slot can allow for further promotion of other slots,
-      // promotion is tried until no promotion succeeds.
-      while (true) {
-        SmallVector<PromotableAllocationOpInterface> allocators;
-        // Build a list of allocators to attempt to promote the slots of.
-        region.walk([&](PromotableAllocationOpInterface allocator) {
-          allocators.emplace_back(allocator);
-        });
-
-        // Attempt promoting until no promotion succeeds.
-        if (failed(tryToPromoteMemorySlots(allocators, builder, dataLayout,
-                                           dominance, statistics)))
-          break;
+      SmallVector<PromotableAllocationOpInterface> allocators;
+      // Build a list of allocators to attempt to promote the slots of.
+      region.walk([&](PromotableAllocationOpInterface allocator) {
+        allocators.emplace_back(allocator);
+      });
 
+      // Attempt promoting as many of the slots as possible.
+      if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout,
+                                            dominance, statistics)))
         changed = true;
-      }
     }
     if (!changed)
       markAllAnalysesPreserved();



More information about the llvm-branch-commits mailing list