[Mlir-commits] [mlir] 4170141 - [mlir][linalg][bufferize] Replace remaining bvm usage with new API

Matthias Springer llvmlistbot at llvm.org
Wed Dec 15 06:25:01 PST 2021


Author: Matthias Springer
Date: 2021-12-15T23:21:39+09:00
New Revision: 417014170bd581186277c5e899a2338c482db69d

URL: https://github.com/llvm/llvm-project/commit/417014170bd581186277c5e899a2338c482db69d
DIFF: https://github.com/llvm/llvm-project/commit/417014170bd581186277c5e899a2338c482db69d.diff

LOG: [mlir][linalg][bufferize] Replace remaining bvm usage with new API

* Call `replaceOp` instead of `mapBuffer`.
* Remove bvm and all helper functions around bvm.
* Simplify FuncOp bufferization and rely on existing functionality to generate ToMemrefOps for function BlockArguments.

Differential Revision: https://reviews.llvm.org/D115515

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 53c1840b18c6d..35955af49efaa 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -268,6 +268,9 @@ class BufferizationAliasInfo {
   llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
 };
 
+/// Return `true` if the given value is a BlockArgument of a FuncOp.
+bool isFunctionArgument(Value value);
+
 /// Determine which OpOperand* will alias with `result` if the op is bufferized
 /// in place. Return an empty vector if the op is not bufferizable.
 SmallVector<OpOperand *> getAliasingOpOperand(OpResult result);
@@ -342,18 +345,18 @@ struct DialectBufferizationState {
   DialectBufferizationState(const DialectBufferizationState &) = delete;
 };
 
-/// BufferizationState keeps track of memory buffers and provides a variety of
-/// helper functions for dealing with them. In particular,
+/// BufferizationState provides a variety of helper functions for dealing with
+/// tensor values and memref buffers. In particular,
 /// `BufferizableOpInterface::bufferize` implementation should utilize the
 /// following helper functions.
 ///
 /// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops
 ///   that allocate and/or deallocate memref buffers.
-/// * `mapBuffer` maps a tensor value to a memref buffer during bufferization.
-/// * `lookupBuffer` returns the mapped memref buffer of a given tensor value.
+/// * `lookupBuffer` returns the memref buffer of a given tensor value.
 /// * `getResultBuffer` returns the memref buffer for a given tensor OpResult.
 ///   Based on inplace bufferization decisions of the analysis, it may either
 ///   directly return a mapped buffer or allocate a new brand new buffer.
+/// * `replaceOp` replaces an op with new values.
 class BufferizationState {
 public:
   BufferizationState(Operation *op, const BufferizationOptions &options)
@@ -378,16 +381,19 @@ class BufferizationState {
   /// Creates a memcpy between two given buffers.
   void createMemCpy(OpBuilder &b, Location loc, Value from, Value to);
 
-  /// Replace an op with replacement values. The op is deleted.
+  /// Replace an op with replacement values. The op is deleted. Tensor OpResults
+  /// must be replaced with memref values.
   void replaceOp(Operation *op, ValueRange values);
 
-  /// Map tensor values to memref buffers.
-  // TODO: Deprecated. Remove all uses of this op. Use `replaceOp` instead.
-  void mapBuffer(ValueRange tensors, ValueRange buffers);
-
-  /// Map a tensor value to a memref buffer.
-  // TODO: Deprecated. Remove all uses of this op. Use `replaceOp` instead.
-  void mapBuffer(Value tensor, Value buffer);
+  /// Replace an op with a new op. Tensor OpResults must be replaced with memref
+  /// values.
+  template <typename OpTy, typename... Args>
+  OpTy replaceOpWithNewOp(OpBuilder &b, Operation *op, Args &&...args) {
+    Operation *newOp =
+        b.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
+    replaceOp(op, newOp->getResults());
+    return cast<OpTy>(newOp);
+  }
 
   /// Lookup the memref buffer that is associated to the given tensor value.
   /// Asserts if no buffer is associated.
@@ -396,23 +402,11 @@ class BufferizationState {
   /// Return `true` if the given OpResult has been decided to bufferize inplace.
   bool isInPlace(OpResult opResult) const;
 
-  /// Return `true` if the given value is mapped.
-  // TODO: Deprecated. Remove all uses of this op.
-  bool isMapped(Value value) const;
-
   /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
   /// a new buffer and copy over data from the existing buffer if out-of-place
   /// bufferization is necessary.
   Value getResultBuffer(OpResult result);
 
-  /// Mark `op` as obsolete, so that it is deleted after bufferization.
-  // TODO: Deprecated. Remove all uses of this op.
-  void markOpObsolete(Operation *op);
-
-  /// Erase all ops that were marked obsolete.
-  // TODO: Deprecated. Remove all uses of this op.
-  void eraseObsoleteOps();
-
   /// Return dialect-specific bufferization state.
   template <typename StateT> StateT &getDialectState(StringRef name) {
     // Create state if it does not exist yet.
@@ -441,12 +435,6 @@ class BufferizationState {
   /// functions and `runComprehensiveBufferize` may access this object.
   BufferizationAliasInfo aliasInfo;
 
-  /// The mapping of tensors to buffers.
-  BlockAndValueMapping mapping;
-
-  /// Obsolete ops that should be deleted after bufferization.
-  SmallVector<Operation *> obsoleteOps;
-
   /// Dialect-specific bufferization state.
   DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index 1f3eaab91d2c5..e370d3f430421 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -35,10 +35,8 @@ struct ConstantOpInterface
 
     GlobalCreator globalCreator(moduleOp);
     auto globalMemref = globalCreator.getGlobalFor(constantOp);
-    Value memref = b.create<memref::GetGlobalOp>(
-        constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
-    state.mapBuffer(constantOp, memref);
-
+    state.replaceOpWithNewOp<memref::GetGlobalOp>(b, op, globalMemref.type(),
+                                                  globalMemref.getName());
     return success();
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index e33562f3b6f9f..d6d3d28a1022f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -498,19 +498,6 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
   if (!state.getOptions().allowUnknownOps)
     return op->emitError() << "unsupported op with tensors";
 
-  // Replace all OpOperands with "to-tensor casted" bufferized values.
-  for (OpOperand &operand : op->getOpOperands()) {
-    if (operand.get().getType().isa<TensorType>() &&
-        state.isMapped(operand.get())) {
-      assert(state.getOptions().allowUnknownOps &&
-             "unsupported op error should have been emitted earlier");
-      b.setInsertionPoint(op);
-      Value toTensorOp = b.create<bufferization::ToTensorOp>(
-          op->getLoc(), state.lookupBuffer(operand.get()));
-      operand.set(toTensorOp);
-    }
-  }
-
   // Bufferize all regions.
   for (Region &region : op->getRegions())
     if (failed(bufferize(&region, state)))
@@ -654,38 +641,13 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
 // Bufferization-specific BlockAndValueMapping support with debugging.
 //===----------------------------------------------------------------------===//
 
-/// Wrapper for better debugging.
-void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
-    ValueRange tensors, ValueRange buffers) {
-  assert(!tensors.empty() && "unexpected empty tensors");
-#ifndef NDEBUG
-  for (Value tensor : tensors) {
-    assert(tensor && "unexpected empty tensor");
-    assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
-  }
-  for (Value buffer : buffers) {
-    assert(buffer && "unexpected empty buffer");
-    assert((buffer.getType().isa<MemRefType>() ||
-            buffer.getType().isa<UnrankedMemRefType>()) &&
-           "expected that tensor is mapped to memref");
-  }
-#endif // NDEBUG
-  return mapping.map(tensors, buffers);
-}
-
-/// Wrapper for better debugging.
-void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
-    Value tensor, Value buffer) {
-  assert(tensor && "unexpected empty tensor");
-  assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
-  assert(buffer && "unexpected empty buffer");
-  assert((buffer.getType().isa<MemRefType>() ||
-          buffer.getType().isa<UnrankedMemRefType>()) &&
-         "expected that tensor is mapped to memref");
-  return mapping.map(tensor, buffer);
+bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
+  auto bbArg = value.dyn_cast<BlockArgument>();
+  if (!bbArg)
+    return false;
+  return isa<FuncOp>(bbArg.getOwner()->getParentOp());
 }
 
-/// Wrapper for better debugging.
 Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
     Value tensor) {
   assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
@@ -694,37 +656,29 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
     return toTensorOp.memref();
 
-  Value buffer = mapping.lookupOrNull(tensor);
-  if (!buffer) {
-    if (options.allowUnknownOps) {
-      // `tensor` was not bufferized yet. This should never happen with
-      // bufferizable ops.
-      assert(!options.dynCastBufferizableOp(tensor) && "tensor is not mapped");
-      // Insert to_memref op.
-      OpBuilder b(tensor.getContext());
-      setInsertionPointAfter(b, tensor);
-      return b.create<bufferization::ToMemrefOp>(
-          tensor.getLoc(),
-          getDynamicMemRefType(tensor.getType().cast<RankedTensorType>()),
-          tensor);
+  if (!isFunctionArgument(tensor)) {
+    if (static_cast<bool>(options.dynCastBufferizableOp(tensor))) {
+      // Dump tensor for easier debugging.
+      tensor.dump();
+      llvm_unreachable("op is known, but has not been bufferized yet");
+      return Value();
+    }
+    if (!options.allowUnknownOps) {
+      // Dump tensor for easier debugging.
+      tensor.dump();
+      // Note: An assertion should already have failed earlier.
+      llvm_unreachable("unknown ops are not allowed");
+      return Value();
     }
-
-    // Dump tensor for easier debugging.
-    tensor.dump();
-    llvm_unreachable("tensor is not mapped");
-    return Value();
   }
 
-  assert((buffer.getType().isa<MemRefType>() ||
-          buffer.getType().isa<UnrankedMemRefType>()) &&
-         "expected that tensor is mapped to memref");
-  return buffer;
-}
-
-bool mlir::linalg::comprehensive_bufferize::BufferizationState::isMapped(
-    Value value) const {
-  assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
-  return mapping.contains(value);
+  // Insert to_memref op.
+  OpBuilder &b = getBuilder();
+  OpBuilder::InsertionGuard g(b);
+  setInsertionPointAfter(b, tensor);
+  return b.create<bufferization::ToMemrefOp>(
+      tensor.getLoc(),
+      getDynamicMemRefType(tensor.getType().cast<RankedTensorType>()), tensor);
 }
 
 bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
@@ -732,18 +686,6 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
   return aliasInfo.isInPlace(opResult);
 }
 
-void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete(
-    Operation *op) {
-  obsoleteOps.push_back(op);
-}
-
-void mlir::linalg::comprehensive_bufferize::BufferizationState::
-    eraseObsoleteOps() {
-  for (Operation *op : obsoleteOps)
-    op->erase();
-  obsoleteOps.clear();
-}
-
 MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
     ShapedType shapedType, MemRefLayoutAttrInterface layout,
     Attribute memorySpace) {

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
index 6836a02dfd3db..97345350835a9 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -63,21 +63,9 @@ struct ToMemrefOpInterface
     // If a ToMemrefOp's tensor operand has not been bufferized yet, the op
     // remains unchanged. All IR up to this ToMemrefOp has already been
     // bufferized, unless there were unknown ops that could be bufferized.
-    if (!state.isMapped(toMemrefOp.tensor())) {
-      assert(state.getOptions().allowUnknownOps &&
-             "expected that tensor is mapped");
-      return success();
-    }
-
-    // If a ToMemrefOp's tensor operand has been bufferized, the op can be
-    // removed.
-    Value memref = state.lookupBuffer(toMemrefOp.tensor());
-    // Do not replace a ToMemrefOp with itself. E.g., when bufferizing a
-    // function body, ToMemrefOps were inserted before starting bufferization of
-    // the function body. Such ToMemrefOps are replaced in a separate step after
-    // the function body has been bufferized.
-    if (toMemrefOp.getResult() != memref)
-      toMemrefOp.replaceAllUsesWith(memref);
+    assert((isFunctionArgument(toMemrefOp.tensor()) ||
+            state.getOptions().allowUnknownOps) &&
+           "expected that tensor is mapped");
 
     return success();
   }
@@ -98,8 +86,6 @@ struct ToTensorOpInterface
                                                     bufferization::ToTensorOp> {
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
-    auto tensorLoadOp = cast<bufferization::ToTensorOp>(op);
-    state.mapBuffer(tensorLoadOp.result(), tensorLoadOp.memref());
     return success();
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 028be806236c4..b9dde90e63ee2 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -699,8 +699,5 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
   if (failed(bufferize(op, state)))
     return failure();
 
-  // Erase all obsolete ops.
-  state.eraseObsoleteOps();
-
   return success();
 }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 451923d768fbb..9984ae1ad122c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -56,7 +56,6 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
     if (!resultBuffer)
       return failure();
     newOutputBuffers.push_back(resultBuffer);
-    state.mapBuffer(opResult, resultBuffer);
   }
 
   // Clone the newly bufferized op.
@@ -68,7 +67,9 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
   auto bufferizedOp = cast<LinalgOp>(
       op.clone(b, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
 
-  // The original op will be DCE'd away later.
+  // Replace the results of the old op with the new output buffers.
+  state.replaceOp(op, newOutputBuffers);
+
   return comprehensive_bufferize::bufferize(bufferizedOp.getBlock(), state);
 }
 
@@ -194,7 +195,7 @@ struct InitTensorOpInterface
 
     Value alloc = state.createAllocDeallocPair(b, initTensorOp->getLoc(),
                                                initTensorOp.result());
-    state.mapBuffer(initTensorOp.result(), alloc);
+    state.replaceOp(op, alloc);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index cb0d75cb366f7..0e391a9a4f04a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -551,9 +551,6 @@ struct CallOpInterface
                   .equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
           Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
           Value buffer = state.lookupBuffer(callOp->getOperand(idx));
-          // Add CallOp operand/result equivalence: this is interprocedural
-          // info.
-          state.mapBuffer(oldRes, buffer);
           // Add a ToTensorOp to kill all uses of the CallOp return.
           // Replace all uses of the CallOp results so we can erase the CallOp.
           // This ToTensorOp must fold/DCE away or bufferization should be
@@ -561,8 +558,6 @@ struct CallOpInterface
           Value toTensorOp =
               b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
           oldRes.replaceAllUsesWith(toTensorOp);
-          // Add new op equivalence info.
-          state.mapBuffer(toTensorOp, buffer);
           continue;
         }
 
@@ -603,8 +598,6 @@ struct CallOpInterface
       if (buffer.getType() != memRefType) {
         Value castBuffer =
             b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
-        // Add new op equivalence info.
-        state.mapBuffer(tensorOperand, castBuffer);
         buffer = castBuffer;
       }
       newOperands.push_back(buffer);
@@ -616,7 +609,7 @@ struct CallOpInterface
     newCallOp->setAttrs(callOp->getAttrs());
 
     // 5. Delete the op at the end of bufferization.
-    state.markOpObsolete(callOp);
+    callOp->erase();
 
     return success();
   }
@@ -651,7 +644,6 @@ struct ReturnOpInterface
       Value returnTensor = b.create<bufferization::ToTensorOp>(
           returnOp.getLoc(), v);
       operand.set(returnTensor);
-      state.mapBuffer(returnTensor, v);
     }
     return success();
   }
@@ -662,23 +654,6 @@ struct FuncOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto funcOp = cast<FuncOp>(op);
-    b.setInsertionPointToStart(&funcOp.body().front());
-
-    // Create BufferCastOps for function args.
-    for (auto bbArg : funcOp.getArguments()) {
-      auto tensorType = bbArg.getType().dyn_cast<TensorType>();
-      if (!tensorType)
-        continue;
-      auto rankedTensorType = tensorType.dyn_cast<RankedTensorType>();
-      // Cast the tensor to the most dynamic buffer possible. Further
-      // canonicalizations will clean up.
-      Type memRefType = rankedTensorType
-                            ? getDynamicMemRefType(rankedTensorType)
-                            : getContiguousOrUnrankedMemRefType(tensorType);
-      Value bufferCast = b.create<bufferization::ToMemrefOp>(funcOp.getLoc(),
-                                                             memRefType, bbArg);
-      state.mapBuffer(bbArg, bufferCast);
-    }
 
     // Bufferize function body.
     return comprehensive_bufferize::bufferize(&funcOp.body(), state);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index cfc04be793b12..7558d792facf6 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -78,9 +78,7 @@ struct CastOpInterface
             : MemRefLayoutAttrInterface();
     Type memRefType = getContiguousOrUnrankedMemRefType(
         castOp.getResult().getType(), layout, memorySpace);
-    Value res =
-        b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
-    state.mapBuffer(castOp.getResult(), res);
+    state.replaceOpWithNewOp<memref::CastOp>(b, op, memRefType, resultBuffer);
     return success();
   }
 };
@@ -103,11 +101,10 @@ struct DimOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto dimOp = cast<tensor::DimOp>(op);
-    if (dimOp.source().getType().isa<RankedTensorType>()) {
-      Value v = state.lookupBuffer(dimOp.source());
-      dimOp.result().replaceAllUsesWith(
-          b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
-    }
+    if (!dimOp.source().getType().isa<RankedTensorType>())
+      return dimOp.emitError("unranked tensor not supported");
+    Value v = state.lookupBuffer(dimOp.source());
+    state.replaceOpWithNewOp<memref::DimOp>(b, op, v, dimOp.index());
     return success();
   }
 };
@@ -168,7 +165,7 @@ struct ExtractSliceOpInterface
       subView = alloc;
     }
 
-    state.mapBuffer(extractSliceOp.result(), subView);
+    state.replaceOp(op, subView);
     return success();
   }
 };
@@ -191,10 +188,9 @@ struct ExtractOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
-    Location loc = extractOp.getLoc();
     Value srcMemref = state.lookupBuffer(extractOp.tensor());
-    Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
-    extractOp.replaceAllUsesWith(l);
+    state.replaceOpWithNewOp<memref::LoadOp>(b, op, srcMemref,
+                                             extractOp.indices());
     return success();
   }
 };
@@ -228,7 +224,7 @@ struct InsertOpInterface
     Value destMemref = state.getResultBuffer(insertOp->getOpResult(0));
     b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
                               insertOp.indices());
-    state.mapBuffer(insertOp, destMemref);
+    state.replaceOp(op, destMemref);
     return success();
   }
 
@@ -423,7 +419,7 @@ struct InsertSliceOpInterface
       state.createMemCpy(b, insertSliceOp.getLoc(), srcMemref, subView);
     }
 
-    state.mapBuffer(insertSliceOp.result(), dstMemref);
+    state.replaceOp(op, dstMemref);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index c2f33b876fff3..3ccfb5065ed21 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -38,13 +38,17 @@ struct TransferReadOpInterface
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
-    auto transferReadOp = cast<vector::TransferReadOp>(op);
-    assert(transferReadOp.getShapedType().isa<TensorType>() &&
+    auto readOp = cast<vector::TransferReadOp>(op);
+    assert(readOp.getShapedType().isa<TensorType>() &&
            "only tensor types expected");
 
     // TransferReadOp always reads from the bufferized op.source().
-    Value v = state.lookupBuffer(transferReadOp.source());
-    transferReadOp.sourceMutable().assign(v);
+    Value buffer = state.lookupBuffer(readOp.source());
+    Value read = b.create<vector::TransferReadOp>(
+        readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(),
+        readOp.permutation_map(), readOp.padding(), readOp.mask(),
+        readOp.in_boundsAttr());
+    state.replaceOp(op, read);
     return success();
   }
 };
@@ -90,7 +94,7 @@ struct TransferWriteOpInterface
     b.create<vector::TransferWriteOp>(
         writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
         writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
-    state.mapBuffer(op->getResult(0), resultBuffer);
+    state.replaceOp(op, resultBuffer);
 
     return success();
   }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
index 24f8de7aac0e6..971d6c6e88a2f 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
@@ -30,11 +30,11 @@ func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
   // CHECK: %[[dim:.*]] = tensor.dim %[[A]]
   // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
   // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+  // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
   // CHECK: memref.copy %[[A_memref]], %[[casted]]
   // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
   %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
 
-  // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
   // CHECK: return %[[res_tensor]]
   return %0 : tensor<?xf32>
 }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
index de8717f22d72d..967b231d73134 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
@@ -52,10 +52,10 @@ func @use_of_unknown_op_3(%t1: tensor<?xf32> {linalg.inplaceable = true})
     -> (vector<5xf32>, vector<5xf32>) {
   %idx = arith.constant 0 : index
   %cst = arith.constant 0.0 : f32
+  // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
   // CHECK: %[[v1:.*]] = vector.transfer_read %[[m1]]
   %1 = vector.transfer_read %t1[%idx], %cst : tensor<?xf32>, vector<5xf32>
 
-  // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
   // CHECK: %[[dummy:.*]] = "test.dummy_op"(%[[m1_tensor]])
   %0 = "test.dummy_op"(%t1) : (tensor<?xf32>) -> tensor<?xf32>
   // CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]]
@@ -114,11 +114,11 @@ func @use_of_bufferizable_op_in_unbufferizable_op(
 func @unused_unknown_op(%t1 : tensor<?xf32>) -> vector<5xf32> {
   %idx = arith.constant 0 : index
   %cst = arith.constant 0.0 : f32
+  // ToTensorOp is inserted to pass in the result of the above bufferized op.
+  // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
   // CHECK: vector.transfer_read %[[m1]]
   %1 = vector.transfer_read %t1[%idx], %cst : tensor<?xf32>, vector<5xf32>
 
-  // ToTensorOp is inserted to pass in the result of the above bufferized op.
-  // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
   // CHECK: "test.dummy_op"(%[[m1_tensor]])
   "test.dummy_op"(%t1) : (tensor<?xf32>) -> ()
 
@@ -158,10 +158,10 @@ func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
   %c0 = arith.constant 0 : index
   // CHECK-TENSOR: %[[alloc:.*]] = memref.alloc
   // CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]]
+  // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
   // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]]
   // CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]]
   %0 = tensor.insert %f into %t1[%c0] : tensor<?xf32>
-  // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
   // CHECK-TENSOR: return %[[casted_tensor]]
   return %0 : tensor<?xf32>
 }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 1094c21ed0537..f55f3008f2bd2 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -168,25 +168,25 @@ func @insert_slice_fun(%A0 : tensor<?xf32>,
   ->  (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
 {
   // Hoisted allocs.
-  //      CHECK: %[[REALLOC_A0_2:.*]] = memref.alloc
-  //      CHECK: %[[REALLOC_A0:.*]] = memref.alloc
-  //      CHECK: %[[REALLOC_A1:.*]] = memref.alloc
+  //      CHECK: %[[REALLOC1:.*]] = memref.alloc
+  //      CHECK: %[[REALLOC2:.*]] = memref.alloc
+  //      CHECK: %[[REALLOC3:.*]] = memref.alloc
 
   // Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
-  //      CHECK: linalg.copy(%[[A0]], %[[REALLOC_A0]]
-  //      CHECK: %[[SV_A0:.*]] = memref.subview %[[REALLOC_A0]]
+  //      CHECK: linalg.copy(%[[A0]], %[[REALLOC3]]
+  //      CHECK: %[[SV_A0:.*]] = memref.subview %[[REALLOC3]]
   //      CHECK: linalg.copy(%[[t0]], %[[SV_A0]])
   %r0 = tensor.insert_slice %t0 into %A0[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   // Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
   //      CHECK: linalg.copy(%[[A0]]
-  //      CHECK: %[[SV_A0_2:.*]] = memref.subview %[[REALLOC_A0_2]]
+  //      CHECK: %[[SV_A0_2:.*]] = memref.subview %[[REALLOC2]]
   //      CHECK: linalg.copy(%[[t1]], %[[SV_A0_2]])
   %r1 = tensor.insert_slice %t1 into %A0[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   //  Still alloc the large tensor because %A1 is read after. Copy the tensor.extract_slice.
   //      CHECK: linalg.copy(%[[A1]]
-  //      CHECK: %[[SV_A1:.*]] = memref.subview %[[REALLOC_A1]]
+  //      CHECK: %[[SV_A1:.*]] = memref.subview %[[REALLOC1]]
   //      CHECK: linalg.copy(%[[t0]], %[[SV_A1]])
   %r2 = tensor.insert_slice %t0 into %A1[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
@@ -196,7 +196,7 @@ func @insert_slice_fun(%A0 : tensor<?xf32>,
   //      CHECK: linalg.copy(%[[t1]], %[[SV_A1_2]])
   %r3 = tensor.insert_slice %t1 into %A1[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
-  //      CHECK: return %[[REALLOC_A0]], %[[REALLOC_A0_2]], %[[REALLOC_A1]] :
+  //      CHECK: return %[[REALLOC3]], %[[REALLOC2]], %[[REALLOC1]] :
   // CHECK-SAME:   memref<?xf32>, memref<?xf32>, memref<?xf32>
   return %r0, %r1, %r2, %r3: tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>
 }


        


More information about the Mlir-commits mailing list