[llvm-branch-commits] [flang] [flang] Move OpenMP-related code from `FirConverter` to `OpenMPMixin` (PR #74866)
Krzysztof Parzyszek via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Dec 9 13:05:16 PST 2023
https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/74866
>From 27fab0c65445893fb27baead5573bad2dd690dfc Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Fri, 8 Dec 2023 09:13:11 -0600
Subject: [PATCH] [flang] Move OpenMP-related code from `FirConverter` to
`OpenMPMixin`
This improves the separation of the generic Fortran lowering and the
lowering of OpenMP constructs.
The mixin is intended to be derived from via CRTP:
```
class FirConverter : public OpenMPMixin<FirConverter> ...
```
The primary goal of the mixin is to implement `genFIR` functions
that the derived converter can then call via
```
std::visit([this](auto &&s) { genFIR(s); });
```
The mixin is also expecting a handful of functions to be present
in the derived class, most importantly `genFIR(Evaluation&)`, plus
getter classes for the op builder, symbol table, etc.
The pre-existing PFT-lowering functionality is preserved.
---
flang/lib/Lower/Bridge.cpp | 84 +---------------------
flang/lib/Lower/ConverterMixin.h | 28 ++++++++
flang/lib/Lower/FirConverter.h | 38 +++++-----
flang/lib/Lower/OpenMP.cpp | 118 ++++++++++++++++++++++++++++++-
flang/lib/Lower/OpenMPMixin.h | 66 +++++++++++++++++
5 files changed, 235 insertions(+), 99 deletions(-)
create mode 100644 flang/lib/Lower/ConverterMixin.h
create mode 100644 flang/lib/Lower/OpenMPMixin.h
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 5aaba233744b2d..0a476a38d8d2de 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -170,7 +170,7 @@ void FirConverter::run(Fortran::lower::pft::Program &pft) {
});
finalizeOpenACCLowering();
- finalizeOpenMPLowering(globalOmpRequiresSymbol);
+ OpenMPBase::finalize(globalOmpRequiresSymbol);
}
/// Generate FIR for Evaluation \p eval.
@@ -977,70 +977,6 @@ void FirConverter::genFIR(const Fortran::parser::OpenACCRoutineConstruct &acc) {
// Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &)
}
-void FirConverter::genFIR(const Fortran::parser::OpenMPConstruct &omp) {
- mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
- localSymbols.pushScope();
- genOpenMPConstruct(*this, bridge.getSemanticsContext(), getEval(), omp);
-
- const Fortran::parser::OpenMPLoopConstruct *ompLoop =
- std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
- const Fortran::parser::OpenMPBlockConstruct *ompBlock =
- std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u);
-
- // If loop is part of an OpenMP Construct then the OpenMP dialect
- // workshare loop operation has already been created. Only the
- // body needs to be created here and the do_loop can be skipped.
- // Skip the number of collapsed loops, which is 1 when there is a
- // no collapse requested.
-
- Fortran::lower::pft::Evaluation *curEval = &getEval();
- const Fortran::parser::OmpClauseList *loopOpClauseList = nullptr;
- if (ompLoop) {
- loopOpClauseList = &std::get<Fortran::parser::OmpClauseList>(
- std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t);
- int64_t collapseValue = Fortran::lower::getCollapseValue(*loopOpClauseList);
-
- curEval = &curEval->getFirstNestedEvaluation();
- for (int64_t i = 1; i < collapseValue; i++) {
- curEval = &*std::next(curEval->getNestedEvaluations().begin());
- }
- }
-
- for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
- genFIR(e);
-
- if (ompLoop) {
- genOpenMPReduction(*this, *loopOpClauseList);
- } else if (ompBlock) {
- const auto &blockStart =
- std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t);
- const auto &blockClauses =
- std::get<Fortran::parser::OmpClauseList>(blockStart.t);
- genOpenMPReduction(*this, blockClauses);
- }
-
- localSymbols.popScope();
- builder->restoreInsertionPoint(insertPt);
-
- // Register if a target region was found
- ompDeviceCodeFound =
- ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp);
-}
-
-void FirConverter::genFIR(
- const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
- mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
- // Register if a declare target construct intended for a target device was
- // found
- ompDeviceCodeFound =
- ompDeviceCodeFound ||
- Fortran::lower::isOpenMPDeviceDeclareTarget(*this, getEval(), ompDecl);
- genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl);
- for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
- genFIR(e);
- builder->restoreInsertionPoint(insertPt);
-}
-
void FirConverter::genFIR(const Fortran::parser::OpenStmt &stmt) {
mlir::Value iostat = genOpenStatement(*this, stmt);
genIoConditionBranches(getEval(), stmt.v, iostat);
@@ -3752,13 +3688,7 @@ void FirConverter::instantiateVar(const Fortran::lower::pft::Variable &var,
Fortran::lower::AggregateStoreMap &storeMap) {
Fortran::lower::instantiateVariable(*this, var, localSymbols, storeMap);
if (var.hasSymbol()) {
- if (var.getSymbol().test(
- Fortran::semantics::Symbol::Flag::OmpThreadprivate))
- Fortran::lower::genThreadprivateOp(*this, var);
-
- if (var.getSymbol().test(
- Fortran::semantics::Symbol::Flag::OmpDeclareTarget))
- Fortran::lower::genDeclareTargetIntGlobal(*this, var);
+ OpenMPBase::instantiateVariable(*this, var);
}
}
@@ -4443,16 +4373,6 @@ void FirConverter::finalizeOpenACCLowering() {
accRoutineInfos);
}
-/// Performing OpenMP lowering actions that were deferred to the end of
-/// lowering.
-void FirConverter::finalizeOpenMPLowering(
- const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
- // Set the module attribute related to OpenMP requires directives
- if (ompDeviceCodeFound)
- Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
- globalOmpRequiresSymbol);
-}
-
} // namespace Fortran::lower
Fortran::evaluate::FoldingContext
diff --git a/flang/lib/Lower/ConverterMixin.h b/flang/lib/Lower/ConverterMixin.h
new file mode 100644
index 00000000000000..a873ff36d0f600
--- /dev/null
+++ b/flang/lib/Lower/ConverterMixin.h
@@ -0,0 +1,28 @@
+//===-- ConverterMixin.h --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_LOWER_CONVERTERMIXIN_H
+#define FORTRAN_LOWER_CONVERTERMIXIN_H
+
+namespace Fortran::lower {
+
+template <typename FirConverterT> class ConverterMixinBase {
+public:
+ FirConverterT *This() { return static_cast<FirConverterT *>(this); }
+ const FirConverterT *This() const {
+ return static_cast<const FirConverterT *>(this);
+ }
+};
+
+} // namespace Fortran::lower
+
+#endif // FORTRAN_LOWER_CONVERTERMIXIN_H
diff --git a/flang/lib/Lower/FirConverter.h b/flang/lib/Lower/FirConverter.h
index 45806eda5df227..b6f3044aa3777d 100644
--- a/flang/lib/Lower/FirConverter.h
+++ b/flang/lib/Lower/FirConverter.h
@@ -13,6 +13,9 @@
#ifndef FORTRAN_LOWER_FIRCONVERTER_H
#define FORTRAN_LOWER_FIRCONVERTER_H
+#include "ConverterMixin.h"
+#include "OpenMPMixin.h"
+
#include "flang/Common/Fortran.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
@@ -74,7 +77,11 @@
namespace Fortran::lower {
-class FirConverter : public Fortran::lower::AbstractConverter {
+class FirConverter : public Fortran::lower::AbstractConverter,
+ public OpenMPMixin<FirConverter> {
+ using OpenMPBase = OpenMPMixin<FirConverter>;
+ using OpenMPBase::genFIR;
+
public:
explicit FirConverter(Fortran::lower::LoweringBridge &bridge)
: Fortran::lower::AbstractConverter(bridge.getLoweringOptions()),
@@ -83,6 +90,20 @@ class FirConverter : public Fortran::lower::AbstractConverter {
void run(Fortran::lower::pft::Program &pft);
+public:
+ // The interface that mixin is expecting.
+
+ Fortran::lower::LoweringBridge &getBridge() { return bridge; }
+ fir::FirOpBuilder &getBuilder() {
+ assert(builder);
+ return *builder;
+ }
+ Fortran::lower::pft::Evaluation &getEval() {
+ assert(evalPtr);
+ return *evalPtr;
+ }
+ Fortran::lower::SymMap &getSymTable() { return localSymbols; }
+
/// The core of the conversion: take an evaluation and generate FIR for it.
/// The generation for each individual element of PFT is done via a specific
/// genFIR function (see below).
@@ -141,8 +162,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
void genFIR(const Fortran::parser::OpenACCConstruct &);
void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &);
void genFIR(const Fortran::parser::OpenACCRoutineConstruct &);
- void genFIR(const Fortran::parser::OpenMPConstruct &);
- void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &);
void genFIR(const Fortran::parser::OpenStmt &);
void genFIR(const Fortran::parser::PauseStmt &);
void genFIR(const Fortran::parser::PointerAssignmentStmt &);
@@ -194,7 +213,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
void genFIR(const Fortran::parser::IfStmt &) {} // nop
void genFIR(const Fortran::parser::IfThenStmt &) {} // nop
void genFIR(const Fortran::parser::NonLabelDoStmt &) {} // nop
- void genFIR(const Fortran::parser::OmpEndLoopDirective &) {} // nop
void genFIR(const Fortran::parser::SelectTypeStmt &) {} // nop
void genFIR(const Fortran::parser::TypeGuardStmt &) {} // nop
@@ -687,7 +705,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::Location toLocation();
void setCurrentEval(Fortran::lower::pft::Evaluation &eval);
- Fortran::lower::pft::Evaluation &getEval();
std::optional<Fortran::evaluate::Shape>
getShape(const Fortran::lower::SomeExpr &expr);
@@ -730,8 +747,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::Type eleTy);
void finalizeOpenACCLowering();
- void finalizeOpenMPLowering(
- const Fortran::semantics::Symbol *globalOmpRequiresSymbol);
//===--------------------------------------------------------------------===//
@@ -779,10 +794,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Deferred OpenACC routine attachment.
Fortran::lower::AccRoutineInfoMappingList accRoutineInfos;
- /// Whether an OpenMP target region or declare target function/subroutine
- /// intended for device offloading has been detected
- bool ompDeviceCodeFound = false;
-
const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr};
};
@@ -1224,11 +1235,6 @@ FirConverter::setCurrentEval(Fortran::lower::pft::Evaluation &eval) {
evalPtr = &eval;
}
-inline Fortran::lower::pft::Evaluation &FirConverter::getEval() {
- assert(evalPtr);
- return *evalPtr;
-}
-
std::optional<Fortran::evaluate::Shape> inline FirConverter::getShape(
const Fortran::lower::SomeExpr &expr) {
return Fortran::evaluate::GetShape(foldingContext, expr);
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index eeba87fcd15116..5ca7be5da26a60 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -10,12 +10,15 @@
//
//===----------------------------------------------------------------------===//
-#include "flang/Lower/OpenMP.h"
+#include "FirConverter.h"
+#include "OpenMPMixin.h"
+
#include "DirectivesCommon.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/OpenMP.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Optimizer/Builder/BoxValue.h"
@@ -41,6 +44,25 @@ using DeclareTargetCapturePair =
std::pair<mlir::omp::DeclareTargetCaptureClause,
Fortran::semantics::Symbol>;
+namespace Fortran::lower {
+
+template <>
+Fortran::lower::LoweringBridge &OpenMPMixin<FirConverter>::getBridge() {
+ return This()->FirConverter::getBridge();
+}
+template <> fir::FirOpBuilder &OpenMPMixin<FirConverter>::getBuilder() {
+ return This()->FirConverter::getBuilder();
+}
+template <>
+Fortran::lower::pft::Evaluation &OpenMPMixin<FirConverter>::getEval() {
+ return This()->FirConverter::getEval();
+}
+template <> Fortran::lower::SymMap &OpenMPMixin<FirConverter>::getSymTable() {
+ return This()->FirConverter::getSymTable();
+}
+
+} // namespace Fortran::lower
+
//===----------------------------------------------------------------------===//
// Common helper functions
//===----------------------------------------------------------------------===//
@@ -3860,3 +3882,97 @@ void Fortran::lower::genOpenMPRequires(
offloadMod.setRequires(mlirFlags);
}
}
+
+namespace Fortran::lower {
+
+template <>
+void OpenMPMixin<FirConverter>::genFIR(
+ const Fortran::parser::OpenMPConstruct &omp) {
+ mlir::OpBuilder::InsertPoint insertPt = getBuilder().saveInsertionPoint();
+ getSymTable().pushScope();
+ genOpenMPConstruct(*This(), getBridge().getSemanticsContext(), getEval(),
+ omp);
+
+ const Fortran::parser::OpenMPLoopConstruct *ompLoop =
+ std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
+ const Fortran::parser::OpenMPBlockConstruct *ompBlock =
+ std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u);
+
+ // If loop is part of an OpenMP Construct then the OpenMP dialect
+ // workshare loop operation has already been created. Only the
+ // body needs to be created here and the do_loop can be skipped.
+ // Skip the number of collapsed loops, which is 1 when there is a
+ // no collapse requested.
+
+ Fortran::lower::pft::Evaluation *curEval = &getEval();
+ const Fortran::parser::OmpClauseList *loopOpClauseList = nullptr;
+ if (ompLoop) {
+ loopOpClauseList = &std::get<Fortran::parser::OmpClauseList>(
+ std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t);
+ int64_t collapseValue = Fortran::lower::getCollapseValue(*loopOpClauseList);
+
+ curEval = &curEval->getFirstNestedEvaluation();
+ for (int64_t i = 1; i < collapseValue; i++) {
+ curEval = &*std::next(curEval->getNestedEvaluations().begin());
+ }
+ }
+
+ for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
+ This()->genFIR(e);
+
+ if (ompLoop) {
+ genOpenMPReduction(*This(), *loopOpClauseList);
+ } else if (ompBlock) {
+ const auto &blockStart =
+ std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t);
+ const auto &blockClauses =
+ std::get<Fortran::parser::OmpClauseList>(blockStart.t);
+ genOpenMPReduction(*This(), blockClauses);
+ }
+
+ getSymTable().popScope();
+ getBuilder().restoreInsertionPoint(insertPt);
+
+ // Register if a target region was found
+ ompDeviceCodeFound =
+ ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp);
+}
+
+template <>
+void OpenMPMixin<FirConverter>::genFIR(
+ const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
+ mlir::OpBuilder::InsertPoint insertPt = getBuilder().saveInsertionPoint();
+ // Register if a declare target construct intended for a target device was
+ // found
+ ompDeviceCodeFound =
+ ompDeviceCodeFound ||
+ Fortran::lower::isOpenMPDeviceDeclareTarget(*This(), getEval(), ompDecl);
+ genOpenMPDeclarativeConstruct(*This(), getEval(), ompDecl);
+ for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
+ This()->genFIR(e);
+ getBuilder().restoreInsertionPoint(insertPt);
+}
+
+template <>
+void OpenMPMixin<FirConverter>::instantiateVariable(
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::lower::pft::Variable &var) {
+ assert(var.hasSymbol() && "Expecting symbol");
+ if (var.getSymbol().test(Fortran::semantics::Symbol::Flag::OmpThreadprivate))
+ genThreadprivateOp(*This(), var);
+
+ if (var.getSymbol().test(Fortran::semantics::Symbol::Flag::OmpDeclareTarget))
+ genDeclareTargetIntGlobal(*This(), var);
+}
+
+template <>
+void OpenMPMixin<FirConverter>::finalize(
+ const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+ // Set the module attribute related to OpenMP requires directives
+ if (ompDeviceCodeFound) {
+ genOpenMPRequires(This()->getModuleOp().getOperation(),
+ globalOmpRequiresSymbol);
+ }
+}
+
+} // namespace Fortran::lower
diff --git a/flang/lib/Lower/OpenMPMixin.h b/flang/lib/Lower/OpenMPMixin.h
new file mode 100644
index 00000000000000..7339d9eb4fc61f
--- /dev/null
+++ b/flang/lib/Lower/OpenMPMixin.h
@@ -0,0 +1,66 @@
+//===-- OpenMPMixin.h -----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_LOWER_OPENMPMIXIN_H
+#define FORTRAN_LOWER_OPENMPMIXIN_H
+
+#include "ConverterMixin.h"
+#include "flang/Parser/parse-tree.h"
+
+namespace fir {
+class FirOpBuilder;
+}
+
+namespace Fortran::semantics {
+class Symbol;
+}
+
+namespace Fortran::lower {
+
+class AbstractConverter;
+class LoweringBridge;
+class SymMap;
+
+namespace pft {
+class Evaluation;
+class Variable;
+} // namespace pft
+
+template <typename ConverterT>
+class OpenMPMixin : public ConverterMixinBase<ConverterT> {
+public:
+ void genFIR(const Fortran::parser::OpenMPConstruct &);
+ void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &);
+
+ void genFIR(const Fortran::parser::OmpEndLoopDirective &) {} // nop
+
+ void instantiateVariable(Fortran::lower::AbstractConverter &converter,
+ const Fortran::lower::pft::Variable &var);
+ void finalize(const Fortran::semantics::Symbol *globalOmpRequiresSymbol);
+
+private:
+ // Shortcuts to call ConverterT:: functions. They can't be defined here
+ // because the definition of ConverterT is not available at this point.
+ Fortran::lower::LoweringBridge &getBridge();
+ fir::FirOpBuilder &getBuilder();
+ Fortran::lower::pft::Evaluation &getEval();
+ Fortran::lower::SymMap &getSymTable();
+
+private:
+ /// Whether a target region or declare target function/subroutine
+ /// intended for device offloading have been detected
+ bool ompDeviceCodeFound = false;
+};
+
+} // namespace Fortran::lower
+
+#endif // FORTRAN_LOWER_OPENMPMIXIN_H
More information about the llvm-branch-commits
mailing list