[flang-commits] [flang] [flang][OpenMP] Declare variant (PR #204345)
via flang-commits
flang-commits at lists.llvm.org
Wed Jun 17 05:19:47 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: SunilKuravinakop
<details>
<summary>Changes</summary>
Support for declare variant.
This also Fixes [#<!-- -->203833](https://github.com/llvm/llvm-project/issues/203833)
---
Patch is 31.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/204345.diff
8 Files Affected:
- (modified) flang/include/flang/Semantics/semantics.h (+24)
- (modified) flang/lib/Lower/Bridge.cpp (+7-1)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+388-6)
- (modified) flang/lib/Semantics/check-omp-variant.cpp (+8-4)
- (added) flang/test/Integration/OpenMP/declare-variant.f90 (+112)
- (removed) flang/test/Lower/OpenMP/Todo/declare-variant.f90 (-17)
- (added) flang/test/Lower/OpenMP/declare-variant-construct.f90 (+137)
- (modified) flang/test/Semantics/OpenMP/declare-variant-match.f90 (-1)
``````````diff
diff --git a/flang/include/flang/Semantics/semantics.h b/flang/include/flang/Semantics/semantics.h
index 5e98378fc49be..aa38baa7a941f 100644
--- a/flang/include/flang/Semantics/semantics.h
+++ b/flang/include/flang/Semantics/semantics.h
@@ -49,6 +49,7 @@ struct SelectRankConstruct;
struct SelectTypeConstruct;
struct Variable;
struct WhereConstruct;
+struct OmpDirectiveSpecification;
} // namespace Fortran::parser
namespace Fortran::semantics {
@@ -57,6 +58,13 @@ class Symbol;
class CommonBlockMap;
using CommonBlockList = std::vector<std::pair<SymbolRef, std::size_t>>;
+// Information about a single declare variant directive associated with a
+// base procedure.
+struct OmpDeclareVariantInfo {
+ SymbolRef variantProcedure;
+ const parser::OmpDirectiveSpecification *spec;
+};
+
using ConstructNode = std::variant<const parser::AssociateConstruct *,
const parser::BlockConstruct *, const parser::CaseConstruct *,
const parser::ChangeTeamConstruct *, const parser::CriticalConstruct *,
@@ -350,6 +358,19 @@ class SemanticsContext {
void DumpSymbols(llvm::raw_ostream &);
+ // OpenMP declare variant support: register variant info for a base
+ // procedure, and query it during lowering.
+ void AddOmpDeclareVariant(const Symbol &base, OmpDeclareVariantInfo info) {
+ ompDeclareVariants_[&base].push_back(std::move(info));
+ }
+ const std::vector<OmpDeclareVariantInfo> *GetOmpDeclareVariants(
+ const Symbol &base) const {
+ auto it = ompDeclareVariants_.find(&base);
+ if (it != ompDeclareVariants_.end())
+ return &it->second;
+ return nullptr;
+ }
+
// Top-level ProgramTrees are owned by the SemanticsContext for persistence.
ProgramTree &SaveProgramTree(ProgramTree &&);
@@ -409,6 +430,9 @@ class SemanticsContext {
UnorderedSymbolSet isUsed_;
std::set<const parser::AccObject *> accObjectDuplicates_;
std::list<ProgramTree> programTrees_;
+ // OpenMP declare variant: base symbol -> list of variant info
+ std::map<const Symbol *, std::vector<OmpDeclareVariantInfo>>
+ ompDeclareVariants_;
};
class Semantics {
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 138e0fbe0fde8..1d94ade62b736 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4123,7 +4123,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
ompDeferredDeclareTarget);
genOpenMPDeclarativeConstruct(
*this, localSymbols, bridge.getSemanticsContext(), getEval(), ompDecl);
- builder->restoreInsertionPoint(insertPt);
+ // In case of declare variant related lowering with runtime user
+ // conditions (if/else) the original block is split. This is done
+ // by creating conditional branches that terminate the current block,
+ // making the saved insertion point invalid for subsequent insertions.
+ // Otherwise restore the insertion point.
+ if (builder->getInsertionBlock() == insertPt.getBlock())
+ builder->restoreInsertionPoint(insertPt);
}
/// Generate FIR for a SELECT CASE statement.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index facca9867e4bb..79b8078d96a9c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1370,6 +1370,11 @@ struct OpWithBodyGenInfo {
bool privatize = true;
};
+// Forward declaration for variant-aware evaluation processing.
+static void genNestedEvaluationsWithVariantResolution(
+ lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval, llvm::omp::Directive constructDir);
+
/// Create the body (block) for an OpenMP Operation.
///
/// \param [in] op - the operation the body belongs to.
@@ -1466,7 +1471,19 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
firOpBuilder.setInsertionPointToEnd(&op.getRegion(0).back());
auto *temp = lower::genOpenMPTerminator(firOpBuilder, &op, info.loc);
firOpBuilder.setInsertionPointAfter(marker);
- genNestedEvaluations(info.converter, info.eval);
+ // For constructs that support declare variant resolution (parallel,
+ // teams, dispatch), use variant-aware evaluation processing.
+ static constexpr llvm::omp::Directive variantResolutionDirs[] = {
+ llvm::omp::Directive::OMPD_parallel,
+ llvm::omp::Directive::OMPD_teams,
+ llvm::omp::Directive::OMPD_dispatch,
+ };
+ if (llvm::is_contained(variantResolutionDirs, info.dir)) {
+ genNestedEvaluationsWithVariantResolution(info.converter, info.semaCtx,
+ info.eval, info.dir);
+ } else {
+ genNestedEvaluations(info.converter, info.eval);
+ }
temp->erase();
}
}
@@ -4209,12 +4226,377 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
TODO(converter.getCurrentLocation(), "OpenMP ASSUMES declaration");
}
+// Return pointer to match() clause from !omp declare variant
+static const parser::traits::OmpContextSelectorSpecification *
+getMatchSelectorInDeclareVariant(
+ const parser::OmpDirectiveSpecification &declareVariant) {
+ auto matchClauseIt = llvm::find_if(
+ declareVariant.Clauses().v, [](const parser::OmpClause &clause) {
+ return clause.Id() == llvm::omp::Clause::OMPC_match;
+ });
+
+ if (matchClauseIt == declareVariant.Clauses().v.end())
+ return nullptr;
+
+ return &std::get<parser::OmpClause::Match>(matchClauseIt->u).v.v;
+}
+
+// Does this declare variant (declareVariantSpec) have a directive which matches
+// a particular constructDirective (e.g. dispatch)?
+static bool doesDeclareVariantMatchDirective(
+ const parser::OmpDirectiveSpecification &declareVariantSpec,
+ llvm::omp::Directive expectedConstructDir) {
+
+ const auto *matchSelector =
+ getMatchSelectorInDeclareVariant(declareVariantSpec);
+ if (!matchSelector)
+ return false;
+
+ using TraitSetName = parser::OmpTraitSetSelectorName;
+
+ return llvm::any_of(matchSelector->v, [expectedConstructDir](
+ const auto &traitSetSelector) {
+ if (std::get<TraitSetName>(traitSetSelector.t).v !=
+ TraitSetName::Value::Construct)
+ return false;
+
+ const auto &constructTraits =
+ std::get<std::list<parser::OmpTraitSelector>>(traitSetSelector.t);
+
+ return llvm::any_of(
+ constructTraits,
+ [expectedConstructDir](const parser::OmpTraitSelector &constructTrait) {
+ const auto &traitName =
+ std::get<parser::OmpTraitSelectorName>(constructTrait.t);
+ const auto *constructDir =
+ std::get_if<llvm::omp::Directive>(&traitName.u);
+ return constructDir && *constructDir == expectedConstructDir;
+ });
+ });
+}
+
+// Check if a context selector includes any construct trait.
+static bool hasAnyConstructTrait(
+ const parser::OmpDirectiveSpecification &declareVariantSpec) {
+ const auto *matchSelector =
+ getMatchSelectorInDeclareVariant(declareVariantSpec);
+ if (!matchSelector)
+ return false;
+ using SetName = parser::OmpTraitSetSelectorName;
+
+ return llvm::any_of(matchSelector->v, [](const auto &traitSetSelector) {
+ return std::get<SetName>(traitSetSelector.t).v == SetName::Value::Construct;
+ });
+}
+
+// Return the "expr" from match(user={condition(expr)}) in a
+// declare variant directive, or nullptr if no such condition exists.
+static const parser::ScalarExpr *getDeclareVariantUserConditionExpr(
+ const parser::OmpDirectiveSpecification &declareVariantSpec) {
+ const auto *matchSelector =
+ getMatchSelectorInDeclareVariant(declareVariantSpec);
+ if (!matchSelector)
+ return nullptr;
+
+ using TraitSetName = parser::OmpTraitSetSelectorName;
+ using TraitName = parser::OmpTraitSelectorName;
+
+ auto userTraitSetIt =
+ llvm::find_if(matchSelector->v, [](const auto &traitSetSelector) {
+ return std::get<TraitSetName>(traitSetSelector.t).v ==
+ TraitSetName::Value::User;
+ });
+ if (userTraitSetIt == matchSelector->v.end())
+ return nullptr;
+
+ const auto &userTraits =
+ std::get<std::list<parser::OmpTraitSelector>>(userTraitSetIt->t);
+ auto conditionTraitIt =
+ llvm::find_if(userTraits, [](const parser::OmpTraitSelector &trait) {
+ const auto &traitName = std::get<TraitName>(trait.t);
+ return std::holds_alternative<TraitName::Value>(traitName.u) &&
+ std::get<TraitName::Value>(traitName.u) ==
+ TraitName::Value::Condition;
+ });
+ if (conditionTraitIt == userTraits.end())
+ return nullptr;
+
+ const auto &properties =
+ std::get<std::optional<parser::OmpTraitSelector::Properties>>(
+ conditionTraitIt->t);
+ if (!properties)
+ return nullptr;
+ const auto &traitProperties =
+ std::get<std::list<parser::OmpTraitProperty>>(properties->t);
+
+ auto scalarExprIt =
+ llvm::find_if(traitProperties, [](const parser::OmpTraitProperty &prop) {
+ return std::holds_alternative<parser::ScalarExpr>(prop.u);
+ });
+ if (scalarExprIt == traitProperties.end())
+ return nullptr;
+
+ return &std::get<parser::ScalarExpr>(scalarExprIt->u);
+}
+
+// Return the ultimate procedure symbol referenced by a CallStmt,
+// or nullptr if the procedure cannot be resolved.
+static const semantics::Symbol *
+getCallProcedureSymbol(const parser::CallStmt &callStmt) {
+ const auto &call = std::get<parser::Call>(callStmt.t);
+ const auto &procedureDesignator =
+ std::get<parser::ProcedureDesignator>(call.t);
+ if (const auto *procedureName =
+ std::get_if<parser::Name>(&procedureDesignator.u)) {
+ if (procedureName->symbol)
+ return &procedureName->symbol->GetUltimate();
+ }
+ return nullptr;
+}
+
+// Generate variant-aware nested evaluations for a construct region.
+// For each CallStmt, check if the called procedure has declare variant
+// directives matching the given construct directive. If so, generate
+// if/else cascade for variant selection; otherwise generate the call normally.
+static void genNestedEvaluationsWithVariantResolution(
+ lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval, llvm::omp::Directive constructDir) {
+ for (lower::pft::Evaluation &nestedEval : eval.getNestedEvaluations()) {
+ const parser::CallStmt *callStmt = nestedEval.getIf<parser::CallStmt>();
+ if (!callStmt) {
+ converter.genEval(nestedEval);
+ continue;
+ }
+
+ const semantics::Symbol *baseSym = getCallProcedureSymbol(*callStmt);
+ if (!baseSym) {
+ converter.genEval(nestedEval);
+ continue;
+ }
+
+ const auto *declareVariants = semaCtx.GetOmpDeclareVariants(*baseSym);
+ if (!declareVariants) {
+ converter.genEval(nestedEval);
+ continue;
+ }
+
+ // Collect variants matching the current construct directive.
+ struct ConstructVariant {
+ const semantics::Symbol *variantSym;
+ const parser::ScalarExpr *userCond;
+ };
+ llvm::SmallVector<ConstructVariant, 4> matchingVariants;
+ for (const auto &variantInfo : *declareVariants) {
+ if (variantInfo.spec &&
+ doesDeclareVariantMatchDirective(*variantInfo.spec, constructDir)) {
+ matchingVariants.push_back(
+ {&variantInfo.variantProcedure.get(),
+ getDeclareVariantUserConditionExpr(*variantInfo.spec)});
+ }
+ }
+
+ if (matchingVariants.empty()) {
+ converter.genEval(nestedEval);
+ continue;
+ }
+
+ // Generate if/else cascade for variant selection.
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ mlir::Location loc = converter.getCurrentLocation();
+ lower::StatementContext stmtCtx;
+
+ for (const auto &matchedVariant : matchingVariants) {
+ if (matchedVariant.userCond) {
+ const auto *condExpr =
+ semantics::GetExpr(semaCtx, *matchedVariant.userCond);
+ if (!condExpr)
+ continue;
+ mlir::Value condVal =
+ fir::getBase(converter.genExprValue(*condExpr, stmtCtx, &loc));
+ if (condVal.getType() != builder.getI1Type())
+ condVal = builder.createConvert(loc, builder.getI1Type(), condVal);
+ stmtCtx.finalizeAndReset();
+
+ auto ifOp = fir::IfOp::create(builder, loc, condVal,
+ /*withElseRegion=*/true);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+ std::string variantName =
+ converter.mangleName(*matchedVariant.variantSym);
+ auto variantFunc =
+ builder.getModule().lookupSymbol<mlir::func::FuncOp>(variantName);
+ if (!variantFunc) {
+ auto funcType = builder.getFunctionType({}, {});
+ variantFunc =
+ mlir::func::FuncOp::create(builder, loc, variantName, funcType);
+ variantFunc.setVisibility(mlir::SymbolTable::Visibility::Private);
+ builder.getModule().push_back(variantFunc);
+ }
+ fir::CallOp::create(builder, loc, variantFunc);
+
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ } else {
+ // No condition: unconditionally call the variant.
+ std::string variantName =
+ converter.mangleName(*matchedVariant.variantSym);
+ auto variantFunc =
+ builder.getModule().lookupSymbol<mlir::func::FuncOp>(variantName);
+ if (!variantFunc) {
+ auto funcType = builder.getFunctionType({}, {});
+ variantFunc =
+ mlir::func::FuncOp::create(builder, loc, variantName, funcType);
+ variantFunc.setVisibility(mlir::SymbolTable::Visibility::Private);
+ builder.getModule().push_back(variantFunc);
+ }
+ fir::CallOp::create(builder, loc, variantFunc);
+ // Skip the base call since variant is unconditional.
+ goto next_eval;
+ }
+ }
+
+ // Generate the base call in the final else-branch.
+ converter.genEval(nestedEval);
+ next_eval:;
+ }
+}
+
+// Lower an OpenMP declare variant directive for the base procedure body.
+static void genDeclareVariant(
+ lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+ const parser::OmpDeclareVariantDirective &declareVariantDirective) {
+
+ const parser::OmpDirectiveSpecification &declareVariantSpec{
+ declareVariantDirective.v};
+ mlir::Location loc = converter.genLocation(declareVariantSpec.source);
+
+ // Variants with construct traits (dispatch, parallel, teams)
+ // are resolved at the corresponding construct call sites, not inside the
+ // base procedure body.
+ if (hasAnyConstructTrait(declareVariantSpec))
+ return;
+
+ // Extract the variant procedure symbol from the argument.
+ const parser::OmpArgumentList &declareVariantArgs{
+ declareVariantSpec.Arguments()};
+ if (declareVariantArgs.v.size() != 1)
+ return;
+ const semantics::Symbol *variantSym = nullptr;
+ const semantics::Symbol *baseProcedureSym = nullptr;
+ const parser::OmpArgument &declareVariantArg{declareVariantArgs.v.front()};
+
+ // Extract a symbol from an OmpObject.
+ auto getDeclareVariantObjectSymbol =
+ [](const parser::OmpObject &obj) -> const semantics::Symbol * {
+ return common::visit(
+ common::visitors{
+ [](const parser::Designator &d) -> const semantics::Symbol * {
+ if (const auto *dr = std::get_if<parser::DataRef>(&d.u))
+ if (const auto *n = std::get_if<parser::Name>(&dr->u))
+ return n->symbol;
+ return nullptr;
+ },
+ [](const parser::Name &n) -> const semantics::Symbol * {
+ return n.symbol;
+ },
+ [](const parser::OmpObject::Invalid &)
+ -> const semantics::Symbol * { return nullptr; },
+ },
+ obj.u);
+ };
+
+ common::visit(common::visitors{
+ [&](const parser::OmpBaseVariantNames &y) {
+ if (const auto *sym =
+ getDeclareVariantObjectSymbol(std::get<0>(y.t)))
+ baseProcedureSym = &sym->GetUltimate();
+ if (const auto *sym =
+ getDeclareVariantObjectSymbol(std::get<1>(y.t)))
+ variantSym = &sym->GetUltimate();
+ },
+ [&](const parser::OmpLocator &y) {
+ common::visit(
+ common::visitors{
+ [&](const parser::OmpObject &obj) {
+ if (const auto *sym =
+ getDeclareVariantObjectSymbol(obj))
+ variantSym = &sym->GetUltimate();
+ },
+ [](const parser::FunctionReference &) {},
+ },
+ y.u);
+ },
+ [](auto &&) {},
+ },
+ declareVariantArg.u);
+ if (!variantSym)
+ return;
+
+ // Get the user condition expression (if any).
+ const parser::ScalarExpr *userCond =
+ getDeclareVariantUserConditionExpr(declareVariantSpec);
+
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+
+ auto genCall = [&](const semantics::Symbol &sym) {
+ std::string name = converter.mangleName(sym);
+ auto func = builder.getModule().lookupSymbol<mlir::func::FuncOp>(name);
+ if (!func) {
+ auto funcType = builder.getFunctionType({}, {});
+ func = mlir::func::FuncOp::create(builder, loc, name, funcType);
+ func.setVisibility(mlir::SymbolTable::Visibility::Private);
+ builder.getModule().push_back(func);
+ }
+ fir::CallOp::create(builder, loc, func);
+ };
+
+ if (userCond) {
+ // Runtime condition: if true, call variant and return; else fall through
+ // to the base body. Split the current block after the condition value.
+ lower::StatementContext stmtCtx;
+ const auto *condExpr = semantics::GetExpr(semaCtx, *userCond);
+ if (!condExpr)
+ return;
+ mlir::Value condVal =
+ fir::getBase(converter.genExprValue(*condExpr, stmtCtx, &loc));
+ if (condVal.getType() != builder.getI1Type())
+ condVal = builder.createConvert(loc, builder.getI1Type(), condVal);
+ stmtCtx.finalizeAndReset();
+
+ // Split the current block: everything after the insertion point goes
+ // into continueBlock (the base body path).
+ mlir::Block *currentBlock = builder.getInsertionBlock();
+ mlir::Block::iterator splitPt = builder.getInsertionPoint();
+ mlir::Block *continueBlock = currentBlock->splitBlock(splitPt);
+
+ // Create a then-block for the variant call + return.
+ mlir::Block *thenBlock = builder.createBlock(continueBlock);
+
+ // Terminate the current block with a conditional branch.
+ builder.setInsertionPointToEnd(currentBlock);
+ mlir::cf::CondBranchOp::create(builder, loc, condVal, thenBlock,
+ mlir::ValueRange{}, continueBlock,
+ mlir::ValueRange{});
+
+ // Then block: call variant and return from the function.
+ builder.setInsertionPointToEnd(thenBlock);
+ genCall(*variantSym);
+ mlir::func::ReturnOp::create(builder, loc);
+
+ // Continue block: the rest of the base body will be generated here.
+ builder.setInsertionPointToStart(continueBlock);
+ } else {
+ // No condition: unconditionally call the variant.
+ genCall(*variantSym);
+ }
+}
+
static void
genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OmpDeclareVariantDirective &declareVariantDirective) {
- if (!semaCtx.langOptions().OpenMPSimd)
- TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
+ genDeclareVariant(converter, symTable, semaCtx, eval,
+ declareVariantDirective);
}
static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
@@ -5368,9 +5750,9 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
static void genO...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/204345
More information about the flang-commits
mailing list