[Mlir-commits] [mlir] 084e2b5 - [MLIR][Interfaces] Change MemorySlotInterface to use OpBuilder (#91341)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 7 22:40:19 PDT 2024


Author: Christian Ulmann
Date: 2024-05-08T07:40:15+02:00
New Revision: 084e2b53d22c11e013b0a495b65d39aa7f934048

URL: https://github.com/llvm/llvm-project/commit/084e2b53d22c11e013b0a495b65d39aa7f934048
DIFF: https://github.com/llvm/llvm-project/commit/084e2b53d22c11e013b0a495b65d39aa7f934048.diff

LOG: [MLIR][Interfaces] Change MemorySlotInterface to use OpBuilder (#91341)

This commit changes the `MemorySlotInterface` back to using `OpBuilder`
instead of a rewriter. This was originally introduced in
https://reviews.llvm.org/D150432 but it was shown that patterns are a
bad idea for both Mem2Reg and SROA.
Mem2Reg suffers from the usage of a rewriter due to being forced to
create new basic blocks. This is an issue, as it leads to the
invalidation of the dominance information, which can be expensive to
recompute.

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
    mlir/include/mlir/Transforms/Mem2Reg.h
    mlir/include/mlir/Transforms/SROA.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
    mlir/lib/Transforms/Mem2Reg.cpp
    mlir/lib/Transforms/SROA.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 764fa6d547b2e..adf182ac7069d 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -40,42 +40,40 @@ def PromotableAllocationOpInterface
         Provides the default Value of this memory slot. The provided Value
         will be used as the reaching definition of loads done before any store.
         This Value must outlive the promotion and dominate all the uses of this
-        slot's pointer. The provided rewriter can be used to create the default
+        slot's pointer. The provided builder can be used to create the default
         value on the fly.
 
-        The rewriter is located at the beginning of the block where the slot
-        pointer is defined. All IR mutations must happen through the rewriter.
+        The builder is located at the beginning of the block where the slot
+        pointer is defined.
       }], "::mlir::Value", "getDefaultValue",
       (ins
         "const ::mlir::MemorySlot &":$slot,
-        "::mlir::RewriterBase &":$rewriter)
+        "::mlir::OpBuilder &":$builder)
     >,
     InterfaceMethod<[{
         Hook triggered for every new block argument added to a block.
         This will only be called for slots declared by this operation.
 
-        The rewriter is located at the beginning of the block on call. All IR
-        mutations must happen through the rewriter.
+        The builder is located at the beginning of the block on call. All IR
+        mutations must happen through the builder.
       }],
       "void", "handleBlockArgument",
       (ins
         "const ::mlir::MemorySlot &":$slot,
         "::mlir::BlockArgument":$argument,
-        "::mlir::RewriterBase &":$rewriter
+        "::mlir::OpBuilder &":$builder
       )
     >,
     InterfaceMethod<[{
         Hook triggered once the promotion of a slot is complete. This can
         also clean up the created default value if necessary.
         This will only be called for slots declared by this operation.
-
-        All IR mutations must happen through the rewriter.
       }],
       "void", "handlePromotionComplete",
       (ins
         "const ::mlir::MemorySlot &":$slot, 
         "::mlir::Value":$defaultValue,
-        "::mlir::RewriterBase &":$rewriter)
+        "::mlir::OpBuilder &":$builder)
     >,
   ];
 }
@@ -119,15 +117,14 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         The returned value must dominate all operations dominated by the storing
         operation.
 
-        If IR must be mutated to extract a concrete value being stored, mutation
-        must happen through the provided rewriter. The rewriter is located
-        immediately after the memory operation on call. No IR deletion is
-        allowed in this method. IR mutations must not introduce new uses of the
-        memory slot. Existing control flow must not be modified.
+        The builder is located immediately after the memory operation on call.
+        No IR deletion is allowed in this method. IR mutations must not
+        introduce new uses of the memory slot. Existing control flow must not
+        be modified.
       }],
       "::mlir::Value", "getStored",
       (ins "const ::mlir::MemorySlot &":$slot,
-           "::mlir::RewriterBase &":$rewriter,
+           "::mlir::OpBuilder &":$builder,
            "::mlir::Value":$reachingDef,
            "const ::mlir::DataLayout &":$dataLayout)
     >,
@@ -166,14 +163,13 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         have been done at the point of calling this method, but it will be done
         eventually.
 
-        The rewriter is located after the promotable operation on call. All IR
-        mutations must happen through the rewriter.
+        The builder is located after the promotable operation on call.
       }],
       "::mlir::DeletionKind",
       "removeBlockingUses",
       (ins "const ::mlir::MemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
-           "::mlir::RewriterBase &":$rewriter,
+           "::mlir::OpBuilder &":$builder,
            "::mlir::Value":$reachingDefinition,
            "const ::mlir::DataLayout &":$dataLayout)
     >,
@@ -224,13 +220,12 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         have been done at the point of calling this method, but it will be done
         eventually.
 
-        The rewriter is located after the promotable operation on call. All IR
-        mutations must happen through the rewriter.
+        The builder is located after the promotable operation on call.
       }],
       "::mlir::DeletionKind",
       "removeBlockingUses",
       (ins "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::OpBuilder &":$builder)
     >,
     InterfaceMethod<[{
         This method allows the promoted operation to visit the SSA values used
@@ -254,13 +249,12 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         scheduled for removal and if `requiresReplacedValues` returned
         true.
 
-        The rewriter is located after the promotable operation on call. All IR
-        mutations must happen through the rewriter. During the transformation,
-        *no operation should be deleted*.
+        The builder is located after the promotable operation on call. During
+        the transformation, *no operation should be deleted*.
       }],
       "void", "visitReplacedValues",
       (ins "::llvm::ArrayRef<std::pair<::mlir::Operation*, ::mlir::Value>>":$mutatedDefs,
-           "::mlir::RewriterBase &":$rewriter), [{}], [{ return; }]
+           "::mlir::OpBuilder &":$builder), [{}], [{ return; }]
     >,
   ];
 }
@@ -293,25 +287,23 @@ def DestructurableAllocationOpInterface
         at the end of this call. Only generates subslots for the indices found in
         `usedIndices` since all other subslots are unused.
 
-        The rewriter is located at the beginning of the block where the slot
-        pointer is defined. All IR mutations must happen through the rewriter.
+        The builder is located at the beginning of the block where the slot
+        pointer is defined.
       }],
       "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot>",
       "destructure",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::OpBuilder &":$builder)
     >,
     InterfaceMethod<[{
         Hook triggered once the destructuring of a slot is complete, meaning the
         original slot is no longer being refered to and could be deleted.
         This will only be called for slots declared by this operation.
-
-        All IR mutations must happen through the rewriter.
       }],
       "void", "handleDestructuringComplete",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::OpBuilder &":$builder)
     >,
   ];
 }
@@ -376,15 +368,14 @@ def DestructurableAccessorOpInterface
         Rewires the use of a slot to the generated subslots, without deleting
         any operation. Returns whether the accessor should be deleted.
 
-        All IR mutations must happen through the rewriter. Deletion of
-        operations is not allowed, only the accessor can be scheduled for
-        deletion by returning the appropriate value.
+        Deletion of operations is not allowed, only the accessor can be
+        scheduled for deletion by returning the appropriate value.
       }],
       "::mlir::DeletionKind",
       "rewire",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot> &":$subslots,
-           "::mlir::RewriterBase &":$rewriter,
+           "::mlir::OpBuilder &":$builder,
            "const ::mlir::DataLayout &":$dataLayout)
     >
   ];

diff  --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index ed10644e26a51..b4f939d654142 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -27,7 +27,7 @@ struct Mem2RegStatistics {
 /// at least one memory slot was promoted.
 LogicalResult
 tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
-                        RewriterBase &rewriter, const DataLayout &dataLayout,
+                        OpBuilder &builder, const DataLayout &dataLayout,
                         Mem2RegStatistics statistics = {});
 
 } // namespace mlir

diff  --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h
index d09a7989edeab..fa84fb1eae73a 100644
--- a/mlir/include/mlir/Transforms/SROA.h
+++ b/mlir/include/mlir/Transforms/SROA.h
@@ -31,7 +31,7 @@ struct SROAStatistics {
 /// failure if no slot was destructured.
 LogicalResult tryToDestructureMemorySlots(
     ArrayRef<DestructurableAllocationOpInterface> allocators,
-    RewriterBase &rewriter, const DataLayout &dataLayout,
+    OpBuilder &builder, const DataLayout &dataLayout,
     SROAStatistics statistics = {});
 
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 6eeb13ebffb51..70102e1c81920 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -36,26 +36,26 @@ llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
 }
 
 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
-                                      RewriterBase &rewriter) {
-  return rewriter.create<LLVM::UndefOp>(getLoc(), slot.elemType);
+                                      OpBuilder &builder) {
+  return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType);
 }
 
 void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
                                          BlockArgument argument,
-                                         RewriterBase &rewriter) {
+                                         OpBuilder &builder) {
   for (Operation *user : getOperation()->getUsers())
     if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
-      rewriter.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
-                                        declareOp.getVarInfo(),
-                                        declareOp.getLocationExpr());
+      builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
+                                       declareOp.getVarInfo(),
+                                       declareOp.getLocationExpr());
 }
 
 void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
                                              Value defaultValue,
-                                             RewriterBase &rewriter) {
+                                             OpBuilder &builder) {
   if (defaultValue && defaultValue.use_empty())
-    rewriter.eraseOp(defaultValue.getDefiningOp());
-  rewriter.eraseOp(*this);
+    defaultValue.getDefiningOp()->erase();
+  this->erase();
 }
 
 SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
@@ -78,16 +78,16 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
 DenseMap<Attribute, MemorySlot>
 LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
                             const SmallPtrSetImpl<Attribute> &usedIndices,
-                            RewriterBase &rewriter) {
+                            OpBuilder &builder) {
   assert(slot.ptr == getResult());
-  rewriter.setInsertionPointAfter(*this);
+  builder.setInsertionPointAfter(*this);
 
   auto destructurableType = cast<DestructurableTypeInterface>(getElemType());
   DenseMap<Attribute, MemorySlot> slotMap;
   for (Attribute index : usedIndices) {
     Type elemType = destructurableType.getTypeAtIndex(index);
     assert(elemType && "used index must exist");
-    auto subAlloca = rewriter.create<LLVM::AllocaOp>(
+    auto subAlloca = builder.create<LLVM::AllocaOp>(
         getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
         getArraySize());
     slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
@@ -97,9 +97,9 @@ LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
 }
 
 void LLVM::AllocaOp::handleDestructuringComplete(
-    const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
+    const DestructurableMemorySlot &slot, OpBuilder &builder) {
   assert(slot.ptr == getResult());
-  rewriter.eraseOp(*this);
+  this->erase();
 }
 
 //===----------------------------------------------------------------------===//
@@ -112,7 +112,7 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
 
 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
-Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
                               Value reachingDef, const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
@@ -175,7 +175,7 @@ static bool isBigEndian(const DataLayout &dataLayout) {
 
 /// Converts a value to an integer type of the same size.
 /// Assumes that the type can be converted.
-static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
+static Value castToSameSizedInt(OpBuilder &builder, Location loc, Value val,
                                 const DataLayout &dataLayout) {
   Type type = val.getType();
   assert(isSupportedTypeForConversion(type) &&
@@ -185,15 +185,15 @@ static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
     return val;
 
   uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
-  IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
+  IntegerType valueSizeInteger = builder.getIntegerType(typeBitSize);
 
   if (isa<LLVM::LLVMPointerType>(type))
-    return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
-  return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
+    return builder.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
+  return builder.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
 }
 
 /// Converts a value with an integer type to `targetType`.
-static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
+static Value castIntValueToSameSizedType(OpBuilder &builder, Location loc,
                                          Value val, Type targetType) {
   assert(isa<IntegerType>(val.getType()) &&
          "expected value to have an integer type");
@@ -202,13 +202,13 @@ static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
   if (val.getType() == targetType)
     return val;
   if (isa<LLVM::LLVMPointerType>(targetType))
-    return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
-  return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
+    return builder.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
+  return builder.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
 }
 
 /// Constructs operations that convert `srcValue` into a new value of type
 /// `targetType`. Assumes the types have the same bitsize.
-static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
+static Value castSameSizedTypes(OpBuilder &builder, Location loc,
                                 Value srcValue, Type targetType,
                                 const DataLayout &dataLayout) {
   Type srcType = srcValue.getType();
@@ -226,18 +226,18 @@ static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
   // provenance.
   if (isa<LLVM::LLVMPointerType>(targetType) &&
       isa<LLVM::LLVMPointerType>(srcType))
-    return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
-                                                        srcValue);
+    return builder.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+                                                       srcValue);
 
   // For all other castable types, casting through integers is necessary.
-  Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
-  return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
+  Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
+  return castIntValueToSameSizedType(builder, loc, replacement, targetType);
 }
 
 /// Constructs operations that convert `srcValue` into a new value of type
 /// `targetType`. Performs bit-level extraction if the source type is larger
 /// than the target type. Assumes that this conversion is possible.
-static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
+static Value createExtractAndCast(OpBuilder &builder, Location loc,
                                   Value srcValue, Type targetType,
                                   const DataLayout &dataLayout) {
   // Get the types of the source and target values.
@@ -249,31 +249,31 @@ static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
   uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
   uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
   if (srcTypeSize == targetTypeSize)
-    return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout);
+    return castSameSizedTypes(builder, loc, srcValue, targetType, dataLayout);
 
   // First, cast the value to a same-sized integer type.
-  Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
+  Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
 
   // Truncate the integer if the size of the target is less than the value.
   if (isBigEndian(dataLayout)) {
     uint64_t shiftAmount = srcTypeSize - targetTypeSize;
-    auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
-        loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+    auto shiftConstant = builder.create<LLVM::ConstantOp>(
+        loc, builder.getIntegerAttr(srcType, shiftAmount));
     replacement =
-        rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
+        builder.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
   }
 
-  replacement = rewriter.create<LLVM::TruncOp>(
-      loc, rewriter.getIntegerType(targetTypeSize), replacement);
+  replacement = builder.create<LLVM::TruncOp>(
+      loc, builder.getIntegerType(targetTypeSize), replacement);
 
   // Now cast the integer to the actual target type if required.
-  return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
+  return castIntValueToSameSizedType(builder, loc, replacement, targetType);
 }
 
 /// Constructs operations that insert the bits of `srcValue` into the
 /// "beginning" of `reachingDef` (beginning is endianness dependent).
 /// Assumes that this conversion is possible.
-static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
+static Value createInsertAndCast(OpBuilder &builder, Location loc,
                                  Value srcValue, Value reachingDef,
                                  const DataLayout &dataLayout) {
 
@@ -284,27 +284,27 @@ static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
   uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
   uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
   if (slotTypeSize == valueTypeSize)
-    return castSameSizedTypes(rewriter, loc, srcValue, reachingDef.getType(),
+    return castSameSizedTypes(builder, loc, srcValue, reachingDef.getType(),
                               dataLayout);
 
   // In the case where the store only overwrites parts of the memory,
   // bit fiddling is required to construct the new value.
 
   // First convert both values to integers of the same size.
-  Value defAsInt = castToSameSizedInt(rewriter, loc, reachingDef, dataLayout);
-  Value valueAsInt = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
+  Value defAsInt = castToSameSizedInt(builder, loc, reachingDef, dataLayout);
+  Value valueAsInt = castToSameSizedInt(builder, loc, srcValue, dataLayout);
   // Extend the value to the size of the reaching definition.
   valueAsInt =
-      rewriter.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
+      builder.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
   uint64_t sizeDifference = slotTypeSize - valueTypeSize;
   if (isBigEndian(dataLayout)) {
     // On big endian systems, a store to the base pointer overwrites the most
     // significant bits. To accomodate for this, the stored value needs to be
     // shifted into the according position.
-    Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
-        loc, rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
+    Value bigEndianShift = builder.create<LLVM::ConstantOp>(
+        loc, builder.getIntegerAttr(defAsInt.getType(), sizeDifference));
     valueAsInt =
-        rewriter.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
+        builder.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
   }
 
   // Construct the mask that is used to erase the bits that are overwritten by
@@ -322,23 +322,23 @@ static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
   }
 
   // Mask out the affected bits ...
-  Value mask = rewriter.create<LLVM::ConstantOp>(
-      loc, rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
-  Value masked = rewriter.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
+  Value mask = builder.create<LLVM::ConstantOp>(
+      loc, builder.getIntegerAttr(defAsInt.getType(), maskValue));
+  Value masked = builder.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
 
   // ... and combine the result with the new value.
-  Value combined = rewriter.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
+  Value combined = builder.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
 
-  return castIntValueToSameSizedType(rewriter, loc, combined,
+  return castIntValueToSameSizedType(builder, loc, combined,
                                      reachingDef.getType());
 }
 
-Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
                                Value reachingDef,
                                const DataLayout &dataLayout) {
   assert(reachingDef && reachingDef.getType() == slot.elemType &&
          "expected the reaching definition's type to match the slot's type");
-  return createInsertAndCast(rewriter, getLoc(), getValue(), reachingDef,
+  return createInsertAndCast(builder, getLoc(), getValue(), reachingDef,
                              dataLayout);
 }
 
@@ -360,13 +360,13 @@ bool LLVM::LoadOp::canUsesBeRemoved(
 
 DeletionKind LLVM::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition,
+    OpBuilder &builder, Value reachingDefinition,
     const DataLayout &dataLayout) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  Value newResult = createExtractAndCast(rewriter, getLoc(), reachingDefinition,
+  Value newResult = createExtractAndCast(builder, getLoc(), reachingDefinition,
                                          getResult().getType(), dataLayout);
-  rewriter.replaceAllUsesWith(getResult(), newResult);
+  getResult().replaceAllUsesWith(newResult);
   return DeletionKind::Delete;
 }
 
@@ -390,7 +390,7 @@ bool LLVM::StoreOp::canUsesBeRemoved(
 
 DeletionKind LLVM::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition,
+    OpBuilder &builder, Value reachingDefinition,
     const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
@@ -452,14 +452,13 @@ bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
 
 DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
                                   DenseMap<Attribute, MemorySlot> &subslots,
-                                  RewriterBase &rewriter,
+                                  OpBuilder &builder,
                                   const DataLayout &dataLayout) {
   auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
   auto it = subslots.find(index);
   assert(it != subslots.end());
 
-  rewriter.modifyOpInPlace(
-      *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+  getAddrMutable().set(it->getSecond().ptr);
   return DeletionKind::Keep;
 }
 
@@ -491,14 +490,13 @@ bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
 
 DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
                                    DenseMap<Attribute, MemorySlot> &subslots,
-                                   RewriterBase &rewriter,
+                                   OpBuilder &builder,
                                    const DataLayout &dataLayout) {
   auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
   auto it = subslots.find(index);
   assert(it != subslots.end());
 
-  rewriter.modifyOpInPlace(
-      *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+  getAddrMutable().set(it->getSecond().ptr);
   return DeletionKind::Keep;
 }
 
@@ -523,7 +521,7 @@ bool LLVM::BitcastOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::BitcastOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
@@ -535,7 +533,7 @@ bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
@@ -547,7 +545,7 @@ bool LLVM::LifetimeStartOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
@@ -559,7 +557,7 @@ bool LLVM::LifetimeEndOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
@@ -571,7 +569,7 @@ bool LLVM::InvariantStartOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::InvariantStartOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
@@ -583,7 +581,7 @@ bool LLVM::InvariantEndOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::InvariantEndOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
@@ -595,7 +593,7 @@ bool LLVM::DbgDeclareOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
@@ -611,28 +609,27 @@ bool LLVM::DbgValueOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::DbgValueOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
-  // Rewriter by default is after '*this', but we need it before '*this'.
-  rewriter.setInsertionPoint(*this);
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+  // builder by default is after '*this', but we need it before '*this'.
+  builder.setInsertionPoint(*this);
 
   // Rather than dropping the debug value, replace it with undef to preserve the
   // debug local variable info. This allows the debugger to inform the user that
   // the variable has been optimized out.
   auto undef =
-      rewriter.create<UndefOp>(getValue().getLoc(), getValue().getType());
-  rewriter.modifyOpInPlace(*this, [&] { getValueMutable().assign(undef); });
+      builder.create<UndefOp>(getValue().getLoc(), getValue().getType());
+  getValueMutable().assign(undef);
   return DeletionKind::Keep;
 }
 
 bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }
 
 void LLVM::DbgDeclareOp::visitReplacedValues(
-    ArrayRef<std::pair<Operation *, Value>> definitions,
-    RewriterBase &rewriter) {
+    ArrayRef<std::pair<Operation *, Value>> definitions, OpBuilder &builder) {
   for (auto [op, value] : definitions) {
-    rewriter.setInsertionPointAfter(op);
-    rewriter.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
-                                      getLocationExpr());
+    builder.setInsertionPointAfter(op);
+    builder.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
+                                     getLocationExpr());
   }
 }
 
@@ -658,7 +655,7 @@ bool LLVM::GEPOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::GEPOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
@@ -855,7 +852,7 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
 
 DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
                                  DenseMap<Attribute, MemorySlot> &subslots,
-                                 RewriterBase &rewriter,
+                                 OpBuilder &builder,
                                  const DataLayout &dataLayout) {
   std::optional<SubslotAccessInfo> accessInfo =
       getSubslotAccessInfo(slot, dataLayout, *this);
@@ -864,11 +861,11 @@ DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
       IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
   const MemorySlot &newSlot = subslots.at(indexAttr);
 
-  auto byteType = IntegerType::get(rewriter.getContext(), 8);
-  auto newPtr = rewriter.createOrFold<LLVM::GEPOp>(
+  auto byteType = IntegerType::get(builder.getContext(), 8);
+  auto newPtr = builder.createOrFold<LLVM::GEPOp>(
       getLoc(), getResult().getType(), byteType, newSlot.ptr,
       ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds());
-  rewriter.replaceAllUsesWith(getResult(), newPtr);
+  getResult().replaceAllUsesWith(newPtr);
   return DeletionKind::Delete;
 }
 
@@ -938,7 +935,7 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
   return getDst() == slot.ptr;
 }
 
-Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
                                 Value reachingDef,
                                 const DataLayout &dataLayout) {
   // TODO: Support non-integer types.
@@ -953,14 +950,14 @@ Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
         // or-ing it with the previous value.
         uint64_t coveredBits = 8;
         Value currentValue =
-            rewriter.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
+            builder.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
         while (coveredBits < intType.getWidth()) {
           Value shiftBy =
-              rewriter.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
+              builder.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
           Value shifted =
-              rewriter.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
+              builder.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
           currentValue =
-              rewriter.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
+              builder.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
           coveredBits *= 2;
         }
 
@@ -994,7 +991,7 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemsetOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition,
+    OpBuilder &builder, Value reachingDefinition,
     const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
@@ -1026,7 +1023,7 @@ bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
 
 DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
                                     DenseMap<Attribute, MemorySlot> &subslots,
-                                    RewriterBase &rewriter,
+                                    OpBuilder &builder,
                                     const DataLayout &dataLayout) {
   std::optional<DenseMap<Attribute, Type>> types =
       cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
@@ -1063,15 +1060,14 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
       uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
 
       Value newMemsetSizeValue =
-          rewriter
+          builder
               .create<LLVM::ConstantOp>(
                   getLen().getLoc(),
                   IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
               .getResult();
 
-      rewriter.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr,
-                                      getVal(), newMemsetSizeValue,
-                                      getIsVolatile());
+      builder.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr, getVal(),
+                                     newMemsetSizeValue, getIsVolatile());
     }
 
     covered += typeSize;
@@ -1096,8 +1092,8 @@ static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) {
 
 template <class MemcpyLike>
 static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot,
-                             RewriterBase &rewriter) {
-  return rewriter.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc());
+                             OpBuilder &builder) {
+  return builder.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc());
 }
 
 template <class MemcpyLike>
@@ -1122,10 +1118,9 @@ template <class MemcpyLike>
 static DeletionKind
 memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot,
                          const SmallPtrSetImpl<OpOperand *> &blockingUses,
-                         RewriterBase &rewriter, Value reachingDefinition) {
+                         OpBuilder &builder, Value reachingDefinition) {
   if (op.loadsFrom(slot))
-    rewriter.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition,
-                                   op.getDst());
+    builder.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition, op.getDst());
   return DeletionKind::Delete;
 }
 
@@ -1168,23 +1163,23 @@ static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
 namespace {
 
 template <class MemcpyLike>
-void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
+void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout,
                                MemcpyLike toReplace, Value dst, Value src,
                                Type toCpy, bool isVolatile) {
-  Value memcpySize = rewriter.create<LLVM::ConstantOp>(
+  Value memcpySize = builder.create<LLVM::ConstantOp>(
       toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(),
                                            layout.getTypeSize(toCpy)));
-  rewriter.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize,
-                              isVolatile);
+  builder.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize,
+                             isVolatile);
 }
 
 template <>
-void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
+void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout,
                                LLVM::MemcpyInlineOp toReplace, Value dst,
                                Value src, Type toCpy, bool isVolatile) {
   Type lenType = IntegerType::get(toReplace->getContext(),
                                   toReplace.getLen().getBitWidth());
-  rewriter.create<LLVM::MemcpyInlineOp>(
+  builder.create<LLVM::MemcpyInlineOp>(
       toReplace.getLoc(), dst, src,
       IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile);
 }
@@ -1196,7 +1191,7 @@ void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
 template <class MemcpyLike>
 static DeletionKind
 memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
-             DenseMap<Attribute, MemorySlot> &subslots, RewriterBase &rewriter,
+             DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder,
              const DataLayout &dataLayout) {
   if (subslots.empty())
     return DeletionKind::Delete;
@@ -1226,12 +1221,12 @@ memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
     SmallVector<LLVM::GEPArg> gepIndices{
         0, static_cast<int32_t>(
                cast<IntegerAttr>(index).getValue().getZExtValue())};
-    Value subslotPtrInOther = rewriter.create<LLVM::GEPOp>(
+    Value subslotPtrInOther = builder.create<LLVM::GEPOp>(
         op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), slot.elemType,
         isDst ? op.getSrc() : op.getDst(), gepIndices);
 
     // Then create a new memcpy out of this source pointer.
-    createMemcpyLikeToReplace(rewriter, dataLayout, op,
+    createMemcpyLikeToReplace(builder, dataLayout, op,
                               isDst ? subslot.ptr : subslotPtrInOther,
                               isDst ? subslotPtrInOther : subslot.ptr,
                               subslot.elemType, op.getIsVolatile());
@@ -1250,10 +1245,10 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
   return memcpyStoresTo(*this, slot);
 }
 
-Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, OpBuilder &builder,
                                 Value reachingDef,
                                 const DataLayout &dataLayout) {
-  return memcpyGetStored(*this, slot, rewriter);
+  return memcpyGetStored(*this, slot, builder);
 }
 
 bool LLVM::MemcpyOp::canUsesBeRemoved(
@@ -1266,9 +1261,9 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemcpyOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition,
+    OpBuilder &builder, Value reachingDefinition,
     const DataLayout &dataLayout) {
-  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
+  return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
                                   reachingDefinition);
 }
 
@@ -1288,9 +1283,9 @@ bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot,
 
 DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot,
                                     DenseMap<Attribute, MemorySlot> &subslots,
-                                    RewriterBase &rewriter,
+                                    OpBuilder &builder,
                                     const DataLayout &dataLayout) {
-  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
+  return memcpyRewire(*this, slot, subslots, builder, dataLayout);
 }
 
 bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) {
@@ -1302,9 +1297,9 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
-                                      RewriterBase &rewriter, Value reachingDef,
+                                      OpBuilder &builder, Value reachingDef,
                                       const DataLayout &dataLayout) {
-  return memcpyGetStored(*this, slot, rewriter);
+  return memcpyGetStored(*this, slot, builder);
 }
 
 bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
@@ -1317,9 +1312,9 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition,
+    OpBuilder &builder, Value reachingDefinition,
     const DataLayout &dataLayout) {
-  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
+  return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
                                   reachingDefinition);
 }
 
@@ -1341,9 +1336,8 @@ bool LLVM::MemcpyInlineOp::canRewire(
 DeletionKind
 LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot,
                              DenseMap<Attribute, MemorySlot> &subslots,
-                             RewriterBase &rewriter,
-                             const DataLayout &dataLayout) {
-  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
+                             OpBuilder &builder, const DataLayout &dataLayout) {
+  return memcpyRewire(*this, slot, subslots, builder, dataLayout);
 }
 
 bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) {
@@ -1354,10 +1348,10 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
   return memcpyStoresTo(*this, slot);
 }
 
-Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, OpBuilder &builder,
                                  Value reachingDef,
                                  const DataLayout &dataLayout) {
-  return memcpyGetStored(*this, slot, rewriter);
+  return memcpyGetStored(*this, slot, builder);
 }
 
 bool LLVM::MemmoveOp::canUsesBeRemoved(
@@ -1370,9 +1364,9 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemmoveOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition,
+    OpBuilder &builder, Value reachingDefinition,
     const DataLayout &dataLayout) {
-  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
+  return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
                                   reachingDefinition);
 }
 
@@ -1392,9 +1386,9 @@ bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot,
 
 DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot,
                                      DenseMap<Attribute, MemorySlot> &subslots,
-                                     RewriterBase &rewriter,
+                                     OpBuilder &builder,
                                      const DataLayout &dataLayout) {
-  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
+  return memcpyRewire(*this, slot, subslots, builder, dataLayout);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 958c5f0c8dbc7..dca07e84ea73c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -83,30 +83,30 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
 }
 
 Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
-                                        RewriterBase &rewriter) {
+                                        OpBuilder &builder) {
   assert(isSupportedElementType(slot.elemType));
   // TODO: support more types.
   return TypeSwitch<Type, Value>(slot.elemType)
       .Case([&](MemRefType t) {
-        return rewriter.create<memref::AllocaOp>(getLoc(), t);
+        return builder.create<memref::AllocaOp>(getLoc(), t);
       })
       .Default([&](Type t) {
-        return rewriter.create<arith::ConstantOp>(getLoc(), t,
-                                                  rewriter.getZeroAttr(t));
+        return builder.create<arith::ConstantOp>(getLoc(), t,
+                                                 builder.getZeroAttr(t));
       });
 }
 
 void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
                                                Value defaultValue,
-                                               RewriterBase &rewriter) {
+                                               OpBuilder &builder) {
   if (defaultValue.use_empty())
-    rewriter.eraseOp(defaultValue.getDefiningOp());
-  rewriter.eraseOp(*this);
+    defaultValue.getDefiningOp()->erase();
+  this->erase();
 }
 
 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
                                            BlockArgument argument,
-                                           RewriterBase &rewriter) {}
+                                           OpBuilder &builder) {}
 
 SmallVector<DestructurableMemorySlot>
 memref::AllocaOp::getDestructurableSlots() {
@@ -127,8 +127,8 @@ memref::AllocaOp::getDestructurableSlots() {
 DenseMap<Attribute, MemorySlot>
 memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
                               const SmallPtrSetImpl<Attribute> &usedIndices,
-                              RewriterBase &rewriter) {
-  rewriter.setInsertionPointAfter(*this);
+                              OpBuilder &builder) {
+  builder.setInsertionPointAfter(*this);
 
   DenseMap<Attribute, MemorySlot> slotMap;
 
@@ -136,7 +136,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
   for (Attribute usedIndex : usedIndices) {
     Type elemType = memrefType.getTypeAtIndex(usedIndex);
     MemRefType elemPtr = MemRefType::get({}, elemType);
-    auto subAlloca = rewriter.create<memref::AllocaOp>(getLoc(), elemPtr);
+    auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
     slotMap.try_emplace<MemorySlot>(usedIndex,
                                     {subAlloca.getResult(), elemType});
   }
@@ -145,9 +145,9 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
 }
 
 void memref::AllocaOp::handleDestructuringComplete(
-    const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
+    const DestructurableMemorySlot &slot, OpBuilder &builder) {
   assert(slot.ptr == getResult());
-  rewriter.eraseOp(*this);
+  this->erase();
 }
 
 //===----------------------------------------------------------------------===//
@@ -160,7 +160,7 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
 
 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
-Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+Value memref::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
                                 Value reachingDef,
                                 const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
@@ -179,11 +179,11 @@ bool memref::LoadOp::canUsesBeRemoved(
 
 DeletionKind memref::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition,
+    OpBuilder &builder, Value reachingDefinition,
     const DataLayout &dataLayout) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
+  getResult().replaceAllUsesWith(reachingDefinition);
   return DeletionKind::Delete;
 }
 
@@ -224,15 +224,13 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
 
 DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
                                     DenseMap<Attribute, MemorySlot> &subslots,
-                                    RewriterBase &rewriter,
+                                    OpBuilder &builder,
                                     const DataLayout &dataLayout) {
   Attribute index = getAttributeIndexFromIndexOperands(
       getContext(), getIndices(), getMemRefType());
   const MemorySlot &memorySlot = subslots.at(index);
-  rewriter.modifyOpInPlace(*this, [&]() {
-    setMemRef(memorySlot.ptr);
-    getIndicesMutable().clear();
-  });
+  setMemRef(memorySlot.ptr);
+  getIndicesMutable().clear();
   return DeletionKind::Keep;
 }
 
@@ -242,7 +240,7 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
   return getMemRef() == slot.ptr;
 }
 
-Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+Value memref::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
                                  Value reachingDef,
                                  const DataLayout &dataLayout) {
   return getValue();
@@ -261,7 +259,7 @@ bool memref::StoreOp::canUsesBeRemoved(
 
 DeletionKind memref::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition,
+    OpBuilder &builder, Value reachingDefinition,
     const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
@@ -282,15 +280,13 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
 
 DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
                                      DenseMap<Attribute, MemorySlot> &subslots,
-                                     RewriterBase &rewriter,
+                                     OpBuilder &builder,
                                      const DataLayout &dataLayout) {
   Attribute index = getAttributeIndexFromIndexOperands(
       getContext(), getIndices(), getMemRefType());
   const MemorySlot &memorySlot = subslots.at(index);
-  rewriter.modifyOpInPlace(*this, [&]() {
-    setMemRef(memorySlot.ptr);
-    getIndicesMutable().clear();
-  });
+  setMemRef(memorySlot.ptr);
+  getIndicesMutable().clear();
   return DeletionKind::Keep;
 }
 

diff  --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 71ba5bc076f0e..1d7ba4ca4f83e 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -164,7 +164,7 @@ class MemorySlotPromotionAnalyzer {
 class MemorySlotPromoter {
 public:
   MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
-                     RewriterBase &rewriter, DominanceInfo &dominance,
+                     OpBuilder &builder, DominanceInfo &dominance,
                      const DataLayout &dataLayout, MemorySlotPromotionInfo info,
                      const Mem2RegStatistics &statistics);
 
@@ -195,7 +195,7 @@ class MemorySlotPromoter {
 
   MemorySlot slot;
   PromotableAllocationOpInterface allocator;
-  RewriterBase &rewriter;
+  OpBuilder &builder;
   /// Potentially non-initialized default value. Use `getOrCreateDefaultValue`
   /// to initialize it on demand.
   Value defaultValue;
@@ -213,12 +213,10 @@ class MemorySlotPromoter {
 
 MemorySlotPromoter::MemorySlotPromoter(
     MemorySlot slot, PromotableAllocationOpInterface allocator,
-    RewriterBase &rewriter, DominanceInfo &dominance,
-    const DataLayout &dataLayout, MemorySlotPromotionInfo info,
-    const Mem2RegStatistics &statistics)
-    : slot(slot), allocator(allocator), rewriter(rewriter),
-      dominance(dominance), dataLayout(dataLayout), info(std::move(info)),
-      statistics(statistics) {
+    OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout,
+    MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+    : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
+      dataLayout(dataLayout), info(std::move(info)), statistics(statistics) {
 #ifndef NDEBUG
   auto isResultOrNewBlockArgument = [&]() {
     if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -236,9 +234,9 @@ Value MemorySlotPromoter::getOrCreateDefaultValue() {
   if (defaultValue)
     return defaultValue;
 
-  RewriterBase::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(slot.ptr.getParentBlock());
-  return defaultValue = allocator.getDefaultValue(slot, rewriter);
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(slot.ptr.getParentBlock());
+  return defaultValue = allocator.getDefaultValue(slot, builder);
 }
 
 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
@@ -437,8 +435,8 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
         reachingDefs.insert({memOp, reachingDef});
 
       if (memOp.storesTo(slot)) {
-        rewriter.setInsertionPointAfter(memOp);
-        Value stored = memOp.getStored(slot, rewriter, reachingDef, dataLayout);
+        builder.setInsertionPointAfter(memOp);
+        Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
         assert(stored && "a memory operation storing to a slot must provide a "
                          "new definition of the slot");
         reachingDef = stored;
@@ -475,33 +473,10 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
     Block *block = job.block->getBlock();
 
     if (info.mergePoints.contains(block)) {
-      // If the block is a merge point, we need to add a block argument to hold
-      // the selected reaching definition. This has to be a bit complicated
-      // because of RewriterBase limitations: we need to create a new block with
-      // the extra block argument, move the content of the block to the new
-      // block, and replace the block with the new block in the merge point set.
-      SmallVector<Type> argTypes;
-      SmallVector<Location> argLocs;
-      for (BlockArgument arg : block->getArguments()) {
-        argTypes.push_back(arg.getType());
-        argLocs.push_back(arg.getLoc());
-      }
-      argTypes.push_back(slot.elemType);
-      argLocs.push_back(slot.ptr.getLoc());
-      Block *newBlock = rewriter.createBlock(block, argTypes, argLocs);
-
-      info.mergePoints.erase(block);
-      info.mergePoints.insert(newBlock);
-
-      rewriter.replaceAllUsesWith(block, newBlock);
-      rewriter.mergeBlocks(block, newBlock,
-                           newBlock->getArguments().drop_back());
-
-      block = newBlock;
-
-      BlockArgument blockArgument = block->getArguments().back();
-      rewriter.setInsertionPointToStart(block);
-      allocator.handleBlockArgument(slot, blockArgument, rewriter);
+      BlockArgument blockArgument =
+          block->addArgument(slot.elemType, slot.ptr.getLoc());
+      builder.setInsertionPointToStart(block);
+      allocator.handleBlockArgument(slot, blockArgument, builder);
       job.reachingDef = blockArgument;
 
       if (statistics.newBlockArgumentAmount)
@@ -514,10 +489,8 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
     if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
       for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
         if (info.mergePoints.contains(blockOperand.get())) {
-          rewriter.modifyOpInPlace(terminator, [&]() {
-            terminator.getSuccessorOperands(blockOperand.getOperandNumber())
-                .append(job.reachingDef);
-          });
+          terminator.getSuccessorOperands(blockOperand.getOperandNumber())
+              .append(job.reachingDef);
         }
       }
     }
@@ -569,9 +542,9 @@ void MemorySlotPromoter::removeBlockingUses() {
       if (!reachingDef)
         reachingDef = getOrCreateDefaultValue();
 
-      rewriter.setInsertionPointAfter(toPromote);
+      builder.setInsertionPointAfter(toPromote);
       if (toPromoteMemOp.removeBlockingUses(
-              slot, info.userToBlockingUses[toPromote], rewriter, reachingDef,
+              slot, info.userToBlockingUses[toPromote], builder, reachingDef,
               dataLayout) == DeletionKind::Delete)
         toErase.push_back(toPromote);
       if (toPromoteMemOp.storesTo(slot))
@@ -581,20 +554,20 @@ void MemorySlotPromoter::removeBlockingUses() {
     }
 
     auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
-    rewriter.setInsertionPointAfter(toPromote);
+    builder.setInsertionPointAfter(toPromote);
     if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
-                                          rewriter) == DeletionKind::Delete)
+                                          builder) == DeletionKind::Delete)
       toErase.push_back(toPromote);
     if (toPromoteBasic.requiresReplacedValues())
       toVisit.push_back(toPromoteBasic);
   }
   for (PromotableOpInterface op : toVisit) {
-    rewriter.setInsertionPointAfter(op);
-    op.visitReplacedValues(replacedValuesList, rewriter);
+    builder.setInsertionPointAfter(op);
+    op.visitReplacedValues(replacedValuesList, builder);
   }
 
   for (Operation *toEraseOp : toErase)
-    rewriter.eraseOp(toEraseOp);
+    toEraseOp->erase();
 
   assert(slot.ptr.use_empty() &&
          "after promotion, the slot pointer should not be used anymore");
@@ -617,8 +590,7 @@ void MemorySlotPromoter::promoteSlot() {
       assert(succOperands.size() == mergePoint->getNumArguments() ||
              succOperands.size() + 1 == mergePoint->getNumArguments());
       if (succOperands.size() + 1 == mergePoint->getNumArguments())
-        rewriter.modifyOpInPlace(
-            user, [&]() { succOperands.append(getOrCreateDefaultValue()); });
+        succOperands.append(getOrCreateDefaultValue());
     }
   }
 
@@ -628,13 +600,12 @@ void MemorySlotPromoter::promoteSlot() {
   if (statistics.promotedAmount)
     (*statistics.promotedAmount)++;
 
-  allocator.handlePromotionComplete(slot, defaultValue, rewriter);
+  allocator.handlePromotionComplete(slot, defaultValue, builder);
 }
 
 LogicalResult mlir::tryToPromoteMemorySlots(
-    ArrayRef<PromotableAllocationOpInterface> allocators,
-    RewriterBase &rewriter, const DataLayout &dataLayout,
-    Mem2RegStatistics statistics) {
+    ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
+    const DataLayout &dataLayout, Mem2RegStatistics statistics) {
   bool promotedAny = false;
 
   for (PromotableAllocationOpInterface allocator : allocators) {
@@ -646,7 +617,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
       MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
       std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
       if (info) {
-        MemorySlotPromoter(slot, allocator, rewriter, dominance, dataLayout,
+        MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
                            std::move(*info), statistics)
             .promoteSlot();
         promotedAny = true;
@@ -674,7 +645,6 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
         continue;
 
       OpBuilder builder(&region.front(), region.front().begin());
-      IRRewriter rewriter(builder);
 
       // Promoting a slot can allow for further promotion of other slots,
       // promotion is tried until no promotion succeeds.
@@ -689,7 +659,7 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
         const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
 
         // Attempt promoting until no promotion succeeds.
-        if (failed(tryToPromoteMemorySlots(allocators, rewriter, dataLayout,
+        if (failed(tryToPromoteMemorySlots(allocators, builder, dataLayout,
                                            statistics)))
           break;
 

diff  --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index f24cbb7b1725c..4e28fa687ffd4 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -134,15 +134,14 @@ computeDestructuringInfo(DestructurableMemorySlot &slot,
 /// subslots as specified by its allocator.
 static void destructureSlot(DestructurableMemorySlot &slot,
                             DestructurableAllocationOpInterface allocator,
-                            RewriterBase &rewriter,
-                            const DataLayout &dataLayout,
+                            OpBuilder &builder, const DataLayout &dataLayout,
                             MemorySlotDestructuringInfo &info,
                             const SROAStatistics &statistics) {
-  RewriterBase::InsertionGuard guard(rewriter);
+  OpBuilder::InsertionGuard guard(builder);
 
-  rewriter.setInsertionPointToStart(slot.ptr.getParentBlock());
+  builder.setInsertionPointToStart(slot.ptr.getParentBlock());
   DenseMap<Attribute, MemorySlot> subslots =
-      allocator.destructure(slot, info.usedIndices, rewriter);
+      allocator.destructure(slot, info.usedIndices, builder);
 
   if (statistics.slotsWithMemoryBenefit &&
       slot.elementPtrs.size() != info.usedIndices.size())
@@ -160,9 +159,9 @@ static void destructureSlot(DestructurableMemorySlot &slot,
 
   llvm::SmallVector<Operation *> toErase;
   for (Operation *toRewire : llvm::reverse(usersToRewire)) {
-    rewriter.setInsertionPointAfter(toRewire);
+    builder.setInsertionPointAfter(toRewire);
     if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
-      if (accessor.rewire(slot, subslots, rewriter, dataLayout) ==
+      if (accessor.rewire(slot, subslots, builder, dataLayout) ==
           DeletionKind::Delete)
         toErase.push_back(accessor);
       continue;
@@ -170,12 +169,12 @@ static void destructureSlot(DestructurableMemorySlot &slot,
 
     auto promotable = cast<PromotableOpInterface>(toRewire);
     if (promotable.removeBlockingUses(info.userToBlockingUses[promotable],
-                                      rewriter) == DeletionKind::Delete)
+                                      builder) == DeletionKind::Delete)
       toErase.push_back(promotable);
   }
 
   for (Operation *toEraseOp : toErase)
-    rewriter.eraseOp(toEraseOp);
+    toEraseOp->erase();
 
   assert(slot.ptr.use_empty() && "after destructuring, the original slot "
                                  "pointer should no longer be used");
@@ -186,12 +185,12 @@ static void destructureSlot(DestructurableMemorySlot &slot,
   if (statistics.destructuredAmount)
     (*statistics.destructuredAmount)++;
 
-  allocator.handleDestructuringComplete(slot, rewriter);
+  allocator.handleDestructuringComplete(slot, builder);
 }
 
 LogicalResult mlir::tryToDestructureMemorySlots(
     ArrayRef<DestructurableAllocationOpInterface> allocators,
-    RewriterBase &rewriter, const DataLayout &dataLayout,
+    OpBuilder &builder, const DataLayout &dataLayout,
     SROAStatistics statistics) {
   bool destructuredAny = false;
 
@@ -202,7 +201,7 @@ LogicalResult mlir::tryToDestructureMemorySlots(
       if (!info)
         continue;
 
-      destructureSlot(slot, allocator, rewriter, dataLayout, *info, statistics);
+      destructureSlot(slot, allocator, builder, dataLayout, *info, statistics);
       destructuredAny = true;
     }
   }
@@ -230,7 +229,6 @@ struct SROA : public impl::SROABase<SROA> {
         continue;
 
       OpBuilder builder(&region.front(), region.front().begin());
-      IRRewriter rewriter(builder);
 
       // Destructuring a slot can allow for further destructuring of other
       // slots, destructuring is tried until no destructuring succeeds.
@@ -243,7 +241,7 @@ struct SROA : public impl::SROABase<SROA> {
           allocators.emplace_back(allocator);
         });
 
-        if (failed(tryToDestructureMemorySlots(allocators, rewriter, dataLayout,
+        if (failed(tryToDestructureMemorySlots(allocators, builder, dataLayout,
                                                statistics)))
           break;
 


        


More information about the Mlir-commits mailing list