[Mlir-commits] [mlir] cdb7675 - [mlir][bufferize][NFC] Make PostAnalysisSteps a function
Matthias Springer
llvmlistbot at llvm.org
Wed Feb 9 02:03:19 PST 2022
Author: Matthias Springer
Date: 2022-02-09T18:56:06+09:00
New Revision: cdb7675c2649ee91c4e97d84daa76c98cb93b9c4
URL: https://github.com/llvm/llvm-project/commit/cdb7675c2649ee91c4e97d84daa76c98cb93b9c4
DIFF: https://github.com/llvm/llvm-project/commit/cdb7675c2649ee91c4e97d84daa76c98cb93b9c4.diff
LOG: [mlir][bufferize][NFC] Make PostAnalysisSteps a function
They used to be classes with a virtual `run` function. This was inconvenient because post analysis steps are stored in BufferizationOptions. Because of this design choice, BufferizationOptions were not copyable.
Differential Revision: https://reviews.llvm.org/D119258
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 3176d6fa337bc..609a1bb520c9d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -47,9 +47,6 @@ struct BufferizationOptions {
BufferizationOptions();
- // BufferizationOptions cannot be copied.
- BufferizationOptions(const BufferizationOptions &other) = delete;
-
/// Return `true` if the op is allowed to be bufferized.
bool isOpAllowed(Operation *op) const {
if (!hasFilter)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index a56287995aa96..e2ede4b63d2f3 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -82,7 +82,7 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
void populateBufferizationPattern(const BufferizationState &state,
RewritePatternSet &patterns);
-std::unique_ptr<BufferizationOptions> getPartialBufferizationOptions();
+BufferizationOptions getPartialBufferizationOptions();
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 93b8be9c7c7aa..8e9e09663ec13 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -20,35 +20,25 @@ class AnalysisBufferizationState;
class BufferizationAliasInfo;
struct AnalysisBufferizationOptions;
-/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
+/// PostAnalysisStepFns can be registered with `BufferizationOptions` and are
/// executed after the analysis, but before bufferization. They can be used to
-/// implement custom dialect-specific optimizations.
-struct PostAnalysisStep {
- virtual ~PostAnalysisStep() = default;
-
- /// Run the post analysis step. This function may modify the IR, but must keep
- /// `aliasInfo` consistent. Newly created operations and operations that
- /// should be re-analyzed must be added to `newOps`.
- virtual LogicalResult run(Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
- SmallVector<Operation *> &newOps) = 0;
-};
+/// implement custom dialect-specific optimizations. They may modify the IR, but
+/// must keep `aliasInfo` consistent. Newly created operations and operations
+/// that should be re-analyzed must be added to `newOps`.
+using PostAnalysisStepFn = std::function<LogicalResult(
+ Operation *, BufferizationState &, BufferizationAliasInfo &,
+ SmallVector<Operation *> &)>;
-using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
+using PostAnalysisStepList = SmallVector<PostAnalysisStepFn>;
/// Options for analysis-enabled bufferization.
struct AnalysisBufferizationOptions : public BufferizationOptions {
AnalysisBufferizationOptions() = default;
- // AnalysisBufferizationOptions cannot be copied.
- AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
-
/// Register a "post analysis" step. Such steps are executed after the
/// analysis, but before bufferization.
- template <typename Step, typename... Args>
- void addPostAnalysisStep(Args... args) {
- postAnalysisSteps.emplace_back(
- std::make_unique<Step>(std::forward<Args>(args)...));
+ void addPostAnalysisStep(PostAnalysisStepFn fn) {
+ postAnalysisSteps.push_back(fn);
}
/// Registered post analysis steps.
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
index 06145d028d4d1..010fd565faa92 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
@@ -18,42 +18,38 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace linalg_ext {
-struct InitTensorEliminationStep : public bufferization::PostAnalysisStep {
- /// A function that matches anchor OpOperands for InitTensorOp elimination.
- /// If an OpOperand is matched, the function should populate the SmallVector
- /// with all values that are needed during `RewriteFn` to produce the
- /// replacement value.
- using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
-
- /// A function that rewrites matched anchors.
- using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
-
- /// Try to eliminate InitTensorOps inside `op`.
- ///
- /// * `rewriteFunc` generates the replacement for the InitTensorOp.
- /// * Only InitTensorOps that are anchored on a matching OpOperand as per
- /// `anchorMatchFunc` are considered. "Anchored" means that there is a path
- /// on the reverse SSA use-def chain, starting from the OpOperand and always
- /// following the aliasing OpOperand, that eventually ends at a single
- /// InitTensorOp.
- /// * The result of `rewriteFunc` must usually be analyzed for inplacability.
- /// This analysis can be skipped with `skipAnalysis`.
- LogicalResult
- eliminateInitTensors(Operation *op, bufferization::BufferizationState &state,
- bufferization::BufferizationAliasInfo &aliasInfo,
- AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
- SmallVector<Operation *> &newOps);
-};
+/// A function that matches anchor OpOperands for InitTensorOp elimination.
+/// If an OpOperand is matched, the function should populate the SmallVector
+/// with all values that are needed during `RewriteFn` to produce the
+/// replacement value.
+using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
+
+/// A function that rewrites matched anchors.
+using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
+
+/// Try to eliminate InitTensorOps inside `op`.
+///
+/// * `rewriteFunc` generates the replacement for the InitTensorOp.
+/// * Only InitTensorOps that are anchored on a matching OpOperand as per
+/// `anchorMatchFunc` are considered. "Anchored" means that there is a path
+/// on the reverse SSA use-def chain, starting from the OpOperand and always
+/// following the aliasing OpOperand, that eventually ends at a single
+/// InitTensorOp.
+/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
+/// This analysis can be skipped with `skipAnalysis`.
+LogicalResult
+eliminateInitTensors(Operation *op, bufferization::BufferizationState &state,
+ bufferization::BufferizationAliasInfo &aliasInfo,
+ AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
+ SmallVector<Operation *> &newOps);
/// Try to eliminate InitTensorOps inside `op` that are anchored on an
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
/// (and some other conditions are met).
-struct InsertSliceAnchoredInitTensorEliminationStep
- : public InitTensorEliminationStep {
- LogicalResult run(Operation *op, bufferization::BufferizationState &state,
- bufferization::BufferizationAliasInfo &aliasInfo,
- SmallVector<Operation *> &newOps) override;
-};
+LogicalResult insertSliceAnchoredInitTensorEliminationStep(
+ Operation *op, bufferization::BufferizationState &state,
+ bufferization::BufferizationAliasInfo &aliasInfo,
+ SmallVector<Operation *> &newOps);
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
index ea8969e004686..dfeb9514409fb 100644
--- a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
@@ -14,16 +14,21 @@
namespace mlir {
class DialectRegistry;
+namespace bufferization {
+class BufferizationState;
+class BufferizationAliasInfo;
+} // namespace bufferization
+
namespace scf {
/// Assert that yielded values of an scf.for op are aliasing their corresponding
/// bbArgs. This is required because the i-th OpResult of an scf.for op is
/// currently assumed to alias with the i-th iter_arg (in the absence of
/// conflicts).
-struct AssertScfForAliasingProperties : public bufferization::PostAnalysisStep {
- LogicalResult run(Operation *op, bufferization::BufferizationState &state,
- bufferization::BufferizationAliasInfo &aliasInfo,
- SmallVector<Operation *> &newOps) override;
-};
+LogicalResult
+assertScfForAliasingProperties(Operation *op,
+ bufferization::BufferizationState &state,
+ bufferization::BufferizationAliasInfo &aliasInfo,
+ SmallVector<Operation *> &newOps);
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
} // namespace scf
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
index 1fafd255d60d3..d88ee4a65e26c 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
@@ -29,16 +29,15 @@ struct ArithmeticBufferizePass
}
void runOnOperation() override {
- std::unique_ptr<BufferizationOptions> options =
- getPartialBufferizationOptions();
+ BufferizationOptions options = getPartialBufferizationOptions();
if (constantOpOnly) {
- options->addToOperationFilter<arith::ConstantOp>();
+ options.addToOperationFilter<arith::ConstantOp>();
} else {
- options->addToDialectFilter<arith::ArithmeticDialect>();
+ options.addToDialectFilter<arith::ArithmeticDialect>();
}
- options->bufferAlignment = alignment;
+ options.bufferAlignment = alignment;
- if (failed(bufferizeOp(getOperation(), *options)))
+ if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 2b2d7cceeabb4..c7468da6132f5 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -253,12 +253,11 @@ void bufferization::populateBufferizationPattern(
patterns.add<BufferizationPattern>(patterns.getContext(), state);
}
-std::unique_ptr<BufferizationOptions>
-bufferization::getPartialBufferizationOptions() {
- auto options = std::make_unique<BufferizationOptions>();
- options->allowReturnMemref = true;
- options->allowUnknownOps = true;
- options->createDeallocs = false;
- options->fullyDynamicLayoutMaps = false;
+BufferizationOptions bufferization::getPartialBufferizationOptions() {
+ BufferizationOptions options;
+ options.allowReturnMemref = true;
+ options.allowUnknownOps = true;
+ options.createDeallocs = false;
+ options.fullyDynamicLayoutMaps = false;
return options;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index c21f7f9704b56..d03a287af5a0e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -698,52 +698,51 @@ annotateOpsWithBufferizationMarkers(Operation *op,
// aliasing values, which is stricter than needed. We can currently not check
// for aliasing values because the analysis is a maybe-alias analysis and we
// need a must-alias analysis here.
-struct AssertDestinationPassingStyle : public PostAnalysisStep {
- LogicalResult run(Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
- SmallVector<Operation *> &newOps) override {
- LogicalResult status = success();
- DominanceInfo domInfo(op);
- op->walk([&](Operation *returnOp) {
- if (!isRegionReturnLike(returnOp))
- return WalkResult::advance();
-
- for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
- Value returnVal = returnValOperand.get();
- // Skip non-tensor values.
- if (!returnVal.getType().isa<TensorType>())
- continue;
+static LogicalResult
+assertDestinationPassingStyle(Operation *op, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
+ SmallVector<Operation *> &newOps) {
+ LogicalResult status = success();
+ DominanceInfo domInfo(op);
+ op->walk([&](Operation *returnOp) {
+ if (!isRegionReturnLike(returnOp))
+ return WalkResult::advance();
- bool foundEquivValue = false;
- aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
- if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
- Operation *definingOp = bbArg.getOwner()->getParentOp();
- if (definingOp->isProperAncestor(returnOp))
- foundEquivValue = true;
- return;
- }
+ for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
+ Value returnVal = returnValOperand.get();
+ // Skip non-tensor values.
+ if (!returnVal.getType().isa<TensorType>())
+ continue;
- Operation *definingOp = equivVal.getDefiningOp();
- if (definingOp->getBlock()->findAncestorOpInBlock(
- *returnOp->getParentOp()))
- // Skip ops that happen after `returnOp` and parent ops.
- if (happensBefore(definingOp, returnOp, domInfo))
- foundEquivValue = true;
- });
-
- if (!foundEquivValue)
- status =
- returnOp->emitError()
- << "operand #" << returnValOperand.getOperandNumber()
- << " of ReturnLike op does not satisfy destination passing style";
- }
+ bool foundEquivValue = false;
+ aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
+ if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
+ Operation *definingOp = bbArg.getOwner()->getParentOp();
+ if (definingOp->isProperAncestor(returnOp))
+ foundEquivValue = true;
+ return;
+ }
- return WalkResult::advance();
- });
+ Operation *definingOp = equivVal.getDefiningOp();
+ if (definingOp->getBlock()->findAncestorOpInBlock(
+ *returnOp->getParentOp()))
+ // Skip ops that happen after `returnOp` and parent ops.
+ if (happensBefore(definingOp, returnOp, domInfo))
+ foundEquivValue = true;
+ });
+
+ if (!foundEquivValue)
+ status =
+ returnOp->emitError()
+ << "operand #" << returnValOperand.getOperandNumber()
+ << " of ReturnLike op does not satisfy destination passing style";
+ }
- return status;
- }
-};
+ return WalkResult::advance();
+ });
+
+ return status;
+}
LogicalResult bufferization::analyzeOp(Operation *op,
AnalysisBufferizationState &state) {
@@ -761,12 +760,11 @@ LogicalResult bufferization::analyzeOp(Operation *op,
return failure();
equivalenceAnalysis(op, aliasInfo, state);
- for (const std::unique_ptr<PostAnalysisStep> &step :
- options.postAnalysisSteps) {
+ for (const PostAnalysisStepFn &fn : options.postAnalysisSteps) {
SmallVector<Operation *> newOps;
- if (failed(step->run(op, state, aliasInfo, newOps)))
+ if (failed(fn(op, state, aliasInfo, newOps)))
return failure();
- // Analyze ops that were created by the PostAnalysisStep.
+ // Analyze ops that were created by the PostAnalysisStepFn.
if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
return failure();
equivalenceAnalysis(newOps, aliasInfo, state);
@@ -774,8 +772,7 @@ LogicalResult bufferization::analyzeOp(Operation *op,
if (!options.allowReturnMemref) {
SmallVector<Operation *> newOps;
- if (failed(
- AssertDestinationPassingStyle().run(op, state, aliasInfo, newOps)))
+ if (failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps)))
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 493044aa53aaa..984b42b59c7f9 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -524,11 +524,10 @@ findValidInsertionPoint(Operation *initTensorOp,
/// chain, starting from the OpOperand and always following the aliasing
/// OpOperand, that eventually ends at a single InitTensorOp.
LogicalResult
-mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
- eliminateInitTensors(Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
- AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
- SmallVector<Operation *> &newOps) {
+mlir::linalg::comprehensive_bufferize::linalg_ext::eliminateInitTensors(
+ Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
+ AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
+ SmallVector<Operation *> &newOps) {
OpBuilder b(op->getContext());
WalkResult status = op->walk([&](Operation *op) {
@@ -628,7 +627,7 @@ mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
/// Note that the newly inserted ExtractSliceOp may have to bufferize
/// out-of-place due to RaW conflicts.
LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
- InsertSliceAnchoredInitTensorEliminationStep::run(
+ insertSliceAnchoredInitTensorEliminationStep(
Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
return eliminateInitTensors(
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 63ffb09320076..6f04a2fd40c27 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -16,11 +16,12 @@
// their respective callers.
//
// After analyzing a FuncOp, additional information about its bbArgs is
-// gathered through PostAnalysisSteps and stored in `ModuleBufferizationState`.
+// gathered through PostAnalysisStepFns and stored in
+// `ModuleBufferizationState`.
//
-// * `EquivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each
+// * `equivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each
// tensor return value (if any).
-// * `FuncOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
+// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
// read/written.
//
// Only tensors that are equivalent to some FuncOp bbArg may be returned.
@@ -47,7 +48,7 @@
// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
// out-of-place because `%t0` is modified by the callee but read by the
// tensor.extract op. The analysis of CallOps decides whether an OpOperand must
-// bufferize out-of-place based on results of `FuncOpBbArgReadWriteAnalysis`.
+// bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
// ```
// func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
// %f = ... : f32
@@ -62,7 +63,7 @@
// }
// ```
//
-// Note: If a function is external, `FuncOpBbArgReadWriteAnalysis` cannot
+// Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
// analyze the function body. In such a case, the CallOp analysis conservatively
// assumes that each tensor OpOperand is both read and written.
//
@@ -159,55 +160,55 @@ static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
}
namespace {
-/// Store function BlockArguments that are equivalent to a returned value in
-/// ModuleBufferizationState.
-struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
- /// Annotate IR with the results of the analysis. For testing purposes only.
- static void annotateReturnOp(OpOperand &returnVal, BlockArgument bbArg) {
- const char *kEquivalentArgsAttr = "__equivalent_func_args__";
- Operation *op = returnVal.getOwner();
-
- SmallVector<int64_t> equivBbArgs;
- if (op->hasAttr(kEquivalentArgsAttr)) {
- auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
- equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
- return a.cast<IntegerAttr>().getValue().getSExtValue();
- }));
- } else {
- equivBbArgs.append(op->getNumOperands(), -1);
- }
- equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
- OpBuilder b(op->getContext());
- op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
+/// Annotate IR with the results of the analysis. For testing purposes only.
+static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
+ BlockArgument bbArg) {
+ const char *kEquivalentArgsAttr = "__equivalent_func_args__";
+ Operation *op = returnVal.getOwner();
+
+ SmallVector<int64_t> equivBbArgs;
+ if (op->hasAttr(kEquivalentArgsAttr)) {
+ auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
+ equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
+ return a.cast<IntegerAttr>().getValue().getSExtValue();
+ }));
+ } else {
+ equivBbArgs.append(op->getNumOperands(), -1);
}
+ equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
- LogicalResult run(Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
- SmallVector<Operation *> &newOps) override {
- ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+ OpBuilder b(op->getContext());
+ op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
+}
- // Support only single return-terminated block in the function.
- auto funcOp = cast<FuncOp>(op);
- ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- assert(returnOp && "expected func with single return op");
-
- for (OpOperand &returnVal : returnOp->getOpOperands())
- if (returnVal.get().getType().isa<RankedTensorType>())
- for (BlockArgument bbArg : funcOp.getArguments())
- if (bbArg.getType().isa<RankedTensorType>())
- if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
- bbArg)) {
- moduleState
- .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
- bbArg.getArgNumber();
- if (state.getOptions().testAnalysisOnly)
- annotateReturnOp(returnVal, bbArg);
- }
+/// Store function BlockArguments that are equivalent to a returned value in
+/// ModuleBufferizationState.
+static LogicalResult
+equivalentFuncOpBBArgsAnalysis(Operation *op, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
+ SmallVector<Operation *> &newOps) {
+ ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
- return success();
- }
-};
+ // Support only single return-terminated block in the function.
+ auto funcOp = cast<FuncOp>(op);
+ ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ assert(returnOp && "expected func with single return op");
+
+ for (OpOperand &returnVal : returnOp->getOpOperands())
+ if (returnVal.get().getType().isa<RankedTensorType>())
+ for (BlockArgument bbArg : funcOp.getArguments())
+ if (bbArg.getType().isa<RankedTensorType>())
+ if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
+ moduleState
+ .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
+ bbArg.getArgNumber();
+ if (state.getOptions().testAnalysisOnly)
+ annotateEquivalentReturnBbArg(returnVal, bbArg);
+ }
+
+ return success();
+}
/// Return true if the buffer of the given tensor value is written to. Must not
/// be called for values inside not yet analyzed functions. (Post-analysis
@@ -239,38 +240,37 @@ static bool isValueWritten(Value value, const BufferizationState &state,
}
/// Determine which FuncOp bbArgs are read and which are written. If this
-/// PostAnalysisStep is run on a function with unknown ops, it will
+/// PostAnalysisStepFn is run on a function with unknown ops, it will
/// conservatively assume that such ops bufferize to a read + write.
-struct FuncOpBbArgReadWriteAnalysis : public PostAnalysisStep {
- LogicalResult run(Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
- SmallVector<Operation *> &newOps) override {
- ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
- auto funcOp = cast<FuncOp>(op);
-
- // If the function has no body, conservatively assume that all args are
- // read + written.
- if (funcOp.getBody().empty()) {
- for (BlockArgument bbArg : funcOp.getArguments()) {
- moduleState.readBbArgs.insert(bbArg);
- moduleState.writtenBbArgs.insert(bbArg);
- }
-
- return success();
- }
+static LogicalResult
+funcOpBbArgReadWriteAnalysis(Operation *op, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
+ SmallVector<Operation *> &newOps) {
+ ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+ auto funcOp = cast<FuncOp>(op);
+ // If the function has no body, conservatively assume that all args are
+ // read + written.
+ if (funcOp.getBody().empty()) {
for (BlockArgument bbArg : funcOp.getArguments()) {
- if (!bbArg.getType().isa<TensorType>())
- continue;
- if (state.isValueRead(bbArg))
- moduleState.readBbArgs.insert(bbArg);
- if (isValueWritten(bbArg, state, aliasInfo))
- moduleState.writtenBbArgs.insert(bbArg);
+ moduleState.readBbArgs.insert(bbArg);
+ moduleState.writtenBbArgs.insert(bbArg);
}
return success();
}
-};
+
+ for (BlockArgument bbArg : funcOp.getArguments()) {
+ if (!bbArg.getType().isa<TensorType>())
+ continue;
+ if (state.isValueRead(bbArg))
+ moduleState.readBbArgs.insert(bbArg);
+ if (isValueWritten(bbArg, state, aliasInfo))
+ moduleState.writtenBbArgs.insert(bbArg);
+ }
+
+ return success();
+}
} // namespace
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
@@ -983,10 +983,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
return failure();
// Collect bbArg/return value information after the analysis.
- options->postAnalysisSteps.emplace_back(
- std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
- options->postAnalysisSteps.emplace_back(
- std::make_unique<FuncOpBbArgReadWriteAnalysis>());
+ options->postAnalysisSteps.push_back(equivalentFuncOpBBArgsAnalysis);
+ options->postAnalysisSteps.push_back(funcOpBbArgReadWriteAnalysis);
// Analyze ops.
for (FuncOp funcOp : moduleState.orderedFuncOps) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 7597ff8be1ff5..708db1e089072 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -125,12 +125,12 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
// Enable InitTensorOp elimination.
if (initTensorElimination) {
- options->addPostAnalysisStep<
- linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
+ options->addPostAnalysisStep(
+ linalg_ext::insertSliceAnchoredInitTensorEliminationStep);
}
// Only certain scf.for ops are supported by the analysis.
- options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
+ options->addPostAnalysisStep(scf::assertScfForAliasingProperties);
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index f26f7b9ec890e..cc4147fdc2691 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -432,7 +432,7 @@ struct YieldOpInterface
} // namespace scf
} // namespace mlir
-LogicalResult mlir::scf::AssertScfForAliasingProperties::run(
+LogicalResult mlir::scf::assertScfForAliasingProperties(
Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
LogicalResult status = success();
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 1d435c64d9287..7d8ee6ff3ee67 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -30,11 +30,10 @@ using namespace bufferization;
namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnOperation() override {
- std::unique_ptr<BufferizationOptions> options =
- getPartialBufferizationOptions();
- options->addToDialectFilter<tensor::TensorDialect>();
+ BufferizationOptions options = getPartialBufferizationOptions();
+ options.addToDialectFilter<tensor::TensorDialect>();
- if (failed(bufferizeOp(getOperation(), *options)))
+ if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
}
diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 230a412cd2c2c..28535ba8248a9 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -104,7 +104,7 @@ void TestComprehensiveFunctionBufferize::runOnOperation() {
auto options = std::make_unique<AnalysisBufferizationOptions>();
if (!allowReturnMemref)
- options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
+ options->addPostAnalysisStep(scf::assertScfForAliasingProperties);
options->allowReturnMemref = allowReturnMemref;
options->allowUnknownOps = allowUnknownOps;
More information about the Mlir-commits
mailing list