[Mlir-commits] [mlir] [MLIR][Mem2Reg] Change API to always retry promotion after changes (PR #91464)

Christian Ulmann llvmlistbot at llvm.org
Wed May 8 08:58:28 PDT 2024


https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/91464

>From 0f0628b295931ee6e0e2cb210071a3ba09cb58c3 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Wed, 8 May 2024 12:03:56 +0000
Subject: [PATCH 1/2] [MLIR][Mem2Reg] Change API to always retry promotion
 after changes

This commit modifies the Mem2Reg's API to always attempt a full
promotion on all the passed in "allocators". This ensures that the pass
does not require unnecessary walks over the regions and improves caching
benefits.
---
 mlir/include/mlir/Transforms/Mem2Reg.h |  6 +--
 mlir/lib/Transforms/Mem2Reg.cpp        | 62 +++++++++++++++-----------
 2 files changed, 39 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index fee7fb312750..6986cad9ae12 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -9,7 +9,6 @@
 #ifndef MLIR_TRANSFORMS_MEM2REG_H
 #define MLIR_TRANSFORMS_MEM2REG_H
 
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "llvm/ADT/Statistic.h"
 
@@ -23,8 +22,9 @@ struct Mem2RegStatistics {
   llvm::Statistic *newBlockArgumentAmount = nullptr;
 };
 
-/// Attempts to promote the memory slots of the provided allocators. Succeeds if
-/// at least one memory slot was promoted.
+/// Attempts to promote the memory slots of the provided allocators. Iteratively
+/// retries the promotion of all slots as promoting one slot might enable
+/// subsequent promotions. Succeeds if at least one memory slot was promoted.
 LogicalResult
 tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
                         OpBuilder &builder, const DataLayout &dataLayout,
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 8adbbcd01cb4..390d2a3f54b6 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -636,20 +636,36 @@ LogicalResult mlir::tryToPromoteMemorySlots(
   // lazily and cached to avoid expensive recomputation.
   BlockIndexCache blockIndexCache;
 
-  for (PromotableAllocationOpInterface allocator : allocators) {
-    for (MemorySlot slot : allocator.getPromotableSlots()) {
-      if (slot.ptr.use_empty())
-        continue;
-
-      MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
-      std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
-      if (info) {
-        MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
-                           std::move(*info), statistics, blockIndexCache)
-            .promoteSlot();
-        promotedAny = true;
+  SmallVector<PromotableAllocationOpInterface> workList(allocators.begin(),
+                                                        allocators.end());
+
+  SmallVector<PromotableAllocationOpInterface> newWorkList;
+  newWorkList.reserve(workList.size());
+  while (true) {
+    for (PromotableAllocationOpInterface allocator : workList) {
+      for (MemorySlot slot : allocator.getPromotableSlots()) {
+        if (slot.ptr.use_empty())
+          continue;
+
+        MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
+        std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
+        if (info) {
+          MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
+                             std::move(*info), statistics, blockIndexCache)
+              .promoteSlot();
+          promotedAny = true;
+          continue;
+        }
+        newWorkList.push_back(allocator);
       }
     }
+    if (workList.size() == newWorkList.size())
+      break;
+
+    // Swap the vector's backing memory and clear the entries in newWorkList
+    // afterwards. This ensures that additional heap allocations can be avoided.
+    workList.swap(newWorkList);
+    newWorkList.clear();
   }
 
   return success(promotedAny);
@@ -677,22 +693,16 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
 
       OpBuilder builder(&region.front(), region.front().begin());
 
-      // Promoting a slot can allow for further promotion of other slots,
-      // promotion is tried until no promotion succeeds.
-      while (true) {
-        SmallVector<PromotableAllocationOpInterface> allocators;
-        // Build a list of allocators to attempt to promote the slots of.
-        region.walk([&](PromotableAllocationOpInterface allocator) {
-          allocators.emplace_back(allocator);
-        });
-
-        // Attempt promoting until no promotion succeeds.
-        if (failed(tryToPromoteMemorySlots(allocators, builder, dataLayout,
-                                           dominance, statistics)))
-          break;
+      SmallVector<PromotableAllocationOpInterface> allocators;
+      // Build a list of allocators to attempt to promote the slots of.
+      region.walk([&](PromotableAllocationOpInterface allocator) {
+        allocators.emplace_back(allocator);
+      });
 
+      // Attempt promoting as many of the slots as possible.
+      if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout,
+                                            dominance, statistics)))
         changed = true;
-      }
     }
     if (!changed)
       markAllAnalysesPreserved();

>From b7326dfdc64c6c7941964b61ffdebec5110084ab Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Wed, 8 May 2024 15:57:56 +0000
Subject: [PATCH 2/2] extend with fix for multi slot allocators

---
 .../mlir/Interfaces/MemorySlotInterfaces.td   |  6 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp |  8 ++-
 .../Dialect/MemRef/IR/MemRefMemorySlot.cpp    |  8 ++-
 mlir/lib/Transforms/Mem2Reg.cpp               | 32 ++++++++---
 mlir/test/Transforms/mem2reg.mlir             | 12 ++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 57 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.h          |  1 +
 mlir/test/lib/Dialect/Test/TestOps.td         | 11 ++++
 8 files changed, 119 insertions(+), 16 deletions(-)
 create mode 100644 mlir/test/Transforms/mem2reg.mlir

diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index adf182ac7069..762a97ac546a 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -68,8 +68,12 @@ def PromotableAllocationOpInterface
         Hook triggered once the promotion of a slot is complete. This can
         also clean up the created default value if necessary.
         This will only be called for slots declared by this operation.
+
+        Must return a new promotable allocation op if this operation produced
+        multiple promotable slots, nullopt otherwise.
       }],
-      "void", "handlePromotionComplete",
+      "std::optional<::mlir::PromotableAllocationOpInterface>",
+        "handlePromotionComplete",
       (ins
         "const ::mlir::MemorySlot &":$slot, 
         "::mlir::Value":$defaultValue,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 70102e1c8192..4fdf847a559c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -50,12 +50,14 @@ void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
                                        declareOp.getLocationExpr());
 }
 
-void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
-                                             Value defaultValue,
-                                             OpBuilder &builder) {
+std::optional<PromotableAllocationOpInterface>
+LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+                                        Value defaultValue,
+                                        OpBuilder &builder) {
   if (defaultValue && defaultValue.use_empty())
     defaultValue.getDefiningOp()->erase();
   this->erase();
+  return std::nullopt;
 }
 
 SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index dca07e84ea73..e30598e6878f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -96,12 +96,14 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
       });
 }
 
-void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
-                                               Value defaultValue,
-                                               OpBuilder &builder) {
+std::optional<PromotableAllocationOpInterface>
+memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+                                          Value defaultValue,
+                                          OpBuilder &builder) {
   if (defaultValue.use_empty())
     defaultValue.getDefiningOp()->erase();
   this->erase();
+  return std::nullopt;
 }
 
 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 390d2a3f54b6..0d90e6820e3c 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -173,7 +173,9 @@ class MemorySlotPromoter {
   /// Actually promotes the slot by mutating IR. Promoting a slot DOES
   /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
   /// promotion info should NOT be performed in batches.
-  void promoteSlot();
+  /// Returns a promotable allocation op if a new allocator was created, nullopt
+  /// otherwise.
+  std::optional<PromotableAllocationOpInterface> promoteSlot();
 
 private:
   /// Computes the reaching definition for all the operations that require
@@ -595,7 +597,8 @@ void MemorySlotPromoter::removeBlockingUses() {
          "after promotion, the slot pointer should not be used anymore");
 }
 
-void MemorySlotPromoter::promoteSlot() {
+std::optional<PromotableAllocationOpInterface>
+MemorySlotPromoter::promoteSlot() {
   computeReachingDefInRegion(slot.ptr.getParentRegion(),
                              getOrCreateDefaultValue());
 
@@ -622,7 +625,7 @@ void MemorySlotPromoter::promoteSlot() {
   if (statistics.promotedAmount)
     (*statistics.promotedAmount)++;
 
-  allocator.handlePromotionComplete(slot, defaultValue, builder);
+  return allocator.handlePromotionComplete(slot, defaultValue, builder);
 }
 
 LogicalResult mlir::tryToPromoteMemorySlots(
@@ -642,6 +645,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
   SmallVector<PromotableAllocationOpInterface> newWorkList;
   newWorkList.reserve(workList.size());
   while (true) {
+    bool changesInThisRound = false;
     for (PromotableAllocationOpInterface allocator : workList) {
       for (MemorySlot slot : allocator.getPromotableSlots()) {
         if (slot.ptr.use_empty())
@@ -650,17 +654,27 @@ LogicalResult mlir::tryToPromoteMemorySlots(
         MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
         std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
         if (info) {
-          MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
-                             std::move(*info), statistics, blockIndexCache)
-              .promoteSlot();
-          promotedAny = true;
-          continue;
+          std::optional<PromotableAllocationOpInterface> newAllocator =
+              MemorySlotPromoter(slot, allocator, builder, dominance,
+                                 dataLayout, std::move(*info), statistics,
+                                 blockIndexCache)
+                  .promoteSlot();
+          changesInThisRound = true;
+          // Add newly created allocators to the worklist for further
+          // processing.
+          if (newAllocator)
+            newWorkList.push_back(*newAllocator);
+
+          // Breaking is required, as a modification to an allocator might have
+          // removed it, making the other slots invalid.
+          break;
         }
         newWorkList.push_back(allocator);
       }
     }
-    if (workList.size() == newWorkList.size())
+    if (!changesInThisRound)
       break;
+    promotedAny = true;
 
     // Swap the vector's backing memory and clear the entries in newWorkList
     // afterwards. This ensures that additional heap allocations can be avoided.
diff --git a/mlir/test/Transforms/mem2reg.mlir b/mlir/test/Transforms/mem2reg.mlir
new file mode 100644
index 000000000000..894cbec010fa
--- /dev/null
+++ b/mlir/test/Transforms/mem2reg.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
+
+// Verifies that allocators with mutliple slots are handled properly.
+
+// CHECK-LABEL: func.func @multi_slot_alloca
+func.func @multi_slot_alloca() -> (i32, i32) {
+  // CHECK-NOT: test.multi_slot_alloca
+  %1, %2 = test.multi_slot_alloca : () -> (memref<i32>, memref<i32>)
+  %3 = memref.load %1[] : memref<i32>
+  %4 = memref.load %2[] : memref<i32>
+  return %3, %4 : i32, i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 08df2e5e1228..c9f0f43fa2ec 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
 
 using namespace mlir;
 using namespace test;
@@ -1172,3 +1173,59 @@ void TestOpWithVersionedProperties::writeToMlirBytecode(
   writer.writeVarInt(prop.value1);
   writer.writeVarInt(prop.value2);
 }
+
+//===----------------------------------------------------------------------===//
+// TestMultiSlotAlloca
+//===----------------------------------------------------------------------===//
+
+llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
+  SmallVector<MemorySlot> slots;
+  for (Value result : getResults()) {
+    slots.push_back(MemorySlot{
+        result, cast<MemRefType>(result.getType()).getElementType()});
+  }
+  return slots;
+}
+
+Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
+                                           OpBuilder &builder) {
+  return builder.create<TestOpConstant>(getLoc(), slot.elemType,
+                                        builder.getI32IntegerAttr(42));
+}
+
+void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
+                                              BlockArgument argument,
+                                              OpBuilder &builder) {
+  // Not relevant for testing.
+}
+
+std::optional<PromotableAllocationOpInterface>
+TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
+                                             Value defaultValue,
+                                             OpBuilder &builder) {
+  if (defaultValue && defaultValue.use_empty())
+    defaultValue.getDefiningOp()->erase();
+
+  if (getNumResults() == 1) {
+    erase();
+    return std::nullopt;
+  }
+
+  SmallVector<Type> newTypes;
+  SmallVector<Value> remainingValues;
+
+  for (Value oldResult : getResults()) {
+    if (oldResult == slot.ptr)
+      continue;
+    remainingValues.push_back(oldResult);
+    newTypes.push_back(oldResult.getType());
+  }
+
+  auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
+  for (auto [oldResult, newResult] :
+       llvm::zip_equal(remainingValues, replacement.getResults()))
+    oldResult.replaceAllUsesWith(newResult);
+
+  erase();
+  return replacement;
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index f9925855bb9d..837ccca56592 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -36,6 +36,7 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5352d574ac39..e16ea2407314 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -28,6 +28,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 
@@ -3167,4 +3168,14 @@ def TestOpOptionallyImplementingInterface
   let arguments = (ins BoolAttr:$implementsInterface);
 }
 
+//===----------------------------------------------------------------------===//
+// Test Mem2Reg
+//===----------------------------------------------------------------------===//
+
+def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
+    [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
+  let results = (outs Variadic<MemRefOf<[I32]>>:$results);
+  let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
+}
+
 #endif // TEST_OPS



More information about the Mlir-commits mailing list