[clang] [openmp] [Clang][OpenMP] Fix tile/unroll on iterator- and foreach-loops. (PR #91459)

Alexey Bataev via cfe-commits cfe-commits at lists.llvm.org
Tue May 21 10:47:55 PDT 2024


================
@@ -15095,16 +15136,70 @@ bool SemaOpenMP::checkTransformableLoopNest(
           DependentPreInits = Dir->getPreInits();
         else
           llvm_unreachable("Unhandled loop transformation");
-        if (!DependentPreInits)
-          return;
-        llvm::append_range(OriginalInits.back(),
-                           cast<DeclStmt>(DependentPreInits)->getDeclGroup());
+
+        appendFlattendedStmtList(OriginalInits.back(), DependentPreInits);
       });
   assert(OriginalInits.back().empty() && "No preinit after innermost loop");
   OriginalInits.pop_back();
   return Result;
 }
 
+/// Add preinit statements that need to be propageted from the selected loop.
+static void addLoopPreInits(ASTContext &Context,
+                            OMPLoopBasedDirective::HelperExprs &LoopHelper,
+                            Stmt *LoopStmt, ArrayRef<Stmt *> OriginalInit,
+                            SmallVectorImpl<Stmt *> &PreInits) {
+
+  // For range-based for-statements, ensure that their syntactic sugar is
+  // executed by adding them as pre-init statements.
+  if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt)) {
+    Stmt *RangeInit = CXXRangeFor->getInit();
+    if (RangeInit)
+      PreInits.push_back(RangeInit);
+
+    DeclStmt *RangeStmt = CXXRangeFor->getRangeStmt();
+    PreInits.push_back(new (Context) DeclStmt(RangeStmt->getDeclGroup(),
+                                              RangeStmt->getBeginLoc(),
+                                              RangeStmt->getEndLoc()));
+
+    DeclStmt *RangeEnd = CXXRangeFor->getEndStmt();
+    PreInits.push_back(new (Context) DeclStmt(RangeEnd->getDeclGroup(),
+                                              RangeEnd->getBeginLoc(),
+                                              RangeEnd->getEndLoc()));
+  }
+
+  llvm::append_range(PreInits, OriginalInit);
+
+  // List of OMPCapturedExprDecl, for __begin, __end, and NumIterations
+  if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits)) {
+    PreInits.push_back(new (Context) DeclStmt(
+        PI->getDeclGroup(), PI->getBeginLoc(), PI->getEndLoc()));
+  }
+
+  // Gather declarations for the data members used as counters.
+  for (Expr *CounterRef : LoopHelper.Counters) {
+    auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
+    if (isa<OMPCapturedExprDecl>(CounterDecl))
+      PreInits.push_back(new (Context) DeclStmt(
+          DeclGroupRef(CounterDecl), SourceLocation(), SourceLocation()));
+  }
+}
+
+/// Collect the loop statements (ForStmt or CXXRangeForStmt) of the affected
+/// loop of a construct.
+static void collectLoopStmts(Stmt *AStmt, MutableArrayRef<Stmt *> LoopStmts) {
+  size_t NumLoops = LoopStmts.size();
+  OMPLoopBasedDirective::doForAllLoops(
+      AStmt, /*TryImperfectlyNestedLoops=*/false, NumLoops,
+      [LoopStmts](unsigned Cnt, Stmt *CurStmt) {
+        assert(!LoopStmts[Cnt] && "Loop statement must not yet be assigned");
+        LoopStmts[Cnt] = CurStmt;
+        return false;
+      });
+  assert(llvm::all_of(LoopStmts, [](Stmt *LoopStmt) { return LoopStmt; }) &&
+         "Expecting a loop statement for each affected loop");
----------------
alexey-bataev wrote:

```suggestion
  assert(!is_contained(LoopStmts, nullptr) &&
         "Expecting a loop statement for each affected loop");
```

https://github.com/llvm/llvm-project/pull/91459


More information about the cfe-commits mailing list