[Mlir-commits] [mlir] [MLIR][SROA] Reuse allocators to avoid rewalking the IR (PR #91971)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 13 07:25:57 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Christian Ulmann (Dinistro)
<details>
<summary>Changes</summary>
This commit extends the SROA interfaces to ensure the interface instantiations can communicate newly created allocators to the algorithm. This ensures that the SROA implementation does no longer require re-walking the IR to find new allocators.
---
Full diff: https://github.com/llvm/llvm-project/pull/91971.diff
8 Files Affected:
- (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.td (+8-2)
- (modified) mlir/include/mlir/Transforms/SROA.h (+4-2)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+8-5)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+8-5)
- (modified) mlir/lib/Transforms/SROA.cpp (+56-30)
- (added) mlir/test/Transforms/sroa.mlir (+31)
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+67-12)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+3-2)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index e2409cbec5fde..29960f457373c 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -298,14 +298,20 @@ def DestructurableAllocationOpInterface
"destructure",
(ins "const ::mlir::DestructurableMemorySlot &":$slot,
"const ::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
- "::mlir::OpBuilder &":$builder)
+ "::mlir::OpBuilder &":$builder,
+ "::mlir::SmallVectorImpl<::mlir::DestructurableAllocationOpInterface> &":
+ $newAllocators)
>,
InterfaceMethod<[{
Hook triggered once the destructuring of a slot is complete, meaning the
original slot is no longer being refered to and could be deleted.
This will only be called for slots declared by this operation.
+
+ Must return a new destructurable allocation op if this operation
+ produced multiple destructurable slots, nullopt otherwise.
}],
- "void", "handleDestructuringComplete",
+ "::std::optional<::mlir::DestructurableAllocationOpInterface>",
+ "handleDestructuringComplete",
(ins "const ::mlir::DestructurableMemorySlot &":$slot,
"::mlir::OpBuilder &":$builder)
>,
diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h
index fa84fb1eae73a..d48f809d30766 100644
--- a/mlir/include/mlir/Transforms/SROA.h
+++ b/mlir/include/mlir/Transforms/SROA.h
@@ -27,8 +27,10 @@ struct SROAStatistics {
llvm::Statistic *maxSubelementAmount = nullptr;
};
-/// Attempts to destructure the slots of destructurable allocators. Returns
-/// failure if no slot was destructured.
+/// Attempts to destructure the slots of destructurable allocators. Iteratively
+/// retries the destructuring of all slots as destructuring one slot might
+/// enable subsequent destructuring. Returns failure if no slot was
+/// destructured.
LogicalResult tryToDestructureMemorySlots(
ArrayRef<DestructurableAllocationOpInterface> allocators,
OpBuilder &builder, const DataLayout &dataLayout,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 4fdf847a559ce..3f1e5b1773bf7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -77,10 +77,10 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
*destructuredType}};
}
-DenseMap<Attribute, MemorySlot>
-LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
- const SmallPtrSetImpl<Attribute> &usedIndices,
- OpBuilder &builder) {
+DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure(
+ const DestructurableMemorySlot &slot,
+ const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+ SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
assert(slot.ptr == getResult());
builder.setInsertionPointAfter(*this);
@@ -92,16 +92,19 @@ LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
auto subAlloca = builder.create<LLVM::AllocaOp>(
getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
getArraySize());
+ newAllocators.push_back(subAlloca);
slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
}
return slotMap;
}
-void LLVM::AllocaOp::handleDestructuringComplete(
+std::optional<DestructurableAllocationOpInterface>
+LLVM::AllocaOp::handleDestructuringComplete(
const DestructurableMemorySlot &slot, OpBuilder &builder) {
assert(slot.ptr == getResult());
this->erase();
+ return std::nullopt;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index e30598e6878f4..631dee2d40538 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -126,10 +126,10 @@ memref::AllocaOp::getDestructurableSlots() {
DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
}
-DenseMap<Attribute, MemorySlot>
-memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
- const SmallPtrSetImpl<Attribute> &usedIndices,
- OpBuilder &builder) {
+DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
+ const DestructurableMemorySlot &slot,
+ const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+ SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
builder.setInsertionPointAfter(*this);
DenseMap<Attribute, MemorySlot> slotMap;
@@ -139,6 +139,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
Type elemType = memrefType.getTypeAtIndex(usedIndex);
MemRefType elemPtr = MemRefType::get({}, elemType);
auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
+ newAllocators.push_back(subAlloca);
slotMap.try_emplace<MemorySlot>(usedIndex,
{subAlloca.getResult(), elemType});
}
@@ -146,10 +147,12 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
return slotMap;
}
-void memref::AllocaOp::handleDestructuringComplete(
+std::optional<DestructurableAllocationOpInterface>
+memref::AllocaOp::handleDestructuringComplete(
const DestructurableMemorySlot &slot, OpBuilder &builder) {
assert(slot.ptr == getResult());
this->erase();
+ return std::nullopt;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 4e28fa687ffd4..67cbade07bc94 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -132,16 +132,17 @@ computeDestructuringInfo(DestructurableMemorySlot &slot,
/// Performs the destructuring of a destructible slot given associated
/// destructuring information. The provided slot will be destructured in
/// subslots as specified by its allocator.
-static void destructureSlot(DestructurableMemorySlot &slot,
- DestructurableAllocationOpInterface allocator,
- OpBuilder &builder, const DataLayout &dataLayout,
- MemorySlotDestructuringInfo &info,
- const SROAStatistics &statistics) {
+static void destructureSlot(
+ DestructurableMemorySlot &slot,
+ DestructurableAllocationOpInterface allocator, OpBuilder &builder,
+ const DataLayout &dataLayout, MemorySlotDestructuringInfo &info,
+ SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators,
+ const SROAStatistics &statistics) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(slot.ptr.getParentBlock());
DenseMap<Attribute, MemorySlot> subslots =
- allocator.destructure(slot, info.usedIndices, builder);
+ allocator.destructure(slot, info.usedIndices, builder, newAllocators);
if (statistics.slotsWithMemoryBenefit &&
slot.elementPtrs.size() != info.usedIndices.size())
@@ -185,7 +186,11 @@ static void destructureSlot(DestructurableMemorySlot &slot,
if (statistics.destructuredAmount)
(*statistics.destructuredAmount)++;
- allocator.handleDestructuringComplete(slot, builder);
+ std::optional<DestructurableAllocationOpInterface> newAllocator =
+ allocator.handleDestructuringComplete(slot, builder);
+ // Add newly created allocators to the worklist for further processing.
+ if (newAllocator)
+ newAllocators.push_back(*newAllocator);
}
LogicalResult mlir::tryToDestructureMemorySlots(
@@ -194,16 +199,44 @@ LogicalResult mlir::tryToDestructureMemorySlots(
SROAStatistics statistics) {
bool destructuredAny = false;
- for (DestructurableAllocationOpInterface allocator : allocators) {
- for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
- std::optional<MemorySlotDestructuringInfo> info =
- computeDestructuringInfo(slot, dataLayout);
- if (!info)
- continue;
+ SmallVector<DestructurableAllocationOpInterface> workList(allocators.begin(),
+ allocators.end());
+ SmallVector<DestructurableAllocationOpInterface> newWorkList;
+ newWorkList.reserve(allocators.size());
+ // Destructuring a slot can allow for further destructuring of other
+ // slots, destructuring is tried until no destructuring succeeds.
+ while (true) {
+ bool changesInThisRound = false;
+
+ for (DestructurableAllocationOpInterface allocator : workList) {
+ bool destructuredAnySlot = false;
+ for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
+ std::optional<MemorySlotDestructuringInfo> info =
+ computeDestructuringInfo(slot, dataLayout);
+ if (!info)
+ continue;
- destructureSlot(slot, allocator, builder, dataLayout, *info, statistics);
- destructuredAny = true;
+ destructureSlot(slot, allocator, builder, dataLayout, *info,
+ newWorkList, statistics);
+ destructuredAnySlot = true;
+
+ // A break is required, since destructuring a slot may invalidate the
+ // remaning slots of an allocator.
+ break;
+ }
+ if (!destructuredAnySlot)
+ newWorkList.push_back(allocator);
+ changesInThisRound |= destructuredAnySlot;
}
+
+ if (!changesInThisRound)
+ break;
+ destructuredAny |= changesInThisRound;
+
+ // Swap the vector's backing memory and clear the entries in newWorkList
+ // afterwards. This ensures that additional heap allocations can be avoided.
+ workList.swap(newWorkList);
+ newWorkList.clear();
}
return success(destructuredAny);
@@ -230,23 +263,16 @@ struct SROA : public impl::SROABase<SROA> {
OpBuilder builder(®ion.front(), region.front().begin());
- // Destructuring a slot can allow for further destructuring of other
- // slots, destructuring is tried until no destructuring succeeds.
- while (true) {
- SmallVector<DestructurableAllocationOpInterface> allocators;
- // Build a list of allocators to attempt to destructure the slots of.
- // TODO: Update list on the fly to avoid repeated visiting of the same
- // allocators.
- region.walk([&](DestructurableAllocationOpInterface allocator) {
- allocators.emplace_back(allocator);
- });
-
- if (failed(tryToDestructureMemorySlots(allocators, builder, dataLayout,
- statistics)))
- break;
+ SmallVector<DestructurableAllocationOpInterface> allocators;
+ // Build a list of allocators to attempt to destructure the slots of.
+ region.walk([&](DestructurableAllocationOpInterface allocator) {
+ allocators.emplace_back(allocator);
+ });
+ // Attempt to destructure as many slots as possible.
+ if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout,
+ statistics)))
changed = true;
- }
}
if (!changed)
markAllAnalysesPreserved();
diff --git a/mlir/test/Transforms/sroa.mlir b/mlir/test/Transforms/sroa.mlir
new file mode 100644
index 0000000000000..c9e80a6cf8dd1
--- /dev/null
+++ b/mlir/test/Transforms/sroa.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(sroa))' --split-input-file | FileCheck %s
+
+// Verifies that allocators with mutliple slots are handled properly.
+
+// CHECK-LABEL: func.func @multi_slot_alloca
+func.func @multi_slot_alloca() -> (i32, i32) {
+ %0 = arith.constant 0 : index
+ %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
+ // CHECK-COUNT-2: test.multi_slot_alloca : () -> memref<i32>
+ %3 = memref.load %1[%0] {first}: memref<2xi32>
+ %4 = memref.load %2[%0] {second} : memref<4xi32>
+ return %3, %4 : i32, i32
+}
+
+// -----
+
+// Verifies that a multi slot allocator can be partially destructured.
+
+func.func private @consumer(memref<2xi32>)
+
+// CHECK-LABEL: func.func @multi_slot_alloca_only_second
+func.func @multi_slot_alloca_only_second() -> (i32, i32) {
+ %0 = arith.constant 0 : index
+ // CHECK: test.multi_slot_alloca : () -> memref<2xi32>
+ // CHECK: test.multi_slot_alloca : () -> memref<i32>
+ %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
+ func.call @consumer(%1) : (memref<2xi32>) -> ()
+ %3 = memref.load %1[%0] : memref<2xi32>
+ %4 = memref.load %2[%0] : memref<4xi32>
+ return %3, %4 : i32, i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index d22d48b139a04..c8949e8d1af72 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1199,22 +1199,20 @@ void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
// Not relevant for testing.
}
-std::optional<PromotableAllocationOpInterface>
-TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
- Value defaultValue,
- OpBuilder &builder) {
- if (defaultValue && defaultValue.use_empty())
- defaultValue.getDefiningOp()->erase();
+/// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
+static std::optional<TestMultiSlotAlloca>
+createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
+ TestMultiSlotAlloca oldOp) {
- if (getNumResults() == 1) {
- erase();
+ if (oldOp.getNumResults() == 1) {
+ oldOp.erase();
return std::nullopt;
}
SmallVector<Type> newTypes;
SmallVector<Value> remainingValues;
- for (Value oldResult : getResults()) {
+ for (Value oldResult : oldOp.getResults()) {
if (oldResult == slot.ptr)
continue;
remainingValues.push_back(oldResult);
@@ -1222,12 +1220,69 @@ TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
}
OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPoint(*this);
- auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
+ builder.setInsertionPoint(oldOp);
+ auto replacement =
+ builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
for (auto [oldResult, newResult] :
llvm::zip_equal(remainingValues, replacement.getResults()))
oldResult.replaceAllUsesWith(newResult);
- erase();
+ oldOp.erase();
return replacement;
}
+
+std::optional<PromotableAllocationOpInterface>
+TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
+ Value defaultValue,
+ OpBuilder &builder) {
+ if (defaultValue && defaultValue.use_empty())
+ defaultValue.getDefiningOp()->erase();
+ return createNewMultiAllocaWithoutSlot(slot, builder, *this);
+}
+
+SmallVector<DestructurableMemorySlot>
+TestMultiSlotAlloca::getDestructurableSlots() {
+ SmallVector<DestructurableMemorySlot> slots;
+ for (Value result : getResults()) {
+ auto memrefType = cast<MemRefType>(result.getType());
+ auto destructurable =
+ llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
+ if (!destructurable)
+ continue;
+
+ std::optional<DenseMap<Attribute, Type>> destructuredType =
+ destructurable.getSubelementIndexMap();
+ if (!destructuredType)
+ continue;
+ slots.emplace_back(
+ DestructurableMemorySlot{{result, memrefType}, *destructuredType});
+ }
+ return slots;
+}
+
+DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
+ const DestructurableMemorySlot &slot,
+ const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+ SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointAfter(*this);
+
+ DenseMap<Attribute, MemorySlot> slotMap;
+
+ for (Attribute usedIndex : usedIndices) {
+ Type elemType = slot.elementPtrs.lookup(usedIndex);
+ MemRefType elemPtr = MemRefType::get({}, elemType);
+ auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
+ newAllocators.push_back(subAlloca);
+ slotMap.try_emplace<MemorySlot>(usedIndex,
+ {subAlloca.getResult(0), elemType});
+ }
+
+ return slotMap;
+}
+
+std::optional<DestructurableAllocationOpInterface>
+TestMultiSlotAlloca::handleDestructuringComplete(
+ const DestructurableMemorySlot &slot, OpBuilder &builder) {
+ return createNewMultiAllocaWithoutSlot(slot, builder, *this);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e16ea2407314e..7fc3d22d18958 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3169,11 +3169,12 @@ def TestOpOptionallyImplementingInterface
}
//===----------------------------------------------------------------------===//
-// Test Mem2Reg
+// Test Mem2Reg & SROA
//===----------------------------------------------------------------------===//
def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
- [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
+ [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>,
+ DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>]> {
let results = (outs Variadic<MemRefOf<[I32]>>:$results);
let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/91971
More information about the Mlir-commits
mailing list