[flang-commits] [flang] [flang][openacc] Check trip count invariance with other IVs (PR #79906)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Mon Jan 29 14:18:49 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/79906

>From 662f1436fef0ba19fe1d6d538b879621b6870700 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 25 Jan 2024 15:15:22 -0800
Subject: [PATCH 1/2] [flang][openacc] Check trip count invariance with other
 IVs

2.9.1 The trip count for all loops associated with the collapse clause must be
computable and invariant in all the loops.

This patch checks that loops part of a collapse nest does not depends on outer
loops induction variables.

The check is also applied to combined construct with a loop.
---
 flang/lib/Semantics/resolve-directives.cpp | 51 +++++++++++++++++++---
 flang/test/Semantics/OpenACC/acc-loop.f90  | 40 ++++++++++++++++-
 2 files changed, 83 insertions(+), 8 deletions(-)

diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index e19f68eefa28672..37821b874170f2a 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -13,6 +13,7 @@
 #include "resolve-names-utils.h"
 #include "flang/Common/idioms.h"
 #include "flang/Evaluate/fold.h"
+#include "flang/Evaluate/tools.h"
 #include "flang/Evaluate/type.h"
 #include "flang/Parser/parse-tree-visitor.h"
 #include "flang/Parser/parse-tree.h"
@@ -266,7 +267,7 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
       Symbol::Flag::AccDevicePtr, Symbol::Flag::AccDeviceResident,
       Symbol::Flag::AccLink, Symbol::Flag::AccPresent};
 
-  void CheckAssociatedLoopIndex(const parser::OpenACCLoopConstruct &);
+  void CheckAssociatedLoop(const parser::DoConstruct &);
   void ResolveAccObjectList(const parser::AccObjectList &, Symbol::Flag);
   void ResolveAccObject(const parser::AccObject &, Symbol::Flag);
   Symbol *ResolveAcc(const parser::Name &, Symbol::Flag, Scope &);
@@ -882,7 +883,8 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCLoopConstruct &x) {
   }
   ClearDataSharingAttributeObjects();
   SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList));
-  CheckAssociatedLoopIndex(x);
+  const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
+  CheckAssociatedLoop(*outer);
   return true;
 }
 
@@ -1087,6 +1089,10 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCCombinedConstruct &x) {
   default:
     break;
   }
+  const auto &clauseList{std::get<parser::AccClauseList>(beginBlockDir.t)};
+  SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList));
+  const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
+  CheckAssociatedLoop(*outer);
   ClearDataSharingAttributeObjects();
   return true;
 }
@@ -1218,8 +1224,8 @@ std::int64_t AccAttributeVisitor::GetAssociatedLoopLevelFromClauses(
   return 1; // default is outermost loop
 }
 
-void AccAttributeVisitor::CheckAssociatedLoopIndex(
-    const parser::OpenACCLoopConstruct &x) {
+void AccAttributeVisitor::CheckAssociatedLoop(
+    const parser::DoConstruct &outerDoConstruct) {
   std::int64_t level{GetContext().associatedLoopLevel};
   if (level <= 0) { // collapse value was negative or 0
     return;
@@ -1250,10 +1256,41 @@ void AccAttributeVisitor::CheckAssociatedLoopIndex(
     return nullptr;
   };
 
-  const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
-  for (const parser::DoConstruct *loop{&*outer}; loop && level > 0;) {
+  auto checkExprHasSymbols = [&](llvm::SmallVector<Symbol *> &ivs,
+                                 semantics::UnorderedSymbolSet &symbols) {
+    for (auto iv : ivs) {
+      if (symbols.count(*iv) != 0) {
+        context_.Say(GetContext().directiveSource,
+            "Trip count must be computable and invariant"_err_en_US);
+      }
+    }
+  };
+
+  Symbol::Flag flag;
+  llvm::SmallVector<Symbol*> ivs;
+  using Bounds = parser::LoopControl::Bounds;
+  for (const parser::DoConstruct *loop{&outerDoConstruct}; loop && level > 0;) {
     // Go through all nested loops to ensure index variable exists.
-    GetLoopIndex(*loop);
+    if (const parser::Name *ivName{GetLoopIndex(*loop)}) {
+      if (auto *symbol{ResolveAcc(*ivName, flag, currScope())}) {
+        if (auto &control = loop->GetLoopControl()) {
+          if (const Bounds * b{std::get_if<Bounds>(&control->u)}) {
+            if (auto lowerExpr = semantics::AnalyzeExpr(context_, b->lower)) {
+              semantics::UnorderedSymbolSet lowerSyms =
+                  evaluate::CollectSymbols(*lowerExpr);
+              checkExprHasSymbols(ivs, lowerSyms);
+            }
+            if (auto upperExpr = semantics::AnalyzeExpr(context_, b->upper)) {
+              semantics::UnorderedSymbolSet upperSyms =
+                  evaluate::CollectSymbols(*upperExpr);
+              checkExprHasSymbols(ivs, upperSyms);
+            }
+          }
+        }
+        ivs.push_back(symbol);
+      }
+    }
+
     const auto &block{std::get<parser::Block>(loop->t)};
     --level;
     loop = getNextDoConstruct(block, level);
diff --git a/flang/test/Semantics/OpenACC/acc-loop.f90 b/flang/test/Semantics/OpenACC/acc-loop.f90
index fde836852c51ed5..859cf3feec0d670 100644
--- a/flang/test/Semantics/OpenACC/acc-loop.f90
+++ b/flang/test/Semantics/OpenACC/acc-loop.f90
@@ -10,9 +10,10 @@ program openacc_loop_validity
   type atype
     real(8), dimension(10) :: arr
     real(8) :: s
+    integer :: n
   end type atype
 
-  integer :: i, j, b, gang_size, vector_size, worker_size
+  integer :: i, j, k, b, gang_size, vector_size, worker_size
   integer, parameter :: N = 256
   integer, dimension(N) :: c
   logical, dimension(N) :: d, e
@@ -317,4 +318,41 @@ program openacc_loop_validity
     END DO
   END DO
 
+  !ERROR: Trip count must be computable and invariant
+  !$acc loop collapse(2)
+  DO i = 1, n
+    DO j = 1, c(i)
+    END DO
+  END DO
+
+  !ERROR: Trip count must be computable and invariant
+  !$acc loop collapse(2)
+  DO i = 1, n
+    DO j = 1, i
+    END DO
+  END DO
+
+  !ERROR: Trip count must be computable and invariant
+  !$acc loop collapse(2)
+  DO i = 1, n
+    DO j = 1, ta(i)%n
+    END DO
+  END DO
+
+  !ERROR: Trip count must be computable and invariant
+  !$acc parallel loop collapse(2)
+  DO i = 1, n
+    DO j = 1, ta(i)%n
+    END DO
+  END DO
+
+  !ERROR: Trip count must be computable and invariant
+  !$acc loop collapse(3)
+  DO i = 1, n
+    DO j = 1, n
+      DO k = 1, i
+      END DO
+    END DO
+  END DO
+
 end program openacc_loop_validity

>From 634988a173a1f9dc1722b2e8797d86f7c71cdfde Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 29 Jan 2024 14:18:38 -0800
Subject: [PATCH 2/2] clang-format

---
 flang/lib/Semantics/resolve-directives.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 37821b874170f2a..dd9d79466e98db0 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1267,11 +1267,11 @@ void AccAttributeVisitor::CheckAssociatedLoop(
   };
 
   Symbol::Flag flag;
-  llvm::SmallVector<Symbol*> ivs;
+  llvm::SmallVector<Symbol *> ivs;
   using Bounds = parser::LoopControl::Bounds;
   for (const parser::DoConstruct *loop{&outerDoConstruct}; loop && level > 0;) {
     // Go through all nested loops to ensure index variable exists.
-    if (const parser::Name *ivName{GetLoopIndex(*loop)}) {
+    if (const parser::Name * ivName{GetLoopIndex(*loop)}) {
       if (auto *symbol{ResolveAcc(*ivName, flag, currScope())}) {
         if (auto &control = loop->GetLoopControl()) {
           if (const Bounds * b{std::get_if<Bounds>(&control->u)}) {



More information about the flang-commits mailing list