[Mlir-commits] [mlir] [MLIR][Mem2Reg] Change API to always retry promotion after changes (PR #91464)
Christian Ulmann
llvmlistbot at llvm.org
Wed May 8 08:58:28 PDT 2024
https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/91464
>From 0f0628b295931ee6e0e2cb210071a3ba09cb58c3 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Wed, 8 May 2024 12:03:56 +0000
Subject: [PATCH 1/2] [MLIR][Mem2Reg] Change API to always retry promotion
after changes
This commit modifies the Mem2Reg's API to always attempt a full
promotion on all the passed in "allocators". This ensures that the pass
does not require unnecessary walks over the regions and improves caching
benefits.
---
mlir/include/mlir/Transforms/Mem2Reg.h | 6 +--
mlir/lib/Transforms/Mem2Reg.cpp | 62 +++++++++++++++-----------
2 files changed, 39 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index fee7fb312750..6986cad9ae12 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -9,7 +9,6 @@
#ifndef MLIR_TRANSFORMS_MEM2REG_H
#define MLIR_TRANSFORMS_MEM2REG_H
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/Statistic.h"
@@ -23,8 +22,9 @@ struct Mem2RegStatistics {
llvm::Statistic *newBlockArgumentAmount = nullptr;
};
-/// Attempts to promote the memory slots of the provided allocators. Succeeds if
-/// at least one memory slot was promoted.
+/// Attempts to promote the memory slots of the provided allocators. Iteratively
+/// retries the promotion of all slots as promoting one slot might enable
+/// subsequent promotions. Succeeds if at least one memory slot was promoted.
LogicalResult
tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
OpBuilder &builder, const DataLayout &dataLayout,
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 8adbbcd01cb4..390d2a3f54b6 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -636,20 +636,36 @@ LogicalResult mlir::tryToPromoteMemorySlots(
// lazily and cached to avoid expensive recomputation.
BlockIndexCache blockIndexCache;
- for (PromotableAllocationOpInterface allocator : allocators) {
- for (MemorySlot slot : allocator.getPromotableSlots()) {
- if (slot.ptr.use_empty())
- continue;
-
- MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
- std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
- if (info) {
- MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
- std::move(*info), statistics, blockIndexCache)
- .promoteSlot();
- promotedAny = true;
+ SmallVector<PromotableAllocationOpInterface> workList(allocators.begin(),
+ allocators.end());
+
+ SmallVector<PromotableAllocationOpInterface> newWorkList;
+ newWorkList.reserve(workList.size());
+ while (true) {
+ for (PromotableAllocationOpInterface allocator : workList) {
+ for (MemorySlot slot : allocator.getPromotableSlots()) {
+ if (slot.ptr.use_empty())
+ continue;
+
+ MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
+ std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
+ if (info) {
+ MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
+ std::move(*info), statistics, blockIndexCache)
+ .promoteSlot();
+ promotedAny = true;
+ continue;
+ }
+ newWorkList.push_back(allocator);
}
}
+ if (workList.size() == newWorkList.size())
+ break;
+
+ // Swap the vector's backing memory and clear the entries in newWorkList
+ // afterwards. This ensures that additional heap allocations can be avoided.
+ workList.swap(newWorkList);
+ newWorkList.clear();
}
return success(promotedAny);
@@ -677,22 +693,16 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
OpBuilder builder(®ion.front(), region.front().begin());
- // Promoting a slot can allow for further promotion of other slots,
- // promotion is tried until no promotion succeeds.
- while (true) {
- SmallVector<PromotableAllocationOpInterface> allocators;
- // Build a list of allocators to attempt to promote the slots of.
- region.walk([&](PromotableAllocationOpInterface allocator) {
- allocators.emplace_back(allocator);
- });
-
- // Attempt promoting until no promotion succeeds.
- if (failed(tryToPromoteMemorySlots(allocators, builder, dataLayout,
- dominance, statistics)))
- break;
+ SmallVector<PromotableAllocationOpInterface> allocators;
+ // Build a list of allocators to attempt to promote the slots of.
+ region.walk([&](PromotableAllocationOpInterface allocator) {
+ allocators.emplace_back(allocator);
+ });
+ // Attempt promoting as many of the slots as possible.
+ if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout,
+ dominance, statistics)))
changed = true;
- }
}
if (!changed)
markAllAnalysesPreserved();
>From b7326dfdc64c6c7941964b61ffdebec5110084ab Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Wed, 8 May 2024 15:57:56 +0000
Subject: [PATCH 2/2] extend with fix for multi slot allocators
---
.../mlir/Interfaces/MemorySlotInterfaces.td | 6 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 8 ++-
.../Dialect/MemRef/IR/MemRefMemorySlot.cpp | 8 ++-
mlir/lib/Transforms/Mem2Reg.cpp | 32 ++++++++---
mlir/test/Transforms/mem2reg.mlir | 12 ++++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 57 +++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOps.h | 1 +
mlir/test/lib/Dialect/Test/TestOps.td | 11 ++++
8 files changed, 119 insertions(+), 16 deletions(-)
create mode 100644 mlir/test/Transforms/mem2reg.mlir
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index adf182ac7069..762a97ac546a 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -68,8 +68,12 @@ def PromotableAllocationOpInterface
Hook triggered once the promotion of a slot is complete. This can
also clean up the created default value if necessary.
This will only be called for slots declared by this operation.
+
+ Must return a new promotable allocation op if this operation produced
+ multiple promotable slots, nullopt otherwise.
}],
- "void", "handlePromotionComplete",
+ "std::optional<::mlir::PromotableAllocationOpInterface>",
+ "handlePromotionComplete",
(ins
"const ::mlir::MemorySlot &":$slot,
"::mlir::Value":$defaultValue,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 70102e1c8192..4fdf847a559c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -50,12 +50,14 @@ void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
declareOp.getLocationExpr());
}
-void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
- Value defaultValue,
- OpBuilder &builder) {
+std::optional<PromotableAllocationOpInterface>
+LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+ Value defaultValue,
+ OpBuilder &builder) {
if (defaultValue && defaultValue.use_empty())
defaultValue.getDefiningOp()->erase();
this->erase();
+ return std::nullopt;
}
SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index dca07e84ea73..e30598e6878f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -96,12 +96,14 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
});
}
-void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
- Value defaultValue,
- OpBuilder &builder) {
+std::optional<PromotableAllocationOpInterface>
+memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+ Value defaultValue,
+ OpBuilder &builder) {
if (defaultValue.use_empty())
defaultValue.getDefiningOp()->erase();
this->erase();
+ return std::nullopt;
}
void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 390d2a3f54b6..0d90e6820e3c 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -173,7 +173,9 @@ class MemorySlotPromoter {
/// Actually promotes the slot by mutating IR. Promoting a slot DOES
/// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
/// promotion info should NOT be performed in batches.
- void promoteSlot();
+ /// Returns a promotable allocation op if a new allocator was created, nullopt
+ /// otherwise.
+ std::optional<PromotableAllocationOpInterface> promoteSlot();
private:
/// Computes the reaching definition for all the operations that require
@@ -595,7 +597,8 @@ void MemorySlotPromoter::removeBlockingUses() {
"after promotion, the slot pointer should not be used anymore");
}
-void MemorySlotPromoter::promoteSlot() {
+std::optional<PromotableAllocationOpInterface>
+MemorySlotPromoter::promoteSlot() {
computeReachingDefInRegion(slot.ptr.getParentRegion(),
getOrCreateDefaultValue());
@@ -622,7 +625,7 @@ void MemorySlotPromoter::promoteSlot() {
if (statistics.promotedAmount)
(*statistics.promotedAmount)++;
- allocator.handlePromotionComplete(slot, defaultValue, builder);
+ return allocator.handlePromotionComplete(slot, defaultValue, builder);
}
LogicalResult mlir::tryToPromoteMemorySlots(
@@ -642,6 +645,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
SmallVector<PromotableAllocationOpInterface> newWorkList;
newWorkList.reserve(workList.size());
while (true) {
+ bool changesInThisRound = false;
for (PromotableAllocationOpInterface allocator : workList) {
for (MemorySlot slot : allocator.getPromotableSlots()) {
if (slot.ptr.use_empty())
@@ -650,17 +654,27 @@ LogicalResult mlir::tryToPromoteMemorySlots(
MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
if (info) {
- MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
- std::move(*info), statistics, blockIndexCache)
- .promoteSlot();
- promotedAny = true;
- continue;
+ std::optional<PromotableAllocationOpInterface> newAllocator =
+ MemorySlotPromoter(slot, allocator, builder, dominance,
+ dataLayout, std::move(*info), statistics,
+ blockIndexCache)
+ .promoteSlot();
+ changesInThisRound = true;
+ // Add newly created allocators to the worklist for further
+ // processing.
+ if (newAllocator)
+ newWorkList.push_back(*newAllocator);
+
+ // Breaking is required, as a modification to an allocator might have
+ // removed it, making the other slots invalid.
+ break;
}
newWorkList.push_back(allocator);
}
}
- if (workList.size() == newWorkList.size())
+ if (!changesInThisRound)
break;
+ promotedAny = true;
// Swap the vector's backing memory and clear the entries in newWorkList
// afterwards. This ensures that additional heap allocations can be avoided.
diff --git a/mlir/test/Transforms/mem2reg.mlir b/mlir/test/Transforms/mem2reg.mlir
new file mode 100644
index 000000000000..894cbec010fa
--- /dev/null
+++ b/mlir/test/Transforms/mem2reg.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
+
+// Verifies that allocators with mutliple slots are handled properly.
+
+// CHECK-LABEL: func.func @multi_slot_alloca
+func.func @multi_slot_alloca() -> (i32, i32) {
+ // CHECK-NOT: test.multi_slot_alloca
+ %1, %2 = test.multi_slot_alloca : () -> (memref<i32>, memref<i32>)
+ %3 = memref.load %1[] : memref<i32>
+ %4 = memref.load %2[] : memref<i32>
+ return %3, %4 : i32, i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 08df2e5e1228..c9f0f43fa2ec 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
using namespace mlir;
using namespace test;
@@ -1172,3 +1173,59 @@ void TestOpWithVersionedProperties::writeToMlirBytecode(
writer.writeVarInt(prop.value1);
writer.writeVarInt(prop.value2);
}
+
+//===----------------------------------------------------------------------===//
+// TestMultiSlotAlloca
+//===----------------------------------------------------------------------===//
+
+llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
+ SmallVector<MemorySlot> slots;
+ for (Value result : getResults()) {
+ slots.push_back(MemorySlot{
+ result, cast<MemRefType>(result.getType()).getElementType()});
+ }
+ return slots;
+}
+
+Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
+ OpBuilder &builder) {
+ return builder.create<TestOpConstant>(getLoc(), slot.elemType,
+ builder.getI32IntegerAttr(42));
+}
+
+void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
+ BlockArgument argument,
+ OpBuilder &builder) {
+ // Not relevant for testing.
+}
+
+std::optional<PromotableAllocationOpInterface>
+TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
+ Value defaultValue,
+ OpBuilder &builder) {
+ if (defaultValue && defaultValue.use_empty())
+ defaultValue.getDefiningOp()->erase();
+
+ if (getNumResults() == 1) {
+ erase();
+ return std::nullopt;
+ }
+
+ SmallVector<Type> newTypes;
+ SmallVector<Value> remainingValues;
+
+ for (Value oldResult : getResults()) {
+ if (oldResult == slot.ptr)
+ continue;
+ remainingValues.push_back(oldResult);
+ newTypes.push_back(oldResult.getType());
+ }
+
+ auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
+ for (auto [oldResult, newResult] :
+ llvm::zip_equal(remainingValues, replacement.getResults()))
+ oldResult.replaceAllUsesWith(newResult);
+
+ erase();
+ return replacement;
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index f9925855bb9d..837ccca56592 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -36,6 +36,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5352d574ac39..e16ea2407314 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -28,6 +28,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -3167,4 +3168,14 @@ def TestOpOptionallyImplementingInterface
let arguments = (ins BoolAttr:$implementsInterface);
}
+//===----------------------------------------------------------------------===//
+// Test Mem2Reg
+//===----------------------------------------------------------------------===//
+
+def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
+ [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
+ let results = (outs Variadic<MemRefOf<[I32]>>:$results);
+ let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
+}
+
#endif // TEST_OPS
More information about the Mlir-commits
mailing list