[flang-commits] [flang] [Flang][Semantics] Allow declare target to be used on functions external to the declare targets scope (PR #122546)

via flang-commits flang-commits at lists.llvm.org
Fri Jan 10 15:06:33 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: None (agozillon)

<details>
<summary>Changes</summary>

Whilst a little contrived, OpenMP allows you to utilise declare target in the scope of one function to mark another function declare target, currently this leads to a semantic error.

This appears to be because when we process the declare target directive in the scope of another function (referring to another function), we do not search externally from that functions scope to find possible prior definitions, we only search in the current scope, this leads to us implicitly defining a new variable and using that when implicit none is not specified and then error'ng out or error'ng out earlier when implict none is defined. This patch tries to address this behaviour by looking externally for a function first and using that, before defaulting back to the prior behaviour.

---
Full diff: https://github.com/llvm/llvm-project/pull/122546.diff


2 Files Affected:

- (modified) flang/lib/Semantics/resolve-names.cpp (+86-2) 
- (added) flang/test/Semantics/OpenMP/declare-target08.f90 (+41) 


``````````diff
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 51e7c5960dc2ef..44bd540a4228b0 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -736,6 +736,8 @@ class ScopeHandler : public ImplicitRulesVisitor {
     std::vector<const std::list<parser::EquivalenceObject> *> equivalenceSets;
     // Names of all common block objects in the scope
     std::set<SourceName> commonBlockObjects;
+    // Names of all names that show in a declare target declaration
+    std::set<SourceName> declareTargetNames;
     // Info about SAVE statements and attributes in current scope
     struct {
       std::optional<SourceName> saveAll; // "SAVE" without entity list
@@ -1223,6 +1225,7 @@ class DeclarationVisitor : public ArraySpecVisitor,
   const parser::Name *FindComponent(const parser::Name *, const parser::Name &);
   void Initialization(const parser::Name &, const parser::Initialization &,
       bool inComponentDecl);
+  bool FindAndMarkDeclareTargetSymbol(const parser::Name &);
   bool PassesLocalityChecks(
       const parser::Name &name, Symbol &symbol, Symbol::Flag flag);
   bool CheckForHostAssociatedImplicit(const parser::Name &);
@@ -1524,7 +1527,44 @@ class OmpVisitor : public virtual DeclarationVisitor {
     return true;
   }
   void Post(const parser::OpenMPThreadprivate &) { SkipImplicitTyping(false); }
-  bool Pre(const parser::OpenMPDeclareTargetConstruct &) {
+  bool Pre(const parser::OpenMPDeclareTargetConstruct &x) {
+    const auto &spec{std::get<parser::OmpDeclareTargetSpecifier>(x.t)};
+    auto populateDeclareTargetNames =
+        [this](const parser::OmpObjectList &objectList) {
+          for (const auto &ompObject : objectList.v) {
+            common::visit(
+                common::visitors{
+                    [&](const parser::Designator &designator) {
+                      if (const auto *name{
+                              semantics::getDesignatorNameIfDataRef(
+                                  designator)}) {
+                        specPartState_.declareTargetNames.insert(name->source);
+                      }
+                    },
+                    [&](const parser::Name &name) {
+                      specPartState_.declareTargetNames.insert(name.source);
+                    }},
+                ompObject.u);
+          }
+        };
+
+    if (const auto *objectList{parser::Unwrap<parser::OmpObjectList>(spec.u)}) {
+      populateDeclareTargetNames(*objectList);
+    } else if (const auto *clauseList{
+                 parser::Unwrap<parser::OmpClauseList>(spec.u)}) {
+      for (const auto &clause : clauseList->v) {
+        if (const auto *toClause{std::get_if<parser::OmpClause::To>(&clause.u)}) {
+          populateDeclareTargetNames({std::get<parser::OmpObjectList>(toClause->v.t)});
+        } else if (const auto *linkClause{
+                       std::get_if<parser::OmpClause::Link>(&clause.u)}) {
+            populateDeclareTargetNames(linkClause->v);
+        } else if (const auto *enterClause{
+                     std::get_if<parser::OmpClause::Enter>(&clause.u)}) {
+        populateDeclareTargetNames(enterClause->v);
+        }
+      }
+    }
+
     SkipImplicitTyping(true);
     return true;
   }
@@ -8114,7 +8154,12 @@ const parser::Name *DeclarationVisitor::ResolveDataRef(
 // If implicit types are allowed, ensure name is in the symbol table.
 // Otherwise, report an error if it hasn't been declared.
 const parser::Name *DeclarationVisitor::ResolveName(const parser::Name &name) {
-  FindSymbol(name);
+  if (!FindSymbol(name)) {
+    if (FindAndMarkDeclareTargetSymbol(name)) {
+      return &name;
+    }
+  }
+
   if (CheckForHostAssociatedImplicit(name)) {
     NotePossibleBadForwardRef(name);
     return &name;
@@ -8157,6 +8202,7 @@ const parser::Name *DeclarationVisitor::ResolveName(const parser::Name &name) {
         "Implied DO index '%s' uses itself in its own bounds expressions"_err_en_US,
         name.source);
   }
+
   MakeSymbol(InclusiveScope(), name.source, Attrs{});
   auto *symbol{FindSymbol(name)};
   if (!symbol) {
@@ -8164,6 +8210,7 @@ const parser::Name *DeclarationVisitor::ResolveName(const parser::Name &name) {
         "'%s' from host scoping unit is not accessible due to IMPORT"_err_en_US);
     return nullptr;
   }
+
   ConvertToObjectEntity(*symbol);
   ApplyImplicitRules(*symbol);
   NotePossibleBadForwardRef(name);
@@ -8298,6 +8345,43 @@ const parser::Name *DeclarationVisitor::FindComponent(
   return nullptr;
 }
 
+bool DeclarationVisitor::FindAndMarkDeclareTargetSymbol(const parser::Name &name) {
+  if (!specPartState_.declareTargetNames.empty()) {
+    if (specPartState_.declareTargetNames.find(name.source) !=
+        specPartState_.declareTargetNames.end()) {
+      if (!currScope().IsTopLevel()) {
+        // Search preceding scopes until we find a matching symbol or run out
+        // of scopes to search, we skip the current scope as it's already been
+        // designated as implicit here.
+        Symbol *symbol = nullptr;
+        for (auto *scope = &currScope().parent(); ; scope = &scope->parent()) {
+          symbol = scope->FindSymbol(name.source);
+          if (symbol) {
+            if (symbol->test(Symbol::Flag::Subroutine) ||
+                symbol->test(Symbol::Flag::Function)) {
+              const auto pair{currScope().try_emplace(
+                  symbol->name(), Attrs{}, HostAssocDetails{*symbol})};
+              name.symbol = &*pair.first->second;
+              symbol->test(Symbol::Flag::Subroutine)
+                  ? name.symbol->set(Symbol::Flag::Subroutine)
+                  : name.symbol->set(Symbol::Flag::Function);
+              return true;
+            }
+          }
+
+          // This is our loop exit condition, as parent() has an inbuilt assert
+          // if you call it on a top level scope, rather than returning a null
+          // value.
+          if (scope->IsTopLevel()) {
+            return false;
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
 void DeclarationVisitor::Initialization(const parser::Name &name,
     const parser::Initialization &init, bool inComponentDecl) {
   // Traversal of the initializer was deferred to here so that the
diff --git a/flang/test/Semantics/OpenMP/declare-target08.f90 b/flang/test/Semantics/OpenMP/declare-target08.f90
new file mode 100644
index 00000000000000..1438d79d373482
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/declare-target08.f90
@@ -0,0 +1,41 @@
+! RUN: %flang_fc1 -fopenmp -fdebug-dump-symbols %s | FileCheck %s
+
+subroutine bar(i, a)
+    !$omp declare target
+    real :: a
+    integer :: i
+    a = a - i
+end subroutine
+
+function baz(a)
+    !$omp declare target
+    real, intent(in) :: a
+    baz = a
+end function baz
+
+program main
+real a
+!CHECK: bar (Subroutine, OmpDeclareTarget): HostAssoc
+!CHECK: baz (Function, OmpDeclareTarget): HostAssoc
+!$omp declare target(bar)
+!$omp declare target(baz)
+
+a = baz(a)
+call bar(2,a)
+call foo(a)
+return
+end
+
+subroutine foo(a)
+real a
+integer i
+!CHECK: bar (Subroutine, OmpDeclareTarget): HostAssoc
+!CHECK: baz (Function, OmpDeclareTarget): HostAssoc
+!$omp declare target(bar)
+!$omp declare target(baz)
+!$omp target
+    a = baz(a)
+    call bar(i,a)
+!$omp end target
+return
+end

``````````

</details>


https://github.com/llvm/llvm-project/pull/122546


More information about the flang-commits mailing list