[Openmp-commits] [clang] [flang] [llvm] [openmp] [Clang][OpenMP][LoopTransformations] Add support for "#pragma omp fuse" loop transformation direcrive and "looprange" clause (PR #139293)
Alexey Bataev via Openmp-commits
openmp-commits at lists.llvm.org
Fri May 9 11:12:09 PDT 2025
================
@@ -14175,27 +14222,350 @@ bool SemaOpenMP::checkTransformableLoopNest(
return false;
},
[&OriginalInits](OMPLoopBasedDirective *Transform) {
- Stmt *DependentPreInits;
- if (auto *Dir = dyn_cast<OMPTileDirective>(Transform))
- DependentPreInits = Dir->getPreInits();
- else if (auto *Dir = dyn_cast<OMPStripeDirective>(Transform))
- DependentPreInits = Dir->getPreInits();
- else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform))
- DependentPreInits = Dir->getPreInits();
- else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform))
- DependentPreInits = Dir->getPreInits();
- else if (auto *Dir = dyn_cast<OMPInterchangeDirective>(Transform))
- DependentPreInits = Dir->getPreInits();
- else
- llvm_unreachable("Unhandled loop transformation");
-
- appendFlattenedStmtList(OriginalInits.back(), DependentPreInits);
+ updatePreInits(Transform, OriginalInits);
});
assert(OriginalInits.back().empty() && "No preinit after innermost loop");
OriginalInits.pop_back();
return Result;
}
+// Counts the total number of nested loops, including the outermost loop (the
+// original loop). PRECONDITION of this visitor is that it must be invoked from
+// the original loop to be analyzed. The traversal is stop for Decl's and
+// Expr's given that they may contain inner loops that must not be counted.
+//
+// Example AST structure for the code:
+//
+// int main() {
+// #pragma omp fuse
+// {
+// for (int i = 0; i < 100; i++) { <-- Outer loop
+// []() {
+// for(int j = 0; j < 100; j++) {} <-- NOT A LOOP
+// };
+// for(int j = 0; j < 5; ++j) {} <-- Inner loop
+// }
+// for (int r = 0; i < 100; i++) { <-- Outer loop
+// struct LocalClass {
+// void bar() {
+// for(int j = 0; j < 100; j++) {} <-- NOT A LOOP
+// }
+// };
+// for(int k = 0; k < 10; ++k) {} <-- Inner loop
+// {x = 5; for(k = 0; k < 10; ++k) x += k; x}; <-- NOT A LOOP
+// }
+// }
+// }
+// Result: Loop 'i' contains 2 loops, Loop 'r' also contains 2 loops
+class NestedLoopCounterVisitor : public DynamicRecursiveASTVisitor {
+private:
+ unsigned NestedLoopCount = 0;
+
+public:
+ explicit NestedLoopCounterVisitor() {}
+
+ unsigned getNestedLoopCount() const { return NestedLoopCount; }
+
+ bool VisitForStmt(ForStmt *FS) override {
+ ++NestedLoopCount;
+ return true;
+ }
+
+ bool VisitCXXForRangeStmt(CXXForRangeStmt *FRS) override {
+ ++NestedLoopCount;
+ return true;
+ }
+
+ bool TraverseStmt(Stmt *S) override {
+ if (!S)
+ return true;
+
+ // Skip traversal of all expressions, including special cases like
+ // LambdaExpr, StmtExpr, BlockExpr, and RequiresExpr. These expressions
+ // may contain inner statements (and even loops), but they are not part
+ // of the syntactic body of the surrounding loop structure.
+ // Therefore must not be counted
+ if (isa<Expr>(S))
+ return true;
+
+ // Only recurse into CompoundStmt (block {}) and loop bodies
+ if (isa<CompoundStmt>(S) || isa<ForStmt>(S) || isa<CXXForRangeStmt>(S)) {
+ return DynamicRecursiveASTVisitor::TraverseStmt(S);
+ }
+
+ // Stop traversal of the rest of statements, that break perfect
+ // loop nesting, such as control flow (IfStmt, SwitchStmt...)
+ return true;
+ }
+
+ bool TraverseDecl(Decl *D) override {
+ // Stop in the case of finding a declaration, it is not important
+ // in order to find nested loops (Possible CXXRecordDecl, RecordDecl,
+ // FunctionDecl...)
+ return true;
+ }
+};
+
+bool SemaOpenMP::analyzeLoopSequence(
+ Stmt *LoopSeqStmt, unsigned &LoopSeqSize, unsigned &NumLoops,
+ SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
+ SmallVectorImpl<Stmt *> &ForStmts,
+ SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits,
+ SmallVectorImpl<SmallVector<Stmt *, 0>> &TransformsPreInits,
+ SmallVectorImpl<SmallVector<Stmt *, 0>> &LoopSequencePreInits,
+ SmallVectorImpl<OMPLoopCategory> &LoopCategories, ASTContext &Context,
+ OpenMPDirectiveKind Kind) {
+
+ VarsWithInheritedDSAType TmpDSA;
+ QualType BaseInductionVarType;
+ // Helper Lambda to handle storing initialization and body statements for both
+ // ForStmt and CXXForRangeStmt and checks for any possible mismatch between
+ // induction variables types
+ auto storeLoopStatements = [&OriginalInits, &ForStmts, &BaseInductionVarType,
+ this, &Context](Stmt *LoopStmt) {
+ if (auto *For = dyn_cast<ForStmt>(LoopStmt)) {
+ OriginalInits.back().push_back(For->getInit());
+ ForStmts.push_back(For);
+ // Extract induction variable
+ if (auto *InitStmt = dyn_cast_or_null<DeclStmt>(For->getInit())) {
+ if (auto *InitDecl = dyn_cast<VarDecl>(InitStmt->getSingleDecl())) {
+ QualType InductionVarType = InitDecl->getType().getCanonicalType();
+
+ // Compare with first loop type
+ if (BaseInductionVarType.isNull()) {
+ BaseInductionVarType = InductionVarType;
+ } else if (!Context.hasSameType(BaseInductionVarType,
+ InductionVarType)) {
+ Diag(InitDecl->getBeginLoc(),
+ diag::warn_omp_different_loop_ind_var_types)
+ << getOpenMPDirectiveName(OMPD_fuse) << BaseInductionVarType
+ << InductionVarType;
+ }
+ }
+ }
+ } else {
+ auto *CXXFor = cast<CXXForRangeStmt>(LoopStmt);
+ OriginalInits.back().push_back(CXXFor->getBeginStmt());
+ ForStmts.push_back(CXXFor);
+ }
+ };
+
+ // Helper lambda functions to encapsulate the processing of different
+ // derivations of the canonical loop sequence grammar
+ //
+ // Modularized code for handling loop generation and transformations
+ auto analyzeLoopGeneration = [&storeLoopStatements, &LoopHelpers,
+ &OriginalInits, &TransformsPreInits,
+ &LoopCategories, &LoopSeqSize, &NumLoops, Kind,
+ &TmpDSA, &ForStmts, &Context,
+ &LoopSequencePreInits, this](Stmt *Child) {
+ auto LoopTransform = dyn_cast<OMPLoopTransformationDirective>(Child);
+ Stmt *TransformedStmt = LoopTransform->getTransformedStmt();
+ unsigned NumGeneratedLoopNests = LoopTransform->getNumGeneratedLoopNests();
+ unsigned NumGeneratedLoops = LoopTransform->getNumGeneratedLoops();
+ // Handle the case where transformed statement is not available due to
+ // dependent contexts
+ if (!TransformedStmt) {
+ if (NumGeneratedLoopNests > 0) {
+ LoopSeqSize += NumGeneratedLoopNests;
+ NumLoops += NumGeneratedLoops;
+ return true;
+ }
+ // Unroll full (0 loops produced)
+ else {
+ Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+ << 0 << getOpenMPDirectiveName(Kind);
+ return false;
+ }
+ }
+ // Handle loop transformations with multiple loop nests
+ // Unroll full
+ if (NumGeneratedLoopNests <= 0) {
+ Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+ << 0 << getOpenMPDirectiveName(Kind);
+ return false;
+ }
+ // Loop transformatons such as split or loopranged fuse
+ else if (NumGeneratedLoopNests > 1) {
+ // Get the preinits related to this loop sequence generating
+ // loop transformation (i.e loopranged fuse, split...)
+ LoopSequencePreInits.emplace_back();
+ // These preinits differ slightly from regular inits/pre-inits related
+ // to single loop generating loop transformations (interchange, unroll)
+ // given that they are not bounded to a particular loop nest
+ // so they need to be treated independently
+ updatePreInits(LoopTransform, LoopSequencePreInits);
+ return analyzeLoopSequence(TransformedStmt, LoopSeqSize, NumLoops,
+ LoopHelpers, ForStmts, OriginalInits,
+ TransformsPreInits, LoopSequencePreInits,
+ LoopCategories, Context, Kind);
+ }
+ // Vast majority: (Tile, Unroll, Stripe, Reverse, Interchange, Fuse all)
+ else {
+ // Process the transformed loop statement
+ OriginalInits.emplace_back();
+ TransformsPreInits.emplace_back();
+ LoopHelpers.emplace_back();
+ LoopCategories.push_back(OMPLoopCategory::TransformSingleLoop);
+
+ unsigned IsCanonical =
+ checkOpenMPLoop(Kind, nullptr, nullptr, TransformedStmt, SemaRef,
+ *DSAStack, TmpDSA, LoopHelpers[LoopSeqSize]);
+
+ if (!IsCanonical) {
+ Diag(TransformedStmt->getBeginLoc(), diag::err_omp_not_canonical_loop)
+ << getOpenMPDirectiveName(Kind);
+ return false;
+ }
+ storeLoopStatements(TransformedStmt);
+ updatePreInits(LoopTransform, TransformsPreInits);
+
+ NumLoops += NumGeneratedLoops;
+ ++LoopSeqSize;
+ return true;
+ }
+ };
+
+ // Modularized code for handling regular canonical loops
+ auto analyzeRegularLoop = [&storeLoopStatements, &LoopHelpers, &OriginalInits,
+ &LoopSeqSize, &NumLoops, Kind, &TmpDSA,
+ &LoopCategories, this](Stmt *Child) {
+ OriginalInits.emplace_back();
+ LoopHelpers.emplace_back();
+ LoopCategories.push_back(OMPLoopCategory::RegularLoop);
+
+ unsigned IsCanonical =
+ checkOpenMPLoop(Kind, nullptr, nullptr, Child, SemaRef, *DSAStack,
+ TmpDSA, LoopHelpers[LoopSeqSize]);
+
+ if (!IsCanonical) {
+ Diag(Child->getBeginLoc(), diag::err_omp_not_canonical_loop)
+ << getOpenMPDirectiveName(Kind);
+ return false;
+ }
+
+ storeLoopStatements(Child);
+ auto NLCV = NestedLoopCounterVisitor();
+ NLCV.TraverseStmt(Child);
+ NumLoops += NLCV.getNestedLoopCount();
+ return true;
+ };
+
+ // Helper functions to validate canonical loop sequence grammar is valid
+ auto isLoopSequenceDerivation = [](auto *Child) {
+ return isa<ForStmt>(Child) || isa<CXXForRangeStmt>(Child) ||
+ isa<OMPLoopTransformationDirective>(Child);
+ };
+ auto isLoopGeneratingStmt = [](auto *Child) {
----------------
alexey-bataev wrote:
```suggestion
auto IsLoopGeneratingStmt = [](auto *Child) {
```
https://github.com/llvm/llvm-project/pull/139293
More information about the Openmp-commits
mailing list