[Mlir-commits] [mlir] 060c8be - [mlir][OneShotModuleBufferize] Add a new flag: no-analysis-func-filter
Maya Amrami
llvmlistbot at llvm.org
Tue Jan 31 01:26:23 PST 2023
Author: Maya Amrami
Date: 2023-01-31T11:26:15+02:00
New Revision: 060c8be51bdfd61734e06d9df713f37d354207a8
URL: https://github.com/llvm/llvm-project/commit/060c8be51bdfd61734e06d9df713f37d354207a8
DIFF: https://github.com/llvm/llvm-project/commit/060c8be51bdfd61734e06d9df713f37d354207a8.diff
LOG: [mlir][OneShotModuleBufferize] Add a new flag: no-analysis-func-filter
OneShotModuleBufferize fails if the input IR cannot be analyzed.
One can set CopyBeforeWrite=true in order to skip analysis.
In that case, a buffer copy is inserted on every write.
This leads to many copies, also in FuncOps that could be analyzed.
This change aims to copy buffers only when it is a must.
When running OneShotModuleBufferize with CopyBeforeWrite=false,
FuncOps whose names are specified in noAnalysisFuncFilter will not be
analyzed. Ops in these FuncOps will not be analyzed as well.
They will be bufferized with CopyBeforeWrite=true,
while the other ops will be bufferized with CopyBeforeWrite=false.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D142631
Added:
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index a69e1c2da04ca..c6f3e6a646182 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
namespace mlir {
struct LogicalResult;
@@ -27,11 +28,16 @@ LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
///
/// Note: This function does not run One-Shot Analysis. No buffer copies are
-/// inserted unless `options.copyBeforeWrite` is set, in which case buffers are
-/// copied before every write.
-LogicalResult bufferizeModuleOp(ModuleOp moduleOp,
- const OneShotBufferizationOptions &options,
- BufferizationStatistics *statistics = nullptr);
+/// inserted except two cases:
+/// - `options.copyBeforeWrite` is set, in which case buffers are copied before
+/// every write.
+/// - `options.copyBeforeWrite` is not set and `analysisFilterFn` returns true
+/// for some FuncOps. These FuncOps were not analyzed. Buffer copies will be
+/// inserted only to these FuncOps.
+LogicalResult
+bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ BufferizationStatistics *statistics = nullptr,
+ OpFilter::Entry::FilterFn analysisFilterFn = nullptr);
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
void removeBufferizationAttributesInModule(ModuleOp moduleOp);
@@ -43,7 +49,8 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
LogicalResult runOneShotModuleBufferize(
ModuleOp moduleOp,
const bufferization::OneShotBufferizationOptions &options,
- BufferizationStatistics *statistics = nullptr);
+ BufferizationStatistics *statistics = nullptr,
+ OpFilter::Entry::FilterFn analysisFilterFn = nullptr);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index cffe3bcb5cfbf..20b35f89ae269 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -297,6 +297,9 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
"core bufferization passes.">,
ListOption<"dialectFilter", "dialect-filter", "std::string",
"Restrict bufferization to ops from these dialects.">,
+ ListOption<"noAnalysisFuncFilter", "no-analysis-func-filter", "std::string",
+ "Skip analysis of functions with these symbol names."
+ "Set copyBeforeWrite to true when bufferizing them.">,
Option<"functionBoundaryTypeConversion",
"function-boundary-type-conversion", "std::string",
/*default=*/"\"infer-layout-map\"",
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index fbed0a77306cc..9ed6c52277279 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -249,11 +249,26 @@ struct OneShotBufferizePass
BufferizationStatistics statistics;
ModuleOp moduleOp = getOperation();
if (opt.bufferizeFunctionBoundaries) {
- if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
+ OpFilter::Entry::FilterFn analysisFilterFn = nullptr;
+ // FuncOps whose names are specified in noAnalysisFuncFilter will not be
+ // analyzed. Ops in these FuncOps will not be analyzed as well.
+ if (this->noAnalysisFuncFilter.hasValue())
+ analysisFilterFn = [=](Operation *op) {
+ auto func = dyn_cast<func::FuncOp>(op);
+ if (!func)
+ func = op->getParentOfType<func::FuncOp>();
+ if (func)
+ return llvm::is_contained(noAnalysisFuncFilter, func.getSymName());
+ return false;
+ };
+ if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics,
+ analysisFilterFn))) {
signalPassFailure();
return;
}
} else {
+ assert(!this->noAnalysisFuncFilter.hasValue() &&
+ "invalid combination of bufferization flags");
if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
signalPassFailure();
return;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index f1211bb878f2b..01818d749b95b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -378,6 +378,9 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
// Analyze ops.
for (func::FuncOp funcOp : orderedFuncOps) {
+ if (!state.getOptions().isOpAllowed(funcOp))
+ continue;
+
// Now analyzing function.
funcState.startFunctionAnalysis(funcOp);
@@ -410,7 +413,8 @@ void mlir::bufferization::removeBufferizationAttributesInModule(
LogicalResult mlir::bufferization::bufferizeModuleOp(
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
- BufferizationStatistics *statistics) {
+ BufferizationStatistics *statistics,
+ OpFilter::Entry::FilterFn analysisFilterFn) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
IRRewriter rewriter(moduleOp.getContext());
@@ -428,7 +432,9 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
for (func::FuncOp funcOp : orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.
- if (failed(bufferizeOp(funcOp, options, options.copyBeforeWrite,
+ bool copyBeforeWrite = options.copyBeforeWrite ||
+ (analysisFilterFn && analysisFilterFn(funcOp));
+ if (failed(bufferizeOp(funcOp, options, copyBeforeWrite,
/*opFilter=*/nullptr, statistics)))
return failure();
// Change buffer return types to more precise layout maps.
@@ -445,18 +451,28 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
- BufferizationStatistics *statistics) {
+ BufferizationStatistics *statistics,
+ OpFilter::Entry::FilterFn analysisFilterFn) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
"invalid combination of bufferization flags");
if (!options.copyBeforeWrite) {
- if (failed(insertTensorCopies(moduleOp, options, statistics)))
- return failure();
+ if (!analysisFilterFn) {
+ if (failed(insertTensorCopies(moduleOp, options, statistics)))
+ return failure();
+ } else {
+ OneShotBufferizationOptions updatedOptions(options);
+ updatedOptions.opFilter.denyOperation(analysisFilterFn);
+ if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
+ return failure();
+ }
}
if (options.testAnalysisOnly)
return success();
- if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
+
+ if (failed(
+ bufferizeModuleOp(moduleOp, options, statistics, analysisFilterFn)))
return failure();
return success();
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir
new file mode 100644
index 0000000000000..e5723f978f4ea
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 no-analysis-func-filter=contains_to_memref_op" -drop-equivalent-buffer-results --split-input-file | FileCheck %s
+
+// ToMemref ops do not pass analysis step. CopyBeforeWrite will be true only for the
+// FuncOp "contains_to_memref_op" since it is specified in no-analysis-func-filter.
+
+module {
+ // CHECK-LABEL: func.func @foo(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>) {
+ func.func @foo(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
+ %cst = arith.constant 1.000000e+00 : f32
+
+ // CHECK-NEXT: %[[c1:.*]] = arith.constant 1.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+
+ // CHECK-NEXT: memref.store %[[c1]], %[[arg0]]{{\[}}%[[c0]]] : memref<?xf32, strided<[?], offset: ?>>
+ %inserted = tensor.insert %cst into %arg0[%c0] : tensor<?xf32>
+
+ return %inserted : tensor<?xf32>
+ }
+
+ // CHECK-LABEL: func.func @contains_to_memref_op(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
+ // CHECK-SAME: %[[arg1:.*]]: index) -> vector<5xf32> {
+ func.func @contains_to_memref_op(%arg0: tensor<?xf32> {bufferization.writable = true}, %arg1: index) -> vector<5xf32> {
+
+ %0 = bufferization.to_memref %arg0 : memref<?xf32>
+
+ // CHECK: %[[c0:.*]] = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+
+ // CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?xf32, strided<[?], offset: ?>>
+ // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+ // CHECK: memref.copy %[[arg0]], %[[alloc]] : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32>
+ // CHECK: vector.transfer_read
+ %1 = vector.transfer_read %0[%arg1], %cst : memref<?xf32>, vector<5xf32>
+ return %1 : vector<5xf32>
+ }
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list