[Mlir-commits] [mlir] 695c341 - [mlir][bufferize] Generalize filtering mechanism in BufferizationOptions

Matthias Springer llvmlistbot at llvm.org
Tue Feb 15 02:17:42 PST 2022


Author: Matthias Springer
Date: 2022-02-15T19:17:33+09:00
New Revision: 695c341b84d19459b8f71319344380e4d7025b39

URL: https://github.com/llvm/llvm-project/commit/695c341b84d19459b8f71319344380e4d7025b39
DIFF: https://github.com/llvm/llvm-project/commit/695c341b84d19459b8f71319344380e4d7025b39.diff

LOG: [mlir][bufferize] Generalize filtering mechanism in BufferizationOptions

Support ALLOW filters and DENY filters. This is needed for compatibility with existing code that specifies more complex op filters.

Differential Revision: https://reviews.llvm.org/D119820

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 714f75d09b965..257f67a2db440 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -45,32 +45,93 @@ struct BufferizationOptions {
   using MemCpyFn =
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
 
+  /// An op filter entry. Filters can be used to specify which ops should be
+  /// processed by the bufferization.
+  struct OpFilterEntry {
+    /// If the filter function evaluates to `true`, the filter matches.
+    using FilterFn = std::function<bool(Operation *)>;
+
+    /// Filter type: A filter can either be a DENY filter or an ALLOW filter.
+    enum FilterType : int8_t { DENY = 0, ALLOW = 1 };
+
+    FilterFn fn;
+    FilterType type;
+  };
+
   BufferizationOptions();
 
-  /// Return `true` if the op is allowed to be bufferized.
+  /// Return whether the op should be bufferized or not.
+  ///
+  /// If no filter is specified (`hasFilter` = false), every op will be
+  /// bufferized. Otherwise, an op is bufferized if:
+  ///
+  /// - At least one ALLOW filter says `true`.
+  /// - And, no DENY filter says `true`.
   bool isOpAllowed(Operation *op) const {
     if (!hasFilter)
       return true;
-    return dialectFilter.contains(op->getDialect()->getNamespace()) ||
-           operationFilter.contains(op->getName().getStringRef());
+    bool isAllowed = false;
+    for (const OpFilterEntry &entry : opFilter) {
+      bool filterResult = entry.fn(op);
+      switch (entry.type) {
+      case OpFilterEntry::ALLOW:
+        isAllowed |= filterResult;
+        break;
+      case OpFilterEntry::DENY:
+        if (filterResult)
+          // DENY filter matches. This op is no allowed. (Even if other ALLOW
+          // filters may match.)
+          return false;
+      };
+    }
+    return isAllowed;
   }
 
   /// Allow the given dialects and activate the filter (`hasFilter`).
+  ///
+  /// This function adds one or multiple ALLOW filters.
   template <typename... DialectTs>
-  void addToDialectFilter() {
-    // The following expands a call to addToDialectFilterImpl for each dialect
+  void allowDialectInFilter() {
+    // The following expands a call to allowDialectInFilterImpl for each dialect
     // in 'DialectTs'. This magic is necessary due to a limitation in the places
     // that a parameter pack can be expanded in c++11.
     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
     (void)std::initializer_list<int>{
-        0, (addToDialectFilterImpl<DialectTs>(), 0)...};
+        0, (allowDialectInFilterImpl<DialectTs>(), 0)...};
+  }
+
+  /// Allow the given dialect and activate the filter (`hasFilter`).
+  ///
+  /// This function adds an ALLOW filter.
+  void allowDialectInFilter(StringRef dialectNamespace) {
+    hasFilter = true;
+    OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
+      return op->getDialect()->getNamespace() == dialectNamespace;
+    };
+    opFilter.push_back(
+        OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
   }
 
   /// Allow the given ops and activate the filter (`hasFilter`).
-  template <typename... OpTys> void addToOperationFilter() {
+  ///
+  /// This function adds one or multiple ALLOW filters.
+  template <typename... OpTys>
+  void allowOperationInFilter() {
     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
-    (void)std::initializer_list<int>{0,
-                                     (addToOperationFilterImpl<OpTys>(), 0)...};
+    (void)std::initializer_list<int>{
+        0, (allowOperationInFilterImpl<OpTys>(), 0)...};
+  }
+
+  /// Allow the given op and activate the filter (`hasFilter`).
+  ///
+  /// This function adds an ALLOW filter.
+  void allowOperationInFilter(StringRef opName) {
+    hasFilter = true;
+    OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
+      return op->getName().getStringRef() == opName;
+    };
+    opFilter.push_back(
+        OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
   }
 
   /// Try to cast the given op to BufferizableOpInterface if the op is allow
@@ -118,33 +179,26 @@ struct BufferizationOptions {
   /// Buffer alignment for new memory allocations.
   unsigned int bufferAlignment = 128;
 
-  /// If set to `true`, only ops that belong to a filtered dialect
-  /// (`dialectFilter`) and filtered ops (`operationFilter`) are processed. All
-  /// other ops are ignored. If set to `false`, all ops are bufferized (as long
-  /// as they implement BufferizableOpInterface).
-  ///
-  /// If a filter is specified, `allowUnknownOps` should be enabled. Otherwise,
-  /// bufferization would fail when encountering a non-filtered op.
+  /// If set to `false`, all ops are bufferized (as long as they implement
+  /// BufferizableOpInterface). Otherwise, only filtered ops are bufferized.
   bool hasFilter = false;
 
-  /// A set of allowed dialects.
-  DenseSet<StringRef> dialectFilter;
-
-  /// A set of allowed ops.
-  DenseSet<StringRef> operationFilter;
+  /// A list of op filters that determine whether an op should be processed or
+  /// ignored by the bufferization. If `hasFilter`, only ops that are not
+  /// DENY-filtered and have at least one matching ALLOW filter are processed.
+  SmallVector<OpFilterEntry> opFilter;
 
 private:
   /// Allow a dialect.
   template <typename DialectT>
-  void addToDialectFilterImpl() {
-    hasFilter = true;
-    dialectFilter.insert(DialectT::getDialectNamespace());
+  void allowDialectInFilterImpl() {
+    allowDialectInFilter(DialectT::getDialectNamespace());
   }
 
   /// Allow an op.
-  template <typename OpTy> void addToOperationFilterImpl() {
-    hasFilter = true;
-    operationFilter.insert(OpTy::getOperationName());
+  template <typename OpTy>
+  void allowOperationInFilterImpl() {
+    allowOperationInFilter(OpTy::getOperationName());
   }
 };
 

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
index d88ee4a65e26c..bc32daf75bc3f 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
@@ -31,9 +31,9 @@ struct ArithmeticBufferizePass
   void runOnOperation() override {
     BufferizationOptions options = getPartialBufferizationOptions();
     if (constantOpOnly) {
-      options.addToOperationFilter<arith::ConstantOp>();
+      options.allowOperationInFilter<arith::ConstantOp>();
     } else {
-      options.addToDialectFilter<arith::ArithmeticDialect>();
+      options.allowDialectInFilter<arith::ArithmeticDialect>();
     }
     options.bufferAlignment = alignment;
 

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 7d8ee6ff3ee67..41b11ea197b22 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -31,7 +31,7 @@ namespace {
 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
   void runOnOperation() override {
     BufferizationOptions options = getPartialBufferizationOptions();
-    options.addToDialectFilter<tensor::TensorDialect>();
+    options.allowDialectInFilter<tensor::TensorDialect>();
 
     if (failed(bufferizeOp(getOperation(), options)))
       signalPassFailure();

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 28535ba8248a9..f0b6b0e669ec4 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -116,7 +116,7 @@ void TestComprehensiveFunctionBufferize::runOnOperation() {
   if (dialectFilter.hasValue()) {
     options->hasFilter = true;
     for (const std::string &dialectNamespace : dialectFilter)
-      options->dialectFilter.insert(dialectNamespace);
+      options->allowDialectInFilter(dialectNamespace);
   }
 
   Operation *op = getOperation();


        


More information about the Mlir-commits mailing list