[Mlir-commits] [mlir] d6dab38 - [mlir][bufferize][NFC] Add function boundary bufferization flag to BufferizationOptions
Matthias Springer
llvmlistbot at llvm.org
Fri Apr 22 09:11:48 PDT 2022
Author: Matthias Springer
Date: 2022-04-23T01:11:37+09:00
New Revision: d6dab38ae48a298b671107c9ebdd86de8f7cd482
URL: https://github.com/llvm/llvm-project/commit/d6dab38ae48a298b671107c9ebdd86de8f7cd482
DIFF: https://github.com/llvm/llvm-project/commit/d6dab38ae48a298b671107c9ebdd86de8f7cd482.diff
LOG: [mlir][bufferize][NFC] Add function boundary bufferization flag to BufferizationOptions
This makes the API easier to use. Also allows us to check for incorrect API usage for easier debugging.
Differential Revision: https://reviews.llvm.org/D124265
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c228e26d1198e..94504c1e9b104 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -78,23 +78,7 @@ struct BufferizationOptions {
/// unless they are explicitly marked as DENY. If the filter has at least one
/// ALLOW rule, ops are ignored by default and only bufferized if they match
/// an ALLOW rule and no DENY rule.
- bool isOpAllowed(Operation *op) const {
- bool isAllowed = !filterHasAllowRule();
- for (const OpFilterEntry &entry : opFilter) {
- bool filterResult = entry.fn(op);
- switch (entry.type) {
- case OpFilterEntry::ALLOW:
- isAllowed |= filterResult;
- break;
- case OpFilterEntry::DENY:
- if (filterResult)
- // DENY filter matches. This op is no allowed. (Even if other ALLOW
- // filters may match.)
- return false;
- };
- }
- return isAllowed;
- }
+ bool isOpAllowed(Operation *op) const;
/// Allow the given dialects in the filter.
///
@@ -182,6 +166,10 @@ struct BufferizationOptions {
/// the boundaries.
bool allowUnknownOps = false;
+ /// Specifies whether function boundaries (ops in the func dialect) should be
+ /// bufferized or not.
+ bool bufferizeFunctionBoundaries = false;
+
/// Specifies whether dealloc ops should be generated along with alloc ops. If
/// not, new memory allocations will leak.
bool createDeallocs = true;
@@ -356,6 +344,12 @@ class AnalysisState {
/// any given tensor.
virtual bool isTensorYielded(Value tensor) const = 0;
+ /// Return `true` if the given dialect state exists.
+ bool hasDialectState(StringRef name) const {
+ auto it = dialectState.find(name);
+ return it != dialectState.end();
+ }
+
/// Return dialect-specific bufferization state.
template <typename StateT>
Optional<const StateT *> getDialectState(StringRef name) const {
@@ -369,7 +363,7 @@ class AnalysisState {
template <typename StateT>
StateT &getOrCreateDialectState(StringRef name) {
// Create state if it does not exist yet.
- if (!dialectState.count(name))
+ if (!hasDialectState(name))
dialectState[name] = std::make_unique<StateT>();
return static_cast<StateT &>(*dialectState[name]);
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 921c6c80f6746..73da6c85e761f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -51,6 +51,31 @@ static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
// Default constructor for BufferizationOptions.
BufferizationOptions::BufferizationOptions() = default;
+bool BufferizationOptions::isOpAllowed(Operation *op) const {
+ // Special case: If function boundary bufferization is deactivated, do not
+ // allow ops that belong to the `func` dialect.
+ bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
+ if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
+ return false;
+
+ // All other ops: Allow/disallow according to filter.
+ bool isAllowed = !filterHasAllowRule();
+ for (const OpFilterEntry &entry : opFilter) {
+ bool filterResult = entry.fn(op);
+ switch (entry.type) {
+ case OpFilterEntry::ALLOW:
+ isAllowed |= filterResult;
+ break;
+ case OpFilterEntry::DENY:
+ if (filterResult)
+ // DENY filter matches. This op is no allowed. (Even if other ALLOW
+ // filters may match.)
+ return false;
+ };
+ }
+ return isAllowed;
+}
+
BufferizableOpInterface
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
if (isOpAllowed(op))
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 49aaeb76a0381..0936512df33a8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -175,15 +175,10 @@ struct OneShotBufferizePass
opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
+ opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
BufferizationOptions::OpFilterEntry::FilterFn filterFn =
[&](Operation *op) {
- // Disallow non-func dialect ops. I.e., no ops related to function
- // calls. (Unless explicitly activated.)
- bool isFuncBoundaryOp =
- isa_and_nonnull<func::FuncDialect>(op->getDialect());
- if (!this->bufferizeFunctionBoundaries && isFuncBoundaryOp)
- return false;
// Filter may be specified via options.
if (this->dialectFilter.hasValue())
return llvm::find(this->dialectFilter,
@@ -198,7 +193,7 @@ struct OneShotBufferizePass
}
ModuleOp moduleOp = getOperation();
- if (bufferizeFunctionBoundaries) {
+ if (opt.bufferizeFunctionBoundaries) {
if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
signalPassFailure();
return;
@@ -284,6 +279,12 @@ bufferization::finalizeBuffers(Operation *op,
LogicalResult bufferization::bufferizeOp(Operation *op,
const AnalysisState &analysisState) {
+ // Catch incorrect API usage.
+ assert((analysisState.hasDialectState(
+ func::FuncDialect::getDialectNamespace()) ||
+ !analysisState.getOptions().bufferizeFunctionBoundaries) &&
+ "must use ModuleBufferize to bufferize function boundaries");
+
BufferizationState bufferizationState(analysisState);
if (failed(bufferizeOp(op, bufferizationState)))
return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 8ae5c1c8cace8..d1fbb70f889c8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -46,6 +46,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dominance.h"
@@ -864,6 +865,11 @@ LogicalResult bufferization::analyzeOp(Operation *op,
const auto &options =
static_cast<const OneShotBufferizationOptions &>(state.getOptions());
+ // Catch incorrect API usage.
+ assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) ||
+ !options.bufferizeFunctionBoundaries) &&
+ "must use ModuleBufferize to bufferize function boundaries");
+
if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 7bcde9fd1be43..6dc3432f46a63 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -417,6 +417,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
ModuleOp moduleOp, OneShotBufferizationOptions options) {
+ assert(options.bufferizeFunctionBoundaries &&
+ "expected that function boundary bufferization is activated");
IRRewriter rewriter(moduleOp.getContext());
OneShotAnalysisState analysisState(moduleOp, options);
BufferizationState bufferizationState(analysisState);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index b9163c5646ee1..13b760c6dd959 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -99,6 +99,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
+ opt.bufferizeFunctionBoundaries = true;
if (initTensorElimination) {
opt.addPostAnalysisStep(insertSliceAnchoredInitTensorEliminationStep);
}
More information about the Mlir-commits
mailing list