[clang] [llvm] Assume (PR #97535)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 3 12:13:02 PDT 2024


================
@@ -7323,6 +7324,69 @@ void SemaOpenMP::ActOnFinishedFunctionDefinitionInOpenMPAssumeScope(Decl *D) {
     FD->addAttr(AA);
 }
 
+class OMPAssumeStmtVisitor : public StmtVisitor<OMPAssumeStmtVisitor> {
+  SmallVector<OMPAssumeAttr *, 4> *OMPAssumeScoped;
+
+public:
+  OMPAssumeStmtVisitor(SmallVector<OMPAssumeAttr *, 4> *OMPAssumeScoped) {
+    this->OMPAssumeScoped = OMPAssumeScoped;
+  }
+
+  void VisitCapturedStmt(CapturedStmt *CS) {
+    // To find the CaptureDecl for the CaptureStmt
+    CapturedDecl *CD = CS->getCapturedDecl();
+    if (CD) {
+      for (OMPAssumeAttr *AA : *OMPAssumeScoped)
+        CD->addAttr(AA);
+    }
+  }
+
+  void VisitCompoundStmt(CompoundStmt *CS) {
+    // Handle CompoundStmt
+    // Visit each statement in the CompoundStmt
+    for (Stmt *SubStmt : CS->body()) {
+      if (Expr *CE = dyn_cast<Expr>(SubStmt)) {
+        // If the statement is a Expr, process it
+        VisitExpr(CE);
+      }
+    }
+  }
+
+  void VisitExpr(Expr *CE) {
+    // Handle all Expr
+    for (auto *Child : CE->children()) {
+      Visit(Child);
+    }
+  }
+
+  void Visit(Stmt *S) {
+    const char *CName = S->getStmtClassName();
+    if ((strstr(CName, "OMP") != NULL) &&
+        (strstr(CName, "Directive") != NULL)) {
+      for (Stmt *Child : S->children()) {
+        auto *CS = dyn_cast<CapturedStmt>(Child);
+        if (CS)
+          VisitCapturedStmt(CS);
+        else
+          StmtVisitor<OMPAssumeStmtVisitor>::Visit(Child);
+      }
+    } else {
+      StmtVisitor<OMPAssumeStmtVisitor>::Visit(S);
+    }
+  }
+};
+
+StmtResult
+SemaOpenMP::ActOnFinishedStatementInOpenMPAssumeScope(Stmt *AssociatedStmt) {
+
+  if (AssociatedStmt) {
+    // Add the AssumeAttr to the Directive associated with the Assume Directive.
+    OMPAssumeStmtVisitor Visitor(&OMPAssumeScoped);
+    Visitor.Visit(AssociatedStmt);
----------------
SunilKuravinakop wrote:

If the example is :
 ```
// assume is for the "simd" region
  #pragma omp assume no_openmp
  #pragma omp simd
  for (int i = 0; i < N; ++i){
    A[i] += B[i];
  }
```
then in -ast-dump for:
```
    |-OMPSimdDirective 0x15cc6a0 <line:22:3, col:19>
    |  `-CapturedStmt 0x15b0d98 <line:23:3, line:25:3>
    |    `-CapturedDecl 0x15b0960 <<invalid sloc>> <invalid sloc>
    |     |
```
```
    |     |-OMPAssumeAttr 0x15b07b0 <line:21:15> "omp_no_openmp"
```
needs to be added. Also, according to the spec :
`The scope of the assume directive is the code executed in the corresponding region or in any region that is nested in the corresponding region. `
Is there a better way to add this other than having
`OMPAssumeStmtVisitor::VisitCapturedStmt( )`

https://github.com/llvm/llvm-project/pull/97535


More information about the llvm-commits mailing list