[Mlir-commits] [mlir] ead8e9d - [mlir] [mem2reg] Adapt to be pattern-friendly.

Tobias Gysi llvmlistbot at llvm.org
Tue May 16 01:39:51 PDT 2023


Author: Théo Degioanni
Date: 2023-05-16T08:35:13Z
New Revision: ead8e9d7953e817c52fdfaf7196dfeb2199dab26

URL: https://github.com/llvm/llvm-project/commit/ead8e9d7953e817c52fdfaf7196dfeb2199dab26
DIFF: https://github.com/llvm/llvm-project/commit/ead8e9d7953e817c52fdfaf7196dfeb2199dab26.diff

LOG: [mlir] [mem2reg] Adapt to be pattern-friendly.

This revision modifies the mem2reg interfaces and algorithm to be more
omfortable to use as a pattern. The motivation behind this is that
currently the pattern needs to be applied to the scope op of the region
in which allocators should be promoted. However, a more natural way to
apply the pattern would be to apply it on the allocator directly. This
is not only clearer but easier to parallelize.

This revision changes the mem2reg pattern to operate this way. This
required restraining the interfaces to only mutate IR using
RewriterBase, as the previously used escape hatch is not granular enough
to match on the region that is modified only. This has the unfortunate
cost of preventing batching allocator promotion and making the block
argument adding logic more complex. Because batching no longer made any
sense, I made the internal analyzer/promoter decoupling private again.

This also adds statistics to the mem2reg infrastructure.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D150432

Added: 
    mlir/test/Dialect/MemRef/mem2reg-statistics.mlir

Modified: 
    mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
    mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
    mlir/include/mlir/Transforms/Mem2Reg.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
    mlir/lib/Transforms/Mem2Reg.cpp
    mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
    mlir/test/Dialect/LLVMIR/mem2reg.mlir
    mlir/test/Dialect/MemRef/mem2reg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
index be761ae427acc..c0f8b2f8ee9ce 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
 

diff  --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index f98bdbabc4c0f..73061f79521af 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -31,6 +31,8 @@ def PromotableAllocationOpInterface
 
         Promotion of the slot will lead to the slot pointer no longer being
         used, leaving the content of the memory slot unreachable.
+
+        No IR mutation is allowed in this method.
       }], "::llvm::SmallVector<::mlir::MemorySlot>", "getPromotableSlots",
       (ins)
     >,
@@ -38,34 +40,42 @@ def PromotableAllocationOpInterface
         Provides the default Value of this memory slot. The provided Value
         will be used as the reaching definition of loads done before any store.
         This Value must outlive the promotion and dominate all the uses of this
-        slot's pointer. The provided builder can be used to create the default
+        slot's pointer. The provided rewriter can be used to create the default
         value on the fly.
 
-        The builder is located at the beginning of the block where the slot
-        pointer is defined.
+        The rewriter is located at the beginning of the block where the slot
+        pointer is defined. All IR mutations must happen through the rewriter.
       }], "::mlir::Value", "getDefaultValue",
-      (ins "const ::mlir::MemorySlot &":$slot, "::mlir::OpBuilder &":$builder)
+      (ins
+        "const ::mlir::MemorySlot &":$slot,
+        "::mlir::RewriterBase &":$rewriter)
     >,
     InterfaceMethod<[{
         Hook triggered for every new block argument added to a block.
         This will only be called for slots declared by this operation.
 
-        The builder is located at the beginning of the block on call.
+        The rewriter is located at the beginning of the block on call. All IR
+        mutations must happen through the rewriter.
       }],
       "void", "handleBlockArgument",
       (ins
         "const ::mlir::MemorySlot &":$slot,
         "::mlir::BlockArgument":$argument,
-        "::mlir::OpBuilder &":$builder
+        "::mlir::RewriterBase &":$rewriter
       )
     >,
     InterfaceMethod<[{
         Hook triggered once the promotion of a slot is complete. This can
         also clean up the created default value if necessary.
         This will only be called for slots declared by this operation.
+
+        All IR mutations must happen through the rewriter.
       }],
       "void", "handlePromotionComplete",
-      (ins "const ::mlir::MemorySlot &":$slot, "::mlir::Value":$defaultValue)
+      (ins
+        "const ::mlir::MemorySlot &":$slot, 
+        "::mlir::Value":$defaultValue,
+        "::mlir::RewriterBase &":$rewriter)
     >,
   ];
 }
@@ -87,6 +97,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
   let methods = [
     InterfaceMethod<[{
         Gets whether this operation loads from the specified slot.
+
+        No IR mutation is allowed in this method.
       }],
       "bool", "loadsFrom",
       (ins "const ::mlir::MemorySlot &":$slot)
@@ -96,6 +108,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         value if this operation does not store to this slot. An operation
         storing a value to a slot must always be able to provide the value it
         stores. This method is only called on operations that use the slot.
+
+        No IR mutation is allowed in this method.
       }],
       "::mlir::Value", "getStored",
       (ins "const ::mlir::MemorySlot &":$slot)
@@ -107,6 +121,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         If the removal procedure of the use will require that other uses get
         removed, that dependency should be added to the `newBlockingUses`
         argument. Dependent uses must only be uses of results of this operation.
+
+        No IR mutation is allowed in this method.
       }], "bool", "canUsesBeRemoved",
       (ins "const ::mlir::MemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses,
@@ -132,13 +148,14 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         have been done at the point of calling this method, but it will be done
         eventually.
 
-        The builder is located after the promotable operation on call.
+        The rewriter is located after the promotable operation on call. All IR
+        mutations must happen through the rewriter.
       }],
       "::mlir::DeletionKind",
       "removeBlockingUses",
       (ins "const ::mlir::MemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
-           "::mlir::OpBuilder &":$builder,
+           "::mlir::RewriterBase &":$rewriter,
            "::mlir::Value":$reachingDefinition)
     >,
   ];
@@ -160,6 +177,8 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         If the removal procedure of the use will require that other uses get
         removed, that dependency should be added to the `newBlockingUses`
         argument. Dependent uses must only be uses of results of this operation.
+
+        No IR mutation is allowed in this method.
       }], "bool", "canUsesBeRemoved",
       (ins "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses,
            "::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses)
@@ -185,12 +204,13 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         have been done at the point of calling this method, but it will be done
         eventually.
 
-        The builder is located after the promotable operation on call.
+        The rewriter is located after the promotable operation on call. All IR
+        mutations must happen through the rewriter.
       }],
       "::mlir::DeletionKind",
       "removeBlockingUses",
       (ins "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
-           "::mlir::OpBuilder &":$builder)
+           "::mlir::RewriterBase &":$rewriter)
     >,
   ];
 }

diff  --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index a34ea68e750bf..46b2a1f56d21e 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -13,129 +13,39 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/ADT/Statistic.h"
 
 namespace mlir {
 
-/// Information computed during promotion analysis used to perform actual
-/// promotion.
-struct MemorySlotPromotionInfo {
-  /// Blocks for which at least two definitions of the slot values clash.
-  SmallPtrSet<Block *, 8> mergePoints;
-  /// Contains, for each operation, which uses must be eliminated by promotion.
-  /// This is a DAG structure because if an operation must eliminate some of
-  /// its uses, it is because the defining ops of the blocking uses requested
-  /// it. The defining ops therefore must also have blocking uses or be the
-  /// starting point of the bloccking uses.
-  DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
-};
-
-/// Computes information for basic slot promotion. This will check that direct
-/// slot promotion can be performed, and provide the information to execute the
-/// promotion. This does not mutate IR.
-class MemorySlotPromotionAnalyzer {
-public:
-  MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance)
-      : slot(slot), dominance(dominance) {}
-
-  /// Computes the information for slot promotion if promotion is possible,
-  /// returns nothing otherwise.
-  std::optional<MemorySlotPromotionInfo> computeInfo();
-
-private:
-  /// Computes the transitive uses of the slot that block promotion. This finds
-  /// uses that would block the promotion, checks that the operation has a
-  /// solution to remove the blocking use, and potentially forwards the analysis
-  /// if the operation needs further blocking uses resolved to resolve its own
-  /// uses (typically, removing its users because it will delete itself to
-  /// resolve its own blocking uses). This will fail if one of the transitive
-  /// users cannot remove a requested use, and should prevent promotion.
-  LogicalResult computeBlockingUses(
-      DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
-
-  /// Computes in which blocks the value stored in the slot is actually used,
-  /// meaning blocks leading to a load. This method uses `definingBlocks`, the
-  /// set of blocks containing a store to the slot (defining the value of the
-  /// slot).
-  SmallPtrSet<Block *, 16>
-  computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
-
-  /// Computes the points in which multiple re-definitions of the slot's value
-  /// (stores) may conflict.
-  void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
-
-  /// Ensures predecessors of merge points can properly provide their current
-  /// definition of the value stored in the slot to the merge point. This can
-  /// notably be an issue if the terminator used does not have the ability to
-  /// forward values through block operands.
-  bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
-
-  MemorySlot slot;
-  DominanceInfo &dominance;
-};
-
-/// The MemorySlotPromoter handles the state of promoting a memory slot. It
-/// wraps a slot and its associated allocator. This will perform the mutation of
-/// IR.
-class MemorySlotPromoter {
-public:
-  MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
-                     OpBuilder &builder, DominanceInfo &dominance,
-                     MemorySlotPromotionInfo info);
-
-  /// Actually promotes the slot by mutating IR. Promoting a slot does not
-  /// invalidate the MemorySlotPromotionInfo of other slots.
-  void promoteSlot();
-
-private:
-  /// Computes the reaching definition for all the operations that require
-  /// promotion. `reachingDef` is the value the slot should contain at the
-  /// beginning of the block. This method returns the reached definition at the
-  /// end of the block.
-  Value computeReachingDefInBlock(Block *block, Value reachingDef);
-
-  /// Computes the reaching definition for all the operations that require
-  /// promotion. `reachingDef` corresponds to the initial value the
-  /// slot will contain before any write, typically a poison value.
-  void computeReachingDefInRegion(Region *region, Value reachingDef);
-
-  /// Removes the blocking uses of the slot, in topological order.
-  void removeBlockingUses();
-
-  /// Lazily-constructed default value representing the content of the slot when
-  /// no store has been executed. This function may mutate IR.
-  Value getLazyDefaultValue();
-
-  MemorySlot slot;
-  PromotableAllocationOpInterface allocator;
-  OpBuilder &builder;
-  /// Potentially non-initialized default value. Use `getLazyDefaultValue` to
-  /// initialize it on demand.
-  Value defaultValue;
-  /// Contains the reaching definition at this operation. Reaching definitions
-  /// are only computed for promotable memory operations with blocking uses.
-  DenseMap<PromotableMemOpInterface, Value> reachingDefs;
-  DominanceInfo &dominance;
-  MemorySlotPromotionInfo info;
+struct Mem2RegStatistics {
+  llvm::Statistic *promotedAmount = nullptr;
+  llvm::Statistic *newBlockArgumentAmount = nullptr;
 };
 
 /// Pattern applying mem2reg to the regions of the operations on which it
 /// matches.
-class Mem2RegPattern : public RewritePattern {
+class Mem2RegPattern
+    : public OpInterfaceRewritePattern<PromotableAllocationOpInterface> {
 public:
-  using RewritePattern::RewritePattern;
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
 
-  Mem2RegPattern(MLIRContext *ctx, PatternBenefit benefit = 1)
-      : RewritePattern(MatchAnyOpTypeTag(), benefit, ctx) {}
+  Mem2RegPattern(MLIRContext *context, Mem2RegStatistics statistics = {},
+                 PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern(context, benefit), statistics(statistics) {}
 
-  LogicalResult matchAndRewrite(Operation *op,
+  LogicalResult matchAndRewrite(PromotableAllocationOpInterface allocator,
                                 PatternRewriter &rewriter) const override;
+
+private:
+  Mem2RegStatistics statistics;
 };
 
 /// Attempts to promote the memory slots of the provided allocators. Succeeds if
 /// at least one memory slot was promoted.
 LogicalResult
 tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
-                        OpBuilder &builder, DominanceInfo &dominance);
+                        RewriterBase &rewriter,
+                        Mem2RegStatistics statistics = {});
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1cc357ca1f9f4..62b8dd075f21f 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -189,6 +189,21 @@ def Mem2Reg : Pass<"mem2reg"> {
     This pass only supports unstructured control-flow. Promotion of operations
     within subregions will not happen.
   }];
+
+  let options = [
+    Option<"enableRegionSimplification", "region-simplify", "bool",
+       /*default=*/"true",
+       "Perform control flow optimizations to the region tree">,
+  ];
+
+  let statistics = [
+    Statistic<"promotedAmount",
+              "promoted slots",
+              "Number of promoted memory slot">,
+    Statistic<"newBlockArgumentAmount",
+              "new block args",
+              "Total number of block arguments added">,
+  ];
 }
 
 def PrintOpStats : Pass<"print-op-stats"> {

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index e4fd2a755d90f..51c49892d71c9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -35,24 +35,25 @@ llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
 }
 
 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
-                                      OpBuilder &builder) {
-  return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType);
+                                      RewriterBase &rewriter) {
+  return rewriter.create<LLVM::UndefOp>(getLoc(), slot.elemType);
 }
 
 void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
                                          BlockArgument argument,
-                                         OpBuilder &builder) {
+                                         RewriterBase &rewriter) {
   for (Operation *user : getOperation()->getUsers())
     if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
-      builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
-                                       declareOp.getVarInfo());
+      rewriter.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
+                                        declareOp.getVarInfo());
 }
 
 void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
-                                             Value defaultValue) {
+                                             Value defaultValue,
+                                             RewriterBase &rewriter) {
   if (defaultValue && defaultValue.use_empty())
-    defaultValue.getDefiningOp()->erase();
-  erase();
+    rewriter.eraseOp(defaultValue.getDefiningOp());
+  rewriter.eraseOp(*this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -87,10 +88,10 @@ bool LLVM::LoadOp::canUsesBeRemoved(
 
 DeletionKind LLVM::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  getResult().replaceAllUsesWith(reachingDefinition);
+  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
   return DeletionKind::Delete;
 }
 
@@ -110,13 +111,13 @@ bool LLVM::StoreOp::canUsesBeRemoved(
 
 DeletionKind LLVM::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition) {
   // `canUsesBeRemoved` checked this blocking use must be the stored slot
   // pointer.
   for (Operation *user : slot.ptr.getUsers())
     if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
-      builder.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
-                                       declareOp.getVarInfo());
+      rewriter.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
+                                        declareOp.getVarInfo());
   return DeletionKind::Delete;
 }
 
@@ -140,7 +141,7 @@ bool LLVM::BitcastOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::BitcastOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }
 
@@ -151,7 +152,7 @@ bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }
 
@@ -162,7 +163,7 @@ bool LLVM::LifetimeStartOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }
 
@@ -173,7 +174,7 @@ bool LLVM::LifetimeEndOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }
 
@@ -184,7 +185,7 @@ bool LLVM::DbgDeclareOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }
 
@@ -209,6 +210,6 @@ bool LLVM::GEPOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::GEPOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
index b5f5272d64212..12d9ebd5a02ad 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
@@ -40,29 +40,30 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
 }
 
 Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
-                                        OpBuilder &builder) {
+                                        RewriterBase &rewriter) {
   assert(isSupportedElementType(slot.elemType));
   // TODO: support more types.
   return TypeSwitch<Type, Value>(slot.elemType)
       .Case([&](MemRefType t) {
-        return builder.create<memref::AllocaOp>(getLoc(), t);
+        return rewriter.create<memref::AllocaOp>(getLoc(), t);
       })
       .Default([&](Type t) {
-        return builder.create<arith::ConstantOp>(getLoc(), t,
-                                                 builder.getZeroAttr(t));
+        return rewriter.create<arith::ConstantOp>(getLoc(), t,
+                                                  rewriter.getZeroAttr(t));
       });
 }
 
 void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
-                                               Value defaultValue) {
+                                               Value defaultValue,
+                                               RewriterBase &rewriter) {
   if (defaultValue.use_empty())
-    defaultValue.getDefiningOp()->erase();
-  erase();
+    rewriter.eraseOp(defaultValue.getDefiningOp());
+  rewriter.eraseOp(*this);
 }
 
 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
                                            BlockArgument argument,
-                                           OpBuilder &builder) {}
+                                           RewriterBase &rewriter) {}
 
 //===----------------------------------------------------------------------===//
 //  LoadOp/StoreOp interfaces
@@ -86,10 +87,10 @@ bool memref::LoadOp::canUsesBeRemoved(
 
 DeletionKind memref::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  getResult().replaceAllUsesWith(reachingDefinition);
+  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
   return DeletionKind::Delete;
 }
 
@@ -113,6 +114,6 @@ bool memref::StoreOp::canUsesBeRemoved(
 
 DeletionKind memref::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition) {
   return DeletionKind::Delete;
 }

diff  --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 45d6f7d0c1ed8..3b303f9836cf5 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -10,6 +10,8 @@
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -92,11 +94,121 @@ using namespace mlir;
 /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022),
 ///      Springer.
 
+namespace {
+
+/// Information computed during promotion analysis used to perform actual
+/// promotion.
+struct MemorySlotPromotionInfo {
+  /// Blocks for which at least two definitions of the slot values clash.
+  SmallPtrSet<Block *, 8> mergePoints;
+  /// Contains, for each operation, which uses must be eliminated by promotion.
+  /// This is a DAG structure because if an operation must eliminate some of
+  /// its uses, it is because the defining ops of the blocking uses requested
+  /// it. The defining ops therefore must also have blocking uses or be the
+  /// starting point of the bloccking uses.
+  DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
+};
+
+/// Computes information for basic slot promotion. This will check that direct
+/// slot promotion can be performed, and provide the information to execute the
+/// promotion. This does not mutate IR.
+class MemorySlotPromotionAnalyzer {
+public:
+  MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance)
+      : slot(slot), dominance(dominance) {}
+
+  /// Computes the information for slot promotion if promotion is possible,
+  /// returns nothing otherwise.
+  std::optional<MemorySlotPromotionInfo> computeInfo();
+
+private:
+  /// Computes the transitive uses of the slot that block promotion. This finds
+  /// uses that would block the promotion, checks that the operation has a
+  /// solution to remove the blocking use, and potentially forwards the analysis
+  /// if the operation needs further blocking uses resolved to resolve its own
+  /// uses (typically, removing its users because it will delete itself to
+  /// resolve its own blocking uses). This will fail if one of the transitive
+  /// users cannot remove a requested use, and should prevent promotion.
+  LogicalResult computeBlockingUses(
+      DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
+
+  /// Computes in which blocks the value stored in the slot is actually used,
+  /// meaning blocks leading to a load. This method uses `definingBlocks`, the
+  /// set of blocks containing a store to the slot (defining the value of the
+  /// slot).
+  SmallPtrSet<Block *, 16>
+  computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
+
+  /// Computes the points in which multiple re-definitions of the slot's value
+  /// (stores) may conflict.
+  void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
+
+  /// Ensures predecessors of merge points can properly provide their current
+  /// definition of the value stored in the slot to the merge point. This can
+  /// notably be an issue if the terminator used does not have the ability to
+  /// forward values through block operands.
+  bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
+
+  MemorySlot slot;
+  DominanceInfo &dominance;
+};
+
+/// The MemorySlotPromoter handles the state of promoting a memory slot. It
+/// wraps a slot and its associated allocator. This will perform the mutation of
+/// IR.
+class MemorySlotPromoter {
+public:
+  MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
+                     RewriterBase &rewriter, DominanceInfo &dominance,
+                     MemorySlotPromotionInfo info,
+                     const Mem2RegStatistics &statistics);
+
+  /// Actually promotes the slot by mutating IR. Promoting a slot DOES
+  /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
+  /// promotion info should NOT be performed in batches.
+  void promoteSlot();
+
+private:
+  /// Computes the reaching definition for all the operations that require
+  /// promotion. `reachingDef` is the value the slot should contain at the
+  /// beginning of the block. This method returns the reached definition at the
+  /// end of the block.
+  Value computeReachingDefInBlock(Block *block, Value reachingDef);
+
+  /// Computes the reaching definition for all the operations that require
+  /// promotion. `reachingDef` corresponds to the initial value the
+  /// slot will contain before any write, typically a poison value.
+  void computeReachingDefInRegion(Region *region, Value reachingDef);
+
+  /// Removes the blocking uses of the slot, in topological order.
+  void removeBlockingUses();
+
+  /// Lazily-constructed default value representing the content of the slot when
+  /// no store has been executed. This function may mutate IR.
+  Value getLazyDefaultValue();
+
+  MemorySlot slot;
+  PromotableAllocationOpInterface allocator;
+  RewriterBase &rewriter;
+  /// Potentially non-initialized default value. Use `getLazyDefaultValue` to
+  /// initialize it on demand.
+  Value defaultValue;
+  /// Contains the reaching definition at this operation. Reaching definitions
+  /// are only computed for promotable memory operations with blocking uses.
+  DenseMap<PromotableMemOpInterface, Value> reachingDefs;
+  DominanceInfo &dominance;
+  MemorySlotPromotionInfo info;
+  const Mem2RegStatistics &statistics;
+};
+
+} // namespace
+
 MemorySlotPromoter::MemorySlotPromoter(
     MemorySlot slot, PromotableAllocationOpInterface allocator,
-    OpBuilder &builder, DominanceInfo &dominance, MemorySlotPromotionInfo info)
-    : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
-      info(std::move(info)) {
+    RewriterBase &rewriter, DominanceInfo &dominance,
+    MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+    : slot(slot), allocator(allocator), rewriter(rewriter),
+      dominance(dominance), info(std::move(info)), statistics(statistics) {
 #ifndef NDEBUG
   auto isResultOrNewBlockArgument = [&]() {
     if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -114,9 +226,9 @@ Value MemorySlotPromoter::getLazyDefaultValue() {
   if (defaultValue)
     return defaultValue;
 
-  OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPointToStart(slot.ptr.getParentBlock());
-  return defaultValue = allocator.getDefaultValue(slot, builder);
+  RewriterBase::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(slot.ptr.getParentBlock());
+  return defaultValue = allocator.getDefaultValue(slot, rewriter);
 }
 
 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
@@ -341,11 +453,37 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
     Block *block = job.block->getBlock();
 
     if (info.mergePoints.contains(block)) {
-      BlockArgument blockArgument =
-          block->addArgument(slot.elemType, slot.ptr.getLoc());
-      builder.setInsertionPointToStart(block);
-      allocator.handleBlockArgument(slot, blockArgument, builder);
+      // If the block is a merge point, we need to add a block argument to hold
+      // the selected reaching definition. This has to be a bit complicated
+      // because of RewriterBase limitations: we need to create a new block with
+      // the extra block argument, move the content of the block to the new
+      // block, and replace the block with the new block in the merge point set.
+      SmallVector<Type> argTypes;
+      SmallVector<Location> argLocs;
+      for (BlockArgument arg : block->getArguments()) {
+        argTypes.push_back(arg.getType());
+        argLocs.push_back(arg.getLoc());
+      }
+      argTypes.push_back(slot.elemType);
+      argLocs.push_back(slot.ptr.getLoc());
+      Block *newBlock = rewriter.createBlock(block, argTypes, argLocs);
+
+      info.mergePoints.erase(block);
+      info.mergePoints.insert(newBlock);
+
+      rewriter.replaceAllUsesWith(block, newBlock);
+      rewriter.mergeBlocks(block, newBlock,
+                           newBlock->getArguments().drop_back());
+
+      block = newBlock;
+
+      BlockArgument blockArgument = block->getArguments().back();
+      rewriter.setInsertionPointToStart(block);
+      allocator.handleBlockArgument(slot, blockArgument, rewriter);
       job.reachingDef = blockArgument;
+
+      if (statistics.newBlockArgumentAmount)
+        (*statistics.newBlockArgumentAmount)++;
     }
 
     job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
@@ -355,8 +493,10 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
         if (info.mergePoints.contains(blockOperand.get())) {
           if (!job.reachingDef)
             job.reachingDef = getLazyDefaultValue();
-          terminator.getSuccessorOperands(blockOperand.getOperandNumber())
-              .append(job.reachingDef);
+          rewriter.updateRootInPlace(terminator, [&]() {
+            terminator.getSuccessorOperands(blockOperand.getOperandNumber())
+                .append(job.reachingDef);
+          });
         }
       }
     }
@@ -382,24 +522,24 @@ void MemorySlotPromoter::removeBlockingUses() {
       if (!reachingDef)
         reachingDef = getLazyDefaultValue();
 
-      builder.setInsertionPointAfter(toPromote);
+      rewriter.setInsertionPointAfter(toPromote);
       if (toPromoteMemOp.removeBlockingUses(
-              slot, info.userToBlockingUses[toPromote], builder, reachingDef) ==
-          DeletionKind::Delete)
+              slot, info.userToBlockingUses[toPromote], rewriter,
+              reachingDef) == DeletionKind::Delete)
         toErase.push_back(toPromote);
 
       continue;
     }
 
     auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
-    builder.setInsertionPointAfter(toPromote);
+    rewriter.setInsertionPointAfter(toPromote);
     if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
-                                          builder) == DeletionKind::Delete)
+                                          rewriter) == DeletionKind::Delete)
       toErase.push_back(toPromote);
   }
 
   for (Operation *toEraseOp : toErase)
-    toEraseOp->erase();
+    rewriter.eraseOp(toEraseOp);
 
   assert(slot.ptr.use_empty() &&
          "after promotion, the slot pointer should not be used anymore");
@@ -421,87 +561,73 @@ void MemorySlotPromoter::promoteSlot() {
       assert(succOperands.size() == mergePoint->getNumArguments() ||
              succOperands.size() + 1 == mergePoint->getNumArguments());
       if (succOperands.size() + 1 == mergePoint->getNumArguments())
-        succOperands.append(getLazyDefaultValue());
+        rewriter.updateRootInPlace(
+            user, [&]() { succOperands.append(getLazyDefaultValue()); });
     }
   }
 
   LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
                           << "\n");
 
-  allocator.handlePromotionComplete(slot, defaultValue);
+  if (statistics.promotedAmount)
+    (*statistics.promotedAmount)++;
+
+  allocator.handlePromotionComplete(slot, defaultValue, rewriter);
 }
 
 LogicalResult mlir::tryToPromoteMemorySlots(
-    ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
-    DominanceInfo &dominance) {
-  // Actual promotion may invalidate the dominance analysis, so slot promotion
-  // is prepated in batches.
-  SmallVector<MemorySlotPromoter> toPromote;
+    ArrayRef<PromotableAllocationOpInterface> allocators,
+    RewriterBase &rewriter, Mem2RegStatistics statistics) {
+  DominanceInfo dominance;
+
+  bool promotedAny = false;
+
   for (PromotableAllocationOpInterface allocator : allocators) {
     for (MemorySlot slot : allocator.getPromotableSlots()) {
       if (slot.ptr.use_empty())
         continue;
 
+      DominanceInfo dominance;
       MemorySlotPromotionAnalyzer analyzer(slot, dominance);
       std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
-      if (info)
-        toPromote.emplace_back(slot, allocator, builder, dominance,
-                               std::move(*info));
+      if (info) {
+        MemorySlotPromoter(slot, allocator, rewriter, dominance,
+                           std::move(*info), statistics)
+            .promoteSlot();
+        promotedAny = true;
+      }
     }
   }
 
-  for (MemorySlotPromoter &promoter : toPromote)
-    promoter.promoteSlot();
-
-  return success(!toPromote.empty());
+  return success(promotedAny);
 }
 
-LogicalResult Mem2RegPattern::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult
+Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator,
+                                PatternRewriter &rewriter) const {
   hasBoundedRewriteRecursion();
-
-  if (op->getNumRegions() == 0)
-    return failure();
-
-  DominanceInfo dominance;
-
-  SmallVector<PromotableAllocationOpInterface> allocators;
-  // Build a list of allocators to attempt to promote the slots of.
-  for (Region &region : op->getRegions())
-    for (auto allocator : region.getOps<PromotableAllocationOpInterface>())
-      allocators.emplace_back(allocator);
-
-  // Because pattern rewriters are normally not expressive enough to support a
-  // transformation like mem2reg, this uses an escape hatch to mark modified
-  // operations manually and operate outside of its context.
-  rewriter.startRootUpdate(op);
-
-  OpBuilder builder(rewriter.getContext());
-
-  if (failed(tryToPromoteMemorySlots(allocators, builder, dominance))) {
-    rewriter.cancelRootUpdate(op);
-    return failure();
-  }
-
-  rewriter.finalizeRootUpdate(op);
-  return success();
+  return tryToPromoteMemorySlots({allocator}, rewriter, statistics);
 }
 
 namespace {
 
 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
+  using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
+
   void runOnOperation() override {
     Operation *scopeOp = getOperation();
-    bool changed = false;
+
+    Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount};
+
+    GreedyRewriteConfig config;
+    config.enableRegionSimplification = enableRegionSimplification;
 
     RewritePatternSet rewritePatterns(&getContext());
-    rewritePatterns.add<Mem2RegPattern>(&getContext());
+    rewritePatterns.add<Mem2RegPattern>(&getContext(), statictics);
     FrozenRewritePatternSet frozen(std::move(rewritePatterns));
-    (void)applyOpPatternsAndFold({scopeOp}, frozen, GreedyRewriteConfig(),
-                                 &changed);
 
-    if (!changed)
-      markAllAnalysesPreserved();
+    if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config)))
+      signalPassFailure();
   }
 };
 

diff  --git a/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
index d8d04dfcfec51..0c1908ec8fdce 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(llvm.func(mem2reg))' | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(llvm.func(mem2reg{region-simplify=false}))' | FileCheck %s
 
 llvm.func @use(i64)
 llvm.func @use_ptr(!llvm.ptr)

diff  --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 090f9133f7a96..fc696c5073c30 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg))" --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @default_value
 llvm.func @default_value() -> i32 {

diff  --git a/mlir/test/Dialect/MemRef/mem2reg-statistics.mlir b/mlir/test/Dialect/MemRef/mem2reg-statistics.mlir
new file mode 100644
index 0000000000000..29ca51194ffd3
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/mem2reg-statistics.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file --mlir-pass-statistics 2>&1 >/dev/null | FileCheck %s
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 0 new block args
+// CHECK-NEXT: (S) 1 promoted slots
+func.func @basic() -> i32 {
+  %0 = arith.constant 5 : i32
+  %1 = memref.alloca() : memref<i32>
+  memref.store %0, %1[] : memref<i32>
+  %2 = memref.load %1[] : memref<i32>
+  return %2 : i32
+}
+
+// -----
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 0 new block args
+// CHECK-NEXT: (S) 0 promoted slots
+func.func @no_alloca() -> i32 {
+  %0 = arith.constant 5 : i32
+  return %0 : i32
+}
+
+// -----
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 2 new block args
+// CHECK-NEXT: (S) 1 promoted slots
+func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) {
+  %alloca = memref.alloca() : memref<i64>
+  memref.store %arg2, %alloca[] : memref<i64>
+  cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+  %use = memref.load %alloca[] : memref<i64>
+  call @use(%use) : (i64) -> ()
+  memref.store %arg0, %alloca[] : memref<i64>
+  cf.br ^bb2
+^bb2:
+  cf.br ^bb1
+}
+
+func.func @use(%arg: i64) { return }
+
+// -----
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 0 new block args
+// CHECK-NEXT: (S) 3 promoted slots
+func.func @recursive(%arg: i64) -> i64 {
+  %alloca0 = memref.alloca() : memref<memref<memref<i64>>>
+  %alloca1 = memref.alloca() : memref<memref<i64>>
+  %alloca2 = memref.alloca() : memref<i64>
+  memref.store %arg, %alloca2[] : memref<i64>
+  memref.store %alloca2, %alloca1[] : memref<memref<i64>>
+  memref.store %alloca1, %alloca0[] : memref<memref<memref<i64>>>
+  %load0 = memref.load %alloca0[] : memref<memref<memref<i64>>>
+  %load1 = memref.load %load0[] : memref<memref<i64>>
+  %load2 = memref.load %load1[] : memref<i64>
+  return %load2 : i64
+}

diff  --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
index 86707ac0b4971..d300699f6f342 100644
--- a/mlir/test/Dialect/MemRef/mem2reg.mlir
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg{region-simplify=false}))' --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func.func @basic
 func.func @basic() -> i32 {
@@ -148,20 +148,18 @@ func.func @deny_store_of_alloca(%arg: memref<memref<i32>>) -> i32 {
 
 // CHECK-LABEL: func.func @promotable_nonpromotable_intertwined
 func.func @promotable_nonpromotable_intertwined() -> i32 {
-  // CHECK: %[[VAL:.*]] = arith.constant 5 : i32
-  %0 = arith.constant 5 : i32
   // CHECK: %[[NON_PROMOTED:.*]] = memref.alloca() : memref<i32>
-  %1 = memref.alloca() : memref<i32>
+  %0 = memref.alloca() : memref<i32>
   // CHECK-NOT: = memref.alloca() : memref<memref<i32>>
-  %2 = memref.alloca() : memref<memref<i32>>
-  memref.store %1, %2[] : memref<memref<i32>>
-  %3 = memref.load %2[] : memref<memref<i32>>
+  %1 = memref.alloca() : memref<memref<i32>>
+  memref.store %0, %1[] : memref<memref<i32>>
+  %2 = memref.load %1[] : memref<memref<i32>>
   // CHECK: call @use(%[[NON_PROMOTED]])
-  call @use(%1) : (memref<i32>) -> ()
+  call @use(%0) : (memref<i32>) -> ()
   // CHECK: %[[RES:.*]] = memref.load %[[NON_PROMOTED]][]
-  %4 = memref.load %1[] : memref<i32>
+  %3 = memref.load %0[] : memref<i32>
   // CHECK: return %[[RES]] : i32
-  return %4 : i32
+  return %3 : i32
 }
 
 func.func @use(%arg: memref<i32>) { return }


        


More information about the Mlir-commits mailing list