[Mlir-commits] [mlir] [MLIR][SROA] Reuse allocators to avoid rewalking the IR (PR #91971)
Christian Ulmann
llvmlistbot at llvm.org
Mon May 13 07:25:20 PDT 2024
https://github.com/Dinistro created https://github.com/llvm/llvm-project/pull/91971
This commit extends the SROA interfaces to ensure the interface instantiations can communicate newly created allocators to the algorithm. This ensures that the SROA implementation does no longer require re-walking the IR to find new allocators.
>From c60b96b7e377569cbec608cde9ee9733062ea9aa Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 13 May 2024 14:19:28 +0000
Subject: [PATCH] [MLIR][SROA] Reuse allocators to avoid rewalking the IR
This commit extends the SROA interfaces to ensure the impementations can
communicate newly created allocators to the algorithm. This ensures that
the SROA implementation does no longer require re-walking the IR to find
new allocators.
---
.../mlir/Interfaces/MemorySlotInterfaces.td | 10 ++-
mlir/include/mlir/Transforms/SROA.h | 6 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 13 +--
.../Dialect/MemRef/IR/MemRefMemorySlot.cpp | 13 +--
mlir/lib/Transforms/SROA.cpp | 86 ++++++++++++-------
mlir/test/Transforms/sroa.mlir | 31 +++++++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 79 ++++++++++++++---
mlir/test/lib/Dialect/Test/TestOps.td | 5 +-
8 files changed, 185 insertions(+), 58 deletions(-)
create mode 100644 mlir/test/Transforms/sroa.mlir
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index e2409cbec5fde..29960f457373c 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -298,14 +298,20 @@ def DestructurableAllocationOpInterface
"destructure",
(ins "const ::mlir::DestructurableMemorySlot &":$slot,
"const ::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
- "::mlir::OpBuilder &":$builder)
+ "::mlir::OpBuilder &":$builder,
+ "::mlir::SmallVectorImpl<::mlir::DestructurableAllocationOpInterface> &":
+ $newAllocators)
>,
InterfaceMethod<[{
Hook triggered once the destructuring of a slot is complete, meaning the
original slot is no longer being refered to and could be deleted.
This will only be called for slots declared by this operation.
+
+ Must return a new destructurable allocation op if this operation
+ produced multiple destructurable slots, nullopt otherwise.
}],
- "void", "handleDestructuringComplete",
+ "::std::optional<::mlir::DestructurableAllocationOpInterface>",
+ "handleDestructuringComplete",
(ins "const ::mlir::DestructurableMemorySlot &":$slot,
"::mlir::OpBuilder &":$builder)
>,
diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h
index fa84fb1eae73a..d48f809d30766 100644
--- a/mlir/include/mlir/Transforms/SROA.h
+++ b/mlir/include/mlir/Transforms/SROA.h
@@ -27,8 +27,10 @@ struct SROAStatistics {
llvm::Statistic *maxSubelementAmount = nullptr;
};
-/// Attempts to destructure the slots of destructurable allocators. Returns
-/// failure if no slot was destructured.
+/// Attempts to destructure the slots of destructurable allocators. Iteratively
+/// retries the destructuring of all slots as destructuring one slot might
+/// enable subsequent destructuring. Returns failure if no slot was
+/// destructured.
LogicalResult tryToDestructureMemorySlots(
ArrayRef<DestructurableAllocationOpInterface> allocators,
OpBuilder &builder, const DataLayout &dataLayout,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 4fdf847a559ce..3f1e5b1773bf7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -77,10 +77,10 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
*destructuredType}};
}
-DenseMap<Attribute, MemorySlot>
-LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
- const SmallPtrSetImpl<Attribute> &usedIndices,
- OpBuilder &builder) {
+DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure(
+ const DestructurableMemorySlot &slot,
+ const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+ SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
assert(slot.ptr == getResult());
builder.setInsertionPointAfter(*this);
@@ -92,16 +92,19 @@ LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
auto subAlloca = builder.create<LLVM::AllocaOp>(
getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
getArraySize());
+ newAllocators.push_back(subAlloca);
slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
}
return slotMap;
}
-void LLVM::AllocaOp::handleDestructuringComplete(
+std::optional<DestructurableAllocationOpInterface>
+LLVM::AllocaOp::handleDestructuringComplete(
const DestructurableMemorySlot &slot, OpBuilder &builder) {
assert(slot.ptr == getResult());
this->erase();
+ return std::nullopt;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index e30598e6878f4..631dee2d40538 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -126,10 +126,10 @@ memref::AllocaOp::getDestructurableSlots() {
DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
}
-DenseMap<Attribute, MemorySlot>
-memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
- const SmallPtrSetImpl<Attribute> &usedIndices,
- OpBuilder &builder) {
+DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
+ const DestructurableMemorySlot &slot,
+ const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+ SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
builder.setInsertionPointAfter(*this);
DenseMap<Attribute, MemorySlot> slotMap;
@@ -139,6 +139,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
Type elemType = memrefType.getTypeAtIndex(usedIndex);
MemRefType elemPtr = MemRefType::get({}, elemType);
auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
+ newAllocators.push_back(subAlloca);
slotMap.try_emplace<MemorySlot>(usedIndex,
{subAlloca.getResult(), elemType});
}
@@ -146,10 +147,12 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
return slotMap;
}
-void memref::AllocaOp::handleDestructuringComplete(
+std::optional<DestructurableAllocationOpInterface>
+memref::AllocaOp::handleDestructuringComplete(
const DestructurableMemorySlot &slot, OpBuilder &builder) {
assert(slot.ptr == getResult());
this->erase();
+ return std::nullopt;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 4e28fa687ffd4..67cbade07bc94 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -132,16 +132,17 @@ computeDestructuringInfo(DestructurableMemorySlot &slot,
/// Performs the destructuring of a destructible slot given associated
/// destructuring information. The provided slot will be destructured in
/// subslots as specified by its allocator.
-static void destructureSlot(DestructurableMemorySlot &slot,
- DestructurableAllocationOpInterface allocator,
- OpBuilder &builder, const DataLayout &dataLayout,
- MemorySlotDestructuringInfo &info,
- const SROAStatistics &statistics) {
+static void destructureSlot(
+ DestructurableMemorySlot &slot,
+ DestructurableAllocationOpInterface allocator, OpBuilder &builder,
+ const DataLayout &dataLayout, MemorySlotDestructuringInfo &info,
+ SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators,
+ const SROAStatistics &statistics) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(slot.ptr.getParentBlock());
DenseMap<Attribute, MemorySlot> subslots =
- allocator.destructure(slot, info.usedIndices, builder);
+ allocator.destructure(slot, info.usedIndices, builder, newAllocators);
if (statistics.slotsWithMemoryBenefit &&
slot.elementPtrs.size() != info.usedIndices.size())
@@ -185,7 +186,11 @@ static void destructureSlot(DestructurableMemorySlot &slot,
if (statistics.destructuredAmount)
(*statistics.destructuredAmount)++;
- allocator.handleDestructuringComplete(slot, builder);
+ std::optional<DestructurableAllocationOpInterface> newAllocator =
+ allocator.handleDestructuringComplete(slot, builder);
+ // Add newly created allocators to the worklist for further processing.
+ if (newAllocator)
+ newAllocators.push_back(*newAllocator);
}
LogicalResult mlir::tryToDestructureMemorySlots(
@@ -194,16 +199,44 @@ LogicalResult mlir::tryToDestructureMemorySlots(
SROAStatistics statistics) {
bool destructuredAny = false;
- for (DestructurableAllocationOpInterface allocator : allocators) {
- for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
- std::optional<MemorySlotDestructuringInfo> info =
- computeDestructuringInfo(slot, dataLayout);
- if (!info)
- continue;
+ SmallVector<DestructurableAllocationOpInterface> workList(allocators.begin(),
+ allocators.end());
+ SmallVector<DestructurableAllocationOpInterface> newWorkList;
+ newWorkList.reserve(allocators.size());
+ // Destructuring a slot can allow for further destructuring of other
+ // slots, destructuring is tried until no destructuring succeeds.
+ while (true) {
+ bool changesInThisRound = false;
+
+ for (DestructurableAllocationOpInterface allocator : workList) {
+ bool destructuredAnySlot = false;
+ for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
+ std::optional<MemorySlotDestructuringInfo> info =
+ computeDestructuringInfo(slot, dataLayout);
+ if (!info)
+ continue;
- destructureSlot(slot, allocator, builder, dataLayout, *info, statistics);
- destructuredAny = true;
+ destructureSlot(slot, allocator, builder, dataLayout, *info,
+ newWorkList, statistics);
+ destructuredAnySlot = true;
+
+ // A break is required, since destructuring a slot may invalidate the
+ // remaning slots of an allocator.
+ break;
+ }
+ if (!destructuredAnySlot)
+ newWorkList.push_back(allocator);
+ changesInThisRound |= destructuredAnySlot;
}
+
+ if (!changesInThisRound)
+ break;
+ destructuredAny |= changesInThisRound;
+
+ // Swap the vector's backing memory and clear the entries in newWorkList
+ // afterwards. This ensures that additional heap allocations can be avoided.
+ workList.swap(newWorkList);
+ newWorkList.clear();
}
return success(destructuredAny);
@@ -230,23 +263,16 @@ struct SROA : public impl::SROABase<SROA> {
OpBuilder builder(®ion.front(), region.front().begin());
- // Destructuring a slot can allow for further destructuring of other
- // slots, destructuring is tried until no destructuring succeeds.
- while (true) {
- SmallVector<DestructurableAllocationOpInterface> allocators;
- // Build a list of allocators to attempt to destructure the slots of.
- // TODO: Update list on the fly to avoid repeated visiting of the same
- // allocators.
- region.walk([&](DestructurableAllocationOpInterface allocator) {
- allocators.emplace_back(allocator);
- });
-
- if (failed(tryToDestructureMemorySlots(allocators, builder, dataLayout,
- statistics)))
- break;
+ SmallVector<DestructurableAllocationOpInterface> allocators;
+ // Build a list of allocators to attempt to destructure the slots of.
+ region.walk([&](DestructurableAllocationOpInterface allocator) {
+ allocators.emplace_back(allocator);
+ });
+ // Attempt to destructure as many slots as possible.
+ if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout,
+ statistics)))
changed = true;
- }
}
if (!changed)
markAllAnalysesPreserved();
diff --git a/mlir/test/Transforms/sroa.mlir b/mlir/test/Transforms/sroa.mlir
new file mode 100644
index 0000000000000..c9e80a6cf8dd1
--- /dev/null
+++ b/mlir/test/Transforms/sroa.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(sroa))' --split-input-file | FileCheck %s
+
+// Verifies that allocators with mutliple slots are handled properly.
+
+// CHECK-LABEL: func.func @multi_slot_alloca
+func.func @multi_slot_alloca() -> (i32, i32) {
+ %0 = arith.constant 0 : index
+ %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
+ // CHECK-COUNT-2: test.multi_slot_alloca : () -> memref<i32>
+ %3 = memref.load %1[%0] {first}: memref<2xi32>
+ %4 = memref.load %2[%0] {second} : memref<4xi32>
+ return %3, %4 : i32, i32
+}
+
+// -----
+
+// Verifies that a multi slot allocator can be partially destructured.
+
+func.func private @consumer(memref<2xi32>)
+
+// CHECK-LABEL: func.func @multi_slot_alloca_only_second
+func.func @multi_slot_alloca_only_second() -> (i32, i32) {
+ %0 = arith.constant 0 : index
+ // CHECK: test.multi_slot_alloca : () -> memref<2xi32>
+ // CHECK: test.multi_slot_alloca : () -> memref<i32>
+ %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
+ func.call @consumer(%1) : (memref<2xi32>) -> ()
+ %3 = memref.load %1[%0] : memref<2xi32>
+ %4 = memref.load %2[%0] : memref<4xi32>
+ return %3, %4 : i32, i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index d22d48b139a04..c8949e8d1af72 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1199,22 +1199,20 @@ void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
// Not relevant for testing.
}
-std::optional<PromotableAllocationOpInterface>
-TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
- Value defaultValue,
- OpBuilder &builder) {
- if (defaultValue && defaultValue.use_empty())
- defaultValue.getDefiningOp()->erase();
+/// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
+static std::optional<TestMultiSlotAlloca>
+createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
+ TestMultiSlotAlloca oldOp) {
- if (getNumResults() == 1) {
- erase();
+ if (oldOp.getNumResults() == 1) {
+ oldOp.erase();
return std::nullopt;
}
SmallVector<Type> newTypes;
SmallVector<Value> remainingValues;
- for (Value oldResult : getResults()) {
+ for (Value oldResult : oldOp.getResults()) {
if (oldResult == slot.ptr)
continue;
remainingValues.push_back(oldResult);
@@ -1222,12 +1220,69 @@ TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
}
OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPoint(*this);
- auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
+ builder.setInsertionPoint(oldOp);
+ auto replacement =
+ builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
for (auto [oldResult, newResult] :
llvm::zip_equal(remainingValues, replacement.getResults()))
oldResult.replaceAllUsesWith(newResult);
- erase();
+ oldOp.erase();
return replacement;
}
+
+std::optional<PromotableAllocationOpInterface>
+TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
+ Value defaultValue,
+ OpBuilder &builder) {
+ if (defaultValue && defaultValue.use_empty())
+ defaultValue.getDefiningOp()->erase();
+ return createNewMultiAllocaWithoutSlot(slot, builder, *this);
+}
+
+SmallVector<DestructurableMemorySlot>
+TestMultiSlotAlloca::getDestructurableSlots() {
+ SmallVector<DestructurableMemorySlot> slots;
+ for (Value result : getResults()) {
+ auto memrefType = cast<MemRefType>(result.getType());
+ auto destructurable =
+ llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
+ if (!destructurable)
+ continue;
+
+ std::optional<DenseMap<Attribute, Type>> destructuredType =
+ destructurable.getSubelementIndexMap();
+ if (!destructuredType)
+ continue;
+ slots.emplace_back(
+ DestructurableMemorySlot{{result, memrefType}, *destructuredType});
+ }
+ return slots;
+}
+
+DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
+ const DestructurableMemorySlot &slot,
+ const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+ SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointAfter(*this);
+
+ DenseMap<Attribute, MemorySlot> slotMap;
+
+ for (Attribute usedIndex : usedIndices) {
+ Type elemType = slot.elementPtrs.lookup(usedIndex);
+ MemRefType elemPtr = MemRefType::get({}, elemType);
+ auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
+ newAllocators.push_back(subAlloca);
+ slotMap.try_emplace<MemorySlot>(usedIndex,
+ {subAlloca.getResult(0), elemType});
+ }
+
+ return slotMap;
+}
+
+std::optional<DestructurableAllocationOpInterface>
+TestMultiSlotAlloca::handleDestructuringComplete(
+ const DestructurableMemorySlot &slot, OpBuilder &builder) {
+ return createNewMultiAllocaWithoutSlot(slot, builder, *this);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e16ea2407314e..7fc3d22d18958 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3169,11 +3169,12 @@ def TestOpOptionallyImplementingInterface
}
//===----------------------------------------------------------------------===//
-// Test Mem2Reg
+// Test Mem2Reg & SROA
//===----------------------------------------------------------------------===//
def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
- [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
+ [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>,
+ DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>]> {
let results = (outs Variadic<MemRefOf<[I32]>>:$results);
let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
}
More information about the Mlir-commits
mailing list