[Mlir-commits] [mlir] d6dab38 - [mlir][bufferize][NFC] Add function boundary bufferization flag to BufferizationOptions

Matthias Springer llvmlistbot at llvm.org
Fri Apr 22 09:11:48 PDT 2022


Author: Matthias Springer
Date: 2022-04-23T01:11:37+09:00
New Revision: d6dab38ae48a298b671107c9ebdd86de8f7cd482

URL: https://github.com/llvm/llvm-project/commit/d6dab38ae48a298b671107c9ebdd86de8f7cd482
DIFF: https://github.com/llvm/llvm-project/commit/d6dab38ae48a298b671107c9ebdd86de8f7cd482.diff

LOG: [mlir][bufferize][NFC] Add function boundary bufferization flag to BufferizationOptions

This makes the API easier to use. Also allows us to check for incorrect API usage for easier debugging.

Differential Revision: https://reviews.llvm.org/D124265

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c228e26d1198e..94504c1e9b104 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -78,23 +78,7 @@ struct BufferizationOptions {
   /// unless they are explicitly marked as DENY. If the filter has at least one
   /// ALLOW rule, ops are ignored by default and only bufferized if they match
   /// an ALLOW rule and no DENY rule.
-  bool isOpAllowed(Operation *op) const {
-    bool isAllowed = !filterHasAllowRule();
-    for (const OpFilterEntry &entry : opFilter) {
-      bool filterResult = entry.fn(op);
-      switch (entry.type) {
-      case OpFilterEntry::ALLOW:
-        isAllowed |= filterResult;
-        break;
-      case OpFilterEntry::DENY:
-        if (filterResult)
-          // DENY filter matches. This op is no allowed. (Even if other ALLOW
-          // filters may match.)
-          return false;
-      };
-    }
-    return isAllowed;
-  }
+  bool isOpAllowed(Operation *op) const;
 
   /// Allow the given dialects in the filter.
   ///
@@ -182,6 +166,10 @@ struct BufferizationOptions {
   /// the boundaries.
   bool allowUnknownOps = false;
 
+  /// Specifies whether function boundaries (ops in the func dialect) should be
+  /// bufferized or not.
+  bool bufferizeFunctionBoundaries = false;
+
   /// Specifies whether dealloc ops should be generated along with alloc ops. If
   /// not, new memory allocations will leak.
   bool createDeallocs = true;
@@ -356,6 +344,12 @@ class AnalysisState {
   /// any given tensor.
   virtual bool isTensorYielded(Value tensor) const = 0;
 
+  /// Return `true` if the given dialect state exists.
+  bool hasDialectState(StringRef name) const {
+    auto it = dialectState.find(name);
+    return it != dialectState.end();
+  }
+
   /// Return dialect-specific bufferization state.
   template <typename StateT>
   Optional<const StateT *> getDialectState(StringRef name) const {
@@ -369,7 +363,7 @@ class AnalysisState {
   template <typename StateT>
   StateT &getOrCreateDialectState(StringRef name) {
     // Create state if it does not exist yet.
-    if (!dialectState.count(name))
+    if (!hasDialectState(name))
       dialectState[name] = std::make_unique<StateT>();
     return static_cast<StateT &>(*dialectState[name]);
   }

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 921c6c80f6746..73da6c85e761f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -51,6 +51,31 @@ static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
 // Default constructor for BufferizationOptions.
 BufferizationOptions::BufferizationOptions() = default;
 
+bool BufferizationOptions::isOpAllowed(Operation *op) const {
+  // Special case: If function boundary bufferization is deactivated, do not
+  // allow ops that belong to the `func` dialect.
+  bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
+  if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
+    return false;
+
+  // All other ops: Allow/disallow according to filter.
+  bool isAllowed = !filterHasAllowRule();
+  for (const OpFilterEntry &entry : opFilter) {
+    bool filterResult = entry.fn(op);
+    switch (entry.type) {
+    case OpFilterEntry::ALLOW:
+      isAllowed |= filterResult;
+      break;
+    case OpFilterEntry::DENY:
+      if (filterResult)
+        // DENY filter matches. This op is no allowed. (Even if other ALLOW
+        // filters may match.)
+        return false;
+    };
+  }
+  return isAllowed;
+}
+
 BufferizableOpInterface
 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
   if (isOpAllowed(op))

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 49aaeb76a0381..0936512df33a8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -175,15 +175,10 @@ struct OneShotBufferizePass
       opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
       opt.printConflicts = printConflicts;
       opt.testAnalysisOnly = testAnalysisOnly;
+      opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
 
       BufferizationOptions::OpFilterEntry::FilterFn filterFn =
           [&](Operation *op) {
-            // Disallow non-func dialect ops. I.e., no ops related to function
-            // calls. (Unless explicitly activated.)
-            bool isFuncBoundaryOp =
-                isa_and_nonnull<func::FuncDialect>(op->getDialect());
-            if (!this->bufferizeFunctionBoundaries && isFuncBoundaryOp)
-              return false;
             // Filter may be specified via options.
             if (this->dialectFilter.hasValue())
               return llvm::find(this->dialectFilter,
@@ -198,7 +193,7 @@ struct OneShotBufferizePass
     }
 
     ModuleOp moduleOp = getOperation();
-    if (bufferizeFunctionBoundaries) {
+    if (opt.bufferizeFunctionBoundaries) {
       if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
         signalPassFailure();
         return;
@@ -284,6 +279,12 @@ bufferization::finalizeBuffers(Operation *op,
 
 LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const AnalysisState &analysisState) {
+  // Catch incorrect API usage.
+  assert((analysisState.hasDialectState(
+              func::FuncDialect::getDialectNamespace()) ||
+          !analysisState.getOptions().bufferizeFunctionBoundaries) &&
+         "must use ModuleBufferize to bufferize function boundaries");
+
   BufferizationState bufferizationState(analysisState);
   if (failed(bufferizeOp(op, bufferizationState)))
     return failure();

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 8ae5c1c8cace8..d1fbb70f889c8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -46,6 +46,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Dominance.h"
@@ -864,6 +865,11 @@ LogicalResult bufferization::analyzeOp(Operation *op,
   const auto &options =
       static_cast<const OneShotBufferizationOptions &>(state.getOptions());
 
+  // Catch incorrect API usage.
+  assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) ||
+          !options.bufferizeFunctionBoundaries) &&
+         "must use ModuleBufferize to bufferize function boundaries");
+
   if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
     return failure();
 

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 7bcde9fd1be43..6dc3432f46a63 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -417,6 +417,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
 
 LogicalResult mlir::bufferization::runOneShotModuleBufferize(
     ModuleOp moduleOp, OneShotBufferizationOptions options) {
+  assert(options.bufferizeFunctionBoundaries &&
+         "expected that function boundary bufferization is activated");
   IRRewriter rewriter(moduleOp.getContext());
   OneShotAnalysisState analysisState(moduleOp, options);
   BufferizationState bufferizationState(analysisState);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index b9163c5646ee1..13b760c6dd959 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -99,6 +99,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
     opt.printConflicts = printConflicts;
     opt.testAnalysisOnly = testAnalysisOnly;
     opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
+    opt.bufferizeFunctionBoundaries = true;
     if (initTensorElimination) {
       opt.addPostAnalysisStep(insertSliceAnchoredInitTensorEliminationStep);
     }


        


More information about the Mlir-commits mailing list