[Mlir-commits] [mlir] 6fc11d4 - [mlir][bufferize] Add BufferizationState initializers

Matthias Springer llvmlistbot at llvm.org
Fri Mar 4 12:20:20 PST 2022


Author: Matthias Springer
Date: 2022-03-05T05:20:11+09:00
New Revision: 6fc11d4d3ea08f2a9e6adf1c1a99c8798904f385

URL: https://github.com/llvm/llvm-project/commit/6fc11d4d3ea08f2a9e6adf1c1a99c8798904f385
DIFF: https://github.com/llvm/llvm-project/commit/6fc11d4d3ea08f2a9e6adf1c1a99c8798904f385.diff

LOG: [mlir][bufferize] Add BufferizationState initializers

Such initializer functions can be enqueued in `BufferizationOptions`. They can be used to set up dialect-specific bufferization state.

Differential Revision: https://reviews.llvm.org/D120985

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 593057d039b3c..1e7587221b6dd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -28,8 +28,8 @@ class FuncOp;
 namespace bufferization {
 
 class BufferizableOpInterface;
-struct BufferizationOptions;
 class BufferizationState;
+struct DialectBufferizationState;
 
 /// Options for ComprehensiveBufferize.
 struct BufferizationOptions {
@@ -44,6 +44,11 @@ struct BufferizationOptions {
   /// Memcpy function: Generate a memcpy between two buffers.
   using MemCpyFn =
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+  /// Initializer function for bufferization state.
+  using BufferizationStateInitFn = std::function<void(BufferizationState &)>;
+  /// Initializer function for dialect-specific bufferization state.
+  using DialectStateInitFn =
+      std::function<std::unique_ptr<DialectBufferizationState>()>;
 
   /// An op filter entry. Filters can be used to specify which ops should be
   /// processed by the bufferization.
@@ -228,6 +233,14 @@ struct BufferizationOptions {
   /// DENY-filtered and have at least one matching ALLOW filter are processed.
   SmallVector<OpFilterEntry> opFilter;
 
+  /// Initializer functions for bufferization state. These can be used to
+  /// initialize dialect-specific bufferization state.
+  SmallVector<BufferizationStateInitFn> stateInitializers;
+
+  /// Add a bufferization state initializer that initializes the specified
+  /// dialect-specific bufferization state.
+  void addDialectStateInitializer(StringRef name, DialectStateInitFn fn);
+
 private:
   /// Allow a dialect.
   template <typename DialectT>
@@ -362,6 +375,12 @@ class BufferizationState {
     return static_cast<StateT &>(*dialectState[name]);
   }
 
+  void insertDialectState(StringRef name,
+                          std::unique_ptr<DialectBufferizationState> state) {
+    assert(!dialectState.count(name) && "dialect state already initialized");
+    dialectState[name] = std::move(state);
+  }
+
   /// Return a reference to the BufferizationOptions.
   const BufferizationOptions &getOptions() const { return options; }
 

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index e5d94872ae586..3e0e89a890c06 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -64,6 +64,12 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
   return nullptr;
 }
 
+void BufferizationOptions::addDialectStateInitializer(StringRef name,
+                                                      DialectStateInitFn fn) {
+  stateInitializers.push_back(
+      [=](BufferizationState &state) { state.insertDialectState(name, fn()); });
+}
+
 //===----------------------------------------------------------------------===//
 // Helper functions for BufferizableOpInterface
 //===----------------------------------------------------------------------===//
@@ -200,7 +206,11 @@ BufferizationState::findLastPrecedingWrite(Value value) const {
 }
 
 BufferizationState::BufferizationState(const BufferizationOptions &options)
-    : options(options) {}
+    : options(options) {
+  for (const BufferizationOptions::BufferizationStateInitFn &fn :
+       options.stateInitializers)
+    fn(*this);
+}
 
 // bufferization.to_memref is not allowed to change the rank.
 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {


        


More information about the Mlir-commits mailing list