[Mlir-commits] [mlir] cc45a13 - [mlir][linalg][bufferize] LinalgOps can bufferize inplace with input args

Matthias Springer llvmlistbot at llvm.org
Thu Dec 9 04:59:34 PST 2021


Author: Matthias Springer
Date: 2021-12-09T21:54:54+09:00
New Revision: cc45a13422ca78ea750cce0fc502ac44b784d811

URL: https://github.com/llvm/llvm-project/commit/cc45a13422ca78ea750cce0fc502ac44b784d811
DIFF: https://github.com/llvm/llvm-project/commit/cc45a13422ca78ea750cce0fc502ac44b784d811.diff

LOG: [mlir][linalg][bufferize] LinalgOps can bufferize inplace with input args

LinalgOp results usually bufferize inplace with output args. With this change, they may buffer inplace with input args if the value of the output arg is not used in the computation.

Differential Revision: https://reviews.llvm.org/D115022

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 81d865042d7f..5abfbe559d0a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -21,40 +21,12 @@ namespace {
 
 // TODO: Ops in the linalg dialect can directly implement this interface.
 
-/// Helper function for LinalgOp bufferization.
-/// When allocating a new buffer, analyze whether `op` wants to read form that
-/// buffer. Only in that case, a copy of the result buffer may be needed.
-static LogicalResult
-allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
-                          SmallVectorImpl<Value> &resultBuffers,
-                          BufferizationState &state) {
-  // Take a guard before anything else.
-  OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(op);
-
-  // TODO: provide the proper interface to iterate on OpResults and get the
-  // matching OpOperands.
-  for (OpOperand *opOperand : op.getOutputOperands()) {
-    OpResult opResult = cast<BufferizableOpInterface>(op.getOperation())
-                            .getAliasingOpResult(*opOperand);
-    assert(opResult && "could not find correspond OpResult");
-    Value resultBuffer = state.getResultBuffer(opResult);
-    if (!resultBuffer)
-      return failure();
-    resultBuffers.push_back(resultBuffer);
-  }
-
-  if (op->getNumResults())
-    state.mapBuffer(op->getResults(), resultBuffers);
-
-  return success();
-}
-
 /// Generic conversion for any LinalgOp on tensors.
 static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
                                        BufferizationState &state) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(op);
 
   // Nothing to do. This op is already bufferized.
   if (op.hasBufferSemantics())
@@ -65,7 +37,6 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
   if (!op.hasTensorSemantics())
     return op->emitError() << "op does not have tensor semantics";
 
-  Location loc = op.getLoc();
   SmallVector<Value> newInputBuffers;
   newInputBuffers.reserve(op.getNumInputs());
   for (OpOperand *opOperand : op.getInputOperands()) {
@@ -75,10 +46,17 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
     }
     newInputBuffers.push_back(state.lookupBuffer(opOperand->get()));
   }
+
   SmallVector<Value> newOutputBuffers;
-  // Try to allocate new buffers depending on op's inplace semantics.
-  if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, state)))
-    return failure();
+  for (OpOperand *opOperand : op.getOutputOperands()) {
+    OpResult opResult = op.getTiedOpResult(opOperand);
+    assert(opResult && "could not find correspond OpResult");
+    Value resultBuffer = state.getResultBuffer(opResult);
+    if (!resultBuffer)
+      return failure();
+    newOutputBuffers.push_back(resultBuffer);
+    state.mapBuffer(opResult, resultBuffer);
+  }
 
   // Clone the newly bufferized op.
   SmallVector<Value> newOperands = newInputBuffers;
@@ -87,53 +65,103 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
   // Set insertion point now that potential alloc/dealloc are introduced.
   b.setInsertionPoint(op);
   auto bufferizedOp = cast<LinalgOp>(
-      op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands));
-
-  // Replace the results of the old op with the new output buffers.
-  if (op->getNumResults())
-    state.mapBuffer(op->getResults(), newOutputBuffers);
+      op.clone(b, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
 
   // The original op will be DCE'd away later.
   return comprehensive_bufferize::bufferize(bufferizedOp.getBlock(), state);
 }
 
+/// Linalg OpResults usually bufferize inplace with their tied (output
+/// OpOperands. However, if an output OpOperand is not used in the computation,
+/// it is better to bufferize inplace with an actually used input OpOperand;
+/// less memory will be touched that way.
+///
+/// Example:
+/// O(i, j) = A(i, j) + B(j)  --> bufferizes inplace to:  A(i, j) += B(j)
+///
+/// O(i, j) = A(j, i) + B(j)  --> cannot bufferize inplace with A because
+///                               indexing maps are not identical
+///
+/// O(i, j) += A(i, j) + B(j) --> Output is used in computation.
+/// This could bufferize inplace with A:
+/// A(i, j) += O(i, j) + B(j)
+/// However, we choose to bufferize inplace with O here, as there is no clear
+/// benefit of choosing A. TODO: We may want to consider both options and make
+/// an informed decision during analysis in the future.
+static DenseMap<OpOperand *, OpResult> computeAliasingPairs(LinalgOp op) {
+  DenseMap<OpOperand *, OpResult> mapping;
+  for (OpResult opResult : op->getOpResults()) {
+    OpOperand *tiedOperand =
+        op.getOutputTensorOperands()[opResult.getResultNumber()];
+    AffineMap outputIndexingMap = op.getTiedIndexingMap(tiedOperand);
+    bool onlyParallelIterators = op.getNumParallelLoops() == op.getNumLoops();
+    bool tiedOperandUsed = op.payloadUsesValueFromOperand(tiedOperand);
+
+    // If the output arg is used in the computation or at least one iterator is
+    // not parallel, try to bufferize inplace with the corresponding output
+    // tensor.
+    if (tiedOperandUsed || !onlyParallelIterators) {
+      mapping[tiedOperand] = opResult;
+      continue;
+    }
+
+    // Otherwise, try to bufferize inplace with one of the inputs.
+    OpOperand *chosenOperand = nullptr;
+    for (OpOperand *opOperand : op.getInputTensorOperands()) {
+      if (opOperand->get().getType() != opResult.getType())
+        continue;
+      if (!op.payloadUsesValueFromOperand(opOperand))
+        continue;
+      if (op.getTiedIndexingMap(opOperand) != outputIndexingMap)
+        continue;
+      // No other OpResult bufferizes aliases with this OpOperand.
+      if (mapping.count(opOperand))
+        continue;
+      assert(op.getTiedIndexingMap(opOperand).isProjectedPermutation() &&
+             "expected projected permutation");
+      chosenOperand = opOperand;
+      break;
+    }
+
+    // No suitable input tensor found. Use output tensor.
+    // TODO: This operand could bufferize inplace with OpOperands that have the
+    // correct type, even if they are not used inside the computation.
+    if (!chosenOperand)
+      chosenOperand = tiedOperand;
+
+    mapping[chosenOperand] = opResult;
+  }
+  return mapping;
+}
+
 template <typename OpTy>
 struct LinalgOpInterface
     : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
                                                     OpTy> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
     auto genericOp = cast<linalg::LinalgOp>(op);
-    return (genericOp.isInputTensor(&opOperand) ||
-            genericOp.isInitTensor(&opOperand)) &&
-           genericOp.payloadUsesValueFromOperand(&opOperand);
+    return genericOp.payloadUsesValueFromOperand(&opOperand);
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
-    auto genericOp = cast<linalg::LinalgOp>(op);
-    return genericOp.isOutputTensor(&opOperand);
+    auto bufferizableOp = cast<BufferizableOpInterface>(op);
+    return static_cast<bool>(bufferizableOp.getAliasingOpResult(opOperand));
   }
 
   SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
                                                 OpResult opResult) const {
     auto genericOp = cast<linalg::LinalgOp>(op);
-    return {genericOp.getOutputTensorOperands()[opResult.getResultNumber()]};
+    DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
+    for (OpOperand *opOperand : genericOp.getInputAndOutputOperands())
+      if (pairs[opOperand] == opResult)
+        return {opOperand};
+    return {};
   }
 
   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
     auto genericOp = cast<linalg::LinalgOp>(op);
-    if (!opOperand.get().getType().isa<RankedTensorType>())
-      return OpResult();
-    // For now assume inputs are never inplaceable.
-    // TODO: refine this.
-    if (opOperand.getOperandNumber() < genericOp.getNumInputs())
-      return OpResult();
-    int64_t outputOperandIndex =
-        opOperand.getOperandNumber() - genericOp.getNumInputs();
-    int64_t numOutputBuffers = 0;
-    for (unsigned idx = 0; idx < outputOperandIndex; ++idx)
-      if (!genericOp.getOutputOperand(idx)->get().getType().isa<TensorType>())
-        ++numOutputBuffers;
-    return genericOp->getResult(outputOperandIndex - numOutputBuffers);
+    DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
+    return pairs[&opOperand];
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 0e1e7602b2af..f2fa7ce3e4bf 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -940,7 +940,7 @@ func @linalg_op_same_out_tensors(
     %t2: tensor<?xf32> {linalg.inplaceable = true}) -> (tensor<?xf32>, tensor<?xf32>){
 
   //      CHECK: linalg.generic
-  // CHECK-SAME: {__inplace_results_attr__ = ["true", "false"]
+  // CHECK-SAME: {__inplace_results_attr__ = ["true", "true"]
   %o:2 = linalg.generic #trait ins(%t1 : tensor<?xf32>)
                                outs (%t2, %t2 : tensor<?xf32>, tensor<?xf32>) {
       ^bb(%0: f32, %1: f32, %2 : f32) :
@@ -948,12 +948,45 @@ func @linalg_op_same_out_tensors(
     } -> (tensor<?xf32>, tensor<?xf32>)
 
   //      CHECK: return
-  // CHECK-SAME: {__equivalent_func_args__ = [1, -1]}
+  // CHECK-SAME: {__equivalent_func_args__ = [0, 1]}
   return %o#0, %o#1 : tensor<?xf32>, tensor<?xf32>
 }
 
 // -----
 
+#accesses = [
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (i)>
+]
+#trait = {
+  indexing_maps = #accesses,
+  iterator_types = ["parallel"]
+}
+
+// CHECK-LABEL: func @linalg_op_same_out_tensors_2
+func @linalg_op_same_out_tensors_2(
+    %t1: tensor<?xf32> {linalg.inplaceable = true},
+    %t2: tensor<?xf32> {linalg.inplaceable = true})
+        -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>){
+
+  //      CHECK: linalg.generic
+  // CHECK-SAME: {__inplace_results_attr__ = ["true", "true", "false"]
+  %o:3 = linalg.generic #trait
+          ins(%t1 : tensor<?xf32>)
+          outs (%t2, %t2, %t2 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
+      ^bb(%0: f32, %1: f32, %2 : f32, %3 : f32) :
+        linalg.yield %0, %0, %0 : f32, f32, f32
+    } -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, 1, -1]}
+  return %o#0, %o#1, %o#2 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @double_insert_slice_into_alias
 func @double_insert_slice_into_alias(
     %v1: vector<32x90xf32>,

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index bd5f73f29df7..79fde5f26256 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1036,3 +1036,75 @@ func @gather_like(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xi32>,
 //  CHECK-SAME:       outs(%[[ARG2]] :
 //       CHECK:     %[[YIELD:.+]] = memref.load %[[ARG0]]
 //       CHECK:     linalg.yield %[[YIELD]]
+
+// -----
+
+// CHECK-LABEL: func @linalg_op_bufferizes_inplace_with_input
+//  CHECK-SAME:     %[[t1:.*]]: memref<?x?xf32, #{{.*}}>, %[[t2:.*]]: memref<?xf32, #{{.*}}>, %[[t3:.*]]: memref<?x?xf32, #{{.*}}>
+func @linalg_op_bufferizes_inplace_with_input(
+    %t1: tensor<?x?xf32> {linalg.inplaceable = true},
+    %t2: tensor<?xf32>, %t3: tensor<?x?xf32>,
+    %s1: index, %s2: index, %cst: f32) -> tensor<?x?xf32> {
+  // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[t1]] : {{.*}})
+  %r = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1)-> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%t1, %t2 : tensor<?x?xf32>, tensor<?xf32>)
+    outs(%t3 : tensor<?x?xf32>) {
+      ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) :
+        %add = arith.addf %arg0, %arg1 : f32
+        linalg.yield %add : f32
+    } -> tensor<?x?xf32>
+  return %r : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @linalg_op_bufferizes_out_of_place_with_input
+//  CHECK-SAME:     %[[t1:.*]]: memref<?x?xf32, #{{.*}}>, %[[t2:.*]]: memref<?xf32, #{{.*}}>, %[[t3:.*]]: memref<?x?xf32, #{{.*}}>
+func @linalg_op_bufferizes_out_of_place_with_input(
+    %t1: tensor<?x?xf32>, %t2: tensor<?xf32>, %t3: tensor<?x?xf32>,
+    %s1: index, %s2: index, %cst: f32) -> tensor<?x?xf32> {
+  // CHECK: %[[alloc:.*]] = memref.alloc
+  // CHECK: linalg.copy(%[[t1]], %[[alloc]])
+  // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[alloc]] : {{.*}})
+  %r = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1)-> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%t1, %t2 : tensor<?x?xf32>, tensor<?xf32>)
+    outs(%t3 : tensor<?x?xf32>) {
+      ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) :
+        %add = arith.addf %arg0, %arg1 : f32
+        linalg.yield %add : f32
+    } -> tensor<?x?xf32>
+  // CHECK: return %[[alloc]]
+  return %r : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @linalg_op_output_cannot_alias_with_input
+//  CHECK-SAME:     %[[t1:.*]]: memref<?x?xf32, #{{.*}}>, %[[t2:.*]]: memref<?xf32, #{{.*}}>, %[[t3:.*]]: memref<?x?xf32, #{{.*}}>
+func @linalg_op_output_cannot_alias_with_input(
+    %t1: tensor<?x?xf32> {linalg.inplaceable = true},
+    %t2: tensor<?xf32>, %t3: tensor<?x?xf32> {linalg.inplaceable = true},
+    %s1: index, %s2: index, %cst: f32) -> tensor<?x?xf32> {
+  // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[t3]] : {{.*}})
+  %r = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1)-> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%t1, %t2 : tensor<?x?xf32>, tensor<?xf32>)
+    outs(%t3 : tensor<?x?xf32>) {
+      ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) :
+        %add = arith.addf %arg0, %arg1 : f32
+        linalg.yield %add : f32
+    } -> tensor<?x?xf32>
+  return %r : tensor<?x?xf32>
+}
+


        


More information about the Mlir-commits mailing list