[llvm-branch-commits] [clang] [llvm] [openmp] [Clang][OpenMP] Add interchange directive (PR #93022)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed May 22 04:44:30 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-flang-openmp
Author: Michael Kruse (Meinersbur)
<details>
<summary>Changes</summary>
Add the interchange directive which will be introduced in the upcoming OpenMP 6.0 specification. A preview has been published in [Technical Report 12](https://www.openmp.org/wp-content/uploads/openmp-TR12.pdf).
---
Patch is 185.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93022.diff
30 Files Affected:
- (modified) clang/include/clang-c/Index.h (+4)
- (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+3)
- (modified) clang/include/clang/AST/StmtOpenMP.h (+75-1)
- (modified) clang/include/clang/Basic/StmtNodes.td (+1)
- (modified) clang/include/clang/Sema/SemaOpenMP.h (+6)
- (modified) clang/include/clang/Serialization/ASTBitCodes.h (+1)
- (modified) clang/lib/AST/StmtOpenMP.cpp (+20)
- (modified) clang/lib/AST/StmtPrinter.cpp (+5)
- (modified) clang/lib/AST/StmtProfile.cpp (+5)
- (modified) clang/lib/Basic/OpenMPKinds.cpp (+3-1)
- (modified) clang/lib/CodeGen/CGStmt.cpp (+3)
- (modified) clang/lib/CodeGen/CGStmtOpenMP.cpp (+10)
- (modified) clang/lib/CodeGen/CodeGenFunction.h (+1)
- (modified) clang/lib/Parse/ParseOpenMP.cpp (+2)
- (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1)
- (modified) clang/lib/Sema/SemaOpenMP.cpp (+162)
- (modified) clang/lib/Sema/TreeTransform.h (+11)
- (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+11)
- (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+5)
- (added) clang/test/OpenMP/interchange_ast_print.cpp (+135)
- (added) clang/test/OpenMP/interchange_codegen.cpp (+1990)
- (added) clang/test/OpenMP/interchange_messages.cpp (+77)
- (modified) clang/tools/libclang/CIndex.cpp (+8)
- (modified) clang/tools/libclang/CXCursor.cpp (+3)
- (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+3)
- (added) openmp/runtime/test/transform/interchange/foreach.cpp (+216)
- (added) openmp/runtime/test/transform/interchange/intfor.c (+38)
- (added) openmp/runtime/test/transform/interchange/iterfor.cpp (+222)
- (added) openmp/runtime/test/transform/interchange/parallel-wsloop-collapse-foreach.cpp (+340)
- (added) openmp/runtime/test/transform/interchange/parallel-wsloop-collapse-intfor.cpp (+106)
``````````diff
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index c7d63818ece23..a79aafbf20222 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2150,6 +2150,10 @@ enum CXCursorKind {
*/
CXCursor_OMPReverseDirective = 307,
+ /** OpenMP interchange directive.
+ */
+ CXCursor_OMPInterchangeDirective = 308,
+
/** OpenACC Compute Construct.
*/
CXCursor_OpenACCComputeConstruct = 320,
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 06b29d59785f6..1bb167d7ddc3c 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3024,6 +3024,9 @@ DEF_TRAVERSE_STMT(OMPUnrollDirective,
DEF_TRAVERSE_STMT(OMPReverseDirective,
{ TRY_TO(TraverseOMPExecutableDirective(S)); })
+DEF_TRAVERSE_STMT(OMPInterchangeDirective,
+ { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
DEF_TRAVERSE_STMT(OMPForDirective,
{ TRY_TO(TraverseOMPExecutableDirective(S)); })
diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index fb7f413162fad..01c8b8e1a9f5e 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -1009,7 +1009,7 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
static bool classof(const Stmt *T) {
Stmt::StmtClass C = T->getStmtClass();
return C == OMPTileDirectiveClass || C == OMPUnrollDirectiveClass ||
- C == OMPReverseDirectiveClass;
+ C == OMPReverseDirectiveClass || C == OMPInterchangeDirectiveClass;
}
};
@@ -5777,6 +5777,80 @@ class OMPReverseDirective final : public OMPLoopTransformationDirective {
}
};
+/// Represents the '#pragma omp interchange' loop transformation directive.
+///
+/// \code{c}
+/// #pragma omp interchange
+/// for (int i = 0; i < m; ++i)
+/// for (int j = 0; j < n; ++j)
+/// ..
+/// \endcode
+class OMPInterchangeDirective final : public OMPLoopTransformationDirective {
+ friend class ASTStmtReader;
+ friend class OMPExecutableDirective;
+
+ /// Offsets of child members.
+ enum {
+ PreInitsOffset = 0,
+ TransformedStmtOffset,
+ };
+
+ explicit OMPInterchangeDirective(SourceLocation StartLoc,
+ SourceLocation EndLoc, unsigned NumLoops)
+ : OMPLoopTransformationDirective(OMPInterchangeDirectiveClass,
+ llvm::omp::OMPD_interchange, StartLoc,
+ EndLoc, NumLoops) {
+ setNumGeneratedLoops(3 * NumLoops);
+ }
+
+ void setPreInits(Stmt *PreInits) {
+ Data->getChildren()[PreInitsOffset] = PreInits;
+ }
+
+ void setTransformedStmt(Stmt *S) {
+ Data->getChildren()[TransformedStmtOffset] = S;
+ }
+
+public:
+ /// Create a new AST node representation for '#pragma omp interchange'.
+ ///
+ /// \param C Context of the AST.
+ /// \param StartLoc Location of the introducer (e.g. the 'omp' token).
+ /// \param EndLoc Location of the directive's end (e.g. the tok::eod).
+ /// \param Clauses The directive's clauses.
+ /// \param NumLoops Number of affected loops
+ /// (number of items in the 'permutation' clause if present).
+ /// \param AssociatedStmt The outermost associated loop.
+ /// \param TransformedStmt The loop nest after tiling, or nullptr in
+ /// dependent contexts.
+ /// \param PreInits Helper preinits statements for the loop nest.
+ static OMPInterchangeDirective *
+ Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+ ArrayRef<OMPClause *> Clauses, unsigned NumLoops, Stmt *AssociatedStmt,
+ Stmt *TransformedStmt, Stmt *PreInits);
+
+ /// Build an empty '#pragma omp interchange' AST node for deserialization.
+ ///
+ /// \param C Context of the AST.
+ /// \param NumClauses Number of clauses to allocate.
+ /// \param NumLoops Number of associated loops to allocate.
+ static OMPInterchangeDirective *
+ CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned NumLoops);
+
+ /// Gets the associated loops after the transformation. This is the de-sugared
+ /// replacement or nullptr in dependent contexts.
+ Stmt *getTransformedStmt() const {
+ return Data->getChildren()[TransformedStmtOffset];
+ }
+
+ /// Return preinits statement.
+ Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; }
+
+ static bool classof(const Stmt *T) {
+ return T->getStmtClass() == OMPInterchangeDirectiveClass;
+ }
+};
+
/// This represents '#pragma omp scan' directive.
///
/// \code
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index b2e2be5c998bb..b445ea225eac5 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -230,6 +230,7 @@ def OMPLoopTransformationDirective : StmtNode<OMPLoopBasedDirective, 1>;
def OMPTileDirective : StmtNode<OMPLoopTransformationDirective>;
def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>;
def OMPReverseDirective : StmtNode<OMPLoopTransformationDirective>;
+def OMPInterchangeDirective : StmtNode<OMPLoopTransformationDirective>;
def OMPForDirective : StmtNode<OMPLoopDirective>;
def OMPForSimdDirective : StmtNode<OMPLoopDirective>;
def OMPSectionsDirective : StmtNode<OMPExecutableDirective>;
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index ca91bffe24f6f..06376f173e8df 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -425,6 +425,12 @@ class SemaOpenMP : public SemaBase {
/// Called on well-formed '#pragma omp reverse'.
StmtResult ActOnOpenMPReverseDirective(Stmt *AStmt, SourceLocation StartLoc,
SourceLocation EndLoc);
+ /// Called on well-formed '#pragma omp interchange' after parsing of its
+ /// clauses and the associated statement.
+ StmtResult ActOnOpenMPInterchangeDirective(ArrayRef<OMPClause *> Clauses,
+ Stmt *AStmt,
+ SourceLocation StartLoc,
+ SourceLocation EndLoc);
/// Called on well-formed '\#pragma omp for' after parsing
/// of the associated statement.
StmtResult
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index dee0d073557cc..5fbdfd7a496fe 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1857,6 +1857,7 @@ enum StmtCode {
STMT_OMP_TILE_DIRECTIVE,
STMT_OMP_UNROLL_DIRECTIVE,
STMT_OMP_REVERSE_DIRECTIVE,
+ STMT_OMP_INTERCHANGE_DIRECTIVE,
STMT_OMP_FOR_DIRECTIVE,
STMT_OMP_FOR_SIMD_DIRECTIVE,
STMT_OMP_SECTIONS_DIRECTIVE,
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index 83b8a08e9af73..24d8eb25c59ba 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -467,6 +467,26 @@ OMPReverseDirective *OMPReverseDirective::CreateEmpty(const ASTContext &C,
SourceLocation(), SourceLocation());
}
+OMPInterchangeDirective *OMPInterchangeDirective::Create(
+ const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+ ArrayRef<OMPClause *> Clauses, unsigned NumLoops, Stmt *AssociatedStmt,
+ Stmt *TransformedStmt, Stmt *PreInits) {
+ OMPInterchangeDirective *Dir = createDirective<OMPInterchangeDirective>(
+ C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc,
+ NumLoops);
+ Dir->setTransformedStmt(TransformedStmt);
+ Dir->setPreInits(PreInits);
+ return Dir;
+}
+
+OMPInterchangeDirective *
+OMPInterchangeDirective::CreateEmpty(const ASTContext &C, unsigned NumClauses,
+ unsigned NumLoops) {
+ return createEmptyDirective<OMPInterchangeDirective>(
+ C, NumClauses, /*HasAssociatedStmt=*/true, TransformedStmtOffset + 1,
+ SourceLocation(), SourceLocation(), NumLoops);
+}
+
OMPForSimdDirective *
OMPForSimdDirective::Create(const ASTContext &C, SourceLocation StartLoc,
SourceLocation EndLoc, unsigned CollapsedNum,
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 64b481f680311..64bee75b205ae 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -768,6 +768,11 @@ void StmtPrinter::VisitOMPReverseDirective(OMPReverseDirective *Node) {
PrintOMPExecutableDirective(Node);
}
+void StmtPrinter::VisitOMPInterchangeDirective(OMPInterchangeDirective *Node) {
+ Indent() << "#pragma omp interchange";
+ PrintOMPExecutableDirective(Node);
+}
+
void StmtPrinter::VisitOMPForDirective(OMPForDirective *Node) {
Indent() << "#pragma omp for";
PrintOMPExecutableDirective(Node);
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 7445e5519b972..1ae99d43575a7 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -989,6 +989,11 @@ void StmtProfiler::VisitOMPReverseDirective(const OMPReverseDirective *S) {
VisitOMPLoopTransformationDirective(S);
}
+void StmtProfiler::VisitOMPInterchangeDirective(
+ const OMPInterchangeDirective *S) {
+ VisitOMPLoopTransformationDirective(S);
+}
+
void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) {
VisitOMPLoopDirective(S);
}
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 803808c38e2fe..ff5d5c8bdc981 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -684,7 +684,8 @@ bool clang::isOpenMPLoopBoundSharingDirective(OpenMPDirectiveKind Kind) {
}
bool clang::isOpenMPLoopTransformationDirective(OpenMPDirectiveKind DKind) {
- return DKind == OMPD_tile || DKind == OMPD_unroll || DKind == OMPD_reverse;
+ return DKind == OMPD_tile || DKind == OMPD_unroll || DKind == OMPD_reverse ||
+ DKind == OMPD_interchange;
}
bool clang::isOpenMPCombinedParallelADirective(OpenMPDirectiveKind DKind) {
@@ -809,6 +810,7 @@ void clang::getOpenMPCaptureRegions(
case OMPD_tile:
case OMPD_unroll:
case OMPD_reverse:
+ case OMPD_interchange:
// loop transformations do not introduce captures.
break;
case OMPD_threadprivate:
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 93c2f8900dd12..ba7c52cc6ab7b 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -225,6 +225,9 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
case Stmt::OMPReverseDirectiveClass:
EmitOMPReverseDirective(cast<OMPReverseDirective>(*S));
break;
+ case Stmt::OMPInterchangeDirectiveClass:
+ EmitOMPInterchangeDirective(cast<OMPInterchangeDirective>(*S));
+ break;
case Stmt::OMPForDirectiveClass:
EmitOMPForDirective(cast<OMPForDirective>(*S));
break;
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index ad6c044aa483b..7a37e452fb559 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -189,6 +189,9 @@ class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
PreInits = Unroll->getPreInits();
} else if (const auto *Reverse = dyn_cast<OMPReverseDirective>(&S)) {
PreInits = Reverse->getPreInits();
+ } else if (const auto *Interchange =
+ dyn_cast<OMPInterchangeDirective>(&S)) {
+ PreInits = Interchange->getPreInits();
} else {
llvm_unreachable("Unknown loop-based directive kind.");
}
@@ -2770,6 +2773,13 @@ void CodeGenFunction::EmitOMPReverseDirective(const OMPReverseDirective &S) {
EmitStmt(S.getTransformedStmt());
}
+void CodeGenFunction::EmitOMPInterchangeDirective(
+ const OMPInterchangeDirective &S) {
+ // Emit the de-sugared statement.
+ OMPTransformDirectiveScopeRAII InterchangeScope(*this, &S);
+ EmitStmt(S.getTransformedStmt());
+}
+
void CodeGenFunction::EmitOMPUnrollDirective(const OMPUnrollDirective &S) {
bool UseOMPIRBuilder = CGM.getLangOpts().OpenMPIRBuilder;
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index ac738e1e82886..c2a8e65ca2d0a 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3808,6 +3808,7 @@ class CodeGenFunction : public CodeGenTypeCache {
void EmitOMPTileDirective(const OMPTileDirective &S);
void EmitOMPUnrollDirective(const OMPUnrollDirective &S);
void EmitOMPReverseDirective(const OMPReverseDirective &S);
+ void EmitOMPInterchangeDirective(const OMPInterchangeDirective &S);
void EmitOMPForDirective(const OMPForDirective &S);
void EmitOMPForSimdDirective(const OMPForSimdDirective &S);
void EmitOMPSectionsDirective(const OMPSectionsDirective &S);
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 57fcf6ce520ac..0e3e604203c86 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -2385,6 +2385,7 @@ Parser::DeclGroupPtrTy Parser::ParseOpenMPDeclarativeDirectiveWithExtDecl(
case OMPD_tile:
case OMPD_unroll:
case OMPD_reverse:
+ case OMPD_interchange:
case OMPD_task:
case OMPD_taskyield:
case OMPD_barrier:
@@ -2804,6 +2805,7 @@ StmtResult Parser::ParseOpenMPDeclarativeOrExecutableDirective(
case OMPD_tile:
case OMPD_unroll:
case OMPD_reverse:
+ case OMPD_interchange:
case OMPD_for:
case OMPD_for_simd:
case OMPD_sections:
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index 4de7183cde281..5991f496d3a0f 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1487,6 +1487,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
case Stmt::OMPTileDirectiveClass:
case Stmt::OMPUnrollDirectiveClass:
case Stmt::OMPReverseDirectiveClass:
+ case Stmt::OMPInterchangeDirectiveClass:
case Stmt::OMPSingleDirectiveClass:
case Stmt::OMPTargetDataDirectiveClass:
case Stmt::OMPTargetDirectiveClass:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index c2fd4de933ae4..ef141003b7d61 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -39,6 +39,7 @@
#include "llvm/ADT/IndexedMap.h"
#include "llvm/ADT/PointerEmbeddedInt.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Frontend/OpenMP/OMPAssume.h"
@@ -4335,6 +4336,7 @@ void SemaOpenMP::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind,
case OMPD_tile:
case OMPD_unroll:
case OMPD_reverse:
+ case OMPD_interchange:
break;
case OMPD_loop:
// TODO: 'loop' may require additional parameters depending on the binding.
@@ -6552,6 +6554,10 @@ StmtResult SemaOpenMP::ActOnOpenMPExecutableDirective(
"reverse directive does not support any clauses");
Res = ActOnOpenMPReverseDirective(AStmt, StartLoc, EndLoc);
break;
+ case OMPD_interchange:
+ Res = ActOnOpenMPInterchangeDirective(ClausesWithImplicit, AStmt, StartLoc,
+ EndLoc);
+ break;
case OMPD_for:
Res = ActOnOpenMPForDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc,
VarsWithInheritedDSA);
@@ -15139,6 +15145,8 @@ bool SemaOpenMP::checkTransformableLoopNest(
DependentPreInits = Dir->getPreInits();
else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform))
DependentPreInits = Dir->getPreInits();
+ else if (auto *Dir = dyn_cast<OMPInterchangeDirective>(Transform))
+ DependentPreInits = Dir->getPreInits();
else
llvm_unreachable("Unhandled loop transformation");
@@ -15937,6 +15945,160 @@ StmtResult SemaOpenMP::ActOnOpenMPReverseDirective(Stmt *AStmt,
buildPreInits(Context, PreInits));
}
+StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
+ ArrayRef<OMPClause *> Clauses, Stmt *AStmt, SourceLocation StartLoc,
+ SourceLocation EndLoc) {
+ ASTContext &Context = getASTContext();
+ DeclContext *CurContext = SemaRef.CurContext;
+ Scope *CurScope = SemaRef.getCurScope();
+
+ // Empty statement should only be possible if there already was an error.
+ if (!AStmt)
+ return StmtError();
+
+ // interchange without permutation clause swaps two loops.
+ constexpr size_t NumLoops = 2;
+
+ // Verify and diagnose loop nest.
+ SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
+ Stmt *Body = nullptr;
+ SmallVector<SmallVector<Stmt *, 0>, 2> OriginalInits;
+ if (!checkTransformableLoopNest(OMPD_interchange, AStmt, NumLoops,
+ LoopHelpers, Body, OriginalInits))
+ return StmtError();
+
+ // Delay interchange to when template is completely instantiated.
+ if (CurContext->isDependentContext())
+ return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses,
+ NumLoops, AStmt, nullptr, nullptr);
+
+ assert(LoopHelpers.size() == NumLoops &&
+ "Expecting loop iteration space dimensionaly to match number of "
+ "affected loops");
+ assert(OriginalInits.size() == NumLoops &&
+ "Expecting loop iteration space dimensionaly to match number of "
+ "affected loops");
+
+ // Decode the permutation clause.
+ constexpr uint64_t Permutation[] = {1, 0};
+
+ // Find the affected loops.
+ SmallVector<Stmt *> LoopStmts(NumLoops, nullptr);
+ collectLoopStmts(AStmt, LoopStmts);
+
+ // Collect pre-init statements on the order before the permuation.
+ SmallVector<Stmt *> PreInits;
+ for (auto I : llvm::seq<int>(NumLoops)) {
+ OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I];
+
+ assert(LoopHelper.Counters.size() == 1 &&
+ "Single-dimensional loop iteration space expected");
+ auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters.front());
+
+ std::string OrigVarName = OrigCntVar->getNameInfo().getAsString();
+ addLoopPreInits(Context, LoopHelper, LoopStmts[I], OriginalInits[I],
+ PreInits);
+ }
+
+ SmallVector<VarDecl *> PermutedIndVars;
+ PermutedIndVars.resize(NumLoops);
+ CaptureVars CopyTransformer(SemaRef);
+
+ // Create the permuted loops from the inside to the outside of the
+ // interchanged loop nest. Body of the innermost new loop is the original
+ // innermost body.
+ Stmt *Inner = Body;
+ for (auto TargetIdx : llvm::reverse(llvm::seq<int>(NumLoops))) {
+ // Get the original loop that belongs to this new position.
+ uint64_t SourceIdx = Permutation[TargetIdx];
+ OMPLoopBasedDirective::HelperExprs &SourceHelper = LoopHelpers[SourceIdx];
+ Stmt *SourceLoopStmt = LoopStmts[SourceIdx];
+ assert(SourceHelper.Counters.size() == 1 &&
+ "Single-dimensional loop iteration space expected");
+ auto *OrigCntVar = cast<DeclRefExpr>(SourceHelper.Counters.front());
+
+ // Normalized loop counter variable: From 0 to n-1, always an integer type.
+ DeclRefExpr *IterVarRef = cast<DeclRefExpr>(SourceHelper.IterationVarRef);
+ QualType IVTy = IterVarRef->getType();
+ assert(IVTy->isIntegerType() &&
+ "Expected the logical iteration counter to be an integer");
+
+ std::string OrigVarName = OrigCntVar->getNameInfo().getAsString();
+ SourceLocation OrigVarLoc = IterVarRef->getExprLoc();
+
+ // Make a copy of the NumIterations expression for each use: By the AST
+ // constraints, every expression object in a DeclContext must be unique.
+ auto MakeNumIterations = [&CopyTransformer, &SourceHelper]() -> Expr * {
+ return AssertSuccess(
+ CopyTransformer.TransformExpr(SourceHelper.NumIterations));
+ };
+
+ // Iteration variable for the permuted loop. Reuse the one from
+ // checkOpenMPLoop which will also be used to update the original loop
+ // variable.
+ std::string PermutedCntName =
+ (Twine(".permuted_") + llvm::utostr(TargetIdx) + ".iv." + OrigVarName)
+ .str();
+ auto *PermutedCntDecl = cast<VarDecl>(IterVarRef->getDecl());
+ PermutedCntDecl->setDeclName(
+ &SemaRef.P...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/93022
More information about the llvm-branch-commits
mailing list