[flang-commits] [flang] [flang][OpenMP] Declare variant (PR #204345)

via flang-commits flang-commits at lists.llvm.org
Wed Jun 17 05:19:47 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: SunilKuravinakop

<details>
<summary>Changes</summary>

Support for declare variant.

This also Fixes [#<!-- -->203833](https://github.com/llvm/llvm-project/issues/203833)

---

Patch is 31.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/204345.diff


8 Files Affected:

- (modified) flang/include/flang/Semantics/semantics.h (+24) 
- (modified) flang/lib/Lower/Bridge.cpp (+7-1) 
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+388-6) 
- (modified) flang/lib/Semantics/check-omp-variant.cpp (+8-4) 
- (added) flang/test/Integration/OpenMP/declare-variant.f90 (+112) 
- (removed) flang/test/Lower/OpenMP/Todo/declare-variant.f90 (-17) 
- (added) flang/test/Lower/OpenMP/declare-variant-construct.f90 (+137) 
- (modified) flang/test/Semantics/OpenMP/declare-variant-match.f90 (-1) 


``````````diff
diff --git a/flang/include/flang/Semantics/semantics.h b/flang/include/flang/Semantics/semantics.h
index 5e98378fc49be..aa38baa7a941f 100644
--- a/flang/include/flang/Semantics/semantics.h
+++ b/flang/include/flang/Semantics/semantics.h
@@ -49,6 +49,7 @@ struct SelectRankConstruct;
 struct SelectTypeConstruct;
 struct Variable;
 struct WhereConstruct;
+struct OmpDirectiveSpecification;
 } // namespace Fortran::parser
 
 namespace Fortran::semantics {
@@ -57,6 +58,13 @@ class Symbol;
 class CommonBlockMap;
 using CommonBlockList = std::vector<std::pair<SymbolRef, std::size_t>>;
 
+// Information about a single declare variant directive associated with a
+// base procedure.
+struct OmpDeclareVariantInfo {
+  SymbolRef variantProcedure;
+  const parser::OmpDirectiveSpecification *spec;
+};
+
 using ConstructNode = std::variant<const parser::AssociateConstruct *,
     const parser::BlockConstruct *, const parser::CaseConstruct *,
     const parser::ChangeTeamConstruct *, const parser::CriticalConstruct *,
@@ -350,6 +358,19 @@ class SemanticsContext {
 
   void DumpSymbols(llvm::raw_ostream &);
 
+  // OpenMP declare variant support: register variant info for a base
+  // procedure, and query it during lowering.
+  void AddOmpDeclareVariant(const Symbol &base, OmpDeclareVariantInfo info) {
+    ompDeclareVariants_[&base].push_back(std::move(info));
+  }
+  const std::vector<OmpDeclareVariantInfo> *GetOmpDeclareVariants(
+      const Symbol &base) const {
+    auto it = ompDeclareVariants_.find(&base);
+    if (it != ompDeclareVariants_.end())
+      return &it->second;
+    return nullptr;
+  }
+
   // Top-level ProgramTrees are owned by the SemanticsContext for persistence.
   ProgramTree &SaveProgramTree(ProgramTree &&);
 
@@ -409,6 +430,9 @@ class SemanticsContext {
   UnorderedSymbolSet isUsed_;
   std::set<const parser::AccObject *> accObjectDuplicates_;
   std::list<ProgramTree> programTrees_;
+  // OpenMP declare variant: base symbol -> list of variant info
+  std::map<const Symbol *, std::vector<OmpDeclareVariantInfo>>
+      ompDeclareVariants_;
 };
 
 class Semantics {
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 138e0fbe0fde8..1d94ade62b736 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4123,7 +4123,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         ompDeferredDeclareTarget);
     genOpenMPDeclarativeConstruct(
         *this, localSymbols, bridge.getSemanticsContext(), getEval(), ompDecl);
-    builder->restoreInsertionPoint(insertPt);
+    // In case of declare variant related lowering with runtime user
+    // conditions (if/else) the original block is split. This is done
+    // by creating conditional branches that terminate the current block,
+    // making the saved insertion point invalid for subsequent insertions.
+    // Otherwise restore the insertion point.
+    if (builder->getInsertionBlock() == insertPt.getBlock())
+      builder->restoreInsertionPoint(insertPt);
   }
 
   /// Generate FIR for a SELECT CASE statement.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index facca9867e4bb..79b8078d96a9c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1370,6 +1370,11 @@ struct OpWithBodyGenInfo {
   bool privatize = true;
 };
 
+// Forward declaration for variant-aware evaluation processing.
+static void genNestedEvaluationsWithVariantResolution(
+    lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, llvm::omp::Directive constructDir);
+
 /// Create the body (block) for an OpenMP Operation.
 ///
 /// \param [in]   op  - the operation the body belongs to.
@@ -1466,7 +1471,19 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
       firOpBuilder.setInsertionPointToEnd(&op.getRegion(0).back());
       auto *temp = lower::genOpenMPTerminator(firOpBuilder, &op, info.loc);
       firOpBuilder.setInsertionPointAfter(marker);
-      genNestedEvaluations(info.converter, info.eval);
+      // For constructs that support declare variant resolution (parallel,
+      // teams, dispatch), use variant-aware evaluation processing.
+      static constexpr llvm::omp::Directive variantResolutionDirs[] = {
+          llvm::omp::Directive::OMPD_parallel,
+          llvm::omp::Directive::OMPD_teams,
+          llvm::omp::Directive::OMPD_dispatch,
+      };
+      if (llvm::is_contained(variantResolutionDirs, info.dir)) {
+        genNestedEvaluationsWithVariantResolution(info.converter, info.semaCtx,
+                                                  info.eval, info.dir);
+      } else {
+        genNestedEvaluations(info.converter, info.eval);
+      }
       temp->erase();
     }
   }
@@ -4209,12 +4226,377 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
     TODO(converter.getCurrentLocation(), "OpenMP ASSUMES declaration");
 }
 
+// Return pointer to match() clause from !omp declare variant
+static const parser::traits::OmpContextSelectorSpecification *
+getMatchSelectorInDeclareVariant(
+    const parser::OmpDirectiveSpecification &declareVariant) {
+  auto matchClauseIt = llvm::find_if(
+      declareVariant.Clauses().v, [](const parser::OmpClause &clause) {
+        return clause.Id() == llvm::omp::Clause::OMPC_match;
+      });
+
+  if (matchClauseIt == declareVariant.Clauses().v.end())
+    return nullptr;
+
+  return &std::get<parser::OmpClause::Match>(matchClauseIt->u).v.v;
+}
+
+// Does this declare variant (declareVariantSpec) have a directive which matches
+// a particular constructDirective (e.g. dispatch)?
+static bool doesDeclareVariantMatchDirective(
+    const parser::OmpDirectiveSpecification &declareVariantSpec,
+    llvm::omp::Directive expectedConstructDir) {
+
+  const auto *matchSelector =
+      getMatchSelectorInDeclareVariant(declareVariantSpec);
+  if (!matchSelector)
+    return false;
+
+  using TraitSetName = parser::OmpTraitSetSelectorName;
+
+  return llvm::any_of(matchSelector->v, [expectedConstructDir](
+                                            const auto &traitSetSelector) {
+    if (std::get<TraitSetName>(traitSetSelector.t).v !=
+        TraitSetName::Value::Construct)
+      return false;
+
+    const auto &constructTraits =
+        std::get<std::list<parser::OmpTraitSelector>>(traitSetSelector.t);
+
+    return llvm::any_of(
+        constructTraits,
+        [expectedConstructDir](const parser::OmpTraitSelector &constructTrait) {
+          const auto &traitName =
+              std::get<parser::OmpTraitSelectorName>(constructTrait.t);
+          const auto *constructDir =
+              std::get_if<llvm::omp::Directive>(&traitName.u);
+          return constructDir && *constructDir == expectedConstructDir;
+        });
+  });
+}
+
+// Check if a context selector includes any construct trait.
+static bool hasAnyConstructTrait(
+    const parser::OmpDirectiveSpecification &declareVariantSpec) {
+  const auto *matchSelector =
+      getMatchSelectorInDeclareVariant(declareVariantSpec);
+  if (!matchSelector)
+    return false;
+  using SetName = parser::OmpTraitSetSelectorName;
+
+  return llvm::any_of(matchSelector->v, [](const auto &traitSetSelector) {
+    return std::get<SetName>(traitSetSelector.t).v == SetName::Value::Construct;
+  });
+}
+
+// Return the "expr" from match(user={condition(expr)}) in a
+// declare variant directive, or nullptr if no such condition exists.
+static const parser::ScalarExpr *getDeclareVariantUserConditionExpr(
+    const parser::OmpDirectiveSpecification &declareVariantSpec) {
+  const auto *matchSelector =
+      getMatchSelectorInDeclareVariant(declareVariantSpec);
+  if (!matchSelector)
+    return nullptr;
+
+  using TraitSetName = parser::OmpTraitSetSelectorName;
+  using TraitName = parser::OmpTraitSelectorName;
+
+  auto userTraitSetIt =
+      llvm::find_if(matchSelector->v, [](const auto &traitSetSelector) {
+        return std::get<TraitSetName>(traitSetSelector.t).v ==
+               TraitSetName::Value::User;
+      });
+  if (userTraitSetIt == matchSelector->v.end())
+    return nullptr;
+
+  const auto &userTraits =
+      std::get<std::list<parser::OmpTraitSelector>>(userTraitSetIt->t);
+  auto conditionTraitIt =
+      llvm::find_if(userTraits, [](const parser::OmpTraitSelector &trait) {
+        const auto &traitName = std::get<TraitName>(trait.t);
+        return std::holds_alternative<TraitName::Value>(traitName.u) &&
+               std::get<TraitName::Value>(traitName.u) ==
+                   TraitName::Value::Condition;
+      });
+  if (conditionTraitIt == userTraits.end())
+    return nullptr;
+
+  const auto &properties =
+      std::get<std::optional<parser::OmpTraitSelector::Properties>>(
+          conditionTraitIt->t);
+  if (!properties)
+    return nullptr;
+  const auto &traitProperties =
+      std::get<std::list<parser::OmpTraitProperty>>(properties->t);
+
+  auto scalarExprIt =
+      llvm::find_if(traitProperties, [](const parser::OmpTraitProperty &prop) {
+        return std::holds_alternative<parser::ScalarExpr>(prop.u);
+      });
+  if (scalarExprIt == traitProperties.end())
+    return nullptr;
+
+  return &std::get<parser::ScalarExpr>(scalarExprIt->u);
+}
+
+// Return the ultimate procedure symbol referenced by a CallStmt,
+// or nullptr if the procedure cannot be resolved.
+static const semantics::Symbol *
+getCallProcedureSymbol(const parser::CallStmt &callStmt) {
+  const auto &call = std::get<parser::Call>(callStmt.t);
+  const auto &procedureDesignator =
+      std::get<parser::ProcedureDesignator>(call.t);
+  if (const auto *procedureName =
+          std::get_if<parser::Name>(&procedureDesignator.u)) {
+    if (procedureName->symbol)
+      return &procedureName->symbol->GetUltimate();
+  }
+  return nullptr;
+}
+
+// Generate variant-aware nested evaluations for a construct region.
+// For each CallStmt, check if the called procedure has declare variant
+// directives matching the given construct directive. If so, generate
+// if/else cascade for variant selection; otherwise generate the call normally.
+static void genNestedEvaluationsWithVariantResolution(
+    lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, llvm::omp::Directive constructDir) {
+  for (lower::pft::Evaluation &nestedEval : eval.getNestedEvaluations()) {
+    const parser::CallStmt *callStmt = nestedEval.getIf<parser::CallStmt>();
+    if (!callStmt) {
+      converter.genEval(nestedEval);
+      continue;
+    }
+
+    const semantics::Symbol *baseSym = getCallProcedureSymbol(*callStmt);
+    if (!baseSym) {
+      converter.genEval(nestedEval);
+      continue;
+    }
+
+    const auto *declareVariants = semaCtx.GetOmpDeclareVariants(*baseSym);
+    if (!declareVariants) {
+      converter.genEval(nestedEval);
+      continue;
+    }
+
+    // Collect variants matching the current construct directive.
+    struct ConstructVariant {
+      const semantics::Symbol *variantSym;
+      const parser::ScalarExpr *userCond;
+    };
+    llvm::SmallVector<ConstructVariant, 4> matchingVariants;
+    for (const auto &variantInfo : *declareVariants) {
+      if (variantInfo.spec &&
+          doesDeclareVariantMatchDirective(*variantInfo.spec, constructDir)) {
+        matchingVariants.push_back(
+            {&variantInfo.variantProcedure.get(),
+             getDeclareVariantUserConditionExpr(*variantInfo.spec)});
+      }
+    }
+
+    if (matchingVariants.empty()) {
+      converter.genEval(nestedEval);
+      continue;
+    }
+
+    // Generate if/else cascade for variant selection.
+    fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+    mlir::Location loc = converter.getCurrentLocation();
+    lower::StatementContext stmtCtx;
+
+    for (const auto &matchedVariant : matchingVariants) {
+      if (matchedVariant.userCond) {
+        const auto *condExpr =
+            semantics::GetExpr(semaCtx, *matchedVariant.userCond);
+        if (!condExpr)
+          continue;
+        mlir::Value condVal =
+            fir::getBase(converter.genExprValue(*condExpr, stmtCtx, &loc));
+        if (condVal.getType() != builder.getI1Type())
+          condVal = builder.createConvert(loc, builder.getI1Type(), condVal);
+        stmtCtx.finalizeAndReset();
+
+        auto ifOp = fir::IfOp::create(builder, loc, condVal,
+                                      /*withElseRegion=*/true);
+        builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+        std::string variantName =
+            converter.mangleName(*matchedVariant.variantSym);
+        auto variantFunc =
+            builder.getModule().lookupSymbol<mlir::func::FuncOp>(variantName);
+        if (!variantFunc) {
+          auto funcType = builder.getFunctionType({}, {});
+          variantFunc =
+              mlir::func::FuncOp::create(builder, loc, variantName, funcType);
+          variantFunc.setVisibility(mlir::SymbolTable::Visibility::Private);
+          builder.getModule().push_back(variantFunc);
+        }
+        fir::CallOp::create(builder, loc, variantFunc);
+
+        builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+      } else {
+        // No condition: unconditionally call the variant.
+        std::string variantName =
+            converter.mangleName(*matchedVariant.variantSym);
+        auto variantFunc =
+            builder.getModule().lookupSymbol<mlir::func::FuncOp>(variantName);
+        if (!variantFunc) {
+          auto funcType = builder.getFunctionType({}, {});
+          variantFunc =
+              mlir::func::FuncOp::create(builder, loc, variantName, funcType);
+          variantFunc.setVisibility(mlir::SymbolTable::Visibility::Private);
+          builder.getModule().push_back(variantFunc);
+        }
+        fir::CallOp::create(builder, loc, variantFunc);
+        // Skip the base call since variant is unconditional.
+        goto next_eval;
+      }
+    }
+
+    // Generate the base call in the final else-branch.
+    converter.genEval(nestedEval);
+  next_eval:;
+  }
+}
+
+// Lower an OpenMP declare variant directive for the base procedure body.
+static void genDeclareVariant(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    const parser::OmpDeclareVariantDirective &declareVariantDirective) {
+
+  const parser::OmpDirectiveSpecification &declareVariantSpec{
+      declareVariantDirective.v};
+  mlir::Location loc = converter.genLocation(declareVariantSpec.source);
+
+  // Variants with construct traits (dispatch, parallel, teams)
+  // are resolved at the corresponding construct call sites, not inside the
+  // base procedure body.
+  if (hasAnyConstructTrait(declareVariantSpec))
+    return;
+
+  // Extract the variant procedure symbol from the argument.
+  const parser::OmpArgumentList &declareVariantArgs{
+      declareVariantSpec.Arguments()};
+  if (declareVariantArgs.v.size() != 1)
+    return;
+  const semantics::Symbol *variantSym = nullptr;
+  const semantics::Symbol *baseProcedureSym = nullptr;
+  const parser::OmpArgument &declareVariantArg{declareVariantArgs.v.front()};
+
+  // Extract a symbol from an OmpObject.
+  auto getDeclareVariantObjectSymbol =
+      [](const parser::OmpObject &obj) -> const semantics::Symbol * {
+    return common::visit(
+        common::visitors{
+            [](const parser::Designator &d) -> const semantics::Symbol * {
+              if (const auto *dr = std::get_if<parser::DataRef>(&d.u))
+                if (const auto *n = std::get_if<parser::Name>(&dr->u))
+                  return n->symbol;
+              return nullptr;
+            },
+            [](const parser::Name &n) -> const semantics::Symbol * {
+              return n.symbol;
+            },
+            [](const parser::OmpObject::Invalid &)
+                -> const semantics::Symbol * { return nullptr; },
+        },
+        obj.u);
+  };
+
+  common::visit(common::visitors{
+                    [&](const parser::OmpBaseVariantNames &y) {
+                      if (const auto *sym =
+                              getDeclareVariantObjectSymbol(std::get<0>(y.t)))
+                        baseProcedureSym = &sym->GetUltimate();
+                      if (const auto *sym =
+                              getDeclareVariantObjectSymbol(std::get<1>(y.t)))
+                        variantSym = &sym->GetUltimate();
+                    },
+                    [&](const parser::OmpLocator &y) {
+                      common::visit(
+                          common::visitors{
+                              [&](const parser::OmpObject &obj) {
+                                if (const auto *sym =
+                                        getDeclareVariantObjectSymbol(obj))
+                                  variantSym = &sym->GetUltimate();
+                              },
+                              [](const parser::FunctionReference &) {},
+                          },
+                          y.u);
+                    },
+                    [](auto &&) {},
+                },
+                declareVariantArg.u);
+  if (!variantSym)
+    return;
+
+  // Get the user condition expression (if any).
+  const parser::ScalarExpr *userCond =
+      getDeclareVariantUserConditionExpr(declareVariantSpec);
+
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+
+  auto genCall = [&](const semantics::Symbol &sym) {
+    std::string name = converter.mangleName(sym);
+    auto func = builder.getModule().lookupSymbol<mlir::func::FuncOp>(name);
+    if (!func) {
+      auto funcType = builder.getFunctionType({}, {});
+      func = mlir::func::FuncOp::create(builder, loc, name, funcType);
+      func.setVisibility(mlir::SymbolTable::Visibility::Private);
+      builder.getModule().push_back(func);
+    }
+    fir::CallOp::create(builder, loc, func);
+  };
+
+  if (userCond) {
+    // Runtime condition: if true, call variant and return; else fall through
+    // to the base body. Split the current block after the condition value.
+    lower::StatementContext stmtCtx;
+    const auto *condExpr = semantics::GetExpr(semaCtx, *userCond);
+    if (!condExpr)
+      return;
+    mlir::Value condVal =
+        fir::getBase(converter.genExprValue(*condExpr, stmtCtx, &loc));
+    if (condVal.getType() != builder.getI1Type())
+      condVal = builder.createConvert(loc, builder.getI1Type(), condVal);
+    stmtCtx.finalizeAndReset();
+
+    // Split the current block: everything after the insertion point goes
+    // into continueBlock (the base body path).
+    mlir::Block *currentBlock = builder.getInsertionBlock();
+    mlir::Block::iterator splitPt = builder.getInsertionPoint();
+    mlir::Block *continueBlock = currentBlock->splitBlock(splitPt);
+
+    // Create a then-block for the variant call + return.
+    mlir::Block *thenBlock = builder.createBlock(continueBlock);
+
+    // Terminate the current block with a conditional branch.
+    builder.setInsertionPointToEnd(currentBlock);
+    mlir::cf::CondBranchOp::create(builder, loc, condVal, thenBlock,
+                                   mlir::ValueRange{}, continueBlock,
+                                   mlir::ValueRange{});
+
+    // Then block: call variant and return from the function.
+    builder.setInsertionPointToEnd(thenBlock);
+    genCall(*variantSym);
+    mlir::func::ReturnOp::create(builder, loc);
+
+    // Continue block: the rest of the base body will be generated here.
+    builder.setInsertionPointToStart(continueBlock);
+  } else {
+    // No condition: unconditionally call the variant.
+    genCall(*variantSym);
+  }
+}
+
 static void
 genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
        semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
        const parser::OmpDeclareVariantDirective &declareVariantDirective) {
-  if (!semaCtx.langOptions().OpenMPSimd)
-    TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
+  genDeclareVariant(converter, symTable, semaCtx, eval,
+                    declareVariantDirective);
 }
 
 static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
@@ -5368,9 +5750,9 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
 static void genO...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/204345


More information about the flang-commits mailing list