[llvm-branch-commits] [flang] [flang] Move OpenMP-related code from `FirConverter` to `OpenMPMixin` (PR #74866)

Krzysztof Parzyszek via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Dec 8 08:56:32 PST 2023


https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/74866

This improves the separation of the generic Fortran lowering and the lowering of OpenMP constructs.

The mixin is intended to be derived from via CRTP:
```
  class FirConverter : public OpenMPMixin<FirConverter> ...
```

The primary goal of the mixin is to implement `genFIR` functions that the derived converter can then call via
```
  std::visit([this](auto &&s) { genFIR(s); });
```

The mixin is also expecting a handful of functions to be present in the derived class, most importantly `genFIR(Evaluation&)`, plus getter classes for the op builder, symbol table, etc.

The pre-existing PFT-lowering functionality is preserved.

>From e754d86e1c35d4417dfae8d270de7dee318f783e Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Fri, 8 Dec 2023 09:13:11 -0600
Subject: [PATCH] [flang] Move OpenMP-related code from `FirConverter` to
 `OpenMPMixin`

This improves the separation of the generic Fortran lowering and the
lowering of OpenMP constructs.

The mixin is intended to be derived from via CRTP:
```
  class FirConverter : public OpenMPMixin<FirConverter> ...
```

The primary goal of the mixin is to implement `genFIR` functions
that the derived converter can then call via
```
  std::visit([this](auto &&s) { genFIR(s); });
```

The mixin is also expecting a handful of functions to be present
in the derived class, most importantly `genFIR(Evaluation&)`, plus
getter classes for the op builder, symbol table, etc.

The pre-existing PFT-lowering functionality is preserved.
---
 flang/lib/Lower/Bridge.cpp       |  84 +---------------------
 flang/lib/Lower/ConverterMixin.h |  28 ++++++++
 flang/lib/Lower/FirConverter.h   |  42 ++++++-----
 flang/lib/Lower/OpenMP.cpp       | 118 ++++++++++++++++++++++++++++++-
 flang/lib/Lower/OpenMPMixin.h    |  66 +++++++++++++++++
 5 files changed, 237 insertions(+), 101 deletions(-)
 create mode 100644 flang/lib/Lower/ConverterMixin.h
 create mode 100644 flang/lib/Lower/OpenMPMixin.h

diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 885c9307b8caf..061f9f29ffb00 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -170,7 +170,7 @@ void FirConverter::run(Fortran::lower::pft::Program &pft) {
     });
 
   finalizeOpenACCLowering();
-  finalizeOpenMPLowering(globalOmpRequiresSymbol);
+  OpenMPBase::finalize(globalOmpRequiresSymbol);
 }
 
 /// Generate FIR for Evaluation \p eval.
@@ -977,70 +977,6 @@ void FirConverter::genFIR(const Fortran::parser::OpenACCRoutineConstruct &acc) {
   // Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &)
 }
 
-void FirConverter::genFIR(const Fortran::parser::OpenMPConstruct &omp) {
-  mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
-  localSymbols.pushScope();
-  genOpenMPConstruct(*this, bridge.getSemanticsContext(), getEval(), omp);
-
-  const Fortran::parser::OpenMPLoopConstruct *ompLoop =
-      std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
-  const Fortran::parser::OpenMPBlockConstruct *ompBlock =
-      std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u);
-
-  // If loop is part of an OpenMP Construct then the OpenMP dialect
-  // workshare loop operation has already been created. Only the
-  // body needs to be created here and the do_loop can be skipped.
-  // Skip the number of collapsed loops, which is 1 when there is a
-  // no collapse requested.
-
-  Fortran::lower::pft::Evaluation *curEval = &getEval();
-  const Fortran::parser::OmpClauseList *loopOpClauseList = nullptr;
-  if (ompLoop) {
-    loopOpClauseList = &std::get<Fortran::parser::OmpClauseList>(
-        std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t);
-    int64_t collapseValue = Fortran::lower::getCollapseValue(*loopOpClauseList);
-
-    curEval = &curEval->getFirstNestedEvaluation();
-    for (int64_t i = 1; i < collapseValue; i++) {
-      curEval = &*std::next(curEval->getNestedEvaluations().begin());
-    }
-  }
-
-  for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
-    genFIR(e);
-
-  if (ompLoop) {
-    genOpenMPReduction(*this, *loopOpClauseList);
-  } else if (ompBlock) {
-    const auto &blockStart =
-        std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t);
-    const auto &blockClauses =
-        std::get<Fortran::parser::OmpClauseList>(blockStart.t);
-    genOpenMPReduction(*this, blockClauses);
-  }
-
-  localSymbols.popScope();
-  builder->restoreInsertionPoint(insertPt);
-
-  // Register if a target region was found
-  ompDeviceCodeFound =
-      ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp);
-}
-
-void FirConverter::genFIR(
-    const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
-  mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
-  // Register if a declare target construct intended for a target device was
-  // found
-  ompDeviceCodeFound =
-      ompDeviceCodeFound ||
-      Fortran::lower::isOpenMPDeviceDeclareTarget(*this, getEval(), ompDecl);
-  genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl);
-  for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
-    genFIR(e);
-  builder->restoreInsertionPoint(insertPt);
-}
-
 void FirConverter::genFIR(const Fortran::parser::OpenStmt &stmt) {
   mlir::Value iostat = genOpenStatement(*this, stmt);
   genIoConditionBranches(getEval(), stmt.v, iostat);
@@ -3752,13 +3688,7 @@ void FirConverter::instantiateVar(const Fortran::lower::pft::Variable &var,
                                   Fortran::lower::AggregateStoreMap &storeMap) {
   Fortran::lower::instantiateVariable(*this, var, localSymbols, storeMap);
   if (var.hasSymbol()) {
-    if (var.getSymbol().test(
-            Fortran::semantics::Symbol::Flag::OmpThreadprivate))
-      Fortran::lower::genThreadprivateOp(*this, var);
-
-    if (var.getSymbol().test(
-            Fortran::semantics::Symbol::Flag::OmpDeclareTarget))
-      Fortran::lower::genDeclareTargetIntGlobal(*this, var);
+    OpenMPBase::instantiateVariable(*this, var);
   }
 }
 
@@ -4443,16 +4373,6 @@ void FirConverter::finalizeOpenACCLowering() {
                                                    accRoutineInfos);
 }
 
-/// Performing OpenMP lowering actions that were deferred to the end of
-/// lowering.
-void FirConverter::finalizeOpenMPLowering(
-    const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
-  // Set the module attribute related to OpenMP requires directives
-  if (ompDeviceCodeFound)
-    Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
-                                      globalOmpRequiresSymbol);
-}
-
 } // namespace Fortran::lower
 
 Fortran::evaluate::FoldingContext
diff --git a/flang/lib/Lower/ConverterMixin.h b/flang/lib/Lower/ConverterMixin.h
new file mode 100644
index 0000000000000..a873ff36d0f60
--- /dev/null
+++ b/flang/lib/Lower/ConverterMixin.h
@@ -0,0 +1,28 @@
+//===-- ConverterMixin.h --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_LOWER_CONVERTERMIXIN_H
+#define FORTRAN_LOWER_CONVERTERMIXIN_H
+
+namespace Fortran::lower {
+
+template <typename FirConverterT> class ConverterMixinBase {
+public:
+  FirConverterT *This() { return static_cast<FirConverterT *>(this); }
+  const FirConverterT *This() const {
+    return static_cast<const FirConverterT *>(this);
+  }
+};
+
+} // namespace Fortran::lower
+
+#endif // FORTRAN_LOWER_CONVERTERMIXIN_H
diff --git a/flang/lib/Lower/FirConverter.h b/flang/lib/Lower/FirConverter.h
index 51b8bd4fa0b38..0214ea88b1e5b 100644
--- a/flang/lib/Lower/FirConverter.h
+++ b/flang/lib/Lower/FirConverter.h
@@ -13,6 +13,9 @@
 #ifndef FORTRAN_LOWER_FIRCONVERTER_H
 #define FORTRAN_LOWER_FIRCONVERTER_H
 
+#include "ConverterMixin.h"
+#include "OpenMPMixin.h"
+
 #include "flang/Common/Fortran.h"
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Lower/Bridge.h"
@@ -74,7 +77,11 @@
 
 namespace Fortran::lower {
 
-class FirConverter : public Fortran::lower::AbstractConverter {
+class FirConverter : public Fortran::lower::AbstractConverter,
+                     public OpenMPMixin<FirConverter> {
+  using OpenMPBase = OpenMPMixin<FirConverter>;
+  using OpenMPBase::genFIR;
+
 public:
   explicit FirConverter(Fortran::lower::LoweringBridge &bridge)
       : Fortran::lower::AbstractConverter(bridge.getLoweringOptions()),
@@ -83,6 +90,20 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   void run(Fortran::lower::pft::Program &pft);
 
+public:
+  // The interface that mixin is expecting.
+
+  Fortran::lower::LoweringBridge &getBridge() { return bridge; }
+  fir::FirOpBuilder &getBuilder() {
+    assert(builder);
+    return *builder;
+  }
+  Fortran::lower::pft::Evaluation &getEval() {
+    assert(evalPtr);
+    return *evalPtr;
+  }
+  Fortran::lower::SymMap &getSymTable() { return localSymbols; }
+
   /// The core of the conversion: take an evaluation and generate FIR for it.
   /// The generation for each individual element of PFT is done via a specific
   /// genFIR function (see below).
@@ -141,8 +162,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   void genFIR(const Fortran::parser::OpenACCConstruct &);
   void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &);
   void genFIR(const Fortran::parser::OpenACCRoutineConstruct &);
-  void genFIR(const Fortran::parser::OpenMPConstruct &);
-  void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &);
   void genFIR(const Fortran::parser::OpenStmt &);
   void genFIR(const Fortran::parser::PauseStmt &);
   void genFIR(const Fortran::parser::PointerAssignmentStmt &);
@@ -194,7 +213,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   void genFIR(const Fortran::parser::IfStmt &) {}              // nop
   void genFIR(const Fortran::parser::IfThenStmt &) {}          // nop
   void genFIR(const Fortran::parser::NonLabelDoStmt &) {}      // nop
-  void genFIR(const Fortran::parser::OmpEndLoopDirective &) {} // nop
   void genFIR(const Fortran::parser::SelectTypeStmt &) {}      // nop
   void genFIR(const Fortran::parser::TypeGuardStmt &) {}       // nop
 
@@ -684,7 +702,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   mlir::Location toLocation();
 
   void setCurrentEval(Fortran::lower::pft::Evaluation &eval);
-  Fortran::lower::pft::Evaluation &getEval();
 
   std::optional<Fortran::evaluate::Shape>
   getShape(const Fortran::lower::SomeExpr &expr);
@@ -707,8 +724,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   void analyzeExplicitSpace(const Fortran::parser::PointerAssignmentStmt &s);
   void analyzeExplicitSpace(const Fortran::parser::WhereBodyConstruct &body);
   void analyzeExplicitSpace(const Fortran::parser::WhereConstruct &c);
-  void analyzeExplicitSpace(
-      const Fortran::parser::WhereConstruct::Elsewhere *ew);
+  void
+  analyzeExplicitSpace(const Fortran::parser::WhereConstruct::Elsewhere *ew);
   void analyzeExplicitSpace(
       const Fortran::parser::WhereConstruct::MaskedElsewhere &ew);
   void analyzeExplicitSpace(const Fortran::parser::WhereConstructStmt &ws);
@@ -727,8 +744,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                                           mlir::Type eleTy);
 
   void finalizeOpenACCLowering();
-  void finalizeOpenMPLowering(
-      const Fortran::semantics::Symbol *globalOmpRequiresSymbol);
 
   //===--------------------------------------------------------------------===//
 
@@ -776,10 +791,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Deferred OpenACC routine attachment.
   Fortran::lower::AccRoutineInfoMappingList accRoutineInfos;
 
-  /// Whether an OpenMP target region or declare target function/subroutine
-  /// intended for device offloading has been detected
-  bool ompDeviceCodeFound = false;
-
   const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr};
 };
 
@@ -1220,11 +1231,6 @@ FirConverter::setCurrentEval(Fortran::lower::pft::Evaluation &eval) {
   evalPtr = &eval;
 }
 
-inline Fortran::lower::pft::Evaluation &FirConverter::getEval() {
-  assert(evalPtr);
-  return *evalPtr;
-}
-
 std::optional<Fortran::evaluate::Shape> inline FirConverter::getShape(
     const Fortran::lower::SomeExpr &expr) {
   return Fortran::evaluate::GetShape(foldingContext, expr);
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index eeba87fcd1511..5ca7be5da26a6 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -10,12 +10,15 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "flang/Lower/OpenMP.h"
+#include "FirConverter.h"
+#include "OpenMPMixin.h"
+
 #include "DirectivesCommon.h"
 #include "flang/Common/idioms.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertExpr.h"
 #include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/OpenMP.h"
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Lower/StatementContext.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
@@ -41,6 +44,25 @@ using DeclareTargetCapturePair =
     std::pair<mlir::omp::DeclareTargetCaptureClause,
               Fortran::semantics::Symbol>;
 
+namespace Fortran::lower {
+
+template <>
+Fortran::lower::LoweringBridge &OpenMPMixin<FirConverter>::getBridge() {
+  return This()->FirConverter::getBridge();
+}
+template <> fir::FirOpBuilder &OpenMPMixin<FirConverter>::getBuilder() {
+  return This()->FirConverter::getBuilder();
+}
+template <>
+Fortran::lower::pft::Evaluation &OpenMPMixin<FirConverter>::getEval() {
+  return This()->FirConverter::getEval();
+}
+template <> Fortran::lower::SymMap &OpenMPMixin<FirConverter>::getSymTable() {
+  return This()->FirConverter::getSymTable();
+}
+
+} // namespace Fortran::lower
+
 //===----------------------------------------------------------------------===//
 // Common helper functions
 //===----------------------------------------------------------------------===//
@@ -3860,3 +3882,97 @@ void Fortran::lower::genOpenMPRequires(
     offloadMod.setRequires(mlirFlags);
   }
 }
+
+namespace Fortran::lower {
+
+template <>
+void OpenMPMixin<FirConverter>::genFIR(
+    const Fortran::parser::OpenMPConstruct &omp) {
+  mlir::OpBuilder::InsertPoint insertPt = getBuilder().saveInsertionPoint();
+  getSymTable().pushScope();
+  genOpenMPConstruct(*This(), getBridge().getSemanticsContext(), getEval(),
+                     omp);
+
+  const Fortran::parser::OpenMPLoopConstruct *ompLoop =
+      std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
+  const Fortran::parser::OpenMPBlockConstruct *ompBlock =
+      std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u);
+
+  // If loop is part of an OpenMP Construct then the OpenMP dialect
+  // workshare loop operation has already been created. Only the
+  // body needs to be created here and the do_loop can be skipped.
+  // Skip the number of collapsed loops, which is 1 when there is a
+  // no collapse requested.
+
+  Fortran::lower::pft::Evaluation *curEval = &getEval();
+  const Fortran::parser::OmpClauseList *loopOpClauseList = nullptr;
+  if (ompLoop) {
+    loopOpClauseList = &std::get<Fortran::parser::OmpClauseList>(
+        std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t);
+    int64_t collapseValue = Fortran::lower::getCollapseValue(*loopOpClauseList);
+
+    curEval = &curEval->getFirstNestedEvaluation();
+    for (int64_t i = 1; i < collapseValue; i++) {
+      curEval = &*std::next(curEval->getNestedEvaluations().begin());
+    }
+  }
+
+  for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
+    This()->genFIR(e);
+
+  if (ompLoop) {
+    genOpenMPReduction(*This(), *loopOpClauseList);
+  } else if (ompBlock) {
+    const auto &blockStart =
+        std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t);
+    const auto &blockClauses =
+        std::get<Fortran::parser::OmpClauseList>(blockStart.t);
+    genOpenMPReduction(*This(), blockClauses);
+  }
+
+  getSymTable().popScope();
+  getBuilder().restoreInsertionPoint(insertPt);
+
+  // Register if a target region was found
+  ompDeviceCodeFound =
+      ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp);
+}
+
+template <>
+void OpenMPMixin<FirConverter>::genFIR(
+    const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
+  mlir::OpBuilder::InsertPoint insertPt = getBuilder().saveInsertionPoint();
+  // Register if a declare target construct intended for a target device was
+  // found
+  ompDeviceCodeFound =
+      ompDeviceCodeFound ||
+      Fortran::lower::isOpenMPDeviceDeclareTarget(*This(), getEval(), ompDecl);
+  genOpenMPDeclarativeConstruct(*This(), getEval(), ompDecl);
+  for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
+    This()->genFIR(e);
+  getBuilder().restoreInsertionPoint(insertPt);
+}
+
+template <>
+void OpenMPMixin<FirConverter>::instantiateVariable(
+    Fortran::lower::AbstractConverter &converter,
+    const Fortran::lower::pft::Variable &var) {
+  assert(var.hasSymbol() && "Expecting symbol");
+  if (var.getSymbol().test(Fortran::semantics::Symbol::Flag::OmpThreadprivate))
+    genThreadprivateOp(*This(), var);
+
+  if (var.getSymbol().test(Fortran::semantics::Symbol::Flag::OmpDeclareTarget))
+    genDeclareTargetIntGlobal(*This(), var);
+}
+
+template <>
+void OpenMPMixin<FirConverter>::finalize(
+    const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+  // Set the module attribute related to OpenMP requires directives
+  if (ompDeviceCodeFound) {
+    genOpenMPRequires(This()->getModuleOp().getOperation(),
+                      globalOmpRequiresSymbol);
+  }
+}
+
+} // namespace Fortran::lower
diff --git a/flang/lib/Lower/OpenMPMixin.h b/flang/lib/Lower/OpenMPMixin.h
new file mode 100644
index 0000000000000..7339d9eb4fc61
--- /dev/null
+++ b/flang/lib/Lower/OpenMPMixin.h
@@ -0,0 +1,66 @@
+//===-- OpenMPMixin.h -----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_LOWER_OPENMPMIXIN_H
+#define FORTRAN_LOWER_OPENMPMIXIN_H
+
+#include "ConverterMixin.h"
+#include "flang/Parser/parse-tree.h"
+
+namespace fir {
+class FirOpBuilder;
+}
+
+namespace Fortran::semantics {
+class Symbol;
+}
+
+namespace Fortran::lower {
+
+class AbstractConverter;
+class LoweringBridge;
+class SymMap;
+
+namespace pft {
+class Evaluation;
+class Variable;
+} // namespace pft
+
+template <typename ConverterT>
+class OpenMPMixin : public ConverterMixinBase<ConverterT> {
+public:
+  void genFIR(const Fortran::parser::OpenMPConstruct &);
+  void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &);
+
+  void genFIR(const Fortran::parser::OmpEndLoopDirective &) {} // nop
+
+  void instantiateVariable(Fortran::lower::AbstractConverter &converter,
+                           const Fortran::lower::pft::Variable &var);
+  void finalize(const Fortran::semantics::Symbol *globalOmpRequiresSymbol);
+
+private:
+  // Shortcuts to call ConverterT:: functions. They can't be defined here
+  // because the definition of ConverterT is not available at this point.
+  Fortran::lower::LoweringBridge &getBridge();
+  fir::FirOpBuilder &getBuilder();
+  Fortran::lower::pft::Evaluation &getEval();
+  Fortran::lower::SymMap &getSymTable();
+
+private:
+  /// Whether a target region or declare target function/subroutine
+  /// intended for device offloading have been detected
+  bool ompDeviceCodeFound = false;
+};
+
+} // namespace Fortran::lower
+
+#endif // FORTRAN_LOWER_OPENMPMIXIN_H



More information about the llvm-branch-commits mailing list