[Mlir-commits] [mlir] 05e0495 - [mlir][bufferize][NFC] Deallocate all buffers at the end of bufferization

Matthias Springer llvmlistbot at llvm.org
Tue Mar 15 01:58:09 PDT 2022


Author: Matthias Springer
Date: 2022-03-15T17:53:53+09:00
New Revision: 05e0495f1d0c6386c8ee30df15d53a7cf4d6bfdc

URL: https://github.com/llvm/llvm-project/commit/05e0495f1d0c6386c8ee30df15d53a7cf4d6bfdc
DIFF: https://github.com/llvm/llvm-project/commit/05e0495f1d0c6386c8ee30df15d53a7cf4d6bfdc.diff

LOG: [mlir][bufferize][NFC] Deallocate all buffers at the end of bufferization

This makes bufferization more modular. This is in preparation of future refactorings.

Differential Revision: https://reviews.llvm.org/D121362

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c7cfc7241a509..500f7e89b1758 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -415,13 +415,22 @@ struct BufferizationState {
   BufferizationState(const AnalysisState &analysisState)
       : analysisState(analysisState) {}
 
+  /// Creates a memref allocation with the given type and dynamic extents.
+  FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+                               ValueRange dynShape);
+
+  /// Creates a memref allocation for the given shaped value. This function may
+  /// perform additional optimizations such as buffer allocation hoisting.
+  // TODO: Allocation hoisting should be a cleanup pass.
+  FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue);
+
   /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
   /// a new buffer and copy over data from the existing buffer if out-of-place
   /// bufferization was decided.
   FailureOr<Value>
   getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
             bool forceInPlace = false,
-            Optional<Operation *> customCopyInsertionPoint = None) const;
+            Optional<Operation *> customCopyInsertionPoint = None);
 
   /// Return a reference to the BufferizationOptions.
   const BufferizationOptions &getOptions() const {
@@ -477,27 +486,6 @@ BaseMemRefType getMemRefType(TensorType tensorType,
                              MemRefLayoutAttrInterface layout = {},
                              Attribute memorySpace = {});
 
-/// Creates a memref allocation with the given type and dynamic extents.
-FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
-                             ValueRange dynShape,
-                             const BufferizationOptions &options);
-
-/// Creates a memref allocation with the given type and dynamic extents. If
-/// `createDealloc`, a deallocation op is inserted at the point where the
-/// allocation goes out of scope.
-FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
-                             ValueRange dynShape, bool deallocMemref,
-                             const BufferizationOptions &options);
-
-/// Creates a memref allocation for the given shaped value. This function may
-/// perform additional optimizations such as buffer allocation hoisting. If
-/// `createDealloc`, a deallocation op is inserted at the point where the
-/// allocation goes out of scope.
-// TODO: Allocation hoisting should be a cleanup pass.
-FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
-                             bool deallocMemref,
-                             const BufferizationOptions &options);
-
 /// Creates a memref deallocation. The given memref buffer must have been
 /// allocated using `createAlloc`.
 LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
@@ -507,6 +495,10 @@ LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
 LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
                            const BufferizationOptions &options);
 
+/// Finalize all buffer allocations, i.e., create alloc ops as specified in the
+/// bufferization options and deallocate all buffers.
+LogicalResult finalizeBuffers(Operation *op,
+                              const BufferizationOptions &options);
 } // namespace bufferization
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 13b48d09d33ff..85a70a42d283a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -70,13 +70,6 @@ void populateEliminateBufferizeMaterializationsPatterns(
 // TODO: Extract `options` from `state` and pass as separate argument.
 LogicalResult bufferizeOp(Operation *op, const AnalysisState &analysisState);
 
-/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
-/// Reuse an existing `BufferizationState`.
-///
-/// Note: This function overload is useful for extending the bufferization.
-LogicalResult bufferizeOp(Operation *op,
-                          BufferizationState &bufferizationState);
-
 /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
 /// Buffers are duplicated and copied before any tensor use that bufferizes to
 /// a memory write.
@@ -87,6 +80,16 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
 
 BufferizationOptions getPartialBufferizationOptions();
 
+//===----------------------------------------------------------------------===//
+// Helper functions for extending Bufferization
+//===----------------------------------------------------------------------===//
+
+/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
+/// Reuse an existing `BufferizationState`.
+///
+/// Note: This function overload is useful for extending the bufferization.
+LogicalResult bufferizeOp(Operation *op,
+                          BufferizationState &bufferizationState);
 } // namespace bufferization
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d2b8f1de5628f..8a3dbfc960e9b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -42,6 +42,8 @@ constexpr const ::llvm::StringLiteral
 constexpr const ::llvm::StringLiteral
     bufferization::BufferizableOpInterface::kInplaceableAttrName;
 
+static const char *kBufferAllocationAttr = "bufferization.allocation";
+
 //===----------------------------------------------------------------------===//
 // BufferizationOptions
 //===----------------------------------------------------------------------===//
@@ -243,9 +245,10 @@ Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
 /// a new buffer and copy over data from the existing buffer if out-of-place
 /// bufferization is necessary.
-FailureOr<Value> BufferizationState::getBuffer(
-    RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
-    Optional<Operation *> customCopyInsertionPoint) const {
+FailureOr<Value>
+BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
+                              bool forceInPlace,
+                              Optional<Operation *> customCopyInsertionPoint) {
   const BufferizationOptions &options = analysisState.getOptions();
   OpBuilder::InsertionGuard guard(rewriter);
   Operation *op = opOperand.getOwner();
@@ -261,8 +264,7 @@ FailureOr<Value> BufferizationState::getBuffer(
   // allocation should be inserted (in the absence of allocation hoisting).
   setInsertionPointAfter(rewriter, operandBuffer);
   // Allocate the result buffer.
-  FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer,
-                                              options.createDeallocs, options);
+  FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer);
   if (failed(resultBuffer))
     return failure();
   // Do not copy if the last preceding writes of `operand` are ops that do
@@ -358,6 +360,33 @@ bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
 // Bufferization-specific scoped alloc/dealloc insertion support.
 //===----------------------------------------------------------------------===//
 
+/// Create a memref allocation with the given type and dynamic extents.
+static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+                                    ValueRange dynShape,
+                                    const BufferizationOptions &options) {
+  if (options.allocationFn)
+    return (*options.allocationFn)(b, loc, type, dynShape,
+                                   options.bufferAlignment);
+
+  // Default bufferallocation via AllocOp.
+  Value allocated = b.create<memref::AllocOp>(
+      loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
+  return allocated;
+}
+
+/// Creates a memref deallocation. The given memref buffer must have been
+/// allocated using `createAlloc`.
+LogicalResult
+bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
+                             const BufferizationOptions &options) {
+  if (options.deallocationFn)
+    return (*options.deallocationFn)(b, loc, allocatedBuffer);
+
+  // Default buffer deallocation via DeallocOp.
+  b.create<memref::DeallocOp>(loc, allocatedBuffer);
+  return success();
+}
+
 /// Move the insertion point of the given builder to the beginning of a
 /// surrounding block as much as possible, while not crossing any allocation
 /// hoisting barriers.
@@ -436,92 +465,39 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
   return allocMemRefType;
 }
 
-/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
-/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
-/// bbArg) and the DeallocOp is at the end of the block.
-FailureOr<Value>
-bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue,
-                           bool deallocMemref,
-                           const BufferizationOptions &options) {
+static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
+                                    ValueRange dynShape) {
+  auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
+  allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
+  return allocaOp.getResult();
+}
+
+/// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
+/// block in case of a bbArg).
+FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
+                                                 Value shapedValue) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
-
-  // 1. Create memory allocation.
   assert(shapedValue.getType().isa<ShapedType>());
   MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
   SmallVector<Value> dynShape;
   // Note: getAllocationTypeAndShape also sets the insertion point.
   MemRefType allocMemRefType =
       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
-  FailureOr<Value> allocated =
-      createAlloc(b, loc, allocMemRefType, dynShape, options);
-  if (failed(allocated))
-    return failure();
-  Value casted = allocated.getValue();
+  Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape);
   if (memRefType && memRefType != allocMemRefType) {
-    assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(),
-                                             memRefType) &&
+    assert(memref::CastOp::areCastCompatible(alloc.getType(), memRefType) &&
            "createAlloc: cast incompatible");
-    casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
-  }
-
-  if (deallocMemref) {
-    // 2. Create memory deallocation.
-    b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
-    if (failed(createDealloc(b, loc, allocated.getValue(), options)))
-      return failure();
-  }
-
-  return casted;
-}
-
-/// Create a memref allocation with the given type and dynamic extents.
-FailureOr<Value>
-bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
-                           ValueRange dynShape,
-                           const BufferizationOptions &options) {
-  if (options.allocationFn)
-    return (*options.allocationFn)(b, loc, type, dynShape,
-                                   options.bufferAlignment);
-
-  // Default bufferallocation via AllocOp.
-  Value allocated = b.create<memref::AllocOp>(
-      loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
-  return allocated;
-}
-
-/// Create a memref allocation with the given type and dynamic extents. May also
-/// deallocate the memref again.
-FailureOr<Value>
-bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
-                           ValueRange dynShape, bool deallocMemref,
-                           const BufferizationOptions &options) {
-  OpBuilder::InsertionGuard g(b);
-
-  FailureOr<Value> alloc = createAlloc(b, loc, type, dynShape, options);
-  if (failed(alloc))
-    return failure();
-
-  if (deallocMemref) {
-    // Dealloc at the end of the block.
-    b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator());
-    if (failed(createDealloc(b, loc, *alloc, options)))
-      return failure();
+    alloc = b.create<memref::CastOp>(loc, memRefType, alloc);
   }
-
   return alloc;
 }
 
-/// Create a memref deallocation.
-LogicalResult
-bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
-                             const BufferizationOptions &options) {
-  if (options.deallocationFn)
-    return (*options.deallocationFn)(b, loc, allocatedBuffer);
-
-  // Default buffer deallocation via DeallocOp.
-  b.create<memref::DeallocOp>(loc, allocatedBuffer);
-  return success();
+/// Create a memref allocation with the given type and dynamic extents.
+FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
+                                                 MemRefType type,
+                                                 ValueRange dynShape) {
+  return createBufferAllocation(b, loc, type, dynShape);
 }
 
 /// Create a memory copy between two memref buffers.
@@ -535,6 +511,41 @@ LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
   return success();
 }
 
+LogicalResult
+bufferization::finalizeBuffers(Operation *op,
+                               const BufferizationOptions &options) {
+  IRRewriter rewriter(op->getContext());
+
+  // Bufferization creates memref.alloca ops. After bufferization, these must be
+  // rewritten to alloc/dealloc ops as specified in the bufferization options.
+  WalkResult status = op->walk([&](memref::AllocaOp allocaOp) {
+    // Ignore memref.alloca ops that were not created by the bufferization.
+    if (!allocaOp->hasAttr(kBufferAllocationAttr))
+      return WalkResult::skip();
+
+    Block *block = allocaOp->getBlock();
+    rewriter.setInsertionPoint(allocaOp);
+    FailureOr<Value> alloc =
+        createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
+                    allocaOp.dynamicSizes(), options);
+    if (failed(alloc))
+      return WalkResult::interrupt();
+    rewriter.replaceOp(allocaOp, *alloc);
+
+    // Stop here if deallocations are deactivated.
+    if (!options.createDeallocs)
+      return WalkResult::advance();
+
+    rewriter.setInsertionPoint(block->getTerminator());
+    if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  return success(!status.wasInterrupted());
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific BlockAndValueMapping support with debugging.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index ba6cd37751119..f237cb7a6a70e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -302,7 +302,11 @@ checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
 LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const AnalysisState &analysisState) {
   BufferizationState bufferizationState(analysisState);
-  return bufferizeOp(op, bufferizationState);
+  if (failed(bufferizeOp(op, bufferizationState)))
+    return failure();
+  if (failed(finalizeBuffers(op, analysisState.getOptions())))
+    return failure();
+  return success();
 }
 
 LogicalResult
@@ -332,7 +336,10 @@ bufferization::bufferizeOp(Operation *op,
   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
     return failure();
 
-  return checkBufferizationResult(op, bufferizationState.getOptions());
+  if (failed(checkBufferizationResult(op, bufferizationState.getOptions())))
+    return failure();
+
+  return success();
 }
 
 namespace {

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 85c425623eacc..333c874761049 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -1054,6 +1054,10 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
     }
   }
 
+  // Finalize all buffers.
+  if (failed(finalizeBuffers(moduleOp, options)))
+    return failure();
+
   // Perform a post-processing pass of layout modification at function boundary
   // according to the kBufferLayoutAttrName.
   layoutPostProcessing(moduleOp);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 9153bce176fa0..2cbed7cebb1e9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -235,9 +235,8 @@ struct InitTensorOpInterface
     if (initTensorOp->getUses().empty())
       return success();
 
-    FailureOr<Value> alloc =
-        createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(),
-                    state.getOptions().createDeallocs, state.getOptions());
+    FailureOr<Value> alloc = state.createAlloc(rewriter, initTensorOp->getLoc(),
+                                               initTensorOp.result());
     if (failed(alloc))
       return failure();
     replaceOpWithBufferizedValues(rewriter, op, *alloc);

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index f068010a50895..67e28c46f3969 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -228,8 +228,7 @@ struct ExtractSliceOpInterface
     Value alloc;
     if (!inplace) {
       FailureOr<Value> allocOrFailure =
-          createAlloc(rewriter, loc, extractSliceOp.result(),
-                      state.getOptions().createDeallocs, state.getOptions());
+          state.createAlloc(rewriter, loc, extractSliceOp.result());
       if (failed(allocOrFailure))
         return failure();
       alloc = *allocOrFailure;
@@ -338,9 +337,7 @@ struct FromElementsOpInterface
     auto shape = tensorType.getShape();
     MemRefType resultType = getContiguousMemRefType(tensorType);
     FailureOr<Value> maybeBuffer =
-        createAlloc(rewriter, loc, resultType, {},
-                    /*deallocMemref=*/state.getOptions().createDeallocs,
-                    state.getOptions());
+        state.createAlloc(rewriter, loc, resultType, {});
     if (failed(maybeBuffer))
       return failure();
     Value buffer = *maybeBuffer;
@@ -389,10 +386,8 @@ struct GenerateOpInterface
     Location loc = op->getLoc();
     MemRefType memrefType =
         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
-    FailureOr<Value> maybeResult =
-        createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
-                    /*deallocMemref=*/state.getOptions().createDeallocs,
-                    state.getOptions());
+    FailureOr<Value> maybeResult = state.createAlloc(
+        rewriter, loc, memrefType, generateOp.dynamicExtents());
     if (failed(maybeResult))
       return failure();
     Value result = *maybeResult;


        


More information about the Mlir-commits mailing list