[Mlir-commits] [mlir] 8906b7b - Enable custom alloc-like ops in `promoteBufferResultsToOutParams` (#120288)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 26 09:32:55 PST 2024
Author: srcarroll
Date: 2024-12-26T11:32:51-06:00
New Revision: 8906b7be918be653d3c5f2ef3dbd923561603969
URL: https://github.com/llvm/llvm-project/commit/8906b7be918be653d3c5f2ef3dbd923561603969
DIFF: https://github.com/llvm/llvm-project/commit/8906b7be918be653d3c5f2ef3dbd923561603969.diff
LOG: Enable custom alloc-like ops in `promoteBufferResultsToOutParams` (#120288)
In `buffer-results-to-out-params`, when `hoist-static-allocs` option is
enabled the pass was looking for `memref.alloc`s in order to attempt to
avoid copies when it can. Which makes it not extensible to external ops
that have allocation like properties. This patch simply changes
`memref::AllocOp` to `AllocationOpInterface` in the check to enable for
any allocation op.
Moreover, for function call updates, we enable setting an allocation
function callback in `BufferResultsToOutParamsOpts` to allow users to
emit their own alloc-like op.
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index fe43a05c81fdc3..c8e456a1d7e380 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -2,10 +2,12 @@
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
class FunctionOpInterface;
+class MemRefType;
class ModuleOp;
class RewritePatternSet;
class OpBuilder;
@@ -38,7 +40,7 @@ std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
DeallocationOptions options = DeallocationOptions());
/// Creates a pass that finds all temporary allocations
-/// and attempts to move the deallocation after the last user/dependency
+/// and attempts to move the deallocation after the last user/dependency
/// of the allocation, thereby optimizing allocation liveness.
std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();
@@ -157,6 +159,12 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOpts {
+ /// Allocator function: Generate a memref allocation with the given type.
+ /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
+ /// results, we don't allow passing a range of values for dynamic dims.
+ using AllocationFn =
+ std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
+
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -167,9 +175,20 @@ struct BufferResultsToOutParamsOpts {
return true;
};
+ /// Allocation function; used to allocate a memref.
+ /// Default memref.alloc is used
+ AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
+ MemRefType type) {
+ return builder.create<memref::AllocOp>(loc, type).getResult();
+ };
+
/// Memcpy function; used to create a copy between two memrefs.
- /// If this is empty, memref.copy is used.
- std::optional<MemCpyFn> memCpyFn;
+ /// Default memref.copy is used.
+ MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from,
+ Value to) {
+ builder.create<memref::CopyOp>(loc, from, to);
+ return success();
+ };
/// If true, the pass adds a "bufferize.result" attribute to each output
/// parameter.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b7755b2be8483b..2502744cb3f580 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -21,6 +22,7 @@ namespace bufferization {
} // namespace mlir
using namespace mlir;
+using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
/// Return `true` if the given MemRef type has a fully dynamic layout.
@@ -105,10 +107,9 @@ updateFuncOp(func::FuncOp func,
// Updates all ReturnOps in the scope of the given func::FuncOp by either
// keeping them as return values or copying the associated buffer contents into
// the given out-params.
-static LogicalResult updateReturnOps(func::FuncOp func,
- ArrayRef<BlockArgument> appendedEntryArgs,
- MemCpyFn memCpyFn,
- bool hoistStaticAllocs) {
+static LogicalResult
+updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
+ const bufferization::BufferResultsToOutParamsOpts &options) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
@@ -120,13 +121,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
}
OpBuilder builder(op);
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
- if (hoistStaticAllocs &&
- isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
+ if (options.hoistStaticAllocs &&
+ isa_and_nonnull<bufferization::AllocationOpInterface>(
+ orig.getDefiningOp()) &&
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
} else {
- if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
+ if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
return WalkResult::interrupt();
}
}
@@ -175,7 +177,14 @@ updateCalls(ModuleOp module,
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
- Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
+ auto maybeOutParam =
+ options.allocationFn(builder, op.getLoc(), allocType);
+ if (failed(maybeOutParam)) {
+ op.emitError() << "failed to create allocation op";
+ didFail = true;
+ return;
+ }
+ Value outParam = maybeOutParam.value();
if (!hasStaticIdentityLayout(memrefType)) {
// Layout maps are already checked in `updateFuncOp`.
assert(hasFullyDynamicLayoutMap(memrefType) &&
@@ -213,14 +222,7 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
if (func.isExternal())
continue;
- auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
- Value to) {
- builder.create<memref::CopyOp>(loc, from, to);
- return success();
- };
- if (failed(updateReturnOps(func, appendedEntryArgs,
- options.memCpyFn.value_or(defaultMemCpyFn),
- options.hoistStaticAllocs))) {
+ if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
return failure();
}
}
More information about the Mlir-commits
mailing list