[flang-commits] [flang] [flang][OpenMP] Implement `CheckReductionObjects` for all reduction c… (PR #118689)
via flang-commits
flang-commits at lists.llvm.org
Wed Dec 4 11:42:26 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Krzysztof Parzyszek (kparzysz)
<details>
<summary>Changes</summary>
…lauses
Currently we only do semantic checks for REDUCTION. There are two other clauses, IN_REDUCTION, and TASK_REDUCTION which will also need those checks. Implement a function that checks the common list-item requirements for all those clauses.
---
Patch is 21.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118689.diff
7 Files Affected:
- (modified) flang/lib/Semantics/check-omp-structure.cpp (+164-82)
- (modified) flang/lib/Semantics/check-omp-structure.h (+10-4)
- (modified) flang/test/Semantics/OpenMP/reduction04.f90 (+1-2)
- (modified) flang/test/Semantics/OpenMP/reduction06.f90 (+3-3)
- (modified) flang/test/Semantics/OpenMP/reduction12.f90 (+1-1)
- (modified) flang/test/Semantics/OpenMP/use_device_addr1.f90 (+1-1)
- (modified) flang/test/Semantics/OpenMP/use_device_ptr1.f90 (+1-1)
``````````diff
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index f09542e5b5df14..d4d4a677ce58ff 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -51,6 +51,30 @@ std::string TryVersion(unsigned version) {
return "try -fopenmp-version=" + std::to_string(version);
}
+static const parser::Designator *GetDesignatorFromObj(
+ const parser::OmpObject &object) {
+ return std::get_if<parser::Designator>(&object.u);
+}
+
+static const parser::DataRef *GetDataRefFromObj(
+ const parser::OmpObject &object) {
+ if (auto *desg{GetDesignatorFromObj(object)}) {
+ return std::get_if<parser::DataRef>(&desg->u);
+ }
+ return nullptr;
+}
+
+static const parser::ArrayElement *GetArrayElementFromObj(
+ const parser::OmpObject &object) {
+ if (auto *dataRef{GetDataRefFromObj(object)}) {
+ using ElementIndirection = common::Indirection<parser::ArrayElement>;
+ if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
+ return &ind->value();
+ }
+ }
+ return nullptr;
+}
+
// 'OmpWorkshareBlockChecker' is used to check the validity of the assignment
// statements and the expressions enclosed in an OpenMP Workshare construct
class OmpWorkshareBlockChecker {
@@ -222,6 +246,10 @@ bool OmpStructureChecker::CheckAllowedClause(llvmOmpClause clause) {
return CheckAllowed(clause);
}
+bool OmpStructureChecker::IsCommonBlock(const Symbol &sym) {
+ return sym.detailsIf<CommonBlockDetails>() != nullptr;
+}
+
bool OmpStructureChecker::IsVariableListItem(const Symbol &sym) {
return evaluate::IsVariable(sym) || sym.attrs().test(Attr::POINTER);
}
@@ -2895,6 +2923,8 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
CheckReductionModifier(*maybeModifier);
}
}
+ CheckReductionObjects(std::get<parser::OmpObjectList>(x.v.t),
+ llvm::omp::Clause::OMPC_reduction);
}
bool OmpStructureChecker::CheckReductionOperators(
@@ -2963,6 +2993,69 @@ bool OmpStructureChecker::CheckIntrinsicOperator(
return false;
}
+/// Check restrictions on objects that are common to all reduction clauses.
+void OmpStructureChecker::CheckReductionObjects(
+ const parser::OmpObjectList &objects, llvm::omp::Clause clauseId) {
+ unsigned version{context_.langOptions().OpenMPVersion};
+ SymbolSourceMap symbols;
+ GetSymbolsInObjectList(objects, symbols);
+
+ // Array sections must be a contiguous storage, have non-zero length.
+ for (const parser::OmpObject &object : objects.v) {
+ CheckIfContiguous(object);
+ }
+ CheckReductionArraySection(objects);
+ // An object must be definable.
+ CheckDefinableObjects(symbols, clauseId);
+ // Procedure pointers are not allowed.
+ CheckProcedurePointer(symbols, clauseId);
+ // Pointers must not have INTENT(IN).
+ CheckIntentInPointer(symbols, clauseId);
+
+ // Disallow common blocks.
+ // Iterate on objects because `GetSymbolsInObjectList` expands common block
+ // names into the lists of their members.
+ for (const parser::OmpObject &object : objects.v) {
+ auto *symbol{GetObjectSymbol(object)};
+ assert(symbol);
+ if (IsCommonBlock(*symbol)) {
+ auto source{GetObjectSource(object)};
+ context_.Say(source ? *source : GetContext().clauseSource,
+ "Common block names are not allowed in %s clause"_err_en_US,
+ parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
+ }
+ }
+
+ if (version >= 50) {
+ // Object cannot be a part of another object (except array elements)
+ CheckStructureComponent(objects, clauseId);
+ // If object is an array section or element, the base expression must be
+ // a language identifier.
+ for (const parser::OmpObject &object : objects.v) {
+ if (auto *elem{GetArrayElementFromObj(object)}) {
+ const parser::DataRef &base = elem->base;
+ if (!std::holds_alternative<parser::Name>(base.u)) {
+ auto source{GetObjectSource(object)};
+ context_.Say(source ? *source : GetContext().clauseSource,
+ "The base expression of an array element in %s clause must be an identifier"_err_en_US,
+ parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
+ }
+ }
+ }
+ // Type parameter inquiries are not allowed.
+ for (const parser::OmpObject &object : objects.v) {
+ if (auto *dataRef{GetDataRefFromObj(object)}) {
+ if (IsDataRefTypeParamInquiry(dataRef)) {
+ auto source{GetObjectSource(object)};
+ context_.Say(source ? *source : GetContext().clauseSource,
+ "Type parameter inquiry is not permitted in %s clause"_err_en_US,
+ parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
+ }
+ }
+ }
+ }
+}
+
static bool IsReductionAllowedForType(
const parser::OmpClause::Reduction &x, const DeclTypeSpec &type) {
auto &modifiers{OmpGetModifiers(x.v)};
@@ -3052,8 +3145,9 @@ static bool IsReductionAllowedForType(
void OmpStructureChecker::CheckReductionTypeList(
const parser::OmpClause::Reduction &x) {
const auto &ompObjectList{std::get<parser::OmpObjectList>(x.v.t)};
- CheckIntentInPointerAndDefinable(
- ompObjectList, llvm::omp::Clause::OMPC_reduction);
+ SymbolSourceMap symbols;
+ GetSymbolsInObjectList(ompObjectList, symbols);
+
CheckReductionArraySection(ompObjectList);
// If this is a worksharing construct then ensure the reduction variable
// is not private in the parallel region that it binds to.
@@ -3061,17 +3155,15 @@ void OmpStructureChecker::CheckReductionTypeList(
CheckSharedBindingInOuterContext(ompObjectList);
}
- SymbolSourceMap symbols;
- GetSymbolsInObjectList(ompObjectList, symbols);
for (auto &[symbol, source] : symbols) {
- if (IsProcedurePointer(*symbol)) {
- context_.Say(source,
- "A procedure pointer '%s' must not appear in a REDUCTION clause."_err_en_US,
- symbol->name());
- } else if (!IsReductionAllowedForType(x, DEREF(symbol->GetType()))) {
- context_.Say(source,
- "The type of '%s' is incompatible with the reduction operator."_err_en_US,
- symbol->name());
+ if (auto *type{symbol->GetType()}) {
+ if (!IsReductionAllowedForType(x, *type)) {
+ context_.Say(source,
+ "The type of '%s' is incompatible with the reduction operator."_err_en_US,
+ symbol->name());
+ }
+ } else {
+ assert(IsProcedurePointer(*symbol) && "Unexpected symbol properties");
}
}
}
@@ -3127,43 +3219,14 @@ void OmpStructureChecker::CheckReductionModifier(
}
}
-void OmpStructureChecker::CheckIntentInPointerAndDefinable(
- const parser::OmpObjectList &objectList, const llvm::omp::Clause clause) {
- for (const auto &ompObject : objectList.v) {
- if (const auto *name{parser::Unwrap<parser::Name>(ompObject)}) {
- if (const auto *symbol{name->symbol}) {
- if (IsPointer(symbol->GetUltimate()) &&
- IsIntentIn(symbol->GetUltimate())) {
- context_.Say(GetContext().clauseSource,
- "Pointer '%s' with the INTENT(IN) attribute may not appear "
- "in a %s clause"_err_en_US,
- symbol->name(),
- parser::ToUpperCaseLetters(getClauseName(clause).str()));
- } else if (auto msg{WhyNotDefinable(name->source,
- context_.FindScope(name->source), DefinabilityFlags{},
- *symbol)}) {
- context_
- .Say(GetContext().clauseSource,
- "Variable '%s' on the %s clause is not definable"_err_en_US,
- symbol->name(),
- parser::ToUpperCaseLetters(getClauseName(clause).str()))
- .Attach(std::move(msg->set_severity(parser::Severity::Because)));
- }
- }
- }
- }
-}
-
void OmpStructureChecker::CheckReductionArraySection(
const parser::OmpObjectList &ompObjectList) {
for (const auto &ompObject : ompObjectList.v) {
if (const auto *dataRef{parser::Unwrap<parser::DataRef>(ompObject)}) {
if (const auto *arrayElement{
parser::Unwrap<parser::ArrayElement>(ompObject)}) {
- if (arrayElement) {
- CheckArraySection(*arrayElement, GetLastName(*dataRef),
- llvm::omp::Clause::OMPC_reduction);
- }
+ CheckArraySection(*arrayElement, GetLastName(*dataRef),
+ llvm::omp::Clause::OMPC_reduction);
}
}
}
@@ -3232,9 +3295,11 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Shared &x) {
CheckIsVarPartOfAnotherVar(GetContext().clauseSource, x.v, "SHARED");
}
void OmpStructureChecker::Enter(const parser::OmpClause::Private &x) {
+ SymbolSourceMap symbols;
+ GetSymbolsInObjectList(x.v, symbols);
CheckAllowedClause(llvm::omp::Clause::OMPC_private);
CheckIsVarPartOfAnotherVar(GetContext().clauseSource, x.v, "PRIVATE");
- CheckIntentInPointer(x.v, llvm::omp::Clause::OMPC_private);
+ CheckIntentInPointer(symbols, llvm::omp::Clause::OMPC_private);
}
void OmpStructureChecker::Enter(const parser::OmpClause::Nowait &x) {
@@ -3891,11 +3956,11 @@ void OmpStructureChecker::CheckCopyingPolymorphicAllocatable(
void OmpStructureChecker::Enter(const parser::OmpClause::Copyprivate &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_copyprivate);
- CheckIntentInPointer(x.v, llvm::omp::Clause::OMPC_copyprivate);
- SymbolSourceMap currSymbols;
- GetSymbolsInObjectList(x.v, currSymbols);
+ SymbolSourceMap symbols;
+ GetSymbolsInObjectList(x.v, symbols);
+ CheckIntentInPointer(symbols, llvm::omp::Clause::OMPC_copyprivate);
CheckCopyingPolymorphicAllocatable(
- currSymbols, llvm::omp::Clause::OMPC_copyprivate);
+ symbols, llvm::omp::Clause::OMPC_copyprivate);
if (GetContext().directive == llvm::omp::Directive::OMPD_single) {
context_.Say(GetContext().clauseSource,
"%s clause is not allowed on the OMP %s directive,"
@@ -3945,29 +4010,26 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Copyin &x) {
currSymbols, llvm::omp::Clause::OMPC_copyin);
}
-void OmpStructureChecker::CheckStructureElement(
- const parser::OmpObjectList &ompObjectList,
- const llvm::omp::Clause clause) {
- for (const auto &ompObject : ompObjectList.v) {
+void OmpStructureChecker::CheckStructureComponent(
+ const parser::OmpObjectList &objects, llvm::omp::Clause clauseId) {
+ auto CheckComponent{[&](const parser::Designator &designator) {
+ if (auto *desg{std::get_if<parser::DataRef>(&designator.u)}) {
+ if (auto *comp{parser::Unwrap<parser::StructureComponent>(*desg)}) {
+ context_.Say(comp->component.source,
+ "A variable that is part of another variable cannot appear on the %s clause"_err_en_US,
+ parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
+ }
+ }
+ }};
+
+ for (const auto &object : objects.v) {
common::visit(
common::visitors{
- [&](const parser::Designator &designator) {
- if (std::get_if<parser::DataRef>(&designator.u)) {
- if (parser::Unwrap<parser::StructureComponent>(ompObject)) {
- context_.Say(GetContext().clauseSource,
- "A variable that is part of another variable "
- "(structure element) cannot appear on the %s "
- "%s clause"_err_en_US,
- ContextDirectiveAsFortran(),
- parser::ToUpperCaseLetters(getClauseName(clause).str()));
- }
- }
- },
+ CheckComponent,
[&](const parser::Name &name) {},
},
- ompObject.u);
+ object.u);
}
- return;
}
void OmpStructureChecker::Enter(const parser::OmpClause::Update &x) {
@@ -4009,7 +4071,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Update &x) {
}
void OmpStructureChecker::Enter(const parser::OmpClause::UseDevicePtr &x) {
- CheckStructureElement(x.v, llvm::omp::Clause::OMPC_use_device_ptr);
+ CheckStructureComponent(x.v, llvm::omp::Clause::OMPC_use_device_ptr);
CheckAllowedClause(llvm::omp::Clause::OMPC_use_device_ptr);
SymbolSourceMap currSymbols;
GetSymbolsInObjectList(x.v, currSymbols);
@@ -4038,7 +4100,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::UseDevicePtr &x) {
}
void OmpStructureChecker::Enter(const parser::OmpClause::UseDeviceAddr &x) {
- CheckStructureElement(x.v, llvm::omp::Clause::OMPC_use_device_addr);
+ CheckStructureComponent(x.v, llvm::omp::Clause::OMPC_use_device_addr);
CheckAllowedClause(llvm::omp::Clause::OMPC_use_device_addr);
SymbolSourceMap currSymbols;
GetSymbolsInObjectList(x.v, currSymbols);
@@ -4214,6 +4276,26 @@ llvm::StringRef OmpStructureChecker::getDirectiveName(
return llvm::omp::getOpenMPDirectiveName(directive);
}
+const Symbol *OmpStructureChecker::GetObjectSymbol(
+ const parser::OmpObject &object) {
+ if (auto *name{std::get_if<parser::Name>(&object.u)}) {
+ return &name->symbol->GetUltimate();
+ } else if (auto *desg{std::get_if<parser::Designator>(&object.u)}) {
+ return &GetLastName(*desg).symbol->GetUltimate();
+ }
+ return nullptr;
+}
+
+std::optional<parser::CharBlock> OmpStructureChecker::GetObjectSource(
+ const parser::OmpObject &object) {
+ if (auto *name{std::get_if<parser::Name>(&object.u)}) {
+ return name->source;
+ } else if (auto *desg{std::get_if<parser::Designator>(&object.u)}) {
+ return GetLastName(*desg).source;
+ }
+ return std::nullopt;
+}
+
void OmpStructureChecker::CheckDependList(const parser::DataRef &d) {
common::visit(
common::visitors{
@@ -4267,15 +4349,6 @@ void OmpStructureChecker::CheckArraySection(
"DEPEND "
"clause"_err_en_US);
}
- const auto stride{GetIntValue(strideExpr)};
- if ((stride && stride != 1)) {
- context_.Say(GetContext().clauseSource,
- "A list item that appears in a REDUCTION clause"
- " should have a contiguous storage array "
- "section."_err_en_US,
- ContextDirectiveAsFortran());
- break;
- }
}
}
}
@@ -4286,14 +4359,23 @@ void OmpStructureChecker::CheckArraySection(
}
void OmpStructureChecker::CheckIntentInPointer(
- const parser::OmpObjectList &objectList, const llvm::omp::Clause clause) {
- SymbolSourceMap symbols;
- GetSymbolsInObjectList(objectList, symbols);
+ SymbolSourceMap &symbols, llvm::omp::Clause clauseId) {
for (auto &[symbol, source] : symbols) {
if (IsPointer(*symbol) && IsIntentIn(*symbol)) {
context_.Say(source,
- "Pointer '%s' with the INTENT(IN) attribute may not appear "
- "in a %s clause"_err_en_US,
+ "Pointer '%s' with the INTENT(IN) attribute may not appear in a %s clause"_err_en_US,
+ symbol->name(),
+ parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
+ }
+ }
+}
+
+void OmpStructureChecker::CheckProcedurePointer(
+ SymbolSourceMap &symbols, llvm::omp::Clause clause) {
+ for (const auto &[symbol, source] : symbols) {
+ if (IsProcedurePointer(*symbol)) {
+ context_.Say(source,
+ "Procedure pointer '%s' may not appear in a %s clause"_err_en_US,
symbol->name(),
parser::ToUpperCaseLetters(getClauseName(clause).str()));
}
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 1411a9271d4665..3d4a30b994a70e 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -146,13 +146,14 @@ class OmpStructureChecker
bool CheckAllowedClause(llvmOmpClause clause);
bool IsVariableListItem(const Symbol &sym);
bool IsExtendedListItem(const Symbol &sym);
+ bool IsCommonBlock(const Symbol &sym);
std::optional<bool> IsContiguous(const parser::OmpObject &object);
void CheckMultipleOccurrence(semantics::UnorderedSymbolSet &listVars,
const std::list<parser::Name> &nameList, const parser::CharBlock &item,
const std::string &clauseName);
void CheckMultListItems();
- void CheckStructureElement(const parser::OmpObjectList &ompObjectList,
- const llvm::omp::Clause clause);
+ void CheckStructureComponent(const parser::OmpObjectList &objects,
+ llvm::omp::Clause clauseId);
bool HasInvalidWorksharingNesting(
const parser::CharBlock &, const OmpDirectiveSet &);
bool IsCloselyNestedRegion(const OmpDirectiveSet &set);
@@ -171,6 +172,9 @@ class OmpStructureChecker
typename IterTy = decltype(std::declval<RangeTy>().begin())>
std::optional<IterTy> FindDuplicate(RangeTy &&);
+ const Symbol *GetObjectSymbol(const parser::OmpObject &object);
+ std::optional<parser::CharBlock> GetObjectSource(
+ const parser::OmpObject &object);
void CheckDependList(const parser::DataRef &);
void CheckDependArraySection(
const common::Indirection<parser::ArrayElement> &, const parser::Name &);
@@ -182,8 +186,8 @@ class OmpStructureChecker
const parser::OmpObjectList &objList);
void CheckSymbolNames(
const parser::CharBlock &source, const parser::OmpObjectList &objList);
- void CheckIntentInPointer(
- const parser::OmpObjectList &, const llvm::omp::Clause);
+ void CheckIntentInPointer(SymbolSourceMap &, const llvm::omp::Clause);
+ void CheckProcedurePointer(SymbolSourceMap &, const llvm::omp::Clause);
void GetSymbolsInObjectList(const parser::OmpObjectList &, SymbolSourceMap &);
void CheckDefinableObjects(SymbolSourceMap &, const llvm::omp::Clause);
void CheckCopyingPolymorphicAllocatable(
@@ -220,6 +224,8 @@ class OmpStructureChecker
void CheckCancellationNest(
const parser::CharBlock &source, const parser::OmpCancelType::Type &type);
std::int64_t GetOrdCollapseLevel(const parser::OpenMPLoopConstruct &x);
+ void CheckReductionObjects(const parser::OmpObjectList &objects,
+ llvm::omp::Clause clauseId);
bool CheckReductionOperators(const parser::OmpClause::Reduction &);
bool CheckIntrinsicOperator(
const parser::DefinedOperator::IntrinsicOperator &);
diff --git a/flang/test/Semantics/OpenMP/reduction04.f90 b/flang/test/Semantics/OpenMP/reduction04.f90
index 319ed9f245abe8..e86b67b6efa961 100644
--- a/flang/test/Semantics/OpenMP/reduction04.f90
+++ b/flang/test/Semantics/OpenMP/reduction04.f90
@@ -14,8 +14,7 @@ program omp_Reduction
end do
!$omp end parallel do
- !ERROR: Variable 'c' on the REDUCTION clause is not definable
- !BECAUSE: 'c' is not a variable
+ !ERROR: Common block names are not allowed in REDUCTION clause
!$omp parallel do reduction(*:/c/)
do i = 1, 10
l = k + 1
diff --git a/flang/test/Semantics/OpenMP/reduction06.f90 b/flang/test/Semantics/OpenMP/reduction06.f90
index 58290c61cae860..7d2a37310a8893 100644
--- a/flang/test/Semantics/OpenMP/reduction06.f90
+++ b/flang/test/Semantics/OpenMP/reduction06.f90
@@ -8,21 +8,21 @@ program omp_reduction
integer :: k = 10
integer :: a(10), b(10,10,10)
- !ERROR: A list item that appears in a REDUCTION clause should have a contiguous storage array section.
+ !ERROR: Reference to a must be a contiguous object
!$omp parallel do reduction(+:a(1:10:3))
do i = 1, 10
k = k + 1
end do
!$omp end parallel do
- !ERROR: A list item that appears in a REDUCTION clause should have a contiguous storage array section.
+ !ERROR: Reference to b must be a contiguous object
!$omp parallel do reduction(+:b(1:10:3,1:8:1,1:5:1))
do i = 1, 10
k = k + 1
end do
!$omp end parallel do
- !ERROR: A list item that appears in a REDUCTION clause should have a contiguous storage array section.
+ !ERROR: Reference to b must be a contiguous object
!$omp parallel do reduction(+:b(1:10:1,1:8:2,1:5:1))
do i = 1, 10
k = k + 1
diff --git a/flang/test/Semantics/OpenMP/reduction12.f90 b/flang/test/Semantics/OpenMP/reduction12.f90
index f896ca4aa60b67..f7e3baf0465624 100644
--- a/flang/test/Semantics/OpenMP/reduction12.f90
+++ b/flang/test/Semantics/OpenMP/reduction12.f90
@@ -6,7 +6,7 @@
procedure(foo), pointer :: ptr
integer :: i
ptr => foo
-!ERROR: A procedure pointer 'ptr' must not appear in a REDUCTION clause.
+!ERROR: Proce...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/118689
More information about the flang-commits
mailing list