[flang-commits] [flang] [Flang][MLIR][OpenMP] Create a deferred declare target marking process for Bridge.cpp (PR #78502)

via flang-commits flang-commits at lists.llvm.org
Wed Jan 17 12:56:11 PST 2024


https://github.com/agozillon updated https://github.com/llvm/llvm-project/pull/78502

>From 5c0493c485a532d5bf30bccc871ca110a8f2340f Mon Sep 17 00:00:00 2001
From: Andrew Gozillon <Andrew.Gozillon at amd.com>
Date: Wed, 17 Jan 2024 14:35:55 -0600
Subject: [PATCH 1/2] [Flang][MLIR][OpenMP] Create a deferred declare target
 marking process for Bridge.cpp

This patch seeks to create a process that happens apon module
finalisation for OpenMP, in which a list of operations that had
declare target directives applied to them and were not generated
at the time of processing the original declare target directive are
re-checked to apply the appropriate declare target semantics.

This works by maintaining a vector of declare target related data
inside of the FIR converter, in this case the symbol and the two
relevant unsigned integers representing the enumerators. This
vector is added to via a new function called from Bridge.cpp,
insertDeferredDeclareTargets, which happens prior to the
processing of the directive (similarly to getDeclareTargetFunctionDevice
currently for requires), it effectively checks if the Operation the
declare target directive is applied to currently exists, if it doesn't it
appends to the vector. This is a seperate function to the processing
of the declare target via the overloaded genOMP as we unfortunately
do not have access to the list without passing it through every call,
as the AbstractConverter we pass will not allow access to it (I've seen
no other cases of casting it to a FirConverter, so I opted to not do that).

The list is then processed at the end of the module in the
finalizeOpenMPLowering function in Bridge by calling a
new function markDelayedDeclareTargetFunctions which
marks the latently generated operations. In certain cases,
some still will not be generated, e.g. if an interface is defined,
marked as declare target, but has no definition or usage in the
module then it will not be emitted to the module, so due to
these cases we must silently ignore when an operation has
not been found via it's symbol.

The main use-case for this (although, I imagine there is others) is
for processing interfaces that have been declared in a module with
a declare target directive but do not have their implementation
defined in the same module. For example, inside of a seperate
C++ module that will be linked in. In cases where the interface
is called inside of a target region it'll be marked as used on
device appropriately (although, realistically a user should explicitly
mark it to match the corresponding definition), however, in
cases where it's used in a non-clear manner through
something like a function pointer passed to an external call
we require this explicit marking, which this patch adds
support for (currently will cause the compiler to crash).

An alternative to this approach may be to do it in the implicit marking pass,
by searching for function pointers/addresses used in target regions and
marking the FuncOp's appropriately. But this doesn't seem like the right
approach, we would be ignoring the users explicit markings at that point
if they exist (we'd have no way of saving the data and transferring it to the
pass trivially) in lieu of our own assumptions, and this would also not support
other possible operation types that may need deferred marking.
---
 flang/include/flang/Lower/OpenMP.h            |  16 ++
 flang/lib/Lower/Bridge.cpp                    |  24 ++-
 flang/lib/Lower/OpenMP.cpp                    | 162 ++++++++++++++----
 .../declare-target-deferred-marking.f90       |  60 +++++++
 4 files changed, 227 insertions(+), 35 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/declare-target-deferred-marking.f90

diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index 872b7d50c3e7a8b..6a97590c0f02b92 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -14,6 +14,9 @@
 #define FORTRAN_LOWER_OPENMP_H
 
 #include <cinttypes>
+#include <utility>
+
+#include "llvm/ADT/SmallVector.h"
 
 namespace mlir {
 class Value;
@@ -84,6 +87,19 @@ bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
 bool isOpenMPDeviceDeclareTarget(Fortran::lower::AbstractConverter &,
                                  Fortran::lower::pft::Evaluation &,
                                  const parser::OpenMPDeclarativeConstruct &);
+void gatherDeferredDeclareTargets(
+    Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &,
+    const parser::OpenMPDeclarativeConstruct &,
+    llvm::SmallVectorImpl<std::tuple<
+        uint32_t /*mlir::omp::DeclareTargetCaptureClause*/, uint32_t,
+        /*mlir::omp::DeclareTargetDeviceType*/ Fortran::semantics::Symbol>> &);
+bool markDelayedDeclareTargetFunctions(
+    mlir::Operation *,
+    llvm::SmallVectorImpl<
+        std::tuple<uint32_t /*mlir::omp::DeclareTargetCaptureClause*/,
+                   uint32_t, /*mlir::omp::DeclareTargetDeviceType*/
+                   Fortran::semantics::Symbol>> &,
+    AbstractConverter &);
 void genOpenMPRequires(mlir::Operation *, const Fortran::semantics::Symbol *);
 
 } // namespace lower
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 8006b9b426f4dc6..45c395d5f803bfd 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2440,6 +2440,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     ompDeviceCodeFound =
         ompDeviceCodeFound ||
         Fortran::lower::isOpenMPDeviceDeclareTarget(*this, getEval(), ompDecl);
+    Fortran::lower::gatherDeferredDeclareTargets(*this, getEval(), ompDecl,
+                                                 deferredDeclareTarget);
     genOpenMPDeclarativeConstruct(
         *this, localSymbols, bridge.getSemanticsContext(), getEval(), ompDecl);
     builder->restoreInsertionPoint(insertPt);
@@ -4656,8 +4658,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Lower functions contained in a module.
   void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
-    for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
+    for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions) {
       lowerFunc(f);
+    }
   }
 
   void setCurrentPosition(const Fortran::parser::CharBlock &position) {
@@ -4959,6 +4962,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// lowering.
   void finalizeOpenMPLowering(
       const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+    if (!deferredDeclareTarget.empty()) {
+      bool deferredDeviceFuncFound =
+          Fortran::lower::markDelayedDeclareTargetFunctions(
+              getModuleOp().getOperation(), deferredDeclareTarget, *this);
+      if (!ompDeviceCodeFound)
+        ompDeviceCodeFound = deferredDeviceFuncFound;
+    }
+
     // Set the module attribute related to OpenMP requires directives
     if (ompDeviceCodeFound)
       Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
@@ -5015,6 +5026,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// intended for device offloading has been detected
   bool ompDeviceCodeFound = false;
 
+  /// Keeps track of symbols defined as declare target that could not be
+  /// processed at the time of lowering the declare target construct, such
+  /// as certain cases where interfaces are declared but not defined within
+  /// a module.
+  llvm::SmallVector<
+      std::tuple<uint32_t /*mlir::omp::DeclareTargetCaptureClause*/,
+                 uint32_t, /*mlir::omp::DeclareTargetDeviceType*/
+                 Fortran::semantics::Symbol>,
+      2>
+      deferredDeclareTarget;
+
   const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr};
 
   /// Stack of derived type under construction to avoid infinite loops when
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index c770d1c60718c35..3a848f299cbc16b 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2885,6 +2885,38 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
   return deviceType;
 }
 
+static void insertDeferredDeclareTargets(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
+    llvm::SmallVectorImpl<
+        std::tuple<uint32_t /*mlir::omp::DeclareTargetCaptureClause*/,
+                   uint32_t, /*mlir::omp::DeclareTargetDeviceType*/
+                   Fortran::semantics::Symbol>> &deferredDeclareTarget) {
+  llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
+  mlir::omp::DeclareTargetDeviceType devType = getDeclareTargetInfo(
+      converter, eval, declareTargetConstruct, symbolAndClause);
+  // Return the device type only if at least one of the targets for the
+  // directive is a function or subroutine
+  mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
+
+  for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
+    mlir::Operation *op = mod.lookupSymbol(
+        converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));
+
+    if (!op) {
+      deferredDeclareTarget.push_back(std::make_tuple(
+          static_cast<
+              std::underlying_type_t<mlir::omp::DeclareTargetCaptureClause>>(
+              std::get<0>(symClause)),
+          static_cast<
+              std::underlying_type_t<mlir::omp::DeclareTargetDeviceType>>(
+              devType),
+          std::get<1>(symClause)));
+    }
+  }
+}
+
 static std::optional<mlir::omp::DeclareTargetDeviceType>
 getDeclareTargetFunctionDevice(
     Fortran::lower::AbstractConverter &converter,
@@ -2902,7 +2934,7 @@ getDeclareTargetFunctionDevice(
     mlir::Operation *op = mod.lookupSymbol(
         converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));
 
-    if (mlir::isa<mlir::func::FuncOp>(op))
+    if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op))
       return deviceType;
   }
 
@@ -3499,6 +3531,31 @@ genOMP(Fortran::lower::AbstractConverter &converter,
       atomicConstruct.u);
 }
 
+static void
+markDeclareTarget(mlir::Operation *op,
+                  Fortran::lower::AbstractConverter &converter,
+                  mlir::omp::DeclareTargetCaptureClause captureClause,
+                  mlir::omp::DeclareTargetDeviceType deviceType) {
+  auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op);
+  if (!declareTargetOp)
+    fir::emitFatalError(
+        converter.getCurrentLocation(),
+        "Attempt to apply declare target on unsupported operation");
+
+  // The function or global already has a declare target applied to it, very
+  // likely through implicit capture (usage in another declare target
+  // function/subroutine). It should be marked as any if it has been assigned
+  // both host and nohost, else we skip, as there is no change
+  if (declareTargetOp.isDeclareTarget()) {
+    if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
+      declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
+                                       captureClause);
+    return;
+  }
+
+  declareTargetOp.setDeclareTarget(deviceType, captureClause);
+}
+
 static void genOMP(Fortran::lower::AbstractConverter &converter,
                    Fortran::lower::SymMap &symTable,
                    Fortran::semantics::SemanticsContext &semanticsContext,
@@ -3513,42 +3570,16 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
   for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
     mlir::Operation *op = mod.lookupSymbol(
         converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));
-    // There's several cases this can currently be triggered and it could be
-    // one of the following:
-    // 1) Invalid argument passed to a declare target that currently isn't
-    // captured by a frontend semantic check
-    // 2) The symbol of a valid argument is not correctly updated by one of
-    // the prior passes, resulting in missing symbol information
-    // 3) It's a variable internal to a module or program, that is legal by
-    // Fortran OpenMP standards, but is currently unhandled as they do not
-    // appear in the symbol table as they are represented as allocas
+
+    // Some symbols are deferred until later in the module, these are handled
+    // apon finalization of the module for OpenMP inside of Bridge, so we simply
+    // skip for now.
     if (!op)
-      TODO(converter.getCurrentLocation(),
-           "Missing symbol, possible case of currently unsupported use of "
-           "a program local variable in declare target or erroneous symbol "
-           "information ");
-
-    auto declareTargetOp =
-        llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op);
-    if (!declareTargetOp)
-      fir::emitFatalError(
-          converter.getCurrentLocation(),
-          "Attempt to apply declare target on unsupported operation");
-
-    // The function or global already has a declare target applied to it, very
-    // likely through implicit capture (usage in another declare target
-    // function/subroutine). It should be marked as any if it has been assigned
-    // both host and nohost, else we skip, as there is no change
-    if (declareTargetOp.isDeclareTarget()) {
-      if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
-        declareTargetOp.setDeclareTarget(
-            mlir::omp::DeclareTargetDeviceType::any,
-            std::get<mlir::omp::DeclareTargetCaptureClause>(symClause));
       continue;
-    }
 
-    declareTargetOp.setDeclareTarget(
-        deviceType, std::get<mlir::omp::DeclareTargetCaptureClause>(symClause));
+    markDeclareTarget(
+        op, converter,
+        std::get<mlir::omp::DeclareTargetCaptureClause>(symClause), deviceType);
   }
 }
 
@@ -4015,6 +4046,25 @@ bool Fortran::lower::isOpenMPTargetConstruct(
   return llvm::omp::allTargetSet.test(dir);
 }
 
+void Fortran::lower::gatherDeferredDeclareTargets(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl,
+    llvm::SmallVectorImpl<
+        std::tuple<uint32_t /*mlir::omp::DeclareTargetCaptureClause*/,
+                   uint32_t, /*mlir::omp::DeclareTargetDeviceType*/
+                   Fortran::semantics::Symbol>> &deferredDeclareTarget) {
+  std::visit(
+      Fortran::common::visitors{
+          [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) {
+            insertDeferredDeclareTargets(converter, eval, ompReq,
+                                         deferredDeclareTarget);
+          },
+          [&](const auto &) {},
+      },
+      ompDecl.u);
+}
+
 bool Fortran::lower::isOpenMPDeviceDeclareTarget(
     Fortran::lower::AbstractConverter &converter,
     Fortran::lower::pft::Evaluation &eval,
@@ -4032,6 +4082,50 @@ bool Fortran::lower::isOpenMPDeviceDeclareTarget(
       ompDecl.u);
 }
 
+// In certain cases such as subroutine or function interfaces which declare
+// but do not define or directly call the subroutine or function in the same
+// module, their lowering is delayed until after the declare target construct
+// itself is processed, so there symbol is not within the table.
+//
+// This function will also return true if we encounter any device declare
+// target cases, to satisfy checking if we require the requires attributes
+// on the module.
+bool Fortran::lower::markDelayedDeclareTargetFunctions(
+    mlir::Operation *mod,
+    llvm::SmallVectorImpl<
+        std::tuple<uint32_t /*mlir::omp::DeclareTargetCaptureClause*/,
+                   uint32_t, /*mlir::omp::DeclareTargetDeviceType*/
+                   Fortran::semantics::Symbol>> &deferredDeclareTargets,
+    AbstractConverter &converter) {
+  bool deviceCodeFound = false;
+  if (auto modOp = llvm::dyn_cast<mlir::ModuleOp>(mod)) {
+    for (auto sym : deferredDeclareTargets) {
+      mlir::Operation *op =
+          modOp.lookupSymbol(converter.mangleName(std::get<2>(sym)));
+
+      // Due to interfaces being optionally emitted on usage in a module,
+      // not finding an operation at this point cannot be a hard error, we
+      // simply ignore it for now.
+      if (!op)
+        continue;
+
+      auto devType =
+          static_cast<mlir::omp::DeclareTargetDeviceType>(std::get<1>(sym));
+      if (!deviceCodeFound &&
+          devType != mlir::omp::DeclareTargetDeviceType::host) {
+        deviceCodeFound = true;
+      }
+
+      markDeclareTarget(
+          op, converter,
+          static_cast<mlir::omp::DeclareTargetCaptureClause>(std::get<0>(sym)),
+          devType);
+    }
+  }
+
+  return deviceCodeFound;
+}
+
 void Fortran::lower::genOpenMPRequires(
     mlir::Operation *mod, const Fortran::semantics::Symbol *symbol) {
   using MlirRequires = mlir::omp::ClauseRequires;
diff --git a/flang/test/Lower/OpenMP/declare-target-deferred-marking.f90 b/flang/test/Lower/OpenMP/declare-target-deferred-marking.f90
new file mode 100644
index 000000000000000..1998c3da23af5fe
--- /dev/null
+++ b/flang/test/Lower/OpenMP/declare-target-deferred-marking.f90
@@ -0,0 +1,60 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s --check-prefixes ALL,HOST
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes ALL
+
+program main
+    use, intrinsic ::  iso_c_binding
+    implicit none
+    interface
+    subroutine any_interface()  bind(c,name="any_interface")
+        use, intrinsic :: iso_c_binding
+        implicit none
+    !$omp declare target enter(any_interface) device_type(any)
+    end subroutine any_interface
+
+    subroutine host_interface()  bind(c,name="host_interface")
+      use, intrinsic :: iso_c_binding
+      implicit none
+   !$omp declare target enter(host_interface) device_type(host)
+    end subroutine host_interface
+
+    subroutine device_interface()  bind(c,name="device_interface")
+        use, intrinsic :: iso_c_binding
+        implicit none
+    !$omp declare target enter(device_interface) device_type(nohost)
+    end subroutine device_interface
+
+    subroutine called_from_target_interface(f1, f2) bind(c,name="called_from_target_interface")
+        use, intrinsic :: iso_c_binding
+        implicit none
+        type(c_funptr),value :: f1
+        type(c_funptr),value :: f2
+    end subroutine called_from_target_interface
+
+    subroutine called_from_host_interface(f1) bind(c,name="called_from_host_interface")
+      use, intrinsic :: iso_c_binding
+      implicit none
+      type(c_funptr),value :: f1
+    end subroutine called_from_host_interface
+
+    subroutine unused_unemitted_interface()  bind(c,name="unused_unemitted_interface")
+      use, intrinsic :: iso_c_binding
+      implicit none
+    !$omp declare target enter(unused_unemitted_interface) device_type(nohost)
+    end subroutine unused_unemitted_interface
+
+    end interface
+
+    CALL called_from_host_interface(c_funloc(host_interface))
+!$omp target
+    CALL called_from_target_interface(c_funloc(any_interface), c_funloc(device_interface))
+!$omp end target
+ end program main
+
+!HOST-LABEL: func.func {{.*}} @host_interface()
+!HOST-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (enter)>{{.*}}
+!ALL-LABEL: func.func {{.*}} @called_from_target_interface(!fir.ref<i64>, !fir.ref<i64>)
+!ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
+!ALL-LABEL: func.func {{.*}} @any_interface()
+!ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (enter)>{{.*}}
+!ALL-LABEL: func.func {{.*}} @device_interface()
+!ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (enter)>{{.*}}

>From e4f5723d3739c4955ff4c7e589d3f4171b92472d Mon Sep 17 00:00:00 2001
From: Andrew Gozillon <Andrew.Gozillon at amd.com>
Date: Wed, 17 Jan 2024 14:55:49 -0600
Subject: [PATCH 2/2] [NFC] Remove some extra braces clang-format may or may
 not have added

---
 flang/lib/Lower/Bridge.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 45c395d5f803bfd..df09bd5a0e34c63 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4658,9 +4658,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Lower functions contained in a module.
   void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
-    for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions) {
+    for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
       lowerFunc(f);
-    }
   }
 
   void setCurrentPosition(const Fortran::parser::CharBlock &position) {



More information about the flang-commits mailing list