[Mlir-commits] [mlir] 9e24f0f - [mlir][bufferize] Do not deallocate allocs that are returned from a block
Matthias Springer
llvmlistbot at llvm.org
Wed Mar 16 02:59:38 PDT 2022
Author: Matthias Springer
Date: 2022-03-16T18:59:27+09:00
New Revision: 9e24f0f4589dfdbc405f72eddd174af7511b2ff3
URL: https://github.com/llvm/llvm-project/commit/9e24f0f4589dfdbc405f72eddd174af7511b2ff3
DIFF: https://github.com/llvm/llvm-project/commit/9e24f0f4589dfdbc405f72eddd174af7511b2ff3.diff
LOG: [mlir][bufferize] Do not deallocate allocs that are returned from a block
Such IR is rejected by default, but can be allowed with `allow-return-memref`. In preparation of future refactorings, do not deallocate such buffers.
One-Shot Analysis now gathers information about yielded tensors, so that we know during the actual bufferization whether a newly allocated buffer should be deallocated again. (Otherwise, it will leak. This will be addressed in a subsequent commit that also makes `allow-return-memref` a non-experimental flag.)
As a cleanup, `allow-return-memref` is now part of OneShotBufferizationOptions. (It was previously ignored by AlwaysCopyBufferizationState.) Moreover, AlwaysCopyBufferizationState now asserts that `create-deallocs` is deactivated to prevent surprising behavior.
Differential Revision: https://reviews.llvm.org/D121521
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 6860bec2386ab..68136fda97384 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -177,10 +177,6 @@ struct BufferizationOptions {
Optional<DeallocationFn> deallocationFn;
Optional<MemCpyFn> memCpyFn;
- /// Specifies whether returning newly allocated memrefs should be allowed.
- /// Otherwise, a pass failure is triggered.
- bool allowReturnMemref = false;
-
/// Specifies whether not bufferizable ops are allowed in the input. If so,
/// bufferization.to_memref and bufferization.to_tensor ops are inserted at
/// the boundaries.
@@ -356,7 +352,14 @@ class AnalysisState {
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0;
- /// Return dialect-specific analysis state.
+ /// Return true if the given tensor (or an aliasing tensor) is yielded from
+ /// the containing block. Also include all aliasing tensors in the same block.
+ ///
+ /// Note: In the absence of an analysis, an implementation may return true for
+ /// any given tensor.
+ virtual bool isTensorYielded(Value tensor) const = 0;
+
+ /// Return dialect-specific bufferization state.
template <typename StateT>
Optional<const StateT *> getDialectState(StringRef name) const {
auto it = dialectState.find(name);
@@ -415,6 +418,10 @@ class AlwaysCopyAnalysisState : public AnalysisState {
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
+
+ /// Return true if the given tensor (or an aliasing tensor) is yielded from
+ /// the containing block. Also include all aliasing tensors in the same block.
+ bool isTensorYielded(Value tensor) const override;
};
/// BufferizationState provides helper functions for performing bufferization
@@ -423,14 +430,20 @@ struct BufferizationState {
BufferizationState(const AnalysisState &analysisState)
: analysisState(analysisState) {}
- /// Creates a memref allocation with the given type and dynamic extents.
- FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ValueRange dynShape);
-
- /// Creates a memref allocation for the given shaped value. This function may
- /// perform additional optimizations such as buffer allocation hoisting.
- // TODO: Allocation hoisting should be a cleanup pass.
- FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue);
+ /// Creates a memref allocation for the given shaped value. `dealloc`
+ /// indicates whether the buffer should be deallocated or not. When `dealloc`
+ /// is `false`, this would create a memory leak, unless the buffer is
+ /// deallocated through some other mechanism.
+ ///
+ /// `dealloc` is optional. By default, this function will figure out by itself
+ /// if it is safe to deallocate the buffer. In essence, when returning the
+ /// buffer from a block, it is not safe to deallocate the buffer. This
+ /// information is queried via `AnalysisState::isTensorYielded`.
+ ///
+ /// Note: `shapedValue` is typically a tensor value. However, if it is a
+ /// memref value, `dealloc` is no longer optional and must be specified.
+ FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
+ Optional<bool> dealloc = None);
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index de555988dd549..2a954f3ea1036 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -43,6 +43,10 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
/// Registered post analysis steps.
PostAnalysisStepList postAnalysisSteps;
+
+ /// Specifies whether returning newly allocated memrefs should be allowed.
+ /// Otherwise, a pass failure is triggered.
+ bool allowReturnMemref = false;
};
/// The BufferizationAliasInfo class maintains a list of buffer aliases and
@@ -153,10 +157,22 @@ class OneShotAnalysisState : public AnalysisState {
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
+ /// Return true if the given tensor (or an aliasing tensor) is yielded from
+ /// the containing block. Also include all aliasing tensors in the same block.
+ bool isTensorYielded(Value tensor) const override;
+
+ /// Find all tensors that are yielded/returned from a block and store them in
+ /// `yieldedTensors`. Also include all aliasing tensors in the same block.
+ void gatherYieldedTensors(Operation *op);
+
private:
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
/// functions and `runOneShotBufferize` may access this object.
BufferizationAliasInfo aliasInfo;
+
+ /// A set of all tensors (and maybe aliasing tensors) that yielded from a
+ /// block.
+ DenseSet<Value> yieldedTensors;
};
/// Analyze `op` and its nested ops. Bufferization decisions are stored in
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index cc697487b07fb..7d21c76d58502 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -42,8 +42,12 @@ constexpr const ::llvm::StringLiteral
constexpr const ::llvm::StringLiteral
bufferization::BufferizableOpInterface::kInplaceableAttrName;
+/// Attribute name used to mark allocs that are created by the bufferization.
static const char *kBufferAllocationAttr = "bufferization.allocation";
+/// Attribute name used to mark allocs that should not be deallocated.
+static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
+
//===----------------------------------------------------------------------===//
// BufferizationOptions
//===----------------------------------------------------------------------===//
@@ -253,6 +257,8 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
OpBuilder::InsertionGuard guard(rewriter);
Operation *op = opOperand.getOwner();
Location loc = op->getLoc();
+ SmallVector<OpResult> aliasingOpResults =
+ analysisState.getAliasingOpResult(opOperand);
Value operand = opOperand.get();
Value operandBuffer = lookupBuffer(rewriter, operand, options);
@@ -263,8 +269,13 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
// Move insertion point right after `operandBuffer`. That is where the
// allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(rewriter, operandBuffer);
- // Allocate the result buffer.
- FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer);
+ // Allocate the result buffer. The buffer should be deallocated if the tensor
+ // is not yielded and deallocs are enabled in general.
+ bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
+ return getAnalysisState().isTensorYielded(v);
+ });
+ FailureOr<Value> resultBuffer = createAlloc(
+ rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
if (failed(resultBuffer))
return failure();
// Do not copy if the last preceding writes of `operand` are ops that do
@@ -281,8 +292,6 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
}))
return resultBuffer;
// Do not copy if the copied data is never read.
- SmallVector<OpResult> aliasingOpResults =
- analysisState.getAliasingOpResult(opOperand);
if (!aliasingOpResults.empty() &&
!analysisState.bufferizesToMemoryRead(opOperand) &&
llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
@@ -339,7 +348,12 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
AlwaysCopyAnalysisState::AlwaysCopyAnalysisState(
const BufferizationOptions &options)
- : AnalysisState(options) {}
+ : AnalysisState(options) {
+ // Note: Allocations must be deallocated with a subsequent run of the buffer
+ // deallocation pass.
+ assert(!options.createDeallocs &&
+ "cannot create deallocs with AlwaysCopyBufferizationState");
+}
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const {
@@ -356,6 +370,13 @@ bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
return false;
}
+/// Return true if the given tensor (or an aliasing tensor) is yielded from
+/// the containing block. Also include all aliasing tensors in the same block.
+bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const {
+ // There is no analysis, so conservatively answer "true".
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
@@ -426,37 +447,54 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
}
static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
- ValueRange dynShape) {
+ ValueRange dynShape, bool skipDealloc) {
auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
+ if (skipDealloc)
+ allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr());
return allocaOp.getResult();
}
/// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
/// block in case of a bbArg).
FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
- Value shapedValue) {
+ Value shapedValue,
+ Optional<bool> dealloc) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
+
+ // Compute allocation memref type.
assert(shapedValue.getType().isa<ShapedType>());
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
SmallVector<Value> dynShape;
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
- Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape);
+
+ // Should be the buffer be deallocated again or should we let it leak?
+ bool skipDealloc;
+ if (dealloc) {
+ skipDealloc = !dealloc.getValue();
+ } else {
+ assert(shapedValue.getType().isa<TensorType>() &&
+ "must specify `dealloc` if non-tensor value is passed");
+ // Buffer should be not be deallocated if deallocs are generally deactivated
+ // or if the tensor is yielded from a block.
+ skipDealloc = !getOptions().createDeallocs ||
+ getAnalysisState().isTensorYielded(shapedValue);
+ }
+
+ // Create the buffer allocation.
+ Value alloc =
+ createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc);
+
+ // Insert a cast if a
diff erent type was requested.
if (memRefType && memRefType != allocMemRefType) {
- assert(memref::CastOp::areCastCompatible(alloc.getType(), memRefType) &&
+ assert(memref::CastOp::areCastCompatible(allocMemRefType, memRefType) &&
"createAlloc: cast incompatible");
alloc = b.create<memref::CastOp>(loc, memRefType, alloc);
}
- return alloc;
-}
-/// Create a memref allocation with the given type and dynamic extents.
-FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
- MemRefType type,
- ValueRange dynShape) {
- return createBufferAllocation(b, loc, type, dynShape);
+ return alloc;
}
/// Create a memory copy between two memref buffers.
@@ -480,7 +518,9 @@ createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
// Ignore memref.alloca ops that were not created by the bufferization.
if (!allocaOp->hasAttr(kBufferAllocationAttr))
return WalkResult::skip();
+ bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr);
+ // Create alloc.
Block *block = allocaOp->getBlock();
rewriter.setInsertionPoint(allocaOp);
FailureOr<Value> alloc =
@@ -490,10 +530,11 @@ createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
return WalkResult::interrupt();
rewriter.replaceOp(allocaOp, *alloc);
- // Stop here if deallocations are deactivated.
- if (!options.createDeallocs)
+ // Stop here if the buffer should not be deallocated.
+ if (skipDealloc)
return WalkResult::advance();
+ // Create dealloc.
rewriter.setInsertionPoint(block->getTerminator());
if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
return WalkResult::interrupt();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index f237cb7a6a70e..85a5c5c120d2b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -379,7 +379,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
BufferizationOptions bufferization::getPartialBufferizationOptions() {
BufferizationOptions options;
- options.allowReturnMemref = true;
options.allowUnknownOps = true;
options.createDeallocs = false;
options.fullyDynamicLayoutMaps = false;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 706072d7b9c10..b0bad7a32f2fb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -215,6 +215,43 @@ bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
return aliasInfo.areEquivalentBufferizedValues(v1, v2);
}
+// Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
+// to ensure that such information is available during bufferization time.
+// Alias information can no longer be queried through BufferizationAliasInfo
+// once we have started modifying the IR.
+void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
+ op->walk([&](Operation *returnOp) {
+ if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
+ return WalkResult::advance();
+
+ for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
+ Value returnVal = returnValOperand.get();
+ // Skip non-tensor values.
+ if (!returnVal.getType().isa<TensorType>())
+ continue;
+
+ // Add all aliases of the returned value. But only the ones that are in
+ // the same block.
+ aliasInfo.applyOnAliases(returnVal, [&](Value v) {
+ if (auto bbArg = v.dyn_cast<BlockArgument>()) {
+ if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
+ yieldedTensors.insert(bbArg);
+ return;
+ }
+ Operation *definingOp = v.getDefiningOp();
+ if (definingOp->getParentOp() == returnOp->getParentOp())
+ yieldedTensors.insert(v);
+ });
+ }
+
+ return WalkResult::advance();
+ });
+}
+
+bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
+ return yieldedTensors.contains(tensor);
+}
+
//===----------------------------------------------------------------------===//
// Bufferization-specific alias analysis.
//===----------------------------------------------------------------------===//
@@ -780,6 +817,9 @@ LogicalResult bufferization::analyzeOp(Operation *op,
failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps));
}
+ // Gather all yielded tensors.
+ state.gatherYieldedTensors(op);
+
// Analysis verification: After setting up alias/equivalence sets, each op
// can check for expected invariants/limitations and fail the analysis if
// necessary.
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 67e28c46f3969..0efebdfc9d41a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -335,9 +335,8 @@ struct FromElementsOpInterface
Location loc = op->getLoc();
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
auto shape = tensorType.getShape();
- MemRefType resultType = getContiguousMemRefType(tensorType);
FailureOr<Value> maybeBuffer =
- state.createAlloc(rewriter, loc, resultType, {});
+ state.createAlloc(rewriter, loc, fromElementsOp.result());
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
@@ -386,8 +385,8 @@ struct GenerateOpInterface
Location loc = op->getLoc();
MemRefType memrefType =
getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
- FailureOr<Value> maybeResult = state.createAlloc(
- rewriter, loc, memrefType, generateOp.dynamicExtents());
+ FailureOr<Value> maybeResult =
+ state.createAlloc(rewriter, loc, generateOp.result());
if (failed(maybeResult))
return failure();
Value result = *maybeResult;
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index 0ea283fc9f6cc..f0fe50c522b32 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -68,4 +68,67 @@ func @empty_func() -> () {
return
}
+// -----
+
+// CHECK-LABEL: func @read_after_write_conflict(
+func @read_after_write_conflict(%cst : f32, %idx : index, %idx2 : index)
+ -> (f32, f32) {
+ // CHECK-DAG: %[[alloc:.*]] = memref.alloc
+ // CHECK-DAG: %[[dummy:.*]] = "test.dummy_op"
+ // CHECK-DAG: %[[dummy_m:.*]] = bufferization.to_memref %[[dummy]]
+ %t = "test.dummy_op"() : () -> (tensor<10xf32>)
+
+ // CHECK: memref.copy %[[dummy_m]], %[[alloc]]
+ // CHECK: memref.store %{{.*}}, %[[alloc]]
+ %write = tensor.insert %cst into %t[%idx2] : tensor<10xf32>
+
+ // CHECK: %[[read:.*]] = "test.some_use"(%[[dummy]])
+ %read = "test.some_use"(%t) : (tensor<10xf32>) -> (f32)
+ // CHECK: %[[read2:.*]] = memref.load %[[alloc]]
+ %read2 = tensor.extract %write[%idx] : tensor<10xf32>
+
+ // CHECK: memref.dealloc %[[alloc]]
+ // CHECK: return %[[read]], %[[read2]]
+ return %read, %read2 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @copy_deallocated(
+func @copy_deallocated() -> tensor<10xf32> {
+ // CHECK: %[[alloc:.*]] = memref.alloc()
+ %0 = linalg.init_tensor[10] : tensor<10xf32>
+ // CHECK: %[[alloc_tensor:.*]] = bufferization.to_tensor %[[alloc]]
+ // CHECK: memref.dealloc %[[alloc]]
+ // CHECK: return %[[alloc_tensor]]
+ return %0 : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @buffer_not_deallocated(
+// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
+func @buffer_not_deallocated(%t : tensor<?xf32>, %c : i1) -> tensor<?xf32> {
+ // CHECK: %[[r:.*]] = scf.if %{{.*}} {
+ %r = scf.if %c -> tensor<?xf32> {
+ // CHECK: %[[some_op:.*]] = "test.some_op"
+ // CHECK: %[[alloc:.*]] = memref.alloc(%[[some_op]])
+ // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+ // CHECK-NOT: dealloc
+ // CHECK: scf.yield %[[casted]]
+ %sz = "test.some_op"() : () -> (index)
+ %0 = linalg.init_tensor[%sz] : tensor<?xf32>
+ scf.yield %0 : tensor<?xf32>
+ } else {
+ // CHECK: } else {
+ // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]]
+ // CHECK: scf.yield %[[m]]
+ scf.yield %t : tensor<?xf32>
+ }
+ // CHECK: }
+ // CHECK-NOT: dealloc
+ // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
+ // CHECK: return %[[r_tensor]]
+ return %r : tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 53c3a603ca03d..a39fb207aa050 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -81,7 +81,6 @@ func @not_inplace(
// CHECK: linalg.fill ins(%[[F0]] : f32) outs(%[[ALLOC]] : memref<?xf32>)
%r = linalg.fill ins(%f0 : f32) outs(%A : tensor<?xf32>) -> tensor<?xf32>
- // CHECK: dealloc %[[ALLOC]] : memref<?xf32>
// CHECK: return %[[ALLOC]] : memref<?xf32>
return %r: tensor<?xf32>
}
@@ -292,7 +291,6 @@ func @insert_slice_fun_not_inplace(
// CHECK: memref.copy %[[A]], %[[ALLOC]] : memref<?xf32{{.*}} to memref<?xf32>
// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref<?xf32> to memref<4xf32>
// CHECK: memref.copy %[[t]], %[[SV]] : memref<4xf32, #map> to memref<4xf32>
- // CHECK: memref.dealloc %[[ALLOC]] : memref<?xf32>
%r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
// CHECK: return %{{.*}} : memref<?xf32>
@@ -329,7 +327,6 @@ func @scf_for_yield_only(%A : tensor<?xf32> {linalg.inplaceable = false},
scf.yield %t : tensor<?xf32>
}
- // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref<?xf32>
// CHECK: return %[[CASTED]] : memref<?xf32, #[[$map_1d_dyn]]>
return %r0, %r1: tensor<?xf32>, tensor<?xf32>
}
@@ -395,7 +392,6 @@ func @scf_for_with_tensor.insert_slice(
scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
}
- // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref<?xf32>
// CHECK: return %[[CASTED]] : memref<?xf32, #[[$map_1d_dyn]]>
return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
}
More information about the Mlir-commits
mailing list