[clang] [llvm] [Clang][OpenMP] Support for dispatch construct (Sema & Codegen) support (PR #117904)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 10 10:46:26 PST 2024


================
@@ -5965,6 +5967,266 @@ static bool teamsLoopCanBeParallelFor(Stmt *AStmt, Sema &SemaRef) {
   return Checker.teamsLoopCanBeParallelFor();
 }
 
+static Expr *getInitialExprFromCapturedExpr(Expr *Cond) {
+
+  Expr *SubExpr = Cond->IgnoreParenImpCasts();
+
+  if (auto *DeclRef = dyn_cast<DeclRefExpr>(SubExpr)) {
+    if (auto *CapturedExprDecl =
+            dyn_cast<OMPCapturedExprDecl>(DeclRef->getDecl())) {
+
+      // Retrieve the initial expression from the captured expression
+      return CapturedExprDecl->getInit();
+    }
+  }
+  return nullptr;
+}
+
+static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context, Expr *,
+                                              SemaOpenMP *, bool);
+
+/// cloneAssociatedStmt() function is for cloning the Associated Statement
+/// present with a Directive and then modifying it. By this we avoid modifying
+/// the original Associated Statement.
+static StmtResult cloneAssociatedStmt(const ASTContext &Context, Stmt *StmtP,
+                                      SemaOpenMP *SemaPtr, bool NoContext) {
+  if (auto *AssocStmt = dyn_cast<CapturedStmt>(StmtP)) {
+    CapturedDecl *CDecl = AssocStmt->getCapturedDecl();
+    Stmt *AssocExprStmt = AssocStmt->getCapturedStmt();
+    auto *AssocExpr = dyn_cast<Expr>(AssocExprStmt);
+    Expr *NewCallOrPseudoObjOrBinExpr = replaceWithNewTraitsOrDirectCall(
+        Context, AssocExpr, SemaPtr, NoContext);
+
+    // Copy Current Captured Decl to a New Captured Decl for noting the
+    // Annotation
+    CapturedDecl *NewDecl =
+        CapturedDecl::Create(const_cast<ASTContext &>(Context),
+                             CDecl->getDeclContext(), CDecl->getNumParams());
+    NewDecl->setBody(static_cast<Stmt *>(NewCallOrPseudoObjOrBinExpr));
+    for (unsigned I : llvm::seq<unsigned>(CDecl->getNumParams())) {
+      if (I != CDecl->getContextParamPosition())
+        NewDecl->setParam(I, CDecl->getParam(I));
+      else
+        NewDecl->setContextParam(I, CDecl->getContextParam());
+    }
+
+    // Create a New Captured Stmt containing the New Captured Decl
+    SmallVector<CapturedStmt::Capture, 4> Captures;
+    SmallVector<Expr *, 4> CaptureInits;
+    for (const CapturedStmt::Capture &Capture : AssocStmt->captures())
+      Captures.push_back(Capture);
+    for (Expr *CaptureInit : AssocStmt->capture_inits())
+      CaptureInits.push_back(CaptureInit);
+    auto *NewStmt = CapturedStmt::Create(
+        Context, AssocStmt->getCapturedStmt(),
+        AssocStmt->getCapturedRegionKind(), Captures, CaptureInits, NewDecl,
+        const_cast<RecordDecl *>(AssocStmt->getCapturedRecordDecl()));
+
+    return NewStmt;
+  }
+  return static_cast<Stmt *>(nullptr);
+}
+
+/// replaceWithNewTraitsOrDirectCall() is for transforming the call traits.
+/// Call traits associated with a function call are removed and replaced with
+/// a direct call. For clause "nocontext" only, the direct call is then
+/// modified to have call traits for a non-dispatch variant.
+static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context,
+                                              Expr *AssocExpr,
+                                              SemaOpenMP *SemaPtr,
+                                              bool NoContext) {
+  BinaryOperator *BinaryCopyOpr = nullptr;
+  bool IsBinaryOp = false;
+  Expr *PseudoObjExprOrCall = AssocExpr;
+  if (auto *BinOprExpr = dyn_cast<BinaryOperator>(AssocExpr)) {
+    IsBinaryOp = true;
+    BinaryCopyOpr = BinaryOperator::Create(
+        Context, BinOprExpr->getLHS(), BinOprExpr->getRHS(),
+        BinOprExpr->getOpcode(), BinOprExpr->getType(),
+        BinOprExpr->getValueKind(), BinOprExpr->getObjectKind(),
+        BinOprExpr->getOperatorLoc(), FPOptionsOverride());
+    PseudoObjExprOrCall = BinaryCopyOpr->getRHS();
+  }
+
+  Expr *CallWithoutInvariants = PseudoObjExprOrCall;
+  // Change PseudoObjectExpr to a direct call
+  if (auto *PseudoObjExpr = dyn_cast<PseudoObjectExpr>(PseudoObjExprOrCall))
+    CallWithoutInvariants = *((PseudoObjExpr->semantics_begin()) - 1);
+
+  Expr *FinalCall = CallWithoutInvariants; // For noinvariants clause
+  if (NoContext) {
+    // example to explain the changes done for "nocontext" clause:
+    //
+    // #pragma omp declare variant(foo_variant_dispatch)
+    // 		 		 match(construct = {dispatch})
+    // #pragma omp declare variant(foo_variant_allCond)
+    // 		 		match(user = {condition(1)})
+    // ...
+    //     #pragma omp dispatch nocontext(cond_true)
+    //         foo(i, j); // with traits: CodeGen call to
+    //         foo_variant_dispatch(i,j)
+    // dispatch construct is changed to:
+    // if (cond_true) {
+    //    foo(i,j) // with traits: CodeGen call to foo_variant_allCond(i,j)
+    // } else {
+    //   #pragma omp dispatch
+    //   foo(i,j)  // with traits: CodeGen call to foo_variant_dispatch(i,j)
+    // }
+
+    // Convert StmtResult to a CallExpr before calling ActOnOpenMPCall()
+    auto *CallExprWithinStmt = cast<CallExpr>(CallWithoutInvariants);
+    int NumArgs = CallExprWithinStmt->getNumArgs();
+    clang::Expr **Args = CallExprWithinStmt->getArgs();
+    // ActOnOpenMPCall() adds traits to a simple function call
+    // e.g. invariant function call traits to "foo(i,j)", if they are present.
+    ExprResult ER = SemaPtr->ActOnOpenMPCall(
+        CallExprWithinStmt, SemaPtr->SemaRef.getCurScope(),
+        CallExprWithinStmt->getBeginLoc(), MultiExprArg(Args, NumArgs),
+        CallExprWithinStmt->getRParenLoc(), static_cast<Expr *>(nullptr));
+    FinalCall = ER.get();
+  }
+
+  if (IsBinaryOp) {
+    BinaryCopyOpr->setRHS(FinalCall);
+    return BinaryCopyOpr;
+  }
+
+  return FinalCall;
+}
+
+static StmtResult combine2Stmts(ASTContext &Context, Stmt *FirstStmt,
+                                Stmt *SecondStmt) {
+
+  llvm::SmallVector<Stmt *, 2> NewCombinedStmtVector;
+  NewCombinedStmtVector.push_back(FirstStmt);
+  NewCombinedStmtVector.push_back(SecondStmt);
+  auto *CombinedStmt = CompoundStmt::Create(
+      Context, llvm::ArrayRef<Stmt *>(NewCombinedStmtVector),
+      FPOptionsOverride(), SourceLocation(), SourceLocation());
+  StmtResult FinalStmts(CombinedStmt);
+  return FinalStmts;
+}
+
+template <typename SpecificClause>
+static bool hasClausesOfKind(ArrayRef<OMPClause *> Clauses) {
----------------
shiltian wrote:

surprised that we don't have existing utility function

https://github.com/llvm/llvm-project/pull/117904


More information about the llvm-commits mailing list