[clang] [flang] [llvm] [openmp] [Clang][OpenMP][LoopTransformations] Add support for "#pragma omp fuse" loop transformation directive and "looprange" clause (PR #139293)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Sat Sep 20 07:35:14 PDT 2025
================
@@ -15716,6 +15987,489 @@ StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
buildPreInits(Context, PreInits));
}
+StmtResult SemaOpenMP::ActOnOpenMPFuseDirective(ArrayRef<OMPClause *> Clauses,
+ Stmt *AStmt,
+ SourceLocation StartLoc,
+ SourceLocation EndLoc) {
+
+ ASTContext &Context = getASTContext();
+ DeclContext *CurrContext = SemaRef.CurContext;
+ Scope *CurScope = SemaRef.getCurScope();
+ CaptureVars CopyTransformer(SemaRef);
+
+ // Ensure the structured block is not empty
+ if (!AStmt)
+ return StmtError();
+
+ // Defer transformation in dependent contexts
+ // The NumLoopNests argument is set to a placeholder 1 (even though
+ // using looprange fuse could yield up to 3 top level loop nests)
+ // because a dependent context could prevent determining its true value
+ if (CurrContext->isDependentContext()) {
+ return OMPFuseDirective::Create(Context, StartLoc, EndLoc, Clauses,
+ /* NumLoops */ 1, /* LoopSeqSize */ 1,
+ AStmt, nullptr, nullptr);
+ }
+
+ // Validate that the potential loop sequence is transformable for fusion
+ // Also collect the HelperExprs, Loop Stmts, Inits, and Number of loops
+ LoopSequenceAnalysis SeqAnalysis;
+ if (!checkTransformableLoopSequence(OMPD_fuse, AStmt, SeqAnalysis, Context))
+ return StmtError();
+
+ // SeqAnalysis.LoopSeqSize exists mostly to handle dependent contexts,
+ // otherwise it must be the same as SeqAnalysis.Loops.size().
+ assert(SeqAnalysis.LoopSeqSize == SeqAnalysis.Loops.size());
+
+ // Handle clauses, which can be any of the following: [looprange, apply]
+ const OMPLoopRangeClause *LRC =
+ OMPExecutableDirective::getSingleClause<OMPLoopRangeClause>(Clauses);
+
+ // The clause arguments are invalidated if any error arises
+ // such as non-constant or non-positive arguments
+ if (LRC && (!LRC->getFirst() || !LRC->getCount()))
+ return StmtError();
+
+ // Delayed semantic check of LoopRange constraint
+ // Evaluates the loop range arguments and returns the first and count values
+ auto EvaluateLoopRangeArguments = [&Context](Expr *First, Expr *Count,
+ uint64_t &FirstVal,
+ uint64_t &CountVal) {
+ llvm::APSInt FirstInt = First->EvaluateKnownConstInt(Context);
+ llvm::APSInt CountInt = Count->EvaluateKnownConstInt(Context);
+ FirstVal = FirstInt.getZExtValue();
+ CountVal = CountInt.getZExtValue();
+ };
+
+ // OpenMP [6.0, Restrictions]
+ // first + count - 1 must not evaluate to a value greater than the
+ // loop sequence length of the associated canonical loop sequence.
+ auto ValidLoopRange = [](uint64_t FirstVal, uint64_t CountVal,
+ unsigned NumLoops) -> bool {
+ return FirstVal + CountVal - 1 <= NumLoops;
+ };
+ uint64_t FirstVal = 1, CountVal = 0, LastVal = SeqAnalysis.LoopSeqSize;
+
+ // Validates the loop range after evaluating the semantic information
+ // and ensures that the range is valid for the given loop sequence size.
+ // Expressions are evaluated at compile time to obtain constant values.
+ if (LRC) {
+ EvaluateLoopRangeArguments(LRC->getFirst(), LRC->getCount(), FirstVal,
+ CountVal);
+ if (CountVal == 1)
+ SemaRef.Diag(LRC->getCountLoc(), diag::warn_omp_redundant_fusion)
+ << getOpenMPDirectiveName(OMPD_fuse);
+
+ if (!ValidLoopRange(FirstVal, CountVal, SeqAnalysis.LoopSeqSize)) {
+ SemaRef.Diag(LRC->getFirstLoc(), diag::err_omp_invalid_looprange)
+ << getOpenMPDirectiveName(OMPD_fuse) << FirstVal
+ << (FirstVal + CountVal - 1) << SeqAnalysis.LoopSeqSize;
+ return StmtError();
+ }
+
+ LastVal = FirstVal + CountVal - 1;
+ }
+
+ // Complete fusion generates a single canonical loop nest
+ // However looprange clause may generate several loop nests
+ unsigned NumGeneratedTopLevelLoops =
+ LRC ? SeqAnalysis.LoopSeqSize - CountVal + 1 : 1;
+
+ // Emit a warning for redundant loop fusion when the sequence contains only
+ // one loop.
+ if (SeqAnalysis.LoopSeqSize == 1)
+ SemaRef.Diag(AStmt->getBeginLoc(), diag::warn_omp_redundant_fusion)
+ << getOpenMPDirectiveName(OMPD_fuse);
+
+ // Select the type with the largest bit width among all induction variables
+ QualType IVType =
+ SeqAnalysis.Loops[FirstVal - 1].HelperExprs.IterationVarRef->getType();
+ for (unsigned I = FirstVal; I < LastVal; ++I) {
+ QualType CurrentIVType =
+ SeqAnalysis.Loops[I].HelperExprs.IterationVarRef->getType();
+ if (Context.getTypeSize(CurrentIVType) > Context.getTypeSize(IVType)) {
+ IVType = CurrentIVType;
+ }
+ }
+ uint64_t IVBitWidth = Context.getIntWidth(IVType);
+
+ // Create pre-init declarations for all loops lower bounds, upper bounds,
+ // strides and num-iterations for every top level loop in the fusion
+ SmallVector<VarDecl *, 4> LBVarDecls;
+ SmallVector<VarDecl *, 4> STVarDecls;
+ SmallVector<VarDecl *, 4> NIVarDecls;
+ SmallVector<VarDecl *, 4> UBVarDecls;
+ SmallVector<VarDecl *, 4> IVVarDecls;
+
+ // Helper lambda to create variables for bounds, strides, and other
+ // expressions. Generates both the variable declaration and the corresponding
+ // initialization statement.
+ auto CreateHelperVarAndStmt =
+ [&, &SemaRef = SemaRef](Expr *ExprToCopy, const std::string &BaseName,
+ unsigned I, bool NeedsNewVD = false) {
+ Expr *TransformedExpr =
+ AssertSuccess(CopyTransformer.TransformExpr(ExprToCopy));
+ if (!TransformedExpr)
+ return std::pair<VarDecl *, StmtResult>(nullptr, StmtError());
+
+ auto Name = (Twine(".omp.") + BaseName + std::to_string(I)).str();
+
+ VarDecl *VD;
+ if (NeedsNewVD) {
+ VD = buildVarDecl(SemaRef, SourceLocation(), IVType, Name);
+ SemaRef.AddInitializerToDecl(VD, TransformedExpr, false);
+
+ } else {
+ // Create a unique variable name
+ DeclRefExpr *DRE = cast<DeclRefExpr>(TransformedExpr);
+ VD = cast<VarDecl>(DRE->getDecl());
+ VD->setDeclName(&SemaRef.PP.getIdentifierTable().get(Name));
+ }
+ // Create the corresponding declaration statement
+ StmtResult DeclStmt = new (Context) class DeclStmt(
+ DeclGroupRef(VD), SourceLocation(), SourceLocation());
+ return std::make_pair(VD, DeclStmt);
+ };
+
+ // PreInits hold a sequence of variable declarations that must be executed
+ // before the fused loop begins. These include bounds, strides, and other
+ // helper variables required for the transformation. Other loop transforms
+ // also contain their own preinits
+ SmallVector<Stmt *> PreInits;
+
+ // Update the general preinits using the preinits generated by loop sequence
+ // generating loop transformations. These preinits differ slightly from
+ // single-loop transformation preinits, as they can be detached from a
+ // specific loop inside multiple generated loop nests. This happens
+ // because certain helper variables, like '.omp.fuse.max', are introduced to
+ // handle fused iteration spaces and may not be directly tied to a single
+ // original loop. The preinit structure must ensure that hidden variables
+ // like '.omp.fuse.max' are still properly handled.
+ // Transformations that apply this concept: Loopranged Fuse, Split
+ if (!SeqAnalysis.LoopSequencePreInits.empty()) {
+ llvm::append_range(PreInits, SeqAnalysis.LoopSequencePreInits);
+ }
+
+ // Process each single loop to generate and collect declarations
+ // and statements for all helper expressions related to
+ // particular single loop nests
+
+ // Also In the case of the fused loops, we keep track of their original
+ // inits by appending them to their preinits statement, and in the case of
+ // transformations, also append their preinits (which contain the original
+ // loop initialization statement or other statements)
+
+ // Firstly we need to set TransformIndex to match the begining of the
+ // looprange section
+ unsigned int TransformIndex = 0;
+ for (unsigned I : llvm::seq<unsigned>(FirstVal - 1)) {
+ if (SeqAnalysis.Loops[I].isLoopTransformation())
+ ++TransformIndex;
+ }
+
+ for (unsigned int I = FirstVal - 1, J = 0; I < LastVal; ++I, ++J) {
+ if (SeqAnalysis.Loops[I].isRegularLoop()) {
+ addLoopPreInits(Context, SeqAnalysis.Loops[I].HelperExprs,
+ SeqAnalysis.Loops[I].ForStmt,
+ SeqAnalysis.Loops[I].OriginalInits, PreInits);
+ } else if (SeqAnalysis.Loops[I].isLoopTransformation()) {
+ // For transformed loops, insert both pre-inits and original inits.
+ // Order matters: pre-inits may define variables used in the original
+ // inits such as upper bounds...
+ SmallVector<Stmt *> &TransformPreInit =
+ SeqAnalysis.Loops[TransformIndex++].TransformsPreInits;
+ if (!TransformPreInit.empty())
+ llvm::append_range(PreInits, TransformPreInit);
+
+ addLoopPreInits(Context, SeqAnalysis.Loops[I].HelperExprs,
+ SeqAnalysis.Loops[I].ForStmt,
+ SeqAnalysis.Loops[I].OriginalInits, PreInits);
+ }
+ auto [UBVD, UBDStmt] =
+ CreateHelperVarAndStmt(SeqAnalysis.Loops[I].HelperExprs.UB, "ub", J);
+ auto [LBVD, LBDStmt] =
+ CreateHelperVarAndStmt(SeqAnalysis.Loops[I].HelperExprs.LB, "lb", J);
+ auto [STVD, STDStmt] =
+ CreateHelperVarAndStmt(SeqAnalysis.Loops[I].HelperExprs.ST, "st", J);
+ auto [NIVD, NIDStmt] = CreateHelperVarAndStmt(
+ SeqAnalysis.Loops[I].HelperExprs.NumIterations, "ni", J, true);
+ auto [IVVD, IVDStmt] = CreateHelperVarAndStmt(
+ SeqAnalysis.Loops[I].HelperExprs.IterationVarRef, "iv", J);
+
+ assert(LBVD && STVD && NIVD && IVVD &&
+ "OpenMP Fuse Helper variables creation failed");
+
+ UBVarDecls.push_back(UBVD);
+ LBVarDecls.push_back(LBVD);
+ STVarDecls.push_back(STVD);
+ NIVarDecls.push_back(NIVD);
+ IVVarDecls.push_back(IVVD);
+
+ PreInits.push_back(LBDStmt.get());
+ PreInits.push_back(STDStmt.get());
+ PreInits.push_back(NIDStmt.get());
+ PreInits.push_back(IVDStmt.get());
+ }
+
+ auto MakeVarDeclRef = [&SemaRef = this->SemaRef](VarDecl *VD) {
+ return buildDeclRefExpr(SemaRef, VD, VD->getType(), VD->getLocation(),
+ false);
+ };
+
+ // Following up the creation of the final fused loop will be performed
+ // which has the following shape (considering the selected loops):
+ //
+ // for (fuse.index = 0; fuse.index < max(ni0, ni1..., nik); ++fuse.index) {
+ // if (fuse.index < ni0){
+ // iv0 = lb0 + st0 * fuse.index;
+ // original.index0 = iv0
+ // body(0);
+ // }
+ // if (fuse.index < ni1){
+ // iv1 = lb1 + st1 * fuse.index;
+ // original.index1 = iv1
+ // body(1);
+ // }
+ //
+ // ...
+ //
+ // if (fuse.index < nik){
+ // ivk = lbk + stk * fuse.index;
+ // original.indexk = ivk
+ // body(k); Expr *InitVal = IntegerLiteral::Create(Context,
+ // llvm::APInt(IVWidth, 0),
+ // }
+
+ // 1. Create the initialized fuse index
+ StringRef IndexName = ".omp.fuse.index";
+ Expr *InitVal = IntegerLiteral::Create(Context, llvm::APInt(IVBitWidth, 0),
+ IVType, SourceLocation());
+ VarDecl *IndexDecl =
+ buildVarDecl(SemaRef, {}, IVType, IndexName, nullptr, nullptr);
+ SemaRef.AddInitializerToDecl(IndexDecl, InitVal, false);
+ StmtResult InitStmt = new (Context)
+ DeclStmt(DeclGroupRef(IndexDecl), SourceLocation(), SourceLocation());
+
+ if (!InitStmt.isUsable())
+ return StmtError();
+
+ auto MakeIVRef = [&SemaRef = this->SemaRef, IndexDecl, IVType,
+ Loc = InitVal->getExprLoc()]() {
+ return buildDeclRefExpr(SemaRef, IndexDecl, IVType, Loc, false);
+ };
+
+ // 2. Iteratively compute the max number of logical iterations Max(NI_1, NI_2,
+ // ..., NI_k)
+ //
+ // This loop accumulates the maximum value across multiple expressions,
+ // ensuring each step constructs a unique AST node for correctness. By using
+ // intermediate temporary variables and conditional operators, we maintain
+ // distinct nodes and avoid duplicating subtrees, For instance, max(a,b,c):
+ // omp.temp0 = max(a, b)
+ // omp.temp1 = max(omp.temp0, c)
+ // omp.fuse.max = max(omp.temp1, omp.temp0)
+
+ ExprResult MaxExpr;
+ // I is the true
----------------
alexey-bataev wrote:
What does it mean?
https://github.com/llvm/llvm-project/pull/139293
More information about the llvm-commits
mailing list