[Mlir-commits] [mlir] 2c5c5ca - [mlir][linalg][bufferize] Fix CallOp bufferization
Matthias Springer
llvmlistbot at llvm.org
Tue Jan 11 03:10:34 PST 2022
Author: Matthias Springer
Date: 2022-01-11T20:10:21+09:00
New Revision: 2c5c5ca8681a2788229cde61d09129316448508b
URL: https://github.com/llvm/llvm-project/commit/2c5c5ca8681a2788229cde61d09129316448508b
DIFF: https://github.com/llvm/llvm-project/commit/2c5c5ca8681a2788229cde61d09129316448508b.diff
LOG: [mlir][linalg][bufferize] Fix CallOp bufferization
Previously, CallOps did not have any aliasing OpResult/OpOperand pairs. Therefore, CallOps were mostly ignored by the analysis and buffer copies were not inserted when necessary.
This commit introduces the following changes:
* Function bbArgs writable by default. A function can now be bufferized without inspecting its callers.
* Callers must introduce buffer copies of function arguments when necessary. If a function is external, the caller must conservatively assume that a function argument is modified by the callee after bufferization. If the function is not external, the caller inspects the callee to determine if a function argument is modified.
Differential Revision: https://reviews.llvm.org/D116457
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 2167908414319..fe5fc26c3d2ba 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -256,6 +256,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
/// themselves (e.g., ExtractSliceOp).
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
Value value) const {
+ assert(value.getType().isa<TensorType>() && "expected TensorType");
SmallVector<OpOperand *> workingSet;
for (OpOperand &use : value.getUses())
workingSet.push_back(&use);
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 95ecb21cf8e96..5bf26365caa6e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -6,87 +6,68 @@
//
//===----------------------------------------------------------------------===//
//
-// Module bufferization is an extension of Comprehensive Bufferize that
+// Module Bufferization is an extension of Comprehensive Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
-// implementations for FuncOp, CallOp and ReturnOp, along with a few helper
-// functions that control the order in which functions are bufferized.
+// implementations for FuncOp, CallOp and ReturnOp.
//
-// Three cases can occur during bufferization of FuncOps.
+// Module Bufferization is run via `runComprehensiveBufferize(ModuleOp, ...)`.
+// This function analyzed the given module and determines the order of
+// analysis and bufferization: Functions that are called are processed before
+// their respective callers.
//
-// i. inplaceable function arguments may be reused in place after the
-// function itself has been bufferized. This is encoded by IR resembling:
+// After analyzing a FuncOp, additional information about its bbArgs is
+// gathered through PostAnalysisSteps and stored in `ModuleBufferizationState`.
//
-// ```
-// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-// func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
-// -> tensor<?xf32> {
-// %0 = bufferization.to_memref %A : memref<?xf32, #map>
-// // ... uses of %0
-// %res = bufferization.to_tensor %0 : memref<?xf32, #map>
-// return %res : tensor<?xf32>
-// }
-// ```
+// * `EquivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each
+// tensor return value (if any).
+// * `FuncOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
+// read/written.
//
-// this is the cue for the bufferization of the function foo (and calls
-// to it) may bufferize to `func @foo(%A: memref<?xf32, some_layout>)`.
-// To fully achieve bufferization, an additional analysis is needed to
-// determine whether function argument/operand pairs bufferize to a
-// single inplace buffer argument (i.e. functions may return tensors in
-// arbitrary order that may not match argument numbers).
+// Only tensors that are equivalent to some FuncOp bbArg may be returned.
+// Bufferization currently fails if other tensors (in particular tensors that
+// bufferize out-of-place and result in a new buffer allocation) are returned.
+// In the future, such allocations could be hoisted to the caller.
//
-// ii. results that don't map to an inplaceable function argument are
-// generally allocated. Since memref semantics wrt ownership of the
-// underlying memory region are not well-defined, comprehensive
-// bufferization chooses to perform allocations in a scoped fashion:
-// returning memrefs is always considered illegal.
-// Such scenarios are encoded by IR resembling:
+// Example: `foo` fails bufferization because %0 is not equivalent to any bbArg.
+// ```
+// func @foo() -> tensor<?xf32> {
+// %0 = linalg.init_tensor [...] : tensor<?xf32>
+// return %0 : tensor<?xf32>
+// }
+// ```
//
-// ```
-// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-// func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
-// -> tensor<?xf32> {
-// %0 = bufferization.to_memref %A : memref<?xf32, #map>
-// %1 = memref.dim %0, %c0 : memref<?xf32, #map>
-// %2 = memref.alloc(%1) : memref<?xf32>
-// %3 = memref.cast %2 : memref<?xf32> to memref<?xf32, #map>
-// // ... uses of %3
-// memref.dealloc %2 : memref<?xf32, #map>
-// %res = bufferization.to_tensor %3 : memref<?xf32, #map>
-// return %res : tensor<?xf32>
-// }
-// ```
+// Module Bufferization implements the following calling convention.
//
-// this is the cue for the bufferization of the function foo (and calls
-// to it) that it must bufferize to `func @foo(%A: memref<?xf32,
-// some_layout>,
-// %B: memref<?xf32, some_layout>)` (i.e. make a cloned
-// allocation of the result tensor)
-// To fully achieve bufferization, the alloc/dealloc pair must be lifted
-// out of the function at each call site.
+// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
+// be written to in-place.
+// * If a tensor operand of a CallOp is read after the CallOp, the operand of
+// the CallOp must bufferize out-of-place.
//
-// iii. as an optimization over ii., it may be possible to reuse an argument
-// and only want to return a slice.
-// This may forego allocation by letting *all* callers decide whether to
-// pass a new *aliasing* memref function argument (i.e. a subview).
-// Without loss of generality, callers may agree to allocate a new buffer
-// to avoid this aliasing. Such scenarios are encoded by IR resembling:
+// Example: The tensor.insert op bufferizes in-place because it is allowed to
+// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
+// out-of-place because `%t0` is modified by the callee but read by the
+// tensor.extract op. The analysis of CallOps decides whether an OpOperand must
+// bufferize out-of-place based on results of `FuncOpBbArgReadWriteAnalysis`.
+// ```
+// func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
+// %f = ... : f32
+// %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
+// return %0 : tensor<?xf32>
+// }
//
-// ```
-// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-// func @foo(%arg0: tensor<?xf32> {linalg.inplaceable = true})
-// -> tensor<4xf32> {
-// %0 = bufferization.to_memref %arg0 : memref<?xf32, #map>
-// %1 = memref.subview %0[0] [4] [1] : memref<?xf32, #map> to
-// memref<4xf32, #map>
-// // ... inplace computes into %1
-// %3 = bufferization.to_tensor %1 : memref<4xf32, #map>
-// return %3 : tensor<4xf32>
-// }
-// ```
+// func @caller() -> () {
+// %t0 = ... : tensor<?xf32>
+// %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
+// %2 = tensor.extract %1[...] : tensor<?xf32>
+// }
+// ```
//
-// Note: In the future, it may be worthwhile to design special bufferization
-// ops to encode the desired semantics at function boundaries for i., ii. and
-// iii.
+// Note: If a function is external, `FuncOpBbArgReadWriteAnalysis` cannot
+// analyze the function body. In such a case, the CallOp analysis conservatively
+// assumes that each tensor OpOperand is both read and written.
+//
+// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
+// as "not reading" and/or "not writing".
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
@@ -103,6 +84,9 @@ using namespace tensor;
using namespace comprehensive_bufferize;
namespace {
+/// The state of analysis of a FuncOp.
+enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };
+
/// Extra bufferization state that is required for bufferization of function
/// boundaries.
struct ModuleBufferizationState : public DialectBufferizationState {
@@ -110,8 +94,22 @@ struct ModuleBufferizationState : public DialectBufferizationState {
/// indices.
DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
+ /// A set of all read BlockArguments of FuncOps.
+ // Note: BlockArgument knows about its owner, so we do not need to store
+ // FuncOps here.
+ DenseSet<BlockArgument> readBbArgs;
+
+ /// A set of all written-to BlockArguments of FuncOps.
+ DenseSet<BlockArgument> writtenBbArgs;
+
+ /// Keep track of which FuncOps are fully analyzed or currently being
+ /// analyzed.
+ DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
+
+ // A list of functions in the order in which they are analyzed + bufferized.
SmallVector<FuncOp> orderedFuncOps;
+ // A mapping of FuncOps to their callers.
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
};
} // namespace
@@ -133,6 +131,17 @@ getModuleBufferizationState(BufferizationState &state) {
StandardOpsDialect::getDialectNamespace());
}
+/// Return the state (phase) of analysis of the FuncOp.
+static FuncOpAnalysisState
+getFuncOpAnalysisState(const BufferizationState &state, FuncOp funcOp) {
+ const ModuleBufferizationState &moduleState =
+ getModuleBufferizationState(state);
+ auto it = moduleState.analyzedFuncOps.find(funcOp);
+ if (it == moduleState.analyzedFuncOps.end())
+ return FuncOpAnalysisState::NotAnalyzed;
+ return it->second;
+}
+
/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
@@ -197,6 +206,69 @@ struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
return success();
}
};
+
+/// Return true if the buffer of the given tensor value is written to. Must not
+/// be called for values inside not yet analyzed functions. (Post-analysis
+/// steps do not have to be run yet, i.e., "in progress" is also OK.)
+static bool isValueWritten(Value value, const BufferizationState &state,
+ const BufferizationAliasInfo &aliasInfo) {
+#ifndef NDEBUG
+ assert(value.getType().isa<TensorType>() && "expected TensorType");
+ FuncOp funcOp;
+ if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+ Operation *owner = bbArg.getOwner()->getParentOp();
+ funcOp = isa<FuncOp>(owner) ? cast<FuncOp>(owner)
+ : owner->getParentOfType<FuncOp>();
+ } else {
+ funcOp = value.getDefiningOp()->getParentOfType<FuncOp>();
+ }
+ assert(getFuncOpAnalysisState(state, funcOp) !=
+ FuncOpAnalysisState::NotAnalyzed &&
+ "FuncOp must be fully analyzed or analysis in progress");
+#endif // NDEBUG
+
+ bool isWritten = false;
+ aliasInfo.applyOnAliases(value, [&](Value val) {
+ for (OpOperand &use : val.getUses())
+ if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use))
+ isWritten = true;
+ });
+ return isWritten;
+}
+
+/// Determine which FuncOp bbArgs are read and which are written. If this
+/// PostAnalysisStep is run on a function with unknown ops, it will
+/// conservatively assume that such ops bufferize to a read + write.
+struct FuncOpBbArgReadWriteAnalysis : public PostAnalysisStep {
+ LogicalResult run(Operation *op, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
+ SmallVector<Operation *> &newOps) override {
+ ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+ auto funcOp = cast<FuncOp>(op);
+
+ // If the function has no body, conservatively assume that all args are
+ // read + written.
+ if (funcOp.getBody().empty()) {
+ for (BlockArgument bbArg : funcOp.getArguments()) {
+ moduleState.readBbArgs.insert(bbArg);
+ moduleState.writtenBbArgs.insert(bbArg);
+ }
+
+ return success();
+ }
+
+ for (BlockArgument bbArg : funcOp.getArguments()) {
+ if (!bbArg.getType().isa<TensorType>())
+ continue;
+ if (state.isValueRead(bbArg))
+ moduleState.readBbArgs.insert(bbArg);
+ if (isValueWritten(bbArg, state, aliasInfo))
+ moduleState.writtenBbArgs.insert(bbArg);
+ }
+
+ return success();
+ }
+};
} // namespace
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
@@ -575,43 +647,101 @@ namespace std_ext {
static Optional<int64_t>
getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state,
int64_t returnValIdx) {
- if (!state.equivalentFuncArgs.count(funcOp))
+ auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
+ if (funcOpIt == state.equivalentFuncArgs.end())
// No equivalence info stores for funcOp.
return None;
- const DenseMap<int64_t, int64_t> &equivFuncArgs =
- state.equivalentFuncArgs.lookup(funcOp);
- if (!equivFuncArgs.count(returnValIdx))
+ auto retValIt = funcOpIt->getSecond().find(returnValIdx);
+ if (retValIt == funcOpIt->getSecond().end())
// Return value has no equivalent bbArg.
return None;
- return equivFuncArgs.lookup(returnValIdx);
+ return retValIt->getSecond();
}
struct CallOpInterface
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
- // CallOpInterface alone doesn't bufferize to a memory read, one of the uses
- // of the matching bbArg may. It is the responsibility of the caller to
- // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
- // conservative.
- return true;
+ CallOp callOp = cast<CallOp>(op);
+ FuncOp funcOp = getCalledFunction(callOp);
+ assert(funcOp && "expected CallOp to a FuncOp");
+
+ const ModuleBufferizationState &moduleState =
+ getModuleBufferizationState(state);
+ if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
+ // FuncOp not analyzed yet. Assume that OpOperand is read.
+ return true;
+
+ return moduleState.readBbArgs.contains(
+ funcOp.getArgument(opOperand.getOperandNumber()));
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
- return false;
+ CallOp callOp = cast<CallOp>(op);
+ FuncOp funcOp = getCalledFunction(callOp);
+ assert(funcOp && "expected CallOp to a FuncOp");
+
+ const ModuleBufferizationState &moduleState =
+ getModuleBufferizationState(state);
+ if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
+ // FuncOp not analyzed yet. Assume that OpOperand is written.
+ return true;
+
+ return moduleState.writtenBbArgs.contains(
+ funcOp.getArgument(opOperand.getOperandNumber()));
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
- // CallOpInterface is special, it needs to wait for the callee to be
- // bufferized and needs to inspect the BufferAliasInfo object. It can't
- // make a proper determination by itself and needs to be conservative.
+ CallOp callOp = cast<CallOp>(op);
+ FuncOp funcOp = getCalledFunction(callOp);
+ assert(funcOp && "expected CallOp to a FuncOp");
+ const ModuleBufferizationState &moduleState =
+ getModuleBufferizationState(state);
+
+ for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults();
+ ++resultIdx)
+ if (Optional<int64_t> maybeArgNumber =
+ getEquivalentFuncArgIdx(funcOp, moduleState, resultIdx))
+ if (*maybeArgNumber == opOperand.getOperandNumber())
+ return callOp->getOpResult(resultIdx);
+
+ // Note: Returning a non-equivalent tensor from a FuncOp is currently not
+ // supported an will fail bufferization. (Even if allow-return-memref, it
+ // will fail when the function is called.)
return OpResult();
}
+ SmallVector<OpOperand *>
+ getAliasingOpOperand(Operation *op, OpResult opResult,
+ const BufferizationState &state) const {
+ CallOp callOp = cast<CallOp>(op);
+ FuncOp funcOp = getCalledFunction(callOp);
+ assert(funcOp && "expected CallOp to a FuncOp");
+ const ModuleBufferizationState &moduleState =
+ getModuleBufferizationState(state);
+
+ // TODO: We should be looking for aliasing block arguments here. The current
+ // condition is actually stronger than neccesary. Once we check for aliasing
+ // block arguments, we may be multiple.
+ if (Optional<int64_t> maybeArgNumber = getEquivalentFuncArgIdx(
+ funcOp, moduleState, opResult.getResultNumber()))
+ return {&op->getOpOperand(*maybeArgNumber)};
+
+ // Note: Returning a non-equivalent tensor from a FuncOp is currently not
+ // supported an will fail bufferization.
+ return {};
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const BufferizationAliasInfo &aliasInfo,
+ const BufferizationState &state) const {
+ return BufferRelation::Equivalent;
+ }
+
/// In a first approximation, all the function arguments of a FuncOp are
/// marked inplaceable. For now, it is the responsibility of the `callOp`
/// bufferization to allow FuncOp that are inplaceable to write inPlace.
@@ -667,11 +797,12 @@ struct CallOpInterface
getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
// Return operands that are equivalent to some bbArg, are not
// returned.
- Value buffer =
- *state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx),
- /*forceInPlace=*/true);
- replacementValues[returnValIdx] = buffer;
- newOperands[*bbArgIdx] = buffer;
+ FailureOr<Value> bufferOrFailure =
+ state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
+ if (failed(bufferOrFailure))
+ return failure();
+ replacementValues[returnValIdx] = *bufferOrFailure;
+ newOperands[*bbArgIdx] = *bufferOrFailure;
continue;
}
@@ -700,11 +831,15 @@ struct CallOpInterface
// Retrieve buffers for tensor operands. Tensor operand buffers, who's
// corresponding FuncOp bbArgs are equivalent to a returned tensor, were
// already stored in `newOperands` during Step 1.
- Value buffer = newOperands[idx] ? newOperands[idx]
- : *state.getBuffer(rewriter, opOperand,
- /*forceInPlace=*/true);
+ Value buffer = newOperands[idx];
+ if (!buffer) {
+ FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand);
+ if (failed(bufferOrFailure))
+ return failure();
+ buffer = *bufferOrFailure;
+ }
- // Caller / callee type mistmatch is handled with a CastOp.
+ // Caller / callee type mismatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx);
// Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
@@ -782,8 +917,6 @@ struct FuncOpInterface
auto funcOp = cast<FuncOp>(op);
BlockArgument bbArg = value.dyn_cast<BlockArgument>();
assert(bbArg && "expected BlockArgument");
- const ModuleBufferizationState &moduleState =
- getModuleBufferizationState(state);
// "linalg.inplaceable" overrides other writability decisions. This is
// currently used for testing only.
@@ -792,16 +925,8 @@ struct FuncOpInterface
BufferizableOpInterface::kInplaceableAttrName))
return inplaceAttr.getValue();
- // In a first approximation:
- // =========================
- // If the function is called, we can allocate on the caller side which lets
- // us force inplace arguments at function boundaries.
- // TODO: do not rely on this behavior.
- if (moduleState.callerMap.find(funcOp) != moduleState.callerMap.end())
- return true;
-
- // All other function arguments are not writable.
- return false;
+ // All function arguments are writable by default.
+ return true;
}
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
@@ -849,11 +974,11 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
moduleState.callerMap)))
return failure();
- // Interestingly, all function args that are not visible outside of a module
- // can be fully bufferized inplace by guaranteeing the CallOp is bufferized
- // inplace. Therefore, we just bufferize funcOp as if none of its results were
- // inplaceable, detect which operands are cloned internally and decide what to
- // do at call sites.
+ // Collect bbArg/return value information after the analysis.
+ options->postAnalysisSteps.emplace_back(
+ std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
+ options->postAnalysisSteps.emplace_back(
+ std::make_unique<FuncOpBbArgReadWriteAnalysis>());
// Analyze ops.
for (FuncOp funcOp : moduleState.orderedFuncOps) {
@@ -861,17 +986,20 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
if (funcOp.body().empty())
continue;
- // Collect bbArg/return value information after the analysis.
- options->postAnalysisSteps.emplace_back(
- std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
-
- // Gather equivalence info for CallOps.
- equivalenceAnalysis(funcOp, aliasInfo, moduleState);
+ // Now analyzing function.
+ moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
// Analyze funcOp.
if (failed(analyzeOp(funcOp, state)))
return failure();
+ // Gather equivalence info for CallOps.
+ // TODO: Make this a post-analysis step.
+ equivalenceAnalysis(funcOp, aliasInfo, moduleState);
+
+ // Mark op as fully analyzed.
+ moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
+
// Add annotations to function arguments.
if (options->testAnalysisOnly)
annotateOpsWithBufferizationMarkers(funcOp, state);
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 96725d16bd16c..929fc150f8946 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -630,7 +630,7 @@ func @scf_for_deps(
// of %r1 is read.
// CHECK: scf.for
// CHECK-NEXT: call
- // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+ // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
// CHECK-NEXT: scf.yield
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
// CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "false"]}
@@ -642,7 +642,7 @@ func @scf_for_deps(
// %r1 bufferizes inplace fine.
// CHECK: scf.for
// CHECK-NEXT: call
- // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+ // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
// CHECK-NEXT: scf.yield
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
// CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]}
@@ -655,7 +655,7 @@ func @scf_for_deps(
// of %r3 is read.
// CHECK: linalg.tiled_loop
// CHECK-NEXT: call
- // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+ // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
// CHECK-NEXT: linalg.yield
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
// CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "false"]}
@@ -669,7 +669,7 @@ func @scf_for_deps(
// %r3 bufferizes inplace fine.
// CHECK: linalg.tiled_loop
// CHECK-NEXT: call
- // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+ // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
// CHECK-NEXT: linalg.yield
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
// CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 05c120bcf557d..a9c2bcba865e6 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -410,7 +410,9 @@ func @main() {
// CHECK: %[[A:.*]] = memref.get_global @__constant_4xi32 : memref<4xi32>
%A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
-// CHECK: %[[B:.*]] = memref.cast %[[A]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
+// CHECK: %[[alloc:.*]] = memref.alloc
+// CHECK: %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
+// CHECK: linalg.copy(%[[A]], %[[alloc]])
// CHECK: call @some_external_func(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> ()
call @some_external_func(%A) : (tensor<4xi32>) -> ()
@@ -430,7 +432,9 @@ func @main() {
// CHECK: %[[A:.*]] = memref.get_global @__constant_4xi32 : memref<4xi32>
%A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
-// CHECK: %[[B:.*]] = memref.cast %[[A]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
+// CHECK: %[[alloc:.*]] = memref.alloc
+// CHECK: %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
+// CHECK: linalg.copy(%[[A]], %[[alloc]])
// CHECK: call @some_external_func_within_scf_execute(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> ()
scf.execute_region {
call @some_external_func_within_scf_execute(%A) : (tensor<4xi32>) -> ()
@@ -488,16 +492,19 @@ func @bar(
%lb : index, %ub : index, %step : index)
-> (tensor<?xf32>, tensor<?xf32>)
{
-// CHECK-NEXT: call @scf_for_with_tensor_insert_slice(%[[A]], %[[B]], %[[C]]
+// CHECK: call @scf_for_with_tensor_insert_slice(%[[A]], %[[B]], %[[C]]
%r0:2 = call @scf_for_with_tensor_insert_slice(%A, %B, %C, %lb, %ub, %step) :
(tensor<?xf32>, tensor<?xf32>, tensor<4xf32>, index, index, index)
-> (tensor<?xf32>, tensor<?xf32>)
- // %r0#0 is actually %B after inplaceable results are swapped in the callee.
-// CHECK-NEXT: call @some_external_func(%[[B]]) : (memref<?xf32, #[[$DYN_1D_MAP]]>) -> ()
+ // %r0#0 requires a copy because we have no idea what the function is doing.
+// CHECK: %[[alloc:.*]] = memref.alloc
+// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+// CHECK: linalg.copy(%[[B]], %[[alloc]])
+// CHECK-NEXT: call @some_external_func(%[[casted]]) : (memref<?xf32, #[[$DYN_1D_MAP]]>) -> ()
call @some_external_func(%r0#0) : (tensor<?xf32>) -> ()
-// CHECK-NEXT: return
+// CHECK: return
return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
}
@@ -745,8 +752,21 @@ func @callee(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -
func @entry(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false},
%B : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false},
%C : tensor<?xf32> {linalg.inplaceable = false}) {
-// CHECK-NEXT: %[[CASTED_B:.*]] = memref.cast %[[B]] : memref<?xf32> to memref<?xf32, #[[$DYNAMIC]]>
-// CHECK-NEXT: call @callee(%[[A]], %[[CASTED_B]], %[[C]])
+// Note: `callee` does not write to its bbArg directly, but `external_func`
+// does. Inside `callee`, the writes via `external_func` do not cause a
+// conflict. However, inside `entry`, the writes do cause a conflict because
+// %A, %B and %C are not inplaceable. This test case shows that this kind of
+// conflict detection has a "transitive" nature.
+// CHECK: %[[ALLOC_C:.*]] = memref.alloc
+// CHECK: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]]
+// CHECK: %[[ALLOC_B:.*]] = memref.alloc
+// CHECK: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]]
+// CHECK: %[[ALLOC_A:.*]] = memref.alloc
+// CHECK: linalg.copy(%[[A]], %[[ALLOC_A]])
+// CHECK: linalg.copy(%[[B]], %[[ALLOC_B]])
+// CHECK: linalg.copy(%[[C]], %[[ALLOC_C]])
+// CHECK: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]]
+// CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]])
call @callee(%A, %B, %C) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> ()
return
}
@@ -992,9 +1012,10 @@ func @inner_func_2(%t: tensor<?xf32>) -> tensor<?xf32> {
func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
%c0: index, %c10: index, %c1: index) -> tensor<?xf32> {
%1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
- // TODO: There should be a memory copy here. This is a bug in CallOp
- // bufferization.
- // CHECK: call @inner_func_2(%[[arg0]])
+ // CHECK: %[[alloc:.*]] = memref.alloc
+ // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+ // CHECK: linalg.copy(%[[arg0]], %[[alloc]])
+ // CHECK: call @inner_func_2(%[[casted]])
%3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
scf.yield %t1 : tensor<?xf32>
}
More information about the Mlir-commits
mailing list