[Mlir-commits] [mlir] 1534177 - [mlir][bufferization][NFC] Move OpFilter out of BufferizationOptions

Matthias Springer llvmlistbot at llvm.org
Fri May 27 16:51:04 PDT 2022


Author: Matthias Springer
Date: 2022-05-28T01:47:39+02:00
New Revision: 1534177f8f7edd83083ceda7c14d6d40cc872c6e

URL: https://github.com/llvm/llvm-project/commit/1534177f8f7edd83083ceda7c14d6d40cc872c6e
DIFF: https://github.com/llvm/llvm-project/commit/1534177f8f7edd83083ceda7c14d6d40cc872c6e.diff

LOG: [mlir][bufferization][NFC] Move OpFilter out of BufferizationOptions

Differential Revision: https://reviews.llvm.org/D126568

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index eb3cff0722b04..94fc30bafef9e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -23,28 +23,11 @@ class AnalysisState;
 class BufferizableOpInterface;
 struct DialectAnalysisState;
 
-/// Options for BufferizableOpInterface-based bufferization.
-struct BufferizationOptions {
-  /// Allocator function: Generate a memref allocation with the given type,
-  /// dynamic extents and alignment.
-  using AllocationFn = std::function<FailureOr<Value>(
-      OpBuilder &, Location, MemRefType, ValueRange, unsigned int)>;
-  /// Deallocator function: Deallocate a buffer that was allocated with
-  /// AllocatorFn.
-  using DeallocationFn =
-      std::function<LogicalResult(OpBuilder &, Location, Value)>;
-  /// Memcpy function: Generate a memcpy between two buffers.
-  using MemCpyFn =
-      std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
-  /// Initializer function for analysis state.
-  using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
-  /// Initializer function for dialect-specific analysis state.
-  using DialectStateInitFn =
-      std::function<std::unique_ptr<DialectAnalysisState>()>;
-
+class OpFilter {
+public:
   /// An op filter entry. Filters can be used to specify which ops should be
   /// processed by the bufferization.
-  struct OpFilterEntry {
+  struct Entry {
     /// If the filter function evaluates to `true`, the filter matches.
     using FilterFn = std::function<bool(Operation *)>;
 
@@ -55,116 +38,156 @@ struct BufferizationOptions {
     FilterType type;
   };
 
-  enum class LayoutMapOption : int8_t {
-    InferLayoutMap = 0,
-    IdentityLayoutMap = 1,
-    FullyDynamicLayoutMap = 2
-  };
-
-  BufferizationOptions();
-
-  /// Return `true` if the filter has at least one ALLOW rule.
-  bool filterHasAllowRule() const {
-    for (const OpFilterEntry &e : opFilter)
-      if (e.type == OpFilterEntry::FilterType::ALLOW)
-        return true;
-    return false;
-  }
-
-  /// Return whether the op should be bufferized or not.
+  /// Return whether the op is allowed or not.
   ///
-  /// If the filter does not have an ALLOW rule, ops are bufferized by default,
+  /// If the filter does not have an ALLOW rule, ops are allowed by default,
   /// unless they are explicitly marked as DENY. If the filter has at least one
-  /// ALLOW rule, ops are ignored by default and only bufferized if they match
+  /// ALLOW rule, ops are denied by default and only allowed if they match
   /// an ALLOW rule and no DENY rule.
   bool isOpAllowed(Operation *op) const;
 
-  /// Allow the given dialects in the filter.
+  /// Allow the given dialects.
   ///
-  /// This function adds one or multiple ALLOW filters.
-  template <typename... DialectTs>
-  void allowDialectInFilter() {
-    // The following expands a call to allowDialectInFilterImpl for each dialect
+  /// This function adds one or multiple ALLOW entries.
+  template <typename... DialectTs> void allowDialect() {
+    // The following expands a call to allowDialectImpl for each dialect
     // in 'DialectTs'. This magic is necessary due to a limitation in the places
     // that a parameter pack can be expanded in c++11.
     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
-    (void)std::initializer_list<int>{
-        0, (allowDialectInFilterImpl<DialectTs>(), 0)...};
+    (void)std::initializer_list<int>{0, (allowDialectImpl<DialectTs>(), 0)...};
   }
 
-  /// Deny the given dialects in the filter.
+  /// Deny the given dialects.
   ///
-  /// This function adds one or multiple DENY filters.
-  template <typename... DialectTs> void denyDialectInFilter() {
+  /// This function adds one or multiple DENY entries.
+  template <typename... DialectTs> void denyDialect() {
     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
-    (void)std::initializer_list<int>{
-        0, (denyDialectInFilterImpl<DialectTs>(), 0)...};
+    (void)std::initializer_list<int>{0, (denyDialectImpl<DialectTs>(), 0)...};
   }
 
-  /// Allow the given dialect in the filter.
+  /// Allow the given dialect.
   ///
-  /// This function adds an ALLOW filter.
-  void allowDialectInFilter(StringRef dialectNamespace) {
-    OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
+  /// This function adds an ALLOW entry.
+  void allowDialect(StringRef dialectNamespace) {
+    Entry::FilterFn filterFn = [=](Operation *op) {
       return op->getDialect()->getNamespace() == dialectNamespace;
     };
-    opFilter.push_back(
-        OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
+    entries.push_back(Entry{filterFn, Entry::FilterType::ALLOW});
   }
 
-  /// Allow the given ops in the filter.
+  /// Allow the given ops.
   ///
-  /// This function adds one or multiple ALLOW filters.
-  template <typename... OpTys>
-  void allowOperationInFilter() {
+  /// This function adds one or multiple ALLOW entries.
+  template <typename... OpTys> void allowOperation() {
     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
-    (void)std::initializer_list<int>{
-        0, (allowOperationInFilterImpl<OpTys>(), 0)...};
+    (void)std::initializer_list<int>{0, (allowOperationImpl<OpTys>(), 0)...};
   }
 
-  /// Deny the given ops in the filter.
+  /// Deny the given ops.
   ///
-  /// This function adds one or multiple DENY filters.
-  template <typename... OpTys> void denyOperationInFilter() {
+  /// This function adds one or multiple DENY entries.
+  template <typename... OpTys> void denyOperation() {
     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
-    (void)std::initializer_list<int>{
-        0, (denyOperationInFilterImpl<OpTys>(), 0)...};
+    (void)std::initializer_list<int>{0, (denyOperationImpl<OpTys>(), 0)...};
   }
 
-  /// Allow the given op in the filter.
+  /// Allow the given op.
   ///
-  /// This function adds an ALLOW filter.
-  void allowOperationInFilter(StringRef opName) {
-    OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
+  /// This function adds an ALLOW entry.
+  void allowOperation(StringRef opName) {
+    Entry::FilterFn filterFn = [=](Operation *op) {
       return op->getName().getStringRef() == opName;
     };
-    allowOperationInFilter(filterFn);
+    allowOperation(filterFn);
   }
 
-  /// Deny the given op in the filter.
+  /// Deny the given op.
   ///
-  /// This function adds a DENY filter.
-  void denyOperationInFilter(StringRef opName) {
-    OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
+  /// This function adds a DENY entry.
+  void denyOperation(StringRef opName) {
+    Entry::FilterFn filterFn = [=](Operation *op) {
       return op->getName().getStringRef() == opName;
     };
-    denyOperationInFilter(filterFn);
+    denyOperation(filterFn);
   }
 
-  /// Allow ops that are matched by `fn` in the filter.
+  /// Allow ops that are matched by `fn`.
   ///
-  /// This function adds an ALLOW filter.
-  void allowOperationInFilter(OpFilterEntry::FilterFn fn) {
-    opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::ALLOW});
+  /// This function adds an ALLOW entry.
+  void allowOperation(Entry::FilterFn fn) {
+    entries.push_back(Entry{fn, Entry::FilterType::ALLOW});
   }
 
-  /// Deny ops that are matched by `fn` in the filter.
+  /// Deny ops that are matched by `fn`.
   ///
-  /// This function adds a DENY filter.
-  void denyOperationInFilter(OpFilterEntry::FilterFn fn) {
-    opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::DENY});
+  /// This function adds a DENY entry.
+  void denyOperation(Entry::FilterFn fn) {
+    entries.push_back(Entry{fn, Entry::FilterType::DENY});
+  }
+
+private:
+  /// Return `true` if the filter has at least one ALLOW rule.
+  bool hasAllowRule() const {
+    for (const Entry &e : entries)
+      if (e.type == Entry::FilterType::ALLOW)
+        return true;
+    return false;
+  }
+
+  /// Allow a dialect.
+  template <typename DialectT> void allowDialectImpl() {
+    allowDialect(DialectT::getDialectNamespace());
+  }
+
+  /// Deny a dialect.
+  template <typename DialectT> void denyDialectImpl() {
+    denyDialect(DialectT::getDialectNamespace());
+  }
+
+  /// Allow an op.
+  template <typename OpTy> void allowOperationImpl() {
+    allowOperation(OpTy::getOperationName());
   }
 
+  /// Deny an op.
+  template <typename OpTy> void denyOperationImpl() {
+    denyOperation(OpTy::getOperationName());
+  }
+
+  /// A list of filter entries that determine whether an op should be allowed or
+  /// denied. If the filter has an ALLOW rule, only ops that are allowed and not
+  /// denied are allowed. If the filter does not have an ALLOW rule, only ops
+  /// that are not denied are allowed.
+  SmallVector<Entry> entries;
+};
+
+/// Options for BufferizableOpInterface-based bufferization.
+struct BufferizationOptions {
+  /// Allocator function: Generate a memref allocation with the given type,
+  /// dynamic extents and alignment.
+  using AllocationFn = std::function<FailureOr<Value>(
+      OpBuilder &, Location, MemRefType, ValueRange, unsigned int)>;
+  /// Deallocator function: Deallocate a buffer that was allocated with
+  /// AllocatorFn.
+  using DeallocationFn =
+      std::function<LogicalResult(OpBuilder &, Location, Value)>;
+  /// Memcpy function: Generate a memcpy between two buffers.
+  using MemCpyFn =
+      std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+  /// Initializer function for analysis state.
+  using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
+  /// Initializer function for dialect-specific analysis state.
+  using DialectStateInitFn =
+      std::function<std::unique_ptr<DialectAnalysisState>()>;
+
+  enum class LayoutMapOption : int8_t {
+    InferLayoutMap = 0,
+    IdentityLayoutMap = 1,
+    FullyDynamicLayoutMap = 2
+  };
+
+  BufferizationOptions();
+
   /// Try to cast the given op to BufferizableOpInterface if the op is allow
   /// listed.
   BufferizableOpInterface dynCastBufferizableOp(Operation *op) const;
@@ -173,6 +196,13 @@ struct BufferizationOptions {
   /// listed.
   BufferizableOpInterface dynCastBufferizableOp(Value value) const;
 
+  /// A filter that specifies which ops should be bufferized and which ops
+  /// should be ignored.
+  OpFilter opFilter;
+
+  /// Return `true` if the given op should be bufferized.
+  bool isOpAllowed(Operation *op) const;
+
   /// Helper functions for allocation, deallocation, memory copying.
   Optional<AllocationFn> allocationFn;
   Optional<DeallocationFn> deallocationFn;
@@ -276,12 +306,6 @@ struct BufferizationOptions {
   /// Buffer alignment for new memory allocations.
   unsigned int bufferAlignment = 128;
 
-  /// A list of op filters that determine whether an op should be processed or
-  /// ignored by the bufferization. If the filter has an ALLOW rule, only ops
-  /// that are allowed and not denied are bufferized. If the filter does not
-  /// have an ALLOW rule, only ops that are not denied are bufferized.
-  SmallVector<OpFilterEntry> opFilter;
-
   /// Initializer functions for analysis state. These can be used to
   /// initialize dialect-specific analysis state.
   SmallVector<AnalysisStateInitFn> stateInitializers;
@@ -289,29 +313,6 @@ struct BufferizationOptions {
   /// Add a analysis state initializer that initializes the specified
   /// dialect-specific analysis state.
   void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn);
-
-private:
-  /// Allow a dialect.
-  template <typename DialectT>
-  void allowDialectInFilterImpl() {
-    allowDialectInFilter(DialectT::getDialectNamespace());
-  }
-
-  /// Deny a dialect.
-  template <typename DialectT> void denyDialectInFilterImpl() {
-    denyDialectInFilter(DialectT::getDialectNamespace());
-  }
-
-  /// Allow an op.
-  template <typename OpTy>
-  void allowOperationInFilterImpl() {
-    allowOperationInFilter(OpTy::getOperationName());
-  }
-
-  /// Deny an op.
-  template <typename OpTy> void denyOperationInFilterImpl() {
-    denyOperationInFilter(OpTy::getOperationName());
-  }
 };
 
 /// Specify fine-grain relationship between buffers to enable more analysis.

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
index 4fb3d0e03ff1e..3237b1a880441 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
@@ -32,9 +32,9 @@ struct ArithmeticBufferizePass
   void runOnOperation() override {
     BufferizationOptions options = getPartialBufferizationOptions();
     if (constantOpOnly) {
-      options.allowOperationInFilter<arith::ConstantOp>();
+      options.opFilter.allowOperation<arith::ConstantOp>();
     } else {
-      options.allowDialectInFilter<arith::ArithmeticDialect>();
+      options.opFilter.allowDialect<arith::ArithmeticDialect>();
     }
     options.bufferAlignment = alignment;
 

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 79b499545ca9f..635e7d770e60b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -45,28 +45,19 @@ static const char *kBufferAllocationAttr = "bufferization.allocation";
 static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
 
 //===----------------------------------------------------------------------===//
-// BufferizationOptions
+// OpFilter
 //===----------------------------------------------------------------------===//
 
-// Default constructor for BufferizationOptions.
-BufferizationOptions::BufferizationOptions() = default;
-
-bool BufferizationOptions::isOpAllowed(Operation *op) const {
-  // Special case: If function boundary bufferization is deactivated, do not
-  // allow ops that belong to the `func` dialect.
-  bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
-  if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
-    return false;
-
+bool OpFilter::isOpAllowed(Operation *op) const {
   // All other ops: Allow/disallow according to filter.
-  bool isAllowed = !filterHasAllowRule();
-  for (const OpFilterEntry &entry : opFilter) {
+  bool isAllowed = !hasAllowRule();
+  for (const Entry &entry : entries) {
     bool filterResult = entry.fn(op);
     switch (entry.type) {
-    case OpFilterEntry::ALLOW:
+    case Entry::ALLOW:
       isAllowed |= filterResult;
       break;
-    case OpFilterEntry::DENY:
+    case Entry::DENY:
       if (filterResult)
         // DENY filter matches. This op is no allowed. (Even if other ALLOW
         // filters may match.)
@@ -76,6 +67,23 @@ bool BufferizationOptions::isOpAllowed(Operation *op) const {
   return isAllowed;
 }
 
+//===----------------------------------------------------------------------===//
+// BufferizationOptions
+//===----------------------------------------------------------------------===//
+
+// Default constructor for BufferizationOptions.
+BufferizationOptions::BufferizationOptions() = default;
+
+bool BufferizationOptions::isOpAllowed(Operation *op) const {
+  // Special case: If function boundary bufferization is deactivated, do not
+  // allow ops that belong to the `func` dialect.
+  bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
+  if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
+    return false;
+
+  return opFilter.isOpAllowed(op);
+}
+
 BufferizableOpInterface
 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 4f621979a5e35..d4b12bad308f0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -194,7 +194,7 @@ struct OneShotBufferizePass
       opt.promoteBufferResultsToOutParams = promoteBufferResultsToOutParams;
       opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion);
 
-      BufferizationOptions::OpFilterEntry::FilterFn filterFn =
+      OpFilter::Entry::FilterFn filterFn =
           [&](Operation *op) {
             // Filter may be specified via options.
             if (this->dialectFilter.hasValue())
@@ -204,7 +204,7 @@ struct OneShotBufferizePass
             // No filter specified: All other ops are allowed.
             return true;
           };
-      opt.allowOperationInFilter(filterFn);
+      opt.opFilter.allowOperation(filterFn);
     } else {
       opt = *options;
     }
@@ -242,7 +242,7 @@ struct BufferizationBufferizePass
     : public BufferizationBufferizeBase<BufferizationBufferizePass> {
   void runOnOperation() override {
     BufferizationOptions options = getPartialBufferizationOptions();
-    options.allowDialectInFilter<BufferizationDialect>();
+    options.opFilter.allowDialect<BufferizationDialect>();
 
     if (failed(bufferizeOp(getOperation(), options)))
       signalPassFailure();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 143cef1ce5cf0..63433ef50fe2e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -28,7 +28,7 @@ namespace {
 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
   void runOnOperation() override {
     BufferizationOptions options = getPartialBufferizationOptions();
-    options.allowDialectInFilter<linalg::LinalgDialect>();
+    options.opFilter.allowDialect<linalg::LinalgDialect>();
 
     if (failed(bufferizeOp(getOperation(), options)))
       signalPassFailure();

diff  --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
index b84c3b7e8a470..373552801673c 100644
--- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
@@ -23,7 +23,7 @@ namespace {
 struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
   void runOnOperation() override {
     BufferizationOptions options = getPartialBufferizationOptions();
-    options.allowDialectInFilter<shape::ShapeDialect>();
+    options.opFilter.allowDialect<shape::ShapeDialect>();
 
     if (failed(bufferizeOp(getOperation(), options)))
       signalPassFailure();

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index c923398996912..75bfc878b9226 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -30,7 +30,7 @@ namespace {
 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
   void runOnOperation() override {
     BufferizationOptions options = getPartialBufferizationOptions();
-    options.allowDialectInFilter<tensor::TensorDialect>();
+    options.opFilter.allowDialect<tensor::TensorDialect>();
 
     if (failed(bufferizeOp(getOperation(), options)))
       signalPassFailure();

diff  --git a/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp
index 4ed2dd629c1b9..f98eeda1bde53 100644
--- a/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp
@@ -27,7 +27,7 @@ namespace {
 struct VectorBufferizePass : public VectorBufferizeBase<VectorBufferizePass> {
   void runOnOperation() override {
     BufferizationOptions options = getPartialBufferizationOptions();
-    options.allowDialectInFilter<vector::VectorDialect>();
+    options.opFilter.allowDialect<vector::VectorDialect>();
 
     if (failed(bufferizeOp(getOperation(), options)))
       signalPassFailure();


        


More information about the Mlir-commits mailing list