[Mlir-commits] [mlir] eeafc9d - [MLIR][Mem2Reg] Fix multi slot handling & move retry handling (#91464)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun May 12 22:37:45 PDT 2024
Author: Christian Ulmann
Date: 2024-05-13T07:37:41+02:00
New Revision: eeafc9daa15d2d022bcdd456d4b8bafd23f5f121
URL: https://github.com/llvm/llvm-project/commit/eeafc9daa15d2d022bcdd456d4b8bafd23f5f121
DIFF: https://github.com/llvm/llvm-project/commit/eeafc9daa15d2d022bcdd456d4b8bafd23f5f121.diff
LOG: [MLIR][Mem2Reg] Fix multi slot handling & move retry handling (#91464)
This commit fixes Mem2Regs mutli-slot allocator handling and extends the
test dialect to test this.
Additionally, this modifies Mem2Reg's API to always attempt a full
promotion on all the passed in "allocators". This ensures that the pass
does not require unnecessary walks over the regions and improves caching
benefits.
Added:
mlir/test/Transforms/mem2reg.mlir
Modified:
mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
mlir/include/mlir/Transforms/Mem2Reg.h
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
mlir/lib/Transforms/Mem2Reg.cpp
mlir/test/lib/Dialect/Test/TestOpDefs.cpp
mlir/test/lib/Dialect/Test/TestOps.h
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index adf182ac7069d..e2409cbec5fde 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -68,8 +68,12 @@ def PromotableAllocationOpInterface
Hook triggered once the promotion of a slot is complete. This can
also clean up the created default value if necessary.
This will only be called for slots declared by this operation.
+
+ Must return a new promotable allocation op if this operation produced
+ multiple promotable slots, nullopt otherwise.
}],
- "void", "handlePromotionComplete",
+ "::std::optional<::mlir::PromotableAllocationOpInterface>",
+ "handlePromotionComplete",
(ins
"const ::mlir::MemorySlot &":$slot,
"::mlir::Value":$defaultValue,
diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index fee7fb312750c..6986cad9ae120 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -9,7 +9,6 @@
#ifndef MLIR_TRANSFORMS_MEM2REG_H
#define MLIR_TRANSFORMS_MEM2REG_H
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/Statistic.h"
@@ -23,8 +22,9 @@ struct Mem2RegStatistics {
llvm::Statistic *newBlockArgumentAmount = nullptr;
};
-/// Attempts to promote the memory slots of the provided allocators. Succeeds if
-/// at least one memory slot was promoted.
+/// Attempts to promote the memory slots of the provided allocators. Iteratively
+/// retries the promotion of all slots as promoting one slot might enable
+/// subsequent promotions. Succeeds if at least one memory slot was promoted.
LogicalResult
tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
OpBuilder &builder, const DataLayout &dataLayout,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 70102e1c81920..4fdf847a559ce 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -50,12 +50,14 @@ void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
declareOp.getLocationExpr());
}
-void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
- Value defaultValue,
- OpBuilder &builder) {
+std::optional<PromotableAllocationOpInterface>
+LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+ Value defaultValue,
+ OpBuilder &builder) {
if (defaultValue && defaultValue.use_empty())
defaultValue.getDefiningOp()->erase();
this->erase();
+ return std::nullopt;
}
SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index dca07e84ea73c..e30598e6878f4 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -96,12 +96,14 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
});
}
-void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
- Value defaultValue,
- OpBuilder &builder) {
+std::optional<PromotableAllocationOpInterface>
+memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+ Value defaultValue,
+ OpBuilder &builder) {
if (defaultValue.use_empty())
defaultValue.getDefiningOp()->erase();
this->erase();
+ return std::nullopt;
}
void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 8adbbcd01cb44..e096747741c0a 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -173,7 +173,9 @@ class MemorySlotPromoter {
/// Actually promotes the slot by mutating IR. Promoting a slot DOES
/// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
/// promotion info should NOT be performed in batches.
- void promoteSlot();
+ /// Returns a promotable allocation op if a new allocator was created, nullopt
+ /// otherwise.
+ std::optional<PromotableAllocationOpInterface> promoteSlot();
private:
/// Computes the reaching definition for all the operations that require
@@ -595,7 +597,8 @@ void MemorySlotPromoter::removeBlockingUses() {
"after promotion, the slot pointer should not be used anymore");
}
-void MemorySlotPromoter::promoteSlot() {
+std::optional<PromotableAllocationOpInterface>
+MemorySlotPromoter::promoteSlot() {
computeReachingDefInRegion(slot.ptr.getParentRegion(),
getOrCreateDefaultValue());
@@ -622,7 +625,7 @@ void MemorySlotPromoter::promoteSlot() {
if (statistics.promotedAmount)
(*statistics.promotedAmount)++;
- allocator.handlePromotionComplete(slot, defaultValue, builder);
+ return allocator.handlePromotionComplete(slot, defaultValue, builder);
}
LogicalResult mlir::tryToPromoteMemorySlots(
@@ -636,20 +639,50 @@ LogicalResult mlir::tryToPromoteMemorySlots(
// lazily and cached to avoid expensive recomputation.
BlockIndexCache blockIndexCache;
- for (PromotableAllocationOpInterface allocator : allocators) {
- for (MemorySlot slot : allocator.getPromotableSlots()) {
- if (slot.ptr.use_empty())
- continue;
-
- MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
- std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
- if (info) {
- MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
- std::move(*info), statistics, blockIndexCache)
- .promoteSlot();
- promotedAny = true;
+ SmallVector<PromotableAllocationOpInterface> workList(allocators.begin(),
+ allocators.end());
+
+ SmallVector<PromotableAllocationOpInterface> newWorkList;
+ newWorkList.reserve(workList.size());
+ while (true) {
+ bool changesInThisRound = false;
+ for (PromotableAllocationOpInterface allocator : workList) {
+ bool changedAllocator = false;
+ for (MemorySlot slot : allocator.getPromotableSlots()) {
+ if (slot.ptr.use_empty())
+ continue;
+
+ MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
+ std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
+ if (info) {
+ std::optional<PromotableAllocationOpInterface> newAllocator =
+ MemorySlotPromoter(slot, allocator, builder, dominance,
+ dataLayout, std::move(*info), statistics,
+ blockIndexCache)
+ .promoteSlot();
+ changedAllocator = true;
+ // Add newly created allocators to the worklist for further
+ // processing.
+ if (newAllocator)
+ newWorkList.push_back(*newAllocator);
+
+ // A break is required, since promoting a slot may invalidate the
+ // remaining slots of an allocator.
+ break;
+ }
}
+ if (!changedAllocator)
+ newWorkList.push_back(allocator);
+ changesInThisRound |= changedAllocator;
}
+ if (!changesInThisRound)
+ break;
+ promotedAny = true;
+
+ // Swap the vector's backing memory and clear the entries in newWorkList
+ // afterwards. This ensures that additional heap allocations can be avoided.
+ workList.swap(newWorkList);
+ newWorkList.clear();
}
return success(promotedAny);
@@ -677,22 +710,16 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
OpBuilder builder(®ion.front(), region.front().begin());
- // Promoting a slot can allow for further promotion of other slots,
- // promotion is tried until no promotion succeeds.
- while (true) {
- SmallVector<PromotableAllocationOpInterface> allocators;
- // Build a list of allocators to attempt to promote the slots of.
- region.walk([&](PromotableAllocationOpInterface allocator) {
- allocators.emplace_back(allocator);
- });
-
- // Attempt promoting until no promotion succeeds.
- if (failed(tryToPromoteMemorySlots(allocators, builder, dataLayout,
- dominance, statistics)))
- break;
+ SmallVector<PromotableAllocationOpInterface> allocators;
+ // Build a list of allocators to attempt to promote the slots of.
+ region.walk([&](PromotableAllocationOpInterface allocator) {
+ allocators.emplace_back(allocator);
+ });
+ // Attempt promoting as many of the slots as possible.
+ if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout,
+ dominance, statistics)))
changed = true;
- }
}
if (!changed)
markAllAnalysesPreserved();
diff --git a/mlir/test/Transforms/mem2reg.mlir b/mlir/test/Transforms/mem2reg.mlir
new file mode 100644
index 0000000000000..daeaa2da07634
--- /dev/null
+++ b/mlir/test/Transforms/mem2reg.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
+
+// Verifies that allocators with mutliple slots are handled properly.
+
+// CHECK-LABEL: func.func @multi_slot_alloca
+func.func @multi_slot_alloca() -> (i32, i32) {
+ // CHECK-NOT: test.multi_slot_alloca
+ %1, %2 = test.multi_slot_alloca : () -> (memref<i32>, memref<i32>)
+ %3 = memref.load %1[] : memref<i32>
+ %4 = memref.load %2[] : memref<i32>
+ return %3, %4 : i32, i32
+}
+
+// -----
+
+// Verifies that a multi slot allocator can be partially promoted.
+
+func.func private @consumer(memref<i32>)
+
+// CHECK-LABEL: func.func @multi_slot_alloca_only_second
+func.func @multi_slot_alloca_only_second() -> (i32, i32) {
+ // CHECK: %{{[[:alnum:]]+}} = test.multi_slot_alloca
+ %1, %2 = test.multi_slot_alloca : () -> (memref<i32>, memref<i32>)
+ func.call @consumer(%1) : (memref<i32>) -> ()
+ %3 = memref.load %1[] : memref<i32>
+ %4 = memref.load %2[] : memref<i32>
+ return %3, %4 : i32, i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 08df2e5e12286..d22d48b139a04 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
using namespace mlir;
using namespace test;
@@ -1172,3 +1173,61 @@ void TestOpWithVersionedProperties::writeToMlirBytecode(
writer.writeVarInt(prop.value1);
writer.writeVarInt(prop.value2);
}
+
+//===----------------------------------------------------------------------===//
+// TestMultiSlotAlloca
+//===----------------------------------------------------------------------===//
+
+llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
+ SmallVector<MemorySlot> slots;
+ for (Value result : getResults()) {
+ slots.push_back(MemorySlot{
+ result, cast<MemRefType>(result.getType()).getElementType()});
+ }
+ return slots;
+}
+
+Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
+ OpBuilder &builder) {
+ return builder.create<TestOpConstant>(getLoc(), slot.elemType,
+ builder.getI32IntegerAttr(42));
+}
+
+void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
+ BlockArgument argument,
+ OpBuilder &builder) {
+ // Not relevant for testing.
+}
+
+std::optional<PromotableAllocationOpInterface>
+TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
+ Value defaultValue,
+ OpBuilder &builder) {
+ if (defaultValue && defaultValue.use_empty())
+ defaultValue.getDefiningOp()->erase();
+
+ if (getNumResults() == 1) {
+ erase();
+ return std::nullopt;
+ }
+
+ SmallVector<Type> newTypes;
+ SmallVector<Value> remainingValues;
+
+ for (Value oldResult : getResults()) {
+ if (oldResult == slot.ptr)
+ continue;
+ remainingValues.push_back(oldResult);
+ newTypes.push_back(oldResult.getType());
+ }
+
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPoint(*this);
+ auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
+ for (auto [oldResult, newResult] :
+ llvm::zip_equal(remainingValues, replacement.getResults()))
+ oldResult.replaceAllUsesWith(newResult);
+
+ erase();
+ return replacement;
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index f9925855bb9db..837ccca565926 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -36,6 +36,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5352d574ac394..e16ea2407314e 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -28,6 +28,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -3167,4 +3168,14 @@ def TestOpOptionallyImplementingInterface
let arguments = (ins BoolAttr:$implementsInterface);
}
+//===----------------------------------------------------------------------===//
+// Test Mem2Reg
+//===----------------------------------------------------------------------===//
+
+def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
+ [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
+ let results = (outs Variadic<MemRefOf<[I32]>>:$results);
+ let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
+}
+
#endif // TEST_OPS
More information about the Mlir-commits
mailing list