[llvm-branch-commits] [clang] [llvm] [openmp] [Clang][OpenMP] Add interchange directive (PR #93022)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed May 22 04:44:30 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-flang-openmp

Author: Michael Kruse (Meinersbur)

<details>
<summary>Changes</summary>

Add the interchange directive which will be introduced in the upcoming OpenMP 6.0 specification. A preview has been published in [Technical Report 12](https://www.openmp.org/wp-content/uploads/openmp-TR12.pdf).

---

Patch is 185.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93022.diff


30 Files Affected:

- (modified) clang/include/clang-c/Index.h (+4) 
- (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+3) 
- (modified) clang/include/clang/AST/StmtOpenMP.h (+75-1) 
- (modified) clang/include/clang/Basic/StmtNodes.td (+1) 
- (modified) clang/include/clang/Sema/SemaOpenMP.h (+6) 
- (modified) clang/include/clang/Serialization/ASTBitCodes.h (+1) 
- (modified) clang/lib/AST/StmtOpenMP.cpp (+20) 
- (modified) clang/lib/AST/StmtPrinter.cpp (+5) 
- (modified) clang/lib/AST/StmtProfile.cpp (+5) 
- (modified) clang/lib/Basic/OpenMPKinds.cpp (+3-1) 
- (modified) clang/lib/CodeGen/CGStmt.cpp (+3) 
- (modified) clang/lib/CodeGen/CGStmtOpenMP.cpp (+10) 
- (modified) clang/lib/CodeGen/CodeGenFunction.h (+1) 
- (modified) clang/lib/Parse/ParseOpenMP.cpp (+2) 
- (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1) 
- (modified) clang/lib/Sema/SemaOpenMP.cpp (+162) 
- (modified) clang/lib/Sema/TreeTransform.h (+11) 
- (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+11) 
- (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+5) 
- (added) clang/test/OpenMP/interchange_ast_print.cpp (+135) 
- (added) clang/test/OpenMP/interchange_codegen.cpp (+1990) 
- (added) clang/test/OpenMP/interchange_messages.cpp (+77) 
- (modified) clang/tools/libclang/CIndex.cpp (+8) 
- (modified) clang/tools/libclang/CXCursor.cpp (+3) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+3) 
- (added) openmp/runtime/test/transform/interchange/foreach.cpp (+216) 
- (added) openmp/runtime/test/transform/interchange/intfor.c (+38) 
- (added) openmp/runtime/test/transform/interchange/iterfor.cpp (+222) 
- (added) openmp/runtime/test/transform/interchange/parallel-wsloop-collapse-foreach.cpp (+340) 
- (added) openmp/runtime/test/transform/interchange/parallel-wsloop-collapse-intfor.cpp (+106) 


``````````diff
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index c7d63818ece23..a79aafbf20222 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2150,6 +2150,10 @@ enum CXCursorKind {
    */
   CXCursor_OMPReverseDirective = 307,
 
+  /** OpenMP interchange directive.
+   */
+  CXCursor_OMPInterchangeDirective = 308,
+
   /** OpenACC Compute Construct.
    */
   CXCursor_OpenACCComputeConstruct = 320,
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 06b29d59785f6..1bb167d7ddc3c 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3024,6 +3024,9 @@ DEF_TRAVERSE_STMT(OMPUnrollDirective,
 DEF_TRAVERSE_STMT(OMPReverseDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
+DEF_TRAVERSE_STMT(OMPInterchangeDirective,
+                  { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
 DEF_TRAVERSE_STMT(OMPForDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index fb7f413162fad..01c8b8e1a9f5e 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -1009,7 +1009,7 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
   static bool classof(const Stmt *T) {
     Stmt::StmtClass C = T->getStmtClass();
     return C == OMPTileDirectiveClass || C == OMPUnrollDirectiveClass ||
-           C == OMPReverseDirectiveClass;
+           C == OMPReverseDirectiveClass || C == OMPInterchangeDirectiveClass;
   }
 };
 
@@ -5777,6 +5777,80 @@ class OMPReverseDirective final : public OMPLoopTransformationDirective {
   }
 };
 
+/// Represents the '#pragma omp interchange' loop transformation directive.
+///
+/// \code{c}
+///   #pragma omp interchange
+///   for (int i = 0; i < m; ++i)
+///     for (int j = 0; j < n; ++j)
+///       ..
+/// \endcode
+class OMPInterchangeDirective final : public OMPLoopTransformationDirective {
+  friend class ASTStmtReader;
+  friend class OMPExecutableDirective;
+
+  /// Offsets of child members.
+  enum {
+    PreInitsOffset = 0,
+    TransformedStmtOffset,
+  };
+
+  explicit OMPInterchangeDirective(SourceLocation StartLoc,
+                                   SourceLocation EndLoc, unsigned NumLoops)
+      : OMPLoopTransformationDirective(OMPInterchangeDirectiveClass,
+                                       llvm::omp::OMPD_interchange, StartLoc,
+                                       EndLoc, NumLoops) {
+    setNumGeneratedLoops(3 * NumLoops);
+  }
+
+  void setPreInits(Stmt *PreInits) {
+    Data->getChildren()[PreInitsOffset] = PreInits;
+  }
+
+  void setTransformedStmt(Stmt *S) {
+    Data->getChildren()[TransformedStmtOffset] = S;
+  }
+
+public:
+  /// Create a new AST node representation for '#pragma omp interchange'.
+  ///
+  /// \param C         Context of the AST.
+  /// \param StartLoc  Location of the introducer (e.g. the 'omp' token).
+  /// \param EndLoc    Location of the directive's end (e.g. the tok::eod).
+  /// \param Clauses   The directive's clauses.
+  /// \param NumLoops  Number of affected loops
+  ///                  (number of items in the 'permutation' clause if present).
+  /// \param AssociatedStmt  The outermost associated loop.
+  /// \param TransformedStmt The loop nest after tiling, or nullptr in
+  ///                        dependent contexts.
+  /// \param PreInits  Helper preinits statements for the loop nest.
+  static OMPInterchangeDirective *
+  Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+         ArrayRef<OMPClause *> Clauses, unsigned NumLoops, Stmt *AssociatedStmt,
+         Stmt *TransformedStmt, Stmt *PreInits);
+
+  /// Build an empty '#pragma omp interchange' AST node for deserialization.
+  ///
+  /// \param C          Context of the AST.
+  /// \param NumClauses Number of clauses to allocate.
+  /// \param NumLoops   Number of associated loops to allocate.
+  static OMPInterchangeDirective *
+  CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned NumLoops);
+
+  /// Gets the associated loops after the transformation. This is the de-sugared
+  /// replacement or nullptr in dependent contexts.
+  Stmt *getTransformedStmt() const {
+    return Data->getChildren()[TransformedStmtOffset];
+  }
+
+  /// Return preinits statement.
+  Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == OMPInterchangeDirectiveClass;
+  }
+};
+
 /// This represents '#pragma omp scan' directive.
 ///
 /// \code
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index b2e2be5c998bb..b445ea225eac5 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -230,6 +230,7 @@ def OMPLoopTransformationDirective : StmtNode<OMPLoopBasedDirective, 1>;
 def OMPTileDirective : StmtNode<OMPLoopTransformationDirective>;
 def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>;
 def OMPReverseDirective : StmtNode<OMPLoopTransformationDirective>;
+def OMPInterchangeDirective : StmtNode<OMPLoopTransformationDirective>;
 def OMPForDirective : StmtNode<OMPLoopDirective>;
 def OMPForSimdDirective : StmtNode<OMPLoopDirective>;
 def OMPSectionsDirective : StmtNode<OMPExecutableDirective>;
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index ca91bffe24f6f..06376f173e8df 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -425,6 +425,12 @@ class SemaOpenMP : public SemaBase {
   /// Called on well-formed '#pragma omp reverse'.
   StmtResult ActOnOpenMPReverseDirective(Stmt *AStmt, SourceLocation StartLoc,
                                          SourceLocation EndLoc);
+  /// Called on well-formed '#pragma omp interchange' after parsing of its
+  /// clauses and the associated statement.
+  StmtResult ActOnOpenMPInterchangeDirective(ArrayRef<OMPClause *> Clauses,
+                                             Stmt *AStmt,
+                                             SourceLocation StartLoc,
+                                             SourceLocation EndLoc);
   /// Called on well-formed '\#pragma omp for' after parsing
   /// of the associated statement.
   StmtResult
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index dee0d073557cc..5fbdfd7a496fe 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1857,6 +1857,7 @@ enum StmtCode {
   STMT_OMP_TILE_DIRECTIVE,
   STMT_OMP_UNROLL_DIRECTIVE,
   STMT_OMP_REVERSE_DIRECTIVE,
+  STMT_OMP_INTERCHANGE_DIRECTIVE,
   STMT_OMP_FOR_DIRECTIVE,
   STMT_OMP_FOR_SIMD_DIRECTIVE,
   STMT_OMP_SECTIONS_DIRECTIVE,
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index 83b8a08e9af73..24d8eb25c59ba 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -467,6 +467,26 @@ OMPReverseDirective *OMPReverseDirective::CreateEmpty(const ASTContext &C,
       SourceLocation(), SourceLocation());
 }
 
+OMPInterchangeDirective *OMPInterchangeDirective::Create(
+    const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+    ArrayRef<OMPClause *> Clauses, unsigned NumLoops, Stmt *AssociatedStmt,
+    Stmt *TransformedStmt, Stmt *PreInits) {
+  OMPInterchangeDirective *Dir = createDirective<OMPInterchangeDirective>(
+      C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc,
+      NumLoops);
+  Dir->setTransformedStmt(TransformedStmt);
+  Dir->setPreInits(PreInits);
+  return Dir;
+}
+
+OMPInterchangeDirective *
+OMPInterchangeDirective::CreateEmpty(const ASTContext &C, unsigned NumClauses,
+                                     unsigned NumLoops) {
+  return createEmptyDirective<OMPInterchangeDirective>(
+      C, NumClauses, /*HasAssociatedStmt=*/true, TransformedStmtOffset + 1,
+      SourceLocation(), SourceLocation(), NumLoops);
+}
+
 OMPForSimdDirective *
 OMPForSimdDirective::Create(const ASTContext &C, SourceLocation StartLoc,
                             SourceLocation EndLoc, unsigned CollapsedNum,
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 64b481f680311..64bee75b205ae 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -768,6 +768,11 @@ void StmtPrinter::VisitOMPReverseDirective(OMPReverseDirective *Node) {
   PrintOMPExecutableDirective(Node);
 }
 
+void StmtPrinter::VisitOMPInterchangeDirective(OMPInterchangeDirective *Node) {
+  Indent() << "#pragma omp interchange";
+  PrintOMPExecutableDirective(Node);
+}
+
 void StmtPrinter::VisitOMPForDirective(OMPForDirective *Node) {
   Indent() << "#pragma omp for";
   PrintOMPExecutableDirective(Node);
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 7445e5519b972..1ae99d43575a7 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -989,6 +989,11 @@ void StmtProfiler::VisitOMPReverseDirective(const OMPReverseDirective *S) {
   VisitOMPLoopTransformationDirective(S);
 }
 
+void StmtProfiler::VisitOMPInterchangeDirective(
+    const OMPInterchangeDirective *S) {
+  VisitOMPLoopTransformationDirective(S);
+}
+
 void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) {
   VisitOMPLoopDirective(S);
 }
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 803808c38e2fe..ff5d5c8bdc981 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -684,7 +684,8 @@ bool clang::isOpenMPLoopBoundSharingDirective(OpenMPDirectiveKind Kind) {
 }
 
 bool clang::isOpenMPLoopTransformationDirective(OpenMPDirectiveKind DKind) {
-  return DKind == OMPD_tile || DKind == OMPD_unroll || DKind == OMPD_reverse;
+  return DKind == OMPD_tile || DKind == OMPD_unroll || DKind == OMPD_reverse ||
+         DKind == OMPD_interchange;
 }
 
 bool clang::isOpenMPCombinedParallelADirective(OpenMPDirectiveKind DKind) {
@@ -809,6 +810,7 @@ void clang::getOpenMPCaptureRegions(
   case OMPD_tile:
   case OMPD_unroll:
   case OMPD_reverse:
+  case OMPD_interchange:
     // loop transformations do not introduce captures.
     break;
   case OMPD_threadprivate:
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 93c2f8900dd12..ba7c52cc6ab7b 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -225,6 +225,9 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
   case Stmt::OMPReverseDirectiveClass:
     EmitOMPReverseDirective(cast<OMPReverseDirective>(*S));
     break;
+  case Stmt::OMPInterchangeDirectiveClass:
+    EmitOMPInterchangeDirective(cast<OMPInterchangeDirective>(*S));
+    break;
   case Stmt::OMPForDirectiveClass:
     EmitOMPForDirective(cast<OMPForDirective>(*S));
     break;
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index ad6c044aa483b..7a37e452fb559 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -189,6 +189,9 @@ class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
       PreInits = Unroll->getPreInits();
     } else if (const auto *Reverse = dyn_cast<OMPReverseDirective>(&S)) {
       PreInits = Reverse->getPreInits();
+    } else if (const auto *Interchange =
+                   dyn_cast<OMPInterchangeDirective>(&S)) {
+      PreInits = Interchange->getPreInits();
     } else {
       llvm_unreachable("Unknown loop-based directive kind.");
     }
@@ -2770,6 +2773,13 @@ void CodeGenFunction::EmitOMPReverseDirective(const OMPReverseDirective &S) {
   EmitStmt(S.getTransformedStmt());
 }
 
+void CodeGenFunction::EmitOMPInterchangeDirective(
+    const OMPInterchangeDirective &S) {
+  // Emit the de-sugared statement.
+  OMPTransformDirectiveScopeRAII InterchangeScope(*this, &S);
+  EmitStmt(S.getTransformedStmt());
+}
+
 void CodeGenFunction::EmitOMPUnrollDirective(const OMPUnrollDirective &S) {
   bool UseOMPIRBuilder = CGM.getLangOpts().OpenMPIRBuilder;
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index ac738e1e82886..c2a8e65ca2d0a 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3808,6 +3808,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   void EmitOMPTileDirective(const OMPTileDirective &S);
   void EmitOMPUnrollDirective(const OMPUnrollDirective &S);
   void EmitOMPReverseDirective(const OMPReverseDirective &S);
+  void EmitOMPInterchangeDirective(const OMPInterchangeDirective &S);
   void EmitOMPForDirective(const OMPForDirective &S);
   void EmitOMPForSimdDirective(const OMPForSimdDirective &S);
   void EmitOMPSectionsDirective(const OMPSectionsDirective &S);
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 57fcf6ce520ac..0e3e604203c86 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -2385,6 +2385,7 @@ Parser::DeclGroupPtrTy Parser::ParseOpenMPDeclarativeDirectiveWithExtDecl(
   case OMPD_tile:
   case OMPD_unroll:
   case OMPD_reverse:
+  case OMPD_interchange:
   case OMPD_task:
   case OMPD_taskyield:
   case OMPD_barrier:
@@ -2804,6 +2805,7 @@ StmtResult Parser::ParseOpenMPDeclarativeOrExecutableDirective(
   case OMPD_tile:
   case OMPD_unroll:
   case OMPD_reverse:
+  case OMPD_interchange:
   case OMPD_for:
   case OMPD_for_simd:
   case OMPD_sections:
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index 4de7183cde281..5991f496d3a0f 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1487,6 +1487,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
   case Stmt::OMPTileDirectiveClass:
   case Stmt::OMPUnrollDirectiveClass:
   case Stmt::OMPReverseDirectiveClass:
+  case Stmt::OMPInterchangeDirectiveClass:
   case Stmt::OMPSingleDirectiveClass:
   case Stmt::OMPTargetDataDirectiveClass:
   case Stmt::OMPTargetDirectiveClass:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index c2fd4de933ae4..ef141003b7d61 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -39,6 +39,7 @@
 #include "llvm/ADT/IndexedMap.h"
 #include "llvm/ADT/PointerEmbeddedInt.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Frontend/OpenMP/OMPAssume.h"
@@ -4335,6 +4336,7 @@ void SemaOpenMP::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind,
   case OMPD_tile:
   case OMPD_unroll:
   case OMPD_reverse:
+  case OMPD_interchange:
     break;
   case OMPD_loop:
     // TODO: 'loop' may require additional parameters depending on the binding.
@@ -6552,6 +6554,10 @@ StmtResult SemaOpenMP::ActOnOpenMPExecutableDirective(
            "reverse directive does not support any clauses");
     Res = ActOnOpenMPReverseDirective(AStmt, StartLoc, EndLoc);
     break;
+  case OMPD_interchange:
+    Res = ActOnOpenMPInterchangeDirective(ClausesWithImplicit, AStmt, StartLoc,
+                                          EndLoc);
+    break;
   case OMPD_for:
     Res = ActOnOpenMPForDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc,
                                   VarsWithInheritedDSA);
@@ -15139,6 +15145,8 @@ bool SemaOpenMP::checkTransformableLoopNest(
           DependentPreInits = Dir->getPreInits();
         else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform))
           DependentPreInits = Dir->getPreInits();
+        else if (auto *Dir = dyn_cast<OMPInterchangeDirective>(Transform))
+          DependentPreInits = Dir->getPreInits();
         else
           llvm_unreachable("Unhandled loop transformation");
 
@@ -15937,6 +15945,160 @@ StmtResult SemaOpenMP::ActOnOpenMPReverseDirective(Stmt *AStmt,
                                      buildPreInits(Context, PreInits));
 }
 
+StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
+    ArrayRef<OMPClause *> Clauses, Stmt *AStmt, SourceLocation StartLoc,
+    SourceLocation EndLoc) {
+  ASTContext &Context = getASTContext();
+  DeclContext *CurContext = SemaRef.CurContext;
+  Scope *CurScope = SemaRef.getCurScope();
+
+  // Empty statement should only be possible if there already was an error.
+  if (!AStmt)
+    return StmtError();
+
+  // interchange without permutation clause swaps two loops.
+  constexpr size_t NumLoops = 2;
+
+  // Verify and diagnose loop nest.
+  SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
+  Stmt *Body = nullptr;
+  SmallVector<SmallVector<Stmt *, 0>, 2> OriginalInits;
+  if (!checkTransformableLoopNest(OMPD_interchange, AStmt, NumLoops,
+                                  LoopHelpers, Body, OriginalInits))
+    return StmtError();
+
+  // Delay interchange to when template is completely instantiated.
+  if (CurContext->isDependentContext())
+    return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses,
+                                           NumLoops, AStmt, nullptr, nullptr);
+
+  assert(LoopHelpers.size() == NumLoops &&
+         "Expecting loop iteration space dimensionaly to match number of "
+         "affected loops");
+  assert(OriginalInits.size() == NumLoops &&
+         "Expecting loop iteration space dimensionaly to match number of "
+         "affected loops");
+
+  // Decode the permutation clause.
+  constexpr uint64_t Permutation[] = {1, 0};
+
+  // Find the affected loops.
+  SmallVector<Stmt *> LoopStmts(NumLoops, nullptr);
+  collectLoopStmts(AStmt, LoopStmts);
+
+  // Collect pre-init statements on the order before the permuation.
+  SmallVector<Stmt *> PreInits;
+  for (auto I : llvm::seq<int>(NumLoops)) {
+    OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I];
+
+    assert(LoopHelper.Counters.size() == 1 &&
+           "Single-dimensional loop iteration space expected");
+    auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters.front());
+
+    std::string OrigVarName = OrigCntVar->getNameInfo().getAsString();
+    addLoopPreInits(Context, LoopHelper, LoopStmts[I], OriginalInits[I],
+                    PreInits);
+  }
+
+  SmallVector<VarDecl *> PermutedIndVars;
+  PermutedIndVars.resize(NumLoops);
+  CaptureVars CopyTransformer(SemaRef);
+
+  // Create the permuted loops from the inside to the outside of the
+  // interchanged loop nest. Body of the innermost new loop is the original
+  // innermost body.
+  Stmt *Inner = Body;
+  for (auto TargetIdx : llvm::reverse(llvm::seq<int>(NumLoops))) {
+    // Get the original loop that belongs to this new position.
+    uint64_t SourceIdx = Permutation[TargetIdx];
+    OMPLoopBasedDirective::HelperExprs &SourceHelper = LoopHelpers[SourceIdx];
+    Stmt *SourceLoopStmt = LoopStmts[SourceIdx];
+    assert(SourceHelper.Counters.size() == 1 &&
+           "Single-dimensional loop iteration space expected");
+    auto *OrigCntVar = cast<DeclRefExpr>(SourceHelper.Counters.front());
+
+    // Normalized loop counter variable: From 0 to n-1, always an integer type.
+    DeclRefExpr *IterVarRef = cast<DeclRefExpr>(SourceHelper.IterationVarRef);
+    QualType IVTy = IterVarRef->getType();
+    assert(IVTy->isIntegerType() &&
+           "Expected the logical iteration counter to be an integer");
+
+    std::string OrigVarName = OrigCntVar->getNameInfo().getAsString();
+    SourceLocation OrigVarLoc = IterVarRef->getExprLoc();
+
+    // Make a copy of the NumIterations expression for each use: By the AST
+    // constraints, every expression object in a DeclContext must be unique.
+    auto MakeNumIterations = [&CopyTransformer, &SourceHelper]() -> Expr * {
+      return AssertSuccess(
+          CopyTransformer.TransformExpr(SourceHelper.NumIterations));
+    };
+
+    // Iteration variable for the permuted loop. Reuse the one from
+    // checkOpenMPLoop which will also be used to update the original loop
+    // variable.
+    std::string PermutedCntName =
+        (Twine(".permuted_") + llvm::utostr(TargetIdx) + ".iv." + OrigVarName)
+            .str();
+    auto *PermutedCntDecl = cast<VarDecl>(IterVarRef->getDecl());
+    PermutedCntDecl->setDeclName(
+        &SemaRef.P...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/93022


More information about the llvm-branch-commits mailing list