[clang] 2130117 - [Clang][OpenMP] Allow loop-transformations with template parameters.

Michael Kruse via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 6 10:21:12 PDT 2021


Author: Michael Kruse
Date: 2021-10-06T12:21:04-05:00
New Revision: 2130117f92e51df73ac8c4b7e37f7f89178a89f2

URL: https://github.com/llvm/llvm-project/commit/2130117f92e51df73ac8c4b7e37f7f89178a89f2
DIFF: https://github.com/llvm/llvm-project/commit/2130117f92e51df73ac8c4b7e37f7f89178a89f2.diff

LOG: [Clang][OpenMP] Allow loop-transformations with template parameters.

Clang would reject

    #pragma omp for
    #pragma omp tile sizes(P)
    for (int i = 0; i < 128; ++i) {}

where P is a template parameter, but the loop itself is not
template-dependent. Because P context-dependent, the TransformedStmt
cannot be generated and therefore is nullptr (until the template is
instantiated by TreeTransform). The OMPForDirective would still expect
the a loop is the dependent context and trigger an error.

Fix by introducing a NumGeneratedLoops field to OMPLoopTransformation.
This is used to distinguish the case where no TransformedStmt will be
generated at all (e.g. #pragma omp unroll full) and template
instantiation is needed. In the latter case, delay resolving the
iteration space like when the for-loop itself is template-dependent
until the template instatiation.

A more radical solution would always delay the iteration space analysis
until template instantiation, but would also break many test cases.

Reviewed By: ABataev

Differential Revision: https://reviews.llvm.org/D111124

Added: 
    

Modified: 
    clang/include/clang/AST/StmtOpenMP.h
    clang/lib/AST/StmtOpenMP.cpp
    clang/lib/Sema/SemaOpenMP.cpp
    clang/lib/Serialization/ASTReaderStmt.cpp
    clang/lib/Serialization/ASTWriterStmt.cpp
    clang/test/OpenMP/tile_ast_print.cpp
    clang/test/OpenMP/unroll_ast_print.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index 285426d26e21..60d47b93ba79 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -959,6 +959,9 @@ class OMPLoopBasedDirective : public OMPExecutableDirective {
 class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
   friend class ASTStmtReader;
 
+  /// Number of loops generated by this loop transformation.
+  unsigned NumGeneratedLoops = 0;
+
 protected:
   explicit OMPLoopTransformationDirective(StmtClass SC,
                                           OpenMPDirectiveKind Kind,
@@ -967,10 +970,16 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
                                           unsigned NumAssociatedLoops)
       : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {}
 
+  /// Set the number of loops generated by this loop transformation.
+  void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; }
+
 public:
   /// Return the number of associated (consumed) loops.
   unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
 
+  /// Return the number of loops generated by this loop transformation.
+  unsigned getNumGeneratedLoops() { return NumGeneratedLoops; }
+
   /// Get the de-sugared statements after after the loop transformation.
   ///
   /// Might be nullptr if either the directive generates no loops and is handled
@@ -5058,7 +5067,9 @@ class OMPTileDirective final : public OMPLoopTransformationDirective {
                             unsigned NumLoops)
       : OMPLoopTransformationDirective(OMPTileDirectiveClass,
                                        llvm::omp::OMPD_tile, StartLoc, EndLoc,
-                                       NumLoops) {}
+                                       NumLoops) {
+    setNumGeneratedLoops(3 * NumLoops);
+  }
 
   void setPreInits(Stmt *PreInits) {
     Data->getChildren()[PreInitsOffset] = PreInits;
@@ -5163,7 +5174,7 @@ class OMPUnrollDirective final : public OMPLoopTransformationDirective {
   static OMPUnrollDirective *
   Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
          ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
-         Stmt *TransformedStmt, Stmt *PreInits);
+         unsigned NumGeneratedLoops, Stmt *TransformedStmt, Stmt *PreInits);
 
   /// Build an empty '#pragma omp unroll' AST node for deserialization.
   ///

diff  --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index c615463f42da..014274f46cae 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -138,9 +138,18 @@ bool OMPLoopBasedDirective::doForAllLoops(
 
       Stmt *TransformedStmt = Dir->getTransformedStmt();
       if (!TransformedStmt) {
-        // May happen if the loop transformation does not result in a
-        // generated loop (such as full unrolling).
-        break;
+        unsigned NumGeneratedLoops = Dir->getNumGeneratedLoops();
+        if (NumGeneratedLoops == 0) {
+          // May happen if the loop transformation does not result in a
+          // generated loop (such as full unrolling).
+          break;
+        }
+        if (NumGeneratedLoops > 0) {
+          // The loop transformation construct has generated loops, but these
+          // may not have been generated yet due to being in a dependent
+          // context.
+          return true;
+        }
       }
 
       CurStmt = TransformedStmt;
@@ -419,10 +428,13 @@ OMPTileDirective *OMPTileDirective::CreateEmpty(const ASTContext &C,
 OMPUnrollDirective *
 OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc,
                            SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
-                           Stmt *AssociatedStmt, Stmt *TransformedStmt,
-                           Stmt *PreInits) {
+                           Stmt *AssociatedStmt, unsigned NumGeneratedLoops,
+                           Stmt *TransformedStmt, Stmt *PreInits) {
+  assert(NumGeneratedLoops <= 1 && "Unrolling generates at most one loop");
+
   auto *Dir = createDirective<OMPUnrollDirective>(
       C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc);
+  Dir->setNumGeneratedLoops(NumGeneratedLoops);
   Dir->setTransformedStmt(TransformedStmt);
   Dir->setPreInits(PreInits);
   return Dir;

diff  --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index af70a180b27c..850b6f162d72 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -12919,10 +12919,12 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
                                   Body, OriginalInits))
     return StmtError();
 
+  unsigned NumGeneratedLoops = PartialClause ? 1 : 0;
+
   // Delay unrolling to when template is completely instantiated.
   if (CurContext->isDependentContext())
     return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                      nullptr, nullptr);
+                                      NumGeneratedLoops, nullptr, nullptr);
 
   OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
 
@@ -12941,9 +12943,9 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
   // The generated loop may only be passed to other loop-associated directive
   // when a partial clause is specified. Without the requirement it is
   // sufficient to generate loop unroll metadata at code-generation.
-  if (!PartialClause)
+  if (NumGeneratedLoops == 0)
     return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                      nullptr, nullptr);
+                                      NumGeneratedLoops, nullptr, nullptr);
 
   // Otherwise, we need to provide a de-sugared/transformed AST that can be
   // associated with another loop directive.
@@ -13164,7 +13166,8 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
               LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc());
 
   return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                    OuterFor, buildPreInits(Context, PreInits));
+                                    NumGeneratedLoops, OuterFor,
+                                    buildPreInits(Context, PreInits));
 }
 
 OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr,

diff  --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 34a58831e0d4..4e6eaf77ff56 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2327,6 +2327,7 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) {
 void ASTStmtReader::VisitOMPLoopTransformationDirective(
     OMPLoopTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
+  D->setNumGeneratedLoops(Record.readUInt32());
 }
 
 void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {

diff  --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index bf32294bc95f..000bf808d32b 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2226,6 +2226,7 @@ void ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) {
 void ASTStmtWriter::VisitOMPLoopTransformationDirective(
     OMPLoopTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
+  Record.writeUInt32(D->getNumGeneratedLoops());
 }
 
 void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {

diff  --git a/clang/test/OpenMP/tile_ast_print.cpp b/clang/test/OpenMP/tile_ast_print.cpp
index 37791f0a8475..14f064358d8d 100644
--- a/clang/test/OpenMP/tile_ast_print.cpp
+++ b/clang/test/OpenMP/tile_ast_print.cpp
@@ -162,4 +162,25 @@ void tfoo6() {
 }
 
 
+// PRINT-LABEL: template <int Tile> void foo7(int start, int stop, int step) {
+// DUMP-LABEL: FunctionTemplateDecl {{.*}} foo7
+template <int Tile>
+void foo7(int start, int stop, int step) {
+  // PRINT: #pragma omp tile sizes(Tile)
+  // DUMP:      OMPTileDirective
+  // DUMP-NEXT:   OMPSizesClause
+  // DUMP-NEXT:     DeclRefExpr {{.*}} 'Tile' 'int'
+  #pragma omp tile sizes(Tile)
+    // PRINT-NEXT:  for (int i = start; i < stop; i += step)
+    // DUMP-NEXT: ForStmt
+    for (int i = start; i < stop; i += step)
+      // PRINT-NEXT: body(i);
+      // DUMP:  CallExpr
+      body(i);
+}
+void tfoo7() {
+  foo7<5>(0, 42, 2);
+}
+
+
 #endif

diff  --git a/clang/test/OpenMP/unroll_ast_print.cpp b/clang/test/OpenMP/unroll_ast_print.cpp
index 63e7b1dbe6ed..4d858284877f 100644
--- a/clang/test/OpenMP/unroll_ast_print.cpp
+++ b/clang/test/OpenMP/unroll_ast_print.cpp
@@ -124,4 +124,26 @@ void unroll_template() {
   unroll_templated<int,0,1024,1,4>();
 }
 
+
+// PRINT-LABEL: template <int Factor> void unroll_templated_factor(int start, int stop, int step) {
+// DUMP-LABEL:  FunctionTemplateDecl {{.*}} unroll_templated_factor
+template <int Factor>
+void unroll_templated_factor(int start, int stop, int step) {
+  // PRINT: #pragma omp unroll partial(Factor)
+  // DUMP:      OMPUnrollDirective
+  // DUMP-NEXT: OMPPartialClause
+  // DUMP-NEXT:   DeclRefExpr {{.*}} 'Factor' 'int'
+  #pragma omp unroll partial(Factor)
+    // PRINT-NEXT: for (int i = start; i < stop; i += step)
+    // DUMP-NEXT:  ForStmt
+    for (int i = start; i < stop; i += step)
+      // PRINT-NEXT: body(i);
+      // DUMP:  CallExpr
+      body(i);
+}
+void unroll_template_factor() {
+  unroll_templated_factor<4>(0, 42, 2);
+}
+
+
 #endif


        


More information about the cfe-commits mailing list