[flang-commits] [flang] [flang][OpenMP] Implement `CheckReductionObjects` for all reduction c… (PR #118689)

via flang-commits flang-commits at lists.llvm.org
Wed Dec 4 11:42:26 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Krzysztof Parzyszek (kparzysz)

<details>
<summary>Changes</summary>

…lauses

Currently we only do semantic checks for REDUCTION. There are two other clauses, IN_REDUCTION, and TASK_REDUCTION which will also need those checks. Implement a function that checks the common list-item requirements for all those clauses.

---

Patch is 21.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118689.diff


7 Files Affected:

- (modified) flang/lib/Semantics/check-omp-structure.cpp (+164-82) 
- (modified) flang/lib/Semantics/check-omp-structure.h (+10-4) 
- (modified) flang/test/Semantics/OpenMP/reduction04.f90 (+1-2) 
- (modified) flang/test/Semantics/OpenMP/reduction06.f90 (+3-3) 
- (modified) flang/test/Semantics/OpenMP/reduction12.f90 (+1-1) 
- (modified) flang/test/Semantics/OpenMP/use_device_addr1.f90 (+1-1) 
- (modified) flang/test/Semantics/OpenMP/use_device_ptr1.f90 (+1-1) 


``````````diff
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index f09542e5b5df14..d4d4a677ce58ff 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -51,6 +51,30 @@ std::string TryVersion(unsigned version) {
   return "try -fopenmp-version=" + std::to_string(version);
 }
 
+static const parser::Designator *GetDesignatorFromObj(
+    const parser::OmpObject &object) {
+  return std::get_if<parser::Designator>(&object.u);
+}
+
+static const parser::DataRef *GetDataRefFromObj(
+    const parser::OmpObject &object) {
+  if (auto *desg{GetDesignatorFromObj(object)}) {
+    return std::get_if<parser::DataRef>(&desg->u);
+  }
+  return nullptr;
+}
+
+static const parser::ArrayElement *GetArrayElementFromObj(
+    const parser::OmpObject &object) {
+  if (auto *dataRef{GetDataRefFromObj(object)}) {
+    using ElementIndirection = common::Indirection<parser::ArrayElement>;
+    if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
+      return &ind->value();
+    }
+  }
+  return nullptr;
+}
+
 // 'OmpWorkshareBlockChecker' is used to check the validity of the assignment
 // statements and the expressions enclosed in an OpenMP Workshare construct
 class OmpWorkshareBlockChecker {
@@ -222,6 +246,10 @@ bool OmpStructureChecker::CheckAllowedClause(llvmOmpClause clause) {
   return CheckAllowed(clause);
 }
 
+bool OmpStructureChecker::IsCommonBlock(const Symbol &sym) {
+  return sym.detailsIf<CommonBlockDetails>() != nullptr;
+}
+
 bool OmpStructureChecker::IsVariableListItem(const Symbol &sym) {
   return evaluate::IsVariable(sym) || sym.attrs().test(Attr::POINTER);
 }
@@ -2895,6 +2923,8 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
       CheckReductionModifier(*maybeModifier);
     }
   }
+  CheckReductionObjects(std::get<parser::OmpObjectList>(x.v.t),
+      llvm::omp::Clause::OMPC_reduction);
 }
 
 bool OmpStructureChecker::CheckReductionOperators(
@@ -2963,6 +2993,69 @@ bool OmpStructureChecker::CheckIntrinsicOperator(
   return false;
 }
 
+/// Check restrictions on objects that are common to all reduction clauses.
+void OmpStructureChecker::CheckReductionObjects(
+    const parser::OmpObjectList &objects, llvm::omp::Clause clauseId) {
+  unsigned version{context_.langOptions().OpenMPVersion};
+  SymbolSourceMap symbols;
+  GetSymbolsInObjectList(objects, symbols);
+
+  // Array sections must be a contiguous storage, have non-zero length.
+  for (const parser::OmpObject &object : objects.v) {
+    CheckIfContiguous(object);
+  }
+  CheckReductionArraySection(objects);
+  // An object must be definable.
+  CheckDefinableObjects(symbols, clauseId);
+  // Procedure pointers are not allowed.
+  CheckProcedurePointer(symbols, clauseId);
+  // Pointers must not have INTENT(IN).
+  CheckIntentInPointer(symbols, clauseId);
+
+  // Disallow common blocks.
+  // Iterate on objects because `GetSymbolsInObjectList` expands common block
+  // names into the lists of their members.
+  for (const parser::OmpObject &object : objects.v) {
+    auto *symbol{GetObjectSymbol(object)};
+    assert(symbol);
+    if (IsCommonBlock(*symbol)) {
+      auto source{GetObjectSource(object)};
+      context_.Say(source ? *source : GetContext().clauseSource,
+          "Common block names are not allowed in %s clause"_err_en_US,
+          parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
+    }
+  }
+
+  if (version >= 50) {
+    // Object cannot be a part of another object (except array elements)
+    CheckStructureComponent(objects, clauseId);
+    // If object is an array section or element, the base expression must be
+    // a language identifier.
+    for (const parser::OmpObject &object : objects.v) {
+      if (auto *elem{GetArrayElementFromObj(object)}) {
+        const parser::DataRef &base = elem->base;
+        if (!std::holds_alternative<parser::Name>(base.u)) {
+          auto source{GetObjectSource(object)};
+          context_.Say(source ? *source : GetContext().clauseSource,
+              "The base expression of an array element in %s clause must be an identifier"_err_en_US,
+              parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
+        }
+      }
+    }
+    // Type parameter inquiries are not allowed.
+    for (const parser::OmpObject &object : objects.v) {
+      if (auto *dataRef{GetDataRefFromObj(object)}) {
+        if (IsDataRefTypeParamInquiry(dataRef)) {
+          auto source{GetObjectSource(object)};
+          context_.Say(source ? *source : GetContext().clauseSource,
+              "Type parameter inquiry is not permitted in %s clause"_err_en_US,
+              parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
+        }
+      }
+    }
+  }
+}
+
 static bool IsReductionAllowedForType(
     const parser::OmpClause::Reduction &x, const DeclTypeSpec &type) {
   auto &modifiers{OmpGetModifiers(x.v)};
@@ -3052,8 +3145,9 @@ static bool IsReductionAllowedForType(
 void OmpStructureChecker::CheckReductionTypeList(
     const parser::OmpClause::Reduction &x) {
   const auto &ompObjectList{std::get<parser::OmpObjectList>(x.v.t)};
-  CheckIntentInPointerAndDefinable(
-      ompObjectList, llvm::omp::Clause::OMPC_reduction);
+  SymbolSourceMap symbols;
+  GetSymbolsInObjectList(ompObjectList, symbols);
+
   CheckReductionArraySection(ompObjectList);
   // If this is a worksharing construct then ensure the reduction variable
   // is not private in the parallel region that it binds to.
@@ -3061,17 +3155,15 @@ void OmpStructureChecker::CheckReductionTypeList(
     CheckSharedBindingInOuterContext(ompObjectList);
   }
 
-  SymbolSourceMap symbols;
-  GetSymbolsInObjectList(ompObjectList, symbols);
   for (auto &[symbol, source] : symbols) {
-    if (IsProcedurePointer(*symbol)) {
-      context_.Say(source,
-          "A procedure pointer '%s' must not appear in a REDUCTION clause."_err_en_US,
-          symbol->name());
-    } else if (!IsReductionAllowedForType(x, DEREF(symbol->GetType()))) {
-      context_.Say(source,
-          "The type of '%s' is incompatible with the reduction operator."_err_en_US,
-          symbol->name());
+    if (auto *type{symbol->GetType()}) {
+      if (!IsReductionAllowedForType(x, *type)) {
+        context_.Say(source,
+            "The type of '%s' is incompatible with the reduction operator."_err_en_US,
+            symbol->name());
+      }
+    } else {
+      assert(IsProcedurePointer(*symbol) && "Unexpected symbol properties");
     }
   }
 }
@@ -3127,43 +3219,14 @@ void OmpStructureChecker::CheckReductionModifier(
   }
 }
 
-void OmpStructureChecker::CheckIntentInPointerAndDefinable(
-    const parser::OmpObjectList &objectList, const llvm::omp::Clause clause) {
-  for (const auto &ompObject : objectList.v) {
-    if (const auto *name{parser::Unwrap<parser::Name>(ompObject)}) {
-      if (const auto *symbol{name->symbol}) {
-        if (IsPointer(symbol->GetUltimate()) &&
-            IsIntentIn(symbol->GetUltimate())) {
-          context_.Say(GetContext().clauseSource,
-              "Pointer '%s' with the INTENT(IN) attribute may not appear "
-              "in a %s clause"_err_en_US,
-              symbol->name(),
-              parser::ToUpperCaseLetters(getClauseName(clause).str()));
-        } else if (auto msg{WhyNotDefinable(name->source,
-                       context_.FindScope(name->source), DefinabilityFlags{},
-                       *symbol)}) {
-          context_
-              .Say(GetContext().clauseSource,
-                  "Variable '%s' on the %s clause is not definable"_err_en_US,
-                  symbol->name(),
-                  parser::ToUpperCaseLetters(getClauseName(clause).str()))
-              .Attach(std::move(msg->set_severity(parser::Severity::Because)));
-        }
-      }
-    }
-  }
-}
-
 void OmpStructureChecker::CheckReductionArraySection(
     const parser::OmpObjectList &ompObjectList) {
   for (const auto &ompObject : ompObjectList.v) {
     if (const auto *dataRef{parser::Unwrap<parser::DataRef>(ompObject)}) {
       if (const auto *arrayElement{
               parser::Unwrap<parser::ArrayElement>(ompObject)}) {
-        if (arrayElement) {
-          CheckArraySection(*arrayElement, GetLastName(*dataRef),
-              llvm::omp::Clause::OMPC_reduction);
-        }
+        CheckArraySection(*arrayElement, GetLastName(*dataRef),
+            llvm::omp::Clause::OMPC_reduction);
       }
     }
   }
@@ -3232,9 +3295,11 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Shared &x) {
   CheckIsVarPartOfAnotherVar(GetContext().clauseSource, x.v, "SHARED");
 }
 void OmpStructureChecker::Enter(const parser::OmpClause::Private &x) {
+  SymbolSourceMap symbols;
+  GetSymbolsInObjectList(x.v, symbols);
   CheckAllowedClause(llvm::omp::Clause::OMPC_private);
   CheckIsVarPartOfAnotherVar(GetContext().clauseSource, x.v, "PRIVATE");
-  CheckIntentInPointer(x.v, llvm::omp::Clause::OMPC_private);
+  CheckIntentInPointer(symbols, llvm::omp::Clause::OMPC_private);
 }
 
 void OmpStructureChecker::Enter(const parser::OmpClause::Nowait &x) {
@@ -3891,11 +3956,11 @@ void OmpStructureChecker::CheckCopyingPolymorphicAllocatable(
 
 void OmpStructureChecker::Enter(const parser::OmpClause::Copyprivate &x) {
   CheckAllowedClause(llvm::omp::Clause::OMPC_copyprivate);
-  CheckIntentInPointer(x.v, llvm::omp::Clause::OMPC_copyprivate);
-  SymbolSourceMap currSymbols;
-  GetSymbolsInObjectList(x.v, currSymbols);
+  SymbolSourceMap symbols;
+  GetSymbolsInObjectList(x.v, symbols);
+  CheckIntentInPointer(symbols, llvm::omp::Clause::OMPC_copyprivate);
   CheckCopyingPolymorphicAllocatable(
-      currSymbols, llvm::omp::Clause::OMPC_copyprivate);
+      symbols, llvm::omp::Clause::OMPC_copyprivate);
   if (GetContext().directive == llvm::omp::Directive::OMPD_single) {
     context_.Say(GetContext().clauseSource,
         "%s clause is not allowed on the OMP %s directive,"
@@ -3945,29 +4010,26 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Copyin &x) {
       currSymbols, llvm::omp::Clause::OMPC_copyin);
 }
 
-void OmpStructureChecker::CheckStructureElement(
-    const parser::OmpObjectList &ompObjectList,
-    const llvm::omp::Clause clause) {
-  for (const auto &ompObject : ompObjectList.v) {
+void OmpStructureChecker::CheckStructureComponent(
+    const parser::OmpObjectList &objects, llvm::omp::Clause clauseId) {
+  auto CheckComponent{[&](const parser::Designator &designator) {
+    if (auto *desg{std::get_if<parser::DataRef>(&designator.u)}) {
+      if (auto *comp{parser::Unwrap<parser::StructureComponent>(*desg)}) {
+        context_.Say(comp->component.source,
+            "A variable that is part of another variable cannot appear on the %s clause"_err_en_US,
+            parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
+      }
+    }
+  }};
+
+  for (const auto &object : objects.v) {
     common::visit(
         common::visitors{
-            [&](const parser::Designator &designator) {
-              if (std::get_if<parser::DataRef>(&designator.u)) {
-                if (parser::Unwrap<parser::StructureComponent>(ompObject)) {
-                  context_.Say(GetContext().clauseSource,
-                      "A variable that is part of another variable "
-                      "(structure element) cannot appear on the %s "
-                      "%s clause"_err_en_US,
-                      ContextDirectiveAsFortran(),
-                      parser::ToUpperCaseLetters(getClauseName(clause).str()));
-                }
-              }
-            },
+            CheckComponent,
             [&](const parser::Name &name) {},
         },
-        ompObject.u);
+        object.u);
   }
-  return;
 }
 
 void OmpStructureChecker::Enter(const parser::OmpClause::Update &x) {
@@ -4009,7 +4071,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Update &x) {
 }
 
 void OmpStructureChecker::Enter(const parser::OmpClause::UseDevicePtr &x) {
-  CheckStructureElement(x.v, llvm::omp::Clause::OMPC_use_device_ptr);
+  CheckStructureComponent(x.v, llvm::omp::Clause::OMPC_use_device_ptr);
   CheckAllowedClause(llvm::omp::Clause::OMPC_use_device_ptr);
   SymbolSourceMap currSymbols;
   GetSymbolsInObjectList(x.v, currSymbols);
@@ -4038,7 +4100,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::UseDevicePtr &x) {
 }
 
 void OmpStructureChecker::Enter(const parser::OmpClause::UseDeviceAddr &x) {
-  CheckStructureElement(x.v, llvm::omp::Clause::OMPC_use_device_addr);
+  CheckStructureComponent(x.v, llvm::omp::Clause::OMPC_use_device_addr);
   CheckAllowedClause(llvm::omp::Clause::OMPC_use_device_addr);
   SymbolSourceMap currSymbols;
   GetSymbolsInObjectList(x.v, currSymbols);
@@ -4214,6 +4276,26 @@ llvm::StringRef OmpStructureChecker::getDirectiveName(
   return llvm::omp::getOpenMPDirectiveName(directive);
 }
 
+const Symbol *OmpStructureChecker::GetObjectSymbol(
+    const parser::OmpObject &object) {
+  if (auto *name{std::get_if<parser::Name>(&object.u)}) {
+    return &name->symbol->GetUltimate();
+  } else if (auto *desg{std::get_if<parser::Designator>(&object.u)}) {
+    return &GetLastName(*desg).symbol->GetUltimate();
+  }
+  return nullptr;
+}
+
+std::optional<parser::CharBlock> OmpStructureChecker::GetObjectSource(
+    const parser::OmpObject &object) {
+  if (auto *name{std::get_if<parser::Name>(&object.u)}) {
+    return name->source;
+  } else if (auto *desg{std::get_if<parser::Designator>(&object.u)}) {
+    return GetLastName(*desg).source;
+  }
+  return std::nullopt;
+}
+
 void OmpStructureChecker::CheckDependList(const parser::DataRef &d) {
   common::visit(
       common::visitors{
@@ -4267,15 +4349,6 @@ void OmpStructureChecker::CheckArraySection(
                       "DEPEND "
                       "clause"_err_en_US);
                 }
-                const auto stride{GetIntValue(strideExpr)};
-                if ((stride && stride != 1)) {
-                  context_.Say(GetContext().clauseSource,
-                      "A list item that appears in a REDUCTION clause"
-                      " should have a contiguous storage array "
-                      "section."_err_en_US,
-                      ContextDirectiveAsFortran());
-                  break;
-                }
               }
             }
           }
@@ -4286,14 +4359,23 @@ void OmpStructureChecker::CheckArraySection(
 }
 
 void OmpStructureChecker::CheckIntentInPointer(
-    const parser::OmpObjectList &objectList, const llvm::omp::Clause clause) {
-  SymbolSourceMap symbols;
-  GetSymbolsInObjectList(objectList, symbols);
+    SymbolSourceMap &symbols, llvm::omp::Clause clauseId) {
   for (auto &[symbol, source] : symbols) {
     if (IsPointer(*symbol) && IsIntentIn(*symbol)) {
       context_.Say(source,
-          "Pointer '%s' with the INTENT(IN) attribute may not appear "
-          "in a %s clause"_err_en_US,
+          "Pointer '%s' with the INTENT(IN) attribute may not appear in a %s clause"_err_en_US,
+          symbol->name(),
+          parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
+    }
+  }
+}
+
+void OmpStructureChecker::CheckProcedurePointer(
+    SymbolSourceMap &symbols, llvm::omp::Clause clause) {
+  for (const auto &[symbol, source] : symbols) {
+    if (IsProcedurePointer(*symbol)) {
+      context_.Say(source,
+          "Procedure pointer '%s' may not appear in a %s clause"_err_en_US,
           symbol->name(),
           parser::ToUpperCaseLetters(getClauseName(clause).str()));
     }
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 1411a9271d4665..3d4a30b994a70e 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -146,13 +146,14 @@ class OmpStructureChecker
   bool CheckAllowedClause(llvmOmpClause clause);
   bool IsVariableListItem(const Symbol &sym);
   bool IsExtendedListItem(const Symbol &sym);
+  bool IsCommonBlock(const Symbol &sym);
   std::optional<bool> IsContiguous(const parser::OmpObject &object);
   void CheckMultipleOccurrence(semantics::UnorderedSymbolSet &listVars,
       const std::list<parser::Name> &nameList, const parser::CharBlock &item,
       const std::string &clauseName);
   void CheckMultListItems();
-  void CheckStructureElement(const parser::OmpObjectList &ompObjectList,
-      const llvm::omp::Clause clause);
+  void CheckStructureComponent(const parser::OmpObjectList &objects,
+      llvm::omp::Clause clauseId);
   bool HasInvalidWorksharingNesting(
       const parser::CharBlock &, const OmpDirectiveSet &);
   bool IsCloselyNestedRegion(const OmpDirectiveSet &set);
@@ -171,6 +172,9 @@ class OmpStructureChecker
       typename IterTy = decltype(std::declval<RangeTy>().begin())>
   std::optional<IterTy> FindDuplicate(RangeTy &&);
 
+  const Symbol *GetObjectSymbol(const parser::OmpObject &object);
+  std::optional<parser::CharBlock> GetObjectSource(
+      const parser::OmpObject &object);
   void CheckDependList(const parser::DataRef &);
   void CheckDependArraySection(
       const common::Indirection<parser::ArrayElement> &, const parser::Name &);
@@ -182,8 +186,8 @@ class OmpStructureChecker
       const parser::OmpObjectList &objList);
   void CheckSymbolNames(
       const parser::CharBlock &source, const parser::OmpObjectList &objList);
-  void CheckIntentInPointer(
-      const parser::OmpObjectList &, const llvm::omp::Clause);
+  void CheckIntentInPointer(SymbolSourceMap &, const llvm::omp::Clause);
+  void CheckProcedurePointer(SymbolSourceMap &, const llvm::omp::Clause);
   void GetSymbolsInObjectList(const parser::OmpObjectList &, SymbolSourceMap &);
   void CheckDefinableObjects(SymbolSourceMap &, const llvm::omp::Clause);
   void CheckCopyingPolymorphicAllocatable(
@@ -220,6 +224,8 @@ class OmpStructureChecker
   void CheckCancellationNest(
       const parser::CharBlock &source, const parser::OmpCancelType::Type &type);
   std::int64_t GetOrdCollapseLevel(const parser::OpenMPLoopConstruct &x);
+  void CheckReductionObjects(const parser::OmpObjectList &objects,
+      llvm::omp::Clause clauseId);
   bool CheckReductionOperators(const parser::OmpClause::Reduction &);
   bool CheckIntrinsicOperator(
       const parser::DefinedOperator::IntrinsicOperator &);
diff --git a/flang/test/Semantics/OpenMP/reduction04.f90 b/flang/test/Semantics/OpenMP/reduction04.f90
index 319ed9f245abe8..e86b67b6efa961 100644
--- a/flang/test/Semantics/OpenMP/reduction04.f90
+++ b/flang/test/Semantics/OpenMP/reduction04.f90
@@ -14,8 +14,7 @@ program omp_Reduction
   end do
   !$omp end parallel do
 
-  !ERROR: Variable 'c' on the REDUCTION clause is not definable
-  !BECAUSE: 'c' is not a variable
+  !ERROR: Common block names are not allowed in REDUCTION clause
   !$omp parallel do reduction(*:/c/)
   do i = 1, 10
     l = k + 1
diff --git a/flang/test/Semantics/OpenMP/reduction06.f90 b/flang/test/Semantics/OpenMP/reduction06.f90
index 58290c61cae860..7d2a37310a8893 100644
--- a/flang/test/Semantics/OpenMP/reduction06.f90
+++ b/flang/test/Semantics/OpenMP/reduction06.f90
@@ -8,21 +8,21 @@ program omp_reduction
   integer :: k = 10
   integer :: a(10), b(10,10,10)
 
-  !ERROR: A list item that appears in a REDUCTION clause should have a contiguous storage array section.
+  !ERROR: Reference to a must be a contiguous object
   !$omp parallel do reduction(+:a(1:10:3))
   do i = 1, 10
     k = k + 1
   end do
   !$omp end parallel do
 
-  !ERROR: A list item that appears in a REDUCTION clause should have a contiguous storage array section.
+  !ERROR: Reference to b must be a contiguous object
   !$omp parallel do reduction(+:b(1:10:3,1:8:1,1:5:1))
   do i = 1, 10
     k = k + 1
   end do
   !$omp end parallel do
 
-  !ERROR: A list item that appears in a REDUCTION clause should have a contiguous storage array section.
+  !ERROR: Reference to b must be a contiguous object
   !$omp parallel do reduction(+:b(1:10:1,1:8:2,1:5:1))
   do i = 1, 10
     k = k + 1
diff --git a/flang/test/Semantics/OpenMP/reduction12.f90 b/flang/test/Semantics/OpenMP/reduction12.f90
index f896ca4aa60b67..f7e3baf0465624 100644
--- a/flang/test/Semantics/OpenMP/reduction12.f90
+++ b/flang/test/Semantics/OpenMP/reduction12.f90
@@ -6,7 +6,7 @@
   procedure(foo), pointer :: ptr
   integer :: i
   ptr => foo
-!ERROR: A procedure pointer 'ptr' must not appear in a REDUCTION clause.
+!ERROR: Proce...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/118689


More information about the flang-commits mailing list