[clang] [llvm] [Clang][OpenMP] Support for dispatch construct (Sema & Codegen) support (PR #117904)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Dec 15 10:31:30 PST 2024
================
@@ -5965,6 +5967,264 @@ static bool teamsLoopCanBeParallelFor(Stmt *AStmt, Sema &SemaRef) {
return Checker.teamsLoopCanBeParallelFor();
}
+static Expr *getInitialExprFromCapturedExpr(Expr *Cond) {
+
+ Expr *SubExpr = Cond->IgnoreParenImpCasts();
+
+ if (auto *DeclRef = dyn_cast<DeclRefExpr>(SubExpr)) {
+ if (auto *CapturedExprDecl =
+ dyn_cast<OMPCapturedExprDecl>(DeclRef->getDecl())) {
+
+ // Retrieve the initial expression from the captured expression
+ return CapturedExprDecl->getInit();
+ }
+ }
+ return nullptr;
+}
+
+static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context, Expr *,
+ SemaOpenMP *, bool);
+
+/// cloneAssociatedStmt() function is for cloning the Associated Statement
+/// present with a Directive and then modifying it. By this we avoid modifying
+/// the original Associated Statement.
+static StmtResult cloneAssociatedStmt(const ASTContext &Context, Stmt *StmtP,
+ SemaOpenMP *SemaPtr, bool NoContext) {
+ StmtResult ResultAssocStmt;
+ if (auto *AssocStmt = dyn_cast<CapturedStmt>(StmtP)) {
+ CapturedDecl *CDecl = AssocStmt->getCapturedDecl();
+ Stmt *AssocExprStmt = AssocStmt->getCapturedStmt();
+ auto *AssocExpr = dyn_cast<Expr>(AssocExprStmt);
+ Expr *NewCallOrPseudoObjOrBinExpr = replaceWithNewTraitsOrDirectCall(
+ Context, AssocExpr, SemaPtr, NoContext);
+
+ // Copy Current Captured Decl to a New Captured Decl for noting the
+ // Annotation
+ CapturedDecl *NewDecl =
+ CapturedDecl::Create(const_cast<ASTContext &>(Context),
+ CDecl->getDeclContext(), CDecl->getNumParams());
+ NewDecl->setBody(static_cast<Stmt *>(NewCallOrPseudoObjOrBinExpr));
+ for (unsigned I : llvm::seq<unsigned>(CDecl->getNumParams())) {
+ if (I != CDecl->getContextParamPosition())
+ NewDecl->setParam(I, CDecl->getParam(I));
+ else
+ NewDecl->setContextParam(I, CDecl->getContextParam());
+ }
+
+ // Create a New Captured Stmt containing the New Captured Decl
+ SmallVector<CapturedStmt::Capture, 4> Captures;
+ SmallVector<Expr *, 4> CaptureInits;
+ for (const CapturedStmt::Capture &Capture : AssocStmt->captures())
+ Captures.push_back(Capture);
+ for (Expr *CaptureInit : AssocStmt->capture_inits())
+ CaptureInits.push_back(CaptureInit);
+ auto *NewStmt = CapturedStmt::Create(
+ Context, AssocStmt->getCapturedStmt(),
+ AssocStmt->getCapturedRegionKind(), Captures, CaptureInits, NewDecl,
+ const_cast<RecordDecl *>(AssocStmt->getCapturedRecordDecl()));
+
+ ResultAssocStmt = NewStmt;
+ }
+ return ResultAssocStmt;
+}
+
+/// replaceWithNewTraitsOrDirectCall() is for transforming the call traits.
+/// Call traits associated with a function call are removed and replaced with
+/// a direct call. For clause "nocontext" only, the direct call is then
+/// modified to have call traits for a non-dispatch variant.
+static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context,
+ Expr *AssocExpr,
+ SemaOpenMP *SemaPtr,
+ bool NoContext) {
+ BinaryOperator *BinaryCopyOpr = nullptr;
+ bool IsBinaryOp = false;
+ Expr *PseudoObjExprOrCall = AssocExpr;
+ if (auto *BinOprExpr = dyn_cast<BinaryOperator>(AssocExpr)) {
+ IsBinaryOp = true;
+ BinaryCopyOpr = BinaryOperator::Create(
+ Context, BinOprExpr->getLHS(), BinOprExpr->getRHS(),
+ BinOprExpr->getOpcode(), BinOprExpr->getType(),
+ BinOprExpr->getValueKind(), BinOprExpr->getObjectKind(),
+ BinOprExpr->getOperatorLoc(), FPOptionsOverride());
+ PseudoObjExprOrCall = BinaryCopyOpr->getRHS();
+ }
+
+ Expr *CallWithoutInvariants = PseudoObjExprOrCall;
+ // Change PseudoObjectExpr to a direct call
+ if (auto *PseudoObjExpr = dyn_cast<PseudoObjectExpr>(PseudoObjExprOrCall))
+ CallWithoutInvariants = *((PseudoObjExpr->semantics_begin()) - 1);
+
+ Expr *FinalCall = CallWithoutInvariants; // For noinvariants clause
+ if (NoContext) {
+ // example to explain the changes done for "nocontext" clause:
+ //
+ // #pragma omp declare variant(foo_variant_dispatch)
+ // match(construct = {dispatch})
+ // #pragma omp declare variant(foo_variant_allCond)
+ // match(user = {condition(1)})
+ // ...
+ // #pragma omp dispatch nocontext(cond_true)
+ // foo(i, j); // with traits: CodeGen call to
+ // foo_variant_dispatch(i,j)
+ // dispatch construct is changed to:
+ // if (cond_true) {
+ // foo(i,j) // with traits: CodeGen call to foo_variant_allCond(i,j)
+ // } else {
+ // #pragma omp dispatch
+ // foo(i,j) // with traits: CodeGen call to foo_variant_dispatch(i,j)
+ // }
+
+ // Convert StmtResult to a CallExpr before calling ActOnOpenMPCall()
+ auto *CallExprWithinStmt = cast<CallExpr>(CallWithoutInvariants);
+ int NumArgs = CallExprWithinStmt->getNumArgs();
+ clang::Expr **Args = CallExprWithinStmt->getArgs();
+ // ActOnOpenMPCall() adds traits to a simple function call
+ // e.g. invariant function call traits to "foo(i,j)", if they are present.
+ ExprResult ER = SemaPtr->ActOnOpenMPCall(
+ CallExprWithinStmt, SemaPtr->SemaRef.getCurScope(),
+ CallExprWithinStmt->getBeginLoc(), MultiExprArg(Args, NumArgs),
+ CallExprWithinStmt->getRParenLoc(), static_cast<Expr *>(nullptr));
+ FinalCall = ER.get();
+ }
+
+ if (IsBinaryOp) {
+ BinaryCopyOpr->setRHS(FinalCall);
+ return BinaryCopyOpr;
+ }
+
+ return FinalCall;
+}
+
+static StmtResult combine2Stmts(ASTContext &Context, Stmt *FirstStmt,
+ Stmt *SecondStmt) {
+
+ llvm::SmallVector<Stmt *, 2> NewCombinedStmtVector;
+ NewCombinedStmtVector.push_back(FirstStmt);
+ NewCombinedStmtVector.push_back(SecondStmt);
+ auto *CombinedStmt = CompoundStmt::Create(
+ Context, llvm::ArrayRef<Stmt *>(NewCombinedStmtVector),
+ FPOptionsOverride(), SourceLocation(), SourceLocation());
+ return CombinedStmt;
+}
+
+template <typename SpecificClause>
+static bool hasClausesOfKind(ArrayRef<OMPClause *> Clauses) {
+ auto ClausesOfKind =
+ OMPExecutableDirective::getClausesOfKind<SpecificClause>(Clauses);
+ return ClausesOfKind.begin() != ClausesOfKind.end();
+}
+
+StmtResult SemaOpenMP::transformDispatchDirective(
+ OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName,
+ OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses,
+ Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc) {
+
+ StmtResult RetValue;
+ llvm::SmallVector<OMPClause *, 8> DependClauseVector;
+ for (const OMPDependClause *ConstDependClause :
+ OMPExecutableDirective::getClausesOfKind<OMPDependClause>(Clauses)) {
+ auto *DependClause = const_cast<OMPDependClause *>(ConstDependClause);
+ DependClauseVector.push_back(DependClause);
+ }
+
+ // #pragma omp dispatch depend() is changed to #pragma omp taskwait depend()
+ // This is done by calling ActOnOpenMPExecutableDirective() for the
+ // new taskwait directive.
+ StmtResult DispatchDepend2taskwait = ActOnOpenMPExecutableDirective(
+ OMPD_taskwait, DirName, CancelRegion, DependClauseVector, NULL, StartLoc,
+ EndLoc);
+
+ if (OMPExecutableDirective::getSingleClause<OMPNovariantsClause>(Clauses)) {
+
+ if (OMPExecutableDirective::getSingleClause<OMPNocontextClause>(Clauses)) {
+ Diag(StartLoc, diag::warn_omp_dispatch_clause_novariants_nocontext);
+ }
+
+ const OMPNovariantsClause *NoVariantsC =
+ OMPExecutableDirective::getSingleClause<OMPNovariantsClause>(Clauses);
+ // #pragma omp dispatch novariants(c2) depend(out: x)
+ // foo();
+ // becomes:
+ // #pragma omp taskwait depend(out: x)
+ // if (c2) {
+ // foo();
+ // } else {
+ // #pragma omp dispatch
+ // foo(); <--- foo() is replaced with foo_variant() in CodeGen
+ // }
+ Expr *Cond = getInitialExprFromCapturedExpr(NoVariantsC->getCondition());
----------------
SunilKuravinakop wrote:
I have modified the comments to give a proper understanding.
https://github.com/llvm/llvm-project/pull/117904
More information about the llvm-commits
mailing list