[flang-commits] [flang] [Flang][MLIR][OpenMP] Create a deferred declare target marking process for Bridge.cpp (PR #78502)

via flang-commits flang-commits at lists.llvm.org
Wed Jan 17 12:54:46 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (agozillon)

<details>
<summary>Changes</summary>

This patch seeks to create a process that happens on module finalization for OpenMP, in which a list of operations that had declare target directives applied to them and were not generated at the time of processing the original declare target directive are re-checked to apply the appropriate declare target semantics.

This works by maintaining a vector of declare target related data inside of the FIR converter, in this case the symbol and the two relevant unsigned integers representing the enumerators. This vector is added to via a new function called from Bridge.cpp, insertDeferredDeclareTargets, which happens prior to the processing of the directive (similarly to getDeclareTargetFunctionDevice currently for requires), it effectively checks if the Operation the declare target directive is applied to currently exists, if it doesn't it appends to the vector. This is a seperate function to the processing of the declare target via the overloaded genOMP as we unfortunately do not have access to the list without passing it through every call, as the AbstractConverter we pass will not allow access to it (I've seen no other cases of casting it to a FirConverter, so I opted to not do that).

The list is then processed at the end of the module in the finalizeOpenMPLowering function in Bridge by calling a new function markDelayedDeclareTargetFunctions which marks the latently generated operations. In certain cases, some still will not be generated, e.g. if an interface is defined, marked as declare target, but has no definition or usage in the module then it will not be emitted to the module, so due to these cases we must silently ignore when an operation has not been found via it's symbol.

The main use-case for this (although, I imagine there is others) is for processing interfaces that have been declared in a module with a declare target directive but do not have their implementation defined in the same module. For example, inside of a seperate C++ module that will be linked in. In cases where the interface is called inside of a target region it'll be marked as used on device appropriately (although, realistically a user should explicitly mark it to match the corresponding definition), however, in cases where it's used in a non-clear manner through something like a function pointer passed to an external call we require this explicit marking, which this patch adds support for (currently will cause the compiler to crash).

An alternative to this approach may be to do it in the implicit marking pass, by searching for function pointers/addresses used in target regions and marking the FuncOp's appropriately. But this doesn't seem like the right approach, we would be ignoring the users explicit markings at that point if they exist (we'd have no way of saving the data and transferring it to the pass trivially) in lieu of our own assumptions, and this would also not support other possible operation types that may need deferred marking.

---
Full diff: https://github.com/llvm/llvm-project/pull/78502.diff


4 Files Affected:

- (modified) flang/include/flang/Lower/OpenMP.h (+16) 
- (modified) flang/lib/Lower/Bridge.cpp (+23-1) 
- (modified) flang/lib/Lower/OpenMP.cpp (+128-34) 
- (added) flang/test/Lower/OpenMP/declare-target-deferred-marking.f90 (+60) 


``````````diff
diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index 872b7d50c3e7a8b..6a97590c0f02b92 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -14,6 +14,9 @@
 #define FORTRAN_LOWER_OPENMP_H
 
 #include <cinttypes>
+#include <utility>
+
+#include "llvm/ADT/SmallVector.h"
 
 namespace mlir {
 class Value;
@@ -84,6 +87,19 @@ bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
 bool isOpenMPDeviceDeclareTarget(Fortran::lower::AbstractConverter &,
                                  Fortran::lower::pft::Evaluation &,
                                  const parser::OpenMPDeclarativeConstruct &);
+void gatherDeferredDeclareTargets(
+    Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &,
+    const parser::OpenMPDeclarativeConstruct &,
+    llvm::SmallVectorImpl<std::tuple<
+        uint32_t /*mlir::omp::DeclareTargetCaptureClause*/, uint32_t,
+        /*mlir::omp::DeclareTargetDeviceType*/ Fortran::semantics::Symbol>> &);
+bool markDelayedDeclareTargetFunctions(
+    mlir::Operation *,
+    llvm::SmallVectorImpl<
+        std::tuple<uint32_t /*mlir::omp::DeclareTargetCaptureClause*/,
+                   uint32_t, /*mlir::omp::DeclareTargetDeviceType*/
+                   Fortran::semantics::Symbol>> &,
+    AbstractConverter &);
 void genOpenMPRequires(mlir::Operation *, const Fortran::semantics::Symbol *);
 
 } // namespace lower
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 8006b9b426f4dc6..45c395d5f803bfd 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2440,6 +2440,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     ompDeviceCodeFound =
         ompDeviceCodeFound ||
         Fortran::lower::isOpenMPDeviceDeclareTarget(*this, getEval(), ompDecl);
+    Fortran::lower::gatherDeferredDeclareTargets(*this, getEval(), ompDecl,
+                                                 deferredDeclareTarget);
     genOpenMPDeclarativeConstruct(
         *this, localSymbols, bridge.getSemanticsContext(), getEval(), ompDecl);
     builder->restoreInsertionPoint(insertPt);
@@ -4656,8 +4658,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Lower functions contained in a module.
   void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
-    for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
+    for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions) {
       lowerFunc(f);
+    }
   }
 
   void setCurrentPosition(const Fortran::parser::CharBlock &position) {
@@ -4959,6 +4962,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// lowering.
   void finalizeOpenMPLowering(
       const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+    if (!deferredDeclareTarget.empty()) {
+      bool deferredDeviceFuncFound =
+          Fortran::lower::markDelayedDeclareTargetFunctions(
+              getModuleOp().getOperation(), deferredDeclareTarget, *this);
+      if (!ompDeviceCodeFound)
+        ompDeviceCodeFound = deferredDeviceFuncFound;
+    }
+
     // Set the module attribute related to OpenMP requires directives
     if (ompDeviceCodeFound)
       Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
@@ -5015,6 +5026,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// intended for device offloading has been detected
   bool ompDeviceCodeFound = false;
 
+  /// Keeps track of symbols defined as declare target that could not be
+  /// processed at the time of lowering the declare target construct, such
+  /// as certain cases where interfaces are declared but not defined within
+  /// a module.
+  llvm::SmallVector<
+      std::tuple<uint32_t /*mlir::omp::DeclareTargetCaptureClause*/,
+                 uint32_t, /*mlir::omp::DeclareTargetDeviceType*/
+                 Fortran::semantics::Symbol>,
+      2>
+      deferredDeclareTarget;
+
   const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr};
 
   /// Stack of derived type under construction to avoid infinite loops when
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index c770d1c60718c35..3a848f299cbc16b 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2885,6 +2885,38 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
   return deviceType;
 }
 
+static void insertDeferredDeclareTargets(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
+    llvm::SmallVectorImpl<
+        std::tuple<uint32_t /*mlir::omp::DeclareTargetCaptureClause*/,
+                   uint32_t, /*mlir::omp::DeclareTargetDeviceType*/
+                   Fortran::semantics::Symbol>> &deferredDeclareTarget) {
+  llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
+  mlir::omp::DeclareTargetDeviceType devType = getDeclareTargetInfo(
+      converter, eval, declareTargetConstruct, symbolAndClause);
+  // Return the device type only if at least one of the targets for the
+  // directive is a function or subroutine
+  mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
+
+  for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
+    mlir::Operation *op = mod.lookupSymbol(
+        converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));
+
+    if (!op) {
+      deferredDeclareTarget.push_back(std::make_tuple(
+          static_cast<
+              std::underlying_type_t<mlir::omp::DeclareTargetCaptureClause>>(
+              std::get<0>(symClause)),
+          static_cast<
+              std::underlying_type_t<mlir::omp::DeclareTargetDeviceType>>(
+              devType),
+          std::get<1>(symClause)));
+    }
+  }
+}
+
 static std::optional<mlir::omp::DeclareTargetDeviceType>
 getDeclareTargetFunctionDevice(
     Fortran::lower::AbstractConverter &converter,
@@ -2902,7 +2934,7 @@ getDeclareTargetFunctionDevice(
     mlir::Operation *op = mod.lookupSymbol(
         converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));
 
-    if (mlir::isa<mlir::func::FuncOp>(op))
+    if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op))
       return deviceType;
   }
 
@@ -3499,6 +3531,31 @@ genOMP(Fortran::lower::AbstractConverter &converter,
       atomicConstruct.u);
 }
 
+static void
+markDeclareTarget(mlir::Operation *op,
+                  Fortran::lower::AbstractConverter &converter,
+                  mlir::omp::DeclareTargetCaptureClause captureClause,
+                  mlir::omp::DeclareTargetDeviceType deviceType) {
+  auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op);
+  if (!declareTargetOp)
+    fir::emitFatalError(
+        converter.getCurrentLocation(),
+        "Attempt to apply declare target on unsupported operation");
+
+  // The function or global already has a declare target applied to it, very
+  // likely through implicit capture (usage in another declare target
+  // function/subroutine). It should be marked as any if it has been assigned
+  // both host and nohost, else we skip, as there is no change
+  if (declareTargetOp.isDeclareTarget()) {
+    if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
+      declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
+                                       captureClause);
+    return;
+  }
+
+  declareTargetOp.setDeclareTarget(deviceType, captureClause);
+}
+
 static void genOMP(Fortran::lower::AbstractConverter &converter,
                    Fortran::lower::SymMap &symTable,
                    Fortran::semantics::SemanticsContext &semanticsContext,
@@ -3513,42 +3570,16 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
   for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
     mlir::Operation *op = mod.lookupSymbol(
         converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));
-    // There's several cases this can currently be triggered and it could be
-    // one of the following:
-    // 1) Invalid argument passed to a declare target that currently isn't
-    // captured by a frontend semantic check
-    // 2) The symbol of a valid argument is not correctly updated by one of
-    // the prior passes, resulting in missing symbol information
-    // 3) It's a variable internal to a module or program, that is legal by
-    // Fortran OpenMP standards, but is currently unhandled as they do not
-    // appear in the symbol table as they are represented as allocas
+
+    // Some symbols are deferred until later in the module, these are handled
+    // apon finalization of the module for OpenMP inside of Bridge, so we simply
+    // skip for now.
     if (!op)
-      TODO(converter.getCurrentLocation(),
-           "Missing symbol, possible case of currently unsupported use of "
-           "a program local variable in declare target or erroneous symbol "
-           "information ");
-
-    auto declareTargetOp =
-        llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op);
-    if (!declareTargetOp)
-      fir::emitFatalError(
-          converter.getCurrentLocation(),
-          "Attempt to apply declare target on unsupported operation");
-
-    // The function or global already has a declare target applied to it, very
-    // likely through implicit capture (usage in another declare target
-    // function/subroutine). It should be marked as any if it has been assigned
-    // both host and nohost, else we skip, as there is no change
-    if (declareTargetOp.isDeclareTarget()) {
-      if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
-        declareTargetOp.setDeclareTarget(
-            mlir::omp::DeclareTargetDeviceType::any,
-            std::get<mlir::omp::DeclareTargetCaptureClause>(symClause));
       continue;
-    }
 
-    declareTargetOp.setDeclareTarget(
-        deviceType, std::get<mlir::omp::DeclareTargetCaptureClause>(symClause));
+    markDeclareTarget(
+        op, converter,
+        std::get<mlir::omp::DeclareTargetCaptureClause>(symClause), deviceType);
   }
 }
 
@@ -4015,6 +4046,25 @@ bool Fortran::lower::isOpenMPTargetConstruct(
   return llvm::omp::allTargetSet.test(dir);
 }
 
+void Fortran::lower::gatherDeferredDeclareTargets(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl,
+    llvm::SmallVectorImpl<
+        std::tuple<uint32_t /*mlir::omp::DeclareTargetCaptureClause*/,
+                   uint32_t, /*mlir::omp::DeclareTargetDeviceType*/
+                   Fortran::semantics::Symbol>> &deferredDeclareTarget) {
+  std::visit(
+      Fortran::common::visitors{
+          [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) {
+            insertDeferredDeclareTargets(converter, eval, ompReq,
+                                         deferredDeclareTarget);
+          },
+          [&](const auto &) {},
+      },
+      ompDecl.u);
+}
+
 bool Fortran::lower::isOpenMPDeviceDeclareTarget(
     Fortran::lower::AbstractConverter &converter,
     Fortran::lower::pft::Evaluation &eval,
@@ -4032,6 +4082,50 @@ bool Fortran::lower::isOpenMPDeviceDeclareTarget(
       ompDecl.u);
 }
 
+// In certain cases such as subroutine or function interfaces which declare
+// but do not define or directly call the subroutine or function in the same
+// module, their lowering is delayed until after the declare target construct
+// itself is processed, so there symbol is not within the table.
+//
+// This function will also return true if we encounter any device declare
+// target cases, to satisfy checking if we require the requires attributes
+// on the module.
+bool Fortran::lower::markDelayedDeclareTargetFunctions(
+    mlir::Operation *mod,
+    llvm::SmallVectorImpl<
+        std::tuple<uint32_t /*mlir::omp::DeclareTargetCaptureClause*/,
+                   uint32_t, /*mlir::omp::DeclareTargetDeviceType*/
+                   Fortran::semantics::Symbol>> &deferredDeclareTargets,
+    AbstractConverter &converter) {
+  bool deviceCodeFound = false;
+  if (auto modOp = llvm::dyn_cast<mlir::ModuleOp>(mod)) {
+    for (auto sym : deferredDeclareTargets) {
+      mlir::Operation *op =
+          modOp.lookupSymbol(converter.mangleName(std::get<2>(sym)));
+
+      // Due to interfaces being optionally emitted on usage in a module,
+      // not finding an operation at this point cannot be a hard error, we
+      // simply ignore it for now.
+      if (!op)
+        continue;
+
+      auto devType =
+          static_cast<mlir::omp::DeclareTargetDeviceType>(std::get<1>(sym));
+      if (!deviceCodeFound &&
+          devType != mlir::omp::DeclareTargetDeviceType::host) {
+        deviceCodeFound = true;
+      }
+
+      markDeclareTarget(
+          op, converter,
+          static_cast<mlir::omp::DeclareTargetCaptureClause>(std::get<0>(sym)),
+          devType);
+    }
+  }
+
+  return deviceCodeFound;
+}
+
 void Fortran::lower::genOpenMPRequires(
     mlir::Operation *mod, const Fortran::semantics::Symbol *symbol) {
   using MlirRequires = mlir::omp::ClauseRequires;
diff --git a/flang/test/Lower/OpenMP/declare-target-deferred-marking.f90 b/flang/test/Lower/OpenMP/declare-target-deferred-marking.f90
new file mode 100644
index 000000000000000..1998c3da23af5fe
--- /dev/null
+++ b/flang/test/Lower/OpenMP/declare-target-deferred-marking.f90
@@ -0,0 +1,60 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s --check-prefixes ALL,HOST
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes ALL
+
+program main
+    use, intrinsic ::  iso_c_binding
+    implicit none
+    interface
+    subroutine any_interface()  bind(c,name="any_interface")
+        use, intrinsic :: iso_c_binding
+        implicit none
+    !$omp declare target enter(any_interface) device_type(any)
+    end subroutine any_interface
+
+    subroutine host_interface()  bind(c,name="host_interface")
+      use, intrinsic :: iso_c_binding
+      implicit none
+   !$omp declare target enter(host_interface) device_type(host)
+    end subroutine host_interface
+
+    subroutine device_interface()  bind(c,name="device_interface")
+        use, intrinsic :: iso_c_binding
+        implicit none
+    !$omp declare target enter(device_interface) device_type(nohost)
+    end subroutine device_interface
+
+    subroutine called_from_target_interface(f1, f2) bind(c,name="called_from_target_interface")
+        use, intrinsic :: iso_c_binding
+        implicit none
+        type(c_funptr),value :: f1
+        type(c_funptr),value :: f2
+    end subroutine called_from_target_interface
+
+    subroutine called_from_host_interface(f1) bind(c,name="called_from_host_interface")
+      use, intrinsic :: iso_c_binding
+      implicit none
+      type(c_funptr),value :: f1
+    end subroutine called_from_host_interface
+
+    subroutine unused_unemitted_interface()  bind(c,name="unused_unemitted_interface")
+      use, intrinsic :: iso_c_binding
+      implicit none
+    !$omp declare target enter(unused_unemitted_interface) device_type(nohost)
+    end subroutine unused_unemitted_interface
+
+    end interface
+
+    CALL called_from_host_interface(c_funloc(host_interface))
+!$omp target
+    CALL called_from_target_interface(c_funloc(any_interface), c_funloc(device_interface))
+!$omp end target
+ end program main
+
+!HOST-LABEL: func.func {{.*}} @host_interface()
+!HOST-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (enter)>{{.*}}
+!ALL-LABEL: func.func {{.*}} @called_from_target_interface(!fir.ref<i64>, !fir.ref<i64>)
+!ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
+!ALL-LABEL: func.func {{.*}} @any_interface()
+!ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (enter)>{{.*}}
+!ALL-LABEL: func.func {{.*}} @device_interface()
+!ALL-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (enter)>{{.*}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/78502


More information about the flang-commits mailing list