[Mlir-commits] [mlir] 3135548 - [mlir][linalg][bufferize][NFC] Split analysis-related code from BufferizationState/Options
Matthias Springer
llvmlistbot at llvm.org
Wed Jan 19 05:06:33 PST 2022
Author: Matthias Springer
Date: 2022-01-19T21:59:23+09:00
New Revision: 31355482e5189c489efa98607d86215480d17452
URL: https://github.com/llvm/llvm-project/commit/31355482e5189c489efa98607d86215480d17452
DIFF: https://github.com/llvm/llvm-project/commit/31355482e5189c489efa98607d86215480d17452.diff
LOG: [mlir][linalg][bufferize][NFC] Split analysis-related code from BufferizationState/Options
This separates the analysis (and its helpers/data structures) more clearly from the rest of the bufferization.
Differential Revision: https://reviews.llvm.org/D117477
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 8fb05125deeb5..86a58001ac735 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -68,14 +68,6 @@ struct BufferizationOptions {
// BufferizationOptions cannot be copied.
BufferizationOptions(const BufferizationOptions &other) = delete;
- /// Register a "post analysis" step. Such steps are executed after the
- /// analysis, but before bufferization.
- template <typename Step, typename... Args>
- void addPostAnalysisStep(Args... args) {
- postAnalysisSteps.emplace_back(
- std::make_unique<Step>(std::forward<Args>(args)...));
- }
-
/// Return `true` if the op is allowed to be bufferized.
bool isOpAllowed(Operation *op) const {
if (!dialectFilter.hasValue())
@@ -134,9 +126,6 @@ struct BufferizationOptions {
/// For debugging only. Should be used together with `testAnalysisOnly`.
bool printConflicts = false;
- /// Registered post analysis steps.
- PostAnalysisStepList postAnalysisSteps;
-
/// Only bufferize ops from dialects that are allowed-listed by the filter.
/// All other ops are ignored. This option controls the scope of partial
/// bufferization.
@@ -157,6 +146,25 @@ struct BufferizationOptions {
}
};
+/// Options for analysis-enabled bufferization.
+struct AnalysisBufferizationOptions : public BufferizationOptions {
+ AnalysisBufferizationOptions() = default;
+
+ // AnalysisBufferizationOptions cannot be copied.
+ AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
+
+ /// Register a "post analysis" step. Such steps are executed after the
+ /// analysis, but before bufferization.
+ template <typename Step, typename... Args>
+ void addPostAnalysisStep(Args... args) {
+ postAnalysisSteps.emplace_back(
+ std::make_unique<Step>(std::forward<Args>(args)...));
+ }
+
+ /// Registered post analysis steps.
+ PostAnalysisStepList postAnalysisSteps;
+};
+
/// Specify fine-grain relationship between buffers to enable more analysis.
enum class BufferRelation {
None,
@@ -198,11 +206,6 @@ class BufferizationAliasInfo {
return equivalentInfo.isEquivalent(v1, v2);
}
- /// Return true if `v1` and `v2` bufferize to aliasing buffers.
- bool areAliasingBufferizedValues(Value v1, Value v2) const {
- return aliasInfo.isEquivalent(v1, v2);
- }
-
/// Union the alias sets of `v1` and `v2`.
void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
@@ -276,11 +279,6 @@ struct DialectBufferizationState {
/// tensor values and memref buffers.
class BufferizationState {
public:
- BufferizationState(Operation *op, const BufferizationOptions &options);
-
- // BufferizationState should be passed as a reference.
- BufferizationState(const BufferizationState &) = delete;
-
/// Determine which OpOperand* will alias with `result` if the op is
/// bufferized in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
@@ -344,7 +342,10 @@ class BufferizationState {
SetVector<Value> findLastPrecedingWrite(Value value) const;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
- bool isInPlace(OpOperand &opOperand) const;
+ virtual bool isInPlace(OpOperand &opOperand) const = 0;
+
+ /// Return true if `v1` and `v2` bufferize to equivalent buffers.
+ virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0;
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
@@ -374,14 +375,15 @@ class BufferizationState {
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const { return options; }
- /// Return a reference to the BufferizationAliasInfo.
- BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
+protected:
+ BufferizationState(const BufferizationOptions &options);
-private:
- /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
- /// functions and `runComprehensiveBufferize` may access this object.
- BufferizationAliasInfo aliasInfo;
+ // BufferizationState should be passed as a reference.
+ BufferizationState(const BufferizationState &) = delete;
+
+ ~BufferizationState() = default;
+private:
/// Dialect-specific bufferization state.
DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
@@ -389,6 +391,33 @@ class BufferizationState {
const BufferizationOptions &options;
};
+/// State for analysis-enabled bufferization. This class keeps track of alias
+/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
+/// in-place.
+class AnalysisBufferizationState : public BufferizationState {
+public:
+ AnalysisBufferizationState(Operation *op,
+ const AnalysisBufferizationOptions &options);
+
+ AnalysisBufferizationState(const AnalysisBufferizationState &) = delete;
+
+ virtual ~AnalysisBufferizationState() = default;
+
+ /// Return a reference to the BufferizationAliasInfo.
+ BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
+
+ /// Return `true` if the given OpResult has been decided to bufferize inplace.
+ bool isInPlace(OpOperand &opOperand) const override;
+
+ /// Return true if `v1` and `v2` bufferize to equivalent buffers.
+ bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
+
+private:
+ /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
+ /// functions and `runComprehensiveBufferize` may access this object.
+ BufferizationAliasInfo aliasInfo;
+};
+
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
@@ -483,7 +512,6 @@ struct AllocationHoistingBarrierOnly
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::None;
}
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index 0f642b6eedeef..6569b0df68123 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -183,7 +183,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"BufferRelation",
/*methodName=*/"bufferRelation",
/*args=*/(ins "OpResult":$opResult,
- "const BufferizationAliasInfo &":$aliasInfo,
"const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -284,8 +283,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"isNotConflicting",
/*args=*/(ins "OpOperand *":$uRead,
"OpOperand *":$uWrite,
- "const BufferizationState &":$state,
- "const BufferizationAliasInfo &":$aliasInfo),
+ "const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index bd1cd17875c98..6a53295babcd8 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -16,22 +16,22 @@ namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
+class AnalysisBufferizationState;
class BufferizationAliasInfo;
-struct BufferizationOptions;
+struct AnalysisBufferizationOptions;
class BufferizationState;
/// Analyze `op` and its nested ops. Bufferization decisions are stored in
/// `state`.
-LogicalResult analyzeOp(Operation *op, BufferizationState &state);
+LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state);
/// Bufferize `op` and its nested ops. Bufferization decisions are stored in
/// `state`.
LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
/// Run Comprehensive Bufferize on the given op: Analysis + Bufferization
-LogicalResult
-runComprehensiveBufferize(Operation *op,
- std::unique_ptr<BufferizationOptions> options);
+LogicalResult runComprehensiveBufferize(
+ Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options);
} // namespace comprehensive_bufferize
} // namespace linalg
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
index cde14939e3fe2..6b4039f5283ff 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
@@ -20,14 +20,13 @@ class ModuleOp;
namespace linalg {
namespace comprehensive_bufferize {
-struct BufferizationOptions;
+struct AnalysisBufferizationOptions;
/// Run Module Bufferization on the given module. Performs a simple function
/// call analysis to determine which function arguments are inplaceable. Then
/// analyzes and bufferizes FuncOps one-by-one with Comprehensive Bufferization.
-LogicalResult
-runComprehensiveBufferize(ModuleOp moduleOp,
- std::unique_ptr<BufferizationOptions> options);
+LogicalResult runComprehensiveBufferize(
+ ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options);
namespace std_ext {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 885a70f56c64b..be9e919fbb628 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -288,8 +288,13 @@ llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
}
mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
- Operation *op, const BufferizationOptions &options)
- : aliasInfo(op), options(options) {
+ const BufferizationOptions &options)
+ : options(options) {}
+
+mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
+ AnalysisBufferizationState(Operation *op,
+ const AnalysisBufferizationOptions &options)
+ : BufferizationState(options), aliasInfo(op) {
// Set up alias sets for OpResults that must bufferize in-place. This should
// be done before making any other bufferization decisions.
op->walk([&](BufferizableOpInterface bufferizableOp) {
@@ -353,7 +358,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
Value operand = opOperand.get();
Value operandBuffer = lookupBuffer(rewriter, operand);
- if (forceInPlace || aliasInfo.isInPlace(opOperand))
+ if (forceInPlace || isInPlace(opOperand))
return operandBuffer;
// Bufferizing out-of-place: Allocate a new buffer.
@@ -597,11 +602,16 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
return isa<FuncOp>(bbArg.getOwner()->getParentOp());
}
-bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
- OpOperand &opOperand) const {
+bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
+ isInPlace(OpOperand &opOperand) const {
return aliasInfo.isInPlace(opOperand);
}
+bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
+ areEquivalentBufferizedValues(Value v1, Value v2) const {
+ return aliasInfo.areEquivalentBufferizedValues(v1, v2);
+}
+
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
ShapedType shapedType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 7955f7b35b61b..67fd483641669 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -252,15 +252,13 @@ static bool hasReadAfterWriteInterference(
// No conflict if the op interface says so.
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
- if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
- aliasInfo))
+ if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
continue;
if (conflictingWritingOp != readingOp)
if (auto bufferizableOp =
options.dynCastBufferizableOp(conflictingWritingOp))
- if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
- aliasInfo))
+ if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
continue;
// Ops are not conflicting if they are in mutually exclusive regions.
@@ -496,7 +494,7 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
for (OpOperand *opOperand :
bufferizableOp.getAliasingOpOperand(opResult, state))
if (state.isInPlace(*opOperand))
- if (bufferizableOp.bufferRelation(opResult, aliasInfo, state) ==
+ if (bufferizableOp.bufferRelation(opResult, state) ==
BufferRelation::Equivalent)
aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
}
@@ -687,12 +685,12 @@ checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
return success();
}
-LogicalResult
-mlir::linalg::comprehensive_bufferize::analyzeOp(Operation *op,
- BufferizationState &state) {
+LogicalResult mlir::linalg::comprehensive_bufferize::analyzeOp(
+ Operation *op, AnalysisBufferizationState &state) {
DominanceInfo domInfo(op);
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
- const BufferizationOptions &options = state.getOptions();
+ const auto &options =
+ static_cast<const AnalysisBufferizationOptions &>(state.getOptions());
if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
return failure();
@@ -740,8 +738,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
}
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
- Operation *op, std::unique_ptr<BufferizationOptions> options) {
- BufferizationState state(op, *options);
+ Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options) {
+ AnalysisBufferizationState state(op, *options);
if (failed(analyzeOp(op, state)))
return failure();
if (options->testAnalysisOnly)
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 6d153522811d7..6cce30d165a34 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -193,7 +193,6 @@ struct LinalgOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@@ -264,7 +263,6 @@ struct TiledLoopOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index ec2d6d758e0e0..fae2682c239fb 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -737,7 +737,6 @@ struct CallOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@@ -964,9 +963,9 @@ annotateOpsWithBufferizationMarkers(FuncOp funcOp,
}
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
- ModuleOp moduleOp, std::unique_ptr<BufferizationOptions> options) {
+ ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options) {
IRRewriter rewriter(moduleOp.getContext());
- BufferizationState state(moduleOp, *options);
+ AnalysisBufferizationState state(moduleOp, *options);
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 6337cd023dd22..97cba18592c93 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -124,7 +124,6 @@ struct ExecuteRegionOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@@ -247,7 +246,6 @@ struct IfOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
// IfOp results are equivalent to their corresponding yield values if both
// yield values are equivalent to each other.
@@ -255,7 +253,7 @@ struct IfOpInterface
SmallVector<OpOperand *> yieldValues =
bufferizableOp.getAliasingOpOperand(opResult, state);
assert(yieldValues.size() == 2 && "expected 2 yield values");
- bool equivalentYields = aliasInfo.areEquivalentBufferizedValues(
+ bool equivalentYields = state.areEquivalentBufferizedValues(
yieldValues[0]->get(), yieldValues[1]->get());
return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
}
@@ -291,7 +289,6 @@ struct ForOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
// ForOp results are equivalent to their corresponding init_args if the
// corresponding iter_args and yield values are equivalent.
@@ -299,7 +296,7 @@ struct ForOpInterface
OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
auto yieldOp = cast<scf::YieldOp>(&forOp.getLoopBody().front().back());
- bool equivalentYield = aliasInfo.areEquivalentBufferizedValues(
+ bool equivalentYield = state.areEquivalentBufferizedValues(
bbArg, yieldOp->getOperand(opResult.getResultNumber()));
return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
}
@@ -408,7 +405,9 @@ mlir::linalg::comprehensive_bufferize::scf_ext::AssertScfForAliasingProperties::
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
- if (!aliasInfo.areAliasingBufferizedValues(operand.get(), bbArg)) {
+ // Note: This is overly strict. We should check for aliasing bufferized
+ // values. But we don't have a "must-alias" analysis yet.
+ if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
status =
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
index 1f8cee5cf7adb..5e603b3765763 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
@@ -62,7 +62,6 @@ struct SelectOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::None;
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index a31832452ea84..dc0742b99022d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -42,7 +42,6 @@ struct CastOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@@ -137,7 +136,6 @@ struct ExtractSliceOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::None;
}
@@ -273,7 +271,6 @@ struct InsertOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@@ -285,12 +282,12 @@ struct InsertOpInterface
/// This is one particular type of relationship between ops on tensors that
/// reduce to an equivalence on buffers. This should be generalized and
/// exposed as interfaces on the proper types.
-static bool
-areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
- ExtractSliceOp st, InsertSliceOp sti) {
+static bool areEquivalentExtractSliceOps(const BufferizationState &state,
+ ExtractSliceOp st, InsertSliceOp sti) {
if (!st || !sti)
return false;
- if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
+ if (sti != sti &&
+ !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
return false;
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
return false;
@@ -299,12 +296,11 @@ areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
-static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
- const BufferizationState &state,
+static bool hasMatchingExtractSliceOp(const BufferizationState &state,
Value value, InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
- if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
+ if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
return true;
return false;
};
@@ -336,15 +332,13 @@ struct InsertSliceOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
- const BufferizationState &state,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationState &state) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
@@ -360,7 +354,7 @@ struct InsertSliceOpInterface
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(aliasInfo, state, uConflictingWrite->get(),
+ hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
@@ -378,8 +372,7 @@ struct InsertSliceOpInterface
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(aliasInfo, state, uRead->get(),
- insertSliceOp))
+ hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
@@ -410,9 +403,9 @@ struct InsertSliceOpInterface
// memory segment of %1 with the exact same data. (Effectively, there
// is no memory write here.)
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- aliasInfo.areEquivalentBufferizedValues(uRead->get(),
- insertSliceOp.source()) &&
- hasMatchingExtractSliceOp(aliasInfo, state, insertSliceOp.source(),
+ state.areEquivalentBufferizedValues(uRead->get(),
+ insertSliceOp.source()) &&
+ hasMatchingExtractSliceOp(state, insertSliceOp.source(),
insertSliceOp))
return true;
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 0b3d8ff6d2663..4dd7fdcebc034 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -85,7 +85,6 @@ struct TransferWriteOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 36da9e52ea3c7..dfc25bae2a475 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -75,7 +75,7 @@ static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
}
void LinalgComprehensiveModuleBufferize::runOnOperation() {
- auto options = std::make_unique<BufferizationOptions>();
+ auto options = std::make_unique<AnalysisBufferizationOptions>();
if (useAlloca) {
options->allocationFn = allocationFnUsingAlloca;
options->deallocationFn = [](OpBuilder &b, Location loc, Value v) {
diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index d0852e9c65189..87677ef2383de 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -101,7 +101,7 @@ struct TestComprehensiveFunctionBufferize
} // namespace
void TestComprehensiveFunctionBufferize::runOnOperation() {
- auto options = std::make_unique<BufferizationOptions>();
+ auto options = std::make_unique<AnalysisBufferizationOptions>();
if (!allowReturnMemref)
options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();
More information about the Mlir-commits
mailing list