[Mlir-commits] [mlir] Enable custom alloc-like ops in `promoteBufferResultsToOutParams` (PR #120288)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 20 11:40:53 PST 2024
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/120288
>From 1db8e794fc718ba3720789659ebba47c9e7e05a8 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 17 Dec 2024 12:57:55 -0600
Subject: [PATCH 1/3] Enable any `AllocationOpInterface` with
`hoistStaticAllocs` option
---
.../Bufferization/Transforms/BufferResultsToOutParams.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b7755b2be8483b..b4d2d6b0c5da8f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -121,7 +122,8 @@ static LogicalResult updateReturnOps(func::FuncOp func,
OpBuilder builder(op);
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (hoistStaticAllocs &&
- isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
+ isa_and_nonnull<bufferization::AllocationOpInterface>(
+ orig.getDefiningOp()) &&
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
>From bf145cfeb45263d879eb998ce47cf6dd93aa2b06 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Dec 2024 11:46:37 -0600
Subject: [PATCH 2/3] add and use custom allocation function
---
.../Dialect/Bufferization/Transforms/Passes.h | 13 +++++++++-
.../Transforms/BufferResultsToOutParams.cpp | 24 ++++++++++++++-----
2 files changed, 30 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index fe43a05c81fdc3..966438956fc6cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -6,6 +6,7 @@
namespace mlir {
class FunctionOpInterface;
+class MemRefType;
class ModuleOp;
class RewritePatternSet;
class OpBuilder;
@@ -38,7 +39,7 @@ std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
DeallocationOptions options = DeallocationOptions());
/// Creates a pass that finds all temporary allocations
-/// and attempts to move the deallocation after the last user/dependency
+/// and attempts to move the deallocation after the last user/dependency
/// of the allocation, thereby optimizing allocation liveness.
std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();
@@ -157,6 +158,12 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOpts {
+ /// Allocator function: Generate a memref allocation with the given type.
+ /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
+ /// results, we don't allow passing a range of values for dynamic dims.
+ using AllocationFn =
+ std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
+
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -167,6 +174,10 @@ struct BufferResultsToOutParamsOpts {
return true;
};
+ /// Allocation function; used to allocate a memref.
+ /// If this is empty, memref.alloc is used
+ std::optional<AllocationFn> allocationFn;
+
/// Memcpy function; used to create a copy between two memrefs.
/// If this is empty, memref.copy is used.
std::optional<MemCpyFn> memCpyFn;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b4d2d6b0c5da8f..545b6ca009c03c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -22,6 +22,7 @@ namespace bufferization {
} // namespace mlir
using namespace mlir;
+using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
/// Return `true` if the given MemRef type has a fully dynamic layout.
@@ -141,9 +142,8 @@ static LogicalResult updateReturnOps(func::FuncOp func,
// Updates all CallOps in the scope of the given ModuleOp by allocating
// temporary buffers for newly introduced out params.
-static LogicalResult
-updateCalls(ModuleOp module,
- const bufferization::BufferResultsToOutParamsOpts &options) {
+static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
+ std::function<bool(func::FuncOp *)> filterFn) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
@@ -154,7 +154,7 @@ updateCalls(ModuleOp module,
didFail = true;
return;
}
- if (!options.filterFn(&callee))
+ if (!filterFn(&callee))
return;
SmallVector<Value, 6> replaceWithNewCallResults;
SmallVector<Value, 6> replaceWithOutParams;
@@ -177,7 +177,13 @@ updateCalls(ModuleOp module,
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
- Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
+ auto maybeOutParam = allocationFn(builder, op.getLoc(), allocType);
+ if (failed(maybeOutParam)) {
+ op.emitError() << "failed to create allocation op";
+ didFail = true;
+ return;
+ }
+ Value outParam = maybeOutParam.value();
if (!hasStaticIdentityLayout(memrefType)) {
// Layout maps are already checked in `updateFuncOp`.
assert(hasFullyDynamicLayoutMap(memrefType) &&
@@ -226,7 +232,13 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
}
}
- if (failed(updateCalls(module, options)))
+ auto defaultAllocationFn = [](OpBuilder &builder, Location loc,
+ MemRefType type) {
+ return builder.create<memref::AllocOp>(loc, type).getResult();
+ };
+ if (failed(updateCalls(module,
+ options.allocationFn.value_or(defaultAllocationFn),
+ options.filterFn)))
return failure();
return success();
}
>From cc788e0b16ac112c0ea214c04a1de943d8101f32 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 20 Dec 2024 10:14:42 -0600
Subject: [PATCH 3/3] Move default alloc/copy fns to
`BufferResultsToOutParamsOpts` struct
---
.../Dialect/Bufferization/Transforms/Passes.h | 16 ++++++--
.../Transforms/BufferResultsToOutParams.cpp | 38 +++++++------------
2 files changed, 25 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 966438956fc6cd..c8e456a1d7e380 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -2,6 +2,7 @@
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -175,12 +176,19 @@ struct BufferResultsToOutParamsOpts {
};
/// Allocation function; used to allocate a memref.
- /// If this is empty, memref.alloc is used
- std::optional<AllocationFn> allocationFn;
+ /// Default memref.alloc is used
+ AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
+ MemRefType type) {
+ return builder.create<memref::AllocOp>(loc, type).getResult();
+ };
/// Memcpy function; used to create a copy between two memrefs.
- /// If this is empty, memref.copy is used.
- std::optional<MemCpyFn> memCpyFn;
+ /// Default memref.copy is used.
+ MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from,
+ Value to) {
+ builder.create<memref::CopyOp>(loc, from, to);
+ return success();
+ };
/// If true, the pass adds a "bufferize.result" attribute to each output
/// parameter.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 545b6ca009c03c..2502744cb3f580 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -107,10 +107,9 @@ updateFuncOp(func::FuncOp func,
// Updates all ReturnOps in the scope of the given func::FuncOp by either
// keeping them as return values or copying the associated buffer contents into
// the given out-params.
-static LogicalResult updateReturnOps(func::FuncOp func,
- ArrayRef<BlockArgument> appendedEntryArgs,
- MemCpyFn memCpyFn,
- bool hoistStaticAllocs) {
+static LogicalResult
+updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
+ const bufferization::BufferResultsToOutParamsOpts &options) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
@@ -122,14 +121,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
}
OpBuilder builder(op);
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
- if (hoistStaticAllocs &&
+ if (options.hoistStaticAllocs &&
isa_and_nonnull<bufferization::AllocationOpInterface>(
orig.getDefiningOp()) &&
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
} else {
- if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
+ if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
return WalkResult::interrupt();
}
}
@@ -142,8 +141,9 @@ static LogicalResult updateReturnOps(func::FuncOp func,
// Updates all CallOps in the scope of the given ModuleOp by allocating
// temporary buffers for newly introduced out params.
-static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
- std::function<bool(func::FuncOp *)> filterFn) {
+static LogicalResult
+updateCalls(ModuleOp module,
+ const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
@@ -154,7 +154,7 @@ static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
didFail = true;
return;
}
- if (!filterFn(&callee))
+ if (!options.filterFn(&callee))
return;
SmallVector<Value, 6> replaceWithNewCallResults;
SmallVector<Value, 6> replaceWithOutParams;
@@ -177,7 +177,8 @@ static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
- auto maybeOutParam = allocationFn(builder, op.getLoc(), allocType);
+ auto maybeOutParam =
+ options.allocationFn(builder, op.getLoc(), allocType);
if (failed(maybeOutParam)) {
op.emitError() << "failed to create allocation op";
didFail = true;
@@ -221,24 +222,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
if (func.isExternal())
continue;
- auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
- Value to) {
- builder.create<memref::CopyOp>(loc, from, to);
- return success();
- };
- if (failed(updateReturnOps(func, appendedEntryArgs,
- options.memCpyFn.value_or(defaultMemCpyFn),
- options.hoistStaticAllocs))) {
+ if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
return failure();
}
}
- auto defaultAllocationFn = [](OpBuilder &builder, Location loc,
- MemRefType type) {
- return builder.create<memref::AllocOp>(loc, type).getResult();
- };
- if (failed(updateCalls(module,
- options.allocationFn.value_or(defaultAllocationFn),
- options.filterFn)))
+ if (failed(updateCalls(module, options)))
return failure();
return success();
}
More information about the Mlir-commits
mailing list