[Mlir-commits] [mlir] eeafc9d - [MLIR][Mem2Reg] Fix multi slot handling & move retry handling (#91464)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun May 12 22:37:45 PDT 2024


Author: Christian Ulmann
Date: 2024-05-13T07:37:41+02:00
New Revision: eeafc9daa15d2d022bcdd456d4b8bafd23f5f121

URL: https://github.com/llvm/llvm-project/commit/eeafc9daa15d2d022bcdd456d4b8bafd23f5f121
DIFF: https://github.com/llvm/llvm-project/commit/eeafc9daa15d2d022bcdd456d4b8bafd23f5f121.diff

LOG: [MLIR][Mem2Reg] Fix multi slot handling & move retry handling (#91464)

This commit fixes Mem2Regs mutli-slot allocator handling and extends the
test dialect to test this.

Additionally, this modifies Mem2Reg's API to always attempt a full
promotion on all the passed in "allocators". This ensures that the pass
does not require unnecessary walks over the regions and improves caching
benefits.

Added: 
    mlir/test/Transforms/mem2reg.mlir

Modified: 
    mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
    mlir/include/mlir/Transforms/Mem2Reg.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
    mlir/lib/Transforms/Mem2Reg.cpp
    mlir/test/lib/Dialect/Test/TestOpDefs.cpp
    mlir/test/lib/Dialect/Test/TestOps.h
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index adf182ac7069d..e2409cbec5fde 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -68,8 +68,12 @@ def PromotableAllocationOpInterface
         Hook triggered once the promotion of a slot is complete. This can
         also clean up the created default value if necessary.
         This will only be called for slots declared by this operation.
+
+        Must return a new promotable allocation op if this operation produced
+        multiple promotable slots, nullopt otherwise.
       }],
-      "void", "handlePromotionComplete",
+      "::std::optional<::mlir::PromotableAllocationOpInterface>",
+        "handlePromotionComplete",
       (ins
         "const ::mlir::MemorySlot &":$slot, 
         "::mlir::Value":$defaultValue,

diff  --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index fee7fb312750c..6986cad9ae120 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -9,7 +9,6 @@
 #ifndef MLIR_TRANSFORMS_MEM2REG_H
 #define MLIR_TRANSFORMS_MEM2REG_H
 
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "llvm/ADT/Statistic.h"
 
@@ -23,8 +22,9 @@ struct Mem2RegStatistics {
   llvm::Statistic *newBlockArgumentAmount = nullptr;
 };
 
-/// Attempts to promote the memory slots of the provided allocators. Succeeds if
-/// at least one memory slot was promoted.
+/// Attempts to promote the memory slots of the provided allocators. Iteratively
+/// retries the promotion of all slots as promoting one slot might enable
+/// subsequent promotions. Succeeds if at least one memory slot was promoted.
 LogicalResult
 tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
                         OpBuilder &builder, const DataLayout &dataLayout,

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 70102e1c81920..4fdf847a559ce 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -50,12 +50,14 @@ void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
                                        declareOp.getLocationExpr());
 }
 
-void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
-                                             Value defaultValue,
-                                             OpBuilder &builder) {
+std::optional<PromotableAllocationOpInterface>
+LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+                                        Value defaultValue,
+                                        OpBuilder &builder) {
   if (defaultValue && defaultValue.use_empty())
     defaultValue.getDefiningOp()->erase();
   this->erase();
+  return std::nullopt;
 }
 
 SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index dca07e84ea73c..e30598e6878f4 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -96,12 +96,14 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
       });
 }
 
-void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
-                                               Value defaultValue,
-                                               OpBuilder &builder) {
+std::optional<PromotableAllocationOpInterface>
+memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+                                          Value defaultValue,
+                                          OpBuilder &builder) {
   if (defaultValue.use_empty())
     defaultValue.getDefiningOp()->erase();
   this->erase();
+  return std::nullopt;
 }
 
 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,

diff  --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 8adbbcd01cb44..e096747741c0a 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -173,7 +173,9 @@ class MemorySlotPromoter {
   /// Actually promotes the slot by mutating IR. Promoting a slot DOES
   /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
   /// promotion info should NOT be performed in batches.
-  void promoteSlot();
+  /// Returns a promotable allocation op if a new allocator was created, nullopt
+  /// otherwise.
+  std::optional<PromotableAllocationOpInterface> promoteSlot();
 
 private:
   /// Computes the reaching definition for all the operations that require
@@ -595,7 +597,8 @@ void MemorySlotPromoter::removeBlockingUses() {
          "after promotion, the slot pointer should not be used anymore");
 }
 
-void MemorySlotPromoter::promoteSlot() {
+std::optional<PromotableAllocationOpInterface>
+MemorySlotPromoter::promoteSlot() {
   computeReachingDefInRegion(slot.ptr.getParentRegion(),
                              getOrCreateDefaultValue());
 
@@ -622,7 +625,7 @@ void MemorySlotPromoter::promoteSlot() {
   if (statistics.promotedAmount)
     (*statistics.promotedAmount)++;
 
-  allocator.handlePromotionComplete(slot, defaultValue, builder);
+  return allocator.handlePromotionComplete(slot, defaultValue, builder);
 }
 
 LogicalResult mlir::tryToPromoteMemorySlots(
@@ -636,20 +639,50 @@ LogicalResult mlir::tryToPromoteMemorySlots(
   // lazily and cached to avoid expensive recomputation.
   BlockIndexCache blockIndexCache;
 
-  for (PromotableAllocationOpInterface allocator : allocators) {
-    for (MemorySlot slot : allocator.getPromotableSlots()) {
-      if (slot.ptr.use_empty())
-        continue;
-
-      MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
-      std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
-      if (info) {
-        MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
-                           std::move(*info), statistics, blockIndexCache)
-            .promoteSlot();
-        promotedAny = true;
+  SmallVector<PromotableAllocationOpInterface> workList(allocators.begin(),
+                                                        allocators.end());
+
+  SmallVector<PromotableAllocationOpInterface> newWorkList;
+  newWorkList.reserve(workList.size());
+  while (true) {
+    bool changesInThisRound = false;
+    for (PromotableAllocationOpInterface allocator : workList) {
+      bool changedAllocator = false;
+      for (MemorySlot slot : allocator.getPromotableSlots()) {
+        if (slot.ptr.use_empty())
+          continue;
+
+        MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
+        std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
+        if (info) {
+          std::optional<PromotableAllocationOpInterface> newAllocator =
+              MemorySlotPromoter(slot, allocator, builder, dominance,
+                                 dataLayout, std::move(*info), statistics,
+                                 blockIndexCache)
+                  .promoteSlot();
+          changedAllocator = true;
+          // Add newly created allocators to the worklist for further
+          // processing.
+          if (newAllocator)
+            newWorkList.push_back(*newAllocator);
+
+          // A break is required, since promoting a slot may invalidate the
+          // remaining slots of an allocator.
+          break;
+        }
       }
+      if (!changedAllocator)
+        newWorkList.push_back(allocator);
+      changesInThisRound |= changedAllocator;
     }
+    if (!changesInThisRound)
+      break;
+    promotedAny = true;
+
+    // Swap the vector's backing memory and clear the entries in newWorkList
+    // afterwards. This ensures that additional heap allocations can be avoided.
+    workList.swap(newWorkList);
+    newWorkList.clear();
   }
 
   return success(promotedAny);
@@ -677,22 +710,16 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
 
       OpBuilder builder(&region.front(), region.front().begin());
 
-      // Promoting a slot can allow for further promotion of other slots,
-      // promotion is tried until no promotion succeeds.
-      while (true) {
-        SmallVector<PromotableAllocationOpInterface> allocators;
-        // Build a list of allocators to attempt to promote the slots of.
-        region.walk([&](PromotableAllocationOpInterface allocator) {
-          allocators.emplace_back(allocator);
-        });
-
-        // Attempt promoting until no promotion succeeds.
-        if (failed(tryToPromoteMemorySlots(allocators, builder, dataLayout,
-                                           dominance, statistics)))
-          break;
+      SmallVector<PromotableAllocationOpInterface> allocators;
+      // Build a list of allocators to attempt to promote the slots of.
+      region.walk([&](PromotableAllocationOpInterface allocator) {
+        allocators.emplace_back(allocator);
+      });
 
+      // Attempt promoting as many of the slots as possible.
+      if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout,
+                                            dominance, statistics)))
         changed = true;
-      }
     }
     if (!changed)
       markAllAnalysesPreserved();

diff  --git a/mlir/test/Transforms/mem2reg.mlir b/mlir/test/Transforms/mem2reg.mlir
new file mode 100644
index 0000000000000..daeaa2da07634
--- /dev/null
+++ b/mlir/test/Transforms/mem2reg.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
+
+// Verifies that allocators with mutliple slots are handled properly.
+
+// CHECK-LABEL: func.func @multi_slot_alloca
+func.func @multi_slot_alloca() -> (i32, i32) {
+  // CHECK-NOT: test.multi_slot_alloca
+  %1, %2 = test.multi_slot_alloca : () -> (memref<i32>, memref<i32>)
+  %3 = memref.load %1[] : memref<i32>
+  %4 = memref.load %2[] : memref<i32>
+  return %3, %4 : i32, i32
+}
+
+// -----
+
+// Verifies that a multi slot allocator can be partially promoted.
+
+func.func private @consumer(memref<i32>)
+
+// CHECK-LABEL: func.func @multi_slot_alloca_only_second
+func.func @multi_slot_alloca_only_second() -> (i32, i32) {
+  // CHECK: %{{[[:alnum:]]+}} = test.multi_slot_alloca
+  %1, %2 = test.multi_slot_alloca : () -> (memref<i32>, memref<i32>)
+  func.call @consumer(%1) : (memref<i32>) -> ()
+  %3 = memref.load %1[] : memref<i32>
+  %4 = memref.load %2[] : memref<i32>
+  return %3, %4 : i32, i32
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 08df2e5e12286..d22d48b139a04 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
 
 using namespace mlir;
 using namespace test;
@@ -1172,3 +1173,61 @@ void TestOpWithVersionedProperties::writeToMlirBytecode(
   writer.writeVarInt(prop.value1);
   writer.writeVarInt(prop.value2);
 }
+
+//===----------------------------------------------------------------------===//
+// TestMultiSlotAlloca
+//===----------------------------------------------------------------------===//
+
+llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
+  SmallVector<MemorySlot> slots;
+  for (Value result : getResults()) {
+    slots.push_back(MemorySlot{
+        result, cast<MemRefType>(result.getType()).getElementType()});
+  }
+  return slots;
+}
+
+Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
+                                           OpBuilder &builder) {
+  return builder.create<TestOpConstant>(getLoc(), slot.elemType,
+                                        builder.getI32IntegerAttr(42));
+}
+
+void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
+                                              BlockArgument argument,
+                                              OpBuilder &builder) {
+  // Not relevant for testing.
+}
+
+std::optional<PromotableAllocationOpInterface>
+TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
+                                             Value defaultValue,
+                                             OpBuilder &builder) {
+  if (defaultValue && defaultValue.use_empty())
+    defaultValue.getDefiningOp()->erase();
+
+  if (getNumResults() == 1) {
+    erase();
+    return std::nullopt;
+  }
+
+  SmallVector<Type> newTypes;
+  SmallVector<Value> remainingValues;
+
+  for (Value oldResult : getResults()) {
+    if (oldResult == slot.ptr)
+      continue;
+    remainingValues.push_back(oldResult);
+    newTypes.push_back(oldResult.getType());
+  }
+
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPoint(*this);
+  auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
+  for (auto [oldResult, newResult] :
+       llvm::zip_equal(remainingValues, replacement.getResults()))
+    oldResult.replaceAllUsesWith(newResult);
+
+  erase();
+  return replacement;
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index f9925855bb9db..837ccca565926 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -36,6 +36,7 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5352d574ac394..e16ea2407314e 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -28,6 +28,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 
@@ -3167,4 +3168,14 @@ def TestOpOptionallyImplementingInterface
   let arguments = (ins BoolAttr:$implementsInterface);
 }
 
+//===----------------------------------------------------------------------===//
+// Test Mem2Reg
+//===----------------------------------------------------------------------===//
+
+def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
+    [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
+  let results = (outs Variadic<MemRefOf<[I32]>>:$results);
+  let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
+}
+
 #endif // TEST_OPS


        


More information about the Mlir-commits mailing list