[Mlir-commits] [mlir] 1534177 - [mlir][bufferization][NFC] Move OpFilter out of BufferizationOptions
Matthias Springer
llvmlistbot at llvm.org
Fri May 27 16:51:04 PDT 2022
Author: Matthias Springer
Date: 2022-05-28T01:47:39+02:00
New Revision: 1534177f8f7edd83083ceda7c14d6d40cc872c6e
URL: https://github.com/llvm/llvm-project/commit/1534177f8f7edd83083ceda7c14d6d40cc872c6e
DIFF: https://github.com/llvm/llvm-project/commit/1534177f8f7edd83083ceda7c14d6d40cc872c6e.diff
LOG: [mlir][bufferization][NFC] Move OpFilter out of BufferizationOptions
Differential Revision: https://reviews.llvm.org/D126568
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index eb3cff0722b04..94fc30bafef9e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -23,28 +23,11 @@ class AnalysisState;
class BufferizableOpInterface;
struct DialectAnalysisState;
-/// Options for BufferizableOpInterface-based bufferization.
-struct BufferizationOptions {
- /// Allocator function: Generate a memref allocation with the given type,
- /// dynamic extents and alignment.
- using AllocationFn = std::function<FailureOr<Value>(
- OpBuilder &, Location, MemRefType, ValueRange, unsigned int)>;
- /// Deallocator function: Deallocate a buffer that was allocated with
- /// AllocatorFn.
- using DeallocationFn =
- std::function<LogicalResult(OpBuilder &, Location, Value)>;
- /// Memcpy function: Generate a memcpy between two buffers.
- using MemCpyFn =
- std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
- /// Initializer function for analysis state.
- using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
- /// Initializer function for dialect-specific analysis state.
- using DialectStateInitFn =
- std::function<std::unique_ptr<DialectAnalysisState>()>;
-
+class OpFilter {
+public:
/// An op filter entry. Filters can be used to specify which ops should be
/// processed by the bufferization.
- struct OpFilterEntry {
+ struct Entry {
/// If the filter function evaluates to `true`, the filter matches.
using FilterFn = std::function<bool(Operation *)>;
@@ -55,116 +38,156 @@ struct BufferizationOptions {
FilterType type;
};
- enum class LayoutMapOption : int8_t {
- InferLayoutMap = 0,
- IdentityLayoutMap = 1,
- FullyDynamicLayoutMap = 2
- };
-
- BufferizationOptions();
-
- /// Return `true` if the filter has at least one ALLOW rule.
- bool filterHasAllowRule() const {
- for (const OpFilterEntry &e : opFilter)
- if (e.type == OpFilterEntry::FilterType::ALLOW)
- return true;
- return false;
- }
-
- /// Return whether the op should be bufferized or not.
+ /// Return whether the op is allowed or not.
///
- /// If the filter does not have an ALLOW rule, ops are bufferized by default,
+ /// If the filter does not have an ALLOW rule, ops are allowed by default,
/// unless they are explicitly marked as DENY. If the filter has at least one
- /// ALLOW rule, ops are ignored by default and only bufferized if they match
+ /// ALLOW rule, ops are denied by default and only allowed if they match
/// an ALLOW rule and no DENY rule.
bool isOpAllowed(Operation *op) const;
- /// Allow the given dialects in the filter.
+ /// Allow the given dialects.
///
- /// This function adds one or multiple ALLOW filters.
- template <typename... DialectTs>
- void allowDialectInFilter() {
- // The following expands a call to allowDialectInFilterImpl for each dialect
+ /// This function adds one or multiple ALLOW entries.
+ template <typename... DialectTs> void allowDialect() {
+ // The following expands a call to allowDialectImpl for each dialect
// in 'DialectTs'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{
- 0, (allowDialectInFilterImpl<DialectTs>(), 0)...};
+ (void)std::initializer_list<int>{0, (allowDialectImpl<DialectTs>(), 0)...};
}
- /// Deny the given dialects in the filter.
+ /// Deny the given dialects.
///
- /// This function adds one or multiple DENY filters.
- template <typename... DialectTs> void denyDialectInFilter() {
+ /// This function adds one or multiple DENY entries.
+ template <typename... DialectTs> void denyDialect() {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{
- 0, (denyDialectInFilterImpl<DialectTs>(), 0)...};
+ (void)std::initializer_list<int>{0, (denyDialectImpl<DialectTs>(), 0)...};
}
- /// Allow the given dialect in the filter.
+ /// Allow the given dialect.
///
- /// This function adds an ALLOW filter.
- void allowDialectInFilter(StringRef dialectNamespace) {
- OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
+ /// This function adds an ALLOW entry.
+ void allowDialect(StringRef dialectNamespace) {
+ Entry::FilterFn filterFn = [=](Operation *op) {
return op->getDialect()->getNamespace() == dialectNamespace;
};
- opFilter.push_back(
- OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
+ entries.push_back(Entry{filterFn, Entry::FilterType::ALLOW});
}
- /// Allow the given ops in the filter.
+ /// Allow the given ops.
///
- /// This function adds one or multiple ALLOW filters.
- template <typename... OpTys>
- void allowOperationInFilter() {
+ /// This function adds one or multiple ALLOW entries.
+ template <typename... OpTys> void allowOperation() {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{
- 0, (allowOperationInFilterImpl<OpTys>(), 0)...};
+ (void)std::initializer_list<int>{0, (allowOperationImpl<OpTys>(), 0)...};
}
- /// Deny the given ops in the filter.
+ /// Deny the given ops.
///
- /// This function adds one or multiple DENY filters.
- template <typename... OpTys> void denyOperationInFilter() {
+ /// This function adds one or multiple DENY entries.
+ template <typename... OpTys> void denyOperation() {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{
- 0, (denyOperationInFilterImpl<OpTys>(), 0)...};
+ (void)std::initializer_list<int>{0, (denyOperationImpl<OpTys>(), 0)...};
}
- /// Allow the given op in the filter.
+ /// Allow the given op.
///
- /// This function adds an ALLOW filter.
- void allowOperationInFilter(StringRef opName) {
- OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
+ /// This function adds an ALLOW entry.
+ void allowOperation(StringRef opName) {
+ Entry::FilterFn filterFn = [=](Operation *op) {
return op->getName().getStringRef() == opName;
};
- allowOperationInFilter(filterFn);
+ allowOperation(filterFn);
}
- /// Deny the given op in the filter.
+ /// Deny the given op.
///
- /// This function adds a DENY filter.
- void denyOperationInFilter(StringRef opName) {
- OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
+ /// This function adds a DENY entry.
+ void denyOperation(StringRef opName) {
+ Entry::FilterFn filterFn = [=](Operation *op) {
return op->getName().getStringRef() == opName;
};
- denyOperationInFilter(filterFn);
+ denyOperation(filterFn);
}
- /// Allow ops that are matched by `fn` in the filter.
+ /// Allow ops that are matched by `fn`.
///
- /// This function adds an ALLOW filter.
- void allowOperationInFilter(OpFilterEntry::FilterFn fn) {
- opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::ALLOW});
+ /// This function adds an ALLOW entry.
+ void allowOperation(Entry::FilterFn fn) {
+ entries.push_back(Entry{fn, Entry::FilterType::ALLOW});
}
- /// Deny ops that are matched by `fn` in the filter.
+ /// Deny ops that are matched by `fn`.
///
- /// This function adds a DENY filter.
- void denyOperationInFilter(OpFilterEntry::FilterFn fn) {
- opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::DENY});
+ /// This function adds a DENY entry.
+ void denyOperation(Entry::FilterFn fn) {
+ entries.push_back(Entry{fn, Entry::FilterType::DENY});
+ }
+
+private:
+ /// Return `true` if the filter has at least one ALLOW rule.
+ bool hasAllowRule() const {
+ for (const Entry &e : entries)
+ if (e.type == Entry::FilterType::ALLOW)
+ return true;
+ return false;
+ }
+
+ /// Allow a dialect.
+ template <typename DialectT> void allowDialectImpl() {
+ allowDialect(DialectT::getDialectNamespace());
+ }
+
+ /// Deny a dialect.
+ template <typename DialectT> void denyDialectImpl() {
+ denyDialect(DialectT::getDialectNamespace());
+ }
+
+ /// Allow an op.
+ template <typename OpTy> void allowOperationImpl() {
+ allowOperation(OpTy::getOperationName());
}
+ /// Deny an op.
+ template <typename OpTy> void denyOperationImpl() {
+ denyOperation(OpTy::getOperationName());
+ }
+
+ /// A list of filter entries that determine whether an op should be allowed or
+ /// denied. If the filter has an ALLOW rule, only ops that are allowed and not
+ /// denied are allowed. If the filter does not have an ALLOW rule, only ops
+ /// that are not denied are allowed.
+ SmallVector<Entry> entries;
+};
+
+/// Options for BufferizableOpInterface-based bufferization.
+struct BufferizationOptions {
+ /// Allocator function: Generate a memref allocation with the given type,
+ /// dynamic extents and alignment.
+ using AllocationFn = std::function<FailureOr<Value>(
+ OpBuilder &, Location, MemRefType, ValueRange, unsigned int)>;
+ /// Deallocator function: Deallocate a buffer that was allocated with
+ /// AllocatorFn.
+ using DeallocationFn =
+ std::function<LogicalResult(OpBuilder &, Location, Value)>;
+ /// Memcpy function: Generate a memcpy between two buffers.
+ using MemCpyFn =
+ std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+ /// Initializer function for analysis state.
+ using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
+ /// Initializer function for dialect-specific analysis state.
+ using DialectStateInitFn =
+ std::function<std::unique_ptr<DialectAnalysisState>()>;
+
+ enum class LayoutMapOption : int8_t {
+ InferLayoutMap = 0,
+ IdentityLayoutMap = 1,
+ FullyDynamicLayoutMap = 2
+ };
+
+ BufferizationOptions();
+
/// Try to cast the given op to BufferizableOpInterface if the op is allow
/// listed.
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const;
@@ -173,6 +196,13 @@ struct BufferizationOptions {
/// listed.
BufferizableOpInterface dynCastBufferizableOp(Value value) const;
+ /// A filter that specifies which ops should be bufferized and which ops
+ /// should be ignored.
+ OpFilter opFilter;
+
+ /// Return `true` if the given op should be bufferized.
+ bool isOpAllowed(Operation *op) const;
+
/// Helper functions for allocation, deallocation, memory copying.
Optional<AllocationFn> allocationFn;
Optional<DeallocationFn> deallocationFn;
@@ -276,12 +306,6 @@ struct BufferizationOptions {
/// Buffer alignment for new memory allocations.
unsigned int bufferAlignment = 128;
- /// A list of op filters that determine whether an op should be processed or
- /// ignored by the bufferization. If the filter has an ALLOW rule, only ops
- /// that are allowed and not denied are bufferized. If the filter does not
- /// have an ALLOW rule, only ops that are not denied are bufferized.
- SmallVector<OpFilterEntry> opFilter;
-
/// Initializer functions for analysis state. These can be used to
/// initialize dialect-specific analysis state.
SmallVector<AnalysisStateInitFn> stateInitializers;
@@ -289,29 +313,6 @@ struct BufferizationOptions {
/// Add a analysis state initializer that initializes the specified
/// dialect-specific analysis state.
void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn);
-
-private:
- /// Allow a dialect.
- template <typename DialectT>
- void allowDialectInFilterImpl() {
- allowDialectInFilter(DialectT::getDialectNamespace());
- }
-
- /// Deny a dialect.
- template <typename DialectT> void denyDialectInFilterImpl() {
- denyDialectInFilter(DialectT::getDialectNamespace());
- }
-
- /// Allow an op.
- template <typename OpTy>
- void allowOperationInFilterImpl() {
- allowOperationInFilter(OpTy::getOperationName());
- }
-
- /// Deny an op.
- template <typename OpTy> void denyOperationInFilterImpl() {
- denyOperationInFilter(OpTy::getOperationName());
- }
};
/// Specify fine-grain relationship between buffers to enable more analysis.
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
index 4fb3d0e03ff1e..3237b1a880441 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
@@ -32,9 +32,9 @@ struct ArithmeticBufferizePass
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
if (constantOpOnly) {
- options.allowOperationInFilter<arith::ConstantOp>();
+ options.opFilter.allowOperation<arith::ConstantOp>();
} else {
- options.allowDialectInFilter<arith::ArithmeticDialect>();
+ options.opFilter.allowDialect<arith::ArithmeticDialect>();
}
options.bufferAlignment = alignment;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 79b499545ca9f..635e7d770e60b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -45,28 +45,19 @@ static const char *kBufferAllocationAttr = "bufferization.allocation";
static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
//===----------------------------------------------------------------------===//
-// BufferizationOptions
+// OpFilter
//===----------------------------------------------------------------------===//
-// Default constructor for BufferizationOptions.
-BufferizationOptions::BufferizationOptions() = default;
-
-bool BufferizationOptions::isOpAllowed(Operation *op) const {
- // Special case: If function boundary bufferization is deactivated, do not
- // allow ops that belong to the `func` dialect.
- bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
- if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
- return false;
-
+bool OpFilter::isOpAllowed(Operation *op) const {
// All other ops: Allow/disallow according to filter.
- bool isAllowed = !filterHasAllowRule();
- for (const OpFilterEntry &entry : opFilter) {
+ bool isAllowed = !hasAllowRule();
+ for (const Entry &entry : entries) {
bool filterResult = entry.fn(op);
switch (entry.type) {
- case OpFilterEntry::ALLOW:
+ case Entry::ALLOW:
isAllowed |= filterResult;
break;
- case OpFilterEntry::DENY:
+ case Entry::DENY:
if (filterResult)
// DENY filter matches. This op is no allowed. (Even if other ALLOW
// filters may match.)
@@ -76,6 +67,23 @@ bool BufferizationOptions::isOpAllowed(Operation *op) const {
return isAllowed;
}
+//===----------------------------------------------------------------------===//
+// BufferizationOptions
+//===----------------------------------------------------------------------===//
+
+// Default constructor for BufferizationOptions.
+BufferizationOptions::BufferizationOptions() = default;
+
+bool BufferizationOptions::isOpAllowed(Operation *op) const {
+ // Special case: If function boundary bufferization is deactivated, do not
+ // allow ops that belong to the `func` dialect.
+ bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
+ if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
+ return false;
+
+ return opFilter.isOpAllowed(op);
+}
+
BufferizableOpInterface
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 4f621979a5e35..d4b12bad308f0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -194,7 +194,7 @@ struct OneShotBufferizePass
opt.promoteBufferResultsToOutParams = promoteBufferResultsToOutParams;
opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion);
- BufferizationOptions::OpFilterEntry::FilterFn filterFn =
+ OpFilter::Entry::FilterFn filterFn =
[&](Operation *op) {
// Filter may be specified via options.
if (this->dialectFilter.hasValue())
@@ -204,7 +204,7 @@ struct OneShotBufferizePass
// No filter specified: All other ops are allowed.
return true;
};
- opt.allowOperationInFilter(filterFn);
+ opt.opFilter.allowOperation(filterFn);
} else {
opt = *options;
}
@@ -242,7 +242,7 @@ struct BufferizationBufferizePass
: public BufferizationBufferizeBase<BufferizationBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
- options.allowDialectInFilter<BufferizationDialect>();
+ options.opFilter.allowDialect<BufferizationDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 143cef1ce5cf0..63433ef50fe2e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -28,7 +28,7 @@ namespace {
struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
- options.allowDialectInFilter<linalg::LinalgDialect>();
+ options.opFilter.allowDialect<linalg::LinalgDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
index b84c3b7e8a470..373552801673c 100644
--- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
@@ -23,7 +23,7 @@ namespace {
struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
- options.allowDialectInFilter<shape::ShapeDialect>();
+ options.opFilter.allowDialect<shape::ShapeDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index c923398996912..75bfc878b9226 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -30,7 +30,7 @@ namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
- options.allowDialectInFilter<tensor::TensorDialect>();
+ options.opFilter.allowDialect<tensor::TensorDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp
index 4ed2dd629c1b9..f98eeda1bde53 100644
--- a/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp
@@ -27,7 +27,7 @@ namespace {
struct VectorBufferizePass : public VectorBufferizeBase<VectorBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
- options.allowDialectInFilter<vector::VectorDialect>();
+ options.opFilter.allowDialect<vector::VectorDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
More information about the Mlir-commits
mailing list