[Mlir-commits] [mlir] bd1d87e - [mlir][bufferization][NFC] Remove layout post processing step

Matthias Springer llvmlistbot at llvm.org
Fri Apr 22 03:15:24 PDT 2022


Author: Matthias Springer
Date: 2022-04-22T18:49:47+09:00
New Revision: bd1d87e3d1806da893e5564add573e8ff69e5aa3

URL: https://github.com/llvm/llvm-project/commit/bd1d87e3d1806da893e5564add573e8ff69e5aa3
DIFF: https://github.com/llvm/llvm-project/commit/bd1d87e3d1806da893e5564add573e8ff69e5aa3.diff

LOG: [mlir][bufferization][NFC] Remove layout post processing step

The layout postprocessing step was removed and is now part of the FuncOp bufferization. If the user specified a certain layout map for a tensor function arg, use that layout map directly when bufferizing the function signature. Previously, the bufferization used a generic layout map for every tensor function arg and then updated function signatures and CallOps in a separate step.

Differential Revision: https://reviews.llvm.org/D122228

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index b798728b33dd6..bf7dfe7240d15 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Module Bufferization is an extension of Comprehensive Bufferize that
+// Module Bufferization is an extension of One-Shot Bufferize that
 // bufferizes function boundaries. It provides `BufferizableOpInterface`
 // implementations for FuncOp, CallOp and ReturnOp.
 //
@@ -357,14 +357,27 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
 }
 
 /// Return the index-th bufferized function argument type. This assumes that the
-/// specified argument is a tensor.
+/// specified argument is a tensor. If the tensor is ranked, a layout map may be
+/// specified by the user. If no layout map is specified, a fully dynamic map is
+/// used.
 static BaseMemRefType
 getBufferizedFunctionArgType(func::FuncOp funcOp, int64_t index,
                              const BufferizationOptions &options) {
   auto tensorType =
       funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
   assert(tensorType && "expected TensorType");
-  return getMemRefType(tensorType, options);
+  BaseMemRefType memrefType = getMemRefType(tensorType, options);
+
+  auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
+      index, BufferizableOpInterface::kBufferLayoutAttrName);
+  if (!layoutAttr)
+    return memrefType;
+
+  auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
+  assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
+  return MemRefType::get(
+      rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
+      layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
 }
 
 /// Gather equivalence info of CallOps.
@@ -451,103 +464,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
   return success();
 }
 
-static void foreachCaller(const FuncCallerMap &callerMap, func::FuncOp callee,
-                          llvm::function_ref<void(Operation *)> doit) {
-  auto itCallers = callerMap.find(callee);
-  if (itCallers == callerMap.end())
-    return;
-  for (Operation *caller : itCallers->second)
-    doit(caller);
-}
-
-/// Postprocess the linalg.buffer_layout annotation across function boundaries.
-/// This is a purely mechanical process that may later become part of a
-/// separate pass with its own layout assignment heuristic.
-static void layoutPostProcessing(ModuleOp moduleOp) {
-  SmallVector<func::FuncOp> orderedFuncOps;
-  DenseMap<func::FuncOp, DenseSet<Operation *>> callerMap;
-  auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
-  (void)res;
-  assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");
-
-  for (func::FuncOp funcOp : orderedFuncOps) {
-    DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
-    foreachCaller(callerMap, funcOp, [&](Operation *caller) {
-      operandsPerCaller.try_emplace(caller, SmallVector<Value>());
-    });
-
-    SmallVector<Type> argumentTypes;
-    // Iterate on each function argument and check it it was marked with a
-    // desired layout.
-    for (const auto &it :
-         llvm::enumerate(funcOp.getFunctionType().getInputs())) {
-      int argNumber = it.index();
-      Type inputType = it.value();
-      auto memrefType = inputType.dyn_cast<MemRefType>();
-      auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
-          argNumber, BufferizableOpInterface::kBufferLayoutAttrName);
-      AffineMap desiredLayoutMap =
-          layoutAttr ? layoutAttr.getValue() : AffineMap();
-      AffineMap currentLayoutMap =
-          memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
-      if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
-        argumentTypes.push_back(inputType);
-        foreachCaller(callerMap, funcOp, [&](Operation *caller) {
-          operandsPerCaller.find(caller)->getSecond().push_back(
-              caller->getOperand(argNumber));
-        });
-        continue;
-      }
-
-      // Compute the buffer type with desired layout and add to input argument
-      // types.
-      MemRefType desiredMemrefType = MemRefType::get(
-          memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
-      argumentTypes.push_back(desiredMemrefType);
-
-      // If funcOp's body is not empty, change the bbArg type and propagate.
-      if (!funcOp.getBody().empty()) {
-        BlockArgument bbArg = funcOp.getArgument(argNumber);
-        bbArg.setType(desiredMemrefType);
-        OpBuilder b(bbArg.getContext());
-        b.setInsertionPointToStart(bbArg.getOwner());
-        assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) &&
-               "layoutPostProcessing: cast incompatible");
-        // Cast back to the original memrefType and let it canonicalize.
-        Value cast =
-            b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
-        bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
-      }
-
-      // Cast to desired buffer type on all callers to `funcOp`.
-      // TODO: on the callee side, this may even have to trigger a copy to
-      // change the layout. For now let the memref::CastOp fail to verify in
-      // such cases.
-      auto castArg = [&](Operation *caller) {
-        OpBuilder b(caller);
-        assert(
-            memref::CastOp::areCastCompatible(
-                caller->getOperand(argNumber).getType(), desiredMemrefType) &&
-            "layoutPostProcessing.2: cast incompatible");
-        Value newOperand = b.create<memref::CastOp>(
-            funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
-        operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
-      };
-      foreachCaller(callerMap, funcOp, castArg);
-    }
-
-    // Set operands with cast buffer on all callers to `funcOp`.
-    foreachCaller(callerMap, funcOp, [&](Operation *caller) {
-      caller->setOperands(operandsPerCaller.lookup(caller));
-    });
-
-    // Finally set the funcOp type to update the arguments.
-    auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
-                                         funcOp.getFunctionType().getResults());
-    funcOp.setType(newFuncType);
-  }
-}
-
 namespace mlir {
 namespace linalg {
 namespace comprehensive_bufferize {
@@ -1111,10 +1027,6 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
   if (failed(finalizeBuffers(moduleOp, options)))
     return failure();
 
-  // Perform a post-processing pass of layout modification at function boundary
-  // according to the kBufferLayoutAttrName.
-  layoutPostProcessing(moduleOp);
-
   // Post-pass cleanup of inplaceable and buffer_layout attributes.
   moduleOp.walk([&](func::FuncOp op) {
     for (BlockArgument bbArg : op.getArguments())


        


More information about the Mlir-commits mailing list