[Mlir-commits] [mlir] 6fc11d4 - [mlir][bufferize] Add BufferizationState initializers
Matthias Springer
llvmlistbot at llvm.org
Fri Mar 4 12:20:20 PST 2022
Author: Matthias Springer
Date: 2022-03-05T05:20:11+09:00
New Revision: 6fc11d4d3ea08f2a9e6adf1c1a99c8798904f385
URL: https://github.com/llvm/llvm-project/commit/6fc11d4d3ea08f2a9e6adf1c1a99c8798904f385
DIFF: https://github.com/llvm/llvm-project/commit/6fc11d4d3ea08f2a9e6adf1c1a99c8798904f385.diff
LOG: [mlir][bufferize] Add BufferizationState initializers
Such initializer functions can be enqueued in `BufferizationOptions`. They can be used to set up dialect-specific bufferization state.
Differential Revision: https://reviews.llvm.org/D120985
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 593057d039b3c..1e7587221b6dd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -28,8 +28,8 @@ class FuncOp;
namespace bufferization {
class BufferizableOpInterface;
-struct BufferizationOptions;
class BufferizationState;
+struct DialectBufferizationState;
/// Options for ComprehensiveBufferize.
struct BufferizationOptions {
@@ -44,6 +44,11 @@ struct BufferizationOptions {
/// Memcpy function: Generate a memcpy between two buffers.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+ /// Initializer function for bufferization state.
+ using BufferizationStateInitFn = std::function<void(BufferizationState &)>;
+ /// Initializer function for dialect-specific bufferization state.
+ using DialectStateInitFn =
+ std::function<std::unique_ptr<DialectBufferizationState>()>;
/// An op filter entry. Filters can be used to specify which ops should be
/// processed by the bufferization.
@@ -228,6 +233,14 @@ struct BufferizationOptions {
/// DENY-filtered and have at least one matching ALLOW filter are processed.
SmallVector<OpFilterEntry> opFilter;
+ /// Initializer functions for bufferization state. These can be used to
+ /// initialize dialect-specific bufferization state.
+ SmallVector<BufferizationStateInitFn> stateInitializers;
+
+ /// Add a bufferization state initializer that initializes the specified
+ /// dialect-specific bufferization state.
+ void addDialectStateInitializer(StringRef name, DialectStateInitFn fn);
+
private:
/// Allow a dialect.
template <typename DialectT>
@@ -362,6 +375,12 @@ class BufferizationState {
return static_cast<StateT &>(*dialectState[name]);
}
+ void insertDialectState(StringRef name,
+ std::unique_ptr<DialectBufferizationState> state) {
+ assert(!dialectState.count(name) && "dialect state already initialized");
+ dialectState[name] = std::move(state);
+ }
+
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const { return options; }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index e5d94872ae586..3e0e89a890c06 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -64,6 +64,12 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
return nullptr;
}
+void BufferizationOptions::addDialectStateInitializer(StringRef name,
+ DialectStateInitFn fn) {
+ stateInitializers.push_back(
+ [=](BufferizationState &state) { state.insertDialectState(name, fn()); });
+}
+
//===----------------------------------------------------------------------===//
// Helper functions for BufferizableOpInterface
//===----------------------------------------------------------------------===//
@@ -200,7 +206,11 @@ BufferizationState::findLastPrecedingWrite(Value value) const {
}
BufferizationState::BufferizationState(const BufferizationOptions &options)
- : options(options) {}
+ : options(options) {
+ for (const BufferizationOptions::BufferizationStateInitFn &fn :
+ options.stateInitializers)
+ fn(*this);
+}
// bufferization.to_memref is not allowed to change the rank.
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
More information about the Mlir-commits
mailing list