[flang-commits] [flang] [flang][OpenMP] Declare variant (PR #204345)

via flang-commits flang-commits at lists.llvm.org
Wed Jun 17 05:19:02 PDT 2026


https://github.com/SunilKuravinakop created https://github.com/llvm/llvm-project/pull/204345

Support for declare variant.

This also Fixes [#203833](https://github.com/llvm/llvm-project/issues/203833)

>From 759aa753b6506a053fdc1c5afd1893c26768b9f8 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe31.hpc.amslabs.hpecorp.net>
Date: Wed, 17 Jun 2026 06:05:40 -0500
Subject: [PATCH] Support for declare variant.

---
 flang/include/flang/Semantics/semantics.h     |  24 ++
 flang/lib/Lower/Bridge.cpp                    |   8 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 394 +++++++++++++++++-
 flang/lib/Semantics/check-omp-variant.cpp     |  12 +-
 .../Integration/OpenMP/declare-variant.f90    | 112 +++++
 .../Lower/OpenMP/Todo/declare-variant.f90     |  17 -
 .../OpenMP/declare-variant-construct.f90      | 137 ++++++
 .../OpenMP/declare-variant-match.f90          |   1 -
 8 files changed, 676 insertions(+), 29 deletions(-)
 create mode 100644 flang/test/Integration/OpenMP/declare-variant.f90
 delete mode 100644 flang/test/Lower/OpenMP/Todo/declare-variant.f90
 create mode 100644 flang/test/Lower/OpenMP/declare-variant-construct.f90

diff --git a/flang/include/flang/Semantics/semantics.h b/flang/include/flang/Semantics/semantics.h
index 5e98378fc49be..aa38baa7a941f 100644
--- a/flang/include/flang/Semantics/semantics.h
+++ b/flang/include/flang/Semantics/semantics.h
@@ -49,6 +49,7 @@ struct SelectRankConstruct;
 struct SelectTypeConstruct;
 struct Variable;
 struct WhereConstruct;
+struct OmpDirectiveSpecification;
 } // namespace Fortran::parser
 
 namespace Fortran::semantics {
@@ -57,6 +58,13 @@ class Symbol;
 class CommonBlockMap;
 using CommonBlockList = std::vector<std::pair<SymbolRef, std::size_t>>;
 
+// Information about a single declare variant directive associated with a
+// base procedure.
+struct OmpDeclareVariantInfo {
+  SymbolRef variantProcedure;
+  const parser::OmpDirectiveSpecification *spec;
+};
+
 using ConstructNode = std::variant<const parser::AssociateConstruct *,
     const parser::BlockConstruct *, const parser::CaseConstruct *,
     const parser::ChangeTeamConstruct *, const parser::CriticalConstruct *,
@@ -350,6 +358,19 @@ class SemanticsContext {
 
   void DumpSymbols(llvm::raw_ostream &);
 
+  // OpenMP declare variant support: register variant info for a base
+  // procedure, and query it during lowering.
+  void AddOmpDeclareVariant(const Symbol &base, OmpDeclareVariantInfo info) {
+    ompDeclareVariants_[&base].push_back(std::move(info));
+  }
+  const std::vector<OmpDeclareVariantInfo> *GetOmpDeclareVariants(
+      const Symbol &base) const {
+    auto it = ompDeclareVariants_.find(&base);
+    if (it != ompDeclareVariants_.end())
+      return &it->second;
+    return nullptr;
+  }
+
   // Top-level ProgramTrees are owned by the SemanticsContext for persistence.
   ProgramTree &SaveProgramTree(ProgramTree &&);
 
@@ -409,6 +430,9 @@ class SemanticsContext {
   UnorderedSymbolSet isUsed_;
   std::set<const parser::AccObject *> accObjectDuplicates_;
   std::list<ProgramTree> programTrees_;
+  // OpenMP declare variant: base symbol -> list of variant info
+  std::map<const Symbol *, std::vector<OmpDeclareVariantInfo>>
+      ompDeclareVariants_;
 };
 
 class Semantics {
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 6078af78ace24..c127bc8019151 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4102,7 +4102,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         ompDeferredDeclareTarget);
     genOpenMPDeclarativeConstruct(
         *this, localSymbols, bridge.getSemanticsContext(), getEval(), ompDecl);
-    builder->restoreInsertionPoint(insertPt);
+    // In case of declare variant related lowering with runtime user
+    // conditions (if/else) the original block is split. This is done
+    // by creating conditional branches that terminate the current block,
+    // making the saved insertion point invalid for subsequent insertions.
+    // Otherwise restore the insertion point.
+    if (builder->getInsertionBlock() == insertPt.getBlock())
+      builder->restoreInsertionPoint(insertPt);
   }
 
   /// Generate FIR for a SELECT CASE statement.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 60b7fdc0e5322..f5ffdc77b24b5 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1263,6 +1263,11 @@ struct OpWithBodyGenInfo {
   bool privatize = true;
 };
 
+// Forward declaration for variant-aware evaluation processing.
+static void genNestedEvaluationsWithVariantResolution(
+    lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, llvm::omp::Directive constructDir);
+
 /// Create the body (block) for an OpenMP Operation.
 ///
 /// \param [in]   op  - the operation the body belongs to.
@@ -1355,7 +1360,19 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
       firOpBuilder.setInsertionPointToEnd(&op.getRegion(0).back());
       auto *temp = lower::genOpenMPTerminator(firOpBuilder, &op, info.loc);
       firOpBuilder.setInsertionPointAfter(marker);
-      genNestedEvaluations(info.converter, info.eval);
+      // For constructs that support declare variant resolution (parallel,
+      // teams, dispatch), use variant-aware evaluation processing.
+      static constexpr llvm::omp::Directive variantResolutionDirs[] = {
+          llvm::omp::Directive::OMPD_parallel,
+          llvm::omp::Directive::OMPD_teams,
+          llvm::omp::Directive::OMPD_dispatch,
+      };
+      if (llvm::is_contained(variantResolutionDirs, info.dir)) {
+        genNestedEvaluationsWithVariantResolution(info.converter, info.semaCtx,
+                                                  info.eval, info.dir);
+      } else {
+        genNestedEvaluations(info.converter, info.eval);
+      }
       temp->erase();
     }
   }
@@ -4098,12 +4115,377 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
     TODO(converter.getCurrentLocation(), "OpenMP ASSUMES declaration");
 }
 
+// Return pointer to match() clause from !omp declare variant
+static const parser::traits::OmpContextSelectorSpecification *
+getMatchSelectorInDeclareVariant(
+    const parser::OmpDirectiveSpecification &declareVariant) {
+  auto matchClauseIt = llvm::find_if(
+      declareVariant.Clauses().v, [](const parser::OmpClause &clause) {
+        return clause.Id() == llvm::omp::Clause::OMPC_match;
+      });
+
+  if (matchClauseIt == declareVariant.Clauses().v.end())
+    return nullptr;
+
+  return &std::get<parser::OmpClause::Match>(matchClauseIt->u).v.v;
+}
+
+// Does this declare variant (declareVariantSpec) have a directive which matches
+// a particular constructDirective (e.g. dispatch)?
+static bool doesDeclareVariantMatchDirective(
+    const parser::OmpDirectiveSpecification &declareVariantSpec,
+    llvm::omp::Directive expectedConstructDir) {
+
+  const auto *matchSelector =
+      getMatchSelectorInDeclareVariant(declareVariantSpec);
+  if (!matchSelector)
+    return false;
+
+  using TraitSetName = parser::OmpTraitSetSelectorName;
+
+  return llvm::any_of(matchSelector->v, [expectedConstructDir](
+                                            const auto &traitSetSelector) {
+    if (std::get<TraitSetName>(traitSetSelector.t).v !=
+        TraitSetName::Value::Construct)
+      return false;
+
+    const auto &constructTraits =
+        std::get<std::list<parser::OmpTraitSelector>>(traitSetSelector.t);
+
+    return llvm::any_of(
+        constructTraits,
+        [expectedConstructDir](const parser::OmpTraitSelector &constructTrait) {
+          const auto &traitName =
+              std::get<parser::OmpTraitSelectorName>(constructTrait.t);
+          const auto *constructDir =
+              std::get_if<llvm::omp::Directive>(&traitName.u);
+          return constructDir && *constructDir == expectedConstructDir;
+        });
+  });
+}
+
+// Check if a context selector includes any construct trait.
+static bool hasAnyConstructTrait(
+    const parser::OmpDirectiveSpecification &declareVariantSpec) {
+  const auto *matchSelector =
+      getMatchSelectorInDeclareVariant(declareVariantSpec);
+  if (!matchSelector)
+    return false;
+  using SetName = parser::OmpTraitSetSelectorName;
+
+  return llvm::any_of(matchSelector->v, [](const auto &traitSetSelector) {
+    return std::get<SetName>(traitSetSelector.t).v == SetName::Value::Construct;
+  });
+}
+
+// Return the "expr" from match(user={condition(expr)}) in a
+// declare variant directive, or nullptr if no such condition exists.
+static const parser::ScalarExpr *getDeclareVariantUserConditionExpr(
+    const parser::OmpDirectiveSpecification &declareVariantSpec) {
+  const auto *matchSelector =
+      getMatchSelectorInDeclareVariant(declareVariantSpec);
+  if (!matchSelector)
+    return nullptr;
+
+  using TraitSetName = parser::OmpTraitSetSelectorName;
+  using TraitName = parser::OmpTraitSelectorName;
+
+  auto userTraitSetIt =
+      llvm::find_if(matchSelector->v, [](const auto &traitSetSelector) {
+        return std::get<TraitSetName>(traitSetSelector.t).v ==
+               TraitSetName::Value::User;
+      });
+  if (userTraitSetIt == matchSelector->v.end())
+    return nullptr;
+
+  const auto &userTraits =
+      std::get<std::list<parser::OmpTraitSelector>>(userTraitSetIt->t);
+  auto conditionTraitIt =
+      llvm::find_if(userTraits, [](const parser::OmpTraitSelector &trait) {
+        const auto &traitName = std::get<TraitName>(trait.t);
+        return std::holds_alternative<TraitName::Value>(traitName.u) &&
+               std::get<TraitName::Value>(traitName.u) ==
+                   TraitName::Value::Condition;
+      });
+  if (conditionTraitIt == userTraits.end())
+    return nullptr;
+
+  const auto &properties =
+      std::get<std::optional<parser::OmpTraitSelector::Properties>>(
+          conditionTraitIt->t);
+  if (!properties)
+    return nullptr;
+  const auto &traitProperties =
+      std::get<std::list<parser::OmpTraitProperty>>(properties->t);
+
+  auto scalarExprIt =
+      llvm::find_if(traitProperties, [](const parser::OmpTraitProperty &prop) {
+        return std::holds_alternative<parser::ScalarExpr>(prop.u);
+      });
+  if (scalarExprIt == traitProperties.end())
+    return nullptr;
+
+  return &std::get<parser::ScalarExpr>(scalarExprIt->u);
+}
+
+// Return the ultimate procedure symbol referenced by a CallStmt,
+// or nullptr if the procedure cannot be resolved.
+static const semantics::Symbol *
+getCallProcedureSymbol(const parser::CallStmt &callStmt) {
+  const auto &call = std::get<parser::Call>(callStmt.t);
+  const auto &procedureDesignator =
+      std::get<parser::ProcedureDesignator>(call.t);
+  if (const auto *procedureName =
+          std::get_if<parser::Name>(&procedureDesignator.u)) {
+    if (procedureName->symbol)
+      return &procedureName->symbol->GetUltimate();
+  }
+  return nullptr;
+}
+
+// Generate variant-aware nested evaluations for a construct region.
+// For each CallStmt, check if the called procedure has declare variant
+// directives matching the given construct directive. If so, generate
+// if/else cascade for variant selection; otherwise generate the call normally.
+static void genNestedEvaluationsWithVariantResolution(
+    lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, llvm::omp::Directive constructDir) {
+  for (lower::pft::Evaluation &nestedEval : eval.getNestedEvaluations()) {
+    const parser::CallStmt *callStmt = nestedEval.getIf<parser::CallStmt>();
+    if (!callStmt) {
+      converter.genEval(nestedEval);
+      continue;
+    }
+
+    const semantics::Symbol *baseSym = getCallProcedureSymbol(*callStmt);
+    if (!baseSym) {
+      converter.genEval(nestedEval);
+      continue;
+    }
+
+    const auto *declareVariants = semaCtx.GetOmpDeclareVariants(*baseSym);
+    if (!declareVariants) {
+      converter.genEval(nestedEval);
+      continue;
+    }
+
+    // Collect variants matching the current construct directive.
+    struct ConstructVariant {
+      const semantics::Symbol *variantSym;
+      const parser::ScalarExpr *userCond;
+    };
+    llvm::SmallVector<ConstructVariant, 4> matchingVariants;
+    for (const auto &variantInfo : *declareVariants) {
+      if (variantInfo.spec &&
+          doesDeclareVariantMatchDirective(*variantInfo.spec, constructDir)) {
+        matchingVariants.push_back(
+            {&variantInfo.variantProcedure.get(),
+             getDeclareVariantUserConditionExpr(*variantInfo.spec)});
+      }
+    }
+
+    if (matchingVariants.empty()) {
+      converter.genEval(nestedEval);
+      continue;
+    }
+
+    // Generate if/else cascade for variant selection.
+    fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+    mlir::Location loc = converter.getCurrentLocation();
+    lower::StatementContext stmtCtx;
+
+    for (const auto &matchedVariant : matchingVariants) {
+      if (matchedVariant.userCond) {
+        const auto *condExpr =
+            semantics::GetExpr(semaCtx, *matchedVariant.userCond);
+        if (!condExpr)
+          continue;
+        mlir::Value condVal =
+            fir::getBase(converter.genExprValue(*condExpr, stmtCtx, &loc));
+        if (condVal.getType() != builder.getI1Type())
+          condVal = builder.createConvert(loc, builder.getI1Type(), condVal);
+        stmtCtx.finalizeAndReset();
+
+        auto ifOp = fir::IfOp::create(builder, loc, condVal,
+                                      /*withElseRegion=*/true);
+        builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+        std::string variantName =
+            converter.mangleName(*matchedVariant.variantSym);
+        auto variantFunc =
+            builder.getModule().lookupSymbol<mlir::func::FuncOp>(variantName);
+        if (!variantFunc) {
+          auto funcType = builder.getFunctionType({}, {});
+          variantFunc =
+              mlir::func::FuncOp::create(builder, loc, variantName, funcType);
+          variantFunc.setVisibility(mlir::SymbolTable::Visibility::Private);
+          builder.getModule().push_back(variantFunc);
+        }
+        fir::CallOp::create(builder, loc, variantFunc);
+
+        builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+      } else {
+        // No condition: unconditionally call the variant.
+        std::string variantName =
+            converter.mangleName(*matchedVariant.variantSym);
+        auto variantFunc =
+            builder.getModule().lookupSymbol<mlir::func::FuncOp>(variantName);
+        if (!variantFunc) {
+          auto funcType = builder.getFunctionType({}, {});
+          variantFunc =
+              mlir::func::FuncOp::create(builder, loc, variantName, funcType);
+          variantFunc.setVisibility(mlir::SymbolTable::Visibility::Private);
+          builder.getModule().push_back(variantFunc);
+        }
+        fir::CallOp::create(builder, loc, variantFunc);
+        // Skip the base call since variant is unconditional.
+        goto next_eval;
+      }
+    }
+
+    // Generate the base call in the final else-branch.
+    converter.genEval(nestedEval);
+  next_eval:;
+  }
+}
+
+// Lower an OpenMP declare variant directive for the base procedure body.
+static void genDeclareVariant(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    const parser::OmpDeclareVariantDirective &declareVariantDirective) {
+
+  const parser::OmpDirectiveSpecification &declareVariantSpec{
+      declareVariantDirective.v};
+  mlir::Location loc = converter.genLocation(declareVariantSpec.source);
+
+  // Variants with construct traits (dispatch, parallel, teams)
+  // are resolved at the corresponding construct call sites, not inside the
+  // base procedure body.
+  if (hasAnyConstructTrait(declareVariantSpec))
+    return;
+
+  // Extract the variant procedure symbol from the argument.
+  const parser::OmpArgumentList &declareVariantArgs{
+      declareVariantSpec.Arguments()};
+  if (declareVariantArgs.v.size() != 1)
+    return;
+  const semantics::Symbol *variantSym = nullptr;
+  const semantics::Symbol *baseProcedureSym = nullptr;
+  const parser::OmpArgument &declareVariantArg{declareVariantArgs.v.front()};
+
+  // Extract a symbol from an OmpObject.
+  auto getDeclareVariantObjectSymbol =
+      [](const parser::OmpObject &obj) -> const semantics::Symbol * {
+    return common::visit(
+        common::visitors{
+            [](const parser::Designator &d) -> const semantics::Symbol * {
+              if (const auto *dr = std::get_if<parser::DataRef>(&d.u))
+                if (const auto *n = std::get_if<parser::Name>(&dr->u))
+                  return n->symbol;
+              return nullptr;
+            },
+            [](const parser::Name &n) -> const semantics::Symbol * {
+              return n.symbol;
+            },
+            [](const parser::OmpObject::Invalid &)
+                -> const semantics::Symbol * { return nullptr; },
+        },
+        obj.u);
+  };
+
+  common::visit(common::visitors{
+                    [&](const parser::OmpBaseVariantNames &y) {
+                      if (const auto *sym =
+                              getDeclareVariantObjectSymbol(std::get<0>(y.t)))
+                        baseProcedureSym = &sym->GetUltimate();
+                      if (const auto *sym =
+                              getDeclareVariantObjectSymbol(std::get<1>(y.t)))
+                        variantSym = &sym->GetUltimate();
+                    },
+                    [&](const parser::OmpLocator &y) {
+                      common::visit(
+                          common::visitors{
+                              [&](const parser::OmpObject &obj) {
+                                if (const auto *sym =
+                                        getDeclareVariantObjectSymbol(obj))
+                                  variantSym = &sym->GetUltimate();
+                              },
+                              [](const parser::FunctionReference &) {},
+                          },
+                          y.u);
+                    },
+                    [](auto &&) {},
+                },
+                declareVariantArg.u);
+  if (!variantSym)
+    return;
+
+  // Get the user condition expression (if any).
+  const parser::ScalarExpr *userCond =
+      getDeclareVariantUserConditionExpr(declareVariantSpec);
+
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+
+  auto genCall = [&](const semantics::Symbol &sym) {
+    std::string name = converter.mangleName(sym);
+    auto func = builder.getModule().lookupSymbol<mlir::func::FuncOp>(name);
+    if (!func) {
+      auto funcType = builder.getFunctionType({}, {});
+      func = mlir::func::FuncOp::create(builder, loc, name, funcType);
+      func.setVisibility(mlir::SymbolTable::Visibility::Private);
+      builder.getModule().push_back(func);
+    }
+    fir::CallOp::create(builder, loc, func);
+  };
+
+  if (userCond) {
+    // Runtime condition: if true, call variant and return; else fall through
+    // to the base body. Split the current block after the condition value.
+    lower::StatementContext stmtCtx;
+    const auto *condExpr = semantics::GetExpr(semaCtx, *userCond);
+    if (!condExpr)
+      return;
+    mlir::Value condVal =
+        fir::getBase(converter.genExprValue(*condExpr, stmtCtx, &loc));
+    if (condVal.getType() != builder.getI1Type())
+      condVal = builder.createConvert(loc, builder.getI1Type(), condVal);
+    stmtCtx.finalizeAndReset();
+
+    // Split the current block: everything after the insertion point goes
+    // into continueBlock (the base body path).
+    mlir::Block *currentBlock = builder.getInsertionBlock();
+    mlir::Block::iterator splitPt = builder.getInsertionPoint();
+    mlir::Block *continueBlock = currentBlock->splitBlock(splitPt);
+
+    // Create a then-block for the variant call + return.
+    mlir::Block *thenBlock = builder.createBlock(continueBlock);
+
+    // Terminate the current block with a conditional branch.
+    builder.setInsertionPointToEnd(currentBlock);
+    mlir::cf::CondBranchOp::create(builder, loc, condVal, thenBlock,
+                                   mlir::ValueRange{}, continueBlock,
+                                   mlir::ValueRange{});
+
+    // Then block: call variant and return from the function.
+    builder.setInsertionPointToEnd(thenBlock);
+    genCall(*variantSym);
+    mlir::func::ReturnOp::create(builder, loc);
+
+    // Continue block: the rest of the base body will be generated here.
+    builder.setInsertionPointToStart(continueBlock);
+  } else {
+    // No condition: unconditionally call the variant.
+    genCall(*variantSym);
+  }
+}
+
 static void
 genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
        semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
        const parser::OmpDeclareVariantDirective &declareVariantDirective) {
-  if (!semaCtx.langOptions().OpenMPSimd)
-    TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
+  genDeclareVariant(converter, symTable, semaCtx, eval,
+                    declareVariantDirective);
 }
 
 static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
@@ -5269,9 +5651,9 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
 static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
                    semantics::SemanticsContext &semaCtx,
                    lower::pft::Evaluation &eval,
-                   const parser::OpenMPDispatchConstruct &) {
-  if (!semaCtx.langOptions().OpenMPSimd)
-    TODO(converter.getCurrentLocation(), "OpenMPDispatchConstruct");
+                   const parser::OpenMPDispatchConstruct &dispatchConstruct) {
+  genNestedEvaluationsWithVariantResolution(
+      converter, semaCtx, eval, llvm::omp::Directive::OMPD_dispatch);
 }
 
 static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
diff --git a/flang/lib/Semantics/check-omp-variant.cpp b/flang/lib/Semantics/check-omp-variant.cpp
index 8ab160496b64a..e02303fef0882 100644
--- a/flang/lib/Semantics/check-omp-variant.cpp
+++ b/flang/lib/Semantics/check-omp-variant.cpp
@@ -21,6 +21,7 @@
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/openmp-modifiers.h"
 #include "flang/Semantics/openmp-utils.h"
+#include "flang/Semantics/semantics.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 
@@ -649,10 +650,7 @@ void OmpStructureChecker::CheckDeclareVariantUserConditions(
         continue;
       }
       if (const auto *expr{GetExpr(scalarExpr)}) {
-        if (!IsConstantExpr(*expr, &context_.foldingContext())) {
-          context_.Say(property.source,
-              "Run-time USER condition in the MATCH clause is not yet implemented"_err_en_US);
-        }
+        (void)IsConstantExpr(*expr, &context_.foldingContext());
       }
     }
   }
@@ -735,6 +733,12 @@ void OmpStructureChecker::CheckOmpDeclareVariantDirective(
     return;
   }
 
+  // Store variant info in SemanticsContext for use during lowering.
+  if (base && variant) {
+    context_.AddOmpDeclareVariant(
+        *base, OmpDeclareVariantInfo{*variant, &spec});
+  }
+
   EnterDirectiveNest(ContextSelectorNest);
   CheckContextSelectorSpecification(*matchSelector);
   CheckDeclareVariantUserConditions(*matchSelector);
diff --git a/flang/test/Integration/OpenMP/declare-variant.f90 b/flang/test/Integration/OpenMP/declare-variant.f90
new file mode 100644
index 0000000000000..ce38b4c2ade09
--- /dev/null
+++ b/flang/test/Integration/OpenMP/declare-variant.f90
@@ -0,0 +1,112 @@
+!===----------------------------------------------------------------------===!
+! This directory can be used to add Integration tests involving multiple
+! stages of the compiler (for eg. from Fortran to LLVM IR). It should not
+! contain executable tests. We should only add tests here sparingly and only
+! if there is no other way to test. Repeat this message in each test that is
+! added to this directory and sub-directories.
+!===----------------------------------------------------------------------===!
+
+!RUN: %flang_fc1 -emit-llvm -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+
+! Test that declare variant with construct context selectors correctly
+! resolves variant calls in the generated LLVM IR.
+
+module m
+contains
+  subroutine variant_parallel()
+  end subroutine
+  subroutine variant_teams()
+  end subroutine
+  subroutine base_parallel()
+    !$omp declare variant(variant_parallel) match(construct={parallel})
+  end subroutine
+  subroutine base_teams()
+    !$omp declare variant(variant_teams) match(construct={teams})
+  end subroutine
+end module m
+
+! CHECK-LABEL: define internal void @test_parallel_..omp_par
+! CHECK:         call void {{.*}}variant_parallel()
+subroutine test_parallel()
+  use m
+  !$omp parallel
+    call base_parallel()
+  !$omp end parallel
+end subroutine
+
+! CHECK-LABEL: define internal void @test_teams_..omp_par
+! CHECK:         call void {{.*}}variant_teams()
+subroutine test_teams()
+  use m
+  !$omp teams
+    call base_teams()
+  !$omp end teams
+end subroutine
+
+! Verify that outside a matching construct, the base call is preserved.
+! CHECK-LABEL: define void {{.*}}test_no_construct{{.*}}
+! CHECK:         call void {{.*}}base_parallel()
+! CHECK-NOT:     call void {{.*}}variant_parallel
+subroutine test_no_construct()
+  use m
+  call base_parallel()
+end subroutine
+
+! Test declare variant with user={condition(expr)} only.
+module m_condition
+  implicit none
+  logical :: use_variant = .true.
+contains
+  subroutine foo_variant_cond()
+  end subroutine
+  subroutine foo_cond()
+    !$omp declare variant(foo_variant_cond) match(user={condition(use_variant)})
+  end subroutine
+end module m_condition
+
+! CHECK-LABEL: define void {{.*}}foo_cond()
+! CHECK:         br i1 %{{.*}}, label %[[THEN:.*]], label %[[ELSE:.*]]
+! CHECK:       [[THEN]]:
+! CHECK:         call void {{.*}}foo_variant_cond()
+! CHECK:         ret void
+! CHECK:       [[ELSE]]:
+! CHECK:         ret void
+subroutine test_condition()
+  use m_condition
+  call foo_cond()
+end subroutine
+
+! Test declare variant with both construct={dispatch} and user={condition(expr)}.
+module m_both
+  implicit none
+  logical :: use_variant_both = .true.
+contains
+  subroutine bar_variant()
+  end subroutine
+  subroutine bar()
+    !$omp declare variant(bar_variant) &
+    !$omp& match(construct={dispatch}, user={condition(use_variant_both)})
+  end subroutine
+end module m_both
+
+! CHECK-LABEL: define void {{.*}}test_dispatch_both{{.*}}
+! CHECK:         br i1 %{{.*}}, label %[[THEN2:.*]], label %[[ELSE2:.*]]
+! CHECK:       [[THEN2]]:
+! CHECK:         call void {{.*}}bar_variant()
+! CHECK:       [[ELSE2]]:
+! CHECK:         call void {{.*}}bar()
+subroutine test_dispatch_both()
+  use m_both
+  !$omp dispatch
+  call bar()
+  !$omp end dispatch
+end subroutine
+
+! Verify that without dispatch, the base call is preserved.
+! CHECK-LABEL: define void {{.*}}test_no_dispatch_both{{.*}}
+! CHECK:         call void {{.*}}bar()
+! CHECK-NOT:     call void {{.*}}bar_variant
+subroutine test_no_dispatch_both()
+  use m_both
+  call bar()
+end subroutine
diff --git a/flang/test/Lower/OpenMP/Todo/declare-variant.f90 b/flang/test/Lower/OpenMP/Todo/declare-variant.f90
deleted file mode 100644
index 5719ef3afdee1..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/declare-variant.f90
+++ /dev/null
@@ -1,17 +0,0 @@
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -fopenmp-version=51 -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: OmpDeclareVariantDirective
-
-subroutine sb1
-  integer :: x
-  x = 1
-  call sub(x)
-contains
-  subroutine vsub (v1)
-    integer, value :: v1
-  end
-  subroutine sub (v1)
-    !$omp declare variant(vsub), match(construct={dispatch})
-    integer, value :: v1
-  end
-end subroutine
diff --git a/flang/test/Lower/OpenMP/declare-variant-construct.f90 b/flang/test/Lower/OpenMP/declare-variant-construct.f90
new file mode 100644
index 0000000000000..7430e00b25536
--- /dev/null
+++ b/flang/test/Lower/OpenMP/declare-variant-construct.f90
@@ -0,0 +1,137 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -o - %s | FileCheck %s
+
+! Variant resolution replaces base procedure calls with variant calls
+! inside matching OpenMP construct regions.
+
+module m
+contains
+  subroutine variant_parallel()
+  end subroutine
+  subroutine base_parallel()
+    !$omp declare variant(variant_parallel) match(construct={parallel})
+  end subroutine
+  subroutine variant_teams()
+  end subroutine
+  subroutine base_teams()
+    !$omp declare variant(variant_teams) match(construct={teams})
+  end subroutine
+end module m
+
+! CHECK-LABEL: func.func {{.*}}test_parallel
+! CHECK:         omp.parallel {
+! CHECK:           fir.call {{.*}}variant_parallel()
+! CHECK-NOT:       fir.call {{.*}}base_parallel
+! CHECK:           omp.terminator
+! CHECK:         }
+subroutine test_parallel()
+  use m
+  !$omp parallel
+    call base_parallel()
+  !$omp end parallel
+end subroutine
+
+! CHECK-LABEL: func.func {{.*}}test_teams
+! CHECK:         omp.teams {
+! CHECK:           fir.call {{.*}}variant_teams()
+! CHECK-NOT:       fir.call {{.*}}base_teams
+! CHECK:           omp.terminator
+! CHECK:         }
+subroutine test_teams()
+  use m
+  !$omp teams
+    call base_teams()
+  !$omp end teams
+end subroutine
+
+!Calls without omp directives
+
+! CHECK-LABEL: func.func {{.*}}test_no_construct
+! CHECK:         fir.call {{.*}}base_parallel()
+! CHECK:         fir.call {{.*}}base_teams()
+! CHECK-NOT:     fir.call {{.*}}variant
+subroutine test_no_construct()
+  use m
+  call base_parallel()
+  call base_teams()
+end subroutine
+
+! Test base:variant form with construct={parallel}
+! CHECK-LABEL: func.func {{.*}}test_base_variant_form
+! CHECK:         omp.parallel {
+! CHECK:           fir.call {{.*}}test_base_variant_form{{.*}}variant_sub()
+! CHECK:           omp.terminator
+! CHECK:         }
+subroutine test_base_variant_form()
+  !$omp declare variant(base_sub:variant_sub) match(construct={parallel})
+  !$omp parallel
+    call base_sub()
+  !$omp end parallel
+contains
+  subroutine variant_sub()
+  end subroutine
+  subroutine base_sub()
+  end subroutine
+end subroutine
+
+! Test declare variant with user={condition(expr)} only.
+! The base procedure body should contain a conditional branch:
+! if condition is true, call variant and return; else execute base body.
+module m_cond
+  implicit none
+  logical :: use_variant = .true.
+contains
+  subroutine foo_variant_cond()
+  end subroutine
+  subroutine foo_cond()
+    !$omp declare variant(foo_variant_cond) match(user={condition(use_variant)})
+  end subroutine
+end module m_cond
+
+! CHECK-LABEL: func.func {{.*}}foo_cond{{.*}}
+! CHECK:         %[[COND:.*]] = fir.convert %{{.*}} : (!fir.logical<4>) -> i1
+! CHECK:         cf.cond_br %[[COND]], ^[[THEN:.*]], ^[[ELSE:.*]]
+! CHECK:       ^[[THEN]]:
+! CHECK:         fir.call {{.*}}foo_variant_cond()
+! CHECK:         return
+! CHECK:       ^[[ELSE]]:
+subroutine test_condition()
+  use m_cond
+  call foo_cond()
+end subroutine
+
+! Test declare variant with both construct={dispatch} and user={condition(expr)}.
+! Inside !$omp dispatch, the call site should generate a conditional
+! if/else selecting between variant and base calls.
+module m_both
+  implicit none
+  logical :: use_variant_both = .true.
+contains
+  subroutine bar_variant()
+  end subroutine
+  subroutine bar()
+    !$omp declare variant(bar_variant) &
+    !$omp& match(construct={dispatch}, user={condition(use_variant_both)})
+  end subroutine
+end module m_both
+
+! CHECK-LABEL: func.func {{.*}}test_dispatch_both
+! CHECK:         fir.if %{{.*}} {
+! CHECK:           fir.call {{.*}}bar_variant()
+! CHECK:         } else {
+! CHECK:           fir.call {{.*}}bar()
+! CHECK:         }
+subroutine test_dispatch_both()
+  use m_both
+  !$omp dispatch
+  call bar()
+  !$omp end dispatch
+end subroutine
+
+! Verify that without dispatch, the base call is preserved.
+! CHECK-LABEL: func.func {{.*}}test_no_dispatch_both
+! CHECK:         fir.call {{.*}}bar()
+! CHECK-NOT:     fir.call {{.*}}bar_variant
+subroutine test_no_dispatch_both()
+  use m_both
+  call bar()
+end subroutine
diff --git a/flang/test/Semantics/OpenMP/declare-variant-match.f90 b/flang/test/Semantics/OpenMP/declare-variant-match.f90
index 199d05f1750ee..4a98ff091158f 100644
--- a/flang/test/Semantics/OpenMP/declare-variant-match.f90
+++ b/flang/test/Semantics/OpenMP/declare-variant-match.f90
@@ -76,7 +76,6 @@ subroutine sub
 
 subroutine f06(x)
   integer :: x
-!ERROR: Run-time USER condition in the MATCH clause is not yet implemented
   !$omp declare variant (sub:vsub) match (user={condition(x > 0)})
 contains
   subroutine vsub



More information about the flang-commits mailing list