[Openmp-commits] [clang] [flang] [llvm] [openmp] [Clang][OpenMP][LoopTransformations] Add support for "#pragma omp fuse" loop transformation directive and "looprange" clause (PR #139293)

Alexey Bataev via Openmp-commits openmp-commits at lists.llvm.org
Mon Sep 22 13:48:29 PDT 2025


================
@@ -14433,29 +14465,268 @@ bool SemaOpenMP::checkTransformableLoopNest(
         OriginalInits.emplace_back();
         return false;
       },
-      [&OriginalInits](OMPLoopBasedDirective *Transform) {
-        Stmt *DependentPreInits;
-        if (auto *Dir = dyn_cast<OMPTileDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPStripeDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPInterchangeDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else
-          llvm_unreachable("Unhandled loop transformation");
-
-        appendFlattenedStmtList(OriginalInits.back(), DependentPreInits);
+      [&OriginalInits](OMPLoopTransformationDirective *Transform) {
+        updatePreInits(Transform, OriginalInits.back());
       });
   assert(OriginalInits.back().empty() && "No preinit after innermost loop");
   OriginalInits.pop_back();
   return Result;
 }
 
-/// Add preinit statements that need to be propageted from the selected loop.
+/// Counts the total number of OpenMP canonical nested loops, including the
+/// outermost loop (the original loop). PRECONDITION of this visitor is that it
+/// must be invoked from the original loop to be analyzed. The traversal stops
+/// for Decl's and Expr's given that they may contain inner loops that must not
+/// be counted.
+///
+/// Example AST structure for the code:
+///
+/// int main() {
+///     #pragma omp fuse
+///     {
+///         for (int i = 0; i < 100; i++) {    <-- Outer loop
+///             []() {
+///                 for(int j = 0; j < 100; j++) {}  <-- NOT A LOOP (1)
+///             };
+///             for(int j = 0; j < 5; ++j) {}    <-- Inner loop
+///         }
+///         for (int r = 0; i < 100; i++) {    <-- Outer loop
+///             struct LocalClass {
+///                 void bar() {
+///                     for(int j = 0; j < 100; j++) {}  <-- NOT A LOOP (2)
+///                 }
+///             };
+///             for(int k = 0; k < 10; ++k) {}    <-- Inner loop
+///             {x = 5; for(k = 0; k < 10; ++k) x += k; x}; <-- NOT A LOOP (3)
+///         }
+///     }
+/// }
+/// (1) because in a different function (here: a lambda)
+/// (2) because in a different function (here: class method)
+/// (3) because considered to be intervening-code of non-perfectly nested loop
+/// Result: Loop 'i' contains 2 loops, Loop 'r' also contains 2 loops.
+class NestedLoopCounterVisitor final : public DynamicRecursiveASTVisitor {
+private:
+  unsigned NestedLoopCount = 0;
+
+public:
+  explicit NestedLoopCounterVisitor() = default;
+
+  unsigned getNestedLoopCount() const { return NestedLoopCount; }
+
+  bool VisitForStmt(ForStmt *FS) override {
+    ++NestedLoopCount;
+    return true;
+  }
+
+  bool VisitCXXForRangeStmt(CXXForRangeStmt *FRS) override {
+    ++NestedLoopCount;
+    return true;
+  }
+
+  bool TraverseStmt(Stmt *S) override {
+    if (!S)
+      return true;
+
+    // Skip traversal of all expressions, including special cases like
+    // LambdaExpr, StmtExpr, BlockExpr, and RequiresExpr. These expressions
+    // may contain inner statements (and even loops), but they are not part
+    // of the syntactic body of the surrounding loop structure.
+    //  Therefore must not be counted.
+    if (isa<Expr>(S))
+      return true;
+
+    // Only recurse into CompoundStmt (block {}) and loop bodies.
+    if (isa<CompoundStmt, ForStmt, CXXForRangeStmt>(S)) {
+      return DynamicRecursiveASTVisitor::TraverseStmt(S);
+    }
+
+    // Stop traversal of the rest of statements, that break perfect
+    // loop nesting, such as control flow (IfStmt, SwitchStmt...).
+    return true;
+  }
+
+  bool TraverseDecl(Decl *D) override {
+    // Stop in the case of finding a declaration, it is not important
+    // in order to find nested loops (Possible CXXRecordDecl, RecordDecl,
+    // FunctionDecl...).
+    return true;
+  }
+};
+
+bool SemaOpenMP::analyzeLoopSequence(Stmt *LoopSeqStmt,
+                                     LoopSequenceAnalysis &SeqAnalysis,
+                                     ASTContext &Context,
+                                     OpenMPDirectiveKind Kind) {
+  VarsWithInheritedDSAType TmpDSA;
+  // Helper Lambda to handle storing initialization and body statements for
+  // both ForStmt and CXXForRangeStmt.
+  auto StoreLoopStatements = [](LoopAnalysis &Analysis, Stmt *LoopStmt) {
+    if (auto *For = dyn_cast<ForStmt>(LoopStmt)) {
+      Analysis.OriginalInits.push_back(For->getInit());
+      Analysis.TheForStmt = For;
+    } else {
+      auto *CXXFor = cast<CXXForRangeStmt>(LoopStmt);
+      Analysis.OriginalInits.push_back(CXXFor->getBeginStmt());
+      Analysis.TheForStmt = CXXFor;
+    }
+  };
+
+  // Helper lambda functions to encapsulate the processing of different
+  // derivations of the canonical loop sequence grammar
+  // Modularized code for handling loop generation and transformations.
+  auto AnalyzeLoopGeneration = [&](Stmt *Child) {
+    auto *LoopTransform = cast<OMPLoopTransformationDirective>(Child);
+    Stmt *TransformedStmt = LoopTransform->getTransformedStmt();
+    unsigned NumGeneratedTopLevelLoops =
+        LoopTransform->getNumGeneratedTopLevelLoops();
+    // Handle the case where transformed statement is not available due to
+    // dependent contexts
+    if (!TransformedStmt) {
+      if (NumGeneratedTopLevelLoops > 0) {
+        SeqAnalysis.LoopSeqSize += NumGeneratedTopLevelLoops;
+        return true;
+      }
+      // Unroll full (0 loops produced)
+      Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+          << 0 << getOpenMPDirectiveName(Kind);
+      return false;
+    }
+    // Handle loop transformations with multiple loop nests
+    // Unroll full
+    if (!NumGeneratedTopLevelLoops) {
+      Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+          << 0 << getOpenMPDirectiveName(Kind);
+      return false;
+    }
+    // Loop transformatons such as split or loopranged fuse
+    if (NumGeneratedTopLevelLoops > 1) {
+      // Get the preinits related to this loop sequence generating
+      // loop transformation (i.e loopranged fuse, split...)
+      // These preinits differ slightly from regular inits/pre-inits related
+      // to single loop generating loop transformations (interchange, unroll)
+      // given that they are not bounded to a particular loop nest
+      // so they need to be treated independently
+      updatePreInits(LoopTransform, SeqAnalysis.LoopSequencePreInits);
+      return analyzeLoopSequence(TransformedStmt, SeqAnalysis, Context, Kind);
+    }
+    // Vast majority: (Tile, Unroll, Stripe, Reverse, Interchange, Fuse all)
+    // Process the transformed loop statement
+    LoopAnalysis &NewTransformedSingleLoop =
+        SeqAnalysis.Loops.emplace_back(Child);
+    unsigned IsCanonical = checkOpenMPLoop(
+        Kind, nullptr, nullptr, TransformedStmt, SemaRef, *DSAStack, TmpDSA,
+        NewTransformedSingleLoop.HelperExprs);
+
+    if (!IsCanonical)
+      return false;
+
+    StoreLoopStatements(NewTransformedSingleLoop, TransformedStmt);
+    updatePreInits(LoopTransform, NewTransformedSingleLoop.TransformsPreInits);
+
+    SeqAnalysis.LoopSeqSize++;
+    return true;
+  };
+
+  // Modularized code for handling regular canonical loops.
+  auto AnalyzeRegularLoop = [&](Stmt *Child) {
+    LoopAnalysis &NewRegularLoop = SeqAnalysis.Loops.emplace_back(Child);
+    unsigned IsCanonical =
+        checkOpenMPLoop(Kind, nullptr, nullptr, Child, SemaRef, *DSAStack,
+                        TmpDSA, NewRegularLoop.HelperExprs);
+
+    if (!IsCanonical)
+      return false;
+
+    StoreLoopStatements(NewRegularLoop, Child);
+    NestedLoopCounterVisitor NLCV;
+    NLCV.TraverseStmt(Child);
+    return true;
+  };
+
+  // High level grammar validation.
+  for (Stmt *Child : LoopSeqStmt->children()) {
+    if (!Child)
+      continue;
+    // Skip over non-loop-sequence statements.
+    if (!LoopSequenceAnalysis::isLoopSequenceDerivation(Child)) {
+      Child = Child->IgnoreContainers();
+      // Ignore empty compound statement.
+      if (!Child)
+        continue;
+      // In the case of a nested loop sequence ignoring containers would not
+      // be enough, a recurisve transversal of the loop sequence is required.
+      if (isa<CompoundStmt>(Child)) {
+        if (!analyzeLoopSequence(Child, SeqAnalysis, Context, Kind))
+          return false;
+        // Already been treated, skip this children
+        continue;
+      }
+    }
+    // Regular loop sequence handling.
+    if (LoopSequenceAnalysis::isLoopSequenceDerivation(Child)) {
+      if (LoopAnalysis::isLoopTransformation(Child)) {
+        if (!AnalyzeLoopGeneration(Child))
+          return false;
+        // AnalyzeLoopGeneration updates SeqAnalysis.LoopSeqSize accordingly.
+      } else {
+        if (!AnalyzeRegularLoop(Child))
+          return false;
+        SeqAnalysis.LoopSeqSize++;
+      }
+    } else {
+      // Report error for invalid statement inside canonical loop sequence.
+      Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+          << 0 << getOpenMPDirectiveName(Kind);
+      return false;
+    }
+  }
+  return true;
+}
+
+bool SemaOpenMP::checkTransformableLoopSequence(
+    OpenMPDirectiveKind Kind, Stmt *AStmt, LoopSequenceAnalysis &SeqAnalysis,
+    ASTContext &Context) {
+  // Following OpenMP 6.0 API Specification, a Canonical Loop Sequence follows
+  // the grammar:
+  //
+  // canonical-loop-sequence:
+  //  {
+  //    loop-sequence+
+  //  }
+  // where loop-sequence can be any of the following:
+  // 1. canonical-loop-sequence
+  // 2. loop-nest
+  // 3. loop-sequence-generating-construct (i.e OMPLoopTransformationDirective)
+  //
+  // To recognise and traverse this structure the helper function
+  // analyzeLoopSequence serves as the recurisve entry point
+  // and tries to match the input AST to the canonical loop sequence grammar
+  // structure. This function will perform both a semantic and syntactical
+  // analysis of the given statement according to OpenMP 6.0 definition of
+  // the aforementioned canonical loop sequence.
+
+  // We expect an outer compound statement.
+  if (!isa<CompoundStmt>(AStmt)) {
+    Diag(AStmt->getBeginLoc(), diag::err_omp_not_a_loop_sequence)
+        << getOpenMPDirectiveName(Kind);
+    return false;
+  }
+
+  // Recursive entry point to process the main loop sequence
+  if (!analyzeLoopSequence(AStmt, SeqAnalysis, Context, Kind))
+    return false;
+
+  // Diagnose an empty loop sequence.
+  if (SeqAnalysis.LoopSeqSize <= 0) {
----------------
alexey-bataev wrote:

LoopSeqSize unsigned, just ==0 should be enough

https://github.com/llvm/llvm-project/pull/139293


More information about the Openmp-commits mailing list