[flang-commits] [flang] bfd1944 - [flang] de-duplicate AbstractResult pass (#88867)

via flang-commits flang-commits at lists.llvm.org
Mon Apr 22 02:11:14 PDT 2024


Author: Tom Eccles
Date: 2024-04-22T10:11:09+01:00
New Revision: bfd19445c38a2ad6a1def7ee9a1f8ff26a159caf

URL: https://github.com/llvm/llvm-project/commit/bfd19445c38a2ad6a1def7ee9a1f8ff26a159caf
DIFF: https://github.com/llvm/llvm-project/commit/bfd19445c38a2ad6a1def7ee9a1f8ff26a159caf.diff

LOG: [flang] de-duplicate AbstractResult pass (#88867)

This is the first proof of concept of the modification of FIR codegen to
fully support a variety of top level operations (beyond just func.func)
proposed in
https://discourse.llvm.org/t/rfc-add-an-interface-for-top-level-container-operations

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.h
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/include/flang/Tools/CLOptions.inc
    flang/lib/Optimizer/Transforms/AbstractResult.cpp
    flang/test/Driver/mlir-debug-pass-pipeline.f90
    flang/test/Driver/mlir-pass-pipeline.f90
    flang/test/Fir/abstract-result-2.fir
    flang/test/Fir/abstract-results.fir
    flang/test/Fir/basic-program.fir
    flang/test/Fir/non-trivial-procedure-binding-description.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index d8840d9e967b48..4d290d87d4cc95 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -31,8 +31,7 @@ namespace fir {
 // Passes defined in Passes.td
 //===----------------------------------------------------------------------===//
 
-#define GEN_PASS_DECL_ABSTRACTRESULTONFUNCOPT
-#define GEN_PASS_DECL_ABSTRACTRESULTONGLOBALOPT
+#define GEN_PASS_DECL_ABSTRACTRESULTOPT
 #define GEN_PASS_DECL_AFFINEDIALECTPROMOTION
 #define GEN_PASS_DECL_AFFINEDIALECTDEMOTION
 #define GEN_PASS_DECL_ANNOTATECONSTANTOPERANDS
@@ -50,8 +49,6 @@ namespace fir {
 #define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 
-std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
-std::unique_ptr<mlir::Pass> createAbstractResultOnGlobalOptPass();
 std::unique_ptr<mlir::Pass> createAffineDemotionPass();
 std::unique_ptr<mlir::Pass>
 createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index bfc0db8124af21..467b7e1c472ec0 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -16,8 +16,8 @@
 
 include "mlir/Pass/PassBase.td"
 
-class AbstractResultOptBase<string optExt, string operation> 
-  : Pass<"abstract-result-on-" # optExt # "-opt", operation> {
+def AbstractResultOpt
+  : Pass<"abstract-result"> {
   let summary = "Convert fir.array, fir.box and fir.rec function result to "
                 "function argument";
   let description = [{
@@ -35,14 +35,6 @@ class AbstractResultOptBase<string optExt, string operation>
   ];
 }
 
-def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> {
-  let constructor = "::fir::createAbstractResultOnFuncOptPass()";
-}
-
-def AbstractResultOnGlobalOpt : AbstractResultOptBase<"global", "fir::GlobalOp"> {
-  let constructor = "::fir::createAbstractResultOnGlobalOptPass()";
-}
-
 def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
   let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`.";
   let description = [{

diff  --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index ea297fb337a2c8..44ff2b3f70ff68 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -19,6 +19,7 @@
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "llvm/Passes/OptimizationLevel.h"
 #include "llvm/Support/CommandLine.h"
+#include <type_traits>
 
 #define DisableOption(DOName, DOOption, DODescription) \
   static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
@@ -86,6 +87,29 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
 DisableOption(ExternalNameConversion, "external-name-interop",
     "convert names with external convention");
 
+// TODO: remove once these are used for non-codegen passes
+#if !defined(FLANG_EXCLUDE_CODEGEN)
+using PassConstructor = std::unique_ptr<mlir::Pass>();
+
+template <typename OP>
+void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
+  pm.addNestedPass<OP>(ctor());
+}
+
+template <typename OP, typename... OPS,
+    typename = std::enable_if_t<sizeof...(OPS) != 0>>
+void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
+  addNestedPassToOps<OP>(pm, ctor);
+  addNestedPassToOps<OPS...>(pm, ctor);
+}
+
+void addNestedPassToAllTopLevelOperations(
+    mlir::PassManager &pm, PassConstructor ctor) {
+  addNestedPassToOps<mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
+      fir::GlobalOp>(pm, ctor);
+}
+#endif
+
 /// Generic for adding a pass to the pass manager if it is not disabled.
 template <typename F>
 void addPassConditionally(
@@ -304,9 +328,7 @@ inline void createDebugPasses(
 inline void createDefaultFIRCodeGenPassPipeline(
     mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config) {
   fir::addBoxedProcedurePass(pm);
-  pm.addNestedPass<mlir::func::FuncOp>(
-      fir::createAbstractResultOnFuncOptPass());
-  pm.addNestedPass<fir::GlobalOp>(fir::createAbstractResultOnGlobalOptPass());
+  addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOpt);
   fir::addCodeGenRewritePass(pm);
   fir::addTargetRewritePass(pm);
   fir::addExternalNameConversionPass(pm, config.Underscoring);

diff  --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index dd1ddd16f2ded5..eb4dd637bb167e 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -16,13 +16,12 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace fir {
-#define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
-#define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
+#define GEN_PASS_DEF_ABSTRACTRESULTOPT
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 } // namespace fir
 
@@ -285,59 +284,12 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
   bool shouldBoxResult;
 };
 
-/// @brief Base CRTP class for AbstractResult pass family.
-/// Contains common logic for abstract result conversion in a reusable fashion.
-/// @tparam Pass target class that implements operation-specific logic.
-/// @tparam PassBase base class template for the pass generated by TableGen.
-/// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
-/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
-/// This function should implement operation-specific functionality.
-template <typename Pass, template <typename> class PassBase>
-class AbstractResultOptTemplate : public PassBase<Pass> {
+class AbstractResultOpt
+    : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
 public:
-  void runOnOperation() override {
-    auto *context = &this->getContext();
-    auto op = this->getOperation();
-
-    mlir::RewritePatternSet patterns(context);
-    mlir::ConversionTarget target = *context;
-    const bool shouldBoxResult = this->passResultAsBox.getValue();
-
-    auto &self = static_cast<Pass &>(*this);
-    self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
-
-    // Convert the calls and, if needed,  the ReturnOp in the function body.
-    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
-                           mlir::func::FuncDialect>();
-    target.addIllegalOp<fir::SaveResultOp>();
-    target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
-      return !hasAbstractResult(call.getFunctionType());
-    });
-    target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
-      if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
-        return !hasAbstractResult(funTy);
-      return true;
-    });
-    target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
-      return !hasAbstractResult(dispatch.getFunctionType());
-    });
-
-    patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
-    patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
-    patterns.insert<SaveResultOpConversion>(context);
-    patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
-    if (mlir::failed(
-            mlir::applyPartialConversion(op, target, std::move(patterns)))) {
-      mlir::emitError(op.getLoc(), "error in converting abstract results\n");
-      this->signalPassFailure();
-    }
-  }
-};
+  using fir::impl::AbstractResultOptBase<
+      AbstractResultOpt>::AbstractResultOptBase;
 
-class AbstractResultOnFuncOpt
-    : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
-                                       fir::impl::AbstractResultOnFuncOptBase> {
-public:
   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
                               mlir::RewritePatternSet &patterns,
                               mlir::ConversionTarget &target) {
@@ -386,25 +338,20 @@ class AbstractResultOnFuncOpt
       }
     }
   }
-};
 
-inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
-  return mlir::TypeSwitch<mlir::Type, bool>(type)
-      .Case([](fir::BoxProcType boxProc) {
-        return fir::hasAbstractResult(
-            boxProc.getEleTy().cast<mlir::FunctionType>());
-      })
-      .Case([](fir::PointerType pointer) {
-        return fir::hasAbstractResult(
-            pointer.getEleTy().cast<mlir::FunctionType>());
-      })
-      .Default([](auto &&) { return false; });
-}
+  inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
+    return mlir::TypeSwitch<mlir::Type, bool>(type)
+        .Case([](fir::BoxProcType boxProc) {
+          return fir::hasAbstractResult(
+              boxProc.getEleTy().cast<mlir::FunctionType>());
+        })
+        .Case([](fir::PointerType pointer) {
+          return fir::hasAbstractResult(
+              pointer.getEleTy().cast<mlir::FunctionType>());
+        })
+        .Default([](auto &&) { return false; });
+  }
 
-class AbstractResultOnGlobalOpt
-    : public AbstractResultOptTemplate<
-          AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
-public:
   void runOnSpecificOperation(fir::GlobalOp global, bool,
                               mlir::RewritePatternSet &,
                               mlir::ConversionTarget &) {
@@ -412,14 +359,77 @@ class AbstractResultOnGlobalOpt
       TODO(global->getLoc(), "support for procedure pointers");
     }
   }
-};
-} // end anonymous namespace
-} // namespace fir
 
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
-  return std::make_unique<AbstractResultOnFuncOpt>();
-}
+  /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
+  void runOnModule() {
+    mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
+
+    auto pass = std::make_unique<AbstractResultOpt>();
+    pass->copyOptionValuesFrom(this);
+    mlir::OpPassManager pipeline;
+    pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
+
+    // Run the pass on all operations directly nested inside of the ModuleOp
+    // we can't just call runOnSpecificOperation here because the pass
+    // implementation only works when scoped to a particular func.func or
+    // fir.global
+    for (mlir::Region &region : mod->getRegions()) {
+      for (mlir::Block &block : region.getBlocks()) {
+        for (mlir::Operation &op : block.getOperations()) {
+          if (mlir::failed(runPipeline(pipeline, &op))) {
+            mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
+            signalPassFailure();
+            return;
+          }
+        }
+      }
+    }
+  }
 
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
-  return std::make_unique<AbstractResultOnGlobalOpt>();
-}
+  void runOnOperation() override {
+    auto *context = &this->getContext();
+    mlir::Operation *op = this->getOperation();
+    if (mlir::isa<mlir::ModuleOp>(op)) {
+      runOnModule();
+      return;
+    }
+
+    mlir::RewritePatternSet patterns(context);
+    mlir::ConversionTarget target = *context;
+    const bool shouldBoxResult = this->passResultAsBox.getValue();
+
+    mlir::TypeSwitch<mlir::Operation *, void>(op)
+        .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
+          runOnSpecificOperation(op, shouldBoxResult, patterns, target);
+        });
+
+    // Convert the calls and, if needed,  the ReturnOp in the function body.
+    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+                           mlir::func::FuncDialect>();
+    target.addIllegalOp<fir::SaveResultOp>();
+    target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
+      return !hasAbstractResult(call.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
+      if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
+        return !hasAbstractResult(funTy);
+      return true;
+    });
+    target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
+      return !hasAbstractResult(dispatch.getFunctionType());
+    });
+
+    patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
+    patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
+    patterns.insert<SaveResultOpConversion>(context);
+    patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
+    if (mlir::failed(
+            mlir::applyPartialConversion(op, target, std::move(patterns)))) {
+      mlir::emitError(op->getLoc(), "error in converting abstract results\n");
+      this->signalPassFailure();
+    }
+  }
+};
+
+} // end anonymous namespace
+} // namespace fir
\ No newline at end of file

diff  --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 04d432f854ca35..ef84cb80ecf1db 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -72,11 +72,13 @@
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 ! ALL-NEXT: BoxedProcedurePass
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 ! ALL-NEXT:   'fir.global' Pipeline
-! ALL-NEXT:   AbstractResultOnGlobalOpt
+! ALL-NEXT:     AbstractResultOpt
 ! ALL-NEXT:   'func.func' Pipeline
-! ALL-NEXT:   AbstractResultOnFuncOpt
+! ALL-NEXT:     AbstractResultOpt
+! ALL-NEXT:   'omp.declare_reduction' Pipeline
+! ALL-NEXT:     AbstractResultOpt
 
 ! ALL-NEXT: CodeGenRewrite
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations eliminated

diff  --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index cfa0de63cde5e8..d1ff2869b0a6a9 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -67,11 +67,13 @@
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 ! ALL-NEXT: BoxedProcedurePass
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 ! ALL-NEXT:   'fir.global' Pipeline
-! ALL-NEXT:    AbstractResultOnGlobalOpt
+! ALL-NEXT:    AbstractResultOpt
 ! ALL-NEXT:  'func.func' Pipeline
-! ALL-NEXT:    AbstractResultOnFuncOpt
+! ALL-NEXT:    AbstractResultOpt
+! ALL-NEXT:  'omp.declare_reduction' Pipeline
+! ALL-NEXT:    AbstractResultOpt
 
 ! ALL-NEXT: CodeGenRewrite
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations eliminated

diff  --git a/flang/test/Fir/abstract-result-2.fir b/flang/test/Fir/abstract-result-2.fir
index 08b723b8305936..af13d57476e8c0 100644
--- a/flang/test/Fir/abstract-result-2.fir
+++ b/flang/test/Fir/abstract-result-2.fir
@@ -1,4 +1,4 @@
-// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s 
+// RUN: fir-opt %s --abstract-result | FileCheck %s
 
 // Check that the attributes are shifted along with their corresponding arguments
 

diff  --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
index 42ff2a5c8eb2a8..82f1cd33073fd3 100644
--- a/flang/test/Fir/abstract-results.fir
+++ b/flang/test/Fir/abstract-results.fir
@@ -1,10 +1,10 @@
 // Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
 // functions that take an additional argument for the result.
 
-// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s --check-prefix=FUNC-REF
-// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
-// RUN: fir-opt %s --abstract-result-on-global-opt | FileCheck %s --check-prefix=GLOBAL-REF
-// RUN: fir-opt %s --abstract-result-on-global-opt=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
+// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=FUNC-REF
+// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
+// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=GLOBAL-REF
+// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
 
 // ----------------------- Test declaration rewrite ----------------------------
 

diff  --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 80d3520bc7f7d4..28c597fc918cd7 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -74,11 +74,13 @@ func.func @_QQmain() {
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 // PASSES-NEXT: BoxedProcedurePass
 
-// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 // PASSES-NEXT:   'fir.global' Pipeline
-// PASSES-NEXT:    AbstractResultOnGlobalOpt
+// PASSES-NEXT:    AbstractResultOpt
 // PASSES-NEXT:  'func.func' Pipeline
-// PASSES-NEXT:    AbstractResultOnFuncOpt
+// PASSES-NEXT:    AbstractResultOpt
+// PASSES-NEXT:  'omp.declare_reduction' Pipeline
+// PASSES-NEXT:    AbstractResultOpt
 
 // PASSES-NEXT: CodeGenRewrite
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations eliminated

diff  --git a/flang/test/Fir/non-trivial-procedure-binding-description.f90 b/flang/test/Fir/non-trivial-procedure-binding-description.f90
index 695d7fdfe232d3..668928600157b1 100644
--- a/flang/test/Fir/non-trivial-procedure-binding-description.f90
+++ b/flang/test/Fir/non-trivial-procedure-binding-description.f90
@@ -1,5 +1,5 @@
 ! RUN: %flang_fc1 -emit-mlir %s -o - | FileCheck %s --check-prefix=BEFORE
-! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result-on-global-opt | FileCheck %s --check-prefix=AFTER
+! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result | FileCheck %s --check-prefix=AFTER
 module a
   type f
   contains


        


More information about the flang-commits mailing list