[clang] [flang] [llvm] [openmp] [Clang][OpenMP][LoopTransformations] Add support for "#pragma omp fuse" loop transformation directive and "looprange" clause (PR #139293)

Michael Kruse via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 20 05:06:13 PDT 2025


================
@@ -14219,27 +14265,320 @@ bool SemaOpenMP::checkTransformableLoopNest(
         return false;
       },
       [&OriginalInits](OMPLoopBasedDirective *Transform) {
-        Stmt *DependentPreInits;
-        if (auto *Dir = dyn_cast<OMPTileDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPStripeDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPInterchangeDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else
-          llvm_unreachable("Unhandled loop transformation");
-
-        appendFlattenedStmtList(OriginalInits.back(), DependentPreInits);
+        updatePreInits(Transform, OriginalInits);
       });
   assert(OriginalInits.back().empty() && "No preinit after innermost loop");
   OriginalInits.pop_back();
   return Result;
 }
 
+/// Counts the total number of nested loops, including the outermost loop (the
+/// original loop). PRECONDITION of this visitor is that it must be invoked from
+/// the original loop to be analyzed. The traversal is stop for Decl's and
+/// Expr's given that they may contain inner loops that must not be counted.
+///
+/// Example AST structure for the code:
+///
+/// int main() {
+///     #pragma omp fuse
+///     {
+///         for (int i = 0; i < 100; i++) {    <-- Outer loop
+///             []() {
+///                 for(int j = 0; j < 100; j++) {}  <-- NOT A LOOP
+///             };
+///             for(int j = 0; j < 5; ++j) {}    <-- Inner loop
+///         }
+///         for (int r = 0; i < 100; i++) {    <-- Outer loop
+///             struct LocalClass {
+///                 void bar() {
+///                     for(int j = 0; j < 100; j++) {}  <-- NOT A LOOP
+///                 }
+///             };
+///             for(int k = 0; k < 10; ++k) {}    <-- Inner loop
+///             {x = 5; for(k = 0; k < 10; ++k) x += k; x}; <-- NOT A LOOP
+///         }
+///     }
+/// }
+/// Result: Loop 'i' contains 2 loops, Loop 'r' also contains 2 loops
+class NestedLoopCounterVisitor : public DynamicRecursiveASTVisitor {
+private:
+  unsigned NestedLoopCount = 0;
+
+public:
+  explicit NestedLoopCounterVisitor() = default;
+
+  unsigned getNestedLoopCount() const { return NestedLoopCount; }
+
+  bool VisitForStmt(ForStmt *FS) override {
+    ++NestedLoopCount;
+    return true;
+  }
+
+  bool VisitCXXForRangeStmt(CXXForRangeStmt *FRS) override {
+    ++NestedLoopCount;
+    return true;
+  }
+
+  bool TraverseStmt(Stmt *S) override {
+    if (!S)
+      return true;
+
+    // Skip traversal of all expressions, including special cases like
+    // LambdaExpr, StmtExpr, BlockExpr, and RequiresExpr. These expressions
+    // may contain inner statements (and even loops), but they are not part
+    // of the syntactic body of the surrounding loop structure.
+    //  Therefore must not be counted
+    if (isa<Expr>(S))
+      return true;
+
+    // Only recurse into CompoundStmt (block {}) and loop bodies
+    if (isa<CompoundStmt, ForStmt, CXXForRangeStmt>(S)) {
+      return DynamicRecursiveASTVisitor::TraverseStmt(S);
+    }
+
+    // Stop traversal of the rest of statements, that break perfect
+    // loop nesting, such as control flow (IfStmt, SwitchStmt...)
+    return true;
+  }
+
+  bool TraverseDecl(Decl *D) override {
+    // Stop in the case of finding a declaration, it is not important
+    // in order to find nested loops (Possible CXXRecordDecl, RecordDecl,
+    // FunctionDecl...)
+    return true;
+  }
+};
+
+bool SemaOpenMP::analyzeLoopSequence(
+    Stmt *LoopSeqStmt, unsigned &LoopSeqSize, unsigned &NumLoops,
+    SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
+    SmallVectorImpl<Stmt *> &ForStmts,
+    SmallVectorImpl<SmallVector<Stmt *>> &OriginalInits,
+    SmallVectorImpl<SmallVector<Stmt *>> &TransformsPreInits,
+    SmallVectorImpl<SmallVector<Stmt *>> &LoopSequencePreInits,
+    SmallVectorImpl<OMPLoopCategory> &LoopCategories, ASTContext &Context,
+    OpenMPDirectiveKind Kind) {
+
+  VarsWithInheritedDSAType TmpDSA;
+  /// Helper Lambda to handle storing initialization and body statements for
+  /// both ForStmt and CXXForRangeStmt
+  auto StoreLoopStatements = [&](Stmt *LoopStmt) {
+    if (auto *For = dyn_cast<ForStmt>(LoopStmt)) {
+      OriginalInits.back().push_back(For->getInit());
+      ForStmts.push_back(For);
+    } else {
+      auto *CXXFor = cast<CXXForRangeStmt>(LoopStmt);
+      OriginalInits.back().push_back(CXXFor->getBeginStmt());
+      ForStmts.push_back(CXXFor);
+    }
+  };
+
+  /// Helper lambda functions to encapsulate the processing of different
+  /// derivations of the canonical loop sequence grammar
+  /// Modularized code for handling loop generation and transformations
+  auto AnalyzeLoopGeneration = [&](Stmt *Child) {
+    auto *LoopTransform = dyn_cast<OMPLoopTransformationDirective>(Child);
+    Stmt *TransformedStmt = LoopTransform->getTransformedStmt();
+    unsigned NumGeneratedLoopNests = LoopTransform->getNumGeneratedLoopNests();
+    unsigned NumGeneratedLoops = LoopTransform->getNumGeneratedLoops();
+    // Handle the case where transformed statement is not available due to
+    // dependent contexts
+    if (!TransformedStmt) {
+      if (NumGeneratedLoopNests > 0) {
+        LoopSeqSize += NumGeneratedLoopNests;
+        NumLoops += NumGeneratedLoops;
+        return true;
+      } else {
+        // Unroll full (0 loops produced)
+        Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+            << 0 << getOpenMPDirectiveName(Kind);
+        return false;
+      }
+    }
+    // Handle loop transformations with multiple loop nests
+    // Unroll full
+    if (NumGeneratedLoopNests <= 0) {
+      Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+          << 0 << getOpenMPDirectiveName(Kind);
+      return false;
+    }
+    // Loop transformatons such as split or loopranged fuse
+    else if (NumGeneratedLoopNests > 1) {
+      // Get the preinits related to this loop sequence generating
+      // loop transformation (i.e loopranged fuse, split...)
+      LoopSequencePreInits.emplace_back();
+      // These preinits differ slightly from regular inits/pre-inits related
+      // to single loop generating loop transformations (interchange, unroll)
+      // given that they are not bounded to a particular loop nest
+      // so they need to be treated independently
+      updatePreInits(LoopTransform, LoopSequencePreInits);
+      return analyzeLoopSequence(TransformedStmt, LoopSeqSize, NumLoops,
+                                 LoopHelpers, ForStmts, OriginalInits,
+                                 TransformsPreInits, LoopSequencePreInits,
+                                 LoopCategories, Context, Kind);
+    } else {
+      // Vast majority: (Tile, Unroll, Stripe, Reverse, Interchange, Fuse all)
+      // Process the transformed loop statement
+      OriginalInits.emplace_back();
+      TransformsPreInits.emplace_back();
+      LoopHelpers.emplace_back();
+      LoopCategories.push_back(OMPLoopCategory::TransformSingleLoop);
+
+      unsigned IsCanonical =
+          checkOpenMPLoop(Kind, nullptr, nullptr, TransformedStmt, SemaRef,
+                          *DSAStack, TmpDSA, LoopHelpers[LoopSeqSize]);
+
+      if (!IsCanonical) {
+        Diag(TransformedStmt->getBeginLoc(), diag::err_omp_not_canonical_loop)
+            << getOpenMPDirectiveName(Kind);
+        return false;
+      }
+      StoreLoopStatements(TransformedStmt);
+      updatePreInits(LoopTransform, TransformsPreInits);
+
+      NumLoops += NumGeneratedLoops;
+      ++LoopSeqSize;
+      return true;
+    }
+  };
+
+  /// Modularized code for handling regular canonical loops
+  auto AnalyzeRegularLoop = [&](Stmt *Child) {
+    OriginalInits.emplace_back();
+    LoopHelpers.emplace_back();
+    LoopCategories.push_back(OMPLoopCategory::RegularLoop);
+
+    unsigned IsCanonical =
+        checkOpenMPLoop(Kind, nullptr, nullptr, Child, SemaRef, *DSAStack,
+                        TmpDSA, LoopHelpers[LoopSeqSize]);
+
+    if (!IsCanonical) {
+      Diag(Child->getBeginLoc(), diag::err_omp_not_canonical_loop)
+          << getOpenMPDirectiveName(Kind);
+      return false;
+    }
+
+    StoreLoopStatements(Child);
+    auto NLCV = NestedLoopCounterVisitor();
+    NLCV.TraverseStmt(Child);
+    NumLoops += NLCV.getNestedLoopCount();
+    return true;
+  };
+
+  /// Helper functions to validate loop sequence grammar derivations
+  auto IsLoopSequenceDerivation = [](auto *Child) {
+    return isa<ForStmt, CXXForRangeStmt, OMPLoopTransformationDirective>(Child);
+  };
+  /// Helper functions to validate loop generating grammar derivations
+  auto IsLoopGeneratingStmt = [](auto *Child) {
+    return isa<OMPLoopTransformationDirective>(Child);
+  };
+
+  // High level grammar validation
+  for (auto *Child : LoopSeqStmt->children()) {
----------------
Meinersbur wrote:

```suggestion
  for (Stmt *Child : LoopSeqStmt->children()) {
```
[Don’t “almost always” use auto](https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable)

Applies to lambda arguments as well

https://github.com/llvm/llvm-project/pull/139293


More information about the llvm-commits mailing list