[Mlir-commits] [mlir] bd1d87e - [mlir][bufferization][NFC] Remove layout post processing step
Matthias Springer
llvmlistbot at llvm.org
Fri Apr 22 03:15:24 PDT 2022
Author: Matthias Springer
Date: 2022-04-22T18:49:47+09:00
New Revision: bd1d87e3d1806da893e5564add573e8ff69e5aa3
URL: https://github.com/llvm/llvm-project/commit/bd1d87e3d1806da893e5564add573e8ff69e5aa3
DIFF: https://github.com/llvm/llvm-project/commit/bd1d87e3d1806da893e5564add573e8ff69e5aa3.diff
LOG: [mlir][bufferization][NFC] Remove layout post processing step
The layout postprocessing step was removed and is now part of the FuncOp bufferization. If the user specified a certain layout map for a tensor function arg, use that layout map directly when bufferizing the function signature. Previously, the bufferization used a generic layout map for every tensor function arg and then updated function signatures and CallOps in a separate step.
Differential Revision: https://reviews.llvm.org/D122228
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index b798728b33dd6..bf7dfe7240d15 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// Module Bufferization is an extension of Comprehensive Bufferize that
+// Module Bufferization is an extension of One-Shot Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
// implementations for FuncOp, CallOp and ReturnOp.
//
@@ -357,14 +357,27 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
}
/// Return the index-th bufferized function argument type. This assumes that the
-/// specified argument is a tensor.
+/// specified argument is a tensor. If the tensor is ranked, a layout map may be
+/// specified by the user. If no layout map is specified, a fully dynamic map is
+/// used.
static BaseMemRefType
getBufferizedFunctionArgType(func::FuncOp funcOp, int64_t index,
const BufferizationOptions &options) {
auto tensorType =
funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
assert(tensorType && "expected TensorType");
- return getMemRefType(tensorType, options);
+ BaseMemRefType memrefType = getMemRefType(tensorType, options);
+
+ auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
+ index, BufferizableOpInterface::kBufferLayoutAttrName);
+ if (!layoutAttr)
+ return memrefType;
+
+ auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
+ assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
+ return MemRefType::get(
+ rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
+ layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
}
/// Gather equivalence info of CallOps.
@@ -451,103 +464,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
return success();
}
-static void foreachCaller(const FuncCallerMap &callerMap, func::FuncOp callee,
- llvm::function_ref<void(Operation *)> doit) {
- auto itCallers = callerMap.find(callee);
- if (itCallers == callerMap.end())
- return;
- for (Operation *caller : itCallers->second)
- doit(caller);
-}
-
-/// Postprocess the linalg.buffer_layout annotation across function boundaries.
-/// This is a purely mechanical process that may later become part of a
-/// separate pass with its own layout assignment heuristic.
-static void layoutPostProcessing(ModuleOp moduleOp) {
- SmallVector<func::FuncOp> orderedFuncOps;
- DenseMap<func::FuncOp, DenseSet<Operation *>> callerMap;
- auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
- (void)res;
- assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");
-
- for (func::FuncOp funcOp : orderedFuncOps) {
- DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
- foreachCaller(callerMap, funcOp, [&](Operation *caller) {
- operandsPerCaller.try_emplace(caller, SmallVector<Value>());
- });
-
- SmallVector<Type> argumentTypes;
- // Iterate on each function argument and check it it was marked with a
- // desired layout.
- for (const auto &it :
- llvm::enumerate(funcOp.getFunctionType().getInputs())) {
- int argNumber = it.index();
- Type inputType = it.value();
- auto memrefType = inputType.dyn_cast<MemRefType>();
- auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
- argNumber, BufferizableOpInterface::kBufferLayoutAttrName);
- AffineMap desiredLayoutMap =
- layoutAttr ? layoutAttr.getValue() : AffineMap();
- AffineMap currentLayoutMap =
- memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
- if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
- argumentTypes.push_back(inputType);
- foreachCaller(callerMap, funcOp, [&](Operation *caller) {
- operandsPerCaller.find(caller)->getSecond().push_back(
- caller->getOperand(argNumber));
- });
- continue;
- }
-
- // Compute the buffer type with desired layout and add to input argument
- // types.
- MemRefType desiredMemrefType = MemRefType::get(
- memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
- argumentTypes.push_back(desiredMemrefType);
-
- // If funcOp's body is not empty, change the bbArg type and propagate.
- if (!funcOp.getBody().empty()) {
- BlockArgument bbArg = funcOp.getArgument(argNumber);
- bbArg.setType(desiredMemrefType);
- OpBuilder b(bbArg.getContext());
- b.setInsertionPointToStart(bbArg.getOwner());
- assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) &&
- "layoutPostProcessing: cast incompatible");
- // Cast back to the original memrefType and let it canonicalize.
- Value cast =
- b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
- bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
- }
-
- // Cast to desired buffer type on all callers to `funcOp`.
- // TODO: on the callee side, this may even have to trigger a copy to
- // change the layout. For now let the memref::CastOp fail to verify in
- // such cases.
- auto castArg = [&](Operation *caller) {
- OpBuilder b(caller);
- assert(
- memref::CastOp::areCastCompatible(
- caller->getOperand(argNumber).getType(), desiredMemrefType) &&
- "layoutPostProcessing.2: cast incompatible");
- Value newOperand = b.create<memref::CastOp>(
- funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
- operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
- };
- foreachCaller(callerMap, funcOp, castArg);
- }
-
- // Set operands with cast buffer on all callers to `funcOp`.
- foreachCaller(callerMap, funcOp, [&](Operation *caller) {
- caller->setOperands(operandsPerCaller.lookup(caller));
- });
-
- // Finally set the funcOp type to update the arguments.
- auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
- funcOp.getFunctionType().getResults());
- funcOp.setType(newFuncType);
- }
-}
-
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
@@ -1111,10 +1027,6 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
if (failed(finalizeBuffers(moduleOp, options)))
return failure();
- // Perform a post-processing pass of layout modification at function boundary
- // according to the kBufferLayoutAttrName.
- layoutPostProcessing(moduleOp);
-
// Post-pass cleanup of inplaceable and buffer_layout attributes.
moduleOp.walk([&](func::FuncOp op) {
for (BlockArgument bbArg : op.getArguments())
More information about the Mlir-commits
mailing list