[flang-commits] [flang] [flang][OpenMP] Add structure checks for DECLARE VARIANT (PR #198799)

Abid Qadeer via flang-commits flang-commits at lists.llvm.org
Fri May 29 08:27:48 PDT 2026


https://github.com/abidh updated https://github.com/llvm/llvm-project/pull/198799

>From 223c1ea6399bdf6759cdc2d34f60ee8968adb877 Mon Sep 17 00:00:00 2001
From: Abid Qadeer <haqadeer at amd.com>
Date: Tue, 19 May 2026 18:29:52 +0100
Subject: [PATCH 1/2] [flang][OpenMP] Add structure checks for DECLARE VARIANT
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Add declare-variant structure checking in a dedicated source file:

- Validate the [base:]variant argument and treat a single-name form as
  variant with implicit base from the enclosing subprogram (OpenMP 5.1 ยง2.3.5).
- Require a MATCH clause and reject duplicate (base, variant) pairs.
- Reject disallowed or duplicate MATCH clauses via the structure checker.
- Apply shared context-selector validation from metadirective to MATCH.
- Require a constant user condition in MATCH for declare variant only
  (dynamic selectors deferred).

Refactor metadirective support:

- Extract CheckContextSelectorSpecification for reuse.
- Reject SCORE on trait sets that do not allow it (also affects metadirective).

Co-authored-by: Cursor <cursoragent at cursor.com>
---
 flang/lib/Semantics/CMakeLists.txt            |   1 +
 .../Semantics/check-omp-declare-variant.cpp   | 183 ++++++++++++++++++
 .../lib/Semantics/check-omp-metadirective.cpp |  16 +-
 flang/lib/Semantics/check-omp-structure.cpp   |  54 +-----
 flang/lib/Semantics/check-omp-structure.h     |   6 +
 .../OpenMP/declare-variant-match.f90          | 118 +++++++++++
 .../test/Semantics/OpenMP/declare-variant.f90 |  30 +++
 7 files changed, 351 insertions(+), 57 deletions(-)
 create mode 100644 flang/lib/Semantics/check-omp-declare-variant.cpp
 create mode 100644 flang/test/Semantics/OpenMP/declare-variant-match.f90

diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt
index 44e6dfb4dd09f..42396a5b3b639 100644
--- a/flang/lib/Semantics/CMakeLists.txt
+++ b/flang/lib/Semantics/CMakeLists.txt
@@ -23,6 +23,7 @@ add_flang_library(FortranSemantics
   check-omp-atomic.cpp
   check-omp-loop.cpp
   check-omp-metadirective.cpp
+  check-omp-declare-variant.cpp
   check-omp-structure.cpp
   check-purity.cpp
   check-return.cpp
diff --git a/flang/lib/Semantics/check-omp-declare-variant.cpp b/flang/lib/Semantics/check-omp-declare-variant.cpp
new file mode 100644
index 0000000000000..f986dfc710aae
--- /dev/null
+++ b/flang/lib/Semantics/check-omp-declare-variant.cpp
@@ -0,0 +1,183 @@
+//===-- lib/Semantics/check-omp-declare-variant.cpp -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Structure checks for DECLARE VARIANT.
+//
+//===----------------------------------------------------------------------===//
+
+#include "check-omp-structure.h"
+
+#include "flang/Common/idioms.h"
+#include "flang/Common/visit.h"
+#include "flang/Evaluate/check-expression.h"
+#include "flang/Parser/parse-tree.h"
+#include "flang/Semantics/openmp-utils.h"
+#include "flang/Semantics/symbol.h"
+#include "flang/Semantics/tools.h"
+
+#include "llvm/Frontend/OpenMP/OMP.h"
+
+namespace Fortran::semantics {
+
+using namespace Fortran::semantics::omp;
+
+static const parser::traits::OmpContextSelectorSpecification *
+getMatchClauseContextSelector(const parser::OmpDirectiveSpecification &spec) {
+  for (const parser::OmpClause &clause : spec.Clauses().v) {
+    if (clause.Id() == llvm::omp::Clause::OMPC_match)
+      return &std::get<parser::OmpClause::Match>(clause.u).v.v;
+  }
+  return nullptr;
+}
+
+void OmpStructureChecker::CheckDeclareVariantUserConditions(
+    const parser::OmpContextSelector &ctx) {
+  using SetName = parser::OmpTraitSetSelectorName;
+  using TraitName = parser::OmpTraitSelectorName;
+
+  for (const parser::OmpTraitSetSelector &traitSet : ctx.v) {
+    if (std::get<SetName>(traitSet.t).v != SetName::Value::User) {
+      continue;
+    }
+    for (const parser::OmpTraitSelector &trait :
+        std::get<std::list<parser::OmpTraitSelector>>(traitSet.t)) {
+      const auto &traitName{std::get<TraitName>(trait.t)};
+      if (!std::holds_alternative<TraitName::Value>(traitName.u) ||
+          std::get<TraitName::Value>(traitName.u) !=
+              TraitName::Value::Condition) {
+        continue;
+      }
+      const auto &maybeProps{
+          std::get<std::optional<parser::OmpTraitSelector::Properties>>(
+              trait.t)};
+      if (!maybeProps) {
+        continue;
+      }
+      const auto &properties{
+          std::get<std::list<parser::OmpTraitProperty>>(maybeProps->t)};
+      if (properties.size() != 1) {
+        continue;
+      }
+      const parser::OmpTraitProperty &property{properties.front()};
+      const parser::ScalarExpr &scalarExpr{
+          std::get<parser::ScalarExpr>(property.u)};
+      auto maybeType{GetDynamicType(scalarExpr.thing.value())};
+      if (!maybeType || maybeType->category() != TypeCategory::Logical) {
+        continue;
+      }
+      if (const auto *expr{GetExpr(scalarExpr)}) {
+        if (!IsConstantExpr(*expr, &context_.foldingContext())) {
+          context_.Say(property.source,
+              "USER condition in the MATCH clause must be a constant expression"_err_en_US);
+        }
+      }
+    }
+  }
+}
+
+void OmpStructureChecker::CheckOmpDeclareVariantDirective(
+    const parser::OmpDeclareVariantDirective &x) {
+  const parser::OmpDirectiveSpecification &spec{x.v};
+  const parser::OmpArgumentList &args{spec.Arguments()};
+
+  if (args.v.size() != 1) {
+    context_.Say(args.source,
+        "DECLARE_VARIANT directive should have a single argument"_err_en_US);
+    return;
+  }
+
+  auto InvalidArgument{[&](parser::CharBlock source) {
+    context_.Say(source,
+        "The argument to the DECLARE_VARIANT directive should be [base-name:]variant-name"_err_en_US);
+  }};
+
+  auto CheckProcedureSymbol{[&](const Symbol *sym, parser::CharBlock source) {
+    if (sym) {
+      if (!IsProcedure(*sym) && !IsFunction(*sym)) {
+        auto &msg{context_.Say(source,
+            "The name '%s' should refer to a procedure"_err_en_US,
+            sym->name())};
+        if (sym->test(Symbol::Flag::Implicit)) {
+          msg.Attach(source, "The name '%s' has been implicitly declared"_en_US,
+              sym->name());
+        }
+      }
+    } else {
+      InvalidArgument(source);
+    }
+  }};
+
+  const Symbol *base{nullptr};
+  const Symbol *variant{nullptr};
+  const parser::OmpArgument &arg{args.v.front()};
+  common::visit( //
+      common::visitors{
+          [&](const parser::OmpBaseVariantNames &y) {
+            base = GetObjectSymbol(std::get<0>(y.t));
+            variant = GetObjectSymbol(std::get<1>(y.t));
+            CheckProcedureSymbol(base, arg.source);
+            CheckProcedureSymbol(variant, arg.source);
+          },
+          [&](const parser::OmpLocator &y) {
+            variant = GetArgumentSymbol(arg);
+            CheckProcedureSymbol(variant, arg.source);
+            // OpenMP 5.1 [2.3.5, declare variant directive, Restrictions]:
+            // "If base-proc-name is omitted then the declare variant directive
+            // must appear in an interface block or the specification part of a
+            // procedure." The same section requires the directive to appear in
+            // the specification part of the subprogram or interface body to
+            // which it applies. Infer the base procedure from that program
+            // unit.
+            const Scope &containingScope{context_.FindScope(x.source)};
+            if (const Symbol *host{
+                    GetProgramUnitContaining(containingScope).symbol()}) {
+              base = host;
+            }
+          },
+          [&](auto &&y) { InvalidArgument(arg.source); },
+      },
+      arg.u);
+
+  if (base && variant) {
+    base = &base->GetUltimate();
+    variant = &variant->GetUltimate();
+    if (base == variant) {
+      context_.Say(arg.source,
+          "The variant procedure must differ from the base procedure"_err_en_US);
+    } else if (!declareVariantPairs_.emplace(base, variant).second) {
+      context_.Say(arg.source,
+          "Variant '%s' was already specified for '%s' in another DECLARE VARIANT directive"_err_en_US,
+          variant->name(), base->name());
+    }
+  }
+
+  const parser::traits::OmpContextSelectorSpecification *matchSelector{
+      getMatchClauseContextSelector(spec)};
+  if (!matchSelector) {
+    context_.Say(x.source,
+        "DECLARE_VARIANT directive requires a MATCH clause"_err_en_US);
+    return;
+  }
+
+  EnterDirectiveNest(ContextSelectorNest);
+  CheckContextSelectorSpecification(*matchSelector);
+  CheckDeclareVariantUserConditions(*matchSelector);
+  ExitDirectiveNest(ContextSelectorNest);
+}
+
+void OmpStructureChecker::Enter(const parser::OmpDeclareVariantDirective &x) {
+  const parser::OmpDirectiveName &dirName{x.v.DirName()};
+  PushContextAndClauseSets(dirName.source, dirName.v);
+  CheckOmpDeclareVariantDirective(x);
+}
+
+void OmpStructureChecker::Leave(const parser::OmpDeclareVariantDirective &) {
+  dirContext_.pop_back();
+}
+
+} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/check-omp-metadirective.cpp b/flang/lib/Semantics/check-omp-metadirective.cpp
index c8c19e4ac7dac..d308c2ee7cac5 100644
--- a/flang/lib/Semantics/check-omp-metadirective.cpp
+++ b/flang/lib/Semantics/check-omp-metadirective.cpp
@@ -43,9 +43,8 @@ void OmpStructureChecker::Enter(const parser::OmpClause::When &x) {
       x.v, llvm::omp::OMPC_when, GetContext().clauseSource, context_);
 }
 
-void OmpStructureChecker::Enter(const parser::OmpContextSelector &ctx) {
-  EnterDirectiveNest(ContextSelectorNest);
-
+void OmpStructureChecker::CheckContextSelectorSpecification(
+    const parser::OmpContextSelector &ctx) {
   using SetName = parser::OmpTraitSetSelectorName;
   std::map<SetName::Value, const SetName *> visited;
 
@@ -66,6 +65,11 @@ void OmpStructureChecker::Enter(const parser::OmpContextSelector &ctx) {
   }
 }
 
+void OmpStructureChecker::Enter(const parser::OmpContextSelector &ctx) {
+  EnterDirectiveNest(ContextSelectorNest);
+  CheckContextSelectorSpecification(ctx);
+}
+
 void OmpStructureChecker::Leave(const parser::OmpContextSelector &) {
   ExitDirectiveNest(ContextSelectorNest);
 }
@@ -240,7 +244,11 @@ void OmpStructureChecker::CheckTraitSetSelector(
       if (maybeProps) {
         auto &[maybeScore, _]{maybeProps->t};
         if (maybeScore) {
-          CheckTraitScore(*maybeScore);
+          if (!config.allowsScore)
+            context_.Say(maybeScore->source,
+                "SCORE is not allowed for %s trait set"_err_en_US, usn);
+          else
+            CheckTraitScore(*maybeScore);
         }
       }
 
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index ff41f49d88b32..72e4a5c591306 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -71,6 +71,7 @@ OmpStructureChecker::OmpStructureChecker(SemanticsContext &context)
 
 void OmpStructureChecker::Enter(const parser::ProgramUnit &) { //
   ClearLabels();
+  declareVariantPairs_.clear();
 }
 
 void OmpStructureChecker::Enter(const parser::MainProgram &x) {
@@ -1714,59 +1715,6 @@ void OmpStructureChecker::Leave(const parser::OmpDeclareSimdDirective &) {
   dirContext_.pop_back();
 }
 
-void OmpStructureChecker::Enter(const parser::OmpDeclareVariantDirective &x) {
-  const parser::OmpDirectiveName &dirName{x.v.DirName()};
-  PushContextAndClauseSets(dirName.source, dirName.v);
-
-  const parser::OmpArgumentList &args{x.v.Arguments()};
-  if (args.v.size() != 1) {
-    context_.Say(args.source,
-        "DECLARE_VARIANT directive should have a single argument"_err_en_US);
-    return;
-  }
-
-  auto InvalidArgument{[&](parser::CharBlock source) {
-    context_.Say(source,
-        "The argument to the DECLARE_VARIANT directive should be [base-name:]variant-name"_err_en_US);
-  }};
-
-  auto CheckSymbol{[&](const Symbol *sym, parser::CharBlock source) {
-    if (sym) {
-      if (!IsProcedure(*sym) && !IsFunction(*sym)) {
-        auto &msg{context_.Say(source,
-            "The name '%s' should refer to a procedure"_err_en_US,
-            sym->name())};
-        if (sym->test(Symbol::Flag::Implicit)) {
-          msg.Attach(source, "The name '%s' has been implicitly declared"_en_US,
-              sym->name());
-        }
-      }
-    } else {
-      InvalidArgument(source);
-    }
-  }};
-
-  const parser::OmpArgument &arg{args.v.front()};
-  common::visit( //
-      common::visitors{
-          [&](const parser::OmpBaseVariantNames &y) {
-            CheckSymbol(GetObjectSymbol(std::get<0>(y.t), /*ultimate=*/true),
-                arg.source);
-            CheckSymbol(GetObjectSymbol(std::get<1>(y.t), /*ultimate=*/true),
-                arg.source);
-          },
-          [&](const parser::OmpLocator &y) {
-            CheckSymbol(GetArgumentSymbol(arg, /*ultimate=*/true), arg.source);
-          },
-          [&](auto &&y) { InvalidArgument(arg.source); },
-      },
-      arg.u);
-}
-
-void OmpStructureChecker::Leave(const parser::OmpDeclareVariantDirective &) {
-  dirContext_.pop_back();
-}
-
 void OmpStructureChecker::CheckInitOnDepobj(
     const parser::OpenMPDepobjConstruct &depobj,
     const parser::OmpClause &initClause) {
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index b3f58335027de..3aa1af192884e 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -127,6 +127,9 @@ class OmpStructureChecker : public OmpStructureCheckerBase {
 
   void Enter(const parser::OmpDeclareVariantDirective &);
   void Leave(const parser::OmpDeclareVariantDirective &);
+  void CheckOmpDeclareVariantDirective(
+      const parser::OmpDeclareVariantDirective &);
+  void CheckDeclareVariantUserConditions(const parser::OmpContextSelector &);
   void Enter(const parser::OmpDeclareSimdDirective &);
   void Leave(const parser::OmpDeclareSimdDirective &);
   void Enter(const parser::OmpAllocateDirective &);
@@ -265,6 +268,7 @@ class OmpStructureChecker : public OmpStructureCheckerBase {
       const parser::OmpTraitProperty &);
 
   void CheckTraitSelectorList(const std::list<parser::OmpTraitSelector> &);
+  void CheckContextSelectorSpecification(const parser::OmpContextSelector &);
   void CheckTraitSetSelector(const parser::OmpTraitSetSelector &);
   void CheckTraitScore(const parser::OmpTraitScore &);
   bool VerifyTraitPropertyLists(
@@ -417,6 +421,8 @@ class OmpStructureChecker : public OmpStructureCheckerBase {
   };
   int directiveNest_[LastType + 1] = {0};
 
+  std::set<std::pair<const Symbol *, const Symbol *>> declareVariantPairs_;
+
   int allocateDirectiveLevel_{0};
   parser::CharBlock visitedAtomicSource_;
 
diff --git a/flang/test/Semantics/OpenMP/declare-variant-match.f90 b/flang/test/Semantics/OpenMP/declare-variant-match.f90
new file mode 100644
index 0000000000000..37f6543dd89da
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/declare-variant-match.f90
@@ -0,0 +1,118 @@
+! RUN: %python %S/../test_errors.py %s %flang -fopenmp -fopenmp-version=52
+
+! MATCH clause checks for DECLARE VARIANT: required/duplicate clause and
+! context-selector validation (shared with METADIRECTIVE).
+
+subroutine f00
+  !$omp declare variant (sub:vsub) &
+  !$omp & match (implementation={vendor("this")}, &
+!ERROR: Repeated trait set name IMPLEMENTATION in a context specifier
+  !$omp &       implementation={requires(unified_shared_memory)})
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+  end subroutine
+end subroutine
+
+subroutine f01
+  !$omp declare variant (sub:vsub) &
+!ERROR: Repeated trait name ISA in a trait set
+  !$omp & match (device={isa("this"), isa("that")})
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+  end subroutine
+end subroutine
+
+subroutine f02
+  !$omp declare variant (sub:vsub) &
+!ERROR: SCORE expression must be a non-negative constant integer expression
+  !$omp & match (user={condition(score(-2): .true.)})
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+  end subroutine
+end subroutine
+
+subroutine f03(x)
+  integer :: x
+  !$omp declare variant (sub:vsub) &
+!ERROR: SCORE expression must be a non-negative constant integer expression
+  !$omp & match (user={condition(score(x): .true.)})
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+  end subroutine
+end subroutine
+
+subroutine f04
+  !$omp declare variant (sub:vsub) &
+!ERROR: Trait property should be a scalar expression
+!ERROR: More invalid properties are present
+  !$omp & match (target_device={device_num("device", "foo"(1))})
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+  end subroutine
+end subroutine
+
+subroutine f05(x)
+  integer :: x
+  !$omp declare variant (sub:vsub) &
+  !$omp & match (user={ &
+!ERROR: CONDITION trait requires a single LOGICAL expression
+  !$omp & condition(score(2): x)})
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+  end subroutine
+end subroutine
+
+subroutine f06(x)
+  integer :: x
+!ERROR: USER condition in the MATCH clause must be a constant expression
+  !$omp declare variant (sub:vsub) match (user={condition(x > 0)})
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+  end subroutine
+end subroutine
+
+subroutine f07
+  !$omp declare variant (sub:vsub) &
+!ERROR: SCORE is not allowed for DEVICE trait set
+  !$omp & match (device={kind(score(1): host)})
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+  end subroutine
+end subroutine
+
+subroutine f08
+!ERROR: DECLARE_VARIANT directive requires a MATCH clause
+  !$omp declare variant (sub:vsub)
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+  end subroutine
+end subroutine
+
+subroutine f09
+  !$omp declare variant (sub:vsub) match (construct={parallel}) &
+!ERROR: At most one MATCH clause can appear on the DECLARE VARIANT directive
+  !$omp & match (construct={teams})
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+  end subroutine
+end subroutine
diff --git a/flang/test/Semantics/OpenMP/declare-variant.f90 b/flang/test/Semantics/OpenMP/declare-variant.f90
index 6fc94a4fb837f..443f767f73244 100644
--- a/flang/test/Semantics/OpenMP/declare-variant.f90
+++ b/flang/test/Semantics/OpenMP/declare-variant.f90
@@ -12,3 +12,33 @@ subroutine vsub
   subroutine sub ()
   end subroutine
 end subroutine
+
+subroutine same_base_variant
+!ERROR: The variant procedure must differ from the base procedure
+  !$omp declare variant (sub:sub) match (construct={parallel})
+contains
+  subroutine sub
+  end subroutine
+end subroutine
+
+subroutine duplicate_variant
+  !$omp declare variant (sub:vsub) match (construct={parallel})
+!ERROR: Variant 'vsub' was already specified for 'sub' in another DECLARE VARIANT directive
+  !$omp declare variant (sub:vsub) match (construct={teams})
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+  end subroutine
+end subroutine
+
+subroutine invalid_clause
+!ERROR: PRIVATE clause is not allowed on the DECLARE VARIANT directive
+  !$omp declare variant (sub:vsub) match (construct={parallel}) private(x)
+contains
+  subroutine vsub
+  end subroutine
+  subroutine sub
+    integer :: x
+  end subroutine
+end subroutine

>From 425704fe8772142e51a51f9143f94f6c03247501 Mon Sep 17 00:00:00 2001
From: Abid Qadeer <haqadeer at amd.com>
Date: Fri, 29 May 2026 16:25:15 +0100
Subject: [PATCH 2/2] Handle review comments.

1. Remove lib/Semantics/check-omp-declare-variant.cpp and move the checks in check-omp-metadirective.cpp.

2. Move some functions into private from public  in OmpStructureChecker.

3. Update error for dynamic condition.
---
 flang/lib/Semantics/CMakeLists.txt            |   1 -
 .../Semantics/check-omp-declare-variant.cpp   | 183 ------------------
 .../lib/Semantics/check-omp-metadirective.cpp | 151 ++++++++++++++-
 flang/lib/Semantics/check-omp-structure.h     |   6 +-
 .../OpenMP/declare-variant-match.f90          |   2 +-
 5 files changed, 154 insertions(+), 189 deletions(-)
 delete mode 100644 flang/lib/Semantics/check-omp-declare-variant.cpp

diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt
index 42396a5b3b639..44e6dfb4dd09f 100644
--- a/flang/lib/Semantics/CMakeLists.txt
+++ b/flang/lib/Semantics/CMakeLists.txt
@@ -23,7 +23,6 @@ add_flang_library(FortranSemantics
   check-omp-atomic.cpp
   check-omp-loop.cpp
   check-omp-metadirective.cpp
-  check-omp-declare-variant.cpp
   check-omp-structure.cpp
   check-purity.cpp
   check-return.cpp
diff --git a/flang/lib/Semantics/check-omp-declare-variant.cpp b/flang/lib/Semantics/check-omp-declare-variant.cpp
deleted file mode 100644
index f986dfc710aae..0000000000000
--- a/flang/lib/Semantics/check-omp-declare-variant.cpp
+++ /dev/null
@@ -1,183 +0,0 @@
-//===-- lib/Semantics/check-omp-declare-variant.cpp -----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Structure checks for DECLARE VARIANT.
-//
-//===----------------------------------------------------------------------===//
-
-#include "check-omp-structure.h"
-
-#include "flang/Common/idioms.h"
-#include "flang/Common/visit.h"
-#include "flang/Evaluate/check-expression.h"
-#include "flang/Parser/parse-tree.h"
-#include "flang/Semantics/openmp-utils.h"
-#include "flang/Semantics/symbol.h"
-#include "flang/Semantics/tools.h"
-
-#include "llvm/Frontend/OpenMP/OMP.h"
-
-namespace Fortran::semantics {
-
-using namespace Fortran::semantics::omp;
-
-static const parser::traits::OmpContextSelectorSpecification *
-getMatchClauseContextSelector(const parser::OmpDirectiveSpecification &spec) {
-  for (const parser::OmpClause &clause : spec.Clauses().v) {
-    if (clause.Id() == llvm::omp::Clause::OMPC_match)
-      return &std::get<parser::OmpClause::Match>(clause.u).v.v;
-  }
-  return nullptr;
-}
-
-void OmpStructureChecker::CheckDeclareVariantUserConditions(
-    const parser::OmpContextSelector &ctx) {
-  using SetName = parser::OmpTraitSetSelectorName;
-  using TraitName = parser::OmpTraitSelectorName;
-
-  for (const parser::OmpTraitSetSelector &traitSet : ctx.v) {
-    if (std::get<SetName>(traitSet.t).v != SetName::Value::User) {
-      continue;
-    }
-    for (const parser::OmpTraitSelector &trait :
-        std::get<std::list<parser::OmpTraitSelector>>(traitSet.t)) {
-      const auto &traitName{std::get<TraitName>(trait.t)};
-      if (!std::holds_alternative<TraitName::Value>(traitName.u) ||
-          std::get<TraitName::Value>(traitName.u) !=
-              TraitName::Value::Condition) {
-        continue;
-      }
-      const auto &maybeProps{
-          std::get<std::optional<parser::OmpTraitSelector::Properties>>(
-              trait.t)};
-      if (!maybeProps) {
-        continue;
-      }
-      const auto &properties{
-          std::get<std::list<parser::OmpTraitProperty>>(maybeProps->t)};
-      if (properties.size() != 1) {
-        continue;
-      }
-      const parser::OmpTraitProperty &property{properties.front()};
-      const parser::ScalarExpr &scalarExpr{
-          std::get<parser::ScalarExpr>(property.u)};
-      auto maybeType{GetDynamicType(scalarExpr.thing.value())};
-      if (!maybeType || maybeType->category() != TypeCategory::Logical) {
-        continue;
-      }
-      if (const auto *expr{GetExpr(scalarExpr)}) {
-        if (!IsConstantExpr(*expr, &context_.foldingContext())) {
-          context_.Say(property.source,
-              "USER condition in the MATCH clause must be a constant expression"_err_en_US);
-        }
-      }
-    }
-  }
-}
-
-void OmpStructureChecker::CheckOmpDeclareVariantDirective(
-    const parser::OmpDeclareVariantDirective &x) {
-  const parser::OmpDirectiveSpecification &spec{x.v};
-  const parser::OmpArgumentList &args{spec.Arguments()};
-
-  if (args.v.size() != 1) {
-    context_.Say(args.source,
-        "DECLARE_VARIANT directive should have a single argument"_err_en_US);
-    return;
-  }
-
-  auto InvalidArgument{[&](parser::CharBlock source) {
-    context_.Say(source,
-        "The argument to the DECLARE_VARIANT directive should be [base-name:]variant-name"_err_en_US);
-  }};
-
-  auto CheckProcedureSymbol{[&](const Symbol *sym, parser::CharBlock source) {
-    if (sym) {
-      if (!IsProcedure(*sym) && !IsFunction(*sym)) {
-        auto &msg{context_.Say(source,
-            "The name '%s' should refer to a procedure"_err_en_US,
-            sym->name())};
-        if (sym->test(Symbol::Flag::Implicit)) {
-          msg.Attach(source, "The name '%s' has been implicitly declared"_en_US,
-              sym->name());
-        }
-      }
-    } else {
-      InvalidArgument(source);
-    }
-  }};
-
-  const Symbol *base{nullptr};
-  const Symbol *variant{nullptr};
-  const parser::OmpArgument &arg{args.v.front()};
-  common::visit( //
-      common::visitors{
-          [&](const parser::OmpBaseVariantNames &y) {
-            base = GetObjectSymbol(std::get<0>(y.t));
-            variant = GetObjectSymbol(std::get<1>(y.t));
-            CheckProcedureSymbol(base, arg.source);
-            CheckProcedureSymbol(variant, arg.source);
-          },
-          [&](const parser::OmpLocator &y) {
-            variant = GetArgumentSymbol(arg);
-            CheckProcedureSymbol(variant, arg.source);
-            // OpenMP 5.1 [2.3.5, declare variant directive, Restrictions]:
-            // "If base-proc-name is omitted then the declare variant directive
-            // must appear in an interface block or the specification part of a
-            // procedure." The same section requires the directive to appear in
-            // the specification part of the subprogram or interface body to
-            // which it applies. Infer the base procedure from that program
-            // unit.
-            const Scope &containingScope{context_.FindScope(x.source)};
-            if (const Symbol *host{
-                    GetProgramUnitContaining(containingScope).symbol()}) {
-              base = host;
-            }
-          },
-          [&](auto &&y) { InvalidArgument(arg.source); },
-      },
-      arg.u);
-
-  if (base && variant) {
-    base = &base->GetUltimate();
-    variant = &variant->GetUltimate();
-    if (base == variant) {
-      context_.Say(arg.source,
-          "The variant procedure must differ from the base procedure"_err_en_US);
-    } else if (!declareVariantPairs_.emplace(base, variant).second) {
-      context_.Say(arg.source,
-          "Variant '%s' was already specified for '%s' in another DECLARE VARIANT directive"_err_en_US,
-          variant->name(), base->name());
-    }
-  }
-
-  const parser::traits::OmpContextSelectorSpecification *matchSelector{
-      getMatchClauseContextSelector(spec)};
-  if (!matchSelector) {
-    context_.Say(x.source,
-        "DECLARE_VARIANT directive requires a MATCH clause"_err_en_US);
-    return;
-  }
-
-  EnterDirectiveNest(ContextSelectorNest);
-  CheckContextSelectorSpecification(*matchSelector);
-  CheckDeclareVariantUserConditions(*matchSelector);
-  ExitDirectiveNest(ContextSelectorNest);
-}
-
-void OmpStructureChecker::Enter(const parser::OmpDeclareVariantDirective &x) {
-  const parser::OmpDirectiveName &dirName{x.v.DirName()};
-  PushContextAndClauseSets(dirName.source, dirName.v);
-  CheckOmpDeclareVariantDirective(x);
-}
-
-void OmpStructureChecker::Leave(const parser::OmpDeclareVariantDirective &) {
-  dirContext_.pop_back();
-}
-
-} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/check-omp-metadirective.cpp b/flang/lib/Semantics/check-omp-metadirective.cpp
index d308c2ee7cac5..163252c57695a 100644
--- a/flang/lib/Semantics/check-omp-metadirective.cpp
+++ b/flang/lib/Semantics/check-omp-metadirective.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Semantic checks for METADIRECTIVE and related constructs/clauses.
+// Semantic checks for METADIRECTIVE, DECLARE VARIANT, and related constructs.
 //
 //===----------------------------------------------------------------------===//
 
@@ -15,11 +15,13 @@
 #include "flang/Common/idioms.h"
 #include "flang/Common/indirection.h"
 #include "flang/Common/visit.h"
+#include "flang/Evaluate/check-expression.h"
 #include "flang/Parser/characters.h"
 #include "flang/Parser/message.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/openmp-modifiers.h"
 #include "flang/Semantics/openmp-utils.h"
+#include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 
 #include "llvm/Frontend/OpenMP/OMP.h"
@@ -602,4 +604,151 @@ void OmpStructureChecker::Leave(
   dirContext_.pop_back();
 }
 
+static const parser::traits::OmpContextSelectorSpecification *
+getMatchClauseContextSelector(const parser::OmpDirectiveSpecification &spec) {
+  for (const parser::OmpClause &clause : spec.Clauses().v) {
+    if (clause.Id() == llvm::omp::Clause::OMPC_match)
+      return &std::get<parser::OmpClause::Match>(clause.u).v.v;
+  }
+  return nullptr;
+}
+
+void OmpStructureChecker::CheckDeclareVariantUserConditions(
+    const parser::OmpContextSelector &ctx) {
+  using SetName = parser::OmpTraitSetSelectorName;
+  using TraitName = parser::OmpTraitSelectorName;
+
+  for (const parser::OmpTraitSetSelector &traitSet : ctx.v) {
+    if (std::get<SetName>(traitSet.t).v != SetName::Value::User) {
+      continue;
+    }
+    for (const parser::OmpTraitSelector &trait :
+        std::get<std::list<parser::OmpTraitSelector>>(traitSet.t)) {
+      const auto &traitName{std::get<TraitName>(trait.t)};
+      if (!std::holds_alternative<TraitName::Value>(traitName.u) ||
+          std::get<TraitName::Value>(traitName.u) !=
+              TraitName::Value::Condition) {
+        continue;
+      }
+      const auto &maybeProps{
+          std::get<std::optional<parser::OmpTraitSelector::Properties>>(
+              trait.t)};
+      if (!maybeProps) {
+        continue;
+      }
+      const auto &properties{
+          std::get<std::list<parser::OmpTraitProperty>>(maybeProps->t)};
+      if (properties.size() != 1) {
+        continue;
+      }
+      const parser::OmpTraitProperty &property{properties.front()};
+      const parser::ScalarExpr &scalarExpr{
+          std::get<parser::ScalarExpr>(property.u)};
+      auto maybeType{GetDynamicType(scalarExpr.thing.value())};
+      if (!maybeType || maybeType->category() != TypeCategory::Logical) {
+        continue;
+      }
+      if (const auto *expr{GetExpr(scalarExpr)}) {
+        if (!IsConstantExpr(*expr, &context_.foldingContext())) {
+          context_.Say(property.source,
+              "Run-time USER condition in the MATCH clause is not yet implemented"_err_en_US);
+        }
+      }
+    }
+  }
+}
+
+void OmpStructureChecker::CheckOmpDeclareVariantDirective(
+    const parser::OmpDeclareVariantDirective &x) {
+  const parser::OmpDirectiveSpecification &spec{x.v};
+  const parser::OmpArgumentList &args{spec.Arguments()};
+
+  if (args.v.size() != 1) {
+    context_.Say(args.source,
+        "DECLARE_VARIANT directive should have a single argument"_err_en_US);
+    return;
+  }
+
+  auto InvalidArgument{[&](parser::CharBlock source) {
+    context_.Say(source,
+        "The argument to the DECLARE_VARIANT directive should be [base-name:]variant-name"_err_en_US);
+  }};
+
+  auto CheckProcedureSymbol{[&](const Symbol *sym, parser::CharBlock source) {
+    if (sym) {
+      if (!IsProcedure(*sym) && !IsFunction(*sym)) {
+        auto &msg{context_.Say(source,
+            "The name '%s' should refer to a procedure"_err_en_US,
+            sym->name())};
+        if (sym->test(Symbol::Flag::Implicit)) {
+          msg.Attach(source, "The name '%s' has been implicitly declared"_en_US,
+              sym->name());
+        }
+      }
+    } else {
+      InvalidArgument(source);
+    }
+  }};
+
+  const Symbol *base{nullptr};
+  const Symbol *variant{nullptr};
+  const parser::OmpArgument &arg{args.v.front()};
+  common::visit( //
+      common::visitors{
+          [&](const parser::OmpBaseVariantNames &y) {
+            base = GetObjectSymbol(std::get<0>(y.t));
+            variant = GetObjectSymbol(std::get<1>(y.t));
+            CheckProcedureSymbol(base, arg.source);
+            CheckProcedureSymbol(variant, arg.source);
+          },
+          [&](const parser::OmpLocator &y) {
+            variant = GetArgumentSymbol(arg);
+            CheckProcedureSymbol(variant, arg.source);
+            const Scope &containingScope{context_.FindScope(x.source)};
+            if (const Symbol *
+                host{GetProgramUnitContaining(containingScope).symbol()}) {
+              base = host;
+            }
+          },
+          [&](auto &&y) { InvalidArgument(arg.source); },
+      },
+      arg.u);
+
+  if (base && variant) {
+    base = &base->GetUltimate();
+    variant = &variant->GetUltimate();
+    if (base == variant) {
+      context_.Say(arg.source,
+          "The variant procedure must differ from the base procedure"_err_en_US);
+    } else if (!declareVariantPairs_.emplace(base, variant).second) {
+      context_.Say(arg.source,
+          "Variant '%s' was already specified for '%s' in another DECLARE VARIANT directive"_err_en_US,
+          variant->name(), base->name());
+    }
+  }
+
+  const parser::traits::OmpContextSelectorSpecification *matchSelector{
+      getMatchClauseContextSelector(spec)};
+  if (!matchSelector) {
+    context_.Say(x.source,
+        "DECLARE_VARIANT directive requires a MATCH clause"_err_en_US);
+    return;
+  }
+
+  EnterDirectiveNest(ContextSelectorNest);
+  CheckContextSelectorSpecification(*matchSelector);
+  CheckDeclareVariantUserConditions(*matchSelector);
+  ExitDirectiveNest(ContextSelectorNest);
+}
+
+void OmpStructureChecker::Enter(const parser::OmpDeclareVariantDirective &x) {
+  const parser::OmpDirectiveName &dirName{x.v.DirName()};
+  PushContextAndClauseSets(dirName.source, dirName.v);
+  CheckOmpDeclareVariantDirective(x);
+}
+
+void OmpStructureChecker::Leave(const parser::OmpDeclareVariantDirective &) {
+  dirContext_.pop_back();
+}
+
 } // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 3aa1af192884e..2e8db65630fa3 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -127,9 +127,6 @@ class OmpStructureChecker : public OmpStructureCheckerBase {
 
   void Enter(const parser::OmpDeclareVariantDirective &);
   void Leave(const parser::OmpDeclareVariantDirective &);
-  void CheckOmpDeclareVariantDirective(
-      const parser::OmpDeclareVariantDirective &);
-  void CheckDeclareVariantUserConditions(const parser::OmpContextSelector &);
   void Enter(const parser::OmpDeclareSimdDirective &);
   void Leave(const parser::OmpDeclareSimdDirective &);
   void Enter(const parser::OmpAllocateDirective &);
@@ -262,6 +259,9 @@ class OmpStructureChecker : public OmpStructureCheckerBase {
   void CheckDistLinear(const parser::OpenMPLoopConstruct &x);
 
   // check-omp-metadirective.cpp
+  void CheckOmpDeclareVariantDirective(
+      const parser::OmpDeclareVariantDirective &);
+  void CheckDeclareVariantUserConditions(const parser::OmpContextSelector &);
   const std::list<parser::OmpTraitProperty> &GetTraitPropertyList(
       const parser::OmpTraitSelector &);
   std::optional<llvm::omp::Clause> GetClauseFromProperty(
diff --git a/flang/test/Semantics/OpenMP/declare-variant-match.f90 b/flang/test/Semantics/OpenMP/declare-variant-match.f90
index 37f6543dd89da..199d05f1750ee 100644
--- a/flang/test/Semantics/OpenMP/declare-variant-match.f90
+++ b/flang/test/Semantics/OpenMP/declare-variant-match.f90
@@ -76,7 +76,7 @@ subroutine sub
 
 subroutine f06(x)
   integer :: x
-!ERROR: USER condition in the MATCH clause must be a constant expression
+!ERROR: Run-time USER condition in the MATCH clause is not yet implemented
   !$omp declare variant (sub:vsub) match (user={condition(x > 0)})
 contains
   subroutine vsub



More information about the flang-commits mailing list