[flang-commits] [flang] [flang] Added extension point callbacks to default FIR optimizer pipeline. (PR #90674)

Vijay Kandiah via flang-commits flang-commits at lists.llvm.org
Tue Apr 30 14:58:29 PDT 2024


https://github.com/VijayKandiah created https://github.com/llvm/llvm-project/pull/90674

This change inserts a few extension point callbacks in the DefaultFIROptimizerPassPipeline. As an example usage of callbacks in the FIR optimizer pipeline,  the FIRInlinerCallback is now used to register the default MLIR inliner pass in flang-new, tco, and bbc compilation flows. Other compilation flows can use these callbacks to add extra passes at different points of the pass pipeline.

>From 48b1f89a28bfc70ce910d7e725cd505bb0fa9d01 Mon Sep 17 00:00:00 2001
From: Vijay Kandiah <vkandiah at sky6.pgi.net>
Date: Tue, 30 Apr 2024 14:17:27 -0700
Subject: [PATCH] [flang] Adding extension point callbacks to default FIR
 optimizer pipeline.

---
 flang/include/flang/Tools/CLOptions.inc      | 30 ++++++++---
 flang/include/flang/Tools/CrossToolHelpers.h | 57 +++++++++++++++++++-
 flang/lib/Frontend/FrontendActions.cpp       |  1 +
 flang/tools/bbc/bbc.cpp                      |  5 +-
 flang/tools/tco/tco.cpp                      |  1 +
 5 files changed, 84 insertions(+), 10 deletions(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index bd60c66b61ad24..d1602cde9ee5fe 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -232,12 +232,27 @@ inline void addExternalNameConversionPass(
   });
 }
 
+// Use inliner extension point callback to register the default inliner pass.
+inline void registerDefaultInlinerPass(MLIRToLLVMPassPipelineConfig &config) {
+  config.registerFIRInlinerCallback(
+      [](mlir::PassManager &pm, llvm::OptimizationLevel level) {
+        llvm::StringMap<mlir::OpPassManager> pipelines;
+        // The default inliner pass adds the canonicalizer pass with the default
+        // configuration. Create the inliner pass with tco config.
+        pm.addPass(mlir::createInlinerPass(
+            pipelines, addCanonicalizerPassWithoutRegionSimplification));
+      });
+}
+
 /// Create a pass pipeline for running default optimization passes for
 /// incremental conversion of FIR.
 ///
 /// \param pm - MLIR pass manager that will hold the pipeline definition
 inline void createDefaultFIROptimizerPassPipeline(
-    mlir::PassManager &pm, const MLIRToLLVMPassPipelineConfig &pc) {
+    mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig &pc) {
+  // Early Optimizer EP Callback
+  pc.invokeFIROptEarlyEPCallbacks(pm, pc.OptLevel);
+
   // simplify the IR
   mlir::GreedyRewriteConfig config;
   config.enableRegionSimplification = false;
@@ -262,11 +277,9 @@ inline void createDefaultFIROptimizerPassPipeline(
   else
     fir::addMemoryAllocationOpt(pm);
 
-  // The default inliner pass adds the canonicalizer pass with the default
-  // configuration. Create the inliner pass with tco config.
-  llvm::StringMap<mlir::OpPassManager> pipelines;
-  pm.addPass(mlir::createInlinerPass(
-      pipelines, addCanonicalizerPassWithoutRegionSimplification));
+  // FIR Inliner Callback
+  pc.invokeFIRInlinerCallback(pm, pc.OptLevel);
+
   pm.addPass(fir::createSimplifyRegionLite());
   pm.addPass(mlir::createCSEPass());
 
@@ -283,6 +296,9 @@ inline void createDefaultFIROptimizerPassPipeline(
   pm.addPass(mlir::createCanonicalizerPass(config));
   pm.addPass(fir::createSimplifyRegionLite());
   pm.addPass(mlir::createCSEPass());
+
+  // Last Optimizer EP Callback
+  pc.invokeFIROptLastEPCallbacks(pm, pc.OptLevel);
 }
 
 /// Create a pass pipeline for lowering from HLFIR to FIR
@@ -375,7 +391,7 @@ inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
 /// \param optLevel - optimization level used for creating FIR optimization
 ///   passes pipeline
 inline void createMLIRToLLVMPassPipeline(mlir::PassManager &pm,
-    const MLIRToLLVMPassPipelineConfig &config,
+    MLIRToLLVMPassPipelineConfig &config,
     llvm::StringRef inputFilename = {}) {
   fir::createHLFIRToFIRPassPipeline(pm, config.OptLevel);
 
diff --git a/flang/include/flang/Tools/CrossToolHelpers.h b/flang/include/flang/Tools/CrossToolHelpers.h
index cebdd6d181c364..f79520707714d7 100644
--- a/flang/include/flang/Tools/CrossToolHelpers.h
+++ b/flang/include/flang/Tools/CrossToolHelpers.h
@@ -20,11 +20,66 @@
 
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Frontend/Debug/Options.h"
 #include "llvm/Passes/OptimizationLevel.h"
 
+// Flang Extension Point Callbacks
+class FlangEPCallBacks {
+public:
+  void registerFIROptEarlyEPCallbacks(
+      const std::function<void(mlir::PassManager &, llvm::OptimizationLevel)>
+          &C) {
+    FIROptEarlyEPCallbacks.push_back(C);
+  }
+
+  void registerFIRInlinerCallback(
+      const std::function<void(mlir::PassManager &, llvm::OptimizationLevel)>
+          &C) {
+    FIRInlinerCallback.push_back(C);
+  }
+
+  void registerFIROptLastEPCallbacks(
+      const std::function<void(mlir::PassManager &, llvm::OptimizationLevel)>
+          &C) {
+    FIROptLastEPCallbacks.push_back(C);
+  }
+
+  void invokeFIROptEarlyEPCallbacks(
+      mlir::PassManager &pm, llvm::OptimizationLevel optLevel) {
+    for (auto &C : FIROptEarlyEPCallbacks)
+      C(pm, optLevel);
+  };
+
+  void invokeFIRInlinerCallback(
+      mlir::PassManager &pm, llvm::OptimizationLevel optLevel) {
+    for (auto &C : FIRInlinerCallback)
+      C(pm, optLevel);
+  };
+
+  void invokeFIROptLastEPCallbacks(
+      mlir::PassManager &pm, llvm::OptimizationLevel optLevel) {
+    for (auto &C : FIROptLastEPCallbacks)
+      C(pm, optLevel);
+  };
+
+private:
+  llvm::SmallVector<
+      std::function<void(mlir::PassManager &, llvm::OptimizationLevel)>, 1>
+      FIROptEarlyEPCallbacks;
+
+  llvm::SmallVector<
+      std::function<void(mlir::PassManager &, llvm::OptimizationLevel)>, 1>
+      FIRInlinerCallback;
+
+  llvm::SmallVector<
+      std::function<void(mlir::PassManager &, llvm::OptimizationLevel)>, 1>
+      FIROptLastEPCallbacks;
+};
+
 /// Configuriation for the MLIR to LLVM pass pipeline.
-struct MLIRToLLVMPassPipelineConfig {
+struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks {
   explicit MLIRToLLVMPassPipelineConfig(llvm::OptimizationLevel level) {
     OptLevel = level;
   }
diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp
index 531616e7926aca..36536c77499989 100644
--- a/flang/lib/Frontend/FrontendActions.cpp
+++ b/flang/lib/Frontend/FrontendActions.cpp
@@ -802,6 +802,7 @@ void CodeGenAction::generateLLVMIR() {
   pm.enableVerifier(/*verifyPasses=*/true);
 
   MLIRToLLVMPassPipelineConfig config(level, opts, mathOpts);
+  fir::registerDefaultInlinerPass(config);
 
   if (auto vsr = getVScaleRange(ci)) {
     config.VScaleMin = vsr->first;
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index a0870d3649c2e4..4c986fb816f06e 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -424,8 +424,9 @@ static mlir::LogicalResult convertFortranSourceToMLIR(
     pm.addPass(std::make_unique<Fortran::lower::VerifierPass>());
 
     // Add O2 optimizer pass pipeline.
-    fir::createDefaultFIROptimizerPassPipeline(
-        pm, MLIRToLLVMPassPipelineConfig(llvm::OptimizationLevel::O2));
+    MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
+    fir::registerDefaultInlinerPass(config);
+    fir::createDefaultFIROptimizerPassPipeline(pm, config);
   }
 
   if (mlir::succeeded(pm.run(mlirModule))) {
diff --git a/flang/tools/tco/tco.cpp b/flang/tools/tco/tco.cpp
index 45284e74f58456..399ea1362fda47 100644
--- a/flang/tools/tco/tco.cpp
+++ b/flang/tools/tco/tco.cpp
@@ -140,6 +140,7 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
       fir::createDefaultFIRCodeGenPassPipeline(pm, config);
     } else {
       // Run tco with O2 by default.
+      fir::registerDefaultInlinerPass(config);
       fir::createMLIRToLLVMPassPipeline(pm, config);
     }
     fir::addLLVMDialectToLLVMPass(pm, out.os());



More information about the flang-commits mailing list