[Mlir-commits] [mlir] 6f90955 - [mlir][Linalg] Add support for subtensor_insert comprehensive bufferization (3/n)

Nicolas Vasilache llvmlistbot at llvm.org
Fri May 14 14:54:35 PDT 2021


Author: Nicolas Vasilache
Date: 2021-05-14T21:51:00Z
New Revision: 6f90955f6949397e20d359e4b914c2d46b35f863

URL: https://github.com/llvm/llvm-project/commit/6f90955f6949397e20d359e4b914c2d46b35f863
DIFF: https://github.com/llvm/llvm-project/commit/6f90955f6949397e20d359e4b914c2d46b35f863.diff

LOG: [mlir][Linalg] Add support for subtensor_insert comprehensive bufferization (3/n)

Differential revision: https://reviews.llvm.org/D102417

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 758cde2880312..0f271c2f27983 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -128,6 +128,8 @@ OpResult getMatchingOpResult(LinalgOp linalgOp, OpOperand &opOperand) {
   return linalgOp->getResult(outputOperandIndex - numOutputBuffers);
 }
 
+/// Return the OpResult that matches an operand.
+/// Return null if no such result exists.
 OpResult getMatchingOpResult(VectorTransferOpInterface op,
                              OpOperand &opOperand) {
   if (opOperand.get() != op.source() ||
@@ -136,17 +138,25 @@ OpResult getMatchingOpResult(VectorTransferOpInterface op,
   return op->getResult(0);
 }
 
+/// Return the OpResult that matches an operand.
+/// Return null if no such result exists.
+OpResult getMatchingOpResult(SubTensorInsertOp op, OpOperand &opOperand) {
+  if (opOperand.get() != op.dest())
+    return OpResult();
+  return op->getResult(0);
+}
+
 /// Determine which results may be reused inplace by the bufferization
 /// patterns of `bufferizeFuncOpInternals`.
 /// The inplace analysis uses this information along with interfering read
 /// analysis to determine which op results reuse the same buffer as some
 /// operand.
 OpResult getMatchingOpResult(OpOperand &opOperand) {
-  OpResult res = llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
-                     .Case<LinalgOp, VectorTransferOpInterface>([&](auto op) {
-                       return getMatchingOpResult(op, opOperand);
-                     })
-                     .Default([&](Operation *op) { return OpResult(); });
+  OpResult res =
+      llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
+          .Case<LinalgOp, SubTensorInsertOp, VectorTransferOpInterface>(
+              [&](auto op) { return getMatchingOpResult(op, opOperand); })
+          .Default([&](Operation *op) { return OpResult(); });
   return res;
 }
 
@@ -644,8 +654,8 @@ static void finalizeBufferAllocation(OpBuilder &b, LinalgOp op,
 
 /// Generic conversion for any LinalgOp.
 /// Operate on mixed tensor + buffer Linalg ops for progressive bufferization.
-static LogicalResult convertAnyLinalgOp(OpBuilder &b, LinalgOp op,
-                                        BlockAndValueMapping &bvm) {
+static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
+                               BlockAndValueMapping &bvm) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -668,16 +678,16 @@ static LogicalResult convertAnyLinalgOp(OpBuilder &b, LinalgOp op,
 
 /// DimOp tensor operand is modified inplace. This allows leaving dead tensors
 /// behind that will get DCE'd.
-static LogicalResult convertDimOp(OpBuilder &b, memref::DimOp dimOp,
-                                  BlockAndValueMapping &bvm) {
+static LogicalResult bufferize(OpBuilder &b, memref::DimOp dimOp,
+                               BlockAndValueMapping &bvm) {
   if (dimOp.memrefOrTensor().getType().isa<RankedTensorType>())
     dimOp.memrefOrTensorMutable().assign(lookup(bvm, dimOp.memrefOrTensor()));
   return success();
 }
 
 /// FuncOp always creates TensorToMemRef ops.
-static LogicalResult convertFuncOp(OpBuilder &b, FuncOp funcOp,
-                                   BlockAndValueMapping &bvm) {
+static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
+                               BlockAndValueMapping &bvm) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPointToStart(&funcOp.body().front());
@@ -699,8 +709,8 @@ static LogicalResult convertFuncOp(OpBuilder &b, FuncOp funcOp,
 }
 
 /// ReturnOp always creates memref::TensorLoadOp.
-static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp,
-                                     BlockAndValueMapping &bvm) {
+static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
+                               BlockAndValueMapping &bvm) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(returnOp);
@@ -717,9 +727,69 @@ static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp,
   return success();
 }
 
-static LogicalResult convertTransferOp(OpBuilder &b,
-                                       VectorTransferOpInterface op,
-                                       BlockAndValueMapping &bvm) {
+static LogicalResult bufferize(OpBuilder &b,
+                               SubTensorInsertOp subTensorInsertOp,
+                               BlockAndValueMapping &bvm) {
+  LLVM_DEBUG(DBGS() << "bufferize: " << *subTensorInsertOp << "\n");
+
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(subTensorInsertOp);
+  Location loc = subTensorInsertOp.getLoc();
+
+  Value dstMemref = lookup(bvm, subTensorInsertOp.dest());
+  auto inPlace = getInPlace(subTensorInsertOp->getResult(0));
+  if (inPlace != InPlaceSpec::True) {
+    // Since subtensor_insert arise from tiling and introducing loops, this case
+    // is generally a deal breaker. When used with loops, this ends up cloning
+    // the whole tensor on every single iteration and is a symtpom of a
+    // catastrophically bad scheduling decision.
+    // TODO: be very loud about it or even consider failing the pass.
+    Value newDstMemref = createNewAllocDeallocPairForShapedValue(
+        b, loc, subTensorInsertOp.result());
+    b.setInsertionPointAfter(newDstMemref.getDefiningOp());
+    b.create<CopyOp>(subTensorInsertOp.getLoc(), dstMemref, newDstMemref);
+    dstMemref = newDstMemref;
+  }
+  auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+
+  Value srcMemref = lookup(bvm, subTensorInsertOp.source());
+  auto subviewMemRefType =
+      memref::SubViewOp::inferRankReducedResultType(
+          subTensorInsertOp.getSourceType().getRank(), dstMemrefType,
+          subTensorInsertOp.getMixedOffsets(),
+          subTensorInsertOp.getMixedSizes(),
+          subTensorInsertOp.getMixedStrides())
+          .cast<MemRefType>();
+
+  // A copy of the source buffer is needed if either:
+  //   - The producer of `source` is not inplace. This is the case where a
+  //     subtensor is computed out of place into the inplace full tensor.
+  //   - The result is not inplace. This is the case where the whole tensor is
+  //     cloned and the clone needs to be updated.
+  Value source = subTensorInsertOp.source();
+  InPlaceSpec inPlaceProducer = InPlaceSpec::None;
+  if (auto opResult = source.dyn_cast<OpResult>())
+    inPlaceProducer = getInPlace(opResult);
+  else
+    inPlaceProducer = getInPlace(source.cast<BlockArgument>());
+  if (inPlaceProducer != InPlaceSpec::True) {
+    LLVM_DEBUG(DBGS() << "subtensor_insert needs extra source copy: " << source
+                      << " -> copy\n");
+    // Take a subview of the dst.
+    Value subView = b.create<memref::SubViewOp>(
+        loc, subviewMemRefType, dstMemref, subTensorInsertOp.getMixedOffsets(),
+        subTensorInsertOp.getMixedSizes(), subTensorInsertOp.getMixedStrides());
+    b.create<CopyOp>(subTensorInsertOp.getLoc(), srcMemref, subView);
+  }
+
+  map(bvm, subTensorInsertOp.result(), dstMemref);
+
+  return success();
+}
+
+static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
+                               BlockAndValueMapping &bvm) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
@@ -730,7 +800,7 @@ static LogicalResult convertTransferOp(OpBuilder &b,
 
   LLVM_DEBUG(DBGS() << "convert: " << *op << "\n");
 
-  /// transfer_read from buffer
+  /// transfer_read from buffer always reads from the bufferized op.source().
   if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
     readOp.sourceMutable().assign(lookup(bvm, op.source()));
     return success();
@@ -778,8 +848,8 @@ static LogicalResult bufferizeFuncOpInternals(
     FuncOp funcOp, BlockAndValueMapping &bvm,
     const DenseMap<FuncOp, SmallVector<int64_t>> &tiedResultsMap) {
   OpBuilder b(funcOp->getContext());
-  /// Start by converting `funcOp` arguments.
-  if (failed(convertFuncOp(b, funcOp, bvm)))
+  /// Start by bufferizing `funcOp` arguments.
+  if (failed(bufferize(b, funcOp, bvm)))
     return failure();
   WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) {
     LogicalResult status =
@@ -787,12 +857,9 @@ static LogicalResult bufferizeFuncOpInternals(
             // Skip BufferCast and TensorLoad ops.
             .Case<memref::BufferCastOp, memref::TensorLoadOp>(
                 [&](auto) { return success(); })
-            .Case([&](memref::DimOp op) { return convertDimOp(b, op, bvm); })
-            .Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); })
-            .Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); })
-            .Case([&](VectorTransferOpInterface op) {
-              return convertTransferOp(b, op, bvm);
-            })
+            .Case<memref::DimOp, LinalgOp, ReturnOp, SubTensorInsertOp,
+                  VectorTransferOpInterface>(
+                [&](auto op) { return bufferize(b, op, bvm); })
             .Default([&](Operation *op) {
               auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
               if (llvm::any_of(op->getOperandTypes(), isaTensor) ||

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
index a5636329fd3a0..19599b99866bf 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
@@ -96,14 +96,13 @@ func @vec_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vector<
 // -----
 
 // CHECK-LABEL: func @vec_not_inplace
-//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: tensor<?xf32> {linalg.inplaceable = true}
 func @vec_not_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vector<4xf32>)
     -> (tensor<?xf32>, tensor<?xf32>)
 {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
 
-  //       CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast %[[A]] : memref<?xf32, #[[$map_2d_dyn]]>
+  //       CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast {{.*}} : memref<?xf32, #[[$map_2d_dyn]]>
 
   /// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc.
   //      CHECK: %[[ALLOC:.*]] = memref.alloc
@@ -117,3 +116,105 @@ func @vec_not_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vec
   return %r0, %r1: tensor<?xf32>, tensor<?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @subtensor_insert_fun
+func @subtensor_insert_fun(%A : tensor<?xf32> {linalg.inplaceable = true}, %t : tensor<4xf32>)
+  ->  tensor<?xf32>
+{
+  //      CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+  //      CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast {{.*}} : memref<4xf32
+
+  //  CHECK-NOT: alloc
+  //      CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]]
+  //      CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]])
+  %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+  return %r0: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @subtensor_insert_fun
+func @subtensor_insert_fun(%A : tensor<?xf32> {linalg.inplaceable = true}, %t : tensor<4xf32>)
+  -> tensor<?xf32>
+{
+  %f0 = constant 0.0 : f32
+
+  //      CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+  //      CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast {{.*}} : memref<4xf32
+
+  //  CHECK-NOT: alloc
+  //      CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]]
+  //      CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]])
+  %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+
+  /// Overwrite BUFFER_CAST_A inplace.
+  //      CHECK: linalg.fill(%[[BUFFER_CAST_A]]
+  %r1 = linalg.fill(%r0, %f0) : tensor<?xf32>, f32 -> tensor<?xf32>
+  return %r1: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @subtensor_insert_fun
+func @subtensor_insert_fun(%A : tensor<?xf32> {linalg.inplaceable = true}, %t : tensor<4xf32>)
+  -> tensor<?xf32>
+{
+  %f0 = constant 0.0 : f32
+
+  //      CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+  //      CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast {{.*}} : memref<4xf32
+
+  //      CHECK: linalg.fill(%[[BUFFER_CAST_A]]
+  %r0 = linalg.fill(%A, %f0) : tensor<?xf32>, f32 -> tensor<?xf32>
+
+  //  CHECK-NOT: alloc
+  //      CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]]
+  /// Overwrite BUFFER_CAST_A inplace by copying into the subview.
+  //      CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]])
+  %r1 = subtensor_insert %t into %r0[0][4][1] : tensor<4xf32> into tensor<?xf32>
+
+  return %r1: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @subtensor_insert_fun_not_inplace
+func @subtensor_insert_fun_not_inplace(%A : tensor<?xf32>, %t : tensor<4xf32>)
+  -> tensor<?xf32>
+{
+  //      CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+  //      CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast {{.*}} : memref<4xf32
+
+  //      CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) : memref<?xf32>
+  //      CHECK: linalg.copy(%[[BUFFER_CAST_A]], %[[ALLOC]]) : memref<?xf32{{.*}}, memref<?xf32>
+  //      CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref<?xf32> to memref<4xf32>
+  //      CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]]) : memref<4xf32, #map>, memref<4xf32>
+  //      CHECK: memref.dealloc %[[ALLOC]] : memref<?xf32>
+  %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+  return %r0: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @subtensor_insert_fun_not_inplace
+func @subtensor_insert_fun_not_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %t : tensor<4xf32>)
+  -> (tensor<?xf32>, tensor<?xf32>)
+{
+  %f0 = constant 0.0 : f32
+
+  //      CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+  //      CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast {{.*}} : memref<4xf32
+
+  //      CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) : memref<?xf32>
+  //      CHECK: linalg.copy(%[[BUFFER_CAST_A]], %[[ALLOC]]) : memref<?xf32{{.*}}, memref<?xf32>
+  //      CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref<?xf32> to memref<4xf32>
+  //      CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]]) : memref<4xf32, #map>, memref<4xf32>
+  %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+
+  // TODO: WAW optimization where result is overwritten without being read.
+  //      CHECK: linalg.fill(%[[BUFFER_CAST_A]]
+  //      CHECK: memref.dealloc %[[ALLOC]] : memref<?xf32>
+  %r1 = linalg.fill(%A, %f0) : tensor<?xf32>, f32 -> tensor<?xf32>
+  return %r0, %r1: tensor<?xf32>, tensor<?xf32>
+}


        


More information about the Mlir-commits mailing list