[clang] f37e8b0 - [Clang][OpenMP] Infix OMPLoopTransformationDirective abstract class. NFC.

Michael Kruse via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 6 08:49:15 PDT 2021


Author: Michael Kruse
Date: 2021-10-06T10:49:07-05:00
New Revision: f37e8b0b831e61d3b6033829fff05d6d193ab735

URL: https://github.com/llvm/llvm-project/commit/f37e8b0b831e61d3b6033829fff05d6d193ab735
DIFF: https://github.com/llvm/llvm-project/commit/f37e8b0b831e61d3b6033829fff05d6d193ab735.diff

LOG: [Clang][OpenMP] Infix OMPLoopTransformationDirective abstract class. NFC.

Insert OMPLoopTransformationDirective between OMPLoopBasedDirective and the loop transformations OMPTileDirective and OMPUnrollDirective. This simplifies handling of loop transformations not requiring distinguishing between OMPTileDirective and OMPUnrollDirective anymore.

Reviewed By: ABataev

Differential Revision: https://reviews.llvm.org/D111119

Added: 
    

Modified: 
    clang/include/clang/AST/StmtOpenMP.h
    clang/include/clang/Basic/StmtNodes.td
    clang/lib/AST/StmtOpenMP.cpp
    clang/lib/AST/StmtProfile.cpp
    clang/lib/CodeGen/CGStmtOpenMP.cpp
    clang/lib/Sema/SemaOpenMP.cpp
    clang/lib/Serialization/ASTReaderStmt.cpp
    clang/lib/Serialization/ASTWriterStmt.cpp
    clang/tools/libclang/CIndex.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index f028c3b32398..285426d26e21 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -889,22 +889,23 @@ class OMPLoopBasedDirective : public OMPExecutableDirective {
 
   /// Calls the specified callback function for all the loops in \p CurStmt,
   /// from the outermost to the innermost.
-  static bool doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops,
-                            unsigned NumLoops,
-                            llvm::function_ref<bool(unsigned, Stmt *)> Callback,
-                            llvm::function_ref<void(OMPLoopBasedDirective *)>
-                                OnTransformationCallback);
+  static bool
+  doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops,
+                unsigned NumLoops,
+                llvm::function_ref<bool(unsigned, Stmt *)> Callback,
+                llvm::function_ref<void(OMPLoopTransformationDirective *)>
+                    OnTransformationCallback);
   static bool
   doForAllLoops(const Stmt *CurStmt, bool TryImperfectlyNestedLoops,
                 unsigned NumLoops,
                 llvm::function_ref<bool(unsigned, const Stmt *)> Callback,
-                llvm::function_ref<void(const OMPLoopBasedDirective *)>
+                llvm::function_ref<void(const OMPLoopTransformationDirective *)>
                     OnTransformationCallback) {
     auto &&NewCallback = [Callback](unsigned Cnt, Stmt *CurStmt) {
       return Callback(Cnt, CurStmt);
     };
     auto &&NewTransformCb =
-        [OnTransformationCallback](OMPLoopBasedDirective *A) {
+        [OnTransformationCallback](OMPLoopTransformationDirective *A) {
           OnTransformationCallback(A);
         };
     return doForAllLoops(const_cast<Stmt *>(CurStmt), TryImperfectlyNestedLoops,
@@ -917,7 +918,7 @@ class OMPLoopBasedDirective : public OMPExecutableDirective {
   doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops,
                 unsigned NumLoops,
                 llvm::function_ref<bool(unsigned, Stmt *)> Callback) {
-    auto &&TransformCb = [](OMPLoopBasedDirective *) {};
+    auto &&TransformCb = [](OMPLoopTransformationDirective *) {};
     return doForAllLoops(CurStmt, TryImperfectlyNestedLoops, NumLoops, Callback,
                          TransformCb);
   }
@@ -954,6 +955,38 @@ class OMPLoopBasedDirective : public OMPExecutableDirective {
   }
 };
 
+/// The base class for all loop transformation directives.
+class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
+  friend class ASTStmtReader;
+
+protected:
+  explicit OMPLoopTransformationDirective(StmtClass SC,
+                                          OpenMPDirectiveKind Kind,
+                                          SourceLocation StartLoc,
+                                          SourceLocation EndLoc,
+                                          unsigned NumAssociatedLoops)
+      : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {}
+
+public:
+  /// Return the number of associated (consumed) loops.
+  unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
+
+  /// Get the de-sugared statements after after the loop transformation.
+  ///
+  /// Might be nullptr if either the directive generates no loops and is handled
+  /// directly in CodeGen, or resolving a template-dependence context is
+  /// required.
+  Stmt *getTransformedStmt() const;
+
+  /// Return preinits statement.
+  Stmt *getPreInits() const;
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == OMPTileDirectiveClass ||
+           T->getStmtClass() == OMPUnrollDirectiveClass;
+  }
+};
+
 /// This is a common base class for loop directives ('omp simd', 'omp
 /// for', 'omp for simd' etc.). It is responsible for the loop code generation.
 ///
@@ -5011,7 +5044,7 @@ class OMPTargetTeamsDistributeSimdDirective final : public OMPLoopDirective {
 };
 
 /// This represents the '#pragma omp tile' loop transformation directive.
-class OMPTileDirective final : public OMPLoopBasedDirective {
+class OMPTileDirective final : public OMPLoopTransformationDirective {
   friend class ASTStmtReader;
   friend class OMPExecutableDirective;
 
@@ -5023,8 +5056,9 @@ class OMPTileDirective final : public OMPLoopBasedDirective {
 
   explicit OMPTileDirective(SourceLocation StartLoc, SourceLocation EndLoc,
                             unsigned NumLoops)
-      : OMPLoopBasedDirective(OMPTileDirectiveClass, llvm::omp::OMPD_tile,
-                              StartLoc, EndLoc, NumLoops) {}
+      : OMPLoopTransformationDirective(OMPTileDirectiveClass,
+                                       llvm::omp::OMPD_tile, StartLoc, EndLoc,
+                                       NumLoops) {}
 
   void setPreInits(Stmt *PreInits) {
     Data->getChildren()[PreInitsOffset] = PreInits;
@@ -5061,8 +5095,6 @@ class OMPTileDirective final : public OMPLoopBasedDirective {
   static OMPTileDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses,
                                        unsigned NumLoops);
 
-  unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
-
   /// Gets/sets the associated loops after tiling.
   ///
   /// This is in de-sugared format stored as a CompoundStmt.
@@ -5092,7 +5124,7 @@ class OMPTileDirective final : public OMPLoopBasedDirective {
 /// #pragma omp unroll
 /// for (int i = 0; i < 64; ++i)
 /// \endcode
-class OMPUnrollDirective final : public OMPLoopBasedDirective {
+class OMPUnrollDirective final : public OMPLoopTransformationDirective {
   friend class ASTStmtReader;
   friend class OMPExecutableDirective;
 
@@ -5103,8 +5135,9 @@ class OMPUnrollDirective final : public OMPLoopBasedDirective {
   };
 
   explicit OMPUnrollDirective(SourceLocation StartLoc, SourceLocation EndLoc)
-      : OMPLoopBasedDirective(OMPUnrollDirectiveClass, llvm::omp::OMPD_unroll,
-                              StartLoc, EndLoc, 1) {}
+      : OMPLoopTransformationDirective(OMPUnrollDirectiveClass,
+                                       llvm::omp::OMPD_unroll, StartLoc, EndLoc,
+                                       1) {}
 
   /// Set the pre-init statements.
   void setPreInits(Stmt *PreInits) {

diff  --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index c5540a16056f..f002c3632829 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -224,8 +224,9 @@ def OMPLoopBasedDirective : StmtNode<OMPExecutableDirective, 1>;
 def OMPLoopDirective : StmtNode<OMPLoopBasedDirective, 1>;
 def OMPParallelDirective : StmtNode<OMPExecutableDirective>;
 def OMPSimdDirective : StmtNode<OMPLoopDirective>;
-def OMPTileDirective : StmtNode<OMPLoopBasedDirective>;
-def OMPUnrollDirective : StmtNode<OMPLoopBasedDirective>;
+def OMPLoopTransformationDirective : StmtNode<OMPLoopBasedDirective, 1>;
+def OMPTileDirective : StmtNode<OMPLoopTransformationDirective>;
+def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>;
 def OMPForDirective : StmtNode<OMPLoopDirective>;
 def OMPForSimdDirective : StmtNode<OMPLoopDirective>;
 def OMPSectionsDirective : StmtNode<OMPExecutableDirective>;

diff  --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index f461d9b65042..c615463f42da 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -125,28 +125,25 @@ OMPLoopBasedDirective::tryToFindNextInnerLoop(Stmt *CurStmt,
 bool OMPLoopBasedDirective::doForAllLoops(
     Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops,
     llvm::function_ref<bool(unsigned, Stmt *)> Callback,
-    llvm::function_ref<void(OMPLoopBasedDirective *)>
+    llvm::function_ref<void(OMPLoopTransformationDirective *)>
         OnTransformationCallback) {
   CurStmt = CurStmt->IgnoreContainers();
   for (unsigned Cnt = 0; Cnt < NumLoops; ++Cnt) {
     while (true) {
-      auto *OrigStmt = CurStmt;
-      if (auto *Dir = dyn_cast<OMPTileDirective>(OrigStmt)) {
-        OnTransformationCallback(Dir);
-        CurStmt = Dir->getTransformedStmt();
-      } else if (auto *Dir = dyn_cast<OMPUnrollDirective>(OrigStmt)) {
-        OnTransformationCallback(Dir);
-        CurStmt = Dir->getTransformedStmt();
-      } else {
+      auto *Dir = dyn_cast<OMPLoopTransformationDirective>(CurStmt);
+      if (!Dir)
         break;
-      }
 
-      if (!CurStmt) {
-        // May happen if the loop transformation does not result in a generated
-        // loop (such as full unrolling).
-        CurStmt = OrigStmt;
+      OnTransformationCallback(Dir);
+
+      Stmt *TransformedStmt = Dir->getTransformedStmt();
+      if (!TransformedStmt) {
+        // May happen if the loop transformation does not result in a
+        // generated loop (such as full unrolling).
         break;
       }
+
+      CurStmt = TransformedStmt;
     }
     if (auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(CurStmt))
       CurStmt = CanonLoop->getLoopStmt();
@@ -363,6 +360,32 @@ OMPForDirective *OMPForDirective::Create(
   return Dir;
 }
 
+Stmt *OMPLoopTransformationDirective::getTransformedStmt() const {
+  switch (getStmtClass()) {
+#define STMT(CLASS, PARENT)
+#define ABSTRACT_STMT(CLASS)
+#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT)                          \
+  case Stmt::CLASS##Class:                                                     \
+    return static_cast<const CLASS *>(this)->getTransformedStmt();
+#include "clang/AST/StmtNodes.inc"
+  default:
+    llvm_unreachable("Not a loop transformation");
+  }
+}
+
+Stmt *OMPLoopTransformationDirective::getPreInits() const {
+  switch (getStmtClass()) {
+#define STMT(CLASS, PARENT)
+#define ABSTRACT_STMT(CLASS)
+#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT)                          \
+  case Stmt::CLASS##Class:                                                     \
+    return static_cast<const CLASS *>(this)->getPreInits();
+#include "clang/AST/StmtNodes.inc"
+  default:
+    llvm_unreachable("Not a loop transformation");
+  }
+}
+
 OMPForDirective *OMPForDirective::CreateEmpty(const ASTContext &C,
                                               unsigned NumClauses,
                                               unsigned CollapsedNum,

diff  --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 1afa3773c711..9c48491dca71 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -915,12 +915,17 @@ void StmtProfiler::VisitOMPSimdDirective(const OMPSimdDirective *S) {
   VisitOMPLoopDirective(S);
 }
 
-void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) {
+void StmtProfiler::VisitOMPLoopTransformationDirective(
+    const OMPLoopTransformationDirective *S) {
   VisitOMPLoopBasedDirective(S);
 }
 
+void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) {
+  VisitOMPLoopTransformationDirective(S);
+}
+
 void StmtProfiler::VisitOMPUnrollDirective(const OMPUnrollDirective *S) {
-  VisitOMPLoopBasedDirective(S);
+  VisitOMPLoopTransformationDirective(S);
 }
 
 void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) {

diff  --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index 68ab1362a454..431cc62ef84f 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -1829,9 +1829,7 @@ static void emitBody(CodeGenFunction &CGF, const Stmt *S, const Stmt *NextLoop,
     return;
   }
   if (SimplifiedS == NextLoop) {
-    if (auto *Dir = dyn_cast<OMPTileDirective>(SimplifiedS))
-      SimplifiedS = Dir->getTransformedStmt();
-    if (auto *Dir = dyn_cast<OMPUnrollDirective>(SimplifiedS))
+    if (auto *Dir = dyn_cast<OMPLoopTransformationDirective>(SimplifiedS))
       SimplifiedS = Dir->getTransformedStmt();
     if (const auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(SimplifiedS))
       SimplifiedS = CanonLoop->getLoopStmt();

diff  --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 5e6fa2c8fc4e..af70a180b27c 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -3823,13 +3823,8 @@ class DSAAttrChecker final : public StmtVisitor<DSAAttrChecker, void> {
     VisitSubCaptures(S);
   }
 
-  void VisitOMPTileDirective(OMPTileDirective *S) {
-    // #pragma omp tile does not introduce data sharing.
-    VisitStmt(S);
-  }
-
-  void VisitOMPUnrollDirective(OMPUnrollDirective *S) {
-    // #pragma omp unroll does not introduce data sharing.
+  void VisitOMPLoopTransformationDirective(OMPLoopTransformationDirective *S) {
+    // Loop transformation directives do not introduce data sharing
     VisitStmt(S);
   }
 
@@ -9050,15 +9045,8 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
             }
             return false;
           },
-          [&SemaRef, &Captures](OMPLoopBasedDirective *Transform) {
-            Stmt *DependentPreInits;
-            if (auto *Dir = dyn_cast<OMPTileDirective>(Transform)) {
-              DependentPreInits = Dir->getPreInits();
-            } else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform)) {
-              DependentPreInits = Dir->getPreInits();
-            } else {
-              llvm_unreachable("Unexpected loop transformation");
-            }
+          [&SemaRef, &Captures](OMPLoopTransformationDirective *Transform) {
+            Stmt *DependentPreInits = Transform->getPreInits();
             if (!DependentPreInits)
               return;
             for (Decl *C : cast<DeclStmt>(DependentPreInits)->getDeclGroup()) {

diff  --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 6dcb5d57048f..34a58831e0d4 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2324,12 +2324,17 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) {
   VisitOMPLoopDirective(D);
 }
 
-void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {
+void ASTStmtReader::VisitOMPLoopTransformationDirective(
+    OMPLoopTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
 }
 
+void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {
+  VisitOMPLoopTransformationDirective(D);
+}
+
 void ASTStmtReader::VisitOMPUnrollDirective(OMPUnrollDirective *D) {
-  VisitOMPLoopBasedDirective(D);
+  VisitOMPLoopTransformationDirective(D);
 }
 
 void ASTStmtReader::VisitOMPForDirective(OMPForDirective *D) {

diff  --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index a12c72674502..bf32294bc95f 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2223,13 +2223,18 @@ void ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) {
   Code = serialization::STMT_OMP_SIMD_DIRECTIVE;
 }
 
-void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {
+void ASTStmtWriter::VisitOMPLoopTransformationDirective(
+    OMPLoopTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
+}
+
+void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {
+  VisitOMPLoopTransformationDirective(D);
   Code = serialization::STMT_OMP_TILE_DIRECTIVE;
 }
 
 void ASTStmtWriter::VisitOMPUnrollDirective(OMPUnrollDirective *D) {
-  VisitOMPLoopBasedDirective(D);
+  VisitOMPLoopTransformationDirective(D);
   Code = serialization::STMT_OMP_UNROLL_DIRECTIVE;
 }
 

diff  --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index 96a7cb4f3469..f76f04195286 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -2046,6 +2046,8 @@ class EnqueueVisitor : public ConstStmtVisitor<EnqueueVisitor, void> {
   void VisitOMPLoopDirective(const OMPLoopDirective *D);
   void VisitOMPParallelDirective(const OMPParallelDirective *D);
   void VisitOMPSimdDirective(const OMPSimdDirective *D);
+  void
+  VisitOMPLoopTransformationDirective(const OMPLoopTransformationDirective *D);
   void VisitOMPTileDirective(const OMPTileDirective *D);
   void VisitOMPUnrollDirective(const OMPUnrollDirective *D);
   void VisitOMPForDirective(const OMPForDirective *D);
@@ -2901,12 +2903,17 @@ void EnqueueVisitor::VisitOMPSimdDirective(const OMPSimdDirective *D) {
   VisitOMPLoopDirective(D);
 }
 
-void EnqueueVisitor::VisitOMPTileDirective(const OMPTileDirective *D) {
+void EnqueueVisitor::VisitOMPLoopTransformationDirective(
+    const OMPLoopTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
 }
 
+void EnqueueVisitor::VisitOMPTileDirective(const OMPTileDirective *D) {
+  VisitOMPLoopTransformationDirective(D);
+}
+
 void EnqueueVisitor::VisitOMPUnrollDirective(const OMPUnrollDirective *D) {
-  VisitOMPLoopBasedDirective(D);
+  VisitOMPLoopTransformationDirective(D);
 }
 
 void EnqueueVisitor::VisitOMPForDirective(const OMPForDirective *D) {


        


More information about the cfe-commits mailing list