[flang-commits] [flang] 29aa749 - [OpenMP][Flang][MLIR] Lowering of OpenMP requires directive from parse tree to MLIR

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Thu Sep 14 02:35:05 PDT 2023


Author: Sergio Afonso
Date: 2023-09-14T10:34:54+01:00
New Revision: 29aa749087be38d3e5a3a37e0b8e8ab74e9f79aa

URL: https://github.com/llvm/llvm-project/commit/29aa749087be38d3e5a3a37e0b8e8ab74e9f79aa
DIFF: https://github.com/llvm/llvm-project/commit/29aa749087be38d3e5a3a37e0b8e8ab74e9f79aa.diff

LOG: [OpenMP][Flang][MLIR] Lowering of OpenMP requires directive from parse tree to MLIR

This patch implements the lowering of the OpenMP 'requires' directive
from Flang parse tree to MLIR attributes attached to the top-level
module.

Target-related 'requires' clauses are gathered and combined for each top-level
unit during semantics. Lastly, a single module-level `omp.requires` attribute
is attached to the MLIR module with that information at the end of the process.

The `atomic_default_mem_order` clause is not addressed by this patch, but
rather it will come as a separate patch and follow a different approach.

Depends on D147214, D150328, D150329 and D157983.

Differential Revision: https://reviews.llvm.org/D147218

Added: 
    flang/test/Lower/OpenMP/Todo/requires-unnamed-common.f90
    flang/test/Lower/OpenMP/requires-common.f90
    flang/test/Lower/OpenMP/requires-notarget.f90
    flang/test/Lower/OpenMP/requires.f90

Modified: 
    flang/include/flang/Lower/OpenMP.h
    flang/lib/Lower/Bridge.cpp
    flang/lib/Lower/OpenMP.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index 08863c79264c259..3563de7d091e894 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -34,6 +34,10 @@ struct OmpEndLoopDirective;
 struct OmpClauseList;
 } // namespace parser
 
+namespace semantics {
+class Symbol;
+} // namespace semantics
+
 namespace lower {
 
 class AbstractConverter;
@@ -62,6 +66,13 @@ fir::ConvertOp getConvertFromReductionOp(mlir::Operation *, mlir::Value);
 void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
                      mlir::Value, fir::ConvertOp * = nullptr);
 void removeStoreOp(mlir::Operation *, mlir::Value);
+
+bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
+bool isOpenMPDeviceDeclareTarget(Fortran::lower::AbstractConverter &,
+                                 Fortran::lower::pft::Evaluation &,
+                                 const parser::OpenMPDeclarativeConstruct &);
+void genOpenMPRequires(mlir::Operation *, const Fortran::semantics::Symbol *);
+
 } // namespace lower
 } // namespace Fortran
 

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index cbe108194dd212f..b06497e3f8ee823 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -50,6 +50,7 @@
 #include "flang/Parser/parse-tree.h"
 #include "flang/Runtime/iostat.h"
 #include "flang/Semantics/runtime-type-info.h"
+#include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/IR/PatternMatch.h"
@@ -294,12 +295,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     //    that they are available before lowering any function that may use
     //    them.
     bool hasMainProgram = false;
+    const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
     for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
       std::visit(Fortran::common::visitors{
                      [&](Fortran::lower::pft::FunctionLikeUnit &f) {
                        if (f.isMainProgram())
                          hasMainProgram = true;
                        declareFunction(f);
+                       if (!globalOmpRequiresSymbol)
+                         globalOmpRequiresSymbol = f.getScope().symbol();
                      },
                      [&](Fortran::lower::pft::ModuleLikeUnit &m) {
                        lowerModuleDeclScope(m);
@@ -307,7 +311,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                             m.nestedFunctions)
                          declareFunction(f);
                      },
-                     [&](Fortran::lower::pft::BlockDataUnit &b) {},
+                     [&](Fortran::lower::pft::BlockDataUnit &b) {
+                       if (!globalOmpRequiresSymbol)
+                         globalOmpRequiresSymbol = b.symTab.symbol();
+                     },
                      [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
                  },
                  u);
@@ -352,6 +359,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       });
 
     finalizeOpenACCLowering();
+    finalizeOpenMPLowering(globalOmpRequiresSymbol);
   }
 
   /// Declare a function.
@@ -2347,10 +2355,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
     localSymbols.popScope();
     builder->restoreInsertionPoint(insertPt);
+
+    // Register if a target region was found
+    ompDeviceCodeFound =
+        ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp);
   }
 
   void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+    // Register if a declare target construct intended for a target device was
+    // found
+    ompDeviceCodeFound =
+        ompDeviceCodeFound ||
+        Fortran::lower::isOpenMPDeviceDeclareTarget(*this, getEval(), ompDecl);
     genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl);
     for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
       genFIR(e);
@@ -4758,6 +4775,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                                                      accRoutineInfos);
   }
 
+  /// Performing OpenMP lowering actions that were deferred to the end of
+  /// lowering.
+  void finalizeOpenMPLowering(
+      const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+    // Set the module attribute related to OpenMP requires directives
+    if (ompDeviceCodeFound)
+      Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
+                                        globalOmpRequiresSymbol);
+  }
+
   //===--------------------------------------------------------------------===//
 
   Fortran::lower::LoweringBridge &bridge;
@@ -4804,6 +4831,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Deferred OpenACC routine attachment.
   Fortran::lower::AccRoutineInfoMappingList accRoutineInfos;
+
+  /// Whether an OpenMP target region or declare target function/subroutine
+  /// intended for device offloading has been detected
+  bool ompDeviceCodeFound = false;
 };
 
 } // namespace

diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index be9c0dcdbdbf485..bc18898426aa85b 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -78,9 +78,7 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
 static void gatherFuncAndVarSyms(
     const Fortran::parser::OmpObjectList &objList,
     mlir::omp::DeclareTargetCaptureClause clause,
-    llvm::SmallVectorImpl<std::pair<mlir::omp::DeclareTargetCaptureClause,
-                                    Fortran::semantics::Symbol>>
-        &symbolAndClause) {
+    llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
   for (const Fortran::parser::OmpObject &ompObject : objList.v) {
     Fortran::common::visit(
         Fortran::common::visitors{
@@ -2453,6 +2451,71 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
                                  reductionDeclSymbols));
 }
 
+/// Extract the list of function and variable symbols affected by the given
+/// 'declare target' directive and return the intended device type for them.
+static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
+    llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
+
+  // The default capture type
+  mlir::omp::DeclareTargetDeviceType deviceType =
+      mlir::omp::DeclareTargetDeviceType::any;
+  const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
+      declareTargetConstruct.t);
+
+  if (const auto *objectList{
+          Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
+    // Case: declare target(func, var1, var2)
+    gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
+                         symbolAndClause);
+  } else if (const auto *clauseList{
+                 Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
+                     spec.u)}) {
+    if (clauseList->v.empty()) {
+      // Case: declare target, implicit capture of function
+      symbolAndClause.emplace_back(
+          mlir::omp::DeclareTargetCaptureClause::to,
+          eval.getOwningProcedure()->getSubprogramSymbol());
+    }
+
+    ClauseProcessor cp(converter, *clauseList);
+    cp.processTo(symbolAndClause);
+    cp.processLink(symbolAndClause);
+    cp.processDeviceType(deviceType);
+    cp.processTODO<Fortran::parser::OmpClause::Indirect>(
+        converter.getCurrentLocation(),
+        llvm::omp::Directive::OMPD_declare_target);
+  }
+
+  return deviceType;
+}
+
+static std::optional<mlir::omp::DeclareTargetDeviceType>
+getDeclareTargetFunctionDevice(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclareTargetConstruct
+        &declareTargetConstruct) {
+  llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
+  mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
+      converter, eval, declareTargetConstruct, symbolAndClause);
+
+  // Return the device type only if at least one of the targets for the
+  // directive is a function or subroutine
+  mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
+  for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
+    mlir::Operation *op = mod.lookupSymbol(
+        converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));
+
+    if (mlir::isa<mlir::func::FuncOp>(op))
+      return deviceType;
+  }
+
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // genOMP() Code generation helper functions
 //===----------------------------------------------------------------------===//
@@ -2973,35 +3036,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
                        &declareTargetConstruct) {
   llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
   mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
-
-  // The default capture type
-  mlir::omp::DeclareTargetDeviceType deviceType =
-      mlir::omp::DeclareTargetDeviceType::any;
-  const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
-      declareTargetConstruct.t);
-  if (const auto *objectList{
-          Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
-    // Case: declare target(func, var1, var2)
-    gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
-                         symbolAndClause);
-  } else if (const auto *clauseList{
-                 Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
-                     spec.u)}) {
-    if (clauseList->v.empty()) {
-      // Case: declare target, implicit capture of function
-      symbolAndClause.emplace_back(
-          mlir::omp::DeclareTargetCaptureClause::to,
-          eval.getOwningProcedure()->getSubprogramSymbol());
-    }
-
-    ClauseProcessor cp(converter, *clauseList);
-    cp.processTo(symbolAndClause);
-    cp.processLink(symbolAndClause);
-    cp.processDeviceType(deviceType);
-    cp.processTODO<Fortran::parser::OmpClause::Indirect>(
-        converter.getCurrentLocation(),
-        llvm::omp::Directive::OMPD_declare_target);
-  }
+  mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
+      converter, eval, declareTargetConstruct, symbolAndClause);
 
   for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
     mlir::Operation *op = mod.lookupSymbol(
@@ -3130,7 +3166,10 @@ void Fortran::lower::genOpenMPDeclarativeConstruct(
           },
           [&](const Fortran::parser::OpenMPRequiresConstruct
                   &requiresConstruct) {
-            TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct");
+            // Requires directives are gathered and processed in semantics and
+            // then combined in the lowering bridge before triggering codegen
+            // just once. Hence, there is no need to lower each individual
+            // occurrence here.
           },
           [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) {
             // The directive is lowered when instantiating the variable to
@@ -3444,3 +3483,72 @@ void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
     }
   }
 }
+
+bool Fortran::lower::isOpenMPTargetConstruct(
+    const Fortran::parser::OpenMPConstruct &omp) {
+  llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown;
+  if (const auto *block =
+          std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u)) {
+    const auto &begin =
+        std::get<Fortran::parser::OmpBeginBlockDirective>(block->t);
+    dir = std::get<Fortran::parser::OmpBlockDirective>(begin.t).v;
+  } else if (const auto *loop =
+                 std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u)) {
+    const auto &begin =
+        std::get<Fortran::parser::OmpBeginLoopDirective>(loop->t);
+    dir = std::get<Fortran::parser::OmpLoopDirective>(begin.t).v;
+  }
+  return llvm::omp::allTargetSet.test(dir);
+}
+
+bool Fortran::lower::isOpenMPDeviceDeclareTarget(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
+  return std::visit(
+      Fortran::common::visitors{
+          [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) {
+            mlir::omp::DeclareTargetDeviceType targetType =
+                getDeclareTargetFunctionDevice(converter, eval, ompReq)
+                    .value_or(mlir::omp::DeclareTargetDeviceType::host);
+            return targetType != mlir::omp::DeclareTargetDeviceType::host;
+          },
+          [&](const auto &) { return false; },
+      },
+      ompDecl.u);
+}
+
+void Fortran::lower::genOpenMPRequires(
+    mlir::Operation *mod, const Fortran::semantics::Symbol *symbol) {
+  using MlirRequires = mlir::omp::ClauseRequires;
+  using SemaRequires = Fortran::semantics::WithOmpDeclarative::RequiresFlag;
+
+  if (auto offloadMod =
+          llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) {
+    Fortran::semantics::WithOmpDeclarative::RequiresFlags semaFlags;
+    if (symbol) {
+      Fortran::common::visit(
+          [&](const auto &details) {
+            if constexpr (std::is_base_of_v<
+                              Fortran::semantics::WithOmpDeclarative,
+                              std::decay_t<decltype(details)>>) {
+              if (details.has_ompRequires())
+                semaFlags = *details.ompRequires();
+            }
+          },
+          symbol->details());
+    }
+
+    MlirRequires mlirFlags = MlirRequires::none;
+    if (semaFlags.test(SemaRequires::ReverseOffload))
+      mlirFlags = mlirFlags | MlirRequires::reverse_offload;
+    if (semaFlags.test(SemaRequires::UnifiedAddress))
+      mlirFlags = mlirFlags | MlirRequires::unified_address;
+    if (semaFlags.test(SemaRequires::UnifiedSharedMemory))
+      mlirFlags = mlirFlags | MlirRequires::unified_shared_memory;
+    if (semaFlags.test(SemaRequires::DynamicAllocators))
+      mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;
+
+    offloadMod.setRequires(mlirFlags);
+  }
+}

diff  --git a/flang/test/Lower/OpenMP/Todo/requires-unnamed-common.f90 b/flang/test/Lower/OpenMP/Todo/requires-unnamed-common.f90
new file mode 100644
index 000000000000000..578f71a55e43b68
--- /dev/null
+++ b/flang/test/Lower/OpenMP/Todo/requires-unnamed-common.f90
@@ -0,0 +1,23 @@
+! This test checks the lowering of REQUIRES inside of an unnamed BLOCK DATA.
+! The symbol of the `symTab` scope of the `BlockDataUnit` PFT node is null in
+! this case, resulting in the inability to store the REQUIRES flags gathered in
+! it.
+
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
+! XFAIL: *
+
+!CHECK:         module attributes {
+!CHECK-SAME:    omp.requires = #omp<clause_requires unified_shared_memory>
+block data
+  !$omp requires unified_shared_memory
+  integer :: x
+  common /block/ x
+  data x / 10 /
+end
+
+subroutine f
+  !$omp declare target
+end subroutine f

diff  --git a/flang/test/Lower/OpenMP/requires-common.f90 b/flang/test/Lower/OpenMP/requires-common.f90
new file mode 100644
index 000000000000000..2e112d72de3fdfc
--- /dev/null
+++ b/flang/test/Lower/OpenMP/requires-common.f90
@@ -0,0 +1,19 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
+
+! This test checks the lowering of requires into MLIR
+
+!CHECK:      module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
+block data init
+  !$omp requires unified_shared_memory
+  integer :: x
+  common /block/ x
+  data x / 10 /
+end
+
+subroutine f
+  !$omp declare target
+end subroutine f

diff  --git a/flang/test/Lower/OpenMP/requires-notarget.f90 b/flang/test/Lower/OpenMP/requires-notarget.f90
new file mode 100644
index 000000000000000..bfa509208428879
--- /dev/null
+++ b/flang/test/Lower/OpenMP/requires-notarget.f90
@@ -0,0 +1,14 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
+
+! This test checks that requires lowering into MLIR skips creating the
+! omp.requires attribute with target-related clauses if there are no device
+! functions in the compilation unit
+
+!CHECK:      module attributes {
+!CHECK-NOT:  omp.requires
+program requires
+  !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
+end program requires

diff  --git a/flang/test/Lower/OpenMP/requires.f90 b/flang/test/Lower/OpenMP/requires.f90
new file mode 100644
index 000000000000000..bc53931b9f240cf
--- /dev/null
+++ b/flang/test/Lower/OpenMP/requires.f90
@@ -0,0 +1,14 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
+
+! This test checks the lowering of requires into MLIR
+
+!CHECK:      module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires reverse_offload|unified_shared_memory>
+program requires
+  !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
+  !$omp target
+  !$omp end target
+end program requires


        


More information about the flang-commits mailing list