[Mlir-commits] [mlir] 9cf9685 - [mlir][bufferize] Add noAnalysisFuncFilter to OneShotBufferizationOptions struct

Maya Amrami llvmlistbot at llvm.org
Tue Feb 7 02:38:36 PST 2023


Author: Maya Amrami
Date: 2023-02-07T12:38:30+02:00
New Revision: 9cf96850c3977adbb28cacdd2a1354616179c956

URL: https://github.com/llvm/llvm-project/commit/9cf96850c3977adbb28cacdd2a1354616179c956
DIFF: https://github.com/llvm/llvm-project/commit/9cf96850c3977adbb28cacdd2a1354616179c956.diff

LOG: [mlir][bufferize] Add noAnalysisFuncFilter to OneShotBufferizationOptions struct

This change is needed in order to set the flag when running the pass not via the command line.
It also allows simplifying the signature of some functions.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D143416

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 66e1e8e0355ff..63a8b389e0d5f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "llvm/ADT/EquivalenceClasses.h"
+#include <string>
 
 namespace mlir {
 namespace bufferization {
@@ -33,6 +34,10 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
   /// The heuristic controls the order in which ops are traversed during the
   /// analysis.
   AnalysisHeuristic analysisHeuristic = AnalysisHeuristic::BottomUp;
+
+  /// Specify the functions that should not be analyzed. copyBeforeWrite will be
+  /// set to true when bufferizing them.
+  llvm::ArrayRef<std::string> noAnalysisFuncFilter;
 };
 
 /// The BufferizationAliasInfo class maintains a list of buffer aliases and

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index c6f3e6a646182..cd4c009a05bc9 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -9,7 +9,6 @@
 #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
 
-#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 namespace mlir {
 
 struct LogicalResult;
@@ -31,13 +30,12 @@ LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
 /// inserted except two cases:
 /// - `options.copyBeforeWrite` is set, in which case buffers are copied before
 ///   every write.
-/// - `options.copyBeforeWrite` is not set and `analysisFilterFn` returns true
-///   for some FuncOps. These FuncOps were not analyzed. Buffer copies will be
-///   inserted only to these FuncOps.
-LogicalResult
-bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-                  BufferizationStatistics *statistics = nullptr,
-                  OpFilter::Entry::FilterFn analysisFilterFn = nullptr);
+/// - `options.copyBeforeWrite` is not set and `options.noAnalysisFuncFilter`
+///   is not empty. The FuncOps it contains were not analyzed. Buffer copies
+///   will be inserted only to these FuncOps.
+LogicalResult bufferizeModuleOp(ModuleOp moduleOp,
+                                const OneShotBufferizationOptions &options,
+                                BufferizationStatistics *statistics = nullptr);
 
 /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
 void removeBufferizationAttributesInModule(ModuleOp moduleOp);
@@ -49,8 +47,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
 LogicalResult runOneShotModuleBufferize(
     ModuleOp moduleOp,
     const bufferization::OneShotBufferizationOptions &options,
-    BufferizationStatistics *statistics = nullptr,
-    OpFilter::Entry::FilterFn analysisFilterFn = nullptr);
+    BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 9ed6c52277279..3ec037069c2c2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -215,6 +215,7 @@ struct OneShotBufferizePass
       opt.printConflicts = printConflicts;
       opt.testAnalysisOnly = testAnalysisOnly;
       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
+      opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
 
       // Configure type converter.
       LayoutMapOption unknownTypeConversionOption =
@@ -249,25 +250,12 @@ struct OneShotBufferizePass
     BufferizationStatistics statistics;
     ModuleOp moduleOp = getOperation();
     if (opt.bufferizeFunctionBoundaries) {
-      OpFilter::Entry::FilterFn analysisFilterFn = nullptr;
-      // FuncOps whose names are specified in noAnalysisFuncFilter will not be
-      // analyzed. Ops in these FuncOps will not be analyzed as well.
-      if (this->noAnalysisFuncFilter.hasValue())
-        analysisFilterFn = [=](Operation *op) {
-          auto func = dyn_cast<func::FuncOp>(op);
-          if (!func)
-            func = op->getParentOfType<func::FuncOp>();
-          if (func)
-            return llvm::is_contained(noAnalysisFuncFilter, func.getSymName());
-          return false;
-        };
-      if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics,
-                                           analysisFilterFn))) {
+      if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;
       }
     } else {
-      assert(!this->noAnalysisFuncFilter.hasValue() &&
+      assert(opt.noAnalysisFuncFilter.empty() &&
              "invalid combination of bufferization flags");
       if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 01818d749b95b..943efe8e18b70 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -413,8 +413,7 @@ void mlir::bufferization::removeBufferizationAttributesInModule(
 
 LogicalResult mlir::bufferization::bufferizeModuleOp(
     ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-    BufferizationStatistics *statistics,
-    OpFilter::Entry::FilterFn analysisFilterFn) {
+    BufferizationStatistics *statistics) {
   assert(options.bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
   IRRewriter rewriter(moduleOp.getContext());
@@ -432,8 +431,9 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   for (func::FuncOp funcOp : orderedFuncOps) {
     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
     // would be invalidated.
-    bool copyBeforeWrite = options.copyBeforeWrite ||
-                           (analysisFilterFn && analysisFilterFn(funcOp));
+    bool copyBeforeWrite =
+        options.copyBeforeWrite ||
+        llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName());
     if (failed(bufferizeOp(funcOp, options, copyBeforeWrite,
                            /*opFilter=*/nullptr, statistics)))
       return failure();
@@ -451,17 +451,27 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
 
 LogicalResult mlir::bufferization::runOneShotModuleBufferize(
     ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-    BufferizationStatistics *statistics,
-    OpFilter::Entry::FilterFn analysisFilterFn) {
+    BufferizationStatistics *statistics) {
   assert(options.bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
   assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
          "invalid combination of bufferization flags");
   if (!options.copyBeforeWrite) {
-    if (!analysisFilterFn) {
+    if (options.noAnalysisFuncFilter.empty()) {
       if (failed(insertTensorCopies(moduleOp, options, statistics)))
         return failure();
     } else {
+      // FuncOps whose names are specified in options.noAnalysisFuncFilter will
+      // not be analyzed. Ops in these FuncOps will not be analyzed as well.
+      OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
+        auto func = dyn_cast<func::FuncOp>(op);
+        if (!func)
+          func = op->getParentOfType<func::FuncOp>();
+        if (func)
+          return llvm::is_contained(options.noAnalysisFuncFilter,
+                                    func.getSymName());
+        return false;
+      };
       OneShotBufferizationOptions updatedOptions(options);
       updatedOptions.opFilter.denyOperation(analysisFilterFn);
       if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
@@ -470,9 +480,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
   }
   if (options.testAnalysisOnly)
     return success();
-
-  if (failed(
-          bufferizeModuleOp(moduleOp, options, statistics, analysisFilterFn)))
+  if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
     return failure();
   return success();
 }


        


More information about the Mlir-commits mailing list