[clang] [OpenACC] Implement 'reduction' sema for compute constructs (PR #92808)

via cfe-commits cfe-commits at lists.llvm.org
Mon May 20 11:56:49 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang-modules

Author: Erich Keane (erichkeane)

<details>
<summary>Changes</summary>

'reduction' has a few restrictions over normal 'var-list' clauses:

1- On parallel, a num_gangs can only have 1 argument when combined with reduction. These two aren't able to be combined on any other of the compute constructs however.

2- The vars all must be 'numerical data types' types of some sort, or a 'composite of numerical data types'.  A list of types is given in the standard as a minimum, so we choose 'isScalar', which covers all of these types and keeps types that are actually numeric. Other compilers don't seem to implement the 'composite of numerical data types', though we do.

3- Because of the above restrictions, member-of-composite is not allowed, so any access via a memberexpr is disallowed. Array-element and sub-arrays (aka array sections) are both permitted, so long as they meet the requirements of #<!-- -->2.

This patch implements all of these for compute constructs.

---

Patch is 107.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92808.diff


39 Files Affected:

- (modified) clang/include/clang/AST/OpenACCClause.h (+29) 
- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+16-2) 
- (modified) clang/include/clang/Basic/OpenACCClauses.def (+1) 
- (modified) clang/include/clang/Basic/OpenACCKinds.h (+36) 
- (modified) clang/include/clang/Parse/Parser.h (+2-2) 
- (modified) clang/include/clang/Sema/SemaOpenACC.h (+27-2) 
- (modified) clang/lib/AST/OpenACCClause.cpp (+19-1) 
- (modified) clang/lib/AST/StmtProfile.cpp (+6) 
- (modified) clang/lib/AST/TextNodeDumper.cpp (+4) 
- (modified) clang/lib/Parse/ParseOpenACC.cpp (+16-14) 
- (modified) clang/lib/Sema/SemaOpenACC.cpp (+147-10) 
- (modified) clang/lib/Sema/TreeTransform.h (+20-1) 
- (modified) clang/lib/Serialization/ASTReader.cpp (+7-1) 
- (modified) clang/lib/Serialization/ASTWriter.cpp (+7-1) 
- (modified) clang/test/AST/ast-print-openacc-compute-construct.cpp (+28) 
- (modified) clang/test/ParserOpenACC/parse-clauses.c (+6-20) 
- (modified) clang/test/SemaOpenACC/compute-construct-attach-clause.c (+1-1) 
- (modified) clang/test/SemaOpenACC/compute-construct-clause-ast.cpp (+248) 
- (modified) clang/test/SemaOpenACC/compute-construct-copy-clause.c (+4-4) 
- (modified) clang/test/SemaOpenACC/compute-construct-copy-clause.cpp (+8-8) 
- (modified) clang/test/SemaOpenACC/compute-construct-copyin-clause.c (+5-5) 
- (modified) clang/test/SemaOpenACC/compute-construct-copyin-clause.cpp (+8-8) 
- (modified) clang/test/SemaOpenACC/compute-construct-copyout-clause.c (+5-5) 
- (modified) clang/test/SemaOpenACC/compute-construct-copyout-clause.cpp (+8-8) 
- (modified) clang/test/SemaOpenACC/compute-construct-create-clause.c (+5-5) 
- (modified) clang/test/SemaOpenACC/compute-construct-create-clause.cpp (+8-8) 
- (modified) clang/test/SemaOpenACC/compute-construct-device_type-clause.c (+1-1) 
- (modified) clang/test/SemaOpenACC/compute-construct-deviceptr-clause.c (+1-1) 
- (modified) clang/test/SemaOpenACC/compute-construct-firstprivate-clause.c (+4-4) 
- (modified) clang/test/SemaOpenACC/compute-construct-firstprivate-clause.cpp (+8-8) 
- (modified) clang/test/SemaOpenACC/compute-construct-no_create-clause.c (+4-4) 
- (modified) clang/test/SemaOpenACC/compute-construct-no_create-clause.cpp (+8-8) 
- (modified) clang/test/SemaOpenACC/compute-construct-present-clause.c (+4-4) 
- (modified) clang/test/SemaOpenACC/compute-construct-present-clause.cpp (+8-8) 
- (modified) clang/test/SemaOpenACC/compute-construct-private-clause.c (+5-5) 
- (modified) clang/test/SemaOpenACC/compute-construct-private-clause.cpp (+8-8) 
- (added) clang/test/SemaOpenACC/compute-construct-reduction-clause.c (+107) 
- (added) clang/test/SemaOpenACC/compute-construct-reduction-clause.cpp (+175) 
- (modified) clang/tools/libclang/CIndex.cpp (+4) 


``````````diff
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
index 607a2b9d65367..28ff8c44bd256 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -677,6 +677,35 @@ class OpenACCCreateClause final
          ArrayRef<Expr *> VarList, SourceLocation EndLoc);
 };
 
+class OpenACCReductionClause final
+    : public OpenACCClauseWithVarList,
+      public llvm::TrailingObjects<OpenACCReductionClause, Expr *> {
+  OpenACCReductionOperator Op;
+
+  OpenACCReductionClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
+                         OpenACCReductionOperator Operator,
+                         ArrayRef<Expr *> VarList, SourceLocation EndLoc)
+      : OpenACCClauseWithVarList(OpenACCClauseKind::Reduction, BeginLoc,
+                                 LParenLoc, EndLoc),
+        Op(Operator) {
+    std::uninitialized_copy(VarList.begin(), VarList.end(),
+                            getTrailingObjects<Expr *>());
+    setExprs(MutableArrayRef(getTrailingObjects<Expr *>(), VarList.size()));
+  }
+
+public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Reduction;
+  }
+
+  static OpenACCReductionClause *
+  Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
+         OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
+         SourceLocation EndLoc);
+
+  OpenACCReductionOperator getReductionOp() const { return Op; }
+};
+
 template <class Impl> class OpenACCClauseVisitor {
   Impl &getDerived() { return static_cast<Impl &>(*this); }
 
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index e3c65cba4886a..c7dea1d54d063 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12343,7 +12343,8 @@ def err_acc_num_gangs_num_args
             "provided}0">;
 def err_acc_not_a_var_ref
     : Error<"OpenACC variable is not a valid variable name, sub-array, array "
-            "element, or composite variable member">;
+            "element,%select{| member of a composite variable,}0 or composite "
+            "variable member">;
 def err_acc_typecheck_subarray_value
     : Error<"OpenACC sub-array subscripted value is not an array or pointer">;
 def err_acc_subarray_function_type
@@ -12374,5 +12375,18 @@ def note_acc_expected_pointer_var : Note<"expected variable of pointer type">;
 def err_acc_clause_after_device_type
     : Error<"OpenACC clause '%0' may not follow a '%1' clause in a "
             "compute construct">;
-
+def err_acc_reduction_num_gangs_conflict
+    : Error<
+          "OpenACC 'reduction' clause may not appear on a 'parallel' construct "
+          "with a 'num_gangs' clause with more than 1 argument, have %0">;
+def err_acc_reduction_type
+    : Error<"OpenACC 'reduction' variable must be of scalar type, sub-array, or a "
+            "composite of scalar types;%select{| sub-array base}1 type is %0">;
+def err_acc_reduction_composite_type
+    : Error<"OpenACC 'reduction' variable must be a composite of scalar types; "
+            "%1 %select{is not a class or struct|is incomplete|is not an "
+            "aggregate}0">;
+def err_acc_reduction_composite_member_type :Error<
+    "OpenACC 'reduction' composite variable must not have non-scalar field">;
+def note_acc_reduction_composite_member_loc : Note<"invalid field is here">;
 } // end of sema component.
diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def
index 7ecc51799468c..3e464abaafd92 100644
--- a/clang/include/clang/Basic/OpenACCClauses.def
+++ b/clang/include/clang/Basic/OpenACCClauses.def
@@ -46,6 +46,7 @@ VISIT_CLAUSE(NumGangs)
 VISIT_CLAUSE(NumWorkers)
 VISIT_CLAUSE(Present)
 VISIT_CLAUSE(Private)
+VISIT_CLAUSE(Reduction)
 VISIT_CLAUSE(Self)
 VISIT_CLAUSE(VectorLength)
 VISIT_CLAUSE(Wait)
diff --git a/clang/include/clang/Basic/OpenACCKinds.h b/clang/include/clang/Basic/OpenACCKinds.h
index 0e38a04e7164b..7b9d619a8aec6 100644
--- a/clang/include/clang/Basic/OpenACCKinds.h
+++ b/clang/include/clang/Basic/OpenACCKinds.h
@@ -514,6 +514,42 @@ enum class OpenACCReductionOperator {
   /// Invalid Reduction Clause Kind.
   Invalid,
 };
+
+template <typename StreamTy>
+inline StreamTy &printOpenACCReductionOperator(StreamTy &Out,
+                                               OpenACCReductionOperator Op) {
+  switch (Op) {
+  case OpenACCReductionOperator::Addition:
+    return Out << "+";
+  case OpenACCReductionOperator::Multiplication:
+    return Out << "*";
+  case OpenACCReductionOperator::Max:
+    return Out << "max";
+  case OpenACCReductionOperator::Min:
+    return Out << "min";
+  case OpenACCReductionOperator::BitwiseAnd:
+    return Out << "&";
+  case OpenACCReductionOperator::BitwiseOr:
+    return Out << "|";
+  case OpenACCReductionOperator::BitwiseXOr:
+    return Out << "^";
+  case OpenACCReductionOperator::And:
+    return Out << "&&";
+  case OpenACCReductionOperator::Or:
+    return Out << "||";
+  case OpenACCReductionOperator::Invalid:
+    return Out << "<invalid>";
+  }
+  llvm_unreachable("Unknown reduction operator kind");
+}
+inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
+                                             OpenACCReductionOperator Op) {
+  return printOpenACCReductionOperator(Out, Op);
+}
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
+                                     OpenACCReductionOperator Op) {
+  return printOpenACCReductionOperator(Out, Op);
+}
 } // namespace clang
 
 #endif // LLVM_CLANG_BASIC_OPENACCKINDS_H
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 5f04664141d29..3c4ab649e3b4c 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3686,9 +3686,9 @@ class Parser : public CodeCompletionHandler {
 
   using OpenACCVarParseResult = std::pair<ExprResult, OpenACCParseCanContinue>;
   /// Parses a single variable in a variable list for OpenACC.
-  OpenACCVarParseResult ParseOpenACCVar();
+  OpenACCVarParseResult ParseOpenACCVar(OpenACCClauseKind CK);
   /// Parses the variable list for the variety of places that take a var-list.
-  llvm::SmallVector<Expr *> ParseOpenACCVarList();
+  llvm::SmallVector<Expr *> ParseOpenACCVarList(OpenACCClauseKind CK);
   /// Parses any parameters for an OpenACC Clause, including required/optional
   /// parens.
   OpenACCClauseParseResult
diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h
index f838fa97d33a2..6f69fa08939b8 100644
--- a/clang/include/clang/Sema/SemaOpenACC.h
+++ b/clang/include/clang/Sema/SemaOpenACC.h
@@ -66,9 +66,14 @@ class SemaOpenACC : public SemaBase {
     struct DeviceTypeDetails {
       SmallVector<DeviceTypeArgument> Archs;
     };
+    struct ReductionDetails {
+      OpenACCReductionOperator Op;
+      SmallVector<Expr *> VarList;
+    };
 
     std::variant<std::monostate, DefaultDetails, ConditionDetails,
-                 IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails>
+                 IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails,
+                 ReductionDetails>
         Details = std::monostate{};
 
   public:
@@ -170,6 +175,10 @@ class SemaOpenACC : public SemaBase {
       return const_cast<OpenACCParsedClause *>(this)->getIntExprs();
     }
 
+    OpenACCReductionOperator getReductionOp() const {
+      return std::get<ReductionDetails>(Details).Op;
+    }
+
     ArrayRef<Expr *> getVarList() {
       assert((ClauseKind == OpenACCClauseKind::Private ||
               ClauseKind == OpenACCClauseKind::NoCreate ||
@@ -188,8 +197,13 @@ class SemaOpenACC : public SemaBase {
               ClauseKind == OpenACCClauseKind::PresentOrCreate ||
               ClauseKind == OpenACCClauseKind::Attach ||
               ClauseKind == OpenACCClauseKind::DevicePtr ||
+              ClauseKind == OpenACCClauseKind::Reduction ||
               ClauseKind == OpenACCClauseKind::FirstPrivate) &&
              "Parsed clause kind does not have a var-list");
+
+      if (ClauseKind == OpenACCClauseKind::Reduction)
+        return std::get<ReductionDetails>(Details).VarList;
+
       return std::get<VarListDetails>(Details).VarList;
     }
 
@@ -334,6 +348,13 @@ class SemaOpenACC : public SemaBase {
       Details = VarListDetails{std::move(VarList), IsReadOnly, IsZero};
     }
 
+    void setReductionDetails(OpenACCReductionOperator Op,
+                             llvm::SmallVector<Expr *> &&VarList) {
+      assert(ClauseKind == OpenACCClauseKind::Reduction &&
+             "reduction details only valid on reduction");
+      Details = ReductionDetails{Op, std::move(VarList)};
+    }
+
     void setWaitDetails(Expr *DevNum, SourceLocation QueuesLoc,
                         llvm::SmallVector<Expr *> &&IntExprs) {
       assert(ClauseKind == OpenACCClauseKind::Wait &&
@@ -394,7 +415,11 @@ class SemaOpenACC : public SemaBase {
 
   /// Called when encountering a 'var' for OpenACC, ensures it is actually a
   /// declaration reference to a variable of the correct type.
-  ExprResult ActOnVar(Expr *VarExpr);
+  ExprResult ActOnVar(OpenACCClauseKind CK, Expr *VarExpr);
+
+  /// Called while semantically analyzing the reduction clause, ensuring the var
+  /// is the correct kind of reference.
+  ExprResult CheckReductionVar(Expr *VarExpr);
 
   /// Called to check the 'var' type is a variable of pointer type, necessary
   /// for 'deviceptr' and 'attach' clauses. Returns true on success.
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index 8ff6dabcbc48e..cb2c7f98be75c 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -35,7 +35,7 @@ bool OpenACCClauseWithVarList::classof(const OpenACCClause *C) {
          OpenACCAttachClause::classof(C) || OpenACCNoCreateClause::classof(C) ||
          OpenACCPresentClause::classof(C) || OpenACCCopyClause::classof(C) ||
          OpenACCCopyInClause::classof(C) || OpenACCCopyOutClause::classof(C) ||
-         OpenACCCreateClause::classof(C);
+         OpenACCReductionClause::classof(C) || OpenACCCreateClause::classof(C);
 }
 bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
   return OpenACCIfClause::classof(C) || OpenACCSelfClause::classof(C);
@@ -310,6 +310,16 @@ OpenACCDeviceTypeClause *OpenACCDeviceTypeClause::Create(
       OpenACCDeviceTypeClause(K, BeginLoc, LParenLoc, Archs, EndLoc);
 }
 
+OpenACCReductionClause *OpenACCReductionClause::Create(
+    const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
+    OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
+    SourceLocation EndLoc) {
+  void *Mem = C.Allocate(
+      OpenACCReductionClause::totalSizeToAlloc<Expr *>(VarList.size()));
+  return new (Mem)
+      OpenACCReductionClause(BeginLoc, LParenLoc, Operator, VarList, EndLoc);
+}
+
 //===----------------------------------------------------------------------===//
 //  OpenACC clauses printing methods
 //===----------------------------------------------------------------------===//
@@ -445,6 +455,14 @@ void OpenACCClausePrinter::VisitCreateClause(const OpenACCCreateClause &C) {
   OS << ")";
 }
 
+void OpenACCClausePrinter::VisitReductionClause(
+    const OpenACCReductionClause &C) {
+  OS << "reduction(" << C.getReductionOp() << ": ";
+  llvm::interleaveComma(C.getVarList(), OS,
+                        [&](const Expr *E) { printExpr(E); });
+  OS << ")";
+}
+
 void OpenACCClausePrinter::VisitWaitClause(const OpenACCWaitClause &C) {
   OS << "wait";
   if (!C.getLParenLoc().isInvalid()) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index caab4ab0ef160..00b8c43af035c 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2588,6 +2588,12 @@ void OpenACCClauseProfiler::VisitWaitClause(const OpenACCWaitClause &Clause) {
 /// Nothing to do here, there are no sub-statements.
 void OpenACCClauseProfiler::VisitDeviceTypeClause(
     const OpenACCDeviceTypeClause &Clause) {}
+
+void OpenACCClauseProfiler::VisitReductionClause(
+    const OpenACCReductionClause &Clause) {
+  for (auto *E : Clause.getVarList())
+    Profiler.VisitStmt(E);
+}
 } // namespace
 
 void StmtProfiler::VisitOpenACCComputeConstruct(
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index efcd74717a4e2..4a1e94ffe283b 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -457,6 +457,10 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
           });
       OS << ")";
       break;
+    case OpenACCClauseKind::Reduction:
+      OS << " clause Operator: "
+         << cast<OpenACCReductionClause>(C)->getReductionOp();
+      break;
     default:
       // Nothing to do here.
       break;
diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp
index 5db3036b00030..e9c60f76165b6 100644
--- a/clang/lib/Parse/ParseOpenACC.cpp
+++ b/clang/lib/Parse/ParseOpenACC.cpp
@@ -920,7 +920,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::PresentOrCopyIn: {
       bool IsReadOnly = tryParseAndConsumeSpecialTokenKind(
           *this, OpenACCSpecialTokenKind::ReadOnly, ClauseKind);
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(), IsReadOnly,
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
+                                     IsReadOnly,
                                      /*IsZero=*/false);
       break;
     }
@@ -932,16 +933,17 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::PresentOrCopyOut: {
       bool IsZero = tryParseAndConsumeSpecialTokenKind(
           *this, OpenACCSpecialTokenKind::Zero, ClauseKind);
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(),
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
                                      /*IsReadOnly=*/false, IsZero);
       break;
     }
-    case OpenACCClauseKind::Reduction:
+    case OpenACCClauseKind::Reduction: {
       // If we're missing a clause-kind (or it is invalid), see if we can parse
       // the var-list anyway.
-      ParseReductionOperator(*this);
-      ParseOpenACCVarList();
+      OpenACCReductionOperator Op = ParseReductionOperator(*this);
+      ParsedClause.setReductionDetails(Op, ParseOpenACCVarList(ClauseKind));
       break;
+    }
     case OpenACCClauseKind::Self:
       // The 'self' clause is a var-list instead of a 'condition' in the case of
       // the 'update' clause, so we have to handle it here.  U se an assert to
@@ -955,11 +957,11 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::Host:
     case OpenACCClauseKind::Link:
     case OpenACCClauseKind::UseDevice:
-      ParseOpenACCVarList();
+      ParseOpenACCVarList(ClauseKind);
       break;
     case OpenACCClauseKind::Attach:
     case OpenACCClauseKind::DevicePtr:
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(),
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
                                      /*IsReadOnly=*/false, /*IsZero=*/false);
       break;
     case OpenACCClauseKind::Copy:
@@ -969,7 +971,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
     case OpenACCClauseKind::NoCreate:
     case OpenACCClauseKind::Present:
     case OpenACCClauseKind::Private:
-      ParsedClause.setVarListDetails(ParseOpenACCVarList(),
+      ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
                                      /*IsReadOnly=*/false, /*IsZero=*/false);
       break;
     case OpenACCClauseKind::Collapse: {
@@ -1278,7 +1280,7 @@ ExprResult Parser::ParseOpenACCBindClauseArgument() {
 /// - an array element
 /// - a member of a composite variable
 /// - a common block name between slashes (fortran only)
-Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
+Parser::OpenACCVarParseResult Parser::ParseOpenACCVar(OpenACCClauseKind CK) {
   OpenACCArraySectionRAII ArraySections(*this);
 
   ExprResult Res = ParseAssignmentExpression();
@@ -1289,15 +1291,15 @@ Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
   if (!Res.isUsable())
     return {Res, OpenACCParseCanContinue::Can};
 
-  Res = getActions().OpenACC().ActOnVar(Res.get());
+  Res = getActions().OpenACC().ActOnVar(CK, Res.get());
 
   return {Res, OpenACCParseCanContinue::Can};
 }
 
-llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
+llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList(OpenACCClauseKind CK) {
   llvm::SmallVector<Expr *> Vars;
 
-  auto [Res, CanContinue] = ParseOpenACCVar();
+  auto [Res, CanContinue] = ParseOpenACCVar(CK);
   if (Res.isUsable()) {
     Vars.push_back(Res.get());
   } else if (CanContinue == OpenACCParseCanContinue::Cannot) {
@@ -1308,7 +1310,7 @@ llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
   while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
     ExpectAndConsume(tok::comma);
 
-    auto [Res, CanContinue] = ParseOpenACCVar();
+    auto [Res, CanContinue] = ParseOpenACCVar(CK);
 
     if (Res.isUsable()) {
       Vars.push_back(Res.get());
@@ -1342,7 +1344,7 @@ void Parser::ParseOpenACCCacheVarList() {
 
   // ParseOpenACCVarList should leave us before a r-paren, so no need to skip
   // anything here.
-  ParseOpenACCVarList();
+  ParseOpenACCVarList(OpenACCClauseKind::Invalid);
 }
 
 Parser::OpenACCDirectiveParseInfo Parser::ParseOpenACCDirective() {
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index f174b2fa63c6a..49847bd997161 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -233,6 +233,19 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
       return false;
     }
 
+  case OpenACCClauseKind::Reduction:
+    switch (DirectiveKind) {
+    case OpenACCDirectiveKind::Parallel:
+    case OpenACCDirectiveKind::Serial:
+    case OpenACCDirectiveKind::Loop:
+    case OpenACCDirectiveKind::ParallelLoop:
+    case OpenACCDirectiveKind::SerialLoop:
+    case OpenACCDirectiveKind::KernelsLoop:
+      return true;
+    default:
+      return false;
+    }
+
   default:
     // Do nothing so we can go to the 'unimplemented' diagnostic instead.
     return true;
@@ -281,7 +294,6 @@ bool checkValidAfterDeviceType(
     return true;
   }
 }
-
 } // namespace
 
 SemaOpenACC::SemaOpenACC(Sema &S) : SemaBase(S) {}
@@ -426,6 +438,24 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
           << /*NoArgs=*/1 << Clause.getDirectiveKind() << MaxArgs
           << Clause.getIntExprs().size();
 
+    // OpenACC 3.3 Section 2.5.4:
+    // A reduction clause may not appear on a parallel construct with a
+    // num_gangs clause that has more than one argument.
+    if (Clause.getDirectiveKind() == OpenACCDirectiveKind::Parallel &&
+        Clause.getIntExprs().size() > 1) {
+      auto *Parallel =
+          llvm::find_if(ExistingClauses, [](const OpenACCClause *C) {
+            return C->getClauseKind() == OpenACCClauseKind::Reduction;
+          });
+
+      if (Parallel != ExistingClauses.end()) {
+        Diag(Clause.getBeginLoc(), diag::err_acc_reduction_num_gangs_conflict)
+            << Clause.getIntExprs().size();
+        Diag((*Parallel)->getBeginLoc(), diag::note_acc_previous_clause_here);
+        return nullptr;
+      }
+    }
+
     // Create the AST node for the clause even if the number of expressions is
     // incorrect.
     return OpenACCNumGangsClause::Create(
@@ -706,6 +736,48 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
         Clause.getLParenLoc(), Clause.getDeviceTypeArchitectures(),
         Clause.getEndLoc());
   }
+  case OpenACCClauseKind::Reduction: {
+    // Restrictions only properly implemented on 'compute' constructs, and
+    // 'compute' constructs are the only construct that can do anything with
+    // this yet, so skip/treat as unimplemented in this case.
+    if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()))
+      break;
+
+    // OpenACC 3.3 Section 2.5.4:
+    // A reduction clause may not appear on a parallel construct with a
+...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/92808


More information about the cfe-commits mailing list