[Mlir-commits] [mlir] be8742b - [mlir][linalg][bufferize][NFC] Merge AllocationCallbacks into BufferizationOptions

Matthias Springer llvmlistbot at llvm.org
Wed Jan 19 01:43:19 PST 2022


Author: Matthias Springer
Date: 2022-01-19T18:36:34+09:00
New Revision: be8742b6c9c7d919db33212a27dc230632dbb1d3

URL: https://github.com/llvm/llvm-project/commit/be8742b6c9c7d919db33212a27dc230632dbb1d3
DIFF: https://github.com/llvm/llvm-project/commit/be8742b6c9c7d919db33212a27dc230632dbb1d3.diff

LOG: [mlir][linalg][bufferize][NFC] Merge AllocationCallbacks into BufferizationOptions

Also move `createAlloc` and related helper functions out of BufferizationState. The goal is to make BufferizationState as small as possible. (Code cleanup)

Differential Revision: https://reviews.llvm.org/D117476

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 807d86331401..8fb05125deeb 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -38,34 +38,6 @@ struct BufferizationOptions;
 class BufferizationState;
 struct PostAnalysisStep;
 
-/// Callback functions that are used to allocate/deallocate/copy memory buffers.
-/// Comprehensive Bufferize provides default implementations of these functions.
-// TODO: Could be replaced with a "bufferization strategy" object with virtual
-// functions in the future.
-struct AllocationCallbacks {
-  using AllocationFn = std::function<FailureOr<Value>(
-      OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
-  using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
-  using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
-
-  AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
-                      MemCpyFn copyFn)
-      : allocationFn(std::move(allocFn)), deallocationFn(std::move(deallocFn)),
-        memCpyFn(std::move(copyFn)) {}
-
-  /// A function that allocates memory.
-  AllocationFn allocationFn;
-
-  /// A function that deallocated memory. Must be allocated by `allocationFn`.
-  DeallocationFn deallocationFn;
-
-  /// A function that copies memory between two allocations.
-  MemCpyFn memCpyFn;
-};
-
-/// Return default allocation callbacks.
-std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
-
 /// PostAnalysisSteps can be registered with `BufferizationOptions` and are
 /// executed after the analysis, but before bufferization. They can be used to
 /// implement custom dialect-specific optimizations.
@@ -84,6 +56,13 @@ using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
 
 /// Options for ComprehensiveBufferize.
 struct BufferizationOptions {
+  using AllocationFn = std::function<FailureOr<Value>(
+      OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
+  using DeallocationFn =
+      std::function<LogicalResult(OpBuilder &, Location, Value)>;
+  using MemCpyFn =
+      std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+
   BufferizationOptions();
 
   // BufferizationOptions cannot be copied.
@@ -126,7 +105,9 @@ struct BufferizationOptions {
   BufferizableOpInterface dynCastBufferizableOp(Value value) const;
 
   /// Helper functions for allocation, deallocation, memory copying.
-  std::unique_ptr<AllocationCallbacks> allocationFns;
+  Optional<AllocationFn> allocationFn;
+  Optional<DeallocationFn> deallocationFn;
+  Optional<MemCpyFn> memCpyFn;
 
   /// Specifies whether returning newly allocated memrefs should be allowed.
   /// Otherwise, a pass failure is triggered.
@@ -362,24 +343,6 @@ class BufferizationState {
   /// is returned regardless of whether it is a memory write or not.
   SetVector<Value> findLastPrecedingWrite(Value value) const;
 
-  /// Creates a memref allocation.
-  FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
-                               ArrayRef<Value> dynShape) const;
-
-  /// Creates a memref allocation for the given shaped value. This function may
-  /// perform additional optimizations such as buffer allocation hoisting. If
-  /// `createDealloc`, a deallocation op is inserted at the point where the
-  /// allocation goes out of scope.
-  FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
-                               bool deallocMemref) const;
-
-  /// Creates a memref deallocation. The given memref buffer must have been
-  /// allocated using `createAlloc`.
-  void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer) const;
-
-  /// Creates a memcpy between two given buffers.
-  void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;
-
   /// Return `true` if the given OpResult has been decided to bufferize inplace.
   bool isInPlace(OpOperand &opOperand) const;
 
@@ -458,6 +421,28 @@ UnrankedMemRefType getUnrankedMemRefType(Type elementType,
 MemRefType getDynamicMemRefType(RankedTensorType tensorType,
                                 unsigned addressSpace = 0);
 
+/// Creates a memref allocation.
+FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+                             ArrayRef<Value> dynShape,
+                             const BufferizationOptions &options);
+
+/// Creates a memref allocation for the given shaped value. This function may
+/// perform additional optimizations such as buffer allocation hoisting. If
+/// `createDealloc`, a deallocation op is inserted at the point where the
+/// allocation goes out of scope.
+FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
+                             bool deallocMemref,
+                             const BufferizationOptions &options);
+
+/// Creates a memref deallocation. The given memref buffer must have been
+/// allocated using `createAlloc`.
+LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
+                            const BufferizationOptions &options);
+
+/// Creates a memcpy between two given buffers.
+LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
+                           const BufferizationOptions &options);
+
 } // namespace comprehensive_bufferize
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 048a4a39111f..885a70f56c64 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -39,40 +39,8 @@ using namespace linalg::comprehensive_bufferize;
 // BufferizationOptions
 //===----------------------------------------------------------------------===//
 
-/// Default allocation function that is used by the comprehensive bufferization
-/// pass. The default currently creates a ranked memref using `memref.alloc`.
-static FailureOr<Value> defaultAllocationFn(OpBuilder &b, Location loc,
-                                            MemRefType type,
-                                            ArrayRef<Value> dynShape) {
-  Value allocated = b.create<memref::AllocOp>(
-      loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
-  return allocated;
-}
-
-/// Default deallocation function that is used by the comprehensive
-/// bufferization pass. It expects to recieve back the value called from the
-/// `defaultAllocationFn`.
-static void defaultDeallocationFn(OpBuilder &b, Location loc,
-                                  Value allocatedBuffer) {
-  b.create<memref::DeallocOp>(loc, allocatedBuffer);
-}
-
-/// Default memory copy function that is used by the comprehensive bufferization
-/// pass. Creates a `memref.copy` op.
-static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) {
-  b.create<memref::CopyOp>(loc, from, to);
-}
-
-std::unique_ptr<AllocationCallbacks>
-mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
-  return std::make_unique<AllocationCallbacks>(
-      defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
-}
-
-// Default constructor for BufferizationOptions that sets all allocation
-// callbacks to their default functions.
-BufferizationOptions::BufferizationOptions()
-    : allocationFns(defaultAllocationCallbacks()) {}
+// Default constructor for BufferizationOptions.
+BufferizationOptions::BufferizationOptions() {}
 
 BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
     BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
@@ -393,8 +361,8 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
   // allocation should be inserted (in the absence of allocation hoisting).
   setInsertionPointAfter(rewriter, operandBuffer);
   // Allocate the result buffer.
-  FailureOr<Value> resultBuffer =
-      createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
+  FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer,
+                                              options.createDeallocs, options);
   if (failed(resultBuffer))
     return failure();
   // Do not copy if the last preceding writes of `operand` are ops that do
@@ -425,7 +393,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
     // The copy happens right before the op that is bufferized.
     rewriter.setInsertionPoint(op);
   }
-  createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
+  if (failed(
+          createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options)))
+    return failure();
 
   return resultBuffer;
 }
@@ -545,9 +515,9 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
 /// Create an AllocOp/DeallocOp pair, where the AllocOp is after
 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a
 /// bbArg) and the DeallocOp is at the end of the block.
-FailureOr<Value>
-mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
-    OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const {
+FailureOr<Value> mlir::linalg::comprehensive_bufferize::createAlloc(
+    OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref,
+    const BufferizationOptions &options) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -558,7 +528,8 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
   // Note: getAllocationTypeAndShape also sets the insertion point.
   MemRefType allocMemRefType =
       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
-  FailureOr<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
+  FailureOr<Value> allocated =
+      createAlloc(b, loc, allocMemRefType, dynShape, options);
   if (failed(allocated))
     return failure();
   Value casted = allocated.getValue();
@@ -572,30 +543,47 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
   if (deallocMemref) {
     // 2. Create memory deallocation.
     b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
-    createDealloc(b, loc, allocated.getValue());
+    if (failed(createDealloc(b, loc, allocated.getValue(), options)))
+      return failure();
   }
 
   return casted;
 }
 
 /// Create a memref allocation.
-FailureOr<Value>
-mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
-    OpBuilder &b, Location loc, MemRefType type,
-    ArrayRef<Value> dynShape) const {
-  return options.allocationFns->allocationFn(b, loc, type, dynShape);
+FailureOr<Value> mlir::linalg::comprehensive_bufferize::createAlloc(
+    OpBuilder &b, Location loc, MemRefType type, ArrayRef<Value> dynShape,
+    const BufferizationOptions &options) {
+  if (options.allocationFn)
+    return (*options.allocationFn)(b, loc, type, dynShape);
+
+  // Default bufferallocation via AllocOp.
+  Value allocated = b.create<memref::AllocOp>(
+      loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+  return allocated;
 }
 
 /// Create a memref deallocation.
-void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc(
-    OpBuilder &b, Location loc, Value allocatedBuffer) const {
-  return options.allocationFns->deallocationFn(b, loc, allocatedBuffer);
+LogicalResult mlir::linalg::comprehensive_bufferize::createDealloc(
+    OpBuilder &b, Location loc, Value allocatedBuffer,
+    const BufferizationOptions &options) {
+  if (options.deallocationFn)
+    return (*options.deallocationFn)(b, loc, allocatedBuffer);
+
+  // Default buffer deallocation via DeallocOp.
+  b.create<memref::DeallocOp>(loc, allocatedBuffer);
+  return success();
 }
 
 /// Create a memory copy between two memref buffers.
-void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
-    OpBuilder &b, Location loc, Value from, Value to) const {
-  return options.allocationFns->memCpyFn(b, loc, from, to);
+LogicalResult mlir::linalg::comprehensive_bufferize::createMemCpy(
+    OpBuilder &b, Location loc, Value from, Value to,
+    const BufferizationOptions &options) {
+  if (options.memCpyFn)
+    return (*options.memCpyFn)(b, loc, from, to);
+
+  b.create<memref::CopyOp>(loc, from, to);
+  return success();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 64bc6920da07..6d153522811d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -221,9 +221,9 @@ struct InitTensorOpInterface
     if (initTensorOp->getUses().empty())
       return success();
 
-    FailureOr<Value> alloc = state.createAlloc(
-        rewriter, initTensorOp->getLoc(), initTensorOp.result(),
-        state.getOptions().createDeallocs);
+    FailureOr<Value> alloc =
+        createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(),
+                    state.getOptions().createDeallocs, state.getOptions());
     if (failed(alloc))
       return failure();
     replaceOpWithBufferizedValues(rewriter, op, *alloc);
@@ -367,7 +367,9 @@ struct TiledLoopOpInterface
       Value output = std::get<1>(it);
       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
           newTerminator.getLoc(), output.getType(), std::get<0>(it));
-      state.createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp, output);
+      if (failed(createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp,
+                              output, state.getOptions())))
+        return failure();
     }
 
     // Erase old terminator.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 620328799712..a31832452ea8 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -158,8 +158,8 @@ struct ExtractSliceOpInterface
     Value alloc;
     if (!inplace) {
       FailureOr<Value> allocOrFailure =
-          state.createAlloc(rewriter, loc, extractSliceOp.result(),
-                            state.getOptions().createDeallocs);
+          createAlloc(rewriter, loc, extractSliceOp.result(),
+                      state.getOptions().createDeallocs, state.getOptions());
       if (failed(allocOrFailure))
         return failure();
       alloc = *allocOrFailure;
@@ -191,7 +191,9 @@ struct ExtractSliceOpInterface
     if (!inplace) {
       // Do not copy if the copied data is never read.
       if (state.isValueRead(extractSliceOp.result()))
-        state.createMemCpy(rewriter, extractSliceOp.getLoc(), subView, alloc);
+        if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
+                                alloc, state.getOptions())))
+          return failure();
       subView = alloc;
     }
 
@@ -461,7 +463,9 @@ struct InsertSliceOpInterface
     // tensor.extract_slice, the copy operation will eventually fold away.
     Value srcMemref =
         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
-    state.createMemCpy(rewriter, loc, srcMemref, subView);
+    if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
+                            state.getOptions())))
+      return failure();
 
     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
     return success();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index f4e9476727f0..36da9e52ea3c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -77,9 +77,10 @@ static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
 void LinalgComprehensiveModuleBufferize::runOnOperation() {
   auto options = std::make_unique<BufferizationOptions>();
   if (useAlloca) {
-    options->allocationFns->allocationFn = allocationFnUsingAlloca;
-    options->allocationFns->deallocationFn = [](OpBuilder &b, Location loc,
-                                                Value v) {};
+    options->allocationFn = allocationFnUsingAlloca;
+    options->deallocationFn = [](OpBuilder &b, Location loc, Value v) {
+      return success();
+    };
   }
 
   options->allowReturnMemref = allowReturnMemref;


        


More information about the Mlir-commits mailing list