[Mlir-commits] [mlir] 060c8be - [mlir][OneShotModuleBufferize] Add a new flag: no-analysis-func-filter

Maya Amrami llvmlistbot at llvm.org
Tue Jan 31 01:26:23 PST 2023


Author: Maya Amrami
Date: 2023-01-31T11:26:15+02:00
New Revision: 060c8be51bdfd61734e06d9df713f37d354207a8

URL: https://github.com/llvm/llvm-project/commit/060c8be51bdfd61734e06d9df713f37d354207a8
DIFF: https://github.com/llvm/llvm-project/commit/060c8be51bdfd61734e06d9df713f37d354207a8.diff

LOG:  [mlir][OneShotModuleBufferize] Add a new flag: no-analysis-func-filter

OneShotModuleBufferize fails if the input IR cannot be analyzed.
One can set CopyBeforeWrite=true in order to skip analysis.
In that case, a buffer copy is inserted on every write.
This leads to many copies, also in FuncOps that could be analyzed.

This change aims to copy buffers only when it is a must.
When running OneShotModuleBufferize with CopyBeforeWrite=false,
FuncOps whose names are specified in noAnalysisFuncFilter will not be
analyzed. Ops in these FuncOps will not be analyzed as well.
They will be bufferized with CopyBeforeWrite=true,
while the other ops will be bufferized with CopyBeforeWrite=false.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D142631

Added: 
    mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index a69e1c2da04ca..c6f3e6a646182 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
 
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 namespace mlir {
 
 struct LogicalResult;
@@ -27,11 +28,16 @@ LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
 /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
 ///
 /// Note: This function does not run One-Shot Analysis. No buffer copies are
-/// inserted unless `options.copyBeforeWrite` is set, in which case buffers are
-/// copied before every write.
-LogicalResult bufferizeModuleOp(ModuleOp moduleOp,
-                                const OneShotBufferizationOptions &options,
-                                BufferizationStatistics *statistics = nullptr);
+/// inserted except two cases:
+/// - `options.copyBeforeWrite` is set, in which case buffers are copied before
+///   every write.
+/// - `options.copyBeforeWrite` is not set and `analysisFilterFn` returns true
+///   for some FuncOps. These FuncOps were not analyzed. Buffer copies will be
+///   inserted only to these FuncOps.
+LogicalResult
+bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+                  BufferizationStatistics *statistics = nullptr,
+                  OpFilter::Entry::FilterFn analysisFilterFn = nullptr);
 
 /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
 void removeBufferizationAttributesInModule(ModuleOp moduleOp);
@@ -43,7 +49,8 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
 LogicalResult runOneShotModuleBufferize(
     ModuleOp moduleOp,
     const bufferization::OneShotBufferizationOptions &options,
-    BufferizationStatistics *statistics = nullptr);
+    BufferizationStatistics *statistics = nullptr,
+    OpFilter::Entry::FilterFn analysisFilterFn = nullptr);
 
 } // namespace bufferization
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index cffe3bcb5cfbf..20b35f89ae269 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -297,6 +297,9 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
            "core bufferization passes.">,
     ListOption<"dialectFilter", "dialect-filter", "std::string",
                "Restrict bufferization to ops from these dialects.">,
+    ListOption<"noAnalysisFuncFilter", "no-analysis-func-filter", "std::string",
+               "Skip analysis of functions with these symbol names."
+               "Set copyBeforeWrite to true when bufferizing them.">,
     Option<"functionBoundaryTypeConversion",
            "function-boundary-type-conversion", "std::string",
            /*default=*/"\"infer-layout-map\"",

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index fbed0a77306cc..9ed6c52277279 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -249,11 +249,26 @@ struct OneShotBufferizePass
     BufferizationStatistics statistics;
     ModuleOp moduleOp = getOperation();
     if (opt.bufferizeFunctionBoundaries) {
-      if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
+      OpFilter::Entry::FilterFn analysisFilterFn = nullptr;
+      // FuncOps whose names are specified in noAnalysisFuncFilter will not be
+      // analyzed. Ops in these FuncOps will not be analyzed as well.
+      if (this->noAnalysisFuncFilter.hasValue())
+        analysisFilterFn = [=](Operation *op) {
+          auto func = dyn_cast<func::FuncOp>(op);
+          if (!func)
+            func = op->getParentOfType<func::FuncOp>();
+          if (func)
+            return llvm::is_contained(noAnalysisFuncFilter, func.getSymName());
+          return false;
+        };
+      if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics,
+                                           analysisFilterFn))) {
         signalPassFailure();
         return;
       }
     } else {
+      assert(!this->noAnalysisFuncFilter.hasValue() &&
+             "invalid combination of bufferization flags");
       if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index f1211bb878f2b..01818d749b95b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -378,6 +378,9 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
 
   // Analyze ops.
   for (func::FuncOp funcOp : orderedFuncOps) {
+    if (!state.getOptions().isOpAllowed(funcOp))
+      continue;
+
     // Now analyzing function.
     funcState.startFunctionAnalysis(funcOp);
 
@@ -410,7 +413,8 @@ void mlir::bufferization::removeBufferizationAttributesInModule(
 
 LogicalResult mlir::bufferization::bufferizeModuleOp(
     ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-    BufferizationStatistics *statistics) {
+    BufferizationStatistics *statistics,
+    OpFilter::Entry::FilterFn analysisFilterFn) {
   assert(options.bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
   IRRewriter rewriter(moduleOp.getContext());
@@ -428,7 +432,9 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   for (func::FuncOp funcOp : orderedFuncOps) {
     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
     // would be invalidated.
-    if (failed(bufferizeOp(funcOp, options, options.copyBeforeWrite,
+    bool copyBeforeWrite = options.copyBeforeWrite ||
+                           (analysisFilterFn && analysisFilterFn(funcOp));
+    if (failed(bufferizeOp(funcOp, options, copyBeforeWrite,
                            /*opFilter=*/nullptr, statistics)))
       return failure();
     // Change buffer return types to more precise layout maps.
@@ -445,18 +451,28 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
 
 LogicalResult mlir::bufferization::runOneShotModuleBufferize(
     ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-    BufferizationStatistics *statistics) {
+    BufferizationStatistics *statistics,
+    OpFilter::Entry::FilterFn analysisFilterFn) {
   assert(options.bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
   assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
          "invalid combination of bufferization flags");
   if (!options.copyBeforeWrite) {
-    if (failed(insertTensorCopies(moduleOp, options, statistics)))
-      return failure();
+    if (!analysisFilterFn) {
+      if (failed(insertTensorCopies(moduleOp, options, statistics)))
+        return failure();
+    } else {
+      OneShotBufferizationOptions updatedOptions(options);
+      updatedOptions.opFilter.denyOperation(analysisFilterFn);
+      if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
+        return failure();
+    }
   }
   if (options.testAnalysisOnly)
     return success();
-  if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
+
+  if (failed(
+          bufferizeModuleOp(moduleOp, options, statistics, analysisFilterFn)))
     return failure();
   return success();
 }

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir
new file mode 100644
index 0000000000000..e5723f978f4ea
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 no-analysis-func-filter=contains_to_memref_op" -drop-equivalent-buffer-results --split-input-file | FileCheck %s
+
+// ToMemref ops do not pass analysis step. CopyBeforeWrite will be true only for the
+// FuncOp "contains_to_memref_op" since it is specified in no-analysis-func-filter.
+
+module {
+  // CHECK-LABEL:   func.func @foo(
+  // CHECK-SAME:                   %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>) {
+  func.func @foo(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  // CHECK-NEXT:      %[[c0:.*]] = arith.constant 0 : index
+    %cst = arith.constant 1.000000e+00 : f32
+
+  // CHECK-NEXT:      %[[c1:.*]] = arith.constant 1.000000e+00 : f32
+    %c0 = arith.constant 0 : index
+
+  // CHECK-NEXT:      memref.store %[[c1]], %[[arg0]]{{\[}}%[[c0]]] : memref<?xf32, strided<[?], offset: ?>>
+    %inserted = tensor.insert %cst into %arg0[%c0] : tensor<?xf32>
+
+    return %inserted : tensor<?xf32>
+  }
+
+  // CHECK-LABEL:   func.func @contains_to_memref_op(
+  // CHECK-SAME:                                     %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
+  // CHECK-SAME:                                     %[[arg1:.*]]: index) -> vector<5xf32> {
+  func.func @contains_to_memref_op(%arg0: tensor<?xf32> {bufferization.writable = true}, %arg1: index) -> vector<5xf32> {
+
+    %0 = bufferization.to_memref %arg0 : memref<?xf32>
+
+    // CHECK:           %[[c0:.*]] = arith.constant 0 : index
+    %cst = arith.constant 0.000000e+00 : f32
+
+    // CHECK:           %[[dim:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?xf32, strided<[?], offset: ?>>
+    // CHECK:           %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+    // CHECK:           memref.copy %[[arg0]], %[[alloc]] : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32>
+    // CHECK:           vector.transfer_read
+    %1 = vector.transfer_read %0[%arg1], %cst : memref<?xf32>, vector<5xf32>
+    return %1 : vector<5xf32>
+  }
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list