[clang] [llvm] Draft/wip/preview of using leaf constructs in clang (PR #84817)

Krzysztof Parzyszek via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 11 12:49:33 PDT 2024


https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/84817

This doesn't do anything with clauses yet.  It's just hooking things up to handle leaf constructs, and it's not even trying to cover all cases.

>From 45859d3b1fb914d26831086fb7f7e6decdec9010 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Mon, 11 Mar 2024 12:55:38 -0500
Subject: [PATCH 1/2] [Frontend][OpenMP] Add isCompositeConstruct and
 isCombinedConstruct

Implement helper functions to identify composite and combined constructs.
---
 llvm/include/llvm/Frontend/OpenMP/OMP.h |  5 +++++
 llvm/lib/Frontend/OpenMP/OMP.cpp        | 21 +++++++++++++++++++++
 2 files changed, 26 insertions(+)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.h b/llvm/include/llvm/Frontend/OpenMP/OMP.h
index a85cd9d344c6d7..59737417b4b014 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.h
@@ -15,4 +15,9 @@
 
 #include "llvm/Frontend/OpenMP/OMP.h.inc"
 
+namespace llvm::omp {
+bool isCompositeConstruct(Directive D);
+bool isCombinedConstruct(Directive D);
+} // namespace llvm::omp
+
 #endif // LLVM_FRONTEND_OPENMP_OMP_H
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index 4f2f95392648b3..19327277d22ad8 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/Frontend/OpenMP/OMP.h"
 
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -17,3 +18,23 @@ using namespace omp;
 
 #define GEN_DIRECTIVES_IMPL
 #include "llvm/Frontend/OpenMP/OMP.inc"
+
+namespace llvm::omp {
+bool isCompositeConstruct(Directive D) {
+  // OpenMP Spec 5.2: [17.3, 8-9]
+  // If directive-name-A and directive-name-B both correspond to loop-
+  // associated constructs then directive-name is a composite construct
+  size_t numLoopConstructs =
+      llvm::count_if(getLeafConstructs(D), [](Directive L) {
+        return getDirectiveAssociation(L) == Association::Loop;
+      });
+  return numLoopConstructs > 1;
+}
+
+bool isCombinedConstruct(Directive D) {
+  // OpenMP Spec 5.2: [17.3, 9-10]
+  // Otherwise directive-name is a combined construct.
+  return !getLeafConstructs(D).empty() && !isCompositeConstruct(D);
+}
+
+} // namespace llvm::omp

>From 0f4bf060d080b156708ff6e9d3fa7987e7652b9c Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 7 Mar 2024 07:10:20 -0600
Subject: [PATCH 2/2] Draft/wip/preview of using leaf constructs

---
 clang/include/clang/Sema/Sema.h |  3 ++
 clang/lib/Basic/OpenMPKinds.cpp | 35 ++++++++++++---------
 clang/lib/Parse/ParseOpenMP.cpp | 28 ++++++++++++++---
 clang/lib/Sema/SemaOpenMP.cpp   | 54 ++++++++++++++++++++-------------
 4 files changed, 79 insertions(+), 41 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 267c79cc057cba..e5e62d5ea5e8ac 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -13718,6 +13718,9 @@ class Sema final {
   ///
   /// \returns Statement for finished OpenMP region.
   StmtResult ActOnOpenMPRegionEnd(StmtResult S, ArrayRef<OMPClause *> Clauses);
+  StmtResult ActOnOpenMPRegionEnd(
+      StmtResult S, ArrayRef<OMPClause *> Clauses,
+      std::function<StmtResult(StmtResult)> callback);
   StmtResult ActOnOpenMPExecutableDirective(
       OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName,
       OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses,
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index b3e9affbb3e58a..289da880400c80 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -706,6 +706,11 @@ void clang::getOpenMPCaptureRegions(
     SmallVectorImpl<OpenMPDirectiveKind> &CaptureRegions,
     OpenMPDirectiveKind DKind) {
   assert(unsigned(DKind) < llvm::omp::Directive_enumSize);
+  if (isCombinedConstruct(DKind)) {
+    for (OpenMPDirectiveKind Leaf : getLeafConstructs(DKind))
+      getOpenMPCaptureRegions(CaptureRegions, Leaf);
+    return;
+  }
   switch (DKind) {
   case OMPD_metadirective:
     CaptureRegions.push_back(OMPD_metadirective);
@@ -713,15 +718,15 @@ void clang::getOpenMPCaptureRegions(
   case OMPD_parallel:
   case OMPD_parallel_for:
   case OMPD_parallel_for_simd:
-  case OMPD_parallel_master:
-  case OMPD_parallel_masked:
-  case OMPD_parallel_sections:
+  // case OMPD_parallel_master:
+  // case OMPD_parallel_masked:
+  // case OMPD_parallel_sections:
   case OMPD_distribute_parallel_for:
   case OMPD_distribute_parallel_for_simd:
-  case OMPD_parallel_loop:
+  // case OMPD_parallel_loop:
     CaptureRegions.push_back(OMPD_parallel);
     break;
-  case OMPD_target_teams:
+  // case OMPD_target_teams:
   case OMPD_target_teams_distribute:
   case OMPD_target_teams_distribute_simd:
     CaptureRegions.push_back(OMPD_task);
@@ -729,8 +734,8 @@ void clang::getOpenMPCaptureRegions(
     CaptureRegions.push_back(OMPD_teams);
     break;
   case OMPD_teams:
-  case OMPD_teams_distribute:
-  case OMPD_teams_distribute_simd:
+  // case OMPD_teams_distribute:
+  // case OMPD_teams_distribute_simd:
     CaptureRegions.push_back(OMPD_teams);
     break;
   case OMPD_target:
@@ -738,16 +743,16 @@ void clang::getOpenMPCaptureRegions(
     CaptureRegions.push_back(OMPD_task);
     CaptureRegions.push_back(OMPD_target);
     break;
-  case OMPD_teams_loop:
+  // case OMPD_teams_loop:
   case OMPD_teams_distribute_parallel_for:
   case OMPD_teams_distribute_parallel_for_simd:
     CaptureRegions.push_back(OMPD_teams);
     CaptureRegions.push_back(OMPD_parallel);
     break;
-  case OMPD_target_parallel:
+  // case OMPD_target_parallel:
   case OMPD_target_parallel_for:
   case OMPD_target_parallel_for_simd:
-  case OMPD_target_parallel_loop:
+  // case OMPD_target_parallel_loop:
     CaptureRegions.push_back(OMPD_task);
     CaptureRegions.push_back(OMPD_target);
     CaptureRegions.push_back(OMPD_parallel);
@@ -760,20 +765,20 @@ void clang::getOpenMPCaptureRegions(
     break;
   case OMPD_taskloop:
   case OMPD_taskloop_simd:
-  case OMPD_master_taskloop:
+  // case OMPD_master_taskloop:
   case OMPD_master_taskloop_simd:
-  case OMPD_masked_taskloop:
+  // case OMPD_masked_taskloop:
   case OMPD_masked_taskloop_simd:
     CaptureRegions.push_back(OMPD_taskloop);
     break;
-  case OMPD_parallel_masked_taskloop:
+  // case OMPD_parallel_masked_taskloop:
   case OMPD_parallel_masked_taskloop_simd:
-  case OMPD_parallel_master_taskloop:
+  // case OMPD_parallel_master_taskloop:
   case OMPD_parallel_master_taskloop_simd:
     CaptureRegions.push_back(OMPD_parallel);
     CaptureRegions.push_back(OMPD_taskloop);
     break;
-  case OMPD_target_teams_loop:
+  // case OMPD_target_teams_loop:
   case OMPD_target_teams_distribute_parallel_for:
   case OMPD_target_teams_distribute_parallel_for_simd:
     CaptureRegions.push_back(OMPD_task);
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 814126e321d3bc..711c3fb2ba9c1a 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -2952,8 +2952,14 @@ StmtResult Parser::ParseOpenMPDeclarativeOrExecutableDirective(
 
     StmtResult AssociatedStmt;
     if (HasAssociatedStatement) {
+      ArrayRef<OpenMPDirectiveKind> Leaves = getLeafConstructs(DKind);
       // The body is a block scope like in Lambdas and Blocks.
-      Actions.ActOnOpenMPRegionStart(DKind, getCurScope());
+      if (isCombinedConstruct(DKind)) {
+        for (OpenMPDirectiveKind Leaf : Leaves)
+          Actions.ActOnOpenMPRegionStart(Leaf, getCurScope());
+      } else {
+        Actions.ActOnOpenMPRegionStart(DKind, getCurScope());
+      }
       // FIXME: We create a bogus CompoundStmt scope to hold the contents of
       // the captured region. Code elsewhere assumes that any FunctionScopeInfo
       // should have at least one compound statement scope within it.
@@ -2966,7 +2972,17 @@ StmtResult Parser::ParseOpenMPDeclarativeOrExecutableDirective(
             getLangOpts().OpenMPIRBuilder)
           AssociatedStmt = Actions.ActOnOpenMPLoopnest(AssociatedStmt.get());
       }
-      AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses);
+      if (!Leaves.empty()) {
+        size_t i = Leaves.size() - 1;
+        AssociatedStmt = Actions.ActOnOpenMPRegionEnd(
+            AssociatedStmt, Clauses, [&](StmtResult S) {
+              return Actions.ActOnOpenMPExecutableDirective(
+                  Leaves[i--], DirName, CancelRegion, Clauses, S.get(),
+                  Loc, EndLoc);
+            });
+      } else {
+        AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses);
+      }
     } else if (DKind == OMPD_target_update || DKind == OMPD_target_enter_data ||
                DKind == OMPD_target_exit_data) {
       Actions.ActOnOpenMPRegionStart(DKind, getCurScope());
@@ -2975,9 +2991,11 @@ StmtResult Parser::ParseOpenMPDeclarativeOrExecutableDirective(
                                                   /*isStmtExpr=*/false));
       AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses);
     }
-    Directive = Actions.ActOnOpenMPExecutableDirective(
-        DKind, DirName, CancelRegion, Clauses, AssociatedStmt.get(), Loc,
-        EndLoc);
+    if (!isCombinedConstruct(DKind)) {
+      Directive = Actions.ActOnOpenMPExecutableDirective(
+          DKind, DirName, CancelRegion, Clauses, AssociatedStmt.get(), Loc,
+          EndLoc);
+    }
 
     // Exit scope.
     Actions.EndOpenMPDSABlock(Directive.get());
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 0cc0cbacb37548..f9edacb7cc782b 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -4194,16 +4194,21 @@ static void handleDeclareVariantConstructTrait(DSAStackTy *Stack,
 }
 
 void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
+  if (isCombinedConstruct(DKind)) {
+    for (OpenMPDirectiveKind Leaf : getLeafConstructs(DKind))
+      ActOnOpenMPRegionStart(Leaf, CurScope);
+    return;
+  }
   switch (DKind) {
   case OMPD_parallel:
-  case OMPD_parallel_for:
+  // case OMPD_parallel_for:
   case OMPD_parallel_for_simd:
-  case OMPD_parallel_sections:
-  case OMPD_parallel_master:
-  case OMPD_parallel_masked:
-  case OMPD_parallel_loop:
+  // case OMPD_parallel_sections:
+  // case OMPD_parallel_master:
+  // case OMPD_parallel_masked:
+  // case OMPD_parallel_loop:
   case OMPD_teams:
-  case OMPD_teams_distribute:
+  // case OMPD_teams_distribute:
   case OMPD_teams_distribute_simd: {
     QualType KmpInt32Ty = Context.getIntTypeForBitwidth(32, 1).withConst();
     QualType KmpInt32PtrTy =
@@ -4217,12 +4222,12 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
                              Params);
     break;
   }
-  case OMPD_target_teams:
-  case OMPD_target_parallel:
-  case OMPD_target_parallel_for:
+  // case OMPD_target_teams:
+  // case OMPD_target_parallel:
+  // case OMPD_target_parallel_for:
   case OMPD_target_parallel_for_simd:
-  case OMPD_target_parallel_loop:
-  case OMPD_target_teams_distribute:
+  // case OMPD_target_parallel_loop:
+  // case OMPD_target_teams_distribute:
   case OMPD_target_teams_distribute_simd: {
     QualType KmpInt32Ty = Context.getIntTypeForBitwidth(32, 1).withConst();
     QualType VoidPtrTy = Context.VoidPtrTy.withConst().withRestrict();
@@ -4269,8 +4274,8 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
                              ParamsTeamsOrParallel, /*OpenMPCaptureLevel=*/2);
     break;
   }
-  case OMPD_target:
-  case OMPD_target_simd: {
+  case OMPD_target: {
+  // case OMPD_target_simd: {
     QualType KmpInt32Ty = Context.getIntTypeForBitwidth(32, 1).withConst();
     QualType VoidPtrTy = Context.VoidPtrTy.withConst().withRestrict();
     QualType KmpInt32PtrTy =
@@ -4366,8 +4371,8 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
   }
   case OMPD_taskloop:
   case OMPD_taskloop_simd:
-  case OMPD_master_taskloop:
-  case OMPD_masked_taskloop:
+  // case OMPD_master_taskloop:
+  // case OMPD_masked_taskloop:
   case OMPD_masked_taskloop_simd:
   case OMPD_master_taskloop_simd: {
     QualType KmpInt32Ty =
@@ -4410,9 +4415,9 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
             Context, {}, AlwaysInlineAttr::Keyword_forceinline));
     break;
   }
-  case OMPD_parallel_masked_taskloop:
+  // case OMPD_parallel_masked_taskloop:
   case OMPD_parallel_masked_taskloop_simd:
-  case OMPD_parallel_master_taskloop:
+  // case OMPD_parallel_master_taskloop:
   case OMPD_parallel_master_taskloop_simd: {
     QualType KmpInt32Ty =
         Context.getIntTypeForBitwidth(/*DestWidth=*/32, /*Signed=*/1)
@@ -4478,7 +4483,7 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
                              Params);
     break;
   }
-  case OMPD_target_teams_loop:
+  // case OMPD_target_teams_loop:
   case OMPD_target_teams_distribute_parallel_for:
   case OMPD_target_teams_distribute_parallel_for_simd: {
     QualType KmpInt32Ty = Context.getIntTypeForBitwidth(32, 1).withConst();
@@ -4539,7 +4544,7 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
     break;
   }
 
-  case OMPD_teams_loop:
+  // case OMPD_teams_loop:
   case OMPD_teams_distribute_parallel_for:
   case OMPD_teams_distribute_parallel_for_simd: {
     QualType KmpInt32Ty = Context.getIntTypeForBitwidth(32, 1).withConst();
@@ -4802,8 +4807,14 @@ static bool checkOrderedOrderSpecified(Sema &S,
   return false;
 }
 
-StmtResult Sema::ActOnOpenMPRegionEnd(StmtResult S,
-                                      ArrayRef<OMPClause *> Clauses) {
+StmtResult
+Sema::ActOnOpenMPRegionEnd(StmtResult S, ArrayRef<OMPClause *> Clauses) {
+  return ActOnOpenMPRegionEnd(S, Clauses, [](StmtResult S) { return S; });
+}
+
+StmtResult
+Sema::ActOnOpenMPRegionEnd(StmtResult S, ArrayRef<OMPClause *> Clauses,
+                           std::function<StmtResult(StmtResult)> callback) {
   handleDeclareVariantConstructTrait(DSAStack, DSAStack->getCurrentDirective(),
                                      /* ScopeEntry */ false);
   if (DSAStack->getCurrentDirective() == OMPD_atomic ||
@@ -4974,6 +4985,7 @@ StmtResult Sema::ActOnOpenMPRegionEnd(StmtResult S,
     if (++CompletedRegions == CaptureRegions.size())
       DSAStack->setBodyComplete();
     SR = ActOnCapturedRegionEnd(SR.get());
+    SR = callback(SR);
   }
   return SR;
 }



More information about the llvm-commits mailing list