[flang-commits] [flang] [flang][OpenMP] Add optional SemanticsContext parameter to loop utili… (PR #191231)

via flang-commits flang-commits at lists.llvm.org
Thu Apr 9 09:08:13 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: Krzysztof Parzyszek (kparzysz)

<details>
<summary>Changes</summary>

…ties

Some of the utilities may be used in symbol resolution which is before the expression analysis is done. In such situations, the typedExpr's normally stored in parser::Expr may not be available. To be able to obtain numeric values of expressions, using the analyzer directly may be necessary, which requires SemanticsContext to be provided.

---
Full diff: https://github.com/llvm/llvm-project/pull/191231.diff


2 Files Affected:

- (modified) flang/include/flang/Semantics/openmp-utils.h (+28-11) 
- (modified) flang/lib/Semantics/openmp-utils.cpp (+46-28) 


``````````diff
diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h
index b741c8eac3248..fad6f0db34f3a 100644
--- a/flang/include/flang/Semantics/openmp-utils.h
+++ b/flang/include/flang/Semantics/openmp-utils.h
@@ -118,6 +118,17 @@ std::optional<evaluate::DynamicType> GetDynamicType(
     const parser::Expr &parserExpr);
 
 std::optional<bool> GetLogicalValue(const SomeExpr &expr);
+std::optional<int64_t> GetIntValueFromExpr(
+    const parser::Expr &parserExpr, SemanticsContext *semaCtx = nullptr);
+
+template <typename T>
+std::optional<int64_t> GetIntValueFromExpr(
+    const T &wrappedExpr, SemanticsContext *semaCtx = nullptr) {
+  if (auto *parserExpr{parser::Unwrap<parser::Expr>(wrappedExpr)}) {
+    return GetIntValueFromExpr(*parserExpr, semaCtx);
+  }
+  return std::nullopt;
+}
 
 std::optional<bool> IsContiguous(
     SemanticsContext &semaCtx, const parser::OmpObject &object);
@@ -194,25 +205,29 @@ template <typename T> struct WithReason {
 
 WithReason<int64_t> GetArgumentValueWithReason(
     const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
-    unsigned version);
+    unsigned version, SemanticsContext *semaCtx = nullptr);
 WithReason<int64_t> GetNumArgumentsWithReason(
     const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
-    unsigned version);
+    unsigned version, SemanticsContext *semaCtx = nullptr);
 WithReason<int64_t> GetHeightWithReason(
-    const parser::OmpDirectiveSpecification &spec, unsigned version);
+    const parser::OmpDirectiveSpecification &spec, unsigned version,
+    SemanticsContext *semaCtx = nullptr);
 
 // Return the depth of the affected nests:
 //   {affected-depth, reason, must-be-perfect-nest}.
 std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
-    const parser::OmpDirectiveSpecification &spec, unsigned version);
+    const parser::OmpDirectiveSpecification &spec, unsigned version,
+    SemanticsContext *semaCtx = nullptr);
 // Return the range of the affected nests in the sequence:
 //   {first, count, reason}.
 // If the range is "the whole sequence", the return value will be {1, -1, ...}.
 WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
-    const parser::OmpDirectiveSpecification &spec, unsigned version);
+    const parser::OmpDirectiveSpecification &spec, unsigned version,
+    SemanticsContext *semaCtx = nullptr);
 /// Return the depth in which all loops must be rectangular.
 WithReason<int64_t> GetRectangularNestDepthWithReason(
-    const parser::OmpDirectiveSpecification &spec, unsigned version);
+    const parser::OmpDirectiveSpecification &spec, unsigned version,
+    SemanticsContext *semaCtx = nullptr);
 
 // Count the required loop count from range. If count == -1, return -1,
 // indicating all loops in the sequence.
@@ -223,11 +238,12 @@ std::optional<int64_t> GetRequiredCount(
 
 struct LoopSequence {
   LoopSequence(const parser::ExecutionPartConstruct &root, unsigned version,
-      bool allowAllLoops = false);
+      bool allowAllLoops = false, SemanticsContext *semaCtx = nullptr);
 
   template <typename R, typename = std::enable_if_t<is_range_v<R>>>
-  LoopSequence(const R &range, unsigned version, bool allowAllLoops = false)
-      : version_(version), allowAllLoops_(allowAllLoops) {
+  LoopSequence(const R &range, unsigned version, bool allowAllLoops = false,
+      SemanticsContext *semaCtx = nullptr)
+      : version_(version), allowAllLoops_(allowAllLoops), semaCtx_(semaCtx) {
     entry_ = std::make_unique<Construct>(range, nullptr);
     createChildrenFromRange(entry_->location);
     precalculate();
@@ -265,8 +281,8 @@ struct LoopSequence {
 private:
   using Construct = ExecutionPartIterator::Construct;
 
-  LoopSequence(
-      std::unique_ptr<Construct> entry, unsigned version, bool allowAllLoops);
+  LoopSequence(std::unique_ptr<Construct> entry, unsigned version,
+      bool allowAllLoops, SemanticsContext *semaCtx = nullptr);
 
   template <typename R, typename = std::enable_if_t<is_range_v<R>>>
   void createChildrenFromRange(const R &range) {
@@ -318,6 +334,7 @@ struct LoopSequence {
   bool allowAllLoops_;
   std::unique_ptr<Construct> entry_;
   std::vector<LoopSequence> children_;
+  SemanticsContext *semaCtx_{nullptr};
 };
 } // namespace omp
 } // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index 13004fb4bab6b..a36cadc04ee65 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -342,6 +342,19 @@ std::optional<bool> GetLogicalValue(const SomeExpr &expr) {
   return LogicalConstantVistor{}(expr);
 }
 
+std::optional<int64_t> GetIntValueFromExpr(
+    const parser::Expr &parserExpr, SemanticsContext *semaCtx) {
+  if (auto value{GetIntValue(parserExpr)}) {
+    return value;
+  }
+  if (semaCtx) {
+    if (auto expr{evaluate::ExpressionAnalyzer{*semaCtx}.Analyze(parserExpr)}) {
+      return evaluate::ToInt64(expr);
+    }
+  }
+  return std::nullopt;
+}
+
 namespace {
 struct ContiguousHelper {
   ContiguousHelper(SemanticsContext &context)
@@ -691,10 +704,10 @@ static SymbolVector SelectUsedSymbols(
 
 WithReason<int64_t> GetArgumentValueWithReason(
     const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
-    unsigned version) {
+    unsigned version, SemanticsContext *semaCtx) {
   if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
     if (auto *expr{parser::Unwrap<parser::Expr>(clause->u)}) {
-      if (auto value{GetIntValue(*expr)}) {
+      if (auto value{GetIntValueFromExpr(*expr, semaCtx)}) {
         std::string name{GetUpperName(clauseId, version)};
         Reason reason;
         reason.Say(clause->source,
@@ -723,7 +736,7 @@ static WithReason<int64_t> GetNumArgumentsWithReasonForType(
 
 WithReason<int64_t> GetNumArgumentsWithReason(
     const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
-    unsigned version) {
+    unsigned version, SemanticsContext *semaCtx) {
   if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
     std::string name{GetUpperName(clauseId, version)};
     // Try the types used for list items.
@@ -744,7 +757,8 @@ WithReason<int64_t> GetNumArgumentsWithReason(
 }
 
 WithReason<int64_t> GetHeightWithReason(
-    const parser::OmpDirectiveSpecification &spec, unsigned version) {
+    const parser::OmpDirectiveSpecification &spec, unsigned version,
+    SemanticsContext *semaCtx) {
   bool isFullUnroll{IsFullUnroll(spec)};
 
   if (!isFullUnroll && !IsTransformableLoop(spec)) {
@@ -756,7 +770,7 @@ WithReason<int64_t> GetHeightWithReason(
   switch (spec.DirName().v) {
   case llvm::omp::Directive::OMPD_flatten:
     if (auto &&value{GetArgumentValueWithReason(
-            spec, llvm::omp::Clause::OMPC_depth, version)}) {
+            spec, llvm::omp::Clause::OMPC_depth, version, semaCtx)}) {
       // FLATTEN DEPTH(n) replaces n loops with 1.
       return {int64_t(1) - *value.value, std::move(value.reason)};
     } else {
@@ -775,7 +789,7 @@ WithReason<int64_t> GetHeightWithReason(
   case llvm::omp::Directive::OMPD_stripe:
   case llvm::omp::Directive::OMPD_tile:
     return GetNumArgumentsWithReason(
-        spec, llvm::omp::Clause::OMPC_sizes, version);
+        spec, llvm::omp::Clause::OMPC_sizes, version, semaCtx);
   case llvm::omp::Directive::OMPD_unroll:
     if (isFullUnroll) {
       Reason reason;
@@ -955,7 +969,8 @@ WithReason<T> operator+(T a, const WithReason<T> &b) {
 // Return the depth of the affected nests:
 //   {affected-depth, must-be-perfect-nest}.
 std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
-    const parser::OmpDirectiveSpecification &spec, unsigned version) {
+    const parser::OmpDirectiveSpecification &spec, unsigned version,
+    SemanticsContext *semaCtx) {
   llvm::omp::Directive dir{spec.DirId()};
   bool allowsCollapse{llvm::omp::isAllowedClauseForDirective(
       dir, llvm::omp::Clause::OMPC_collapse, version)};
@@ -964,9 +979,9 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
 
   if (allowsCollapse || allowsOrdered) {
     auto [ccount, creason]{GetArgumentValueWithReason(
-        spec, llvm::omp::Clause::OMPC_collapse, version)};
+        spec, llvm::omp::Clause::OMPC_collapse, version, semaCtx)};
     auto [ocount, oreason]{GetArgumentValueWithReason(
-        spec, llvm::omp::Clause::OMPC_ordered, version)};
+        spec, llvm::omp::Clause::OMPC_ordered, version, semaCtx)};
     // Ignore invalid arguments.
     if (ccount <= 0) {
       ccount = std::nullopt;
@@ -989,7 +1004,7 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
       // Get the length of the argument list to PERMUTATION.
       if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_permutation)) {
         auto [num, reason]{GetNumArgumentsWithReason(
-            spec, llvm::omp::Clause::OMPC_permutation, version)};
+            spec, llvm::omp::Clause::OMPC_permutation, version, semaCtx)};
         return {{num, std::move(reason)}, true};
       }
       // PERMUTATION not specified, assume PERMUTATION(2, 1).
@@ -1004,14 +1019,14 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
     case llvm::omp::Directive::OMPD_tile: {
       // Get the length of the argument list to SIZES.
       auto [num, reason]{GetNumArgumentsWithReason(
-          spec, llvm::omp::Clause::OMPC_sizes, version)};
+          spec, llvm::omp::Clause::OMPC_sizes, version, semaCtx)};
       return {{num, std::move(reason)}, true};
     }
     case llvm::omp::Directive::OMPD_fuse: {
       // Get the value from the argument to DEPTH.
       if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_depth)) {
         auto [count, reason]{GetArgumentValueWithReason(
-            spec, llvm::omp::Clause::OMPC_depth, version)};
+            spec, llvm::omp::Clause::OMPC_depth, version, semaCtx)};
         return {{count, std::move(reason)}, true};
       }
       std::string name{GetUpperName(llvm::omp::Clause::OMPC_depth, version)};
@@ -1035,7 +1050,8 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
 // Return the range of the affected nests in the sequence:
 //   {first, count, std::move(reason)}.
 WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
-    const parser::OmpDirectiveSpecification &spec, unsigned version) {
+    const parser::OmpDirectiveSpecification &spec, unsigned version,
+    SemanticsContext *semaCtx) {
   llvm::omp::Directive dir{spec.DirId()};
 
   if (dir == llvm::omp::Directive::OMPD_fuse) {
@@ -1043,8 +1059,8 @@ WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
     if (auto *clause{
             parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_looprange)}) {
       auto &range{DEREF(parser::Unwrap<parser::OmpLooprangeClause>(clause->u))};
-      std::optional<int64_t> first{GetIntValue(std::get<0>(range.t))};
-      std::optional<int64_t> count{GetIntValue(std::get<1>(range.t))};
+      auto first{GetIntValueFromExpr(std::get<0>(range.t), semaCtx)};
+      auto count{GetIntValueFromExpr(std::get<1>(range.t), semaCtx)};
       if (!first || !count || *first <= 0 || *count <= 0) {
         return {};
       }
@@ -1071,8 +1087,9 @@ WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
 }
 
 WithReason<int64_t> GetRectangularNestDepthWithReason(
-    const parser::OmpDirectiveSpecification &spec, unsigned version) {
-  auto [depth, _]{GetAffectedNestDepthWithReason(spec, version)};
+    const parser::OmpDirectiveSpecification &spec, unsigned version,
+    SemanticsContext *semaCtx) {
+  auto [depth, _]{GetAffectedNestDepthWithReason(spec, version, semaCtx)};
   if (!depth) {
     return {};
   }
@@ -1195,8 +1212,8 @@ static_assert(HasSourceT<parser::ExecutionPartConstruct>::value);
 #endif // EXPENSIVE_CHECKS
 
 LoopSequence::LoopSequence(const parser::ExecutionPartConstruct &root,
-    unsigned version, bool allowAllLoops)
-    : version_(version), allowAllLoops_(allowAllLoops) {
+    unsigned version, bool allowAllLoops, SemanticsContext *semaCtx)
+    : version_(version), allowAllLoops_(allowAllLoops), semaCtx_(semaCtx) {
   entry_ = createConstructEntry(root);
   assert(entry_ && "Expecting loop like code");
 
@@ -1204,10 +1221,10 @@ LoopSequence::LoopSequence(const parser::ExecutionPartConstruct &root,
   precalculate();
 }
 
-LoopSequence::LoopSequence(
-    std::unique_ptr<Construct> entry, unsigned version, bool allowAllLoops)
+LoopSequence::LoopSequence(std::unique_ptr<Construct> entry, unsigned version,
+    bool allowAllLoops, SemanticsContext *semaCtx)
     : version_(version), allowAllLoops_(allowAllLoops),
-      entry_(std::move(entry)) {
+      entry_(std::move(entry)), semaCtx_(semaCtx) {
   createChildrenFromRange(entry_->location);
   precalculate();
 }
@@ -1240,7 +1257,7 @@ void LoopSequence::createChildrenFromRange(
   for (auto &code : BlockRange(begin, end, BlockRange::Step::Over)) {
     if (auto entry{createConstructEntry(code)}) {
       children_.push_back(
-          LoopSequence(std::move(entry), version_, allowAllLoops_));
+          LoopSequence(std::move(entry), version_, allowAllLoops_, semaCtx_));
       // Even when DO WHILE et al are allowed to have entries, still treat
       // them as invalid intervening code.
       // Give it priority over other kinds of invalid interveninig code.
@@ -1346,7 +1363,7 @@ WithReason<int64_t> LoopSequence::calculateLength() const {
     }
 
     auto *loopRange{parser::Unwrap<parser::OmpLooprangeClause>(*clause)};
-    std::optional<int64_t> count{GetIntValue(std::get<1>(loopRange->t))};
+    auto count{GetIntValueFromExpr(std::get<1>(loopRange->t), semaCtx_)};
     if (!count || *count <= 0) {
       return {};
     }
@@ -1456,11 +1473,12 @@ LoopSequence::Depth LoopSequence::calculateDepths() const {
     if (auto *clause{parser::omp::FindClause(
             beginSpec, llvm::omp::Clause::OMPC_depth)}) {
       auto &expr{parser::UnwrapRef<parser::Expr>(clause->u)};
-      auto value{GetIntValue(expr)};
+      auto value{GetIntValueFromExpr(expr, semaCtx_)};
       // The result is a perfect nest only if all loop in the sequence
       // are fused.
       if (value && nestedLength.value) {
-        auto range{GetAffectedLoopRangeWithReason(beginSpec, version_)};
+        auto range{
+            GetAffectedLoopRangeWithReason(beginSpec, version_, semaCtx_)};
         if (auto required{GetRequiredCount(range.value)}) {
           if (*required == -1 || *required == *nestedLength.value) {
             return Depth{value, value};
@@ -1508,7 +1526,7 @@ LoopSequence::Depth LoopSequence::calculateDepths() const {
             beginSpec, llvm::omp::Clause::OMPC_partial)}) {
       std::optional<int64_t> factor;
       if (auto *expr{parser::Unwrap<parser::Expr>(clause->u)}) {
-        factor = GetIntValue(*expr);
+        factor = GetIntValueFromExpr(*expr, semaCtx_);
       }
       // If it's a partial unroll, and the unroll count is 1, then this
       // construct is a no-op.
@@ -1555,7 +1573,7 @@ WithReason<int64_t> LoopSequence::calculateHeight() const {
   if (auto *omp{parser::Unwrap<parser::OpenMPLoopConstruct>(*entry_->owner)}) {
     const parser::OmpDirectiveSpecification &beginSpec{omp->BeginDir()};
     if (IsLoopTransforming(beginSpec.DirId())) {
-      return GetHeightWithReason(beginSpec, version_);
+      return GetHeightWithReason(beginSpec, version_, semaCtx_);
     }
     return {0, Reason()};
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/191231


More information about the flang-commits mailing list