[Mlir-commits] [mlir] [MLIR][Interfaces] Change MemorySlotInterface to use OpBuilder (PR #91341)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 7 07:34:47 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Christian Ulmann (Dinistro)

<details>
<summary>Changes</summary>

This commit changes the `MemorySlotInterface` back to using `OpBuilder` instead of a rewriter. This was originally introduced in https://reviews.llvm.org/D150432 but it was shown that patterns are a bad idea for both Mem2Reg and SROA.
Mem2Reg suffers from the usage of a rewriter due to being forced to create new basic blocks. This is an issue, as it leads to the invalidation of the dominance information, which can be expensive to recompute.

---

Patch is 62.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91341.diff


7 Files Affected:

- (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.td (+27-36) 
- (modified) mlir/include/mlir/Transforms/Mem2Reg.h (+1-1) 
- (modified) mlir/include/mlir/Transforms/SROA.h (+1-1) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+121-127) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+24-28) 
- (modified) mlir/lib/Transforms/Mem2Reg.cpp (+30-60) 
- (modified) mlir/lib/Transforms/SROA.cpp (+12-14) 


``````````diff
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 764fa6d547b2e..adf182ac7069d 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -40,42 +40,40 @@ def PromotableAllocationOpInterface
         Provides the default Value of this memory slot. The provided Value
         will be used as the reaching definition of loads done before any store.
         This Value must outlive the promotion and dominate all the uses of this
-        slot's pointer. The provided rewriter can be used to create the default
+        slot's pointer. The provided builder can be used to create the default
         value on the fly.
 
-        The rewriter is located at the beginning of the block where the slot
-        pointer is defined. All IR mutations must happen through the rewriter.
+        The builder is located at the beginning of the block where the slot
+        pointer is defined.
       }], "::mlir::Value", "getDefaultValue",
       (ins
         "const ::mlir::MemorySlot &":$slot,
-        "::mlir::RewriterBase &":$rewriter)
+        "::mlir::OpBuilder &":$builder)
     >,
     InterfaceMethod<[{
         Hook triggered for every new block argument added to a block.
         This will only be called for slots declared by this operation.
 
-        The rewriter is located at the beginning of the block on call. All IR
-        mutations must happen through the rewriter.
+        The builder is located at the beginning of the block on call. All IR
+        mutations must happen through the builder.
       }],
       "void", "handleBlockArgument",
       (ins
         "const ::mlir::MemorySlot &":$slot,
         "::mlir::BlockArgument":$argument,
-        "::mlir::RewriterBase &":$rewriter
+        "::mlir::OpBuilder &":$builder
       )
     >,
     InterfaceMethod<[{
         Hook triggered once the promotion of a slot is complete. This can
         also clean up the created default value if necessary.
         This will only be called for slots declared by this operation.
-
-        All IR mutations must happen through the rewriter.
       }],
       "void", "handlePromotionComplete",
       (ins
         "const ::mlir::MemorySlot &":$slot, 
         "::mlir::Value":$defaultValue,
-        "::mlir::RewriterBase &":$rewriter)
+        "::mlir::OpBuilder &":$builder)
     >,
   ];
 }
@@ -119,15 +117,14 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         The returned value must dominate all operations dominated by the storing
         operation.
 
-        If IR must be mutated to extract a concrete value being stored, mutation
-        must happen through the provided rewriter. The rewriter is located
-        immediately after the memory operation on call. No IR deletion is
-        allowed in this method. IR mutations must not introduce new uses of the
-        memory slot. Existing control flow must not be modified.
+        The builder is located immediately after the memory operation on call.
+        No IR deletion is allowed in this method. IR mutations must not
+        introduce new uses of the memory slot. Existing control flow must not
+        be modified.
       }],
       "::mlir::Value", "getStored",
       (ins "const ::mlir::MemorySlot &":$slot,
-           "::mlir::RewriterBase &":$rewriter,
+           "::mlir::OpBuilder &":$builder,
            "::mlir::Value":$reachingDef,
            "const ::mlir::DataLayout &":$dataLayout)
     >,
@@ -166,14 +163,13 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         have been done at the point of calling this method, but it will be done
         eventually.
 
-        The rewriter is located after the promotable operation on call. All IR
-        mutations must happen through the rewriter.
+        The builder is located after the promotable operation on call.
       }],
       "::mlir::DeletionKind",
       "removeBlockingUses",
       (ins "const ::mlir::MemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
-           "::mlir::RewriterBase &":$rewriter,
+           "::mlir::OpBuilder &":$builder,
            "::mlir::Value":$reachingDefinition,
            "const ::mlir::DataLayout &":$dataLayout)
     >,
@@ -224,13 +220,12 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         have been done at the point of calling this method, but it will be done
         eventually.
 
-        The rewriter is located after the promotable operation on call. All IR
-        mutations must happen through the rewriter.
+        The builder is located after the promotable operation on call.
       }],
       "::mlir::DeletionKind",
       "removeBlockingUses",
       (ins "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::OpBuilder &":$builder)
     >,
     InterfaceMethod<[{
         This method allows the promoted operation to visit the SSA values used
@@ -254,13 +249,12 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         scheduled for removal and if `requiresReplacedValues` returned
         true.
 
-        The rewriter is located after the promotable operation on call. All IR
-        mutations must happen through the rewriter. During the transformation,
-        *no operation should be deleted*.
+        The builder is located after the promotable operation on call. During
+        the transformation, *no operation should be deleted*.
       }],
       "void", "visitReplacedValues",
       (ins "::llvm::ArrayRef<std::pair<::mlir::Operation*, ::mlir::Value>>":$mutatedDefs,
-           "::mlir::RewriterBase &":$rewriter), [{}], [{ return; }]
+           "::mlir::OpBuilder &":$builder), [{}], [{ return; }]
     >,
   ];
 }
@@ -293,25 +287,23 @@ def DestructurableAllocationOpInterface
         at the end of this call. Only generates subslots for the indices found in
         `usedIndices` since all other subslots are unused.
 
-        The rewriter is located at the beginning of the block where the slot
-        pointer is defined. All IR mutations must happen through the rewriter.
+        The builder is located at the beginning of the block where the slot
+        pointer is defined.
       }],
       "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot>",
       "destructure",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::OpBuilder &":$builder)
     >,
     InterfaceMethod<[{
         Hook triggered once the destructuring of a slot is complete, meaning the
         original slot is no longer being refered to and could be deleted.
         This will only be called for slots declared by this operation.
-
-        All IR mutations must happen through the rewriter.
       }],
       "void", "handleDestructuringComplete",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::OpBuilder &":$builder)
     >,
   ];
 }
@@ -376,15 +368,14 @@ def DestructurableAccessorOpInterface
         Rewires the use of a slot to the generated subslots, without deleting
         any operation. Returns whether the accessor should be deleted.
 
-        All IR mutations must happen through the rewriter. Deletion of
-        operations is not allowed, only the accessor can be scheduled for
-        deletion by returning the appropriate value.
+        Deletion of operations is not allowed, only the accessor can be
+        scheduled for deletion by returning the appropriate value.
       }],
       "::mlir::DeletionKind",
       "rewire",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot> &":$subslots,
-           "::mlir::RewriterBase &":$rewriter,
+           "::mlir::OpBuilder &":$builder,
            "const ::mlir::DataLayout &":$dataLayout)
     >
   ];
diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index ed10644e26a51..b4f939d654142 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -27,7 +27,7 @@ struct Mem2RegStatistics {
 /// at least one memory slot was promoted.
 LogicalResult
 tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
-                        RewriterBase &rewriter, const DataLayout &dataLayout,
+                        OpBuilder &builder, const DataLayout &dataLayout,
                         Mem2RegStatistics statistics = {});
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h
index d09a7989edeab..fa84fb1eae73a 100644
--- a/mlir/include/mlir/Transforms/SROA.h
+++ b/mlir/include/mlir/Transforms/SROA.h
@@ -31,7 +31,7 @@ struct SROAStatistics {
 /// failure if no slot was destructured.
 LogicalResult tryToDestructureMemorySlots(
     ArrayRef<DestructurableAllocationOpInterface> allocators,
-    RewriterBase &rewriter, const DataLayout &dataLayout,
+    OpBuilder &builder, const DataLayout &dataLayout,
     SROAStatistics statistics = {});
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 6eeb13ebffb51..70102e1c81920 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -36,26 +36,26 @@ llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
 }
 
 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
-                                      RewriterBase &rewriter) {
-  return rewriter.create<LLVM::UndefOp>(getLoc(), slot.elemType);
+                                      OpBuilder &builder) {
+  return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType);
 }
 
 void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
                                          BlockArgument argument,
-                                         RewriterBase &rewriter) {
+                                         OpBuilder &builder) {
   for (Operation *user : getOperation()->getUsers())
     if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
-      rewriter.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
-                                        declareOp.getVarInfo(),
-                                        declareOp.getLocationExpr());
+      builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
+                                       declareOp.getVarInfo(),
+                                       declareOp.getLocationExpr());
 }
 
 void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
                                              Value defaultValue,
-                                             RewriterBase &rewriter) {
+                                             OpBuilder &builder) {
   if (defaultValue && defaultValue.use_empty())
-    rewriter.eraseOp(defaultValue.getDefiningOp());
-  rewriter.eraseOp(*this);
+    defaultValue.getDefiningOp()->erase();
+  this->erase();
 }
 
 SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
@@ -78,16 +78,16 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
 DenseMap<Attribute, MemorySlot>
 LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
                             const SmallPtrSetImpl<Attribute> &usedIndices,
-                            RewriterBase &rewriter) {
+                            OpBuilder &builder) {
   assert(slot.ptr == getResult());
-  rewriter.setInsertionPointAfter(*this);
+  builder.setInsertionPointAfter(*this);
 
   auto destructurableType = cast<DestructurableTypeInterface>(getElemType());
   DenseMap<Attribute, MemorySlot> slotMap;
   for (Attribute index : usedIndices) {
     Type elemType = destructurableType.getTypeAtIndex(index);
     assert(elemType && "used index must exist");
-    auto subAlloca = rewriter.create<LLVM::AllocaOp>(
+    auto subAlloca = builder.create<LLVM::AllocaOp>(
         getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
         getArraySize());
     slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
@@ -97,9 +97,9 @@ LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
 }
 
 void LLVM::AllocaOp::handleDestructuringComplete(
-    const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
+    const DestructurableMemorySlot &slot, OpBuilder &builder) {
   assert(slot.ptr == getResult());
-  rewriter.eraseOp(*this);
+  this->erase();
 }
 
 //===----------------------------------------------------------------------===//
@@ -112,7 +112,7 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
 
 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
-Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
                               Value reachingDef, const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
@@ -175,7 +175,7 @@ static bool isBigEndian(const DataLayout &dataLayout) {
 
 /// Converts a value to an integer type of the same size.
 /// Assumes that the type can be converted.
-static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
+static Value castToSameSizedInt(OpBuilder &builder, Location loc, Value val,
                                 const DataLayout &dataLayout) {
   Type type = val.getType();
   assert(isSupportedTypeForConversion(type) &&
@@ -185,15 +185,15 @@ static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
     return val;
 
   uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
-  IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
+  IntegerType valueSizeInteger = builder.getIntegerType(typeBitSize);
 
   if (isa<LLVM::LLVMPointerType>(type))
-    return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
-  return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
+    return builder.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
+  return builder.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
 }
 
 /// Converts a value with an integer type to `targetType`.
-static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
+static Value castIntValueToSameSizedType(OpBuilder &builder, Location loc,
                                          Value val, Type targetType) {
   assert(isa<IntegerType>(val.getType()) &&
          "expected value to have an integer type");
@@ -202,13 +202,13 @@ static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
   if (val.getType() == targetType)
     return val;
   if (isa<LLVM::LLVMPointerType>(targetType))
-    return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
-  return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
+    return builder.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
+  return builder.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
 }
 
 /// Constructs operations that convert `srcValue` into a new value of type
 /// `targetType`. Assumes the types have the same bitsize.
-static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
+static Value castSameSizedTypes(OpBuilder &builder, Location loc,
                                 Value srcValue, Type targetType,
                                 const DataLayout &dataLayout) {
   Type srcType = srcValue.getType();
@@ -226,18 +226,18 @@ static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
   // provenance.
   if (isa<LLVM::LLVMPointerType>(targetType) &&
       isa<LLVM::LLVMPointerType>(srcType))
-    return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
-                                                        srcValue);
+    return builder.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+                                                       srcValue);
 
   // For all other castable types, casting through integers is necessary.
-  Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
-  return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
+  Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
+  return castIntValueToSameSizedType(builder, loc, replacement, targetType);
 }
 
 /// Constructs operations that convert `srcValue` into a new value of type
 /// `targetType`. Performs bit-level extraction if the source type is larger
 /// than the target type. Assumes that this conversion is possible.
-static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
+static Value createExtractAndCast(OpBuilder &builder, Location loc,
                                   Value srcValue, Type targetType,
                                   const DataLayout &dataLayout) {
   // Get the types of the source and target values.
@@ -249,31 +249,31 @@ static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
   uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
   uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
   if (srcTypeSize == targetTypeSize)
-    return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout);
+    return castSameSizedTypes(builder, loc, srcValue, targetType, dataLayout);
 
   // First, cast the value to a same-sized integer type.
-  Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
+  Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
 
   // Truncate the integer if the size of the target is less than the value.
   if (isBigEndian(dataLayout)) {
     uint64_t shiftAmount = srcTypeSize - targetTypeSize;
-    auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
-        loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+    auto shiftConstant = builder.create<LLVM::ConstantOp>(
+        loc, builder.getIntegerAttr(srcType, shiftAmount));
     replacement =
-        rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
+        builder.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
   }
 
-  replacement = rewriter.create<LLVM::TruncOp>(
-      loc, rewriter.getIntegerType(targetTypeSize), replacement);
+  replacement = builder.create<LLVM::TruncOp>(
+      loc, builder.getIntegerType(targetTypeSize), replacement);
 
   // Now cast the integer to the actual target type if required.
-  return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
+  return castIntValueToSameSizedType(builder, loc, replacement, targetType);
 }
 
 /// Constructs operations that insert the bits of `srcValue` into the
 /// "beginning" of `reachingDef` (beginning is endianness dependent).
 /// Assumes that this conversion is possible.
-static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
+static Value createInsertAndCast(OpBuilder &builder, Location loc,
                                  Value srcValue, Value reachingDef,
                                  const DataLayout &dataLayout) {
 
@@ -284,27 +284,27 @@ static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
   uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
   uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
   if (slotTypeSize == valueTypeSize)
-    return castSameSizedTypes(rewriter, loc, srcValue, reachingDef.getType(),
+    return castSameSizedTypes(builder, loc, srcValue, reachingDef.getType(),
                               dataLayout);
 
   // In the case where the store only overwrites parts of the memory,
   // bit fiddling is required to construct the new value.
 
   // First convert both values to integers of the same size.
-  Value defAsInt = castToSameSizedInt(rewriter, loc, reachingDef, dataLayout);
-  Value valueAsInt = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
+  Value defAsInt = castToSameSizedInt(builder, loc, reachingDef, dataLayout);
...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/91341


More information about the Mlir-commits mailing list