[Mlir-commits] [mlir] 17194ca - [mlir][linalg][bufferize][NFC] Clean up tensor op bufferization

Matthias Springer llvmlistbot at llvm.org
Mon Nov 15 18:18:13 PST 2021


Author: Matthias Springer
Date: 2021-11-16T11:17:42+09:00
New Revision: 17194ca96ab5a46ee90c656f7654d3f15f5d46c6

URL: https://github.com/llvm/llvm-project/commit/17194ca96ab5a46ee90c656f7654d3f15f5d46c6
DIFF: https://github.com/llvm/llvm-project/commit/17194ca96ab5a46ee90c656f7654d3f15f5d46c6.diff

LOG: [mlir][linalg][bufferize][NFC] Clean up tensor op bufferization

Differential Revision: https://reviews.llvm.org/D113730

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 6482d740f1515..697b894f89908 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -2244,23 +2244,22 @@ struct ExtractSliceOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+    LDBG("bufferize: " << *extractSliceOp << '\n');
 
     // Take a guard before anything else.
     OpBuilder::InsertionGuard g(b);
     b.setInsertionPoint(extractSliceOp);
 
-    LDBG("bufferize: " << *extractSliceOp << '\n');
-
     Location loc = extractSliceOp.getLoc();
-    // Bail if source was not bufferized.
     Value srcMemref = state.lookupBuffer(extractSliceOp.source());
     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
     auto dstTensorType =
         extractSliceOp.result().getType().cast<RankedTensorType>();
 
     // If not inplaceable, alloc.
+    bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
     Value alloc;
-    if (!state.aliasInfo.isInPlace(extractSliceOp->getResult(0)))
+    if (!inplace)
       alloc = createNewAllocDeallocPairForShapedValue(
           b, loc, extractSliceOp.result(), state);
 
@@ -2278,7 +2277,7 @@ struct ExtractSliceOpInterface
     state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
 
     /// If not inplaceable, copy.
-    if (alloc) {
+    if (!inplace) {
       // Do not copy if the copied data is never read.
       if (isValueRead(extractSliceOp.result()))
         state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
@@ -2374,34 +2373,23 @@ struct InsertSliceOpInterface
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
+    // insert_slice ops arise from tiling and bufferizing them out-of-place is
+    // generally a deal breaker. When used with loops, this ends up cloning the
+    // whole tensor on every single iteration and is a symptom of a
+    // catastrophically bad scheduling decision.
+    // TODO: be very loud about it or even consider failing the pass.
     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+    LDBG("bufferize: " << *insertSliceOp << '\n');
 
     // Take a guard before anything else.
     OpBuilder::InsertionGuard g(b);
     b.setInsertionPoint(insertSliceOp);
-
-    LDBG("bufferize: " << *insertSliceOp << '\n');
-
     Location loc = insertSliceOp.getLoc();
-    // Since insert_slice arise from tiling and introducing loops, this
-    // case is generally a deal breaker. When used with loops, this ends up
-    // cloning the whole tensor on every single iteration and is a symptom
-    // of a catastrophically bad scheduling decision.
-    // TODO: be very loud about it or even consider failing the pass.
-    // Alloc a copy for `insertSliceOp.dest()`, it will become the result
-    // buffer.
+
+    // When bufferizing out-of-place, `getResultBuffer` allocates.
     Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
     if (!dstMemref)
       return failure();
-    auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
-
-    Value srcMemref = state.lookupBuffer(insertSliceOp.source());
-    auto subviewMemRefType =
-        memref::SubViewOp::inferRankReducedResultType(
-            insertSliceOp.getSourceType().getRank(), dstMemrefType,
-            insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
-            insertSliceOp.getMixedStrides())
-            .cast<MemRefType>();
 
     // A copy of the source buffer is needed if either:
     //   - The producer of `source` is not inplace. This is the case where a
@@ -2409,23 +2397,32 @@ struct InsertSliceOpInterface
     //   - The result is not inplace. This is the case where the whole tensor is
     //     cloned and the clone needs to be updated.
     // TODO: Is this necessary?
-    if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo,
-                                                            insertSliceOp) ||
-        !state.aliasInfo.isInPlace(insertSliceOp->getResult(0))) {
+    bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
+                        state.aliasInfo, insertSliceOp) ||
+                    !state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
+    if (needCopy) {
       LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
                                                     << " -> copy\n");
       // Take a subview of the dst.
+      auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+      auto subviewMemRefType =
+          memref::SubViewOp::inferRankReducedResultType(
+              insertSliceOp.getSourceType().getRank(), dstMemrefType,
+              insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+              insertSliceOp.getMixedStrides())
+              .cast<MemRefType>();
       Value subView = b.create<memref::SubViewOp>(
           loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
           insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
       // Insert new alias.
       state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
+      // Copy tensor.
+      Value srcMemref = state.lookupBuffer(insertSliceOp.source());
       state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
                                    subView);
     }
 
     state.mapBuffer(insertSliceOp.result(), dstMemref);
-
     return success();
   }
 };


        


More information about the Mlir-commits mailing list