[llvm-branch-commits] [clang] [flang] [llvm] [openmp] [Clang][OpenMP] Add reverse and interchange directives (PR #92030)
Alexey Bataev via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue May 14 03:28:09 PDT 2024
================
@@ -15745,6 +15760,388 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
buildPreInits(Context, PreInits));
}
+StmtResult
+SemaOpenMP::ActOnOpenMPReverseDirective(ArrayRef<OMPClause *> Clauses,
+ Stmt *AStmt, SourceLocation StartLoc,
+ SourceLocation EndLoc) {
+ ASTContext &Context = getASTContext();
+ Scope *CurScope = SemaRef.getCurScope();
+ assert(Clauses.empty() && "reverse directive does not accept any clauses; "
+ "must have beed checked before");
+
+ // Empty statement should only be possible if there already was an error.
+ if (!AStmt)
+ return StmtError();
+
+ constexpr unsigned NumLoops = 1;
+ Stmt *Body = nullptr;
+ SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
+ NumLoops);
+ SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits;
+ if (!checkTransformableLoopNest(OMPD_reverse, AStmt, NumLoops, LoopHelpers,
+ Body, OriginalInits))
+ return StmtError();
+
+ // Delay applying the transformation to when template is completely
+ // instantiated.
+ if (SemaRef.CurContext->isDependentContext())
+ return OMPReverseDirective::Create(Context, StartLoc, EndLoc, Clauses,
+ AStmt, nullptr, nullptr);
+
+ assert(LoopHelpers.size() == NumLoops &&
+ "Expecting a single-dimensional loop iteration space");
+ assert(OriginalInits.size() == NumLoops &&
+ "Expecting a single-dimensional loop iteration space");
+ OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
+
+ // Find the loop statement.
+ Stmt *LoopStmt = nullptr;
+ collectLoopStmts(AStmt, {LoopStmt});
+
+ // Determine the PreInit declarations.
+ SmallVector<Stmt *> PreInits;
+ addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits);
+
+ auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef);
+ QualType IVTy = IterationVarRef->getType();
+ uint64_t IVWidth = Context.getTypeSize(IVTy);
+ auto *OrigVar = cast<DeclRefExpr>(LoopHelper.Counters.front());
+
+ // Iteration variable SourceLocations.
+ SourceLocation OrigVarLoc = OrigVar->getExprLoc();
+ SourceLocation OrigVarLocBegin = OrigVar->getBeginLoc();
+ SourceLocation OrigVarLocEnd = OrigVar->getEndLoc();
+
+ // Locations pointing to the transformation.
+ SourceLocation TransformLoc = StartLoc;
+ SourceLocation TransformLocBegin = StartLoc;
+ SourceLocation TransformLocEnd = EndLoc;
+
+ // Internal variable names.
+ std::string OrigVarName = OrigVar->getNameInfo().getAsString();
+ std::string TripCountName = (Twine(".tripcount.") + OrigVarName).str();
+ std::string ForwardIVName = (Twine(".forward.iv.") + OrigVarName).str();
+ std::string ReversedIVName = (Twine(".reversed.iv.") + OrigVarName).str();
+
+ // LoopHelper.Updates will read the logical iteration number from
+ // LoopHelper.IterationVarRef, compute the value of the user loop counter of
+ // that logical iteration from it, then assign it to the user loop counter
+ // variable. We cannot directly use LoopHelper.IterationVarRef as the
+ // induction variable of the generated loop because it may cause an underflow:
+ // \code
+ // for (unsigned i = 0; i < n; ++i)
+ // body(i);
+ // \endcode
+ //
+ // Naive reversal:
+ // \code
+ // for (unsigned i = n-1; i >= 0; ++i)
+ // body(i);
+ // \endcode
+ //
+ // Instead, we introduce a new iteration variable representing the logical
+ // iteration counter of the original loop, convert it to the logical iteration
+ // number of the reversed loop, then let LoopHelper.Updates compute the user's
+ // loop iteration variable from it.
+ // \code
+ // for (auto .forward.iv = 0; .forward.iv < n; ++.forward.iv) {
+ // auto .reversed.iv = n - .forward.iv - 1;
+ // i = (.reversed.iv + 0) * 1 // LoopHelper.Updates
+ // body(i); // Body
+ // }
+ // \endcode
+
+ // Subexpressions with more than one use. One of the constraints of an AST is
+ // that every node object must appear at most once, hence we define a lambda
+ // that creates a new AST node at every use.
+ CaptureVars CopyTransformer(SemaRef);
+ auto MakeNumIterations = [&CopyTransformer, &LoopHelper]() -> Expr * {
+ return AssertSuccess(
+ CopyTransformer.TransformExpr(LoopHelper.NumIterations));
+ };
+
+ // Create the iteration variable for the forward loop (from 0 to n-1).
+ VarDecl *ForwardIVDecl =
+ buildVarDecl(SemaRef, {}, IVTy, ForwardIVName, nullptr, OrigVar);
+ auto MakeForwardRef = [&SemaRef = this->SemaRef, ForwardIVDecl, IVTy,
+ OrigVarLoc]() {
+ return buildDeclRefExpr(SemaRef, ForwardIVDecl, IVTy, OrigVarLoc);
+ };
+
+ // Iteration variable for the reversed induction variable (from n-1 downto 0):
+ // Reuse the iteration variable created by checkOpenMPLoop.
+ auto *ReversedIVDecl = cast<VarDecl>(IterationVarRef->getDecl());
+ ReversedIVDecl->setDeclName(
+ &SemaRef.PP.getIdentifierTable().get(ReversedIVName));
+
+ // For init-statement:
+ // \code
+ // auto .forward.iv = 0
+ // \endcode
+ IntegerLiteral *Zero =
+ IntegerLiteral::Create(Context, llvm::APInt::getZero(IVWidth),
+ ForwardIVDecl->getType(), OrigVarLoc);
+ SemaRef.AddInitializerToDecl(ForwardIVDecl, Zero, /*DirectInit=*/false);
+ StmtResult Init = new (Context)
+ DeclStmt(DeclGroupRef(ForwardIVDecl), OrigVarLocBegin, OrigVarLocEnd);
+ if (!Init.isUsable())
+ return StmtError();
+
+ // Forward iv cond-expression:
+ // \code
+ // .forward.iv < NumIterations
+ // \endcode
+ ExprResult Cond =
+ SemaRef.BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), BO_LT,
+ MakeForwardRef(), MakeNumIterations());
+ if (!Cond.isUsable())
+ return StmtError();
+
+ // Forward incr-statement: ++.forward.iv
+ ExprResult Incr = SemaRef.BuildUnaryOp(CurScope, LoopHelper.Inc->getExprLoc(),
+ UO_PreInc, MakeForwardRef());
+ if (!Incr.isUsable())
+ return StmtError();
+
+ // Reverse the forward-iv: auto .reversed.iv = MakeNumIterations() - 1 -
+ // .forward.iv
+ IntegerLiteral *One = IntegerLiteral::Create(Context, llvm::APInt(IVWidth, 1),
+ IVTy, TransformLoc);
+ ExprResult Minus = SemaRef.BuildBinOp(CurScope, TransformLoc, BO_Sub,
+ MakeNumIterations(), One);
+ if (!Minus.isUsable())
+ return StmtError();
+ Minus = SemaRef.BuildBinOp(CurScope, TransformLoc, BO_Sub, Minus.get(),
+ MakeForwardRef());
+ if (!Minus.isUsable())
+ return StmtError();
+ StmtResult InitReversed = new (Context) DeclStmt(
+ DeclGroupRef(ReversedIVDecl), TransformLocBegin, TransformLocEnd);
+ if (!InitReversed.isUsable())
+ return StmtError();
+ SemaRef.AddInitializerToDecl(ReversedIVDecl, Minus.get(),
+ /*DirectInit=*/false);
+
+ // The new loop body.
+ SmallVector<Stmt *> BodyStmts;
+ BodyStmts.push_back(InitReversed.get());
+ llvm::append_range(BodyStmts, LoopHelper.Updates);
+ if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
+ BodyStmts.push_back(CXXRangeFor->getLoopVarStmt());
+ BodyStmts.push_back(Body);
+ auto *ReversedBody =
+ CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(),
+ Body->getBeginLoc(), Body->getEndLoc());
+
+ // Finally create the reversed For-statement.
+ auto *ReversedFor = new (Context)
+ ForStmt(Context, Init.get(), Cond.get(), nullptr, Incr.get(),
+ ReversedBody, LoopHelper.Init->getBeginLoc(),
+ LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc());
+ return OMPReverseDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
+ ReversedFor,
+ buildPreInits(Context, PreInits));
+}
+
+StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
+ ArrayRef<OMPClause *> Clauses, Stmt *AStmt, SourceLocation StartLoc,
+ SourceLocation EndLoc) {
+ ASTContext &Context = getASTContext();
+ DeclContext *CurContext = SemaRef.CurContext;
+ Scope *CurScope = SemaRef.getCurScope();
+
+ // Empty statement should only be possible if there already was an error.
+ if (!AStmt)
+ return StmtError();
+
+ // interchange without permutation clause swaps two loops.
+ const OMPPermutationClause *PermutationClause =
+ OMPExecutableDirective::getSingleClause<OMPPermutationClause>(Clauses);
+ size_t NumLoops = PermutationClause ? PermutationClause->getNumLoops() : 2;
+
+ // Verify and diagnose loop nest.
+ SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
+ Stmt *Body = nullptr;
+ SmallVector<SmallVector<Stmt *, 0>, 2> OriginalInits;
+ if (!checkTransformableLoopNest(OMPD_interchange, AStmt, NumLoops,
+ LoopHelpers, Body, OriginalInits))
+ return StmtError();
+
+ // Delay interchange to when template is completely instantiated.
+ if (CurContext->isDependentContext())
+ return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses,
+ NumLoops, AStmt, nullptr, nullptr);
+
+ // An invalid expression in the permutation clause is set to nullptr in
+ // ActOnOpenMPPermutationClause.
+ if (PermutationClause && llvm::any_of(PermutationClause->getArgsRefs(),
+ [](Expr *E) { return !E; }))
+ return StmtError();
+
+ assert(LoopHelpers.size() == NumLoops &&
+ "Expecting loop iteration space dimensionaly to match number of "
+ "affected loops");
+ assert(OriginalInits.size() == NumLoops &&
+ "Expecting loop iteration space dimensionaly to match number of "
+ "affected loops");
+
+ // Decode the permutation clause.
+ SmallVector<uint64_t, 2> Permutation;
+ if (!PermutationClause) {
+ Permutation = {1, 0};
+ } else {
+ ArrayRef<Expr *> PermArgs = PermutationClause->getArgsRefs();
+ llvm::BitVector Flags(PermArgs.size());
+ for (Expr *PermArg : PermArgs) {
+ std::optional<llvm::APSInt> PermCstExpr =
+ PermArg->getIntegerConstantExpr(Context);
+ if (!PermCstExpr)
+ continue;
+ uint64_t PermInt = PermCstExpr->getZExtValue();
+ assert(1 <= PermInt && PermInt <= NumLoops &&
+ "Must be a permutation; diagnostic emitted in "
+ "ActOnOpenMPPermutationClause");
+ if (Flags[PermInt - 1]) {
+ SourceRange ExprRange(PermArg->getBeginLoc(), PermArg->getEndLoc());
+ Diag(PermArg->getExprLoc(),
+ diag::err_omp_interchange_permutation_value_repeated)
+ << PermInt << ExprRange;
+ continue;
+ }
+ Flags[PermInt - 1] = true;
+
+ Permutation.push_back(PermInt - 1);
+ }
+
+ if (Permutation.size() != NumLoops)
+ return StmtError();
+ }
+
+ // Nothing to transform with trivial permutation.
+ if (NumLoops <= 1 || llvm::all_of(llvm::enumerate(Permutation), [](auto p) {
+ auto [Idx, Arg] = p;
+ return Idx == Arg;
+ }))
+ return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses,
+ NumLoops, AStmt, AStmt, nullptr);
+
+ // Find the affected loops.
+ SmallVector<Stmt *> LoopStmts(NumLoops, nullptr);
+ collectLoopStmts(AStmt, LoopStmts);
+
+ // Collect pre-init statements on the order before the permuation.
+ SmallVector<Stmt *> PreInits;
+ for (auto I : llvm::seq<int>(NumLoops)) {
+ OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I];
+
+ assert(LoopHelper.Counters.size() == 1 &&
+ "Single-dimensional loop iteration space expected");
+ auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters.front());
+
+ std::string OrigVarName = OrigCntVar->getNameInfo().getAsString();
+ addLoopPreInits(Context, LoopHelper, LoopStmts[I], OriginalInits[I],
+ PreInits);
+ }
+
+ SmallVector<VarDecl *> PermutedIndVars;
+ PermutedIndVars.resize(NumLoops);
+ CaptureVars CopyTransformer(SemaRef);
+
+ // Create the permuted loops from the inside to the outside of the
+ // interchanged loop nest. Body of the innermost new loop is the original
+ // innermost body.
+ Stmt *Inner = Body;
+ for (auto TargetIdx : llvm::reverse(llvm::seq<int>(NumLoops))) {
+ // Get the original loop that belongs to this new position.
+ uint64_t SourceIdx = Permutation[TargetIdx];
+ OMPLoopBasedDirective::HelperExprs &SourceHelper = LoopHelpers[SourceIdx];
+ Stmt *SourceLoopStmt = LoopStmts[SourceIdx];
+ assert(SourceHelper.Counters.size() == 1 &&
+ "Single-dimensional loop iteration space expected");
+ auto *OrigCntVar = cast<DeclRefExpr>(SourceHelper.Counters.front());
+
+ // Normalized loop counter variable: From 0 to n-1, always an integer type.
+ DeclRefExpr *IterVarRef = cast<DeclRefExpr>(SourceHelper.IterationVarRef);
+ QualType IVTy = IterVarRef->getType();
+ assert(IVTy->isIntegerType() &&
+ "Expected the logical iteration counter to be an integer");
+
+ std::string OrigVarName = OrigCntVar->getNameInfo().getAsString();
+ SourceLocation OrigVarLoc = IterVarRef->getExprLoc();
+
+ // Make a copy of the NumIterations expression for each use: By the AST
+ // constraints, every expression object in a DeclContext must be unique.
+ auto MakeNumIterations = [&CopyTransformer, &SourceHelper]() -> Expr * {
+ return AssertSuccess(
+ CopyTransformer.TransformExpr(SourceHelper.NumIterations));
+ };
+
+ // Iteration variable for the permuted loop. Reuse the one from
+ // checkOpenMPLoop which will also be used to update the original loop
+ // variable.
+ std::string PermutedCntName =
+ (Twine(".permuted_") + llvm::utostr(TargetIdx) + ".iv." + OrigVarName)
+ .str();
+ auto *PermutedCntDecl = cast<VarDecl>(IterVarRef->getDecl());
+ PermutedCntDecl->setDeclName(
+ &SemaRef.PP.getIdentifierTable().get(PermutedCntName));
+ PermutedIndVars[TargetIdx] = PermutedCntDecl;
+ auto MakePermutedRef = [this, PermutedCntDecl, IVTy, OrigVarLoc]() {
+ return buildDeclRefExpr(SemaRef, PermutedCntDecl, IVTy, OrigVarLoc);
+ };
+
+ // For init-statement:
+ // \code{c}
+ // auto .permuted_{target}.iv = 0
+ // \endcode
+ ExprResult Zero = SemaRef.ActOnIntegerConstant(OrigVarLoc, 0);
+ if (!Zero.isUsable())
+ return StmtError();
+ SemaRef.AddInitializerToDecl(PermutedCntDecl, Zero.get(),
+ /*DirectInit=*/false);
+ StmtResult InitStmt = new (Context)
+ DeclStmt(DeclGroupRef(PermutedCntDecl), OrigCntVar->getBeginLoc(),
+ OrigCntVar->getEndLoc());
+ if (!InitStmt.isUsable())
+ return StmtError();
+
+ // For cond-expression:
+ // \code{c}
+ // .permuted_{target}.iv < NumIterations
+ // \endcode
+ ExprResult CondExpr =
+ SemaRef.BuildBinOp(CurScope, SourceHelper.Cond->getExprLoc(), BO_LT,
+ MakePermutedRef(), MakeNumIterations());
+ if (!CondExpr.isUsable())
+ return StmtError();
+
+ // For incr-statement:
+ // \code{c}
+ // ++.tile.iv
+ // \endcode
+ ExprResult IncrStmt = SemaRef.BuildUnaryOp(
+ CurScope, SourceHelper.Inc->getExprLoc(), UO_PreInc, MakePermutedRef());
+ if (!IncrStmt.isUsable())
+ return StmtError();
+
+ SmallVector<Stmt *, 4> BodyParts;
+ llvm::append_range(BodyParts, SourceHelper.Updates);
----------------
alexey-bataev wrote:
```suggestion
SmallVector<Stmt *, 4> BodyParts(SourceHelper.Updates.begin(), SourceHelper.Updates.end());
```
https://github.com/llvm/llvm-project/pull/92030
More information about the llvm-branch-commits
mailing list