[Mlir-commits] [mlir] 248e113 - [mlir][bufferize][NFC] Move helper functions to BufferizationOptions

Matthias Springer llvmlistbot at llvm.org
Wed May 11 07:23:33 PDT 2022


Author: Matthias Springer
Date: 2022-05-11T16:23:22+02:00
New Revision: 248e113e9f6e583ed93e52de621a89d098c6d79e

URL: https://github.com/llvm/llvm-project/commit/248e113e9f6e583ed93e52de621a89d098c6d79e
DIFF: https://github.com/llvm/llvm-project/commit/248e113e9f6e583ed93e52de621a89d098c6d79e.diff

LOG: [mlir][bufferize][NFC] Move helper functions to BufferizationOptions

Move helper functions for creating allocs/deallocs/memcpys to BufferizationOptions.

Differential Revision: https://reviews.llvm.org/D125375

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 7d80e472c90ea..421db92543f02 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -161,6 +161,19 @@ struct BufferizationOptions {
   Optional<DeallocationFn> deallocationFn;
   Optional<MemCpyFn> memCpyFn;
 
+  /// Create a memref allocation with the given type and dynamic extents.
+  FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+                               ValueRange dynShape) const;
+
+  /// Creates a memref deallocation. The given memref buffer must have been
+  /// allocated using `createAlloc`.
+  LogicalResult createDealloc(OpBuilder &b, Location loc,
+                              Value allocatedBuffer) const;
+
+  /// Creates a memcpy between two given buffers.
+  LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from,
+                             Value to) const;
+
   /// Specifies whether not bufferizable ops are allowed in the input. If so,
   /// bufferization.to_memref and bufferization.to_tensor ops are inserted at
   /// the boundaries.
@@ -514,15 +527,6 @@ BaseMemRefType getMemRefType(TensorType tensorType,
                              MemRefLayoutAttrInterface layout = {},
                              Attribute memorySpace = {});
 
-/// Creates a memref deallocation. The given memref buffer must have been
-/// allocated using `createAlloc`.
-LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
-                            const BufferizationOptions &options);
-
-/// Creates a memcpy between two given buffers.
-LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
-                           const BufferizationOptions &options);
-
 /// Try to hoist all new buffer allocations until the next hoisting barrier.
 LogicalResult hoistBufferAllocations(Operation *op,
                                      const BufferizationOptions &options);

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 8a3bc5c40b5c3..29d983bdffdf1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -327,8 +327,7 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
     // The copy happens right before the op that is bufferized.
     rewriter.setInsertionPoint(op);
   }
-  if (failed(
-          createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options)))
+  if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer)))
     return failure();
 
   return resultBuffer;
@@ -418,26 +417,24 @@ bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const {
 //===----------------------------------------------------------------------===//
 
 /// Create a memref allocation with the given type and dynamic extents.
-static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
-                                    ValueRange dynShape,
-                                    const BufferizationOptions &options) {
-  if (options.allocationFn)
-    return (*options.allocationFn)(b, loc, type, dynShape,
-                                   options.bufferAlignment);
+FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
+                                                   MemRefType type,
+                                                   ValueRange dynShape) const {
+  if (allocationFn)
+    return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
 
   // Default bufferallocation via AllocOp.
   Value allocated = b.create<memref::AllocOp>(
-      loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
+      loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
   return allocated;
 }
 
 /// Creates a memref deallocation. The given memref buffer must have been
 /// allocated using `createAlloc`.
-LogicalResult
-bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
-                             const BufferizationOptions &options) {
-  if (options.deallocationFn)
-    return (*options.deallocationFn)(b, loc, allocatedBuffer);
+LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
+                                                  Value allocatedBuffer) const {
+  if (deallocationFn)
+    return (*deallocationFn)(b, loc, allocatedBuffer);
 
   // Default buffer deallocation via DeallocOp.
   b.create<memref::DeallocOp>(loc, allocatedBuffer);
@@ -523,11 +520,10 @@ FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
 }
 
 /// Create a memory copy between two memref buffers.
-LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
-                                          Value from, Value to,
-                                          const BufferizationOptions &options) {
-  if (options.memCpyFn)
-    return (*options.memCpyFn)(b, loc, from, to);
+LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
+                                                 Value from, Value to) const {
+  if (memCpyFn)
+    return (*memCpyFn)(b, loc, from, to);
 
   b.create<memref::CopyOp>(loc, from, to);
   return success();
@@ -557,8 +553,8 @@ bufferization::createAllocDeallocOps(Operation *op,
     Block *block = allocaOp->getBlock();
     rewriter.setInsertionPoint(allocaOp);
     FailureOr<Value> alloc =
-        createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
-                    allocaOp.dynamicSizes(), options);
+        options.createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
+                            allocaOp.dynamicSizes());
     if (failed(alloc))
       return WalkResult::interrupt();
     rewriter.replaceOp(allocaOp, *alloc);
@@ -571,7 +567,7 @@ bufferization::createAllocDeallocOps(Operation *op,
 
     // Create dealloc.
     rewriter.setInsertionPoint(block->getTerminator());
-    if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
+    if (failed(options.createDealloc(rewriter, alloc->getLoc(), *alloc)))
       return WalkResult::interrupt();
 
     return WalkResult::advance();

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 25d3df2fac287..dda46141e5382 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -495,7 +495,7 @@ struct FuncOpInterface
           // Note: This copy will fold away. It must be inserted here to ensure
           // that `returnVal` still has at least one use and does not fold away.
           if (failed(
-                  createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options)))
+                  options.createMemCpy(rewriter, loc, toMemrefOp, equivBbArg)))
             return funcOp->emitError("could not generate copy for bbArg");
           continue;
         }

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 39af7d337d054..337b0aa57ea3e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -363,8 +363,8 @@ static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
   // TODO: We should rollback, but for now just assume that this always
   // succeeds.
   assert(yieldedAlloc.hasValue() && "could not create alloc");
-  LogicalResult copyStatus = bufferization::createMemCpy(
-      rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc, state.getOptions());
+  LogicalResult copyStatus = state.getOptions().createMemCpy(
+      rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc);
   (void)copyStatus;
   assert(succeeded(copyStatus) && "could not create memcpy");
 

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index efd2de7978e94..3ec52e94423f6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -320,8 +320,8 @@ struct ExtractSliceOpInterface
     if (!inplace) {
       // Do not copy if the copied data is never read.
       if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
-        if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
-                                alloc, state.getOptions())))
+        if (failed(state.getOptions().createMemCpy(
+                rewriter, extractSliceOp.getLoc(), subView, alloc)))
           return failure();
       subView = alloc;
     }
@@ -718,8 +718,8 @@ struct InsertSliceOpInterface
     // tensor.extract_slice, the copy operation will eventually fold away.
     auto srcMemref =
         state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
-    if (failed(srcMemref) || failed(createMemCpy(rewriter, loc, *srcMemref,
-                                                 subView, state.getOptions())))
+    if (failed(srcMemref) || failed(state.getOptions().createMemCpy(
+                                 rewriter, loc, *srcMemref, subView)))
       return failure();
 
     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);


        


More information about the Mlir-commits mailing list