[Mlir-commits] [mlir] 1eeffcd - [mlir][linalg][bufferize] Support custom insertion point for buffer copies

Matthias Springer llvmlistbot at llvm.org
Fri Jan 14 05:53:35 PST 2022


Author: Matthias Springer
Date: 2022-01-14T22:47:20+09:00
New Revision: 1eeffcdb7a11b0f9e9c28dc9a40fa4f08e737939

URL: https://github.com/llvm/llvm-project/commit/1eeffcdb7a11b0f9e9c28dc9a40fa4f08e737939
DIFF: https://github.com/llvm/llvm-project/commit/1eeffcdb7a11b0f9e9c28dc9a40fa4f08e737939.diff

LOG: [mlir][linalg][bufferize] Support custom insertion point for buffer copies

By default, copies are inserted right before the tensor OpOperand use. With this change, `bufferize` implementation can change the insertion point. This is needed for some ops where it would be illegal to insert a copy right before the use.

Differential Revision: https://reviews.llvm.org/D117291

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 9d865350f2313..807d86331401b 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -386,8 +386,10 @@ class BufferizationState {
   /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
   /// a new buffer and copy over data from the existing buffer if out-of-place
   /// bufferization was decided.
-  FailureOr<Value> getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
-                             bool forceInPlace = false) const;
+  FailureOr<Value>
+  getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
+            bool forceInPlace = false,
+            Optional<Operation *> customCopyInsertionPoint = None) const;
 
   /// Return dialect-specific bufferization state.
   template <typename StateT>

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index aa9c7bd6806c7..048a4a39111f0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -377,7 +377,8 @@ static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
 /// bufferization is necessary.
 FailureOr<Value>
 mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
-    RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace) const {
+    RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
+    Optional<Operation *> customCopyInsertionPoint) const {
   OpBuilder::InsertionGuard guard(rewriter);
   Operation *op = opOperand.getOwner();
   Location loc = op->getLoc();
@@ -418,9 +419,14 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
     return resultBuffer;
 
-  // The copy happens right before the op that is bufferized.
-  rewriter.setInsertionPoint(op);
+  if (customCopyInsertionPoint) {
+    rewriter.setInsertionPoint(*customCopyInsertionPoint);
+  } else {
+    // The copy happens right before the op that is bufferized.
+    rewriter.setInsertionPoint(op);
+  }
   createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
+
   return resultBuffer;
 }
 


        


More information about the Mlir-commits mailing list