[Mlir-commits] [mlir] 695c341 - [mlir][bufferize] Generalize filtering mechanism in BufferizationOptions
Matthias Springer
llvmlistbot at llvm.org
Tue Feb 15 02:17:42 PST 2022
Author: Matthias Springer
Date: 2022-02-15T19:17:33+09:00
New Revision: 695c341b84d19459b8f71319344380e4d7025b39
URL: https://github.com/llvm/llvm-project/commit/695c341b84d19459b8f71319344380e4d7025b39
DIFF: https://github.com/llvm/llvm-project/commit/695c341b84d19459b8f71319344380e4d7025b39.diff
LOG: [mlir][bufferize] Generalize filtering mechanism in BufferizationOptions
Support ALLOW filters and DENY filters. This is needed for compatibility with existing code that specifies more complex op filters.
Differential Revision: https://reviews.llvm.org/D119820
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 714f75d09b965..257f67a2db440 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -45,32 +45,93 @@ struct BufferizationOptions {
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+ /// An op filter entry. Filters can be used to specify which ops should be
+ /// processed by the bufferization.
+ struct OpFilterEntry {
+ /// If the filter function evaluates to `true`, the filter matches.
+ using FilterFn = std::function<bool(Operation *)>;
+
+ /// Filter type: A filter can either be a DENY filter or an ALLOW filter.
+ enum FilterType : int8_t { DENY = 0, ALLOW = 1 };
+
+ FilterFn fn;
+ FilterType type;
+ };
+
BufferizationOptions();
- /// Return `true` if the op is allowed to be bufferized.
+ /// Return whether the op should be bufferized or not.
+ ///
+ /// If no filter is specified (`hasFilter` = false), every op will be
+ /// bufferized. Otherwise, an op is bufferized if:
+ ///
+ /// - At least one ALLOW filter says `true`.
+ /// - And, no DENY filter says `true`.
bool isOpAllowed(Operation *op) const {
if (!hasFilter)
return true;
- return dialectFilter.contains(op->getDialect()->getNamespace()) ||
- operationFilter.contains(op->getName().getStringRef());
+ bool isAllowed = false;
+ for (const OpFilterEntry &entry : opFilter) {
+ bool filterResult = entry.fn(op);
+ switch (entry.type) {
+ case OpFilterEntry::ALLOW:
+ isAllowed |= filterResult;
+ break;
+ case OpFilterEntry::DENY:
+ if (filterResult)
+ // DENY filter matches. This op is no allowed. (Even if other ALLOW
+ // filters may match.)
+ return false;
+ };
+ }
+ return isAllowed;
}
/// Allow the given dialects and activate the filter (`hasFilter`).
+ ///
+ /// This function adds one or multiple ALLOW filters.
template <typename... DialectTs>
- void addToDialectFilter() {
- // The following expands a call to addToDialectFilterImpl for each dialect
+ void allowDialectInFilter() {
+ // The following expands a call to allowDialectInFilterImpl for each dialect
// in 'DialectTs'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{
- 0, (addToDialectFilterImpl<DialectTs>(), 0)...};
+ 0, (allowDialectInFilterImpl<DialectTs>(), 0)...};
+ }
+
+ /// Allow the given dialect and activate the filter (`hasFilter`).
+ ///
+ /// This function adds an ALLOW filter.
+ void allowDialectInFilter(StringRef dialectNamespace) {
+ hasFilter = true;
+ OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
+ return op->getDialect()->getNamespace() == dialectNamespace;
+ };
+ opFilter.push_back(
+ OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
}
/// Allow the given ops and activate the filter (`hasFilter`).
- template <typename... OpTys> void addToOperationFilter() {
+ ///
+ /// This function adds one or multiple ALLOW filters.
+ template <typename... OpTys>
+ void allowOperationInFilter() {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{0,
- (addToOperationFilterImpl<OpTys>(), 0)...};
+ (void)std::initializer_list<int>{
+ 0, (allowOperationInFilterImpl<OpTys>(), 0)...};
+ }
+
+ /// Allow the given op and activate the filter (`hasFilter`).
+ ///
+ /// This function adds an ALLOW filter.
+ void allowOperationInFilter(StringRef opName) {
+ hasFilter = true;
+ OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
+ return op->getName().getStringRef() == opName;
+ };
+ opFilter.push_back(
+ OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
}
/// Try to cast the given op to BufferizableOpInterface if the op is allow
@@ -118,33 +179,26 @@ struct BufferizationOptions {
/// Buffer alignment for new memory allocations.
unsigned int bufferAlignment = 128;
- /// If set to `true`, only ops that belong to a filtered dialect
- /// (`dialectFilter`) and filtered ops (`operationFilter`) are processed. All
- /// other ops are ignored. If set to `false`, all ops are bufferized (as long
- /// as they implement BufferizableOpInterface).
- ///
- /// If a filter is specified, `allowUnknownOps` should be enabled. Otherwise,
- /// bufferization would fail when encountering a non-filtered op.
+ /// If set to `false`, all ops are bufferized (as long as they implement
+ /// BufferizableOpInterface). Otherwise, only filtered ops are bufferized.
bool hasFilter = false;
- /// A set of allowed dialects.
- DenseSet<StringRef> dialectFilter;
-
- /// A set of allowed ops.
- DenseSet<StringRef> operationFilter;
+ /// A list of op filters that determine whether an op should be processed or
+ /// ignored by the bufferization. If `hasFilter`, only ops that are not
+ /// DENY-filtered and have at least one matching ALLOW filter are processed.
+ SmallVector<OpFilterEntry> opFilter;
private:
/// Allow a dialect.
template <typename DialectT>
- void addToDialectFilterImpl() {
- hasFilter = true;
- dialectFilter.insert(DialectT::getDialectNamespace());
+ void allowDialectInFilterImpl() {
+ allowDialectInFilter(DialectT::getDialectNamespace());
}
/// Allow an op.
- template <typename OpTy> void addToOperationFilterImpl() {
- hasFilter = true;
- operationFilter.insert(OpTy::getOperationName());
+ template <typename OpTy>
+ void allowOperationInFilterImpl() {
+ allowOperationInFilter(OpTy::getOperationName());
}
};
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
index d88ee4a65e26c..bc32daf75bc3f 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
@@ -31,9 +31,9 @@ struct ArithmeticBufferizePass
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
if (constantOpOnly) {
- options.addToOperationFilter<arith::ConstantOp>();
+ options.allowOperationInFilter<arith::ConstantOp>();
} else {
- options.addToDialectFilter<arith::ArithmeticDialect>();
+ options.allowDialectInFilter<arith::ArithmeticDialect>();
}
options.bufferAlignment = alignment;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 7d8ee6ff3ee67..41b11ea197b22 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -31,7 +31,7 @@ namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
- options.addToDialectFilter<tensor::TensorDialect>();
+ options.allowDialectInFilter<tensor::TensorDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 28535ba8248a9..f0b6b0e669ec4 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -116,7 +116,7 @@ void TestComprehensiveFunctionBufferize::runOnOperation() {
if (dialectFilter.hasValue()) {
options->hasFilter = true;
for (const std::string &dialectNamespace : dialectFilter)
- options->dialectFilter.insert(dialectNamespace);
+ options->allowDialectInFilter(dialectNamespace);
}
Operation *op = getOperation();
More information about the Mlir-commits
mailing list