[Mlir-commits] [mlir] 3f6c0fb - [mlir][linalg][bufferize] Add MemCpyFn to AllocationCallbacks struct
Matthias Springer
llvmlistbot at llvm.org
Thu Nov 4 18:49:22 PDT 2021
Author: Matthias Springer
Date: 2021-11-05T10:44:12+09:00
New Revision: 3f6c0fb2ff750c9246aee41eb8ad086518752edf
URL: https://github.com/llvm/llvm-project/commit/3f6c0fb2ff750c9246aee41eb8ad086518752edf
DIFF: https://github.com/llvm/llvm-project/commit/3f6c0fb2ff750c9246aee41eb8ad086518752edf.diff
LOG: [mlir][linalg][bufferize] Add MemCpyFn to AllocationCallbacks struct
This in preparation of decoupling BufferizableOpInterface, Comprehensive Bufferize and dialects.
The goal of this CL is to make `getResultBuffer` (and other `bufferize` functions) independent of `LinalgOps`.
Differential Revision: https://reviews.llvm.org/D112907
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index e3b59d5daa60..94cb52b4bca5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -172,16 +172,28 @@ Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
/// `defaultAllocationFn`.
void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer);
+/// Default memory copy function that is used by the comprehensive bufferization
+/// pass. Creates a `linalg.copy` op.
+void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to);
+
/// Callback functions that are used by the comprehensive bufferization pass to
/// allocate/deallocate memory. These default to use the
/// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the
/// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned
/// by the `allocationFn`.
struct AllocationCallbacks {
- std::function<Optional<Value>(OpBuilder &b, Location loc, Value shapedValue)>
- allocationFn = defaultAllocationFn;
- std::function<void(OpBuilder &b, Location loc, Value v)> deallocationFn =
- defaultDeallocationFn;
+ using AllocationFn =
+ std::function<Optional<Value>(OpBuilder &, Location, Value)>;
+ using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
+ using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
+
+ AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
+ MemCpyFn copyFn)
+ : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
+
+ AllocationFn allocationFn;
+ DeallocationFn deallocationFn;
+ MemCpyFn memCpyFn;
};
/// Bufferize one particular op.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index dad19a94080d..a6b6e0131d19 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1274,7 +1274,7 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
if (!skipCopy) {
// Set insertion point now that potential alloc/dealloc are introduced.
b.setInsertionPoint(op);
- b.create<CopyOp>(loc, operandBuffer, resultBuffer);
+ allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
}
return resultBuffer;
}
@@ -1669,6 +1669,11 @@ void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc,
b.create<memref::DeallocOp>(loc, allocatedBuffer);
}
+void mlir::linalg::defaultMemCpyFn(OpBuilder &b, Location loc, Value from,
+ Value to) {
+ b.create<CopyOp>(loc, from, to);
+}
+
LogicalResult mlir::linalg::bufferizeOp(
Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
AllocationCallbacks allocationFns,
@@ -2258,11 +2263,13 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
// command line option. So this is set up at the start of the pass.
if (useAlloca) {
AllocationCallbacks allocaAllocationFns = {
- allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}};
+ allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {},
+ defaultMemCpyFn};
allocationFns =
std::make_unique<AllocationCallbacks>(std::move(allocaAllocationFns));
} else {
- allocationFns = std::make_unique<AllocationCallbacks>();
+ allocationFns = std::make_unique<AllocationCallbacks>(
+ defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
}
}
ModuleOp moduleOp = getOperation();
@@ -3222,7 +3229,7 @@ struct ExtractSliceOpInterface
if (alloc) {
// Do not copy if the copied data is never read.
if (isValueRead(extractSliceOp.result()))
- b.create<CopyOp>(extractSliceOp.getLoc(), subView, alloc);
+ allocationFn.memCpyFn(b, extractSliceOp.getLoc(), subView, alloc);
subView = alloc;
}
@@ -3344,7 +3351,7 @@ struct InsertSliceOpInterface
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
// Insert new alias.
aliasInfo.insertNewBufferAlias(subView, dstMemref);
- b.create<CopyOp>(insertSliceOp.getLoc(), srcMemref, subView);
+ allocationFn.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, subView);
}
map(bvm, insertSliceOp.result(), dstMemref);
More information about the Mlir-commits
mailing list