[Mlir-commits] [mlir] c6efcc9 - [MLIR][Mem2Reg] Improve performance by avoiding recomputations (#91444)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 8 06:53:18 PDT 2024


Author: Christian Ulmann
Date: 2024-05-08T15:53:14+02:00
New Revision: c6efcc925c9969d616bc463171c0423d6a9766af

URL: https://github.com/llvm/llvm-project/commit/c6efcc925c9969d616bc463171c0423d6a9766af
DIFF: https://github.com/llvm/llvm-project/commit/c6efcc925c9969d616bc463171c0423d6a9766af.diff

LOG: [MLIR][Mem2Reg] Improve performance by avoiding recomputations (#91444)

This commit ensures that Mem2Reg reuses the `DominanceInfo` as well as
block index maps to avoid expensive recomputations. Due to the recent
migration to `OpBuilder`, the promotion of a slot does no longer replace
blocks. Having stable blocks makes the `DominanceInfo` preservable and
additionally allows to cache block index maps between different
promotions.

Performance measurements on very large functions show an up to 4x
speedup by these changes.

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/Mem2Reg.h
    mlir/lib/Transforms/Mem2Reg.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index b4f939d654142..fee7fb312750c 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -28,6 +28,7 @@ struct Mem2RegStatistics {
 LogicalResult
 tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
                         OpBuilder &builder, const DataLayout &dataLayout,
+                        DominanceInfo &dominance,
                         Mem2RegStatistics statistics = {});
 
 } // namespace mlir

diff  --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 1d7ba4ca4f83e..8adbbcd01cb44 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -18,7 +18,6 @@
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/Casting.h"
 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
 
 namespace mlir {
@@ -158,6 +157,8 @@ class MemorySlotPromotionAnalyzer {
   const DataLayout &dataLayout;
 };
 
+using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t>>;
+
 /// The MemorySlotPromoter handles the state of promoting a memory slot. It
 /// wraps a slot and its associated allocator. This will perform the mutation of
 /// IR.
@@ -166,7 +167,8 @@ class MemorySlotPromoter {
   MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
                      OpBuilder &builder, DominanceInfo &dominance,
                      const DataLayout &dataLayout, MemorySlotPromotionInfo info,
-                     const Mem2RegStatistics &statistics);
+                     const Mem2RegStatistics &statistics,
+                     BlockIndexCache &blockIndexCache);
 
   /// Actually promotes the slot by mutating IR. Promoting a slot DOES
   /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
@@ -207,6 +209,9 @@ class MemorySlotPromoter {
   const DataLayout &dataLayout;
   MemorySlotPromotionInfo info;
   const Mem2RegStatistics &statistics;
+
+  /// Shared cache of block indices of specific regions.
+  BlockIndexCache &blockIndexCache;
 };
 
 } // namespace
@@ -214,9 +219,11 @@ class MemorySlotPromoter {
 MemorySlotPromoter::MemorySlotPromoter(
     MemorySlot slot, PromotableAllocationOpInterface allocator,
     OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout,
-    MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+    MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
+    BlockIndexCache &blockIndexCache)
     : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
-      dataLayout(dataLayout), info(std::move(info)), statistics(statistics) {
+      dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
+      blockIndexCache(blockIndexCache) {
 #ifndef NDEBUG
   auto isResultOrNewBlockArgument = [&]() {
     if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -500,15 +507,29 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
   }
 }
 
+/// Gets or creates a block index mapping for `region`.
+static const DenseMap<Block *, size_t> &
+getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) {
+  auto [it, inserted] = blockIndexCache.try_emplace(region);
+  if (!inserted)
+    return it->second;
+
+  DenseMap<Block *, size_t> &blockIndices = it->second;
+  SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(*region);
+  for (auto [index, block] : llvm::enumerate(topologicalOrder))
+    blockIndices[block] = index;
+  return blockIndices;
+}
+
 /// Sorts `ops` according to dominance. Relies on the topological order of basic
-/// blocks to get a deterministic ordering.
-static void dominanceSort(SmallVector<Operation *> &ops, Region &region) {
+/// blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the
+/// potentially expensive recomputation of a block index map.
+static void dominanceSort(SmallVector<Operation *> &ops, Region &region,
+                          BlockIndexCache &blockIndexCache) {
   // Produce a topological block order and construct a map to lookup the indices
   // of blocks.
-  DenseMap<Block *, size_t> topoBlockIndices;
-  SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(region);
-  for (auto [index, block] : llvm::enumerate(topologicalOrder))
-    topoBlockIndices[block] = index;
+  const DenseMap<Block *, size_t> &topoBlockIndices =
+      getOrCreateBlockIndices(blockIndexCache, &region);
 
   // Combining the topological order of the basic blocks together with block
   // internal operation order guarantees a deterministic, dominance respecting
@@ -527,7 +548,8 @@ void MemorySlotPromoter::removeBlockingUses() {
       llvm::make_first_range(info.userToBlockingUses));
 
   // Sort according to dominance.
-  dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
+  dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(),
+                blockIndexCache);
 
   llvm::SmallVector<Operation *> toErase;
   // List of all replaced values in the slot.
@@ -605,20 +627,25 @@ void MemorySlotPromoter::promoteSlot() {
 
 LogicalResult mlir::tryToPromoteMemorySlots(
     ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
-    const DataLayout &dataLayout, Mem2RegStatistics statistics) {
+    const DataLayout &dataLayout, DominanceInfo &dominance,
+    Mem2RegStatistics statistics) {
   bool promotedAny = false;
 
+  // A cache that stores deterministic block indices which are used to determine
+  // a valid operation modification order. The block index maps are computed
+  // lazily and cached to avoid expensive recomputation.
+  BlockIndexCache blockIndexCache;
+
   for (PromotableAllocationOpInterface allocator : allocators) {
     for (MemorySlot slot : allocator.getPromotableSlots()) {
       if (slot.ptr.use_empty())
         continue;
 
-      DominanceInfo dominance;
       MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
       std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
       if (info) {
         MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
-                           std::move(*info), statistics)
+                           std::move(*info), statistics, blockIndexCache)
             .promoteSlot();
         promotedAny = true;
       }
@@ -640,6 +667,10 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
 
     bool changed = false;
 
+    auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
+    const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
+    auto &dominance = getAnalysis<DominanceInfo>();
+
     for (Region &region : scopeOp->getRegions()) {
       if (region.getBlocks().empty())
         continue;
@@ -655,16 +686,12 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
           allocators.emplace_back(allocator);
         });
 
-        auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
-        const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
-
         // Attempt promoting until no promotion succeeds.
         if (failed(tryToPromoteMemorySlots(allocators, builder, dataLayout,
-                                           statistics)))
+                                           dominance, statistics)))
           break;
 
         changed = true;
-        getAnalysisManager().invalidate({});
       }
     }
     if (!changed)


        


More information about the Mlir-commits mailing list