[Mlir-commits] [mlir] 9e24f0f - [mlir][bufferize] Do not deallocate allocs that are returned from a block

Matthias Springer llvmlistbot at llvm.org
Wed Mar 16 02:59:38 PDT 2022


Author: Matthias Springer
Date: 2022-03-16T18:59:27+09:00
New Revision: 9e24f0f4589dfdbc405f72eddd174af7511b2ff3

URL: https://github.com/llvm/llvm-project/commit/9e24f0f4589dfdbc405f72eddd174af7511b2ff3
DIFF: https://github.com/llvm/llvm-project/commit/9e24f0f4589dfdbc405f72eddd174af7511b2ff3.diff

LOG: [mlir][bufferize] Do not deallocate allocs that are returned from a block

Such IR is rejected by default, but can be allowed with `allow-return-memref`. In preparation of future refactorings, do not deallocate such buffers.

One-Shot Analysis now gathers information about yielded tensors, so that we know during the actual bufferization whether a newly allocated buffer should be deallocated again. (Otherwise, it will leak. This will be addressed in a subsequent commit that also makes `allow-return-memref` a non-experimental flag.)

As a cleanup, `allow-return-memref` is now part of OneShotBufferizationOptions. (It was previously ignored by AlwaysCopyBufferizationState.) Moreover, AlwaysCopyBufferizationState now asserts that `create-deallocs` is deactivated to prevent surprising behavior.

Differential Revision: https://reviews.llvm.org/D121521

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 6860bec2386ab..68136fda97384 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -177,10 +177,6 @@ struct BufferizationOptions {
   Optional<DeallocationFn> deallocationFn;
   Optional<MemCpyFn> memCpyFn;
 
-  /// Specifies whether returning newly allocated memrefs should be allowed.
-  /// Otherwise, a pass failure is triggered.
-  bool allowReturnMemref = false;
-
   /// Specifies whether not bufferizable ops are allowed in the input. If so,
   /// bufferization.to_memref and bufferization.to_tensor ops are inserted at
   /// the boundaries.
@@ -356,7 +352,14 @@ class AnalysisState {
   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
   virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0;
 
-  /// Return dialect-specific analysis state.
+  /// Return true if the given tensor (or an aliasing tensor) is yielded from
+  /// the containing block. Also include all aliasing tensors in the same block.
+  ///
+  /// Note: In the absence of an analysis, an implementation may return true for
+  /// any given tensor.
+  virtual bool isTensorYielded(Value tensor) const = 0;
+
+  /// Return dialect-specific bufferization state.
   template <typename StateT>
   Optional<const StateT *> getDialectState(StringRef name) const {
     auto it = dialectState.find(name);
@@ -415,6 +418,10 @@ class AlwaysCopyAnalysisState : public AnalysisState {
 
   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
   bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
+
+  /// Return true if the given tensor (or an aliasing tensor) is yielded from
+  /// the containing block. Also include all aliasing tensors in the same block.
+  bool isTensorYielded(Value tensor) const override;
 };
 
 /// BufferizationState provides helper functions for performing bufferization
@@ -423,14 +430,20 @@ struct BufferizationState {
   BufferizationState(const AnalysisState &analysisState)
       : analysisState(analysisState) {}
 
-  /// Creates a memref allocation with the given type and dynamic extents.
-  FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
-                               ValueRange dynShape);
-
-  /// Creates a memref allocation for the given shaped value. This function may
-  /// perform additional optimizations such as buffer allocation hoisting.
-  // TODO: Allocation hoisting should be a cleanup pass.
-  FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue);
+  /// Creates a memref allocation for the given shaped value. `dealloc`
+  /// indicates whether the buffer should be deallocated or not. When `dealloc`
+  /// is `false`, this would create a memory leak, unless the buffer is
+  /// deallocated through some other mechanism.
+  ///
+  /// `dealloc` is optional. By default, this function will figure out by itself
+  /// if it is safe to deallocate the buffer. In essence, when returning the
+  /// buffer from a block, it is not safe to deallocate the buffer. This
+  /// information is queried via `AnalysisState::isTensorYielded`.
+  ///
+  /// Note: `shapedValue` is typically a tensor value. However, if it is a
+  /// memref value, `dealloc` is no longer optional and must be specified.
+  FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
+                               Optional<bool> dealloc = None);
 
   /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
   /// a new buffer and copy over data from the existing buffer if out-of-place

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index de555988dd549..2a954f3ea1036 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -43,6 +43,10 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
 
   /// Registered post analysis steps.
   PostAnalysisStepList postAnalysisSteps;
+
+  /// Specifies whether returning newly allocated memrefs should be allowed.
+  /// Otherwise, a pass failure is triggered.
+  bool allowReturnMemref = false;
 };
 
 /// The BufferizationAliasInfo class maintains a list of buffer aliases and
@@ -153,10 +157,22 @@ class OneShotAnalysisState : public AnalysisState {
   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
   bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
 
+  /// Return true if the given tensor (or an aliasing tensor) is yielded from
+  /// the containing block. Also include all aliasing tensors in the same block.
+  bool isTensorYielded(Value tensor) const override;
+
+  /// Find all tensors that are yielded/returned from a block and store them in
+  /// `yieldedTensors`. Also include all aliasing tensors in the same block.
+  void gatherYieldedTensors(Operation *op);
+
 private:
   /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
   /// functions and `runOneShotBufferize` may access this object.
   BufferizationAliasInfo aliasInfo;
+
+  /// A set of all tensors (and maybe aliasing tensors) that yielded from a
+  /// block.
+  DenseSet<Value> yieldedTensors;
 };
 
 /// Analyze `op` and its nested ops. Bufferization decisions are stored in

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index cc697487b07fb..7d21c76d58502 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -42,8 +42,12 @@ constexpr const ::llvm::StringLiteral
 constexpr const ::llvm::StringLiteral
     bufferization::BufferizableOpInterface::kInplaceableAttrName;
 
+/// Attribute name used to mark allocs that are created by the bufferization.
 static const char *kBufferAllocationAttr = "bufferization.allocation";
 
+/// Attribute name used to mark allocs that should not be deallocated.
+static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
+
 //===----------------------------------------------------------------------===//
 // BufferizationOptions
 //===----------------------------------------------------------------------===//
@@ -253,6 +257,8 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
   OpBuilder::InsertionGuard guard(rewriter);
   Operation *op = opOperand.getOwner();
   Location loc = op->getLoc();
+  SmallVector<OpResult> aliasingOpResults =
+      analysisState.getAliasingOpResult(opOperand);
   Value operand = opOperand.get();
   Value operandBuffer = lookupBuffer(rewriter, operand, options);
 
@@ -263,8 +269,13 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
   // Move insertion point right after `operandBuffer`. That is where the
   // allocation should be inserted (in the absence of allocation hoisting).
   setInsertionPointAfter(rewriter, operandBuffer);
-  // Allocate the result buffer.
-  FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer);
+  // Allocate the result buffer. The buffer should be deallocated if the tensor
+  // is not yielded and deallocs are enabled in general.
+  bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
+    return getAnalysisState().isTensorYielded(v);
+  });
+  FailureOr<Value> resultBuffer = createAlloc(
+      rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
   if (failed(resultBuffer))
     return failure();
   // Do not copy if the last preceding writes of `operand` are ops that do
@@ -281,8 +292,6 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
       }))
     return resultBuffer;
   // Do not copy if the copied data is never read.
-  SmallVector<OpResult> aliasingOpResults =
-      analysisState.getAliasingOpResult(opOperand);
   if (!aliasingOpResults.empty() &&
       !analysisState.bufferizesToMemoryRead(opOperand) &&
       llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
@@ -339,7 +348,12 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
 
 AlwaysCopyAnalysisState::AlwaysCopyAnalysisState(
     const BufferizationOptions &options)
-    : AnalysisState(options) {}
+    : AnalysisState(options) {
+  // Note: Allocations must be deallocated with a subsequent run of the buffer
+  // deallocation pass.
+  assert(!options.createDeallocs &&
+         "cannot create deallocs with AlwaysCopyBufferizationState");
+}
 
 /// Return `true` if the given OpResult has been decided to bufferize inplace.
 bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const {
@@ -356,6 +370,13 @@ bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
   return false;
 }
 
+/// Return true if the given tensor (or an aliasing tensor) is yielded from
+/// the containing block. Also include all aliasing tensors in the same block.
+bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const {
+  // There is no analysis, so conservatively answer "true".
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific scoped alloc/dealloc insertion support.
 //===----------------------------------------------------------------------===//
@@ -426,37 +447,54 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
 }
 
 static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
-                                    ValueRange dynShape) {
+                                    ValueRange dynShape, bool skipDealloc) {
   auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
   allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
+  if (skipDealloc)
+    allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr());
   return allocaOp.getResult();
 }
 
 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
 /// block in case of a bbArg).
 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
-                                                 Value shapedValue) {
+                                                 Value shapedValue,
+                                                 Optional<bool> dealloc) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
+
+  // Compute allocation memref type.
   assert(shapedValue.getType().isa<ShapedType>());
   MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
   SmallVector<Value> dynShape;
   MemRefType allocMemRefType =
       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
-  Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape);
+
+  // Should be the buffer be deallocated again or should we let it leak?
+  bool skipDealloc;
+  if (dealloc) {
+    skipDealloc = !dealloc.getValue();
+  } else {
+    assert(shapedValue.getType().isa<TensorType>() &&
+           "must specify `dealloc` if non-tensor value is passed");
+    // Buffer should be not be deallocated if deallocs are generally deactivated
+    // or if the tensor is yielded from a block.
+    skipDealloc = !getOptions().createDeallocs ||
+                  getAnalysisState().isTensorYielded(shapedValue);
+  }
+
+  // Create the buffer allocation.
+  Value alloc =
+      createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc);
+
+  // Insert a cast if a 
diff erent type was requested.
   if (memRefType && memRefType != allocMemRefType) {
-    assert(memref::CastOp::areCastCompatible(alloc.getType(), memRefType) &&
+    assert(memref::CastOp::areCastCompatible(allocMemRefType, memRefType) &&
            "createAlloc: cast incompatible");
     alloc = b.create<memref::CastOp>(loc, memRefType, alloc);
   }
-  return alloc;
-}
 
-/// Create a memref allocation with the given type and dynamic extents.
-FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
-                                                 MemRefType type,
-                                                 ValueRange dynShape) {
-  return createBufferAllocation(b, loc, type, dynShape);
+  return alloc;
 }
 
 /// Create a memory copy between two memref buffers.
@@ -480,7 +518,9 @@ createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
     // Ignore memref.alloca ops that were not created by the bufferization.
     if (!allocaOp->hasAttr(kBufferAllocationAttr))
       return WalkResult::skip();
+    bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr);
 
+    // Create alloc.
     Block *block = allocaOp->getBlock();
     rewriter.setInsertionPoint(allocaOp);
     FailureOr<Value> alloc =
@@ -490,10 +530,11 @@ createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
       return WalkResult::interrupt();
     rewriter.replaceOp(allocaOp, *alloc);
 
-    // Stop here if deallocations are deactivated.
-    if (!options.createDeallocs)
+    // Stop here if the buffer should not be deallocated.
+    if (skipDealloc)
       return WalkResult::advance();
 
+    // Create dealloc.
     rewriter.setInsertionPoint(block->getTerminator());
     if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
       return WalkResult::interrupt();

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index f237cb7a6a70e..85a5c5c120d2b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -379,7 +379,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
 
 BufferizationOptions bufferization::getPartialBufferizationOptions() {
   BufferizationOptions options;
-  options.allowReturnMemref = true;
   options.allowUnknownOps = true;
   options.createDeallocs = false;
   options.fullyDynamicLayoutMaps = false;

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 706072d7b9c10..b0bad7a32f2fb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -215,6 +215,43 @@ bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
   return aliasInfo.areEquivalentBufferizedValues(v1, v2);
 }
 
+// Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
+// to ensure that such information is available during bufferization time.
+// Alias information can no longer be queried through BufferizationAliasInfo
+// once we have started modifying the IR.
+void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
+  op->walk([&](Operation *returnOp) {
+    if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
+      return WalkResult::advance();
+
+    for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
+      Value returnVal = returnValOperand.get();
+      // Skip non-tensor values.
+      if (!returnVal.getType().isa<TensorType>())
+        continue;
+
+      // Add all aliases of the returned value. But only the ones that are in
+      // the same block.
+      aliasInfo.applyOnAliases(returnVal, [&](Value v) {
+        if (auto bbArg = v.dyn_cast<BlockArgument>()) {
+          if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
+            yieldedTensors.insert(bbArg);
+          return;
+        }
+        Operation *definingOp = v.getDefiningOp();
+        if (definingOp->getParentOp() == returnOp->getParentOp())
+          yieldedTensors.insert(v);
+      });
+    }
+
+    return WalkResult::advance();
+  });
+}
+
+bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
+  return yieldedTensors.contains(tensor);
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific alias analysis.
 //===----------------------------------------------------------------------===//
@@ -780,6 +817,9 @@ LogicalResult bufferization::analyzeOp(Operation *op,
         failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps));
   }
 
+  // Gather all yielded tensors.
+  state.gatherYieldedTensors(op);
+
   // Analysis verification: After setting up alias/equivalence sets, each op
   // can check for expected invariants/limitations and fail the analysis if
   // necessary.

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 67e28c46f3969..0efebdfc9d41a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -335,9 +335,8 @@ struct FromElementsOpInterface
     Location loc = op->getLoc();
     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
     auto shape = tensorType.getShape();
-    MemRefType resultType = getContiguousMemRefType(tensorType);
     FailureOr<Value> maybeBuffer =
-        state.createAlloc(rewriter, loc, resultType, {});
+        state.createAlloc(rewriter, loc, fromElementsOp.result());
     if (failed(maybeBuffer))
       return failure();
     Value buffer = *maybeBuffer;
@@ -386,8 +385,8 @@ struct GenerateOpInterface
     Location loc = op->getLoc();
     MemRefType memrefType =
         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
-    FailureOr<Value> maybeResult = state.createAlloc(
-        rewriter, loc, memrefType, generateOp.dynamicExtents());
+    FailureOr<Value> maybeResult =
+        state.createAlloc(rewriter, loc, generateOp.result());
     if (failed(maybeResult))
       return failure();
     Value result = *maybeResult;

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index 0ea283fc9f6cc..f0fe50c522b32 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -68,4 +68,67 @@ func @empty_func() -> () {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func @read_after_write_conflict(
+func @read_after_write_conflict(%cst : f32, %idx : index, %idx2 : index)
+    -> (f32, f32) {
+  // CHECK-DAG: %[[alloc:.*]] = memref.alloc
+  // CHECK-DAG: %[[dummy:.*]] = "test.dummy_op"
+  // CHECK-DAG: %[[dummy_m:.*]] = bufferization.to_memref %[[dummy]]
+  %t = "test.dummy_op"() : () -> (tensor<10xf32>)
+
+  // CHECK: memref.copy %[[dummy_m]], %[[alloc]]
+  // CHECK: memref.store %{{.*}}, %[[alloc]]
+  %write = tensor.insert %cst into %t[%idx2] : tensor<10xf32>
+
+  // CHECK: %[[read:.*]] = "test.some_use"(%[[dummy]])
+  %read = "test.some_use"(%t) : (tensor<10xf32>) -> (f32)
+  // CHECK: %[[read2:.*]] = memref.load %[[alloc]]
+  %read2 = tensor.extract %write[%idx] : tensor<10xf32>
+
+  // CHECK: memref.dealloc %[[alloc]]
+  // CHECK: return %[[read]], %[[read2]]
+  return %read, %read2 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @copy_deallocated(
+func @copy_deallocated() -> tensor<10xf32> {
+  // CHECK: %[[alloc:.*]] = memref.alloc()
+  %0 = linalg.init_tensor[10] : tensor<10xf32>
+  // CHECK: %[[alloc_tensor:.*]] = bufferization.to_tensor %[[alloc]]
+  // CHECK: memref.dealloc %[[alloc]]
+  // CHECK: return %[[alloc_tensor]]
+  return %0 : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @buffer_not_deallocated(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>
+func @buffer_not_deallocated(%t : tensor<?xf32>, %c : i1) -> tensor<?xf32> {
+  // CHECK: %[[r:.*]] = scf.if %{{.*}} {
+  %r = scf.if %c -> tensor<?xf32> {
+    // CHECK: %[[some_op:.*]] = "test.some_op"
+    // CHECK: %[[alloc:.*]] = memref.alloc(%[[some_op]])
+    // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+    // CHECK-NOT: dealloc
+    // CHECK: scf.yield %[[casted]]
+    %sz = "test.some_op"() : () -> (index)
+    %0 = linalg.init_tensor[%sz] : tensor<?xf32>
+    scf.yield %0 : tensor<?xf32>
+  } else {
+  // CHECK: } else {
+    // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]]
+    // CHECK: scf.yield %[[m]]
+    scf.yield %t : tensor<?xf32>
+  }
+  // CHECK: }
+  // CHECK-NOT: dealloc
+  // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
+  // CHECK: return %[[r_tensor]]
+  return %r : tensor<?xf32>
+}
 

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 53c3a603ca03d..a39fb207aa050 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -81,7 +81,6 @@ func @not_inplace(
   //     CHECK: linalg.fill ins(%[[F0]] : f32) outs(%[[ALLOC]] : memref<?xf32>)
   %r = linalg.fill ins(%f0 : f32) outs(%A : tensor<?xf32>) -> tensor<?xf32>
 
-  //     CHECK:  dealloc %[[ALLOC]] : memref<?xf32>
   //     CHECK:  return %[[ALLOC]] : memref<?xf32>
   return %r: tensor<?xf32>
 }
@@ -292,7 +291,6 @@ func @insert_slice_fun_not_inplace(
   //      CHECK: memref.copy %[[A]], %[[ALLOC]] : memref<?xf32{{.*}} to memref<?xf32>
   //      CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref<?xf32> to memref<4xf32>
   //      CHECK: memref.copy %[[t]], %[[SV]] : memref<4xf32, #map> to memref<4xf32>
-  //      CHECK: memref.dealloc %[[ALLOC]] : memref<?xf32>
   %r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   //     CHECK: return %{{.*}} : memref<?xf32>
@@ -329,7 +327,6 @@ func @scf_for_yield_only(%A : tensor<?xf32> {linalg.inplaceable = false},
     scf.yield %t : tensor<?xf32>
   }
 
-  //     CHECK:   memref.dealloc %[[ALLOC_FOR_A]] : memref<?xf32>
   //     CHECK:   return %[[CASTED]] : memref<?xf32, #[[$map_1d_dyn]]>
   return %r0, %r1: tensor<?xf32>, tensor<?xf32>
 }
@@ -395,7 +392,6 @@ func @scf_for_with_tensor.insert_slice(
     scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
   }
 
-  //     CHECK:  memref.dealloc %[[ALLOC_FOR_A]] : memref<?xf32>
   //     CHECK:  return %[[CASTED]] : memref<?xf32, #[[$map_1d_dyn]]>
   return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
 }


        


More information about the Mlir-commits mailing list