[clang] [llvm] [openmp] [Clang][OpenMP] Implement Loop splitting `#pragma omp split` directive (PR #190397)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 3 13:12:20 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang-codegen

Author: Amit Tiwari (amitamd7)

<details>
<summary>Changes</summary>

Implement Loop-splitting #pragma omp split construct with counts clause.
Posting this PR after the revert of PR ([#<!-- -->183261](https://github.com/llvm/llvm-project/pull/183261))

---

Patch is 218.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/190397.diff


80 Files Affected:

- (modified) clang/bindings/python/clang/cindex.py (+3) 
- (modified) clang/include/clang-c/Index.h (+4) 
- (modified) clang/include/clang/AST/OpenMPClause.h (+101) 
- (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+10) 
- (modified) clang/include/clang/AST/StmtOpenMP.h (+78) 
- (modified) clang/include/clang/ASTMatchers/ASTMatchers.h (+20) 
- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2) 
- (modified) clang/include/clang/Basic/StmtNodes.td (+1) 
- (modified) clang/include/clang/Parse/Parser.h (+3) 
- (modified) clang/include/clang/Sema/SemaOpenMP.h (+12) 
- (modified) clang/include/clang/Serialization/ASTBitCodes.h (+1) 
- (modified) clang/lib/AST/OpenMPClause.cpp (+35) 
- (modified) clang/lib/AST/StmtOpenMP.cpp (+21) 
- (modified) clang/lib/AST/StmtPrinter.cpp (+5) 
- (modified) clang/lib/AST/StmtProfile.cpp (+10) 
- (modified) clang/lib/ASTMatchers/ASTMatchersInternal.cpp (+4) 
- (modified) clang/lib/ASTMatchers/Dynamic/Registry.cpp (+2) 
- (modified) clang/lib/Basic/OpenMPKinds.cpp (+4-1) 
- (modified) clang/lib/CodeGen/CGStmt.cpp (+3) 
- (modified) clang/lib/CodeGen/CGStmtOpenMP.cpp (+8) 
- (modified) clang/lib/CodeGen/CodeGenFunction.h (+1) 
- (modified) clang/lib/Parse/ParseOpenMP.cpp (+59) 
- (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1) 
- (modified) clang/lib/Sema/SemaOpenMP.cpp (+271) 
- (modified) clang/lib/Sema/TreeTransform.h (+44) 
- (modified) clang/lib/Serialization/ASTReader.cpp (+15) 
- (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+11) 
- (modified) clang/lib/Serialization/ASTWriter.cpp (+11) 
- (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+5) 
- (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+1) 
- (added) clang/test/AST/ast-dump-openmp-split.c (+19) 
- (added) clang/test/Index/openmp-split.c (+11) 
- (added) clang/test/OpenMP/split_analyze.c (+10) 
- (added) clang/test/OpenMP/split_ast_print.cpp (+71) 
- (added) clang/test/OpenMP/split_codegen.cpp (+1986) 
- (added) clang/test/OpenMP/split_composition.cpp (+17) 
- (added) clang/test/OpenMP/split_compound_associated.cpp (+13) 
- (added) clang/test/OpenMP/split_counts_constexpr.cpp (+19) 
- (added) clang/test/OpenMP/split_counts_ice.c (+56) 
- (added) clang/test/OpenMP/split_counts_verify.c (+123) 
- (added) clang/test/OpenMP/split_diag_errors.c (+61) 
- (added) clang/test/OpenMP/split_distribute_inner_split.cpp (+14) 
- (added) clang/test/OpenMP/split_driver_smoke.c (+12) 
- (added) clang/test/OpenMP/split_iv_types.c (+24) 
- (added) clang/test/OpenMP/split_loop_styles.cpp (+14) 
- (added) clang/test/OpenMP/split_member_ctor.cpp (+20) 
- (added) clang/test/OpenMP/split_messages.cpp (+108) 
- (added) clang/test/OpenMP/split_nested_outer_only.c (+12) 
- (added) clang/test/OpenMP/split_offload_codegen.cpp (+27) 
- (added) clang/test/OpenMP/split_omp_fill.c (+36) 
- (added) clang/test/OpenMP/split_openmp_version.cpp (+22) 
- (added) clang/test/OpenMP/split_opts_simd_debug.cpp (+30) 
- (added) clang/test/OpenMP/split_parallel_split.cpp (+15) 
- (added) clang/test/OpenMP/split_pch_codegen.cpp (+43) 
- (added) clang/test/OpenMP/split_range_for_diag.cpp (+25) 
- (added) clang/test/OpenMP/split_serialize_module.cpp (+24) 
- (added) clang/test/OpenMP/split_teams_nesting.cpp (+13) 
- (added) clang/test/OpenMP/split_template_nttp.cpp (+15) 
- (added) clang/test/OpenMP/split_templates.cpp (+30) 
- (added) clang/test/OpenMP/split_trip_volatile.c (+14) 
- (modified) clang/tools/libclang/CIndex.cpp (+7) 
- (modified) clang/tools/libclang/CXCursor.cpp (+3) 
- (modified) clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp (+62) 
- (modified) clang/unittests/ASTMatchers/ASTMatchersTest.h (+14) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+11-10) 
- (added) openmp/runtime/test/transform/split/fill_first.c (+23) 
- (added) openmp/runtime/test/transform/split/foreach.cpp (+24) 
- (added) openmp/runtime/test/transform/split/intfor.c (+26) 
- (added) openmp/runtime/test/transform/split/intfor_negstart.c (+27) 
- (added) openmp/runtime/test/transform/split/iterfor.cpp (+139) 
- (added) openmp/runtime/test/transform/split/leq_bound.c (+22) 
- (added) openmp/runtime/test/transform/split/lit.local.cfg (+5) 
- (added) openmp/runtime/test/transform/split/negative_incr.c (+22) 
- (added) openmp/runtime/test/transform/split/nonconstant_incr.c (+22) 
- (added) openmp/runtime/test/transform/split/parallel-split-intfor.c (+27) 
- (added) openmp/runtime/test/transform/split/single_fill.c (+23) 
- (added) openmp/runtime/test/transform/split/three_segments.c (+26) 
- (added) openmp/runtime/test/transform/split/trip_one.c (+32) 
- (added) openmp/runtime/test/transform/split/unsigned_iv.c (+24) 
- (added) openmp/runtime/test/transform/split/zero_first_segment.c (+21) 


``````````diff
diff --git a/clang/bindings/python/clang/cindex.py b/clang/bindings/python/clang/cindex.py
index b71f9ed2275e0..a90d48cf6d481 100644
--- a/clang/bindings/python/clang/cindex.py
+++ b/clang/bindings/python/clang/cindex.py
@@ -1453,6 +1453,9 @@ def is_unexposed(self):
     # OpenMP fuse directive.
     OMP_FUSE_DIRECTIVE = 311
 
+    # OpenMP split directive.
+    OMP_SPLIT_DIRECTIVE = 312
+
     # OpenACC Compute Construct.
     OPEN_ACC_COMPUTE_DIRECTIVE = 320
 
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index dcf1f4f1b4258..119bd68ff9814 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2166,6 +2166,10 @@ enum CXCursorKind {
    */
   CXCursor_OMPFuseDirective = 311,
 
+  /** OpenMP split directive.
+   */
+  CXCursor_OMPSplitDirective = 312,
+
   /** OpenACC Compute Construct.
    */
   CXCursor_OpenACCComputeConstruct = 320,
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index af5d3f4698eda..ccf2c40bc5efa 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -39,6 +39,7 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/TrailingObjects.h"
 #include <cassert>
+#include <climits>
 #include <cstddef>
 #include <iterator>
 #include <utility>
@@ -1023,6 +1024,106 @@ class OMPSizesClause final
   }
 };
 
+/// This represents the 'counts' clause in the '#pragma omp split' directive.
+///
+/// \code
+/// #pragma omp split counts(3, omp_fill, 2)
+/// for (int i = 0; i < n; ++i) { ... }
+/// \endcode
+class OMPCountsClause final
+    : public OMPClause,
+      private llvm::TrailingObjects<OMPCountsClause, Expr *> {
+  friend class OMPClauseReader;
+  friend class llvm::TrailingObjects<OMPCountsClause, Expr *>;
+
+  /// Location of '('.
+  SourceLocation LParenLoc;
+
+  /// Number of count expressions in the clause.
+  unsigned NumCounts = 0;
+
+  /// 0-based index of the omp_fill list item.
+  std::optional<unsigned> OmpFillIndex;
+
+  /// Source location of the omp_fill keyword.
+  SourceLocation OmpFillLoc;
+
+  /// Build an empty clause.
+  explicit OMPCountsClause(int NumCounts)
+      : OMPClause(llvm::omp::OMPC_counts, SourceLocation(), SourceLocation()),
+        NumCounts(NumCounts) {}
+
+  /// Sets the location of '('.
+  void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
+  void setOmpFillIndex(std::optional<unsigned> Idx) { OmpFillIndex = Idx; }
+  void setOmpFillLoc(SourceLocation Loc) { OmpFillLoc = Loc; }
+
+  /// Sets the count expressions.
+  void setCountsRefs(ArrayRef<Expr *> VL) {
+    assert(VL.size() == NumCounts);
+    llvm::copy(VL, getCountsRefs().begin());
+  }
+
+public:
+  /// Build a 'counts' AST node.
+  ///
+  /// \param C         Context of the AST.
+  /// \param StartLoc  Location of the 'counts' identifier.
+  /// \param LParenLoc Location of '('.
+  /// \param EndLoc    Location of ')'.
+  /// \param Counts    Content of the clause.
+  static OMPCountsClause *Create(const ASTContext &C, SourceLocation StartLoc,
+                                 SourceLocation LParenLoc,
+                                 SourceLocation EndLoc, ArrayRef<Expr *> Counts,
+                                 std::optional<unsigned> FillIdx,
+                                 SourceLocation FillLoc);
+
+  /// Build an empty 'counts' AST node for deserialization.
+  ///
+  /// \param C          Context of the AST.
+  /// \param NumCounts   Number of items in the clause.
+  static OMPCountsClause *CreateEmpty(const ASTContext &C, unsigned NumCounts);
+
+  /// Returns the location of '('.
+  SourceLocation getLParenLoc() const { return LParenLoc; }
+
+  /// Returns the number of list items.
+  unsigned getNumCounts() const { return NumCounts; }
+
+  std::optional<unsigned> getOmpFillIndex() const { return OmpFillIndex; }
+  SourceLocation getOmpFillLoc() const { return OmpFillLoc; }
+  bool hasOmpFill() const { return OmpFillIndex.has_value(); }
+
+  /// Returns the count expressions.
+  MutableArrayRef<Expr *> getCountsRefs() {
+    return getTrailingObjects(NumCounts);
+  }
+  ArrayRef<Expr *> getCountsRefs() const {
+    return getTrailingObjects(NumCounts);
+  }
+
+  child_range children() {
+    MutableArrayRef<Expr *> Counts = getCountsRefs();
+    return child_range(reinterpret_cast<Stmt **>(Counts.begin()),
+                       reinterpret_cast<Stmt **>(Counts.end()));
+  }
+  const_child_range children() const {
+    ArrayRef<Expr *> Counts = getCountsRefs();
+    return const_child_range(reinterpret_cast<Stmt *const *>(Counts.begin()),
+                             reinterpret_cast<Stmt *const *>(Counts.end()));
+  }
+  child_range used_children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range used_children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == llvm::omp::OMPC_counts;
+  }
+};
+
 /// This class represents the 'permutation' clause in the
 /// '#pragma omp interchange' directive.
 ///
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index ce6ad723191e0..1a14dd2c666b5 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3202,6 +3202,9 @@ DEF_TRAVERSE_STMT(OMPFuseDirective,
 DEF_TRAVERSE_STMT(OMPInterchangeDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
+DEF_TRAVERSE_STMT(OMPSplitDirective,
+                  { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
 DEF_TRAVERSE_STMT(OMPForDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
@@ -3503,6 +3506,13 @@ bool RecursiveASTVisitor<Derived>::VisitOMPSizesClause(OMPSizesClause *C) {
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPCountsClause(OMPCountsClause *C) {
+  for (Expr *E : C->getCountsRefs())
+    TRY_TO(TraverseStmt(E));
+  return true;
+}
+
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPPermutationClause(
     OMPPermutationClause *C) {
diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index bc6aeaa8d143c..dbc76e7df8ecd 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -6065,6 +6065,84 @@ class OMPFuseDirective final
   }
 };
 
+/// Represents the '#pragma omp split' loop transformation directive.
+///
+/// \code{.c}
+///   #pragma omp split counts(3, omp_fill, 2)
+///   for (int i = 0; i < n; ++i)
+///     ...
+/// \endcode
+///
+/// This directive transforms a single loop into multiple loops based on
+/// index ranges. The transformation splits the iteration space of the loop
+/// into multiple contiguous ranges. The \c counts clause is required and
+/// exactly one list item must be \c omp_fill.
+class OMPSplitDirective final
+    : public OMPCanonicalLoopNestTransformationDirective {
+  friend class ASTStmtReader;
+  friend class OMPExecutableDirective;
+
+  /// Offsets of child members.
+  enum {
+    PreInitsOffset = 0,
+    TransformedStmtOffset,
+  };
+
+  explicit OMPSplitDirective(SourceLocation StartLoc, SourceLocation EndLoc,
+                             unsigned NumLoops)
+      : OMPCanonicalLoopNestTransformationDirective(
+            OMPSplitDirectiveClass, llvm::omp::OMPD_split, StartLoc, EndLoc,
+            NumLoops) {}
+
+  void setPreInits(Stmt *PreInits) {
+    Data->getChildren()[PreInitsOffset] = PreInits;
+  }
+
+  void setTransformedStmt(Stmt *S) {
+    Data->getChildren()[TransformedStmtOffset] = S;
+  }
+
+public:
+  /// Create a new AST node representation for '#pragma omp split'.
+  ///
+  /// \param C         Context of the AST.
+  /// \param StartLoc  Location of the introducer (e.g. the 'omp' token).
+  /// \param EndLoc    Location of the directive's end (e.g. the tok::eod).
+  /// \param Clauses   The directive's clauses (e.g. the required \c counts
+  ///                  clause).
+  /// \param NumLoops  Number of affected loops (should be 1 for split).
+  /// \param AssociatedStmt  The outermost associated loop.
+  /// \param TransformedStmt The loop nest after splitting, or nullptr in
+  ///                        dependent contexts.
+  /// \param PreInits   Helper preinits statements for the loop nest.
+  static OMPSplitDirective *Create(const ASTContext &C, SourceLocation StartLoc,
+                                   SourceLocation EndLoc,
+                                   ArrayRef<OMPClause *> Clauses,
+                                   unsigned NumLoops, Stmt *AssociatedStmt,
+                                   Stmt *TransformedStmt, Stmt *PreInits);
+
+  /// Build an empty '#pragma omp split' AST node for deserialization.
+  ///
+  /// \param C          Context of the AST.
+  /// \param NumClauses Number of clauses to allocate.
+  /// \param NumLoops   Number of associated loops to allocate.
+  static OMPSplitDirective *CreateEmpty(const ASTContext &C,
+                                        unsigned NumClauses, unsigned NumLoops);
+
+  /// Gets/sets the associated loops after the transformation, i.e. after
+  /// de-sugaring.
+  Stmt *getTransformedStmt() const {
+    return Data->getChildren()[TransformedStmtOffset];
+  }
+
+  /// Return preinits statement.
+  Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == OMPSplitDirectiveClass;
+  }
+};
+
 /// This represents '#pragma omp scan' directive.
 ///
 /// \code
diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h
index e8e7643e0dddd..87b6dbefa7a62 100644
--- a/clang/include/clang/ASTMatchers/ASTMatchers.h
+++ b/clang/include/clang/ASTMatchers/ASTMatchers.h
@@ -8781,6 +8781,26 @@ extern const internal::VariadicDynCastAllOfMatcher<Stmt,
                                                    OMPTargetUpdateDirective>
     ompTargetUpdateDirective;
 
+/// Matches any ``#pragma omp split`` executable directive.
+///
+/// Given
+///
+/// \code
+///   #pragma omp split counts(2, omp_fill)
+///   for (int i = 0; i < n; ++i) {}
+/// \endcode
+///
+/// ``ompSplitDirective()`` matches the split directive.
+extern const internal::VariadicDynCastAllOfMatcher<Stmt, OMPSplitDirective>
+    ompSplitDirective;
+
+/// Matches OpenMP ``counts`` clause used by ``#pragma omp split``.
+///
+/// Given ``#pragma omp split counts(1, 2, omp_fill)``, ``ompCountsClause()``
+/// matches the ``counts`` clause node.
+extern const internal::VariadicDynCastAllOfMatcher<OMPClause, OMPCountsClause>
+    ompCountsClause;
+
 /// Matches OpenMP ``default`` clause.
 ///
 /// Given
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index d5904bd1d6f26..71d504c659cc2 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11176,6 +11176,8 @@ def err_omp_bind_required_on_loop : Error<
   "construct">;
 def err_omp_loop_reduction_clause : Error<
   "'reduction' clause not allowed with '#pragma omp loop bind(teams)'">;
+def err_omp_split_counts_not_one_omp_fill : Error<
+  "exactly one 'omp_fill' must appear in the 'counts' clause">;
 def warn_break_binds_to_switch : Warning<
   "'break' is bound to loop, GCC binds it to switch">,
   InGroup<GccCompat>;
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index 61d76bafdfcde..e166894ea024b 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -244,6 +244,7 @@ def OMPTileDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>;
 def OMPStripeDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>;
 def OMPUnrollDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>;
 def OMPReverseDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>;
+def OMPSplitDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>;
 def OMPInterchangeDirective
     : StmtNode<OMPCanonicalLoopNestTransformationDirective>;
 def OMPCanonicalLoopSequenceTransformationDirective
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 08a3d88ee6a36..bd313d37cc4b5 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -6812,6 +6812,9 @@ class Parser : public CodeCompletionHandler {
   /// Parses the 'sizes' clause of a '#pragma omp tile' directive.
   OMPClause *ParseOpenMPSizesClause();
 
+  /// Parses the 'counts' clause of a '#pragma omp split' directive.
+  OMPClause *ParseOpenMPCountsClause();
+
   /// Parses the 'permutation' clause of a '#pragma omp interchange' directive.
   OMPClause *ParseOpenMPPermutationClause();
 
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 7853f29f98c25..3621ce96b8724 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -42,6 +42,7 @@ class FunctionScopeInfo;
 
 class DeclContext;
 class DeclGroupRef;
+class EnumConstantDecl;
 class ParsedAttr;
 class Scope;
 
@@ -457,6 +458,11 @@ class SemaOpenMP : public SemaBase {
   /// Called on well-formed '#pragma omp reverse'.
   StmtResult ActOnOpenMPReverseDirective(Stmt *AStmt, SourceLocation StartLoc,
                                          SourceLocation EndLoc);
+  /// Called on well-formed '#pragma omp split' after parsing of its
+  /// associated statement.
+  StmtResult ActOnOpenMPSplitDirective(ArrayRef<OMPClause *> Clauses,
+                                       Stmt *AStmt, SourceLocation StartLoc,
+                                       SourceLocation EndLoc);
   /// Called on well-formed '#pragma omp interchange' after parsing of its
   /// clauses and the associated statement.
   StmtResult ActOnOpenMPInterchangeDirective(ArrayRef<OMPClause *> Clauses,
@@ -911,6 +917,12 @@ class SemaOpenMP : public SemaBase {
                                     SourceLocation StartLoc,
                                     SourceLocation LParenLoc,
                                     SourceLocation EndLoc);
+  /// Called on well-formed 'counts' clause after parsing its arguments.
+  OMPClause *
+  ActOnOpenMPCountsClause(ArrayRef<Expr *> CountExprs, SourceLocation StartLoc,
+                          SourceLocation LParenLoc, SourceLocation EndLoc,
+                          std::optional<unsigned> FillIdx,
+                          SourceLocation FillLoc, unsigned FillCount);
   /// Called on well-form 'permutation' clause after parsing its arguments.
   OMPClause *ActOnOpenMPPermutationClause(ArrayRef<Expr *> PermExprs,
                                           SourceLocation StartLoc,
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index 783cd82895a90..9b798ed484454 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1965,6 +1965,7 @@ enum StmtCode {
   STMP_OMP_STRIPE_DIRECTIVE,
   STMT_OMP_UNROLL_DIRECTIVE,
   STMT_OMP_REVERSE_DIRECTIVE,
+  STMT_OMP_SPLIT_DIRECTIVE,
   STMT_OMP_INTERCHANGE_DIRECTIVE,
   STMT_OMP_FUSE_DIRECTIVE,
   STMT_OMP_FOR_DIRECTIVE,
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index d4826c3c6edca..3a35e17aff40b 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -15,10 +15,12 @@
 #include "clang/AST/Attr.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclOpenMP.h"
+#include "clang/AST/Expr.h"
 #include "clang/AST/ExprOpenMP.h"
 #include "clang/Basic/LLVM.h"
 #include "clang/Basic/OpenMPKinds.h"
 #include "clang/Basic/TargetInfo.h"
+#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/ErrorHandling.h"
 #include <algorithm>
@@ -986,6 +988,26 @@ OMPSizesClause *OMPSizesClause::CreateEmpty(const ASTContext &C,
   return new (Mem) OMPSizesClause(NumSizes);
 }
 
+OMPCountsClause *OMPCountsClause::Create(
+    const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc,
+    SourceLocation EndLoc, ArrayRef<Expr *> Counts,
+    std::optional<unsigned> FillIdx, SourceLocation FillLoc) {
+  OMPCountsClause *Clause = CreateEmpty(C, Counts.size());
+  Clause->setLocStart(StartLoc);
+  Clause->setLParenLoc(LParenLoc);
+  Clause->setLocEnd(EndLoc);
+  Clause->setCountsRefs(Counts);
+  Clause->setOmpFillIndex(FillIdx);
+  Clause->setOmpFillLoc(FillLoc);
+  return Clause;
+}
+
+OMPCountsClause *OMPCountsClause::CreateEmpty(const ASTContext &C,
+                                              unsigned NumCounts) {
+  void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(NumCounts));
+  return new (Mem) OMPCountsClause(NumCounts);
+}
+
 OMPPermutationClause *OMPPermutationClause::Create(const ASTContext &C,
                                                    SourceLocation StartLoc,
                                                    SourceLocation LParenLoc,
@@ -1984,6 +2006,19 @@ void OMPClausePrinter::VisitOMPSizesClause(OMPSizesClause *Node) {
   OS << ")";
 }
 
+void OMPClausePrinter::VisitOMPCountsClause(OMPCountsClause *Node) {
+  OS << "counts(";
+  std::optional<unsigned> FillIdx = Node->getOmpFillIndex();
+  ArrayRef<Expr *> Refs = Node->getCountsRefs();
+  llvm::interleaveComma(llvm::seq<unsigned>(Refs.size()), OS, [&](unsigned I) {
+    if (FillIdx && I == *FillIdx)
+      OS << "omp_fill";
+    else
+      Refs[I]->printPretty(OS, nullptr, Policy, 0);
+  });
+  OS << ")";
+}
+
 void OMPClausePrinter::VisitOMPPermutationClause(OMPPermutationClause *Node) {
   OS << "permutation(";
   llvm::interleaveComma(Node->getArgsRefs(), OS, [&](const Expr *E) {
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index a5b0cd3786a28..9d6b315effb41 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -552,6 +552,27 @@ OMPInterchangeDirective::CreateEmpty(const ASTContext &C, unsigned NumClauses,
       SourceLocation(), SourceLocation(), NumLoops);
 }
 
+OMPSplitDirective *
+OMPSplitDirective::Create(const ASTContext &C, SourceLocation StartLoc,
+                          SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
+                          unsigned NumLoops, Stmt *AssociatedStmt,
+                          Stmt *TransformedStmt, Stmt *PreInits) {
+  OMPSplitDirective *Dir = createDirective<OMPSplitDirective>(
+      C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc,
+      NumLoops);
+  Dir->setTransformedStmt(TransformedStmt);
+  Dir->setPreInits(PreInits);
+  return Dir;
+}
+
+OMPSplitDirective *OMPSplitDirective::CreateEmpty(const ASTContext &C,
+                                                  unsigned NumClauses,
+                                                  unsigned NumLoops) {
+  return createEmptyDirective<OMPSplitDirective>(
+      C, NumClauses, /*HasAssociatedStmt=*/true, TransformedStmtOffset + 1,
+      SourceLocation(), SourceLocation(), NumLoops);
+}
+
 OMPFuseDirective *OMPFuseDirective::Create(
     const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
     ArrayRef<OMPClause *> Clauses, unsigned NumGeneratedTopLevelLoops,
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 4d364fdcd5502..e0b930ba0a21a 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -800,6 +800,11 @@ void StmtPrinter::VisitOMPInterchangeDirective(OMPInterchangeDirective *Node) {
   PrintOMPExecutableDirective(Node);
 }
 
+void StmtPrinter::VisitOMPSplitDirective(OMPSplitDirective *Node) {
+  Indent() << "#pragma omp split";
+  PrintOMPExecutableDirective(Node);
+}
+
 void StmtPrinter::VisitOMPFuseDirective(OMPFuseDirective *Node) {
   Indent() << "#pragma omp fuse";
   PrintOMPExecutableDirective(Node);
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index e8c1f8a8ecb5f..c75652e5c1dd3 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -498,6 +498,12 @@ void OMPClauseProfiler::VisitOMPSizesClause(const OMPSizesClause *C) {
       Profiler->VisitExpr(E);
 }
 
+void OMPClauseProfiler::VisitOMPCountsClause(const OMPCountsClause *C) {
+  for (auto *E : C->getCountsRefs())
+    if ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/190397


More information about the llvm-commits mailing list