[Mlir-commits] [mlir] Enable custom alloc-like ops in `promoteBufferResultsToOutParams` (PR #120288)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 19 09:47:33 PST 2024
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/120288
>From 1db8e794fc718ba3720789659ebba47c9e7e05a8 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 17 Dec 2024 12:57:55 -0600
Subject: [PATCH 1/2] Enable any `AllocationOpInterface` with
`hoistStaticAllocs` option
---
.../Bufferization/Transforms/BufferResultsToOutParams.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b7755b2be8483b..b4d2d6b0c5da8f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -121,7 +122,8 @@ static LogicalResult updateReturnOps(func::FuncOp func,
OpBuilder builder(op);
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (hoistStaticAllocs &&
- isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
+ isa_and_nonnull<bufferization::AllocationOpInterface>(
+ orig.getDefiningOp()) &&
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
>From bf145cfeb45263d879eb998ce47cf6dd93aa2b06 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Dec 2024 11:46:37 -0600
Subject: [PATCH 2/2] add and use custom allocation function
---
.../Dialect/Bufferization/Transforms/Passes.h | 13 +++++++++-
.../Transforms/BufferResultsToOutParams.cpp | 24 ++++++++++++++-----
2 files changed, 30 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index fe43a05c81fdc3..966438956fc6cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -6,6 +6,7 @@
namespace mlir {
class FunctionOpInterface;
+class MemRefType;
class ModuleOp;
class RewritePatternSet;
class OpBuilder;
@@ -38,7 +39,7 @@ std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
DeallocationOptions options = DeallocationOptions());
/// Creates a pass that finds all temporary allocations
-/// and attempts to move the deallocation after the last user/dependency
+/// and attempts to move the deallocation after the last user/dependency
/// of the allocation, thereby optimizing allocation liveness.
std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();
@@ -157,6 +158,12 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOpts {
+ /// Allocator function: Generate a memref allocation with the given type.
+ /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
+ /// results, we don't allow passing a range of values for dynamic dims.
+ using AllocationFn =
+ std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
+
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -167,6 +174,10 @@ struct BufferResultsToOutParamsOpts {
return true;
};
+ /// Allocation function; used to allocate a memref.
+ /// If this is empty, memref.alloc is used
+ std::optional<AllocationFn> allocationFn;
+
/// Memcpy function; used to create a copy between two memrefs.
/// If this is empty, memref.copy is used.
std::optional<MemCpyFn> memCpyFn;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b4d2d6b0c5da8f..545b6ca009c03c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -22,6 +22,7 @@ namespace bufferization {
} // namespace mlir
using namespace mlir;
+using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
/// Return `true` if the given MemRef type has a fully dynamic layout.
@@ -141,9 +142,8 @@ static LogicalResult updateReturnOps(func::FuncOp func,
// Updates all CallOps in the scope of the given ModuleOp by allocating
// temporary buffers for newly introduced out params.
-static LogicalResult
-updateCalls(ModuleOp module,
- const bufferization::BufferResultsToOutParamsOpts &options) {
+static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
+ std::function<bool(func::FuncOp *)> filterFn) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
@@ -154,7 +154,7 @@ updateCalls(ModuleOp module,
didFail = true;
return;
}
- if (!options.filterFn(&callee))
+ if (!filterFn(&callee))
return;
SmallVector<Value, 6> replaceWithNewCallResults;
SmallVector<Value, 6> replaceWithOutParams;
@@ -177,7 +177,13 @@ updateCalls(ModuleOp module,
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
- Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
+ auto maybeOutParam = allocationFn(builder, op.getLoc(), allocType);
+ if (failed(maybeOutParam)) {
+ op.emitError() << "failed to create allocation op";
+ didFail = true;
+ return;
+ }
+ Value outParam = maybeOutParam.value();
if (!hasStaticIdentityLayout(memrefType)) {
// Layout maps are already checked in `updateFuncOp`.
assert(hasFullyDynamicLayoutMap(memrefType) &&
@@ -226,7 +232,13 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
}
}
- if (failed(updateCalls(module, options)))
+ auto defaultAllocationFn = [](OpBuilder &builder, Location loc,
+ MemRefType type) {
+ return builder.create<memref::AllocOp>(loc, type).getResult();
+ };
+ if (failed(updateCalls(module,
+ options.allocationFn.value_or(defaultAllocationFn),
+ options.filterFn)))
return failure();
return success();
}
More information about the Mlir-commits
mailing list