[flang-commits] [flang] [flang][OpenMP] Declare variant (PR #204345)
via flang-commits
flang-commits at lists.llvm.org
Wed Jun 17 05:19:02 PDT 2026
https://github.com/SunilKuravinakop created https://github.com/llvm/llvm-project/pull/204345
Support for declare variant.
This also Fixes [#203833](https://github.com/llvm/llvm-project/issues/203833)
>From 759aa753b6506a053fdc1c5afd1893c26768b9f8 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe31.hpc.amslabs.hpecorp.net>
Date: Wed, 17 Jun 2026 06:05:40 -0500
Subject: [PATCH] Support for declare variant.
---
flang/include/flang/Semantics/semantics.h | 24 ++
flang/lib/Lower/Bridge.cpp | 8 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 394 +++++++++++++++++-
flang/lib/Semantics/check-omp-variant.cpp | 12 +-
.../Integration/OpenMP/declare-variant.f90 | 112 +++++
.../Lower/OpenMP/Todo/declare-variant.f90 | 17 -
.../OpenMP/declare-variant-construct.f90 | 137 ++++++
.../OpenMP/declare-variant-match.f90 | 1 -
8 files changed, 676 insertions(+), 29 deletions(-)
create mode 100644 flang/test/Integration/OpenMP/declare-variant.f90
delete mode 100644 flang/test/Lower/OpenMP/Todo/declare-variant.f90
create mode 100644 flang/test/Lower/OpenMP/declare-variant-construct.f90
diff --git a/flang/include/flang/Semantics/semantics.h b/flang/include/flang/Semantics/semantics.h
index 5e98378fc49be..aa38baa7a941f 100644
--- a/flang/include/flang/Semantics/semantics.h
+++ b/flang/include/flang/Semantics/semantics.h
@@ -49,6 +49,7 @@ struct SelectRankConstruct;
struct SelectTypeConstruct;
struct Variable;
struct WhereConstruct;
+struct OmpDirectiveSpecification;
} // namespace Fortran::parser
namespace Fortran::semantics {
@@ -57,6 +58,13 @@ class Symbol;
class CommonBlockMap;
using CommonBlockList = std::vector<std::pair<SymbolRef, std::size_t>>;
+// Information about a single declare variant directive associated with a
+// base procedure.
+struct OmpDeclareVariantInfo {
+ SymbolRef variantProcedure;
+ const parser::OmpDirectiveSpecification *spec;
+};
+
using ConstructNode = std::variant<const parser::AssociateConstruct *,
const parser::BlockConstruct *, const parser::CaseConstruct *,
const parser::ChangeTeamConstruct *, const parser::CriticalConstruct *,
@@ -350,6 +358,19 @@ class SemanticsContext {
void DumpSymbols(llvm::raw_ostream &);
+ // OpenMP declare variant support: register variant info for a base
+ // procedure, and query it during lowering.
+ void AddOmpDeclareVariant(const Symbol &base, OmpDeclareVariantInfo info) {
+ ompDeclareVariants_[&base].push_back(std::move(info));
+ }
+ const std::vector<OmpDeclareVariantInfo> *GetOmpDeclareVariants(
+ const Symbol &base) const {
+ auto it = ompDeclareVariants_.find(&base);
+ if (it != ompDeclareVariants_.end())
+ return &it->second;
+ return nullptr;
+ }
+
// Top-level ProgramTrees are owned by the SemanticsContext for persistence.
ProgramTree &SaveProgramTree(ProgramTree &&);
@@ -409,6 +430,9 @@ class SemanticsContext {
UnorderedSymbolSet isUsed_;
std::set<const parser::AccObject *> accObjectDuplicates_;
std::list<ProgramTree> programTrees_;
+ // OpenMP declare variant: base symbol -> list of variant info
+ std::map<const Symbol *, std::vector<OmpDeclareVariantInfo>>
+ ompDeclareVariants_;
};
class Semantics {
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 6078af78ace24..c127bc8019151 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4102,7 +4102,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
ompDeferredDeclareTarget);
genOpenMPDeclarativeConstruct(
*this, localSymbols, bridge.getSemanticsContext(), getEval(), ompDecl);
- builder->restoreInsertionPoint(insertPt);
+ // In case of declare variant related lowering with runtime user
+ // conditions (if/else) the original block is split. This is done
+ // by creating conditional branches that terminate the current block,
+ // making the saved insertion point invalid for subsequent insertions.
+ // Otherwise restore the insertion point.
+ if (builder->getInsertionBlock() == insertPt.getBlock())
+ builder->restoreInsertionPoint(insertPt);
}
/// Generate FIR for a SELECT CASE statement.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 60b7fdc0e5322..f5ffdc77b24b5 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1263,6 +1263,11 @@ struct OpWithBodyGenInfo {
bool privatize = true;
};
+// Forward declaration for variant-aware evaluation processing.
+static void genNestedEvaluationsWithVariantResolution(
+ lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval, llvm::omp::Directive constructDir);
+
/// Create the body (block) for an OpenMP Operation.
///
/// \param [in] op - the operation the body belongs to.
@@ -1355,7 +1360,19 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
firOpBuilder.setInsertionPointToEnd(&op.getRegion(0).back());
auto *temp = lower::genOpenMPTerminator(firOpBuilder, &op, info.loc);
firOpBuilder.setInsertionPointAfter(marker);
- genNestedEvaluations(info.converter, info.eval);
+ // For constructs that support declare variant resolution (parallel,
+ // teams, dispatch), use variant-aware evaluation processing.
+ static constexpr llvm::omp::Directive variantResolutionDirs[] = {
+ llvm::omp::Directive::OMPD_parallel,
+ llvm::omp::Directive::OMPD_teams,
+ llvm::omp::Directive::OMPD_dispatch,
+ };
+ if (llvm::is_contained(variantResolutionDirs, info.dir)) {
+ genNestedEvaluationsWithVariantResolution(info.converter, info.semaCtx,
+ info.eval, info.dir);
+ } else {
+ genNestedEvaluations(info.converter, info.eval);
+ }
temp->erase();
}
}
@@ -4098,12 +4115,377 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
TODO(converter.getCurrentLocation(), "OpenMP ASSUMES declaration");
}
+// Return pointer to match() clause from !omp declare variant
+static const parser::traits::OmpContextSelectorSpecification *
+getMatchSelectorInDeclareVariant(
+ const parser::OmpDirectiveSpecification &declareVariant) {
+ auto matchClauseIt = llvm::find_if(
+ declareVariant.Clauses().v, [](const parser::OmpClause &clause) {
+ return clause.Id() == llvm::omp::Clause::OMPC_match;
+ });
+
+ if (matchClauseIt == declareVariant.Clauses().v.end())
+ return nullptr;
+
+ return &std::get<parser::OmpClause::Match>(matchClauseIt->u).v.v;
+}
+
+// Does this declare variant (declareVariantSpec) have a directive which matches
+// a particular constructDirective (e.g. dispatch)?
+static bool doesDeclareVariantMatchDirective(
+ const parser::OmpDirectiveSpecification &declareVariantSpec,
+ llvm::omp::Directive expectedConstructDir) {
+
+ const auto *matchSelector =
+ getMatchSelectorInDeclareVariant(declareVariantSpec);
+ if (!matchSelector)
+ return false;
+
+ using TraitSetName = parser::OmpTraitSetSelectorName;
+
+ return llvm::any_of(matchSelector->v, [expectedConstructDir](
+ const auto &traitSetSelector) {
+ if (std::get<TraitSetName>(traitSetSelector.t).v !=
+ TraitSetName::Value::Construct)
+ return false;
+
+ const auto &constructTraits =
+ std::get<std::list<parser::OmpTraitSelector>>(traitSetSelector.t);
+
+ return llvm::any_of(
+ constructTraits,
+ [expectedConstructDir](const parser::OmpTraitSelector &constructTrait) {
+ const auto &traitName =
+ std::get<parser::OmpTraitSelectorName>(constructTrait.t);
+ const auto *constructDir =
+ std::get_if<llvm::omp::Directive>(&traitName.u);
+ return constructDir && *constructDir == expectedConstructDir;
+ });
+ });
+}
+
+// Check if a context selector includes any construct trait.
+static bool hasAnyConstructTrait(
+ const parser::OmpDirectiveSpecification &declareVariantSpec) {
+ const auto *matchSelector =
+ getMatchSelectorInDeclareVariant(declareVariantSpec);
+ if (!matchSelector)
+ return false;
+ using SetName = parser::OmpTraitSetSelectorName;
+
+ return llvm::any_of(matchSelector->v, [](const auto &traitSetSelector) {
+ return std::get<SetName>(traitSetSelector.t).v == SetName::Value::Construct;
+ });
+}
+
+// Return the "expr" from match(user={condition(expr)}) in a
+// declare variant directive, or nullptr if no such condition exists.
+static const parser::ScalarExpr *getDeclareVariantUserConditionExpr(
+ const parser::OmpDirectiveSpecification &declareVariantSpec) {
+ const auto *matchSelector =
+ getMatchSelectorInDeclareVariant(declareVariantSpec);
+ if (!matchSelector)
+ return nullptr;
+
+ using TraitSetName = parser::OmpTraitSetSelectorName;
+ using TraitName = parser::OmpTraitSelectorName;
+
+ auto userTraitSetIt =
+ llvm::find_if(matchSelector->v, [](const auto &traitSetSelector) {
+ return std::get<TraitSetName>(traitSetSelector.t).v ==
+ TraitSetName::Value::User;
+ });
+ if (userTraitSetIt == matchSelector->v.end())
+ return nullptr;
+
+ const auto &userTraits =
+ std::get<std::list<parser::OmpTraitSelector>>(userTraitSetIt->t);
+ auto conditionTraitIt =
+ llvm::find_if(userTraits, [](const parser::OmpTraitSelector &trait) {
+ const auto &traitName = std::get<TraitName>(trait.t);
+ return std::holds_alternative<TraitName::Value>(traitName.u) &&
+ std::get<TraitName::Value>(traitName.u) ==
+ TraitName::Value::Condition;
+ });
+ if (conditionTraitIt == userTraits.end())
+ return nullptr;
+
+ const auto &properties =
+ std::get<std::optional<parser::OmpTraitSelector::Properties>>(
+ conditionTraitIt->t);
+ if (!properties)
+ return nullptr;
+ const auto &traitProperties =
+ std::get<std::list<parser::OmpTraitProperty>>(properties->t);
+
+ auto scalarExprIt =
+ llvm::find_if(traitProperties, [](const parser::OmpTraitProperty &prop) {
+ return std::holds_alternative<parser::ScalarExpr>(prop.u);
+ });
+ if (scalarExprIt == traitProperties.end())
+ return nullptr;
+
+ return &std::get<parser::ScalarExpr>(scalarExprIt->u);
+}
+
+// Return the ultimate procedure symbol referenced by a CallStmt,
+// or nullptr if the procedure cannot be resolved.
+static const semantics::Symbol *
+getCallProcedureSymbol(const parser::CallStmt &callStmt) {
+ const auto &call = std::get<parser::Call>(callStmt.t);
+ const auto &procedureDesignator =
+ std::get<parser::ProcedureDesignator>(call.t);
+ if (const auto *procedureName =
+ std::get_if<parser::Name>(&procedureDesignator.u)) {
+ if (procedureName->symbol)
+ return &procedureName->symbol->GetUltimate();
+ }
+ return nullptr;
+}
+
+// Generate variant-aware nested evaluations for a construct region.
+// For each CallStmt, check if the called procedure has declare variant
+// directives matching the given construct directive. If so, generate
+// if/else cascade for variant selection; otherwise generate the call normally.
+static void genNestedEvaluationsWithVariantResolution(
+ lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval, llvm::omp::Directive constructDir) {
+ for (lower::pft::Evaluation &nestedEval : eval.getNestedEvaluations()) {
+ const parser::CallStmt *callStmt = nestedEval.getIf<parser::CallStmt>();
+ if (!callStmt) {
+ converter.genEval(nestedEval);
+ continue;
+ }
+
+ const semantics::Symbol *baseSym = getCallProcedureSymbol(*callStmt);
+ if (!baseSym) {
+ converter.genEval(nestedEval);
+ continue;
+ }
+
+ const auto *declareVariants = semaCtx.GetOmpDeclareVariants(*baseSym);
+ if (!declareVariants) {
+ converter.genEval(nestedEval);
+ continue;
+ }
+
+ // Collect variants matching the current construct directive.
+ struct ConstructVariant {
+ const semantics::Symbol *variantSym;
+ const parser::ScalarExpr *userCond;
+ };
+ llvm::SmallVector<ConstructVariant, 4> matchingVariants;
+ for (const auto &variantInfo : *declareVariants) {
+ if (variantInfo.spec &&
+ doesDeclareVariantMatchDirective(*variantInfo.spec, constructDir)) {
+ matchingVariants.push_back(
+ {&variantInfo.variantProcedure.get(),
+ getDeclareVariantUserConditionExpr(*variantInfo.spec)});
+ }
+ }
+
+ if (matchingVariants.empty()) {
+ converter.genEval(nestedEval);
+ continue;
+ }
+
+ // Generate if/else cascade for variant selection.
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ mlir::Location loc = converter.getCurrentLocation();
+ lower::StatementContext stmtCtx;
+
+ for (const auto &matchedVariant : matchingVariants) {
+ if (matchedVariant.userCond) {
+ const auto *condExpr =
+ semantics::GetExpr(semaCtx, *matchedVariant.userCond);
+ if (!condExpr)
+ continue;
+ mlir::Value condVal =
+ fir::getBase(converter.genExprValue(*condExpr, stmtCtx, &loc));
+ if (condVal.getType() != builder.getI1Type())
+ condVal = builder.createConvert(loc, builder.getI1Type(), condVal);
+ stmtCtx.finalizeAndReset();
+
+ auto ifOp = fir::IfOp::create(builder, loc, condVal,
+ /*withElseRegion=*/true);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+ std::string variantName =
+ converter.mangleName(*matchedVariant.variantSym);
+ auto variantFunc =
+ builder.getModule().lookupSymbol<mlir::func::FuncOp>(variantName);
+ if (!variantFunc) {
+ auto funcType = builder.getFunctionType({}, {});
+ variantFunc =
+ mlir::func::FuncOp::create(builder, loc, variantName, funcType);
+ variantFunc.setVisibility(mlir::SymbolTable::Visibility::Private);
+ builder.getModule().push_back(variantFunc);
+ }
+ fir::CallOp::create(builder, loc, variantFunc);
+
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ } else {
+ // No condition: unconditionally call the variant.
+ std::string variantName =
+ converter.mangleName(*matchedVariant.variantSym);
+ auto variantFunc =
+ builder.getModule().lookupSymbol<mlir::func::FuncOp>(variantName);
+ if (!variantFunc) {
+ auto funcType = builder.getFunctionType({}, {});
+ variantFunc =
+ mlir::func::FuncOp::create(builder, loc, variantName, funcType);
+ variantFunc.setVisibility(mlir::SymbolTable::Visibility::Private);
+ builder.getModule().push_back(variantFunc);
+ }
+ fir::CallOp::create(builder, loc, variantFunc);
+ // Skip the base call since variant is unconditional.
+ goto next_eval;
+ }
+ }
+
+ // Generate the base call in the final else-branch.
+ converter.genEval(nestedEval);
+ next_eval:;
+ }
+}
+
+// Lower an OpenMP declare variant directive for the base procedure body.
+static void genDeclareVariant(
+ lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+ const parser::OmpDeclareVariantDirective &declareVariantDirective) {
+
+ const parser::OmpDirectiveSpecification &declareVariantSpec{
+ declareVariantDirective.v};
+ mlir::Location loc = converter.genLocation(declareVariantSpec.source);
+
+ // Variants with construct traits (dispatch, parallel, teams)
+ // are resolved at the corresponding construct call sites, not inside the
+ // base procedure body.
+ if (hasAnyConstructTrait(declareVariantSpec))
+ return;
+
+ // Extract the variant procedure symbol from the argument.
+ const parser::OmpArgumentList &declareVariantArgs{
+ declareVariantSpec.Arguments()};
+ if (declareVariantArgs.v.size() != 1)
+ return;
+ const semantics::Symbol *variantSym = nullptr;
+ const semantics::Symbol *baseProcedureSym = nullptr;
+ const parser::OmpArgument &declareVariantArg{declareVariantArgs.v.front()};
+
+ // Extract a symbol from an OmpObject.
+ auto getDeclareVariantObjectSymbol =
+ [](const parser::OmpObject &obj) -> const semantics::Symbol * {
+ return common::visit(
+ common::visitors{
+ [](const parser::Designator &d) -> const semantics::Symbol * {
+ if (const auto *dr = std::get_if<parser::DataRef>(&d.u))
+ if (const auto *n = std::get_if<parser::Name>(&dr->u))
+ return n->symbol;
+ return nullptr;
+ },
+ [](const parser::Name &n) -> const semantics::Symbol * {
+ return n.symbol;
+ },
+ [](const parser::OmpObject::Invalid &)
+ -> const semantics::Symbol * { return nullptr; },
+ },
+ obj.u);
+ };
+
+ common::visit(common::visitors{
+ [&](const parser::OmpBaseVariantNames &y) {
+ if (const auto *sym =
+ getDeclareVariantObjectSymbol(std::get<0>(y.t)))
+ baseProcedureSym = &sym->GetUltimate();
+ if (const auto *sym =
+ getDeclareVariantObjectSymbol(std::get<1>(y.t)))
+ variantSym = &sym->GetUltimate();
+ },
+ [&](const parser::OmpLocator &y) {
+ common::visit(
+ common::visitors{
+ [&](const parser::OmpObject &obj) {
+ if (const auto *sym =
+ getDeclareVariantObjectSymbol(obj))
+ variantSym = &sym->GetUltimate();
+ },
+ [](const parser::FunctionReference &) {},
+ },
+ y.u);
+ },
+ [](auto &&) {},
+ },
+ declareVariantArg.u);
+ if (!variantSym)
+ return;
+
+ // Get the user condition expression (if any).
+ const parser::ScalarExpr *userCond =
+ getDeclareVariantUserConditionExpr(declareVariantSpec);
+
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+
+ auto genCall = [&](const semantics::Symbol &sym) {
+ std::string name = converter.mangleName(sym);
+ auto func = builder.getModule().lookupSymbol<mlir::func::FuncOp>(name);
+ if (!func) {
+ auto funcType = builder.getFunctionType({}, {});
+ func = mlir::func::FuncOp::create(builder, loc, name, funcType);
+ func.setVisibility(mlir::SymbolTable::Visibility::Private);
+ builder.getModule().push_back(func);
+ }
+ fir::CallOp::create(builder, loc, func);
+ };
+
+ if (userCond) {
+ // Runtime condition: if true, call variant and return; else fall through
+ // to the base body. Split the current block after the condition value.
+ lower::StatementContext stmtCtx;
+ const auto *condExpr = semantics::GetExpr(semaCtx, *userCond);
+ if (!condExpr)
+ return;
+ mlir::Value condVal =
+ fir::getBase(converter.genExprValue(*condExpr, stmtCtx, &loc));
+ if (condVal.getType() != builder.getI1Type())
+ condVal = builder.createConvert(loc, builder.getI1Type(), condVal);
+ stmtCtx.finalizeAndReset();
+
+ // Split the current block: everything after the insertion point goes
+ // into continueBlock (the base body path).
+ mlir::Block *currentBlock = builder.getInsertionBlock();
+ mlir::Block::iterator splitPt = builder.getInsertionPoint();
+ mlir::Block *continueBlock = currentBlock->splitBlock(splitPt);
+
+ // Create a then-block for the variant call + return.
+ mlir::Block *thenBlock = builder.createBlock(continueBlock);
+
+ // Terminate the current block with a conditional branch.
+ builder.setInsertionPointToEnd(currentBlock);
+ mlir::cf::CondBranchOp::create(builder, loc, condVal, thenBlock,
+ mlir::ValueRange{}, continueBlock,
+ mlir::ValueRange{});
+
+ // Then block: call variant and return from the function.
+ builder.setInsertionPointToEnd(thenBlock);
+ genCall(*variantSym);
+ mlir::func::ReturnOp::create(builder, loc);
+
+ // Continue block: the rest of the base body will be generated here.
+ builder.setInsertionPointToStart(continueBlock);
+ } else {
+ // No condition: unconditionally call the variant.
+ genCall(*variantSym);
+ }
+}
+
static void
genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OmpDeclareVariantDirective &declareVariantDirective) {
- if (!semaCtx.langOptions().OpenMPSimd)
- TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
+ genDeclareVariant(converter, symTable, semaCtx, eval,
+ declareVariantDirective);
}
static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
@@ -5269,9 +5651,9 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
- const parser::OpenMPDispatchConstruct &) {
- if (!semaCtx.langOptions().OpenMPSimd)
- TODO(converter.getCurrentLocation(), "OpenMPDispatchConstruct");
+ const parser::OpenMPDispatchConstruct &dispatchConstruct) {
+ genNestedEvaluationsWithVariantResolution(
+ converter, semaCtx, eval, llvm::omp::Directive::OMPD_dispatch);
}
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
diff --git a/flang/lib/Semantics/check-omp-variant.cpp b/flang/lib/Semantics/check-omp-variant.cpp
index 8ab160496b64a..e02303fef0882 100644
--- a/flang/lib/Semantics/check-omp-variant.cpp
+++ b/flang/lib/Semantics/check-omp-variant.cpp
@@ -21,6 +21,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-modifiers.h"
#include "flang/Semantics/openmp-utils.h"
+#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
@@ -649,10 +650,7 @@ void OmpStructureChecker::CheckDeclareVariantUserConditions(
continue;
}
if (const auto *expr{GetExpr(scalarExpr)}) {
- if (!IsConstantExpr(*expr, &context_.foldingContext())) {
- context_.Say(property.source,
- "Run-time USER condition in the MATCH clause is not yet implemented"_err_en_US);
- }
+ (void)IsConstantExpr(*expr, &context_.foldingContext());
}
}
}
@@ -735,6 +733,12 @@ void OmpStructureChecker::CheckOmpDeclareVariantDirective(
return;
}
+ // Store variant info in SemanticsContext for use during lowering.
+ if (base && variant) {
+ context_.AddOmpDeclareVariant(
+ *base, OmpDeclareVariantInfo{*variant, &spec});
+ }
+
EnterDirectiveNest(ContextSelectorNest);
CheckContextSelectorSpecification(*matchSelector);
CheckDeclareVariantUserConditions(*matchSelector);
diff --git a/flang/test/Integration/OpenMP/declare-variant.f90 b/flang/test/Integration/OpenMP/declare-variant.f90
new file mode 100644
index 0000000000000..ce38b4c2ade09
--- /dev/null
+++ b/flang/test/Integration/OpenMP/declare-variant.f90
@@ -0,0 +1,112 @@
+!===----------------------------------------------------------------------===!
+! This directory can be used to add Integration tests involving multiple
+! stages of the compiler (for eg. from Fortran to LLVM IR). It should not
+! contain executable tests. We should only add tests here sparingly and only
+! if there is no other way to test. Repeat this message in each test that is
+! added to this directory and sub-directories.
+!===----------------------------------------------------------------------===!
+
+!RUN: %flang_fc1 -emit-llvm -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+
+! Test that declare variant with construct context selectors correctly
+! resolves variant calls in the generated LLVM IR.
+
+module m
+contains
+ subroutine variant_parallel()
+ end subroutine
+ subroutine variant_teams()
+ end subroutine
+ subroutine base_parallel()
+ !$omp declare variant(variant_parallel) match(construct={parallel})
+ end subroutine
+ subroutine base_teams()
+ !$omp declare variant(variant_teams) match(construct={teams})
+ end subroutine
+end module m
+
+! CHECK-LABEL: define internal void @test_parallel_..omp_par
+! CHECK: call void {{.*}}variant_parallel()
+subroutine test_parallel()
+ use m
+ !$omp parallel
+ call base_parallel()
+ !$omp end parallel
+end subroutine
+
+! CHECK-LABEL: define internal void @test_teams_..omp_par
+! CHECK: call void {{.*}}variant_teams()
+subroutine test_teams()
+ use m
+ !$omp teams
+ call base_teams()
+ !$omp end teams
+end subroutine
+
+! Verify that outside a matching construct, the base call is preserved.
+! CHECK-LABEL: define void {{.*}}test_no_construct{{.*}}
+! CHECK: call void {{.*}}base_parallel()
+! CHECK-NOT: call void {{.*}}variant_parallel
+subroutine test_no_construct()
+ use m
+ call base_parallel()
+end subroutine
+
+! Test declare variant with user={condition(expr)} only.
+module m_condition
+ implicit none
+ logical :: use_variant = .true.
+contains
+ subroutine foo_variant_cond()
+ end subroutine
+ subroutine foo_cond()
+ !$omp declare variant(foo_variant_cond) match(user={condition(use_variant)})
+ end subroutine
+end module m_condition
+
+! CHECK-LABEL: define void {{.*}}foo_cond()
+! CHECK: br i1 %{{.*}}, label %[[THEN:.*]], label %[[ELSE:.*]]
+! CHECK: [[THEN]]:
+! CHECK: call void {{.*}}foo_variant_cond()
+! CHECK: ret void
+! CHECK: [[ELSE]]:
+! CHECK: ret void
+subroutine test_condition()
+ use m_condition
+ call foo_cond()
+end subroutine
+
+! Test declare variant with both construct={dispatch} and user={condition(expr)}.
+module m_both
+ implicit none
+ logical :: use_variant_both = .true.
+contains
+ subroutine bar_variant()
+ end subroutine
+ subroutine bar()
+ !$omp declare variant(bar_variant) &
+ !$omp& match(construct={dispatch}, user={condition(use_variant_both)})
+ end subroutine
+end module m_both
+
+! CHECK-LABEL: define void {{.*}}test_dispatch_both{{.*}}
+! CHECK: br i1 %{{.*}}, label %[[THEN2:.*]], label %[[ELSE2:.*]]
+! CHECK: [[THEN2]]:
+! CHECK: call void {{.*}}bar_variant()
+! CHECK: [[ELSE2]]:
+! CHECK: call void {{.*}}bar()
+subroutine test_dispatch_both()
+ use m_both
+ !$omp dispatch
+ call bar()
+ !$omp end dispatch
+end subroutine
+
+! Verify that without dispatch, the base call is preserved.
+! CHECK-LABEL: define void {{.*}}test_no_dispatch_both{{.*}}
+! CHECK: call void {{.*}}bar()
+! CHECK-NOT: call void {{.*}}bar_variant
+subroutine test_no_dispatch_both()
+ use m_both
+ call bar()
+end subroutine
diff --git a/flang/test/Lower/OpenMP/Todo/declare-variant.f90 b/flang/test/Lower/OpenMP/Todo/declare-variant.f90
deleted file mode 100644
index 5719ef3afdee1..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/declare-variant.f90
+++ /dev/null
@@ -1,17 +0,0 @@
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -fopenmp-version=51 -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: OmpDeclareVariantDirective
-
-subroutine sb1
- integer :: x
- x = 1
- call sub(x)
-contains
- subroutine vsub (v1)
- integer, value :: v1
- end
- subroutine sub (v1)
- !$omp declare variant(vsub), match(construct={dispatch})
- integer, value :: v1
- end
-end subroutine
diff --git a/flang/test/Lower/OpenMP/declare-variant-construct.f90 b/flang/test/Lower/OpenMP/declare-variant-construct.f90
new file mode 100644
index 0000000000000..7430e00b25536
--- /dev/null
+++ b/flang/test/Lower/OpenMP/declare-variant-construct.f90
@@ -0,0 +1,137 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -o - %s | FileCheck %s
+
+! Variant resolution replaces base procedure calls with variant calls
+! inside matching OpenMP construct regions.
+
+module m
+contains
+ subroutine variant_parallel()
+ end subroutine
+ subroutine base_parallel()
+ !$omp declare variant(variant_parallel) match(construct={parallel})
+ end subroutine
+ subroutine variant_teams()
+ end subroutine
+ subroutine base_teams()
+ !$omp declare variant(variant_teams) match(construct={teams})
+ end subroutine
+end module m
+
+! CHECK-LABEL: func.func {{.*}}test_parallel
+! CHECK: omp.parallel {
+! CHECK: fir.call {{.*}}variant_parallel()
+! CHECK-NOT: fir.call {{.*}}base_parallel
+! CHECK: omp.terminator
+! CHECK: }
+subroutine test_parallel()
+ use m
+ !$omp parallel
+ call base_parallel()
+ !$omp end parallel
+end subroutine
+
+! CHECK-LABEL: func.func {{.*}}test_teams
+! CHECK: omp.teams {
+! CHECK: fir.call {{.*}}variant_teams()
+! CHECK-NOT: fir.call {{.*}}base_teams
+! CHECK: omp.terminator
+! CHECK: }
+subroutine test_teams()
+ use m
+ !$omp teams
+ call base_teams()
+ !$omp end teams
+end subroutine
+
+!Calls without omp directives
+
+! CHECK-LABEL: func.func {{.*}}test_no_construct
+! CHECK: fir.call {{.*}}base_parallel()
+! CHECK: fir.call {{.*}}base_teams()
+! CHECK-NOT: fir.call {{.*}}variant
+subroutine test_no_construct()
+ use m
+ call base_parallel()
+ call base_teams()
+end subroutine
+
+! Test base:variant form with construct={parallel}
+! CHECK-LABEL: func.func {{.*}}test_base_variant_form
+! CHECK: omp.parallel {
+! CHECK: fir.call {{.*}}test_base_variant_form{{.*}}variant_sub()
+! CHECK: omp.terminator
+! CHECK: }
+subroutine test_base_variant_form()
+ !$omp declare variant(base_sub:variant_sub) match(construct={parallel})
+ !$omp parallel
+ call base_sub()
+ !$omp end parallel
+contains
+ subroutine variant_sub()
+ end subroutine
+ subroutine base_sub()
+ end subroutine
+end subroutine
+
+! Test declare variant with user={condition(expr)} only.
+! The base procedure body should contain a conditional branch:
+! if condition is true, call variant and return; else execute base body.
+module m_cond
+ implicit none
+ logical :: use_variant = .true.
+contains
+ subroutine foo_variant_cond()
+ end subroutine
+ subroutine foo_cond()
+ !$omp declare variant(foo_variant_cond) match(user={condition(use_variant)})
+ end subroutine
+end module m_cond
+
+! CHECK-LABEL: func.func {{.*}}foo_cond{{.*}}
+! CHECK: %[[COND:.*]] = fir.convert %{{.*}} : (!fir.logical<4>) -> i1
+! CHECK: cf.cond_br %[[COND]], ^[[THEN:.*]], ^[[ELSE:.*]]
+! CHECK: ^[[THEN]]:
+! CHECK: fir.call {{.*}}foo_variant_cond()
+! CHECK: return
+! CHECK: ^[[ELSE]]:
+subroutine test_condition()
+ use m_cond
+ call foo_cond()
+end subroutine
+
+! Test declare variant with both construct={dispatch} and user={condition(expr)}.
+! Inside !$omp dispatch, the call site should generate a conditional
+! if/else selecting between variant and base calls.
+module m_both
+ implicit none
+ logical :: use_variant_both = .true.
+contains
+ subroutine bar_variant()
+ end subroutine
+ subroutine bar()
+ !$omp declare variant(bar_variant) &
+ !$omp& match(construct={dispatch}, user={condition(use_variant_both)})
+ end subroutine
+end module m_both
+
+! CHECK-LABEL: func.func {{.*}}test_dispatch_both
+! CHECK: fir.if %{{.*}} {
+! CHECK: fir.call {{.*}}bar_variant()
+! CHECK: } else {
+! CHECK: fir.call {{.*}}bar()
+! CHECK: }
+subroutine test_dispatch_both()
+ use m_both
+ !$omp dispatch
+ call bar()
+ !$omp end dispatch
+end subroutine
+
+! Verify that without dispatch, the base call is preserved.
+! CHECK-LABEL: func.func {{.*}}test_no_dispatch_both
+! CHECK: fir.call {{.*}}bar()
+! CHECK-NOT: fir.call {{.*}}bar_variant
+subroutine test_no_dispatch_both()
+ use m_both
+ call bar()
+end subroutine
diff --git a/flang/test/Semantics/OpenMP/declare-variant-match.f90 b/flang/test/Semantics/OpenMP/declare-variant-match.f90
index 199d05f1750ee..4a98ff091158f 100644
--- a/flang/test/Semantics/OpenMP/declare-variant-match.f90
+++ b/flang/test/Semantics/OpenMP/declare-variant-match.f90
@@ -76,7 +76,6 @@ subroutine sub
subroutine f06(x)
integer :: x
-!ERROR: Run-time USER condition in the MATCH clause is not yet implemented
!$omp declare variant (sub:vsub) match (user={condition(x > 0)})
contains
subroutine vsub
More information about the flang-commits
mailing list