[Mlir-commits] [mlir] 7ce427e - [mlir][linalg][bufferize][NFC] Clean up BufferizationState

Matthias Springer llvmlistbot at llvm.org
Mon Dec 6 17:08:40 PST 2021


Author: Matthias Springer
Date: 2021-12-07T10:05:39+09:00
New Revision: 7ce427e3bc0bcf50cb2f3b5944852219be03db9e

URL: https://github.com/llvm/llvm-project/commit/7ce427e3bc0bcf50cb2f3b5944852219be03db9e
DIFF: https://github.com/llvm/llvm-project/commit/7ce427e3bc0bcf50cb2f3b5944852219be03db9e.diff

LOG: [mlir][linalg][bufferize][NFC] Clean up BufferizationState

Make fields private and clean up the interface. In particular, BufferizableOpInterface::bufferize no longer has access to `aliasInfo`. This was potentially dangerous because some of the ops registered in BufferizationAliasInfo may have been deleted.

Differential Revision: https://reviews.llvm.org/D114931

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index a76007ad9a97a..1e0c96f0114a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -29,7 +29,10 @@ namespace comprehensive_bufferize {
 // TODO: from some HW description.
 static constexpr int64_t kBufferAlignments = 128;
 
-struct BufferizationState;
+class BufferizationAliasInfo;
+struct BufferizationOptions;
+class BufferizationState;
+struct PostAnalysisStep;
 
 /// Callback functions that are used to allocate/deallocate/copy memory buffers.
 /// Comprehensive Bufferize provides default implementations of these functions.
@@ -68,6 +71,7 @@ struct PostAnalysisStep {
   /// `aliasInfo` (inside `state`) consistent. Newly created operations and
   /// operations that should be re-analyzed must be stored in `newOps`.
   virtual LogicalResult run(FuncOp funcOp, BufferizationState &state,
+                            BufferizationAliasInfo &aliasInfo,
                             SmallVector<Operation *> &newOps) = 0;
 };
 
@@ -281,9 +285,20 @@ struct DialectBufferizationState {
   virtual ~DialectBufferizationState() = default;
 };
 
-/// BufferizationState keeps track of bufferization state and provides access to
-/// the results of the analysis.
-struct BufferizationState {
+/// BufferizationState keeps track of memory buffers and provides a variety of
+/// helper functions for dealing with them. In particular,
+/// `BufferizableOpInterface::bufferize` implementation should utilize the
+/// following helper functions.
+///
+/// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops
+///   that allocate and/or deallocate memref buffers.
+/// * `mapBuffer` maps a tensor value to a memref buffer during bufferization.
+/// * `lookupBuffer` returns the mapped memref buffer of a given tensor value.
+/// * `getResultBuffer` returns the memref buffer for a given tensor OpResult.
+///   Based on inplace bufferization decisions of the analysis, it may either
+///   directly return a mapped buffer or allocate a new brand new buffer.
+class BufferizationState {
+public:
   BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options)
       : aliasInfo(moduleOp), options(options),
         builder(moduleOp->getContext()) {}
@@ -291,11 +306,21 @@ struct BufferizationState {
   // BufferizationState should be passed as a reference.
   BufferizationState(const BufferizationState &) = delete;
 
-  /// A function that creates an alloc-dealloc pair. This function may perform
-  /// additional optimizations such as buffer allocation hoisting. This function
-  /// calls `allocationFn` and `deallocationFn` to create (de)allocations.
-  Value createAllocDeallocFn(OpBuilder &builder, Location loc,
-                             Value shapedValue);
+  /// Creates a memref allocation.
+  Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+                              ArrayRef<Value> dynShape);
+
+  /// Creates an alloc-dealloc pair. This function may perform additional
+  /// optimizations such as buffer allocation hoisting.
+  Value createAllocDeallocPair(OpBuilder &builder, Location loc,
+                               Value shapedValue);
+
+  /// Creates a memref deallocation. The given memref buffer must have been
+  /// allocated using `createAlloc`.
+  void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer);
+
+  /// Creates a memcpy between two given buffers.
+  void createMemCpy(OpBuilder &b, Location loc, Value from, Value to);
 
   /// Map tensor values to memref buffers.
   void mapBuffer(ValueRange tensors, ValueRange buffers);
@@ -307,6 +332,9 @@ struct BufferizationState {
   /// Asserts if no buffer is associated.
   Value lookupBuffer(Value tensor);
 
+  /// Return `true` if the given OpResult has been decided to bufferize inplace.
+  bool isInPlace(OpResult opResult) const;
+
   /// Return `true` if the given value is mapped.
   bool isMapped(Value value) const;
 
@@ -329,7 +357,24 @@ struct BufferizationState {
     return static_cast<StateT &>(*dialectState[name]);
   }
 
-  /// `aliasInfo` keeps track of aliasing and equivalent values.
+  /// Return a reference to the BufferizationOptions.
+  const BufferizationOptions &getOptions() const { return options; }
+
+  /// Return a reference to the OpBuilder.
+  OpBuilder &getBuilder() { return builder; }
+
+private:
+  friend LogicalResult
+  runComprehensiveBufferize(FuncOp funcOp, const BufferizationOptions &options,
+                            BufferizationState &state,
+                            const PostAnalysisStepList &extraSteps);
+
+  friend LogicalResult
+  runComprehensiveBufferize(ModuleOp moduleOp,
+                            const BufferizationOptions &options);
+
+  /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
+  /// functions and `runComprehensiveBufferize` may access this object.
   BufferizationAliasInfo aliasInfo;
 
   /// The mapping of tensors to buffers.
@@ -428,7 +473,7 @@ struct AllocationHoistingBarrierOnly
     auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
     if (any_of(op->getOperandTypes(), isaTensor) ||
         any_of(op->getResultTypes(), isaTensor))
-      if (!state.options.allowUnknownOps)
+      if (!state.getOptions().allowUnknownOps)
         return op->emitError() << "unsupported op with tensors";
 
     for (Region &region : op->getRegions())

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
index 185878701d710..9b7cb9421dd98 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
@@ -35,6 +35,7 @@ struct InitTensorEliminationStep : public PostAnalysisStep {
   ///   This analysis can be skipped with `skipAnalysis`.
   LogicalResult eliminateInitTensors(
       FuncOp funcOp, BufferizationState &state,
+      BufferizationAliasInfo &aliasInfo,
       std::function<bool(OpOperand &)> anchorMatchFunc,
       std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
       SmallVector<Operation *> &newOps);
@@ -46,6 +47,7 @@ struct InitTensorEliminationStep : public PostAnalysisStep {
 struct InsertSliceAnchoredInitTensorEliminationStep
     : public InitTensorEliminationStep {
   LogicalResult run(FuncOp funcOp, BufferizationState &state,
+                    BufferizationAliasInfo &aliasInfo,
                     SmallVector<Operation *> &newOps) override;
 };
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
index 3ab5cc3525fc3..3e557e962b2aa 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
@@ -23,6 +23,7 @@ namespace scf_ext {
 /// equivalent to their corresponding loop yield values.
 struct AssertDestinationPassingStyle : public PostAnalysisStep {
   LogicalResult run(FuncOp funcOp, BufferizationState &state,
+                    BufferizationAliasInfo &aliasInfo,
                     SmallVector<Operation *> &newOps) override;
 };
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
index dbda53743d9b0..61b1f9356d545 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
@@ -21,6 +21,7 @@ namespace tensor_ext {
 
 struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep {
   LogicalResult run(FuncOp funcOp, BufferizationState &state,
+                    BufferizationAliasInfo &aliasInfo,
                     SmallVector<Operation *> &newOps) override;
 };
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 7682c4ae49393..4cb4c5bd38166 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -367,7 +367,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
     // allocation should be inserted (in the absence of allocation hoisting).
     setInsertionPointAfter(builder, operandBuffer);
     // Allocate the result buffer.
-    Value resultBuffer = createAllocDeallocFn(builder, loc, operandBuffer);
+    Value resultBuffer = createAllocDeallocPair(builder, loc, operandBuffer);
     bool skipCopy = false;
     // Do not copy if the last preceding write of `operand` is an op that does
     // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -389,8 +389,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
     if (!skipCopy) {
       // The copy happens right before the op that is bufferized.
       builder.setInsertionPoint(op);
-      options.allocationFns->memCpyFn(builder, loc, operandBuffer,
-                                      resultBuffer);
+      createMemCpy(builder, loc, operandBuffer, resultBuffer);
     }
     return resultBuffer;
   }
@@ -420,7 +419,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
 LogicalResult
 mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
                                                  BufferizationState &state) {
-  OpBuilder &b = state.builder;
+  OpBuilder &b = state.getBuilder();
 
   // Check if op has tensor results or operands.
   auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@@ -443,7 +442,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
   }
 
   // `op` is an unbufferizable tensor op.
-  if (!state.options.allowUnknownOps)
+  if (!state.getOptions().allowUnknownOps)
     return op->emitError() << "unsupported op with tensors";
 
   // Replace all OpOperands with "to-tensor casted" bufferized values.
@@ -550,7 +549,7 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a
 /// bbArg) and the DeallocOp is at the end of the block.
 Value mlir::linalg::comprehensive_bufferize::BufferizationState::
-    createAllocDeallocFn(OpBuilder &b, Location loc, Value shapedValue) {
+    createAllocDeallocPair(OpBuilder &b, Location loc, Value shapedValue) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -561,8 +560,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
   // Note: getAllocationTypeAndShape also sets the insertion point.
   MemRefType allocMemRefType =
       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
-  Optional<Value> allocated =
-      options.allocationFns->allocationFn(b, loc, allocMemRefType, dynShape);
+  Optional<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
   // TODO: For now just assert the value is returned. Eventually need to
   // error-propagate.
   assert(allocated && "allocation failed");
@@ -573,10 +571,29 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
 
   // 2. Create memory deallocation.
   b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
-  options.allocationFns->deallocationFn(b, loc, allocated.getValue());
+  createDealloc(b, loc, allocated.getValue());
   return casted;
 }
 
+/// Create a memref allocation.
+Optional<Value>
+mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
+    OpBuilder &b, Location loc, MemRefType type, ArrayRef<Value> dynShape) {
+  return options.allocationFns->allocationFn(b, loc, type, dynShape);
+}
+
+/// Create a memref deallocation.
+void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc(
+    OpBuilder &b, Location loc, Value allocatedBuffer) {
+  return options.allocationFns->deallocationFn(b, loc, allocatedBuffer);
+}
+
+/// Create a memory copy between two memref buffers.
+void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
+    OpBuilder &b, Location loc, Value from, Value to) {
+  return options.allocationFns->memCpyFn(b, loc, from, to);
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific BlockAndValueMapping support with debugging.
 //===----------------------------------------------------------------------===//
@@ -648,9 +665,15 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
 
 bool mlir::linalg::comprehensive_bufferize::BufferizationState::isMapped(
     Value value) const {
+  assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
   return mapping.contains(value);
 }
 
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
+    OpResult opResult) const {
+  return aliasInfo.isInPlace(opResult);
+}
+
 void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete(
     Operation *op) {
   obsoleteOps.push_back(op);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 6cbbda3f97146..19b6361027514 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -732,7 +732,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
   auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
     for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
       SmallVector<Operation *> newOps;
-      if (failed(step->run(funcOp, state, newOps)))
+      if (failed(step->run(funcOp, state, aliasInfo, newOps)))
         return failure();
       // Analyze ops that were created by the PostAnalysisStep.
       if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index d9231c0445164..4d66f44724a3f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -159,8 +159,8 @@ struct InitTensorOpInterface
     if (initTensorOp->getUses().empty())
       return success();
 
-    Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(),
-                                             initTensorOp.result());
+    Value alloc = state.createAllocDeallocPair(b, initTensorOp->getLoc(),
+                                               initTensorOp.result());
     state.mapBuffer(initTensorOp.result(), alloc);
     return success();
   }
@@ -379,11 +379,11 @@ struct LinalgOpInterfaceHelper<> {
 LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
     InitTensorEliminationStep::eliminateInitTensors(
         FuncOp funcOp, BufferizationState &state,
+        BufferizationAliasInfo &aliasInfo,
         std::function<bool(OpOperand &)> anchorMatchFunc,
         std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
         SmallVector<Operation *> &newOps) {
   OpBuilder b(funcOp->getContext());
-  BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
   WalkResult status = funcOp->walk([&](Operation *op) {
     for (OpOperand &operand : op->getOpOperands()) {
@@ -474,16 +474,16 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
 LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
     InsertSliceAnchoredInitTensorEliminationStep::run(
         FuncOp funcOp, BufferizationState &state,
-        SmallVector<Operation *> &newOps) {
+        BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
   return eliminateInitTensors(
-      funcOp, state,
+      funcOp, state, aliasInfo,
       [&](OpOperand &operand) {
         auto insertSliceOp =
             dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
         if (!insertSliceOp)
           return false;
         // Only inplace bufferized InsertSliceOps are eligible.
-        if (!state.aliasInfo.isInPlace(insertSliceOp->getOpResult(0)))
+        if (!aliasInfo.isInPlace(insertSliceOp->getOpResult(0)))
           return false;
         return &operand == &insertSliceOp->getOpOperand(0) /*source*/;
       },

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index ebe7a5feb8d00..3a64ae711b9fc 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -88,6 +88,7 @@ struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
   }
 
   LogicalResult run(FuncOp funcOp, BufferizationState &state,
+                    BufferizationAliasInfo &aliasInfo,
                     SmallVector<Operation *> &newOps) override {
     ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
 
@@ -99,12 +100,12 @@ struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
       if (returnVal.get().getType().isa<RankedTensorType>())
         for (BlockArgument bbArg : funcOp.getArguments())
           if (bbArg.getType().isa<RankedTensorType>())
-            if (state.aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
-                                                              bbArg)) {
+            if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
+                                                        bbArg)) {
               moduleState
                   .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
                   bbArg.getArgNumber();
-              if (state.options.testAnalysisOnly)
+              if (state.getOptions().testAnalysisOnly)
                 annotateReturnOp(returnVal, bbArg);
             }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 55156f949635d..db70cfd571523 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -265,6 +265,7 @@ struct ForOpInterface
 
 LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
     AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state,
+                                       BufferizationAliasInfo &aliasInfo,
                                        SmallVector<Operation *> &newOps) {
   LogicalResult status = success();
   funcOp->walk([&](scf::YieldOp yieldOp) {
@@ -280,8 +281,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
       OpOperand &forOperand = forOp.getOpOperandForResult(
           forOp->getResult(operand.getOperandNumber()));
       auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-      if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
-                                                         bbArg)) {
+      if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
         // TODO: this could get resolved with copies but it can also turn into
         // swaps so we need to be careful about order of copies.
         status =

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 7f1bdb703d18d..21872d8407bea 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -144,10 +144,10 @@ struct ExtractSliceOpInterface
         extractSliceOp.result().getType().cast<RankedTensorType>();
 
     // If not inplaceable, alloc.
-    bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
+    bool inplace = state.isInPlace(extractSliceOp->getResult(0));
     Value alloc;
     if (!inplace)
-      alloc = state.createAllocDeallocFn(b, loc, extractSliceOp.result());
+      alloc = state.createAllocDeallocPair(b, loc, extractSliceOp.result());
 
     // Bufferize to subview.
     auto subviewMemRefType =
@@ -159,15 +159,12 @@ struct ExtractSliceOpInterface
     Value subView = b.create<memref::SubViewOp>(
         loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
         extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
-    // Insert new alias.
-    state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
 
     /// If not inplaceable, copy.
     if (!inplace) {
       // Do not copy if the copied data is never read.
       if (isValueRead(extractSliceOp.result()))
-        state.options.allocationFns->memCpyFn(b, extractSliceOp.getLoc(),
-                                              subView, alloc);
+        state.createMemCpy(b, extractSliceOp.getLoc(), subView, alloc);
       subView = alloc;
     }
 
@@ -421,8 +418,7 @@ struct InsertSliceOpInterface
           insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
       // Copy tensor.
       Value srcMemref = state.lookupBuffer(insertSliceOp.source());
-      state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(),
-                                            srcMemref, subView);
+      state.createMemCpy(b, insertSliceOp.getLoc(), srcMemref, subView);
     }
 
     state.mapBuffer(insertSliceOp.result(), dstMemref);
@@ -437,6 +433,7 @@ struct InsertSliceOpInterface
 
 LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext::
     InplaceInsertSliceOpAnalysis::run(FuncOp funcOp, BufferizationState &state,
+                                      BufferizationAliasInfo &aliasInfo,
                                       SmallVector<Operation *> &newOps) {
   auto &tensorState = getTensorBufferizationState(state);
   funcOp.walk([&](InsertSliceOp insertSliceOp) {
@@ -445,9 +442,9 @@ LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext::
     //     slice is computed out of place into the inplace full tensor.
     //   - The result is not inplace. This is the case where the whole tensor is
     //     cloned and the clone needs to be updated.
-    if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo,
+    if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo,
                                                            insertSliceOp) &&
-        state.aliasInfo.isInPlace(insertSliceOp->getResult(0)))
+        state.isInPlace(insertSliceOp->getResult(0)))
       tensorState.insertSliceOpsWithoutCopy.insert(insertSliceOp);
   });
   return success();


        


More information about the Mlir-commits mailing list