[flang-commits] [flang] bfd1944 - [flang] de-duplicate AbstractResult pass (#88867)
via flang-commits
flang-commits at lists.llvm.org
Mon Apr 22 02:11:14 PDT 2024
Author: Tom Eccles
Date: 2024-04-22T10:11:09+01:00
New Revision: bfd19445c38a2ad6a1def7ee9a1f8ff26a159caf
URL: https://github.com/llvm/llvm-project/commit/bfd19445c38a2ad6a1def7ee9a1f8ff26a159caf
DIFF: https://github.com/llvm/llvm-project/commit/bfd19445c38a2ad6a1def7ee9a1f8ff26a159caf.diff
LOG: [flang] de-duplicate AbstractResult pass (#88867)
This is the first proof of concept of the modification of FIR codegen to
fully support a variety of top level operations (beyond just func.func)
proposed in
https://discourse.llvm.org/t/rfc-add-an-interface-for-top-level-container-operations
Added:
Modified:
flang/include/flang/Optimizer/Transforms/Passes.h
flang/include/flang/Optimizer/Transforms/Passes.td
flang/include/flang/Tools/CLOptions.inc
flang/lib/Optimizer/Transforms/AbstractResult.cpp
flang/test/Driver/mlir-debug-pass-pipeline.f90
flang/test/Driver/mlir-pass-pipeline.f90
flang/test/Fir/abstract-result-2.fir
flang/test/Fir/abstract-results.fir
flang/test/Fir/basic-program.fir
flang/test/Fir/non-trivial-procedure-binding-description.f90
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index d8840d9e967b48..4d290d87d4cc95 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -31,8 +31,7 @@ namespace fir {
// Passes defined in Passes.td
//===----------------------------------------------------------------------===//
-#define GEN_PASS_DECL_ABSTRACTRESULTONFUNCOPT
-#define GEN_PASS_DECL_ABSTRACTRESULTONGLOBALOPT
+#define GEN_PASS_DECL_ABSTRACTRESULTOPT
#define GEN_PASS_DECL_AFFINEDIALECTPROMOTION
#define GEN_PASS_DECL_AFFINEDIALECTDEMOTION
#define GEN_PASS_DECL_ANNOTATECONSTANTOPERANDS
@@ -50,8 +49,6 @@ namespace fir {
#define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
#include "flang/Optimizer/Transforms/Passes.h.inc"
-std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
-std::unique_ptr<mlir::Pass> createAbstractResultOnGlobalOptPass();
std::unique_ptr<mlir::Pass> createAffineDemotionPass();
std::unique_ptr<mlir::Pass>
createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index bfc0db8124af21..467b7e1c472ec0 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -16,8 +16,8 @@
include "mlir/Pass/PassBase.td"
-class AbstractResultOptBase<string optExt, string operation>
- : Pass<"abstract-result-on-" # optExt # "-opt", operation> {
+def AbstractResultOpt
+ : Pass<"abstract-result"> {
let summary = "Convert fir.array, fir.box and fir.rec function result to "
"function argument";
let description = [{
@@ -35,14 +35,6 @@ class AbstractResultOptBase<string optExt, string operation>
];
}
-def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> {
- let constructor = "::fir::createAbstractResultOnFuncOptPass()";
-}
-
-def AbstractResultOnGlobalOpt : AbstractResultOptBase<"global", "fir::GlobalOp"> {
- let constructor = "::fir::createAbstractResultOnGlobalOptPass()";
-}
-
def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`.";
let description = [{
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index ea297fb337a2c8..44ff2b3f70ff68 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -19,6 +19,7 @@
#include "flang/Optimizer/Transforms/Passes.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Support/CommandLine.h"
+#include <type_traits>
#define DisableOption(DOName, DOOption, DODescription) \
static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
@@ -86,6 +87,29 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
DisableOption(ExternalNameConversion, "external-name-interop",
"convert names with external convention");
+// TODO: remove once these are used for non-codegen passes
+#if !defined(FLANG_EXCLUDE_CODEGEN)
+using PassConstructor = std::unique_ptr<mlir::Pass>();
+
+template <typename OP>
+void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
+ pm.addNestedPass<OP>(ctor());
+}
+
+template <typename OP, typename... OPS,
+ typename = std::enable_if_t<sizeof...(OPS) != 0>>
+void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
+ addNestedPassToOps<OP>(pm, ctor);
+ addNestedPassToOps<OPS...>(pm, ctor);
+}
+
+void addNestedPassToAllTopLevelOperations(
+ mlir::PassManager &pm, PassConstructor ctor) {
+ addNestedPassToOps<mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
+ fir::GlobalOp>(pm, ctor);
+}
+#endif
+
/// Generic for adding a pass to the pass manager if it is not disabled.
template <typename F>
void addPassConditionally(
@@ -304,9 +328,7 @@ inline void createDebugPasses(
inline void createDefaultFIRCodeGenPassPipeline(
mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config) {
fir::addBoxedProcedurePass(pm);
- pm.addNestedPass<mlir::func::FuncOp>(
- fir::createAbstractResultOnFuncOptPass());
- pm.addNestedPass<fir::GlobalOp>(fir::createAbstractResultOnGlobalOptPass());
+ addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOpt);
fir::addCodeGenRewritePass(pm);
fir::addTargetRewritePass(pm);
fir::addExternalNameConversionPass(pm, config.Underscoring);
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index dd1ddd16f2ded5..eb4dd637bb167e 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -16,13 +16,12 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/TypeSwitch.h"
namespace fir {
-#define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
-#define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
+#define GEN_PASS_DEF_ABSTRACTRESULTOPT
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
@@ -285,59 +284,12 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
bool shouldBoxResult;
};
-/// @brief Base CRTP class for AbstractResult pass family.
-/// Contains common logic for abstract result conversion in a reusable fashion.
-/// @tparam Pass target class that implements operation-specific logic.
-/// @tparam PassBase base class template for the pass generated by TableGen.
-/// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
-/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
-/// This function should implement operation-specific functionality.
-template <typename Pass, template <typename> class PassBase>
-class AbstractResultOptTemplate : public PassBase<Pass> {
+class AbstractResultOpt
+ : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
public:
- void runOnOperation() override {
- auto *context = &this->getContext();
- auto op = this->getOperation();
-
- mlir::RewritePatternSet patterns(context);
- mlir::ConversionTarget target = *context;
- const bool shouldBoxResult = this->passResultAsBox.getValue();
-
- auto &self = static_cast<Pass &>(*this);
- self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
-
- // Convert the calls and, if needed, the ReturnOp in the function body.
- target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
- mlir::func::FuncDialect>();
- target.addIllegalOp<fir::SaveResultOp>();
- target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
- return !hasAbstractResult(call.getFunctionType());
- });
- target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
- if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
- return !hasAbstractResult(funTy);
- return true;
- });
- target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
- return !hasAbstractResult(dispatch.getFunctionType());
- });
-
- patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
- patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
- patterns.insert<SaveResultOpConversion>(context);
- patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
- if (mlir::failed(
- mlir::applyPartialConversion(op, target, std::move(patterns)))) {
- mlir::emitError(op.getLoc(), "error in converting abstract results\n");
- this->signalPassFailure();
- }
- }
-};
+ using fir::impl::AbstractResultOptBase<
+ AbstractResultOpt>::AbstractResultOptBase;
-class AbstractResultOnFuncOpt
- : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
- fir::impl::AbstractResultOnFuncOptBase> {
-public:
void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
mlir::RewritePatternSet &patterns,
mlir::ConversionTarget &target) {
@@ -386,25 +338,20 @@ class AbstractResultOnFuncOpt
}
}
}
-};
-inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
- return mlir::TypeSwitch<mlir::Type, bool>(type)
- .Case([](fir::BoxProcType boxProc) {
- return fir::hasAbstractResult(
- boxProc.getEleTy().cast<mlir::FunctionType>());
- })
- .Case([](fir::PointerType pointer) {
- return fir::hasAbstractResult(
- pointer.getEleTy().cast<mlir::FunctionType>());
- })
- .Default([](auto &&) { return false; });
-}
+ inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
+ return mlir::TypeSwitch<mlir::Type, bool>(type)
+ .Case([](fir::BoxProcType boxProc) {
+ return fir::hasAbstractResult(
+ boxProc.getEleTy().cast<mlir::FunctionType>());
+ })
+ .Case([](fir::PointerType pointer) {
+ return fir::hasAbstractResult(
+ pointer.getEleTy().cast<mlir::FunctionType>());
+ })
+ .Default([](auto &&) { return false; });
+ }
-class AbstractResultOnGlobalOpt
- : public AbstractResultOptTemplate<
- AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
-public:
void runOnSpecificOperation(fir::GlobalOp global, bool,
mlir::RewritePatternSet &,
mlir::ConversionTarget &) {
@@ -412,14 +359,77 @@ class AbstractResultOnGlobalOpt
TODO(global->getLoc(), "support for procedure pointers");
}
}
-};
-} // end anonymous namespace
-} // namespace fir
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
- return std::make_unique<AbstractResultOnFuncOpt>();
-}
+ /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
+ void runOnModule() {
+ mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
+
+ auto pass = std::make_unique<AbstractResultOpt>();
+ pass->copyOptionValuesFrom(this);
+ mlir::OpPassManager pipeline;
+ pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
+
+ // Run the pass on all operations directly nested inside of the ModuleOp
+ // we can't just call runOnSpecificOperation here because the pass
+ // implementation only works when scoped to a particular func.func or
+ // fir.global
+ for (mlir::Region ®ion : mod->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (mlir::Operation &op : block.getOperations()) {
+ if (mlir::failed(runPipeline(pipeline, &op))) {
+ mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
+ signalPassFailure();
+ return;
+ }
+ }
+ }
+ }
+ }
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
- return std::make_unique<AbstractResultOnGlobalOpt>();
-}
+ void runOnOperation() override {
+ auto *context = &this->getContext();
+ mlir::Operation *op = this->getOperation();
+ if (mlir::isa<mlir::ModuleOp>(op)) {
+ runOnModule();
+ return;
+ }
+
+ mlir::RewritePatternSet patterns(context);
+ mlir::ConversionTarget target = *context;
+ const bool shouldBoxResult = this->passResultAsBox.getValue();
+
+ mlir::TypeSwitch<mlir::Operation *, void>(op)
+ .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
+ runOnSpecificOperation(op, shouldBoxResult, patterns, target);
+ });
+
+ // Convert the calls and, if needed, the ReturnOp in the function body.
+ target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+ mlir::func::FuncDialect>();
+ target.addIllegalOp<fir::SaveResultOp>();
+ target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
+ return !hasAbstractResult(call.getFunctionType());
+ });
+ target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
+ if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
+ return !hasAbstractResult(funTy);
+ return true;
+ });
+ target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
+ return !hasAbstractResult(dispatch.getFunctionType());
+ });
+
+ patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
+ patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
+ patterns.insert<SaveResultOpConversion>(context);
+ patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
+ if (mlir::failed(
+ mlir::applyPartialConversion(op, target, std::move(patterns)))) {
+ mlir::emitError(op->getLoc(), "error in converting abstract results\n");
+ this->signalPassFailure();
+ }
+ }
+};
+
+} // end anonymous namespace
+} // namespace fir
\ No newline at end of file
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 04d432f854ca35..ef84cb80ecf1db 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -72,11 +72,13 @@
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
! ALL-NEXT: BoxedProcedurePass
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
! ALL-NEXT: 'fir.global' Pipeline
-! ALL-NEXT: AbstractResultOnGlobalOpt
+! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'func.func' Pipeline
-! ALL-NEXT: AbstractResultOnFuncOpt
+! ALL-NEXT: AbstractResultOpt
+! ALL-NEXT: 'omp.declare_reduction' Pipeline
+! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: CodeGenRewrite
! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index cfa0de63cde5e8..d1ff2869b0a6a9 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -67,11 +67,13 @@
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
! ALL-NEXT: BoxedProcedurePass
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
! ALL-NEXT: 'fir.global' Pipeline
-! ALL-NEXT: AbstractResultOnGlobalOpt
+! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'func.func' Pipeline
-! ALL-NEXT: AbstractResultOnFuncOpt
+! ALL-NEXT: AbstractResultOpt
+! ALL-NEXT: 'omp.declare_reduction' Pipeline
+! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: CodeGenRewrite
! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Fir/abstract-result-2.fir b/flang/test/Fir/abstract-result-2.fir
index 08b723b8305936..af13d57476e8c0 100644
--- a/flang/test/Fir/abstract-result-2.fir
+++ b/flang/test/Fir/abstract-result-2.fir
@@ -1,4 +1,4 @@
-// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s
+// RUN: fir-opt %s --abstract-result | FileCheck %s
// Check that the attributes are shifted along with their corresponding arguments
diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
index 42ff2a5c8eb2a8..82f1cd33073fd3 100644
--- a/flang/test/Fir/abstract-results.fir
+++ b/flang/test/Fir/abstract-results.fir
@@ -1,10 +1,10 @@
// Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
// functions that take an additional argument for the result.
-// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s --check-prefix=FUNC-REF
-// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
-// RUN: fir-opt %s --abstract-result-on-global-opt | FileCheck %s --check-prefix=GLOBAL-REF
-// RUN: fir-opt %s --abstract-result-on-global-opt=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
+// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=FUNC-REF
+// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
+// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=GLOBAL-REF
+// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
// ----------------------- Test declaration rewrite ----------------------------
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 80d3520bc7f7d4..28c597fc918cd7 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -74,11 +74,13 @@ func.func @_QQmain() {
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
// PASSES-NEXT: BoxedProcedurePass
-// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
// PASSES-NEXT: 'fir.global' Pipeline
-// PASSES-NEXT: AbstractResultOnGlobalOpt
+// PASSES-NEXT: AbstractResultOpt
// PASSES-NEXT: 'func.func' Pipeline
-// PASSES-NEXT: AbstractResultOnFuncOpt
+// PASSES-NEXT: AbstractResultOpt
+// PASSES-NEXT: 'omp.declare_reduction' Pipeline
+// PASSES-NEXT: AbstractResultOpt
// PASSES-NEXT: CodeGenRewrite
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Fir/non-trivial-procedure-binding-description.f90 b/flang/test/Fir/non-trivial-procedure-binding-description.f90
index 695d7fdfe232d3..668928600157b1 100644
--- a/flang/test/Fir/non-trivial-procedure-binding-description.f90
+++ b/flang/test/Fir/non-trivial-procedure-binding-description.f90
@@ -1,5 +1,5 @@
! RUN: %flang_fc1 -emit-mlir %s -o - | FileCheck %s --check-prefix=BEFORE
-! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result-on-global-opt | FileCheck %s --check-prefix=AFTER
+! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result | FileCheck %s --check-prefix=AFTER
module a
type f
contains
More information about the flang-commits
mailing list