[flang-commits] [flang] [flang] Do not traverse selectors in FindImpureCall and HasVectorSubscript (PR #84041)

via flang-commits flang-commits at lists.llvm.org
Tue Mar 5 09:05:25 PST 2024


https://github.com/jeanPerier created https://github.com/llvm/llvm-project/pull/84041

In presence of symbols with AssocEntityDetails in an expression, `Traverse`, `AnyTraverse`, `AllTraverse`, and `SetTraverse` automatically visit the selector expression or variable.

This is most often the desired behavior but can be surprising, and was not correct for FindImpureCall and HasVectorSubscript.

Add a default template option to flag the behavior to someone willing to use the Traverse helper for a new utility, and set this template to false for FindImpureCall and HasVectorSubscript.

>From 13aa6c548a78e0a104600ca01db7b24642efedbc Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 5 Mar 2024 08:54:07 -0800
Subject: [PATCH] [flang] Do not traverse selectors in FindImpureCall and
 HasVectorSubscript

In presence of symbol with AssocEntityDetails in an expression,
Traverse, AnyTraverse, AllTraverse, and SetTraverse automatically visit
the selector.

This is most often the desired behavior but can be surprising,
and was not correct for FindImpureCall and HasVectorSubscript.

Add a default template option to flag the behavior to someone willing to
use the Traverse helper for a new utility, and set this template to false
for FindImpureCall and HasVectorSubscript.
---
 flang/include/flang/Evaluate/traverse.h | 24 +++++++++++++++---------
 flang/lib/Evaluate/tools.cpp            | 11 +++++++----
 flang/test/Semantics/forall01.f90       | 11 +++++++++++
 flang/test/Semantics/selecttype03.f90   |  2 +-
 4 files changed, 34 insertions(+), 14 deletions(-)

diff --git a/flang/include/flang/Evaluate/traverse.h b/flang/include/flang/Evaluate/traverse.h
index 8d75cc2df7247b..7f4a67d97e64e7 100644
--- a/flang/include/flang/Evaluate/traverse.h
+++ b/flang/include/flang/Evaluate/traverse.h
@@ -45,7 +45,9 @@
 #include <type_traits>
 
 namespace Fortran::evaluate {
-template <typename Visitor, typename Result> class Traverse {
+template <typename Visitor, typename Result,
+    bool TraverseAssocEntityDetails = true>
+class Traverse {
 public:
   explicit Traverse(Visitor &v) : visitor_{v} {}
 
@@ -108,12 +110,13 @@ template <typename Visitor, typename Result> class Traverse {
   }
   Result operator()(const Symbol &symbol) const {
     const Symbol &ultimate{symbol.GetUltimate()};
-    if (const auto *assoc{
-            ultimate.detailsIf<semantics::AssocEntityDetails>()}) {
-      return visitor_(assoc->expr());
-    } else {
-      return visitor_.Default();
+    if constexpr (TraverseAssocEntityDetails) {
+      if (const auto *assoc{
+              ultimate.detailsIf<semantics::AssocEntityDetails>()}) {
+        return visitor_(assoc->expr());
+      }
     }
+    return visitor_.Default();
   }
   Result operator()(const StaticDataObject &) const {
     return visitor_.Default();
@@ -284,7 +287,8 @@ template <typename Visitor, typename Result> class Traverse {
 // For validity checks across an expression: if any operator() result is
 // false, so is the overall result.
 template <typename Visitor, bool DefaultValue,
-    typename Base = Traverse<Visitor, bool>>
+    bool TraverseAssocEntityDetails = true,
+    typename Base = Traverse<Visitor, bool, TraverseAssocEntityDetails>>
 struct AllTraverse : public Base {
   explicit AllTraverse(Visitor &v) : Base{v} {}
   using Base::operator();
@@ -296,7 +300,8 @@ struct AllTraverse : public Base {
 // is truthful is the final result.  Works for Booleans, pointers,
 // and std::optional<>.
 template <typename Visitor, typename Result = bool,
-    typename Base = Traverse<Visitor, Result>>
+    bool TraverseAssocEntityDetails = true,
+    typename Base = Traverse<Visitor, Result, TraverseAssocEntityDetails>>
 class AnyTraverse : public Base {
 public:
   explicit AnyTraverse(Visitor &v) : Base{v} {}
@@ -315,7 +320,8 @@ class AnyTraverse : public Base {
 };
 
 template <typename Visitor, typename Set,
-    typename Base = Traverse<Visitor, Set>>
+    bool TraverseAssocEntityDetails = true,
+    typename Base = Traverse<Visitor, Set, TraverseAssocEntityDetails>>
 struct SetTraverse : public Base {
   explicit SetTraverse(Visitor &v) : Base{v} {}
   using Base::operator();
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index e7fc651b9173fe..f514a25b010241 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -995,8 +995,10 @@ template semantics::UnorderedSymbolSet CollectSymbols(
     const Expr<SubscriptInteger> &);
 
 // HasVectorSubscript()
-struct HasVectorSubscriptHelper : public AnyTraverse<HasVectorSubscriptHelper> {
-  using Base = AnyTraverse<HasVectorSubscriptHelper>;
+struct HasVectorSubscriptHelper
+    : public AnyTraverse<HasVectorSubscriptHelper, bool,
+          /*TraverseAssocEntityDetails=*/false> {
+  using Base = AnyTraverse<HasVectorSubscriptHelper, bool, false>;
   HasVectorSubscriptHelper() : Base{*this} {}
   using Base::operator();
   bool operator()(const Subscript &ss) const {
@@ -1045,9 +1047,10 @@ parser::Message *AttachDeclaration(
 }
 
 class FindImpureCallHelper
-    : public AnyTraverse<FindImpureCallHelper, std::optional<std::string>> {
+    : public AnyTraverse<FindImpureCallHelper, std::optional<std::string>,
+          /*TraverseAssocEntityDetails=*/false> {
   using Result = std::optional<std::string>;
-  using Base = AnyTraverse<FindImpureCallHelper, Result>;
+  using Base = AnyTraverse<FindImpureCallHelper, Result, false>;
 
 public:
   explicit FindImpureCallHelper(FoldingContext &c) : Base{*this}, context_{c} {}
diff --git a/flang/test/Semantics/forall01.f90 b/flang/test/Semantics/forall01.f90
index a81eb9621e77c6..72ad9ecd39471a 100644
--- a/flang/test/Semantics/forall01.f90
+++ b/flang/test/Semantics/forall01.f90
@@ -135,3 +135,14 @@ subroutine forall7(x)
     end forall
   end select
 end subroutine
+
+subroutine forall8(x)
+  real :: x(10)
+  real, external :: foo
+  !ERROR: Impure procedure 'foo' may not be referenced in a FORALL
+  forall(i=1:10) x(i) = foo() + i
+  !OK
+  associate(y => foo())
+    forall (i=1:10) x(i) = y + i
+  end associate
+end subroutine
diff --git a/flang/test/Semantics/selecttype03.f90 b/flang/test/Semantics/selecttype03.f90
index eb343c4ccc5300..c440960f404a3e 100644
--- a/flang/test/Semantics/selecttype03.f90
+++ b/flang/test/Semantics/selecttype03.f90
@@ -65,7 +65,7 @@
     b%i  = 1 !VDC
   type is (t2)
     !ERROR: Actual argument associated with INTENT(IN OUT) dummy argument 'z=' is not definable
-    !BECAUSE: Variable 'b' has a vector subscript
+    !BECAUSE: Construct association 'b' has a vector subscript
     call sub_with_in_and_inout_param_vector(b,b) !VDC
 end select
 select type(b =>  foo(1) )



More information about the flang-commits mailing list