[clang] [llvm] Assume (PR #97535)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 3 12:13:02 PDT 2024
================
@@ -7323,6 +7324,69 @@ void SemaOpenMP::ActOnFinishedFunctionDefinitionInOpenMPAssumeScope(Decl *D) {
FD->addAttr(AA);
}
+class OMPAssumeStmtVisitor : public StmtVisitor<OMPAssumeStmtVisitor> {
+ SmallVector<OMPAssumeAttr *, 4> *OMPAssumeScoped;
+
+public:
+ OMPAssumeStmtVisitor(SmallVector<OMPAssumeAttr *, 4> *OMPAssumeScoped) {
+ this->OMPAssumeScoped = OMPAssumeScoped;
+ }
+
+ void VisitCapturedStmt(CapturedStmt *CS) {
+ // To find the CaptureDecl for the CaptureStmt
+ CapturedDecl *CD = CS->getCapturedDecl();
+ if (CD) {
+ for (OMPAssumeAttr *AA : *OMPAssumeScoped)
+ CD->addAttr(AA);
+ }
+ }
+
+ void VisitCompoundStmt(CompoundStmt *CS) {
+ // Handle CompoundStmt
+ // Visit each statement in the CompoundStmt
+ for (Stmt *SubStmt : CS->body()) {
+ if (Expr *CE = dyn_cast<Expr>(SubStmt)) {
+ // If the statement is a Expr, process it
+ VisitExpr(CE);
+ }
+ }
+ }
+
+ void VisitExpr(Expr *CE) {
+ // Handle all Expr
+ for (auto *Child : CE->children()) {
+ Visit(Child);
+ }
+ }
+
+ void Visit(Stmt *S) {
+ const char *CName = S->getStmtClassName();
+ if ((strstr(CName, "OMP") != NULL) &&
+ (strstr(CName, "Directive") != NULL)) {
+ for (Stmt *Child : S->children()) {
+ auto *CS = dyn_cast<CapturedStmt>(Child);
+ if (CS)
+ VisitCapturedStmt(CS);
+ else
+ StmtVisitor<OMPAssumeStmtVisitor>::Visit(Child);
+ }
+ } else {
+ StmtVisitor<OMPAssumeStmtVisitor>::Visit(S);
+ }
+ }
+};
+
+StmtResult
+SemaOpenMP::ActOnFinishedStatementInOpenMPAssumeScope(Stmt *AssociatedStmt) {
+
+ if (AssociatedStmt) {
+ // Add the AssumeAttr to the Directive associated with the Assume Directive.
+ OMPAssumeStmtVisitor Visitor(&OMPAssumeScoped);
+ Visitor.Visit(AssociatedStmt);
----------------
SunilKuravinakop wrote:
If the example is :
```
// assume is for the "simd" region
#pragma omp assume no_openmp
#pragma omp simd
for (int i = 0; i < N; ++i){
A[i] += B[i];
}
```
then in -ast-dump for:
```
|-OMPSimdDirective 0x15cc6a0 <line:22:3, col:19>
| `-CapturedStmt 0x15b0d98 <line:23:3, line:25:3>
| `-CapturedDecl 0x15b0960 <<invalid sloc>> <invalid sloc>
| |
```
```
| |-OMPAssumeAttr 0x15b07b0 <line:21:15> "omp_no_openmp"
```
needs to be added. Also, according to the spec :
`The scope of the assume directive is the code executed in the corresponding region or in any region that is nested in the corresponding region. `
Is there a better way to add this other than having
`OMPAssumeStmtVisitor::VisitCapturedStmt( )`
https://github.com/llvm/llvm-project/pull/97535
More information about the llvm-commits
mailing list