[Mlir-commits] [mlir] 6db2007 - [mlir][linalg][bufferize][NFC] Use same OpBuilder throughout bufferization

Matthias Springer llvmlistbot at llvm.org
Fri Dec 3 17:03:18 PST 2021


Author: Matthias Springer
Date: 2021-12-04T09:57:26+09:00
New Revision: 6db200736c51a61834fd2e192d8a5fd71e0874b4

URL: https://github.com/llvm/llvm-project/commit/6db200736c51a61834fd2e192d8a5fd71e0874b4
DIFF: https://github.com/llvm/llvm-project/commit/6db200736c51a61834fd2e192d8a5fd71e0874b4.diff

LOG: [mlir][linalg][bufferize][NFC] Use same OpBuilder throughout bufferization

Also set insertion point right before calling `bufferize`. No need to put an InsertionGuard anymore.

Differential Revision: https://reviews.llvm.org/D114928

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index ca7711dd032f4..1b3b3ff2f12e1 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -297,7 +297,8 @@ struct DialectBufferizationState {
 /// the results of the analysis.
 struct BufferizationState {
   BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options)
-      : aliasInfo(moduleOp), options(options) {}
+      : aliasInfo(moduleOp), options(options),
+        builder(moduleOp->getContext()) {}
 
   // BufferizationState should be passed as a reference.
   BufferizationState(const BufferizationState &) = delete;
@@ -321,6 +322,11 @@ struct BufferizationState {
   /// Return `true` if the given value is mapped.
   bool isMapped(Value value) const;
 
+  /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
+  /// a new buffer and copy over data from the existing buffer if out-of-place
+  /// bufferization is necessary.
+  Value getResultBuffer(OpResult result);
+
   /// Mark `op` as obsolete, so that it is deleted after bufferization.
   void markOpObsolete(Operation *op);
 
@@ -349,12 +355,10 @@ struct BufferizationState {
 
   /// A reference to current bufferization options.
   const BufferizationOptions &options;
-};
 
-/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
-/// a new buffer and copy over data from the existing buffer if out-of-place
-/// bufferization is necessary.
-Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
+  /// The OpBuilder used during bufferization.
+  OpBuilder builder;
+};
 
 /// Bufferize all ops in the given region.
 LogicalResult bufferize(Region *region, BufferizationState &state);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index 8fb8f919c5ea6..fa0c96275daf0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -26,21 +26,14 @@ struct ConstantOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto constantOp = cast<arith::ConstantOp>(op);
-    if (!constantOp.getResult().getType().isa<TensorType>())
-      return success();
     assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
            "not a constant ranked tensor");
     auto moduleOp = constantOp->getParentOfType<ModuleOp>();
-    if (!moduleOp) {
+    if (!moduleOp)
       return constantOp.emitError(
           "cannot bufferize constants not within builtin.module op");
-    }
-    GlobalCreator globalCreator(moduleOp);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(constantOp);
 
+    GlobalCreator globalCreator(moduleOp);
     auto globalMemref = globalCreator.getGlobalFor(constantOp);
     Value memref = b.create<memref::GetGlobalOp>(
         constantOp.getLoc(), globalMemref.type(), globalMemref.getName());

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 4ed9ec9fee10a..4348fe4d5ad28 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -372,15 +372,15 @@ Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
 /// a new buffer and copy over data from the existing buffer if out-of-place
 /// bufferization is necessary.
-Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
-    OpBuilder &b, OpResult result, BufferizationState &state) {
-  OpBuilder::InsertionGuard guard(b);
+Value mlir::linalg::comprehensive_bufferize::BufferizationState::
+    getResultBuffer(OpResult result) {
+  OpBuilder::InsertionGuard guard(builder);
   Operation *op = result.getOwner();
   SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
   assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
   OpOperand *opOperand = aliasingOperands.front();
   Value operand = opOperand->get();
-  Value operandBuffer = state.lookupBuffer(operand);
+  Value operandBuffer = lookupBuffer(operand);
   // Make sure that all OpOperands are the same buffer. If this is not the case,
   // we would have to materialize a memref value.
   // TODO: Should be looking for checking for "equivalent buffers" instead of
@@ -388,14 +388,14 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
   // set up yet.
   if (aliasingOperands.size() > 1 &&
       !llvm::all_of(aliasingOperands, [&](OpOperand *o) {
-        return state.lookupBuffer(o->get()) == operandBuffer;
+        return lookupBuffer(o->get()) == operandBuffer;
       })) {
     op->emitError("result buffer is ambiguous");
     return Value();
   }
 
   // If bufferizing out-of-place, allocate a new buffer.
-  if (!state.aliasInfo.isInPlace(result)) {
+  if (!aliasInfo.isInPlace(result)) {
     // Ops with multiple aliasing operands can currently not bufferize
     // out-of-place.
     assert(
@@ -404,9 +404,9 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
     Location loc = op->getLoc();
     // Move insertion point right after `operandBuffer`. That is where the
     // allocation should be inserted (in the absence of allocation hoisting).
-    setInsertionPointAfter(b, operandBuffer);
+    setInsertionPointAfter(builder, operandBuffer);
     // Allocate the result buffer.
-    Value resultBuffer = state.createAllocDeallocFn(b, loc, operandBuffer);
+    Value resultBuffer = createAllocDeallocFn(builder, loc, operandBuffer);
     bool skipCopy = false;
     // Do not copy if the last preceding write of `operand` is an op that does
     // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -427,9 +427,9 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
       skipCopy = true;
     if (!skipCopy) {
       // The copy happens right before the op that is bufferized.
-      b.setInsertionPoint(op);
-      state.options.allocationFns->memCpyFn(b, loc, operandBuffer,
-                                            resultBuffer);
+      builder.setInsertionPoint(op);
+      options.allocationFns->memCpyFn(builder, loc, operandBuffer,
+                                      resultBuffer);
     }
     return resultBuffer;
   }
@@ -459,7 +459,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
 LogicalResult
 mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
                                                  BufferizationState &state) {
-  OpBuilder b(op->getContext());
+  OpBuilder &b = state.builder;
 
   // Check if op has tensor results or operands.
   auto isaTensor = [](Type t) { return t.isa<TensorType>(); };

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 5f62c97ff2651..3f14975386a51 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -38,7 +38,7 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
     OpResult opResult = cast<BufferizableOpInterface>(op.getOperation())
                             .getAliasingOpResult(*opOperand);
     assert(opResult && "could not find correspond OpResult");
-    Value resultBuffer = getResultBuffer(b, opResult, state);
+    Value resultBuffer = state.getResultBuffer(opResult);
     if (!resultBuffer)
       return failure();
     resultBuffers.push_back(resultBuffer);
@@ -158,10 +158,6 @@ struct InitTensorOpInterface
     if (initTensorOp->getUses().empty())
       return success();
 
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(initTensorOp);
-
     Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(),
                                              initTensorOp.result());
     state.mapBuffer(initTensorOp.result(), alloc);
@@ -250,7 +246,7 @@ struct TiledLoopOpInterface
 
       const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
       OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
-      Value resultBuffer = getResultBuffer(b, opResult, state);
+      Value resultBuffer = state.getResultBuffer(opResult);
       if (!resultBuffer)
         return failure();
 
@@ -350,11 +346,6 @@ struct YieldOpInterface
                           BufferizationState &state) const {
     auto yieldOp = cast<linalg::YieldOp>(op);
 
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    // Cannot create IR past a yieldOp.
-    b.setInsertionPoint(yieldOp);
-
     // No tensors -> success.
     if (!llvm::any_of(yieldOp.getOperandTypes(),
                       [](Type t) { return t.isa<TensorType>(); }))

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 94b14bc5622e0..30f51d5d2ca38 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -469,10 +469,6 @@ struct CallOpInterface
            "expected Callop to a FuncOp");
     ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
 
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(callOp);
-
     // 1. Filter return types:
     //    - if the callee is bodiless / external, we cannot inspect it and we
     //      cannot assume anything. We can just assert that it does not return a
@@ -600,14 +596,9 @@ struct ReturnOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto returnOp = cast<ReturnOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    // Cannot insert after returnOp.
-    b.setInsertionPoint(returnOp);
-
     assert(isa<FuncOp>(returnOp->getParentOp()) &&
            "only support FuncOp parent for ReturnOp");
+
     for (OpOperand &operand : returnOp->getOpOperands()) {
       auto tensorType = operand.get().getType().dyn_cast<TensorType>();
       if (!tensorType)
@@ -628,9 +619,6 @@ struct FuncOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto funcOp = cast<FuncOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
     b.setInsertionPointToStart(&funcOp.body().front());
 
     // Create BufferCastOps for function args.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 8632c65ca7c66..5dc434335e1ef 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -138,7 +138,7 @@ struct IfOpInterface
       assert(opResult.getType().isa<RankedTensorType>() &&
              "unsupported unranked tensor");
 
-      Value resultBuffer = getResultBuffer(b, opResult, state);
+      Value resultBuffer = state.getResultBuffer(opResult);
       if (!resultBuffer)
         return failure();
 
@@ -204,7 +204,7 @@ struct ForOpInterface
              "unsupported unranked tensor");
 
       // TODO: More general: Matching bbArg does not bufferize to a read.
-      Value resultBuffer = getResultBuffer(b, opResult, state);
+      Value resultBuffer = state.getResultBuffer(opResult);
       if (!resultBuffer)
         return failure();
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index d944d81267dc8..11333807dd7a8 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -60,11 +60,7 @@ struct CastOpInterface
                           BufferizationState &state) const {
     auto castOp = cast<tensor::CastOp>(op);
 
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(castOp);
-
-    Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
+    Value resultBuffer = state.getResultBuffer(castOp->getResult(0));
     if (!resultBuffer)
       return failure();
     Type sourceType = resultBuffer.getType();
@@ -107,11 +103,6 @@ struct DimOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto dimOp = cast<tensor::DimOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(dimOp);
-
     if (dimOp.source().getType().isa<RankedTensorType>()) {
       Value v = state.lookupBuffer(dimOp.source());
       dimOp.result().replaceAllUsesWith(
@@ -145,11 +136,6 @@ struct ExtractSliceOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(extractSliceOp);
-
     Location loc = extractSliceOp.getLoc();
     Value srcMemref = state.lookupBuffer(extractSliceOp.source());
     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
@@ -207,11 +193,6 @@ struct ExtractOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(extractOp);
-
     Location loc = extractOp.getLoc();
     Value srcMemref = state.lookupBuffer(extractOp.tensor());
     Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
@@ -245,13 +226,8 @@ struct InsertOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto insertOp = cast<tensor::InsertOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(insertOp);
-
     Location loc = insertOp.getLoc();
-    Value destMemref = getResultBuffer(b, insertOp->getOpResult(0), state);
+    Value destMemref = state.getResultBuffer(insertOp->getOpResult(0));
     b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
                               insertOp.indices());
     state.mapBuffer(insertOp, destMemref);
@@ -419,15 +395,11 @@ struct InsertSliceOpInterface
     // catastrophically bad scheduling decision.
     // TODO: be very loud about it or even consider failing the pass.
     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
-    TensorBufferizationState &tensorState = getTensorBufferizationState(state);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(insertSliceOp);
     Location loc = insertSliceOp.getLoc();
+    TensorBufferizationState &tensorState = getTensorBufferizationState(state);
 
     // When bufferizing out-of-place, `getResultBuffer` allocates.
-    Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
+    Value dstMemref = state.getResultBuffer(insertSliceOp->getResult(0));
     if (!dstMemref)
       return failure();
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index d582b88f1eb0a..3fafa75aa79bf 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -39,14 +39,10 @@ struct TransferReadOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto transferReadOp = cast<vector::TransferReadOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(op);
-
-    // TransferReadOp always reads from the bufferized op.source().
     assert(transferReadOp.getShapedType().isa<TensorType>() &&
            "only tensor types expected");
+
+    // TransferReadOp always reads from the bufferized op.source().
     Value v = state.lookupBuffer(transferReadOp.source());
     transferReadOp.sourceMutable().assign(v);
     return success();
@@ -81,17 +77,13 @@ struct TransferWriteOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto writeOp = cast<vector::TransferWriteOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(op);
+    assert(writeOp.getShapedType().isa<TensorType>() &&
+           "only tensor types expected");
 
     // Create a new transfer_write on buffer that doesn't have a return value.
     // Leave the previous transfer_write to dead code as it still has uses at
     // this point.
-    assert(writeOp.getShapedType().isa<TensorType>() &&
-           "only tensor types expected");
-    Value resultBuffer = getResultBuffer(b, op->getResult(0), state);
+    Value resultBuffer = state.getResultBuffer(op->getResult(0));
     if (!resultBuffer)
       return failure();
     b.create<vector::TransferWriteOp>(


        


More information about the Mlir-commits mailing list