[clang] [llvm] [openmp] [Clang][OpenMP] Implement Loop splitting `#pragma omp split` directive (PR #190397)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 3 13:12:20 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-codegen
Author: Amit Tiwari (amitamd7)
<details>
<summary>Changes</summary>
Implement Loop-splitting #pragma omp split construct with counts clause.
Posting this PR after the revert of PR ([#<!-- -->183261](https://github.com/llvm/llvm-project/pull/183261))
---
Patch is 218.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/190397.diff
80 Files Affected:
- (modified) clang/bindings/python/clang/cindex.py (+3)
- (modified) clang/include/clang-c/Index.h (+4)
- (modified) clang/include/clang/AST/OpenMPClause.h (+101)
- (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+10)
- (modified) clang/include/clang/AST/StmtOpenMP.h (+78)
- (modified) clang/include/clang/ASTMatchers/ASTMatchers.h (+20)
- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2)
- (modified) clang/include/clang/Basic/StmtNodes.td (+1)
- (modified) clang/include/clang/Parse/Parser.h (+3)
- (modified) clang/include/clang/Sema/SemaOpenMP.h (+12)
- (modified) clang/include/clang/Serialization/ASTBitCodes.h (+1)
- (modified) clang/lib/AST/OpenMPClause.cpp (+35)
- (modified) clang/lib/AST/StmtOpenMP.cpp (+21)
- (modified) clang/lib/AST/StmtPrinter.cpp (+5)
- (modified) clang/lib/AST/StmtProfile.cpp (+10)
- (modified) clang/lib/ASTMatchers/ASTMatchersInternal.cpp (+4)
- (modified) clang/lib/ASTMatchers/Dynamic/Registry.cpp (+2)
- (modified) clang/lib/Basic/OpenMPKinds.cpp (+4-1)
- (modified) clang/lib/CodeGen/CGStmt.cpp (+3)
- (modified) clang/lib/CodeGen/CGStmtOpenMP.cpp (+8)
- (modified) clang/lib/CodeGen/CodeGenFunction.h (+1)
- (modified) clang/lib/Parse/ParseOpenMP.cpp (+59)
- (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1)
- (modified) clang/lib/Sema/SemaOpenMP.cpp (+271)
- (modified) clang/lib/Sema/TreeTransform.h (+44)
- (modified) clang/lib/Serialization/ASTReader.cpp (+15)
- (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+11)
- (modified) clang/lib/Serialization/ASTWriter.cpp (+11)
- (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+5)
- (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+1)
- (added) clang/test/AST/ast-dump-openmp-split.c (+19)
- (added) clang/test/Index/openmp-split.c (+11)
- (added) clang/test/OpenMP/split_analyze.c (+10)
- (added) clang/test/OpenMP/split_ast_print.cpp (+71)
- (added) clang/test/OpenMP/split_codegen.cpp (+1986)
- (added) clang/test/OpenMP/split_composition.cpp (+17)
- (added) clang/test/OpenMP/split_compound_associated.cpp (+13)
- (added) clang/test/OpenMP/split_counts_constexpr.cpp (+19)
- (added) clang/test/OpenMP/split_counts_ice.c (+56)
- (added) clang/test/OpenMP/split_counts_verify.c (+123)
- (added) clang/test/OpenMP/split_diag_errors.c (+61)
- (added) clang/test/OpenMP/split_distribute_inner_split.cpp (+14)
- (added) clang/test/OpenMP/split_driver_smoke.c (+12)
- (added) clang/test/OpenMP/split_iv_types.c (+24)
- (added) clang/test/OpenMP/split_loop_styles.cpp (+14)
- (added) clang/test/OpenMP/split_member_ctor.cpp (+20)
- (added) clang/test/OpenMP/split_messages.cpp (+108)
- (added) clang/test/OpenMP/split_nested_outer_only.c (+12)
- (added) clang/test/OpenMP/split_offload_codegen.cpp (+27)
- (added) clang/test/OpenMP/split_omp_fill.c (+36)
- (added) clang/test/OpenMP/split_openmp_version.cpp (+22)
- (added) clang/test/OpenMP/split_opts_simd_debug.cpp (+30)
- (added) clang/test/OpenMP/split_parallel_split.cpp (+15)
- (added) clang/test/OpenMP/split_pch_codegen.cpp (+43)
- (added) clang/test/OpenMP/split_range_for_diag.cpp (+25)
- (added) clang/test/OpenMP/split_serialize_module.cpp (+24)
- (added) clang/test/OpenMP/split_teams_nesting.cpp (+13)
- (added) clang/test/OpenMP/split_template_nttp.cpp (+15)
- (added) clang/test/OpenMP/split_templates.cpp (+30)
- (added) clang/test/OpenMP/split_trip_volatile.c (+14)
- (modified) clang/tools/libclang/CIndex.cpp (+7)
- (modified) clang/tools/libclang/CXCursor.cpp (+3)
- (modified) clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp (+62)
- (modified) clang/unittests/ASTMatchers/ASTMatchersTest.h (+14)
- (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+11-10)
- (added) openmp/runtime/test/transform/split/fill_first.c (+23)
- (added) openmp/runtime/test/transform/split/foreach.cpp (+24)
- (added) openmp/runtime/test/transform/split/intfor.c (+26)
- (added) openmp/runtime/test/transform/split/intfor_negstart.c (+27)
- (added) openmp/runtime/test/transform/split/iterfor.cpp (+139)
- (added) openmp/runtime/test/transform/split/leq_bound.c (+22)
- (added) openmp/runtime/test/transform/split/lit.local.cfg (+5)
- (added) openmp/runtime/test/transform/split/negative_incr.c (+22)
- (added) openmp/runtime/test/transform/split/nonconstant_incr.c (+22)
- (added) openmp/runtime/test/transform/split/parallel-split-intfor.c (+27)
- (added) openmp/runtime/test/transform/split/single_fill.c (+23)
- (added) openmp/runtime/test/transform/split/three_segments.c (+26)
- (added) openmp/runtime/test/transform/split/trip_one.c (+32)
- (added) openmp/runtime/test/transform/split/unsigned_iv.c (+24)
- (added) openmp/runtime/test/transform/split/zero_first_segment.c (+21)
``````````diff
diff --git a/clang/bindings/python/clang/cindex.py b/clang/bindings/python/clang/cindex.py
index b71f9ed2275e0..a90d48cf6d481 100644
--- a/clang/bindings/python/clang/cindex.py
+++ b/clang/bindings/python/clang/cindex.py
@@ -1453,6 +1453,9 @@ def is_unexposed(self):
# OpenMP fuse directive.
OMP_FUSE_DIRECTIVE = 311
+ # OpenMP split directive.
+ OMP_SPLIT_DIRECTIVE = 312
+
# OpenACC Compute Construct.
OPEN_ACC_COMPUTE_DIRECTIVE = 320
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index dcf1f4f1b4258..119bd68ff9814 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2166,6 +2166,10 @@ enum CXCursorKind {
*/
CXCursor_OMPFuseDirective = 311,
+ /** OpenMP split directive.
+ */
+ CXCursor_OMPSplitDirective = 312,
+
/** OpenACC Compute Construct.
*/
CXCursor_OpenACCComputeConstruct = 320,
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index af5d3f4698eda..ccf2c40bc5efa 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -39,6 +39,7 @@
#include "llvm/Support/Compiler.h"
#include "llvm/Support/TrailingObjects.h"
#include <cassert>
+#include <climits>
#include <cstddef>
#include <iterator>
#include <utility>
@@ -1023,6 +1024,106 @@ class OMPSizesClause final
}
};
+/// This represents the 'counts' clause in the '#pragma omp split' directive.
+///
+/// \code
+/// #pragma omp split counts(3, omp_fill, 2)
+/// for (int i = 0; i < n; ++i) { ... }
+/// \endcode
+class OMPCountsClause final
+ : public OMPClause,
+ private llvm::TrailingObjects<OMPCountsClause, Expr *> {
+ friend class OMPClauseReader;
+ friend class llvm::TrailingObjects<OMPCountsClause, Expr *>;
+
+ /// Location of '('.
+ SourceLocation LParenLoc;
+
+ /// Number of count expressions in the clause.
+ unsigned NumCounts = 0;
+
+ /// 0-based index of the omp_fill list item.
+ std::optional<unsigned> OmpFillIndex;
+
+ /// Source location of the omp_fill keyword.
+ SourceLocation OmpFillLoc;
+
+ /// Build an empty clause.
+ explicit OMPCountsClause(int NumCounts)
+ : OMPClause(llvm::omp::OMPC_counts, SourceLocation(), SourceLocation()),
+ NumCounts(NumCounts) {}
+
+ /// Sets the location of '('.
+ void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
+ void setOmpFillIndex(std::optional<unsigned> Idx) { OmpFillIndex = Idx; }
+ void setOmpFillLoc(SourceLocation Loc) { OmpFillLoc = Loc; }
+
+ /// Sets the count expressions.
+ void setCountsRefs(ArrayRef<Expr *> VL) {
+ assert(VL.size() == NumCounts);
+ llvm::copy(VL, getCountsRefs().begin());
+ }
+
+public:
+ /// Build a 'counts' AST node.
+ ///
+ /// \param C Context of the AST.
+ /// \param StartLoc Location of the 'counts' identifier.
+ /// \param LParenLoc Location of '('.
+ /// \param EndLoc Location of ')'.
+ /// \param Counts Content of the clause.
+ static OMPCountsClause *Create(const ASTContext &C, SourceLocation StartLoc,
+ SourceLocation LParenLoc,
+ SourceLocation EndLoc, ArrayRef<Expr *> Counts,
+ std::optional<unsigned> FillIdx,
+ SourceLocation FillLoc);
+
+ /// Build an empty 'counts' AST node for deserialization.
+ ///
+ /// \param C Context of the AST.
+ /// \param NumCounts Number of items in the clause.
+ static OMPCountsClause *CreateEmpty(const ASTContext &C, unsigned NumCounts);
+
+ /// Returns the location of '('.
+ SourceLocation getLParenLoc() const { return LParenLoc; }
+
+ /// Returns the number of list items.
+ unsigned getNumCounts() const { return NumCounts; }
+
+ std::optional<unsigned> getOmpFillIndex() const { return OmpFillIndex; }
+ SourceLocation getOmpFillLoc() const { return OmpFillLoc; }
+ bool hasOmpFill() const { return OmpFillIndex.has_value(); }
+
+ /// Returns the count expressions.
+ MutableArrayRef<Expr *> getCountsRefs() {
+ return getTrailingObjects(NumCounts);
+ }
+ ArrayRef<Expr *> getCountsRefs() const {
+ return getTrailingObjects(NumCounts);
+ }
+
+ child_range children() {
+ MutableArrayRef<Expr *> Counts = getCountsRefs();
+ return child_range(reinterpret_cast<Stmt **>(Counts.begin()),
+ reinterpret_cast<Stmt **>(Counts.end()));
+ }
+ const_child_range children() const {
+ ArrayRef<Expr *> Counts = getCountsRefs();
+ return const_child_range(reinterpret_cast<Stmt *const *>(Counts.begin()),
+ reinterpret_cast<Stmt *const *>(Counts.end()));
+ }
+ child_range used_children() {
+ return child_range(child_iterator(), child_iterator());
+ }
+ const_child_range used_children() const {
+ return const_child_range(const_child_iterator(), const_child_iterator());
+ }
+
+ static bool classof(const OMPClause *T) {
+ return T->getClauseKind() == llvm::omp::OMPC_counts;
+ }
+};
+
/// This class represents the 'permutation' clause in the
/// '#pragma omp interchange' directive.
///
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index ce6ad723191e0..1a14dd2c666b5 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3202,6 +3202,9 @@ DEF_TRAVERSE_STMT(OMPFuseDirective,
DEF_TRAVERSE_STMT(OMPInterchangeDirective,
{ TRY_TO(TraverseOMPExecutableDirective(S)); })
+DEF_TRAVERSE_STMT(OMPSplitDirective,
+ { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
DEF_TRAVERSE_STMT(OMPForDirective,
{ TRY_TO(TraverseOMPExecutableDirective(S)); })
@@ -3503,6 +3506,13 @@ bool RecursiveASTVisitor<Derived>::VisitOMPSizesClause(OMPSizesClause *C) {
return true;
}
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPCountsClause(OMPCountsClause *C) {
+ for (Expr *E : C->getCountsRefs())
+ TRY_TO(TraverseStmt(E));
+ return true;
+}
+
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPPermutationClause(
OMPPermutationClause *C) {
diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index bc6aeaa8d143c..dbc76e7df8ecd 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -6065,6 +6065,84 @@ class OMPFuseDirective final
}
};
+/// Represents the '#pragma omp split' loop transformation directive.
+///
+/// \code{.c}
+/// #pragma omp split counts(3, omp_fill, 2)
+/// for (int i = 0; i < n; ++i)
+/// ...
+/// \endcode
+///
+/// This directive transforms a single loop into multiple loops based on
+/// index ranges. The transformation splits the iteration space of the loop
+/// into multiple contiguous ranges. The \c counts clause is required and
+/// exactly one list item must be \c omp_fill.
+class OMPSplitDirective final
+ : public OMPCanonicalLoopNestTransformationDirective {
+ friend class ASTStmtReader;
+ friend class OMPExecutableDirective;
+
+ /// Offsets of child members.
+ enum {
+ PreInitsOffset = 0,
+ TransformedStmtOffset,
+ };
+
+ explicit OMPSplitDirective(SourceLocation StartLoc, SourceLocation EndLoc,
+ unsigned NumLoops)
+ : OMPCanonicalLoopNestTransformationDirective(
+ OMPSplitDirectiveClass, llvm::omp::OMPD_split, StartLoc, EndLoc,
+ NumLoops) {}
+
+ void setPreInits(Stmt *PreInits) {
+ Data->getChildren()[PreInitsOffset] = PreInits;
+ }
+
+ void setTransformedStmt(Stmt *S) {
+ Data->getChildren()[TransformedStmtOffset] = S;
+ }
+
+public:
+ /// Create a new AST node representation for '#pragma omp split'.
+ ///
+ /// \param C Context of the AST.
+ /// \param StartLoc Location of the introducer (e.g. the 'omp' token).
+ /// \param EndLoc Location of the directive's end (e.g. the tok::eod).
+ /// \param Clauses The directive's clauses (e.g. the required \c counts
+ /// clause).
+ /// \param NumLoops Number of affected loops (should be 1 for split).
+ /// \param AssociatedStmt The outermost associated loop.
+ /// \param TransformedStmt The loop nest after splitting, or nullptr in
+ /// dependent contexts.
+ /// \param PreInits Helper preinits statements for the loop nest.
+ static OMPSplitDirective *Create(const ASTContext &C, SourceLocation StartLoc,
+ SourceLocation EndLoc,
+ ArrayRef<OMPClause *> Clauses,
+ unsigned NumLoops, Stmt *AssociatedStmt,
+ Stmt *TransformedStmt, Stmt *PreInits);
+
+ /// Build an empty '#pragma omp split' AST node for deserialization.
+ ///
+ /// \param C Context of the AST.
+ /// \param NumClauses Number of clauses to allocate.
+ /// \param NumLoops Number of associated loops to allocate.
+ static OMPSplitDirective *CreateEmpty(const ASTContext &C,
+ unsigned NumClauses, unsigned NumLoops);
+
+ /// Gets/sets the associated loops after the transformation, i.e. after
+ /// de-sugaring.
+ Stmt *getTransformedStmt() const {
+ return Data->getChildren()[TransformedStmtOffset];
+ }
+
+ /// Return preinits statement.
+ Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; }
+
+ static bool classof(const Stmt *T) {
+ return T->getStmtClass() == OMPSplitDirectiveClass;
+ }
+};
+
/// This represents '#pragma omp scan' directive.
///
/// \code
diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h
index e8e7643e0dddd..87b6dbefa7a62 100644
--- a/clang/include/clang/ASTMatchers/ASTMatchers.h
+++ b/clang/include/clang/ASTMatchers/ASTMatchers.h
@@ -8781,6 +8781,26 @@ extern const internal::VariadicDynCastAllOfMatcher<Stmt,
OMPTargetUpdateDirective>
ompTargetUpdateDirective;
+/// Matches any ``#pragma omp split`` executable directive.
+///
+/// Given
+///
+/// \code
+/// #pragma omp split counts(2, omp_fill)
+/// for (int i = 0; i < n; ++i) {}
+/// \endcode
+///
+/// ``ompSplitDirective()`` matches the split directive.
+extern const internal::VariadicDynCastAllOfMatcher<Stmt, OMPSplitDirective>
+ ompSplitDirective;
+
+/// Matches OpenMP ``counts`` clause used by ``#pragma omp split``.
+///
+/// Given ``#pragma omp split counts(1, 2, omp_fill)``, ``ompCountsClause()``
+/// matches the ``counts`` clause node.
+extern const internal::VariadicDynCastAllOfMatcher<OMPClause, OMPCountsClause>
+ ompCountsClause;
+
/// Matches OpenMP ``default`` clause.
///
/// Given
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index d5904bd1d6f26..71d504c659cc2 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11176,6 +11176,8 @@ def err_omp_bind_required_on_loop : Error<
"construct">;
def err_omp_loop_reduction_clause : Error<
"'reduction' clause not allowed with '#pragma omp loop bind(teams)'">;
+def err_omp_split_counts_not_one_omp_fill : Error<
+ "exactly one 'omp_fill' must appear in the 'counts' clause">;
def warn_break_binds_to_switch : Warning<
"'break' is bound to loop, GCC binds it to switch">,
InGroup<GccCompat>;
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index 61d76bafdfcde..e166894ea024b 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -244,6 +244,7 @@ def OMPTileDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>;
def OMPStripeDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>;
def OMPUnrollDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>;
def OMPReverseDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>;
+def OMPSplitDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>;
def OMPInterchangeDirective
: StmtNode<OMPCanonicalLoopNestTransformationDirective>;
def OMPCanonicalLoopSequenceTransformationDirective
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 08a3d88ee6a36..bd313d37cc4b5 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -6812,6 +6812,9 @@ class Parser : public CodeCompletionHandler {
/// Parses the 'sizes' clause of a '#pragma omp tile' directive.
OMPClause *ParseOpenMPSizesClause();
+ /// Parses the 'counts' clause of a '#pragma omp split' directive.
+ OMPClause *ParseOpenMPCountsClause();
+
/// Parses the 'permutation' clause of a '#pragma omp interchange' directive.
OMPClause *ParseOpenMPPermutationClause();
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 7853f29f98c25..3621ce96b8724 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -42,6 +42,7 @@ class FunctionScopeInfo;
class DeclContext;
class DeclGroupRef;
+class EnumConstantDecl;
class ParsedAttr;
class Scope;
@@ -457,6 +458,11 @@ class SemaOpenMP : public SemaBase {
/// Called on well-formed '#pragma omp reverse'.
StmtResult ActOnOpenMPReverseDirective(Stmt *AStmt, SourceLocation StartLoc,
SourceLocation EndLoc);
+ /// Called on well-formed '#pragma omp split' after parsing of its
+ /// associated statement.
+ StmtResult ActOnOpenMPSplitDirective(ArrayRef<OMPClause *> Clauses,
+ Stmt *AStmt, SourceLocation StartLoc,
+ SourceLocation EndLoc);
/// Called on well-formed '#pragma omp interchange' after parsing of its
/// clauses and the associated statement.
StmtResult ActOnOpenMPInterchangeDirective(ArrayRef<OMPClause *> Clauses,
@@ -911,6 +917,12 @@ class SemaOpenMP : public SemaBase {
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
+ /// Called on well-formed 'counts' clause after parsing its arguments.
+ OMPClause *
+ ActOnOpenMPCountsClause(ArrayRef<Expr *> CountExprs, SourceLocation StartLoc,
+ SourceLocation LParenLoc, SourceLocation EndLoc,
+ std::optional<unsigned> FillIdx,
+ SourceLocation FillLoc, unsigned FillCount);
/// Called on well-form 'permutation' clause after parsing its arguments.
OMPClause *ActOnOpenMPPermutationClause(ArrayRef<Expr *> PermExprs,
SourceLocation StartLoc,
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index 783cd82895a90..9b798ed484454 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1965,6 +1965,7 @@ enum StmtCode {
STMP_OMP_STRIPE_DIRECTIVE,
STMT_OMP_UNROLL_DIRECTIVE,
STMT_OMP_REVERSE_DIRECTIVE,
+ STMT_OMP_SPLIT_DIRECTIVE,
STMT_OMP_INTERCHANGE_DIRECTIVE,
STMT_OMP_FUSE_DIRECTIVE,
STMT_OMP_FOR_DIRECTIVE,
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index d4826c3c6edca..3a35e17aff40b 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -15,10 +15,12 @@
#include "clang/AST/Attr.h"
#include "clang/AST/Decl.h"
#include "clang/AST/DeclOpenMP.h"
+#include "clang/AST/Expr.h"
#include "clang/AST/ExprOpenMP.h"
#include "clang/Basic/LLVM.h"
#include "clang/Basic/OpenMPKinds.h"
#include "clang/Basic/TargetInfo.h"
+#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/ErrorHandling.h"
#include <algorithm>
@@ -986,6 +988,26 @@ OMPSizesClause *OMPSizesClause::CreateEmpty(const ASTContext &C,
return new (Mem) OMPSizesClause(NumSizes);
}
+OMPCountsClause *OMPCountsClause::Create(
+ const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc,
+ SourceLocation EndLoc, ArrayRef<Expr *> Counts,
+ std::optional<unsigned> FillIdx, SourceLocation FillLoc) {
+ OMPCountsClause *Clause = CreateEmpty(C, Counts.size());
+ Clause->setLocStart(StartLoc);
+ Clause->setLParenLoc(LParenLoc);
+ Clause->setLocEnd(EndLoc);
+ Clause->setCountsRefs(Counts);
+ Clause->setOmpFillIndex(FillIdx);
+ Clause->setOmpFillLoc(FillLoc);
+ return Clause;
+}
+
+OMPCountsClause *OMPCountsClause::CreateEmpty(const ASTContext &C,
+ unsigned NumCounts) {
+ void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(NumCounts));
+ return new (Mem) OMPCountsClause(NumCounts);
+}
+
OMPPermutationClause *OMPPermutationClause::Create(const ASTContext &C,
SourceLocation StartLoc,
SourceLocation LParenLoc,
@@ -1984,6 +2006,19 @@ void OMPClausePrinter::VisitOMPSizesClause(OMPSizesClause *Node) {
OS << ")";
}
+void OMPClausePrinter::VisitOMPCountsClause(OMPCountsClause *Node) {
+ OS << "counts(";
+ std::optional<unsigned> FillIdx = Node->getOmpFillIndex();
+ ArrayRef<Expr *> Refs = Node->getCountsRefs();
+ llvm::interleaveComma(llvm::seq<unsigned>(Refs.size()), OS, [&](unsigned I) {
+ if (FillIdx && I == *FillIdx)
+ OS << "omp_fill";
+ else
+ Refs[I]->printPretty(OS, nullptr, Policy, 0);
+ });
+ OS << ")";
+}
+
void OMPClausePrinter::VisitOMPPermutationClause(OMPPermutationClause *Node) {
OS << "permutation(";
llvm::interleaveComma(Node->getArgsRefs(), OS, [&](const Expr *E) {
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index a5b0cd3786a28..9d6b315effb41 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -552,6 +552,27 @@ OMPInterchangeDirective::CreateEmpty(const ASTContext &C, unsigned NumClauses,
SourceLocation(), SourceLocation(), NumLoops);
}
+OMPSplitDirective *
+OMPSplitDirective::Create(const ASTContext &C, SourceLocation StartLoc,
+ SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
+ unsigned NumLoops, Stmt *AssociatedStmt,
+ Stmt *TransformedStmt, Stmt *PreInits) {
+ OMPSplitDirective *Dir = createDirective<OMPSplitDirective>(
+ C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc,
+ NumLoops);
+ Dir->setTransformedStmt(TransformedStmt);
+ Dir->setPreInits(PreInits);
+ return Dir;
+}
+
+OMPSplitDirective *OMPSplitDirective::CreateEmpty(const ASTContext &C,
+ unsigned NumClauses,
+ unsigned NumLoops) {
+ return createEmptyDirective<OMPSplitDirective>(
+ C, NumClauses, /*HasAssociatedStmt=*/true, TransformedStmtOffset + 1,
+ SourceLocation(), SourceLocation(), NumLoops);
+}
+
OMPFuseDirective *OMPFuseDirective::Create(
const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
ArrayRef<OMPClause *> Clauses, unsigned NumGeneratedTopLevelLoops,
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 4d364fdcd5502..e0b930ba0a21a 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -800,6 +800,11 @@ void StmtPrinter::VisitOMPInterchangeDirective(OMPInterchangeDirective *Node) {
PrintOMPExecutableDirective(Node);
}
+void StmtPrinter::VisitOMPSplitDirective(OMPSplitDirective *Node) {
+ Indent() << "#pragma omp split";
+ PrintOMPExecutableDirective(Node);
+}
+
void StmtPrinter::VisitOMPFuseDirective(OMPFuseDirective *Node) {
Indent() << "#pragma omp fuse";
PrintOMPExecutableDirective(Node);
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index e8c1f8a8ecb5f..c75652e5c1dd3 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -498,6 +498,12 @@ void OMPClauseProfiler::VisitOMPSizesClause(const OMPSizesClause *C) {
Profiler->VisitExpr(E);
}
+void OMPClauseProfiler::VisitOMPCountsClause(const OMPCountsClause *C) {
+ for (auto *E : C->getCountsRefs())
+ if ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/190397
More information about the llvm-commits
mailing list