[Mlir-commits] [mlir] 7e133eb - [mlir][bufferize] Add filterFn option to BufferResultsToOutParams

Emilio Cota llvmlistbot at llvm.org
Thu Nov 3 10:09:28 PDT 2022


Author: Emilio Cota
Date: 2022-11-03T13:01:39-04:00
New Revision: 7e133eb49b35b1648de786f21f38db084f597b7f

URL: https://github.com/llvm/llvm-project/commit/7e133eb49b35b1648de786f21f38db084f597b7f
DIFF: https://github.com/llvm/llvm-project/commit/7e133eb49b35b1648de786f21f38db084f597b7f.diff

LOG: [mlir][bufferize] Add filterFn option to BufferResultsToOutParams

This allows users to restrict the transformation to a
subset of the functions in a module.

For example, a user might want to apply the transformation to
a module's entry point, but not to the calls in the module
because those calls might refer to external C functions outside
of their control.

Reviewed By: springerm, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D137264

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
    mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index aa3f6423407c7..445430ac21a00 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -35,12 +35,25 @@ std::unique_ptr<Pass> createBufferHoistingPass();
 /// reallocations inside of loops.
 std::unique_ptr<Pass> createBufferLoopHoistingPass();
 
+// Options struct for BufferResultsToOutParams pass.
+// Note: defined only here, not in tablegen.
+struct BufferResultsToOutParamsOptions {
+  // Filter function; returns true if the function should be converted.
+  // Defaults to true, i.e. all functions are converted.
+  llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
+    return true;
+  };
+};
+
 /// Creates a pass that converts memref function results to out-params.
-std::unique_ptr<Pass> createBufferResultsToOutParamsPass();
+std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
+    const BufferResultsToOutParamsOptions &options = {});
 
 /// Replace buffers that are returned from a function with an out parameter.
 /// Also update all call sites.
-LogicalResult promoteBufferResultsToOutParams(ModuleOp module);
+LogicalResult
+promoteBufferResultsToOutParams(ModuleOp module,
+                                const BufferResultsToOutParamsOptions &options);
 
 /// Creates a pass that drops memref function results that are equivalent to a
 /// function argument.

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 996e7b729c373..bff3b664ede55 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -119,9 +119,21 @@ static void updateReturnOps(func::FuncOp func,
 
 // Updates all CallOps in the scope of the given ModuleOp by allocating
 // temporary buffers for newly introduced out params.
-static LogicalResult updateCalls(ModuleOp module) {
+static LogicalResult
+updateCalls(ModuleOp module,
+            const bufferization::BufferResultsToOutParamsOptions &options) {
   bool didFail = false;
+  SymbolTable symtab(module);
   module.walk([&](func::CallOp op) {
+    auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
+    if (!callee) {
+      op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
+                     << "symbol table";
+      didFail = true;
+      return;
+    }
+    if (!options.filterFn(&callee))
+      return;
     SmallVector<Value, 6> replaceWithNewCallResults;
     SmallVector<Value, 6> replaceWithOutParams;
     for (OpResult result : op.getResults()) {
@@ -169,9 +181,12 @@ static LogicalResult updateCalls(ModuleOp module) {
   return failure(didFail);
 }
 
-LogicalResult
-mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) {
+LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
+    ModuleOp module,
+    const bufferization::BufferResultsToOutParamsOptions &options) {
   for (auto func : module.getOps<func::FuncOp>()) {
+    if (!options.filterFn(&func))
+      continue;
     SmallVector<BlockArgument, 6> appendedEntryArgs;
     if (failed(updateFuncOp(func, appendedEntryArgs)))
       return failure();
@@ -179,7 +194,7 @@ mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) {
       continue;
     updateReturnOps(func, appendedEntryArgs);
   }
-  if (failed(updateCalls(module)))
+  if (failed(updateCalls(module, options)))
     return failure();
   return success();
 }
@@ -188,14 +203,22 @@ namespace {
 struct BufferResultsToOutParamsPass
     : bufferization::impl::BufferResultsToOutParamsBase<
           BufferResultsToOutParamsPass> {
+  explicit BufferResultsToOutParamsPass(
+      const bufferization::BufferResultsToOutParamsOptions &options)
+      : options(options) {}
+
   void runOnOperation() override {
-    if (failed(bufferization::promoteBufferResultsToOutParams(getOperation())))
+    if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
+                                                              options)))
       return signalPassFailure();
   }
+
+private:
+  bufferization::BufferResultsToOutParamsOptions options;
 };
 } // namespace
 
-std::unique_ptr<Pass>
-mlir::bufferization::createBufferResultsToOutParamsPass() {
-  return std::make_unique<BufferResultsToOutParamsPass>();
+std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
+    const bufferization::BufferResultsToOutParamsOptions &options) {
+  return std::make_unique<BufferResultsToOutParamsPass>(options);
 }


        


More information about the Mlir-commits mailing list