[flang-commits] [flang] [flang] de-duplicate AbstractResult pass (PR #88867)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Fri Apr 19 10:11:07 PDT 2024


https://github.com/tblah updated https://github.com/llvm/llvm-project/pull/88867

>From d736b2561aec2462c4fdee96517cbfca19aff466 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 15 Apr 2024 13:09:37 +0000
Subject: [PATCH 1/4] [flang] de-duplicate AbstractResult pass

This is the first proof of concept of the modification of FIR lowering
to fully support a variety of top level operations (beyond just
func.func) proposed in
https://discourse.llvm.org/t/rfc-add-an-interface-for-top-level-container-operations

One unfortunate side-effect of this is that the new AbstractResult pass
cannot be scheduled on a builtin.module operation and so we can't use
  fir-opt --abstract-result < file.fir

I tried adding support for operating on a module to the pass, but this
wasn't straightforward. Operating at module scope means that conversions
added for return operations run on every return operation in the module
rather than just in the current function and this violates assumptions
in the pass: producing incorrect results. This doesn't effect normal
operation because the pass manager will always run the pass on a
specific top level operation not on a whole module. I have worked around
this by specifying the pass pipeline more specifically in the tests.

I expect most other passes will be able to keep their old fir-opt
interface.
---
 .../flang/Optimizer/Dialect/FIROpsSupport.h   |   7 +
 .../flang/Optimizer/Transforms/Passes.h       |   6 +-
 .../flang/Optimizer/Transforms/Passes.td      |  13 +-
 flang/include/flang/Tools/CLOptions.inc       |  30 +++-
 flang/lib/Optimizer/Dialect/FIROps.cpp        |  13 ++
 .../Optimizer/Transforms/AbstractResult.cpp   | 139 ++++++++----------
 .../test/Driver/mlir-debug-pass-pipeline.f90  |   8 +-
 flang/test/Driver/mlir-pass-pipeline.f90      |   8 +-
 flang/test/Fir/abstract-result-2.fir          |   2 +-
 flang/test/Fir/abstract-results.fir           |   8 +-
 flang/test/Fir/basic-program.fir              |   8 +-
 ...-trivial-procedure-binding-description.f90 |   2 +-
 12 files changed, 135 insertions(+), 109 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
index 3266ea3aa7fdc6..44f2985e573785 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
@@ -173,6 +173,13 @@ inline mlir::NamedAttribute getAdaptToByRefAttr(Builder &builder) {
           builder.getUnitAttr()};
 }
 
+/// Returns true if the operation name is for a container operation expected to
+/// contain (HL)FIR operations which need to be lowered by FIR passes. The
+/// simplest example of this is func.func.
+/// This operates on mlir::RegisteredOperationName so that it can be used to
+/// implement mlir::Pass::canScheduleOn.
+bool isa_toplevel(mlir::RegisteredOperationName opName);
+
 } // namespace fir
 
 #endif // FORTRAN_OPTIMIZER_DIALECT_FIROPSSUPPORT_H
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index d8840d9e967b48..8520324e5491e1 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -31,8 +31,7 @@ namespace fir {
 // Passes defined in Passes.td
 //===----------------------------------------------------------------------===//
 
-#define GEN_PASS_DECL_ABSTRACTRESULTONFUNCOPT
-#define GEN_PASS_DECL_ABSTRACTRESULTONGLOBALOPT
+#define GEN_PASS_DECL_ABSTRACTRESULTOPT
 #define GEN_PASS_DECL_AFFINEDIALECTPROMOTION
 #define GEN_PASS_DECL_AFFINEDIALECTDEMOTION
 #define GEN_PASS_DECL_ANNOTATECONSTANTOPERANDS
@@ -50,8 +49,7 @@ namespace fir {
 #define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 
-std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
-std::unique_ptr<mlir::Pass> createAbstractResultOnGlobalOptPass();
+std::unique_ptr<mlir::Pass> createAbstractResultOptPass();
 std::unique_ptr<mlir::Pass> createAffineDemotionPass();
 std::unique_ptr<mlir::Pass>
 createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 187796d77cf5c1..06887091a1d3ac 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -16,8 +16,8 @@
 
 include "mlir/Pass/PassBase.td"
 
-class AbstractResultOptBase<string optExt, string operation> 
-  : Pass<"abstract-result-on-" # optExt # "-opt", operation> {
+def AbstractResultOpt
+  : Pass<"abstract-result"> {
   let summary = "Convert fir.array, fir.box and fir.rec function result to "
                 "function argument";
   let description = [{
@@ -33,14 +33,7 @@ class AbstractResultOptBase<string optExt, string operation>
            "Pass fir.array<T> result as fir.box<fir.array<T>> argument instead"
            " of fir.ref<fir.array<T>>.">
   ];
-}
-
-def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> {
-  let constructor = "::fir::createAbstractResultOnFuncOptPass()";
-}
-
-def AbstractResultOnGlobalOpt : AbstractResultOptBase<"global", "fir::GlobalOp"> {
-  let constructor = "::fir::createAbstractResultOnGlobalOptPass()";
+  let constructor = "::fir::createAbstractResultOptPass()";
 }
 
 def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 268d00b5a60535..2735a0944e8e9e 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -19,6 +19,7 @@
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "llvm/Passes/OptimizationLevel.h"
 #include "llvm/Support/CommandLine.h"
+#include <type_traits>
 
 #define DisableOption(DOName, DOOption, DODescription) \
   static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
@@ -86,6 +87,31 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
 DisableOption(ExternalNameConversion, "external-name-interop",
     "convert names with external convention");
 
+// TODO: remove once these are used for non-codegen passes
+#if !defined(FLANG_EXCLUDE_CODEGEN)
+using PassConstructor = std::function<std::unique_ptr<mlir::Pass>()>;
+
+template <typename OP>
+void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
+  pm.addNestedPass<OP>(ctor());
+}
+
+template <typename OP, typename... OPS,
+    typename = std::enable_if_t<sizeof...(OPS) != 0>>
+void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
+  addNestedPassToOps<OP>(pm, ctor);
+  addNestedPassToOps<OPS...>(pm, ctor);
+}
+
+void addNestedPassToAllTopLevelOperations(
+    mlir::PassManager &pm, PassConstructor ctor) {
+  // TODO: add more operations that might need full lowering support
+  // any operations also need to be added to fir::isa_toplevel
+  addNestedPassToOps<mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
+      fir::GlobalOp>(pm, ctor);
+}
+#endif
+
 /// Generic for adding a pass to the pass manager if it is not disabled.
 template <typename F>
 void addPassConditionally(
@@ -304,9 +330,7 @@ inline void createDebugPasses(
 inline void createDefaultFIRCodeGenPassPipeline(
     mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config) {
   fir::addBoxedProcedurePass(pm);
-  pm.addNestedPass<mlir::func::FuncOp>(
-      fir::createAbstractResultOnFuncOptPass());
-  pm.addNestedPass<fir::GlobalOp>(fir::createAbstractResultOnGlobalOptPass());
+  addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOptPass);
   fir::addCodeGenRewritePass(pm);
   fir::addTargetRewritePass(pm);
   fir::addExternalNameConversionPass(pm, config.Underscoring);
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 88710880174d21..0bbbf59dbb352a 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3846,6 +3846,19 @@ std::optional<std::int64_t> fir::getIntIfConstant(mlir::Value value) {
   return {};
 }
 
+bool fir::isa_toplevel(mlir::RegisteredOperationName opName) {
+  const std::initializer_list<llvm::StringLiteral> topLevelOps{
+      fir::GlobalOp::getOperationName(),
+      mlir::func::FuncOp::getOperationName(),
+      mlir::omp::DeclareReductionOp::getOperationName(),
+  };
+
+  llvm::StringRef opStr = opName.getStringRef();
+  return llvm::any_of(topLevelOps, [&](const llvm::StringRef &topLevelOp) {
+    return opStr == topLevelOp;
+  });
+}
+
 mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) {
   for (auto i = path.begin(), end = path.end(); eleTy && i < end;) {
     eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy)
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index dd1ddd16f2ded5..e295694f84d3fc 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -21,8 +21,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace fir {
-#define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
-#define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
+#define GEN_PASS_DEF_ABSTRACTRESULTOPT
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 } // namespace fir
 
@@ -285,58 +284,8 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
   bool shouldBoxResult;
 };
 
-/// @brief Base CRTP class for AbstractResult pass family.
-/// Contains common logic for abstract result conversion in a reusable fashion.
-/// @tparam Pass target class that implements operation-specific logic.
-/// @tparam PassBase base class template for the pass generated by TableGen.
-/// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
-/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
-/// This function should implement operation-specific functionality.
-template <typename Pass, template <typename> class PassBase>
-class AbstractResultOptTemplate : public PassBase<Pass> {
-public:
-  void runOnOperation() override {
-    auto *context = &this->getContext();
-    auto op = this->getOperation();
-
-    mlir::RewritePatternSet patterns(context);
-    mlir::ConversionTarget target = *context;
-    const bool shouldBoxResult = this->passResultAsBox.getValue();
-
-    auto &self = static_cast<Pass &>(*this);
-    self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
-
-    // Convert the calls and, if needed,  the ReturnOp in the function body.
-    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
-                           mlir::func::FuncDialect>();
-    target.addIllegalOp<fir::SaveResultOp>();
-    target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
-      return !hasAbstractResult(call.getFunctionType());
-    });
-    target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
-      if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
-        return !hasAbstractResult(funTy);
-      return true;
-    });
-    target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
-      return !hasAbstractResult(dispatch.getFunctionType());
-    });
-
-    patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
-    patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
-    patterns.insert<SaveResultOpConversion>(context);
-    patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
-    if (mlir::failed(
-            mlir::applyPartialConversion(op, target, std::move(patterns)))) {
-      mlir::emitError(op.getLoc(), "error in converting abstract results\n");
-      this->signalPassFailure();
-    }
-  }
-};
-
-class AbstractResultOnFuncOpt
-    : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
-                                       fir::impl::AbstractResultOnFuncOptBase> {
+class AbstractResultOpt
+    : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
 public:
   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
                               mlir::RewritePatternSet &patterns,
@@ -386,25 +335,20 @@ class AbstractResultOnFuncOpt
       }
     }
   }
-};
 
-inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
-  return mlir::TypeSwitch<mlir::Type, bool>(type)
-      .Case([](fir::BoxProcType boxProc) {
-        return fir::hasAbstractResult(
-            boxProc.getEleTy().cast<mlir::FunctionType>());
-      })
-      .Case([](fir::PointerType pointer) {
-        return fir::hasAbstractResult(
-            pointer.getEleTy().cast<mlir::FunctionType>());
-      })
-      .Default([](auto &&) { return false; });
-}
+  inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
+    return mlir::TypeSwitch<mlir::Type, bool>(type)
+        .Case([](fir::BoxProcType boxProc) {
+          return fir::hasAbstractResult(
+              boxProc.getEleTy().cast<mlir::FunctionType>());
+        })
+        .Case([](fir::PointerType pointer) {
+          return fir::hasAbstractResult(
+              pointer.getEleTy().cast<mlir::FunctionType>());
+        })
+        .Default([](auto &&) { return false; });
+  }
 
-class AbstractResultOnGlobalOpt
-    : public AbstractResultOptTemplate<
-          AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
-public:
   void runOnSpecificOperation(fir::GlobalOp global, bool,
                               mlir::RewritePatternSet &,
                               mlir::ConversionTarget &) {
@@ -412,14 +356,55 @@ class AbstractResultOnGlobalOpt
       TODO(global->getLoc(), "support for procedure pointers");
     }
   }
+
+  virtual bool canScheduleOn(RegisteredOperationName opName) const override {
+    return fir::isa_toplevel(opName);
+  }
+
+  void runOnOperation() override {
+    auto *context = &this->getContext();
+    mlir::Operation *op = this->getOperation();
+
+    mlir::RewritePatternSet patterns(context);
+    mlir::ConversionTarget target = *context;
+    const bool shouldBoxResult = this->passResultAsBox.getValue();
+
+    mlir::TypeSwitch<mlir::Operation *, void>(op)
+        .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
+          runOnSpecificOperation(op, shouldBoxResult, patterns, target);
+        });
+
+    // Convert the calls and, if needed,  the ReturnOp in the function body.
+    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+                           mlir::func::FuncDialect>();
+    target.addIllegalOp<fir::SaveResultOp>();
+    target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
+      return !hasAbstractResult(call.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
+      if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
+        return !hasAbstractResult(funTy);
+      return true;
+    });
+    target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
+      return !hasAbstractResult(dispatch.getFunctionType());
+    });
+
+    patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
+    patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
+    patterns.insert<SaveResultOpConversion>(context);
+    patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
+    if (mlir::failed(
+            mlir::applyPartialConversion(op, target, std::move(patterns)))) {
+      mlir::emitError(op->getLoc(), "error in converting abstract results\n");
+      this->signalPassFailure();
+    }
+  }
 };
+
 } // end anonymous namespace
 } // namespace fir
 
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
-  return std::make_unique<AbstractResultOnFuncOpt>();
-}
-
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
-  return std::make_unique<AbstractResultOnGlobalOpt>();
+std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
+  return std::make_unique<AbstractResultOpt>();
 }
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 04d432f854ca35..ef84cb80ecf1db 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -72,11 +72,13 @@
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 ! ALL-NEXT: BoxedProcedurePass
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 ! ALL-NEXT:   'fir.global' Pipeline
-! ALL-NEXT:   AbstractResultOnGlobalOpt
+! ALL-NEXT:     AbstractResultOpt
 ! ALL-NEXT:   'func.func' Pipeline
-! ALL-NEXT:   AbstractResultOnFuncOpt
+! ALL-NEXT:     AbstractResultOpt
+! ALL-NEXT:   'omp.declare_reduction' Pipeline
+! ALL-NEXT:     AbstractResultOpt
 
 ! ALL-NEXT: CodeGenRewrite
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index cfa0de63cde5e8..d1ff2869b0a6a9 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -67,11 +67,13 @@
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 ! ALL-NEXT: BoxedProcedurePass
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 ! ALL-NEXT:   'fir.global' Pipeline
-! ALL-NEXT:    AbstractResultOnGlobalOpt
+! ALL-NEXT:    AbstractResultOpt
 ! ALL-NEXT:  'func.func' Pipeline
-! ALL-NEXT:    AbstractResultOnFuncOpt
+! ALL-NEXT:    AbstractResultOpt
+! ALL-NEXT:  'omp.declare_reduction' Pipeline
+! ALL-NEXT:    AbstractResultOpt
 
 ! ALL-NEXT: CodeGenRewrite
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Fir/abstract-result-2.fir b/flang/test/Fir/abstract-result-2.fir
index 08b723b8305936..d0cba7a9a63431 100644
--- a/flang/test/Fir/abstract-result-2.fir
+++ b/flang/test/Fir/abstract-result-2.fir
@@ -1,4 +1,4 @@
-// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s 
+// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result))" | FileCheck %s
 
 // Check that the attributes are shifted along with their corresponding arguments
 
diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
index 42ff2a5c8eb2a8..4aac7f70d21039 100644
--- a/flang/test/Fir/abstract-results.fir
+++ b/flang/test/Fir/abstract-results.fir
@@ -1,10 +1,10 @@
 // Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
 // functions that take an additional argument for the result.
 
-// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s --check-prefix=FUNC-REF
-// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
-// RUN: fir-opt %s --abstract-result-on-global-opt | FileCheck %s --check-prefix=GLOBAL-REF
-// RUN: fir-opt %s --abstract-result-on-global-opt=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
+// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result))" | FileCheck %s --check-prefix=FUNC-REF
+// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result{abstract-result-as-box}))" | FileCheck %s --check-prefix=FUNC-BOX
+// RUN: fir-opt %s --pass-pipeline="builtin.module(fir.global(abstract-result))" | FileCheck %s --check-prefix=GLOBAL-REF
+// RUN: fir-opt %s --pass-pipeline="builtin.module(fir.global(abstract-result{abstract-result-as-box}))" | FileCheck %s --check-prefix=GLOBAL-BOX
 
 // ----------------------- Test declaration rewrite ----------------------------
 
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 80d3520bc7f7d4..28c597fc918cd7 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -74,11 +74,13 @@ func.func @_QQmain() {
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 // PASSES-NEXT: BoxedProcedurePass
 
-// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 // PASSES-NEXT:   'fir.global' Pipeline
-// PASSES-NEXT:    AbstractResultOnGlobalOpt
+// PASSES-NEXT:    AbstractResultOpt
 // PASSES-NEXT:  'func.func' Pipeline
-// PASSES-NEXT:    AbstractResultOnFuncOpt
+// PASSES-NEXT:    AbstractResultOpt
+// PASSES-NEXT:  'omp.declare_reduction' Pipeline
+// PASSES-NEXT:    AbstractResultOpt
 
 // PASSES-NEXT: CodeGenRewrite
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Fir/non-trivial-procedure-binding-description.f90 b/flang/test/Fir/non-trivial-procedure-binding-description.f90
index 695d7fdfe232d3..f59248961d2ea1 100644
--- a/flang/test/Fir/non-trivial-procedure-binding-description.f90
+++ b/flang/test/Fir/non-trivial-procedure-binding-description.f90
@@ -1,5 +1,5 @@
 ! RUN: %flang_fc1 -emit-mlir %s -o - | FileCheck %s --check-prefix=BEFORE
-! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result-on-global-opt | FileCheck %s --check-prefix=AFTER
+! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --pass-pipeline="builtin.module(fir.global(abstract-result))" | FileCheck %s --check-prefix=AFTER
 module a
   type f
   contains

>From d6210ea9188644c4dbfb9c2f34eee52b41894f9c Mon Sep 17 00:00:00 2001
From: Tom Eccles <t at freedommail.info>
Date: Tue, 16 Apr 2024 16:55:00 +0100
Subject: [PATCH 2/4] Use llvm::function_ref
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
 flang/include/flang/Tools/CLOptions.inc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 2735a0944e8e9e..b63da792279127 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -89,7 +89,7 @@ DisableOption(ExternalNameConversion, "external-name-interop",
 
 // TODO: remove once these are used for non-codegen passes
 #if !defined(FLANG_EXCLUDE_CODEGEN)
-using PassConstructor = std::function<std::unique_ptr<mlir::Pass>()>;
+using PassConstructor = llvm::function_ref<std::unique_ptr<mlir::Pass>()>;
 
 template <typename OP>
 void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {

>From d699cf8a292a628bbb813dbebddd7cd9fd916487 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Fri, 19 Apr 2024 16:13:19 +0000
Subject: [PATCH 3/4] Support running the pass on modules

---
 .../flang/Optimizer/Dialect/FIROpsSupport.h   |  7 ----
 flang/include/flang/Tools/CLOptions.inc       |  2 --
 flang/lib/Optimizer/Dialect/FIROps.cpp        | 13 --------
 .../Optimizer/Transforms/AbstractResult.cpp   | 32 +++++++++++++++++--
 flang/test/Fir/abstract-result-2.fir          |  2 +-
 flang/test/Fir/abstract-results.fir           |  8 ++---
 ...-trivial-procedure-binding-description.f90 |  2 +-
 7 files changed, 35 insertions(+), 31 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
index 44f2985e573785..3266ea3aa7fdc6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
@@ -173,13 +173,6 @@ inline mlir::NamedAttribute getAdaptToByRefAttr(Builder &builder) {
           builder.getUnitAttr()};
 }
 
-/// Returns true if the operation name is for a container operation expected to
-/// contain (HL)FIR operations which need to be lowered by FIR passes. The
-/// simplest example of this is func.func.
-/// This operates on mlir::RegisteredOperationName so that it can be used to
-/// implement mlir::Pass::canScheduleOn.
-bool isa_toplevel(mlir::RegisteredOperationName opName);
-
 } // namespace fir
 
 #endif // FORTRAN_OPTIMIZER_DIALECT_FIROPSSUPPORT_H
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index b63da792279127..bb30b5cfb4a850 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -105,8 +105,6 @@ void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
 
 void addNestedPassToAllTopLevelOperations(
     mlir::PassManager &pm, PassConstructor ctor) {
-  // TODO: add more operations that might need full lowering support
-  // any operations also need to be added to fir::isa_toplevel
   addNestedPassToOps<mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
       fir::GlobalOp>(pm, ctor);
 }
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 0bbbf59dbb352a..88710880174d21 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3846,19 +3846,6 @@ std::optional<std::int64_t> fir::getIntIfConstant(mlir::Value value) {
   return {};
 }
 
-bool fir::isa_toplevel(mlir::RegisteredOperationName opName) {
-  const std::initializer_list<llvm::StringLiteral> topLevelOps{
-      fir::GlobalOp::getOperationName(),
-      mlir::func::FuncOp::getOperationName(),
-      mlir::omp::DeclareReductionOp::getOperationName(),
-  };
-
-  llvm::StringRef opStr = opName.getStringRef();
-  return llvm::any_of(topLevelOps, [&](const llvm::StringRef &topLevelOp) {
-    return opStr == topLevelOp;
-  });
-}
-
 mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) {
   for (auto i = path.begin(), end = path.end(); eleTy && i < end;) {
     eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy)
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index e295694f84d3fc..636c8aabc7af32 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -16,8 +16,8 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace fir {
@@ -357,13 +357,39 @@ class AbstractResultOpt
     }
   }
 
-  virtual bool canScheduleOn(RegisteredOperationName opName) const override {
-    return fir::isa_toplevel(opName);
+  /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
+  void runOnModule() {
+    mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
+
+    auto pass = std::make_unique<AbstractResultOpt>();
+    pass->copyOptionValuesFrom(this);
+    mlir::OpPassManager pipeline;
+    pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
+
+    // Run the pass on all operations directly nested inside of the ModuleOp
+    // we can't just call runOnSpecificOperation here because the pass
+    // implementation only works when scoped to a particular func.func or
+    // fir.global
+    for (mlir::Region &region : mod->getRegions()) {
+      for (mlir::Block &block : region.getBlocks()) {
+        for (mlir::Operation &op : block.getOperations()) {
+          if (mlir::failed(runPipeline(pipeline, &op))) {
+            mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
+            signalPassFailure();
+            return;
+          }
+        }
+      }
+    }
   }
 
   void runOnOperation() override {
     auto *context = &this->getContext();
     mlir::Operation *op = this->getOperation();
+    if (mlir::isa<mlir::ModuleOp>(op)) {
+      runOnModule();
+      return;
+    }
 
     mlir::RewritePatternSet patterns(context);
     mlir::ConversionTarget target = *context;
diff --git a/flang/test/Fir/abstract-result-2.fir b/flang/test/Fir/abstract-result-2.fir
index d0cba7a9a63431..af13d57476e8c0 100644
--- a/flang/test/Fir/abstract-result-2.fir
+++ b/flang/test/Fir/abstract-result-2.fir
@@ -1,4 +1,4 @@
-// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result))" | FileCheck %s
+// RUN: fir-opt %s --abstract-result | FileCheck %s
 
 // Check that the attributes are shifted along with their corresponding arguments
 
diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
index 4aac7f70d21039..82f1cd33073fd3 100644
--- a/flang/test/Fir/abstract-results.fir
+++ b/flang/test/Fir/abstract-results.fir
@@ -1,10 +1,10 @@
 // Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
 // functions that take an additional argument for the result.
 
-// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result))" | FileCheck %s --check-prefix=FUNC-REF
-// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result{abstract-result-as-box}))" | FileCheck %s --check-prefix=FUNC-BOX
-// RUN: fir-opt %s --pass-pipeline="builtin.module(fir.global(abstract-result))" | FileCheck %s --check-prefix=GLOBAL-REF
-// RUN: fir-opt %s --pass-pipeline="builtin.module(fir.global(abstract-result{abstract-result-as-box}))" | FileCheck %s --check-prefix=GLOBAL-BOX
+// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=FUNC-REF
+// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
+// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=GLOBAL-REF
+// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
 
 // ----------------------- Test declaration rewrite ----------------------------
 
diff --git a/flang/test/Fir/non-trivial-procedure-binding-description.f90 b/flang/test/Fir/non-trivial-procedure-binding-description.f90
index f59248961d2ea1..668928600157b1 100644
--- a/flang/test/Fir/non-trivial-procedure-binding-description.f90
+++ b/flang/test/Fir/non-trivial-procedure-binding-description.f90
@@ -1,5 +1,5 @@
 ! RUN: %flang_fc1 -emit-mlir %s -o - | FileCheck %s --check-prefix=BEFORE
-! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --pass-pipeline="builtin.module(fir.global(abstract-result))" | FileCheck %s --check-prefix=AFTER
+! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result | FileCheck %s --check-prefix=AFTER
 module a
   type f
   contains

>From a8c28942c0125ac30fe50afddeab07a38eb651a8 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Fri, 19 Apr 2024 17:02:51 +0000
Subject: [PATCH 4/4] Don't use tablegen to make the factory function

---
 flang/include/flang/Optimizer/Transforms/Passes.h  | 1 -
 flang/include/flang/Optimizer/Transforms/Passes.td | 1 -
 flang/include/flang/Tools/CLOptions.inc            | 4 ++--
 flang/lib/Optimizer/Transforms/AbstractResult.cpp  | 9 ++++-----
 4 files changed, 6 insertions(+), 9 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 8520324e5491e1..4d290d87d4cc95 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -49,7 +49,6 @@ namespace fir {
 #define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 
-std::unique_ptr<mlir::Pass> createAbstractResultOptPass();
 std::unique_ptr<mlir::Pass> createAffineDemotionPass();
 std::unique_ptr<mlir::Pass>
 createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 06887091a1d3ac..f3ef6ceeb8f8fa 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -33,7 +33,6 @@ def AbstractResultOpt
            "Pass fir.array<T> result as fir.box<fir.array<T>> argument instead"
            " of fir.ref<fir.array<T>>.">
   ];
-  let constructor = "::fir::createAbstractResultOptPass()";
 }
 
 def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index bb30b5cfb4a850..c043fcfb55fce2 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -89,7 +89,7 @@ DisableOption(ExternalNameConversion, "external-name-interop",
 
 // TODO: remove once these are used for non-codegen passes
 #if !defined(FLANG_EXCLUDE_CODEGEN)
-using PassConstructor = llvm::function_ref<std::unique_ptr<mlir::Pass>()>;
+using PassConstructor = std::unique_ptr<mlir::Pass>();
 
 template <typename OP>
 void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
@@ -328,7 +328,7 @@ inline void createDebugPasses(
 inline void createDefaultFIRCodeGenPassPipeline(
     mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config) {
   fir::addBoxedProcedurePass(pm);
-  addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOptPass);
+  addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOpt);
   fir::addCodeGenRewritePass(pm);
   fir::addTargetRewritePass(pm);
   fir::addExternalNameConversionPass(pm, config.Underscoring);
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index 636c8aabc7af32..eb4dd637bb167e 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -287,6 +287,9 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
 class AbstractResultOpt
     : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
 public:
+  using fir::impl::AbstractResultOptBase<
+      AbstractResultOpt>::AbstractResultOptBase;
+
   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
                               mlir::RewritePatternSet &patterns,
                               mlir::ConversionTarget &target) {
@@ -429,8 +432,4 @@ class AbstractResultOpt
 };
 
 } // end anonymous namespace
-} // namespace fir
-
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
-  return std::make_unique<AbstractResultOpt>();
-}
+} // namespace fir
\ No newline at end of file



More information about the flang-commits mailing list