[Mlir-commits] [mlir] 7e133eb - [mlir][bufferize] Add filterFn option to BufferResultsToOutParams
Emilio Cota
llvmlistbot at llvm.org
Thu Nov 3 10:09:28 PDT 2022
Author: Emilio Cota
Date: 2022-11-03T13:01:39-04:00
New Revision: 7e133eb49b35b1648de786f21f38db084f597b7f
URL: https://github.com/llvm/llvm-project/commit/7e133eb49b35b1648de786f21f38db084f597b7f
DIFF: https://github.com/llvm/llvm-project/commit/7e133eb49b35b1648de786f21f38db084f597b7f.diff
LOG: [mlir][bufferize] Add filterFn option to BufferResultsToOutParams
This allows users to restrict the transformation to a
subset of the functions in a module.
For example, a user might want to apply the transformation to
a module's entry point, but not to the calls in the module
because those calls might refer to external C functions outside
of their control.
Reviewed By: springerm, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D137264
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index aa3f6423407c7..445430ac21a00 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -35,12 +35,25 @@ std::unique_ptr<Pass> createBufferHoistingPass();
/// reallocations inside of loops.
std::unique_ptr<Pass> createBufferLoopHoistingPass();
+// Options struct for BufferResultsToOutParams pass.
+// Note: defined only here, not in tablegen.
+struct BufferResultsToOutParamsOptions {
+ // Filter function; returns true if the function should be converted.
+ // Defaults to true, i.e. all functions are converted.
+ llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
+ return true;
+ };
+};
+
/// Creates a pass that converts memref function results to out-params.
-std::unique_ptr<Pass> createBufferResultsToOutParamsPass();
+std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
+ const BufferResultsToOutParamsOptions &options = {});
/// Replace buffers that are returned from a function with an out parameter.
/// Also update all call sites.
-LogicalResult promoteBufferResultsToOutParams(ModuleOp module);
+LogicalResult
+promoteBufferResultsToOutParams(ModuleOp module,
+ const BufferResultsToOutParamsOptions &options);
/// Creates a pass that drops memref function results that are equivalent to a
/// function argument.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 996e7b729c373..bff3b664ede55 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -119,9 +119,21 @@ static void updateReturnOps(func::FuncOp func,
// Updates all CallOps in the scope of the given ModuleOp by allocating
// temporary buffers for newly introduced out params.
-static LogicalResult updateCalls(ModuleOp module) {
+static LogicalResult
+updateCalls(ModuleOp module,
+ const bufferization::BufferResultsToOutParamsOptions &options) {
bool didFail = false;
+ SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
+ auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
+ if (!callee) {
+ op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
+ << "symbol table";
+ didFail = true;
+ return;
+ }
+ if (!options.filterFn(&callee))
+ return;
SmallVector<Value, 6> replaceWithNewCallResults;
SmallVector<Value, 6> replaceWithOutParams;
for (OpResult result : op.getResults()) {
@@ -169,9 +181,12 @@ static LogicalResult updateCalls(ModuleOp module) {
return failure(didFail);
}
-LogicalResult
-mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) {
+LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
+ ModuleOp module,
+ const bufferization::BufferResultsToOutParamsOptions &options) {
for (auto func : module.getOps<func::FuncOp>()) {
+ if (!options.filterFn(&func))
+ continue;
SmallVector<BlockArgument, 6> appendedEntryArgs;
if (failed(updateFuncOp(func, appendedEntryArgs)))
return failure();
@@ -179,7 +194,7 @@ mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) {
continue;
updateReturnOps(func, appendedEntryArgs);
}
- if (failed(updateCalls(module)))
+ if (failed(updateCalls(module, options)))
return failure();
return success();
}
@@ -188,14 +203,22 @@ namespace {
struct BufferResultsToOutParamsPass
: bufferization::impl::BufferResultsToOutParamsBase<
BufferResultsToOutParamsPass> {
+ explicit BufferResultsToOutParamsPass(
+ const bufferization::BufferResultsToOutParamsOptions &options)
+ : options(options) {}
+
void runOnOperation() override {
- if (failed(bufferization::promoteBufferResultsToOutParams(getOperation())))
+ if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
+ options)))
return signalPassFailure();
}
+
+private:
+ bufferization::BufferResultsToOutParamsOptions options;
};
} // namespace
-std::unique_ptr<Pass>
-mlir::bufferization::createBufferResultsToOutParamsPass() {
- return std::make_unique<BufferResultsToOutParamsPass>();
+std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
+ const bufferization::BufferResultsToOutParamsOptions &options) {
+ return std::make_unique<BufferResultsToOutParamsPass>(options);
}
More information about the Mlir-commits
mailing list