[clang] [clang][SME] Emit error for OpenMP captured regions in SME functions (PR #124590)

Benjamin Maxwell via cfe-commits cfe-commits at lists.llvm.org
Tue Jan 28 03:28:34 PST 2025


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/124590

>From f3083be395aa9cfb1e6d44f00a32faaee347468a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 27 Jan 2025 16:43:58 +0000
Subject: [PATCH 1/3] [clang][SME] Emit error for OpemMP captured regions in
 SME functions

Currently, these generate incorrect code, as streaming/SME attributes
are not propagated to the outlined function. As we've yet to work on
mixing OpenMP and streaming functions (and determine how they should
interact with OpenMP's runtime), we think it is best to disallow this
for now.
---
 clang/include/clang/AST/Decl.h                |  6 ++
 .../clang/Basic/DiagnosticSemaKinds.td        |  3 +
 clang/lib/AST/Decl.cpp                        | 14 ++++
 clang/lib/Sema/SemaARM.cpp                    | 14 ----
 clang/lib/Sema/SemaStmt.cpp                   | 22 ++++++
 ...aarch64-sme-attrs-openmp-captured-region.c | 68 +++++++++++++++++++
 6 files changed, 113 insertions(+), 14 deletions(-)
 create mode 100644 clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c

diff --git a/clang/include/clang/AST/Decl.h b/clang/include/clang/AST/Decl.h
index 16403774e72b31..9593bab576412a 100644
--- a/clang/include/clang/AST/Decl.h
+++ b/clang/include/clang/AST/Decl.h
@@ -5139,6 +5139,12 @@ static constexpr StringRef getOpenMPVariantManglingSeparatorStr() {
 bool IsArmStreamingFunction(const FunctionDecl *FD,
                             bool IncludeLocallyStreaming);
 
+/// Returns whether the given FunctionDecl has Arm ZA state.
+bool hasArmZAState(const FunctionDecl *FD);
+
+/// Returns whether the given FunctionDecl has Arm ZT0 state.
+bool hasArmZT0State(const FunctionDecl *FD);
+
 } // namespace clang
 
 #endif // LLVM_CLANG_AST_DECL_H
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 774e5484cfa0e7..66ad02c7e0a7f8 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -3864,6 +3864,9 @@ def err_sme_definition_using_za_in_non_sme_target : Error<
   "function using ZA state requires 'sme'">;
 def err_sme_definition_using_zt0_in_non_sme2_target : Error<
   "function using ZT0 state requires 'sme2'">;
+def err_sme_openmp_captured_region : Error<
+  "OpenMP captured regions are not yet supported in "
+  "%select{streaming functions|functions with ZA state|functions with ZT0 state}0">;
 def warn_sme_streaming_pass_return_vl_to_non_streaming : Warning<
   "%select{returning|passing}0 a VL-dependent argument %select{from|to}0 a function with a different"
   " streaming-mode is undefined behaviour when the streaming and non-streaming vector lengths are different at runtime">,
diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp
index 728556614e632f..14a18111123cb6 100644
--- a/clang/lib/AST/Decl.cpp
+++ b/clang/lib/AST/Decl.cpp
@@ -5844,3 +5844,17 @@ bool clang::IsArmStreamingFunction(const FunctionDecl *FD,
 
   return false;
 }
+
+bool clang::hasArmZAState(const FunctionDecl *FD) {
+  const auto *T = FD->getType()->getAs<FunctionProtoType>();
+  return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
+                   FunctionType::ARM_None) ||
+         (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
+}
+
+bool clang::hasArmZT0State(const FunctionDecl *FD) {
+  const auto *T = FD->getType()->getAs<FunctionProtoType>();
+  return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
+                   FunctionType::ARM_None) ||
+         (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
+}
diff --git a/clang/lib/Sema/SemaARM.cpp b/clang/lib/Sema/SemaARM.cpp
index 2620bbc97ba02a..9fbe8358f716b3 100644
--- a/clang/lib/Sema/SemaARM.cpp
+++ b/clang/lib/Sema/SemaARM.cpp
@@ -624,20 +624,6 @@ static bool checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
   return true;
 }
 
-static bool hasArmZAState(const FunctionDecl *FD) {
-  const auto *T = FD->getType()->getAs<FunctionProtoType>();
-  return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
-                   FunctionType::ARM_None) ||
-         (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
-}
-
-static bool hasArmZT0State(const FunctionDecl *FD) {
-  const auto *T = FD->getType()->getAs<FunctionProtoType>();
-  return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
-                   FunctionType::ARM_None) ||
-         (FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
-}
-
 static ArmSMEState getSMEState(unsigned BuiltinID) {
   switch (BuiltinID) {
   default:
diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp
index 25a07d0315eac1..6eaa597e07436e 100644
--- a/clang/lib/Sema/SemaStmt.cpp
+++ b/clang/lib/Sema/SemaStmt.cpp
@@ -4568,9 +4568,28 @@ buildCapturedStmtCaptureList(Sema &S, CapturedRegionScopeInfo *RSI,
   return false;
 }
 
+static std::optional<int>
+isOpenMPCapturedRegionInArmSMEFunction(Sema const &S, CapturedRegionKind Kind) {
+  if (!S.getLangOpts().OpenMP || Kind != CR_OpenMP)
+    return false;
+  FunctionDecl *FD = S.getCurFunctionDecl(/*AllowLambda=*/true);
+  if (!FD)
+    return false;
+  if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true))
+    return /* in streaming functions */ 0;
+  if (hasArmZAState(FD))
+    return /* in functions with ZA state */ 1;
+  if (hasArmZT0State(FD))
+    return /* in fuctions with ZT0 state */ 2;
+  return {};
+}
+
 void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
                                     CapturedRegionKind Kind,
                                     unsigned NumParams) {
+  if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
+    Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;
+
   CapturedDecl *CD = nullptr;
   RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, NumParams);
 
@@ -4602,6 +4621,9 @@ void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
                                     CapturedRegionKind Kind,
                                     ArrayRef<CapturedParamNameType> Params,
                                     unsigned OpenMPCaptureLevel) {
+  if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
+    Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;
+
   CapturedDecl *CD = nullptr;
   RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, Params.size());
 
diff --git a/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c b/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c
new file mode 100644
index 00000000000000..9fee61cfe26a49
--- /dev/null
+++ b/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c
@@ -0,0 +1,68 @@
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify=expected-cpp -x c++ %s
+
+int compute(int);
+
+void streaming_openmp_captured_region(int* out) __arm_streaming
+{
+  // expected-error at +2 {{OpenMP captured regions are not yet supported in streaming functions}}
+  // expected-cpp-error at +1 {{OpenMP captured regions are not yet supported in streaming functions}}
+  #pragma omp parallel for num_threads(32)
+  for(int ci =0;ci< 8;ci++)
+  {
+    out[ci] =compute(ci);
+  }
+}
+
+__arm_locally_streaming void locally_streaming_openmp_captured_region(int* out)
+{
+  // expected-error at +2 {{OpenMP captured regions are not yet supported in streaming functions}}
+  // expected-cpp-error at +1 {{OpenMP captured regions are not yet supported in streaming functions}}
+  #pragma omp parallel for num_threads(32)
+  for(int ci =0;ci< 8;ci++)
+  {
+    out[ci] = compute(ci);
+  }
+}
+
+void za_state_captured_region(int* out) __arm_inout("za")
+{
+  // expected-error at +2 {{OpenMP captured regions are not yet supported in functions with ZA state}}
+  // expected-cpp-error at +1 {{OpenMP captured regions are not yet supported in functions with ZA state}}
+  #pragma omp parallel for num_threads(32)
+  for(int ci =0;ci< 8;ci++)
+  {
+    out[ci] =compute(ci);
+  }
+}
+
+void zt0_state_openmp_captured_region(int* out) __arm_inout("zt0")
+{
+  // expected-error at +2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
+  // expected-cpp-error at +1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
+  #pragma omp parallel for num_threads(32)
+  for(int ci =0;ci< 8;ci++)
+  {
+    out[ci] = compute(ci);
+  }
+}
+
+/// OpenMP directives that don't create a captured region are okay:
+
+void streaming_function_openmp(int* out) __arm_streaming __arm_inout("za", "zt0")
+{
+  #pragma omp unroll full
+  for(int ci =0;ci< 8;ci++)
+  {
+    out[ci] =compute(ci);
+  }
+}
+
+__arm_locally_streaming void locally_streaming_openmp(int* out) __arm_inout("za", "zt0")
+{
+  #pragma omp unroll full
+  for(int ci =0;ci< 8;ci++)
+  {
+    out[ci] = compute(ci);
+  }
+}

>From 71e6ba2262d68d2daadd4523479b9c78ac8b3247 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 27 Jan 2025 20:38:03 +0000
Subject: [PATCH 2/3] Update clang/lib/Sema/SemaStmt.cpp

---
 clang/lib/Sema/SemaStmt.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp
index 6eaa597e07436e..09bafa5bae19da 100644
--- a/clang/lib/Sema/SemaStmt.cpp
+++ b/clang/lib/Sema/SemaStmt.cpp
@@ -4571,10 +4571,10 @@ buildCapturedStmtCaptureList(Sema &S, CapturedRegionScopeInfo *RSI,
 static std::optional<int>
 isOpenMPCapturedRegionInArmSMEFunction(Sema const &S, CapturedRegionKind Kind) {
   if (!S.getLangOpts().OpenMP || Kind != CR_OpenMP)
-    return false;
+    return {};
   FunctionDecl *FD = S.getCurFunctionDecl(/*AllowLambda=*/true);
   if (!FD)
-    return false;
+    return {};
   if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true))
     return /* in streaming functions */ 0;
   if (hasArmZAState(FD))

>From 4db99bea227ac9181705554bf94b85cfaa6174ed Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 28 Jan 2025 11:27:28 +0000
Subject: [PATCH 3/3] Fixups

---
 clang/lib/Sema/SemaStmt.cpp                   | 17 +++--
 ...aarch64-sme-attrs-openmp-captured-region.c | 67 +++++++++++--------
 2 files changed, 48 insertions(+), 36 deletions(-)

diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp
index 09bafa5bae19da..947651d514b3b0 100644
--- a/clang/lib/Sema/SemaStmt.cpp
+++ b/clang/lib/Sema/SemaStmt.cpp
@@ -4572,15 +4572,14 @@ static std::optional<int>
 isOpenMPCapturedRegionInArmSMEFunction(Sema const &S, CapturedRegionKind Kind) {
   if (!S.getLangOpts().OpenMP || Kind != CR_OpenMP)
     return {};
-  FunctionDecl *FD = S.getCurFunctionDecl(/*AllowLambda=*/true);
-  if (!FD)
-    return {};
-  if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true))
-    return /* in streaming functions */ 0;
-  if (hasArmZAState(FD))
-    return /* in functions with ZA state */ 1;
-  if (hasArmZT0State(FD))
-    return /* in fuctions with ZT0 state */ 2;
+  if (const FunctionDecl *FD = S.getCurFunctionDecl(/*AllowLambda=*/true)) {
+    if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true))
+      return /* in streaming functions */ 0;
+    if (hasArmZAState(FD))
+      return /* in functions with ZA state */ 1;
+    if (hasArmZT0State(FD))
+      return /* in fuctions with ZT0 state */ 2;
+  }
   return {};
 }
 
diff --git a/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c b/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c
index 9fee61cfe26a49..6fb7c60d02cd71 100644
--- a/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c
+++ b/clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c
@@ -3,66 +3,79 @@
 
 int compute(int);
 
-void streaming_openmp_captured_region(int* out) __arm_streaming
-{
+void streaming_openmp_captured_region(int * out) __arm_streaming {
   // expected-error at +2 {{OpenMP captured regions are not yet supported in streaming functions}}
   // expected-cpp-error at +1 {{OpenMP captured regions are not yet supported in streaming functions}}
   #pragma omp parallel for num_threads(32)
-  for(int ci =0;ci< 8;ci++)
-  {
-    out[ci] =compute(ci);
+  for (int ci = 0; ci < 8; ci++) {
+    out[ci] = compute(ci);
   }
 }
 
-__arm_locally_streaming void locally_streaming_openmp_captured_region(int* out)
-{
+__arm_locally_streaming void locally_streaming_openmp_captured_region(int * out) {
   // expected-error at +2 {{OpenMP captured regions are not yet supported in streaming functions}}
   // expected-cpp-error at +1 {{OpenMP captured regions are not yet supported in streaming functions}}
   #pragma omp parallel for num_threads(32)
-  for(int ci =0;ci< 8;ci++)
-  {
+  for (int ci = 0; ci < 8; ci++) {
     out[ci] = compute(ci);
   }
 }
 
-void za_state_captured_region(int* out) __arm_inout("za")
-{
+void za_state_captured_region(int * out) __arm_inout("za") {
   // expected-error at +2 {{OpenMP captured regions are not yet supported in functions with ZA state}}
   // expected-cpp-error at +1 {{OpenMP captured regions are not yet supported in functions with ZA state}}
   #pragma omp parallel for num_threads(32)
-  for(int ci =0;ci< 8;ci++)
-  {
-    out[ci] =compute(ci);
+  for (int ci = 0; ci < 8; ci++) {
+    out[ci] = compute(ci);
+  }
+}
+
+__arm_new("za") void new_za_state_captured_region(int * out) {
+  // expected-error at +2 {{OpenMP captured regions are not yet supported in functions with ZA state}}
+  // expected-cpp-error at +1 {{OpenMP captured regions are not yet supported in functions with ZA state}}
+  #pragma omp parallel for num_threads(32)
+  for (int ci = 0; ci < 8; ci++) {
+    out[ci] = compute(ci);
+  }
+}
+
+void zt0_state_openmp_captured_region(int * out) __arm_inout("zt0") {
+  // expected-error at +2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
+  // expected-cpp-error at +1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
+  #pragma omp parallel for num_threads(32)
+  for (int ci = 0; ci < 8; ci++) {
+    out[ci] = compute(ci);
   }
 }
 
-void zt0_state_openmp_captured_region(int* out) __arm_inout("zt0")
-{
+__arm_new("zt0") void new_zt0_state_openmp_captured_region(int * out) {
   // expected-error at +2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
   // expected-cpp-error at +1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
   #pragma omp parallel for num_threads(32)
-  for(int ci =0;ci< 8;ci++)
-  {
+  for (int ci = 0; ci < 8; ci++) {
     out[ci] = compute(ci);
   }
 }
 
 /// OpenMP directives that don't create a captured region are okay:
 
-void streaming_function_openmp(int* out) __arm_streaming __arm_inout("za", "zt0")
-{
+void streaming_function_openmp(int * out) __arm_streaming __arm_inout("za", "zt0") {
   #pragma omp unroll full
-  for(int ci =0;ci< 8;ci++)
-  {
-    out[ci] =compute(ci);
+  for (int ci = 0; ci < 8; ci++) {
+    out[ci] = compute(ci);
+  }
+}
+
+__arm_locally_streaming void locally_streaming_openmp(int * out) __arm_inout("za", "zt0") {
+  #pragma omp unroll full
+  for (int ci = 0; ci < 8; ci++) {
+    out[ci] = compute(ci);
   }
 }
 
-__arm_locally_streaming void locally_streaming_openmp(int* out) __arm_inout("za", "zt0")
-{
+__arm_new("za", "zt0") void arm_new_openmp(int * out) {
   #pragma omp unroll full
-  for(int ci =0;ci< 8;ci++)
-  {
+  for (int ci = 0; ci < 8; ci++) {
     out[ci] = compute(ci);
   }
 }



More information about the cfe-commits mailing list