[Mlir-commits] [mlir] e312fc4 - [mlir][Linalg] Add layout specification support to bufferization.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jul 13 03:22:32 PDT 2021


Author: Nicolas Vasilache
Date: 2021-07-13T10:22:18Z
New Revision: e312fc49ae1ec86999676edc9c02a4ac0bc39cec

URL: https://github.com/llvm/llvm-project/commit/e312fc49ae1ec86999676edc9c02a4ac0bc39cec
DIFF: https://github.com/llvm/llvm-project/commit/e312fc49ae1ec86999676edc9c02a4ac0bc39cec.diff

LOG: [mlir][Linalg] Add layout specification support to bufferization.

Previously, linalg bufferization always had to be conservative at function boundaries and assume the most dynamic strided memref layout.
This revision introduce the mechanism to specify a  linalg.buffer_layout function argument attribute that carries an affine map used to set a less pessimistic layout.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D105859

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 49ececc0790aa..ce36323dbb28a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -48,6 +48,11 @@ def Linalg_Dialect : Dialect {
     constexpr const static ::llvm::StringLiteral
       kInplaceableAttrName = "linalg.inplaceable";
 
+    /// Attribute name used to mark the bufferization layout for region
+    // arguments during linalg comprehensive bufferization.
+    constexpr const static ::llvm::StringLiteral
+      kBufferLayoutAttrName = "linalg.buffer_layout";
+
     using RegionBuilderFunType =
       llvm::function_ref<void(ImplicitLocOpBuilder &b, Block &)>;
     RegionBuilderFunType getRegionBuilder(StringRef name) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index be39eec14a993..333a129b7efae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -324,8 +324,10 @@ setInPlaceFuncArgument(BlockArgument bbArg,
 
 /// Remove the attribute that triggers inplace bufferization on a FuncOp
 /// argument `bbArg`.
-static void removeInPlaceFuncArgument(BlockArgument bbArg) {
+static void removeBufferizationFuncArguments(BlockArgument bbArg) {
   auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
+  funcOp.removeArgAttr(bbArg.getArgNumber(),
+                       LinalgDialect::kBufferLayoutAttrName);
   funcOp.removeArgAttr(bbArg.getArgNumber(),
                        LinalgDialect::kInplaceableAttrName);
 }
@@ -2608,6 +2610,96 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
   (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
 }
 
+static void
+foreachCaller(const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
+              FuncOp callee, llvm::function_ref<void(Operation *)> doit) {
+  auto itCallers = callerMap.find(callee);
+  if (itCallers == callerMap.end())
+    return;
+  for (Operation *caller : itCallers->second)
+    doit(caller);
+}
+
+/// Postprocess the linalg.buffer_layout annotation across function boundaries.
+/// This is a purely mechanical process that may later become part of a
+/// separate pass with its own layout assignment heuristic.
+static void layoutPostProcessing(ModuleOp moduleOp) {
+  SmallVector<FuncOp> orderedFuncOps;
+  DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
+  auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
+  assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");
+
+  for (FuncOp funcOp : orderedFuncOps) {
+    DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
+    foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+      operandsPerCaller.try_emplace(caller, SmallVector<Value>());
+    });
+
+    SmallVector<Type> argumentTypes;
+    // Iterate on each function argument and check it it was marked with a
+    // desired layout.
+    for (auto it : llvm::enumerate(funcOp.getType().getInputs())) {
+      int argNumber = it.index();
+      Type inputType = it.value();
+      auto memrefType = inputType.dyn_cast<MemRefType>();
+      auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
+          argNumber, LinalgDialect::kBufferLayoutAttrName);
+      AffineMap desiredLayoutMap =
+          layoutAttr ? layoutAttr.getValue() : AffineMap();
+      AffineMap currentLayoutMap =
+          memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
+      if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
+        argumentTypes.push_back(inputType);
+        foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+          operandsPerCaller.find(caller)->getSecond().push_back(
+              caller->getOperand(argNumber));
+        });
+        continue;
+      }
+
+      // Compute the buffer type with desired layout and add to input argument
+      // types.
+      MemRefType desiredMemrefType = MemRefType::get(
+          memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
+      argumentTypes.push_back(desiredMemrefType);
+
+      // If funcOp's body is not empty, change the bbArg type and propagate.
+      if (!funcOp.body().empty()) {
+        BlockArgument bbArg = funcOp.getArgument(argNumber);
+        bbArg.setType(desiredMemrefType);
+        OpBuilder b(bbArg.getContext());
+        b.setInsertionPointToStart(bbArg.getOwner());
+        // Cast back to the original memrefType and let it canonicalize.
+        Value cast =
+            b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
+        bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
+      }
+
+      // Cast to desired buffer type on all callers to `funcOp`.
+      // TODO: on the callee side, this may even have to trigger a copy to
+      // change the layout. For now let the memref::CastOp fail to verify in
+      // such cases.
+      auto castArg = [&](Operation *caller) {
+        OpBuilder b(caller);
+        Value newOperand = b.create<memref::CastOp>(
+            funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
+        operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
+      };
+      foreachCaller(callerMap, funcOp, castArg);
+    }
+
+    // Set operands with cast buffer on all callers to `funcOp`.
+    foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+      caller->setOperands(operandsPerCaller.lookup(caller));
+    });
+
+    // Finally set the funcOp type to update the arguments.
+    auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
+                                         funcOp.getType().getResults());
+    funcOp.setType(newFuncType);
+  }
+}
+
 void LinalgComprehensiveModuleBufferize::runOnOperation() {
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);
@@ -2672,12 +2764,16 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
     }
   }
 
-  // Post-pass cleanup of inplaceable attributes.
+  // Perform a post-processing pass of layout modification at function boundary
+  // according to the kBufferLayoutAttrName.
+  layoutPostProcessing(moduleOp);
+
+  // Post-pass cleanup of inplaceable and buffer_layout attributes.
   moduleOp.walk(
       [&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); });
   moduleOp.walk([&](FuncOp op) {
     for (BlockArgument bbArg : op.getArguments())
-      removeInPlaceFuncArgument(bbArg);
+      removeBufferizationFuncArguments(bbArg);
   });
 
   OpPassManager cleanupPipeline(OpPassManager("module"));

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index b29cf6e81f92c..56278ef1a5dae 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -555,3 +555,43 @@ func @tiled_dot(%A: tensor<?xf32>, %B: tensor<?xf32>, %c: tensor<f32> {linalg.in
   // CHECK-NOT: tensor
   return %1 : tensor<f32>
 }
+
+// -----
+
+// CHECK: #[[$DYNAMIC:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+// CHECK: func private @external_func(memref<?xf32, #[[$DYNAMIC]]>)
+func private @external_func(tensor<?xf32>)
+
+//      CHECK: func @callee(
+// CHECK-SAME:   %[[A:[0-9a-zA-Z]*]]: memref<?xf32>
+// CHECK-SAME:   %[[B:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
+// CHECK-SAME:   %[[C:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
+func @callee(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
+             %B : tensor<?xf32>,
+             %C : tensor<?xf32>) {
+// CHECK-NEXT: %[[CASTED:.*]] = memref.cast %[[A]] : memref<?xf32> to memref<?xf32, #[[$DYNAMIC]]>
+// CHECK-NEXT: call @external_func(%[[CASTED]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
+  call @external_func(%A) : (tensor<?xf32>) -> ()
+
+// CHECK-NEXT: call @external_func(%[[B]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
+  call @external_func(%B) : (tensor<?xf32>) -> ()
+
+// CHECK-NEXT: call @external_func(%[[C]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
+  call @external_func(%C) : (tensor<?xf32>) -> ()
+
+  return
+}
+
+//      CHECK: func @entry(
+// CHECK-SAME:   %[[A:[0-9a-zA-Z]*]]: memref<?xf32>
+// CHECK-SAME:   %[[B:[0-9a-zA-Z]*]]: memref<?xf32>
+// CHECK-SAME:   %[[C:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
+func @entry(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
+            %B : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
+            %C : tensor<?xf32>) {
+// CHECK-NEXT: %[[CASTED_B:.*]] = memref.cast %[[B]] : memref<?xf32> to memref<?xf32, #[[$DYNAMIC]]>
+// CHECK-NEXT: call @callee(%[[A]], %[[CASTED_B]], %[[C]])
+  call @callee(%A, %B, %C) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list