[Mlir-commits] [mlir] 248e113 - [mlir][bufferize][NFC] Move helper functions to BufferizationOptions
Matthias Springer
llvmlistbot at llvm.org
Wed May 11 07:23:33 PDT 2022
Author: Matthias Springer
Date: 2022-05-11T16:23:22+02:00
New Revision: 248e113e9f6e583ed93e52de621a89d098c6d79e
URL: https://github.com/llvm/llvm-project/commit/248e113e9f6e583ed93e52de621a89d098c6d79e
DIFF: https://github.com/llvm/llvm-project/commit/248e113e9f6e583ed93e52de621a89d098c6d79e.diff
LOG: [mlir][bufferize][NFC] Move helper functions to BufferizationOptions
Move helper functions for creating allocs/deallocs/memcpys to BufferizationOptions.
Differential Revision: https://reviews.llvm.org/D125375
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 7d80e472c90ea..421db92543f02 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -161,6 +161,19 @@ struct BufferizationOptions {
Optional<DeallocationFn> deallocationFn;
Optional<MemCpyFn> memCpyFn;
+ /// Create a memref allocation with the given type and dynamic extents.
+ FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+ ValueRange dynShape) const;
+
+ /// Creates a memref deallocation. The given memref buffer must have been
+ /// allocated using `createAlloc`.
+ LogicalResult createDealloc(OpBuilder &b, Location loc,
+ Value allocatedBuffer) const;
+
+ /// Creates a memcpy between two given buffers.
+ LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from,
+ Value to) const;
+
/// Specifies whether not bufferizable ops are allowed in the input. If so,
/// bufferization.to_memref and bufferization.to_tensor ops are inserted at
/// the boundaries.
@@ -514,15 +527,6 @@ BaseMemRefType getMemRefType(TensorType tensorType,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {});
-/// Creates a memref deallocation. The given memref buffer must have been
-/// allocated using `createAlloc`.
-LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
- const BufferizationOptions &options);
-
-/// Creates a memcpy between two given buffers.
-LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
- const BufferizationOptions &options);
-
/// Try to hoist all new buffer allocations until the next hoisting barrier.
LogicalResult hoistBufferAllocations(Operation *op,
const BufferizationOptions &options);
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 8a3bc5c40b5c3..29d983bdffdf1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -327,8 +327,7 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
// The copy happens right before the op that is bufferized.
rewriter.setInsertionPoint(op);
}
- if (failed(
- createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options)))
+ if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer)))
return failure();
return resultBuffer;
@@ -418,26 +417,24 @@ bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const {
//===----------------------------------------------------------------------===//
/// Create a memref allocation with the given type and dynamic extents.
-static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ValueRange dynShape,
- const BufferizationOptions &options) {
- if (options.allocationFn)
- return (*options.allocationFn)(b, loc, type, dynShape,
- options.bufferAlignment);
+FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
+ MemRefType type,
+ ValueRange dynShape) const {
+ if (allocationFn)
+ return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
// Default bufferallocation via AllocOp.
Value allocated = b.create<memref::AllocOp>(
- loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
+ loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
return allocated;
}
/// Creates a memref deallocation. The given memref buffer must have been
/// allocated using `createAlloc`.
-LogicalResult
-bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
- const BufferizationOptions &options) {
- if (options.deallocationFn)
- return (*options.deallocationFn)(b, loc, allocatedBuffer);
+LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
+ Value allocatedBuffer) const {
+ if (deallocationFn)
+ return (*deallocationFn)(b, loc, allocatedBuffer);
// Default buffer deallocation via DeallocOp.
b.create<memref::DeallocOp>(loc, allocatedBuffer);
@@ -523,11 +520,10 @@ FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
}
/// Create a memory copy between two memref buffers.
-LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
- Value from, Value to,
- const BufferizationOptions &options) {
- if (options.memCpyFn)
- return (*options.memCpyFn)(b, loc, from, to);
+LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
+ Value from, Value to) const {
+ if (memCpyFn)
+ return (*memCpyFn)(b, loc, from, to);
b.create<memref::CopyOp>(loc, from, to);
return success();
@@ -557,8 +553,8 @@ bufferization::createAllocDeallocOps(Operation *op,
Block *block = allocaOp->getBlock();
rewriter.setInsertionPoint(allocaOp);
FailureOr<Value> alloc =
- createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
- allocaOp.dynamicSizes(), options);
+ options.createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
+ allocaOp.dynamicSizes());
if (failed(alloc))
return WalkResult::interrupt();
rewriter.replaceOp(allocaOp, *alloc);
@@ -571,7 +567,7 @@ bufferization::createAllocDeallocOps(Operation *op,
// Create dealloc.
rewriter.setInsertionPoint(block->getTerminator());
- if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
+ if (failed(options.createDealloc(rewriter, alloc->getLoc(), *alloc)))
return WalkResult::interrupt();
return WalkResult::advance();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 25d3df2fac287..dda46141e5382 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -495,7 +495,7 @@ struct FuncOpInterface
// Note: This copy will fold away. It must be inserted here to ensure
// that `returnVal` still has at least one use and does not fold away.
if (failed(
- createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options)))
+ options.createMemCpy(rewriter, loc, toMemrefOp, equivBbArg)))
return funcOp->emitError("could not generate copy for bbArg");
continue;
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 39af7d337d054..337b0aa57ea3e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -363,8 +363,8 @@ static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
// TODO: We should rollback, but for now just assume that this always
// succeeds.
assert(yieldedAlloc.hasValue() && "could not create alloc");
- LogicalResult copyStatus = bufferization::createMemCpy(
- rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc, state.getOptions());
+ LogicalResult copyStatus = state.getOptions().createMemCpy(
+ rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc);
(void)copyStatus;
assert(succeeded(copyStatus) && "could not create memcpy");
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index efd2de7978e94..3ec52e94423f6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -320,8 +320,8 @@ struct ExtractSliceOpInterface
if (!inplace) {
// Do not copy if the copied data is never read.
if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
- if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
- alloc, state.getOptions())))
+ if (failed(state.getOptions().createMemCpy(
+ rewriter, extractSliceOp.getLoc(), subView, alloc)))
return failure();
subView = alloc;
}
@@ -718,8 +718,8 @@ struct InsertSliceOpInterface
// tensor.extract_slice, the copy operation will eventually fold away.
auto srcMemref =
state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
- if (failed(srcMemref) || failed(createMemCpy(rewriter, loc, *srcMemref,
- subView, state.getOptions())))
+ if (failed(srcMemref) || failed(state.getOptions().createMemCpy(
+ rewriter, loc, *srcMemref, subView)))
return failure();
replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
More information about the Mlir-commits
mailing list