[flang-commits] [flang] [flang][openacc] Check trip count invariance with other IVs (PR #79906)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue Jan 30 11:15:06 PST 2024


Valentin Clement =?utf-8?b?KOODkOODrOODsw=?Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/79906 at github.com>


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/79906

>From 662f1436fef0ba19fe1d6d538b879621b6870700 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 25 Jan 2024 15:15:22 -0800
Subject: [PATCH 1/4] [flang][openacc] Check trip count invariance with other
 IVs

2.9.1 The trip count for all loops associated with the collapse clause must be
computable and invariant in all the loops.

This patch checks that loops part of a collapse nest does not depends on outer
loops induction variables.

The check is also applied to combined construct with a loop.
---
 flang/lib/Semantics/resolve-directives.cpp | 51 +++++++++++++++++++---
 flang/test/Semantics/OpenACC/acc-loop.f90  | 40 ++++++++++++++++-
 2 files changed, 83 insertions(+), 8 deletions(-)

diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index e19f68eefa28..37821b874170 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -13,6 +13,7 @@
 #include "resolve-names-utils.h"
 #include "flang/Common/idioms.h"
 #include "flang/Evaluate/fold.h"
+#include "flang/Evaluate/tools.h"
 #include "flang/Evaluate/type.h"
 #include "flang/Parser/parse-tree-visitor.h"
 #include "flang/Parser/parse-tree.h"
@@ -266,7 +267,7 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
       Symbol::Flag::AccDevicePtr, Symbol::Flag::AccDeviceResident,
       Symbol::Flag::AccLink, Symbol::Flag::AccPresent};
 
-  void CheckAssociatedLoopIndex(const parser::OpenACCLoopConstruct &);
+  void CheckAssociatedLoop(const parser::DoConstruct &);
   void ResolveAccObjectList(const parser::AccObjectList &, Symbol::Flag);
   void ResolveAccObject(const parser::AccObject &, Symbol::Flag);
   Symbol *ResolveAcc(const parser::Name &, Symbol::Flag, Scope &);
@@ -882,7 +883,8 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCLoopConstruct &x) {
   }
   ClearDataSharingAttributeObjects();
   SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList));
-  CheckAssociatedLoopIndex(x);
+  const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
+  CheckAssociatedLoop(*outer);
   return true;
 }
 
@@ -1087,6 +1089,10 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCCombinedConstruct &x) {
   default:
     break;
   }
+  const auto &clauseList{std::get<parser::AccClauseList>(beginBlockDir.t)};
+  SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList));
+  const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
+  CheckAssociatedLoop(*outer);
   ClearDataSharingAttributeObjects();
   return true;
 }
@@ -1218,8 +1224,8 @@ std::int64_t AccAttributeVisitor::GetAssociatedLoopLevelFromClauses(
   return 1; // default is outermost loop
 }
 
-void AccAttributeVisitor::CheckAssociatedLoopIndex(
-    const parser::OpenACCLoopConstruct &x) {
+void AccAttributeVisitor::CheckAssociatedLoop(
+    const parser::DoConstruct &outerDoConstruct) {
   std::int64_t level{GetContext().associatedLoopLevel};
   if (level <= 0) { // collapse value was negative or 0
     return;
@@ -1250,10 +1256,41 @@ void AccAttributeVisitor::CheckAssociatedLoopIndex(
     return nullptr;
   };
 
-  const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
-  for (const parser::DoConstruct *loop{&*outer}; loop && level > 0;) {
+  auto checkExprHasSymbols = [&](llvm::SmallVector<Symbol *> &ivs,
+                                 semantics::UnorderedSymbolSet &symbols) {
+    for (auto iv : ivs) {
+      if (symbols.count(*iv) != 0) {
+        context_.Say(GetContext().directiveSource,
+            "Trip count must be computable and invariant"_err_en_US);
+      }
+    }
+  };
+
+  Symbol::Flag flag;
+  llvm::SmallVector<Symbol*> ivs;
+  using Bounds = parser::LoopControl::Bounds;
+  for (const parser::DoConstruct *loop{&outerDoConstruct}; loop && level > 0;) {
     // Go through all nested loops to ensure index variable exists.
-    GetLoopIndex(*loop);
+    if (const parser::Name *ivName{GetLoopIndex(*loop)}) {
+      if (auto *symbol{ResolveAcc(*ivName, flag, currScope())}) {
+        if (auto &control = loop->GetLoopControl()) {
+          if (const Bounds * b{std::get_if<Bounds>(&control->u)}) {
+            if (auto lowerExpr = semantics::AnalyzeExpr(context_, b->lower)) {
+              semantics::UnorderedSymbolSet lowerSyms =
+                  evaluate::CollectSymbols(*lowerExpr);
+              checkExprHasSymbols(ivs, lowerSyms);
+            }
+            if (auto upperExpr = semantics::AnalyzeExpr(context_, b->upper)) {
+              semantics::UnorderedSymbolSet upperSyms =
+                  evaluate::CollectSymbols(*upperExpr);
+              checkExprHasSymbols(ivs, upperSyms);
+            }
+          }
+        }
+        ivs.push_back(symbol);
+      }
+    }
+
     const auto &block{std::get<parser::Block>(loop->t)};
     --level;
     loop = getNextDoConstruct(block, level);
diff --git a/flang/test/Semantics/OpenACC/acc-loop.f90 b/flang/test/Semantics/OpenACC/acc-loop.f90
index fde836852c51..859cf3feec0d 100644
--- a/flang/test/Semantics/OpenACC/acc-loop.f90
+++ b/flang/test/Semantics/OpenACC/acc-loop.f90
@@ -10,9 +10,10 @@ program openacc_loop_validity
   type atype
     real(8), dimension(10) :: arr
     real(8) :: s
+    integer :: n
   end type atype
 
-  integer :: i, j, b, gang_size, vector_size, worker_size
+  integer :: i, j, k, b, gang_size, vector_size, worker_size
   integer, parameter :: N = 256
   integer, dimension(N) :: c
   logical, dimension(N) :: d, e
@@ -317,4 +318,41 @@ program openacc_loop_validity
     END DO
   END DO
 
+  !ERROR: Trip count must be computable and invariant
+  !$acc loop collapse(2)
+  DO i = 1, n
+    DO j = 1, c(i)
+    END DO
+  END DO
+
+  !ERROR: Trip count must be computable and invariant
+  !$acc loop collapse(2)
+  DO i = 1, n
+    DO j = 1, i
+    END DO
+  END DO
+
+  !ERROR: Trip count must be computable and invariant
+  !$acc loop collapse(2)
+  DO i = 1, n
+    DO j = 1, ta(i)%n
+    END DO
+  END DO
+
+  !ERROR: Trip count must be computable and invariant
+  !$acc parallel loop collapse(2)
+  DO i = 1, n
+    DO j = 1, ta(i)%n
+    END DO
+  END DO
+
+  !ERROR: Trip count must be computable and invariant
+  !$acc loop collapse(3)
+  DO i = 1, n
+    DO j = 1, n
+      DO k = 1, i
+      END DO
+    END DO
+  END DO
+
 end program openacc_loop_validity

>From 634988a173a1f9dc1722b2e8797d86f7c71cdfde Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 29 Jan 2024 14:18:38 -0800
Subject: [PATCH 2/4] clang-format

---
 flang/lib/Semantics/resolve-directives.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 37821b874170..dd9d79466e98 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1267,11 +1267,11 @@ void AccAttributeVisitor::CheckAssociatedLoop(
   };
 
   Symbol::Flag flag;
-  llvm::SmallVector<Symbol*> ivs;
+  llvm::SmallVector<Symbol *> ivs;
   using Bounds = parser::LoopControl::Bounds;
   for (const parser::DoConstruct *loop{&outerDoConstruct}; loop && level > 0;) {
     // Go through all nested loops to ensure index variable exists.
-    if (const parser::Name *ivName{GetLoopIndex(*loop)}) {
+    if (const parser::Name * ivName{GetLoopIndex(*loop)}) {
       if (auto *symbol{ResolveAcc(*ivName, flag, currScope())}) {
         if (auto &control = loop->GetLoopControl()) {
           if (const Bounds * b{std::get_if<Bounds>(&control->u)}) {

>From 9e64db21a3a0f4630ec53b78bb35bf441243841f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?=
 =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?=
 =?UTF-8?q?=E3=83=B3=29?= <clementval at gmail.com>
Date: Tue, 30 Jan 2024 11:14:30 -0800
Subject: [PATCH 3/4] Update flang/lib/Semantics/resolve-directives.cpp

---
 flang/lib/Semantics/resolve-directives.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index dd9d79466e98..a3c7a946985d 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1273,7 +1273,7 @@ void AccAttributeVisitor::CheckAssociatedLoop(
     // Go through all nested loops to ensure index variable exists.
     if (const parser::Name * ivName{GetLoopIndex(*loop)}) {
       if (auto *symbol{ResolveAcc(*ivName, flag, currScope())}) {
-        if (auto &control = loop->GetLoopControl()) {
+        if (auto &control{loop->GetLoopControl()}) {
           if (const Bounds * b{std::get_if<Bounds>(&control->u)}) {
             if (auto lowerExpr = semantics::AnalyzeExpr(context_, b->lower)) {
               semantics::UnorderedSymbolSet lowerSyms =

>From 3061c8a95edd3c9503e479ce7f52cd30a8298428 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?=
 =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?=
 =?UTF-8?q?=E3=83=B3=29?= <clementval at gmail.com>
Date: Tue, 30 Jan 2024 11:14:57 -0800
Subject: [PATCH 4/4] Use brace init

---
 flang/lib/Semantics/resolve-directives.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index a3c7a946985d..e9c7fb2c5ce1 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1275,7 +1275,7 @@ void AccAttributeVisitor::CheckAssociatedLoop(
       if (auto *symbol{ResolveAcc(*ivName, flag, currScope())}) {
         if (auto &control{loop->GetLoopControl()}) {
           if (const Bounds * b{std::get_if<Bounds>(&control->u)}) {
-            if (auto lowerExpr = semantics::AnalyzeExpr(context_, b->lower)) {
+            if (auto lowerExpr{semantics::AnalyzeExpr(context_, b->lower)}) {
               semantics::UnorderedSymbolSet lowerSyms =
                   evaluate::CollectSymbols(*lowerExpr);
               checkExprHasSymbols(ivs, lowerSyms);



More information about the flang-commits mailing list