[clang] [llvm] Support for dispatch construct (Sema & Codegen) support. (PR #117904)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 27 09:22:36 PST 2024
================
@@ -5965,6 +5967,244 @@ static bool teamsLoopCanBeParallelFor(Stmt *AStmt, Sema &SemaRef) {
return Checker.teamsLoopCanBeParallelFor();
}
+static Expr *getInitialExprFromCapturedExpr(Expr *Cond) {
+
+ Expr *SubExpr = Cond->IgnoreParenImpCasts();
+
+ if (auto *DeclRef = dyn_cast<DeclRefExpr>(SubExpr)) {
+ if (auto *CapturedExprDecl =
+ dyn_cast<OMPCapturedExprDecl>(DeclRef->getDecl())) {
+
+ // Retrieve the initial expression from the captured expression
+ return CapturedExprDecl->getInit();
+ }
+ }
+ return nullptr;
+}
+
+static Expr *copy_removePseudoObjectExpr(const ASTContext &Context, Expr *E,
+ SemaOpenMP *SemaPtr, bool NoContext) {
+
+ BinaryOperator *BinaryCopyOpr = NULL;
+ bool BinaryOp = false;
+ if (E->getStmtClass() == Stmt::BinaryOperatorClass) {
+ BinaryOp = true;
+ BinaryOperator *E_BinOpr = static_cast<BinaryOperator *>(E);
+ BinaryCopyOpr = BinaryOperator::Create(
+ Context, E_BinOpr->getLHS(), E_BinOpr->getRHS(), E_BinOpr->getOpcode(),
+ E_BinOpr->getType(), E_BinOpr->getValueKind(),
+ E_BinOpr->getObjectKind(), E_BinOpr->getOperatorLoc(),
+ FPOptionsOverride());
+ E = BinaryCopyOpr->getRHS();
+ }
+
+ // Change PseudoObjectExpr to a direct call
+ if (PseudoObjectExpr *PO = dyn_cast<PseudoObjectExpr>(E))
+ E = *((PO->semantics_begin()) - 1);
+
+ // Add new Traits to direct call to convert it to new PseudoObjectExpr
+ // This converts Traits for the function call from under "dispatch" to traits
+ // of direct function call not under "dispatch".
+ if (NoContext) {
+ // Convert StmtResult to a CallExpr before calling ActOnOpenMPCall()
+ CallExpr *CallExprWithinStmt = dyn_cast<CallExpr>(E);
+ int NumArgs = CallExprWithinStmt->getNumArgs();
+ clang::Expr **Args = CallExprWithinStmt->getArgs();
+ ExprResult er = SemaPtr->ActOnOpenMPCall(
+ CallExprWithinStmt, SemaPtr->SemaRef.getCurScope(),
+ CallExprWithinStmt->getBeginLoc(), MultiExprArg(Args, NumArgs),
+ CallExprWithinStmt->getRParenLoc(), static_cast<Expr *>(nullptr));
+ E = er.get();
+ }
+
+ if (BinaryOp) {
+ BinaryCopyOpr->setRHS(E);
+ return BinaryCopyOpr;
+ }
+
+ return E;
+}
+
+static StmtResult combine2Stmts(ASTContext &context, Stmt *first,
+ Stmt *second) {
+
+ llvm::SmallVector<Stmt *, 2> newCombinedStmts;
+ newCombinedStmts.push_back(first);
+ newCombinedStmts.push_back(second); // Adding foo();
+ llvm::ArrayRef<Stmt *> ar(newCombinedStmts);
+ CompoundStmt *CombinedStmt = CompoundStmt::Create(
+ context, ar, FPOptionsOverride(), SourceLocation(), SourceLocation());
+ StmtResult FinalStmts(CombinedStmt);
+ return FinalStmts;
+}
+
+template <typename SpecificClause>
+static bool hasClausesOfKind(ArrayRef<OMPClause *> Clauses) {
+ auto ClausesOfKind =
+ OMPExecutableDirective::getClausesOfKind<SpecificClause>(Clauses);
+ return ClausesOfKind.begin() != ClausesOfKind.end();
+}
+
+// Get a CapturedStmt with direct call to function.
+// If there is a PseudoObjectExpr under the CapturedDecl
+// choose the first call under it for the direct call to function
+static StmtResult CloneNewCapturedStmtForDirectCall(const ASTContext &Context,
+ Stmt *StmtP,
+ SemaOpenMP *SemaPtr,
+ bool NoContext) {
+ if (StmtP->getStmtClass() == Stmt::CapturedStmtClass) {
+ CapturedStmt *AStmt = static_cast<CapturedStmt *>(StmtP);
+ CapturedDecl *CDecl = AStmt->getCapturedDecl();
+ Stmt *S = cast<CapturedStmt>(AStmt)->getCapturedStmt();
+ auto *E = dyn_cast<Expr>(S);
+ E = copy_removePseudoObjectExpr(Context, E, SemaPtr, NoContext);
+
+ // Copy Current Captured Decl to a New Captured Decl for noting the
+ // Annotation
+ CapturedDecl *NewDecl =
+ CapturedDecl::Create(const_cast<ASTContext &>(Context),
+ CDecl->getDeclContext(), CDecl->getNumParams());
+ NewDecl->setBody(static_cast<Stmt *>(E));
+ for (unsigned i = 0; i < CDecl->getNumParams(); ++i) {
+ if (i != CDecl->getContextParamPosition())
+ NewDecl->setParam(i, CDecl->getParam(i));
+ else
+ NewDecl->setContextParam(i, CDecl->getContextParam());
+ }
+
+ // Create a New Captured Stmt containing the New Captured Decl
+ SmallVector<CapturedStmt::Capture, 4> Captures;
+ SmallVector<Expr *, 4> CaptureInits;
+ for (auto capture : AStmt->captures())
+ Captures.push_back(capture);
+ for (auto capture_init : AStmt->capture_inits())
+ CaptureInits.push_back(capture_init);
+ CapturedStmt *NewStmt = CapturedStmt::Create(
+ Context, AStmt->getCapturedStmt(), AStmt->getCapturedRegionKind(),
+ Captures, CaptureInits, NewDecl,
+ const_cast<RecordDecl *>(AStmt->getCapturedRecordDecl()));
+
+ return NewStmt;
+ }
+ return static_cast<Stmt *>(NULL);
+}
+
+StmtResult SemaOpenMP::transformDispatchDirective(
+ OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName,
+ OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses,
+ Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc) {
+
+ StmtResult RetValue;
+ std::vector<OMPClause *> DependVector;
----------------
alexey-bataev wrote:
Use SmallVector
https://github.com/llvm/llvm-project/pull/117904
More information about the llvm-commits
mailing list