[clang] 2130117 - [Clang][OpenMP] Allow loop-transformations with template parameters.
Michael Kruse via cfe-commits
cfe-commits at lists.llvm.org
Wed Oct 6 10:21:12 PDT 2021
Author: Michael Kruse
Date: 2021-10-06T12:21:04-05:00
New Revision: 2130117f92e51df73ac8c4b7e37f7f89178a89f2
URL: https://github.com/llvm/llvm-project/commit/2130117f92e51df73ac8c4b7e37f7f89178a89f2
DIFF: https://github.com/llvm/llvm-project/commit/2130117f92e51df73ac8c4b7e37f7f89178a89f2.diff
LOG: [Clang][OpenMP] Allow loop-transformations with template parameters.
Clang would reject
#pragma omp for
#pragma omp tile sizes(P)
for (int i = 0; i < 128; ++i) {}
where P is a template parameter, but the loop itself is not
template-dependent. Because P context-dependent, the TransformedStmt
cannot be generated and therefore is nullptr (until the template is
instantiated by TreeTransform). The OMPForDirective would still expect
the a loop is the dependent context and trigger an error.
Fix by introducing a NumGeneratedLoops field to OMPLoopTransformation.
This is used to distinguish the case where no TransformedStmt will be
generated at all (e.g. #pragma omp unroll full) and template
instantiation is needed. In the latter case, delay resolving the
iteration space like when the for-loop itself is template-dependent
until the template instatiation.
A more radical solution would always delay the iteration space analysis
until template instantiation, but would also break many test cases.
Reviewed By: ABataev
Differential Revision: https://reviews.llvm.org/D111124
Added:
Modified:
clang/include/clang/AST/StmtOpenMP.h
clang/lib/AST/StmtOpenMP.cpp
clang/lib/Sema/SemaOpenMP.cpp
clang/lib/Serialization/ASTReaderStmt.cpp
clang/lib/Serialization/ASTWriterStmt.cpp
clang/test/OpenMP/tile_ast_print.cpp
clang/test/OpenMP/unroll_ast_print.cpp
Removed:
################################################################################
diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index 285426d26e21..60d47b93ba79 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -959,6 +959,9 @@ class OMPLoopBasedDirective : public OMPExecutableDirective {
class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
friend class ASTStmtReader;
+ /// Number of loops generated by this loop transformation.
+ unsigned NumGeneratedLoops = 0;
+
protected:
explicit OMPLoopTransformationDirective(StmtClass SC,
OpenMPDirectiveKind Kind,
@@ -967,10 +970,16 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
unsigned NumAssociatedLoops)
: OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {}
+ /// Set the number of loops generated by this loop transformation.
+ void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; }
+
public:
/// Return the number of associated (consumed) loops.
unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
+ /// Return the number of loops generated by this loop transformation.
+ unsigned getNumGeneratedLoops() { return NumGeneratedLoops; }
+
/// Get the de-sugared statements after after the loop transformation.
///
/// Might be nullptr if either the directive generates no loops and is handled
@@ -5058,7 +5067,9 @@ class OMPTileDirective final : public OMPLoopTransformationDirective {
unsigned NumLoops)
: OMPLoopTransformationDirective(OMPTileDirectiveClass,
llvm::omp::OMPD_tile, StartLoc, EndLoc,
- NumLoops) {}
+ NumLoops) {
+ setNumGeneratedLoops(3 * NumLoops);
+ }
void setPreInits(Stmt *PreInits) {
Data->getChildren()[PreInitsOffset] = PreInits;
@@ -5163,7 +5174,7 @@ class OMPUnrollDirective final : public OMPLoopTransformationDirective {
static OMPUnrollDirective *
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
- Stmt *TransformedStmt, Stmt *PreInits);
+ unsigned NumGeneratedLoops, Stmt *TransformedStmt, Stmt *PreInits);
/// Build an empty '#pragma omp unroll' AST node for deserialization.
///
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index c615463f42da..014274f46cae 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -138,9 +138,18 @@ bool OMPLoopBasedDirective::doForAllLoops(
Stmt *TransformedStmt = Dir->getTransformedStmt();
if (!TransformedStmt) {
- // May happen if the loop transformation does not result in a
- // generated loop (such as full unrolling).
- break;
+ unsigned NumGeneratedLoops = Dir->getNumGeneratedLoops();
+ if (NumGeneratedLoops == 0) {
+ // May happen if the loop transformation does not result in a
+ // generated loop (such as full unrolling).
+ break;
+ }
+ if (NumGeneratedLoops > 0) {
+ // The loop transformation construct has generated loops, but these
+ // may not have been generated yet due to being in a dependent
+ // context.
+ return true;
+ }
}
CurStmt = TransformedStmt;
@@ -419,10 +428,13 @@ OMPTileDirective *OMPTileDirective::CreateEmpty(const ASTContext &C,
OMPUnrollDirective *
OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc,
SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
- Stmt *AssociatedStmt, Stmt *TransformedStmt,
- Stmt *PreInits) {
+ Stmt *AssociatedStmt, unsigned NumGeneratedLoops,
+ Stmt *TransformedStmt, Stmt *PreInits) {
+ assert(NumGeneratedLoops <= 1 && "Unrolling generates at most one loop");
+
auto *Dir = createDirective<OMPUnrollDirective>(
C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc);
+ Dir->setNumGeneratedLoops(NumGeneratedLoops);
Dir->setTransformedStmt(TransformedStmt);
Dir->setPreInits(PreInits);
return Dir;
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index af70a180b27c..850b6f162d72 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -12919,10 +12919,12 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
Body, OriginalInits))
return StmtError();
+ unsigned NumGeneratedLoops = PartialClause ? 1 : 0;
+
// Delay unrolling to when template is completely instantiated.
if (CurContext->isDependentContext())
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
- nullptr, nullptr);
+ NumGeneratedLoops, nullptr, nullptr);
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
@@ -12941,9 +12943,9 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
// The generated loop may only be passed to other loop-associated directive
// when a partial clause is specified. Without the requirement it is
// sufficient to generate loop unroll metadata at code-generation.
- if (!PartialClause)
+ if (NumGeneratedLoops == 0)
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
- nullptr, nullptr);
+ NumGeneratedLoops, nullptr, nullptr);
// Otherwise, we need to provide a de-sugared/transformed AST that can be
// associated with another loop directive.
@@ -13164,7 +13166,8 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc());
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
- OuterFor, buildPreInits(Context, PreInits));
+ NumGeneratedLoops, OuterFor,
+ buildPreInits(Context, PreInits));
}
OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr,
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 34a58831e0d4..4e6eaf77ff56 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2327,6 +2327,7 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) {
void ASTStmtReader::VisitOMPLoopTransformationDirective(
OMPLoopTransformationDirective *D) {
VisitOMPLoopBasedDirective(D);
+ D->setNumGeneratedLoops(Record.readUInt32());
}
void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index bf32294bc95f..000bf808d32b 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2226,6 +2226,7 @@ void ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) {
void ASTStmtWriter::VisitOMPLoopTransformationDirective(
OMPLoopTransformationDirective *D) {
VisitOMPLoopBasedDirective(D);
+ Record.writeUInt32(D->getNumGeneratedLoops());
}
void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {
diff --git a/clang/test/OpenMP/tile_ast_print.cpp b/clang/test/OpenMP/tile_ast_print.cpp
index 37791f0a8475..14f064358d8d 100644
--- a/clang/test/OpenMP/tile_ast_print.cpp
+++ b/clang/test/OpenMP/tile_ast_print.cpp
@@ -162,4 +162,25 @@ void tfoo6() {
}
+// PRINT-LABEL: template <int Tile> void foo7(int start, int stop, int step) {
+// DUMP-LABEL: FunctionTemplateDecl {{.*}} foo7
+template <int Tile>
+void foo7(int start, int stop, int step) {
+ // PRINT: #pragma omp tile sizes(Tile)
+ // DUMP: OMPTileDirective
+ // DUMP-NEXT: OMPSizesClause
+ // DUMP-NEXT: DeclRefExpr {{.*}} 'Tile' 'int'
+ #pragma omp tile sizes(Tile)
+ // PRINT-NEXT: for (int i = start; i < stop; i += step)
+ // DUMP-NEXT: ForStmt
+ for (int i = start; i < stop; i += step)
+ // PRINT-NEXT: body(i);
+ // DUMP: CallExpr
+ body(i);
+}
+void tfoo7() {
+ foo7<5>(0, 42, 2);
+}
+
+
#endif
diff --git a/clang/test/OpenMP/unroll_ast_print.cpp b/clang/test/OpenMP/unroll_ast_print.cpp
index 63e7b1dbe6ed..4d858284877f 100644
--- a/clang/test/OpenMP/unroll_ast_print.cpp
+++ b/clang/test/OpenMP/unroll_ast_print.cpp
@@ -124,4 +124,26 @@ void unroll_template() {
unroll_templated<int,0,1024,1,4>();
}
+
+// PRINT-LABEL: template <int Factor> void unroll_templated_factor(int start, int stop, int step) {
+// DUMP-LABEL: FunctionTemplateDecl {{.*}} unroll_templated_factor
+template <int Factor>
+void unroll_templated_factor(int start, int stop, int step) {
+ // PRINT: #pragma omp unroll partial(Factor)
+ // DUMP: OMPUnrollDirective
+ // DUMP-NEXT: OMPPartialClause
+ // DUMP-NEXT: DeclRefExpr {{.*}} 'Factor' 'int'
+ #pragma omp unroll partial(Factor)
+ // PRINT-NEXT: for (int i = start; i < stop; i += step)
+ // DUMP-NEXT: ForStmt
+ for (int i = start; i < stop; i += step)
+ // PRINT-NEXT: body(i);
+ // DUMP: CallExpr
+ body(i);
+}
+void unroll_template_factor() {
+ unroll_templated_factor<4>(0, 42, 2);
+}
+
+
#endif
More information about the cfe-commits
mailing list