[clang] [AArch64][Clang] Add support for __arm_agnostic("sme_za_state") (PR #121788)

Sander de Smalen via cfe-commits cfe-commits at lists.llvm.org
Mon Jan 6 08:10:13 PST 2025


https://github.com/sdesmalen-arm created https://github.com/llvm/llvm-project/pull/121788

This adds support for parsing the attribute and codegen to map it to "aarch64_za_state_agnostic" LLVM IR attribute.

This attribute is described in the Arm C Language Extensions (ACLE) document:

  https://github.com/ARM-software/acle/blob/main/main/acle.md#__arm_agnostic

>From 159bc2ccdab5457d997c9fc8ed679e4607db0b79 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Mon, 9 Sep 2024 15:20:26 +0100
Subject: [PATCH] [AArch64][Clang] Add support for
 __arm_agnostic("sme_za_state")

This adds support for parsing the attribute and codegen to
map it to "aarch64_za_state_agnostic" LLVM IR attribute.

This attribute is described in the Arm C Language Extensions
(ACLE) document:

  https://github.com/ARM-software/acle/blob/main/main/acle.md#__arm_agnostic
---
 clang/include/clang/AST/Type.h                | 13 ++++--
 clang/include/clang/Basic/Attr.td             |  7 +++
 clang/include/clang/Basic/AttrDocs.td         | 20 ++++++++
 .../clang/Basic/DiagnosticSemaKinds.td        |  5 ++
 clang/lib/AST/TypePrinter.cpp                 |  2 +
 clang/lib/CodeGen/CGCall.cpp                  |  2 +
 clang/lib/Sema/SemaDecl.cpp                   |  8 ++++
 clang/lib/Sema/SemaType.cpp                   | 46 ++++++++++++++++++-
 .../sme-intrinsics/aarch64-sme-attrs.cpp      | 24 ++++++++++
 clang/test/Sema/aarch64-sme-func-attrs.c      | 21 +++++++++
 10 files changed, 143 insertions(+), 5 deletions(-)

diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 09c98f642852fc..49c1681d84dd52 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -4593,9 +4593,14 @@ class FunctionType : public Type {
     SME_ZT0Shift = 5,
     SME_ZT0Mask = 0b111 << SME_ZT0Shift,
 
+    // A bit to tell whether a function is agnostic about sme ZA state.
+    SME_AgnosticZAStateShift = 8,
+    SME_AgnosticZAStateMask = 1 << SME_AgnosticZAStateShift,
+
     SME_AttributeMask =
-        0b111'111'11 // We can't support more than 8 bits because of
-                     // the bitmask in FunctionTypeExtraBitfields.
+        0b1'111'111'11 // We can't support more than 16 bits because of
+                       // the bitmask in FunctionTypeArmAttributes
+                       // and ExtProtoInfo.
   };
 
   enum ArmStateValue : unsigned {
@@ -4620,7 +4625,7 @@ class FunctionType : public Type {
   struct alignas(void *) FunctionTypeArmAttributes {
     /// Any AArch64 SME ACLE type attributes that need to be propagated
     /// on declarations and function pointers.
-    unsigned AArch64SMEAttributes : 8;
+    unsigned AArch64SMEAttributes : 9;
 
     FunctionTypeArmAttributes() : AArch64SMEAttributes(SME_NormalFunction) {}
   };
@@ -5188,7 +5193,7 @@ class FunctionProtoType final
     FunctionType::ExtInfo ExtInfo;
     unsigned Variadic : 1;
     unsigned HasTrailingReturn : 1;
-    unsigned AArch64SMEAttributes : 8;
+    unsigned AArch64SMEAttributes : 9;
     Qualifiers TypeQuals;
     RefQualifierKind RefQualifier = RQ_None;
     ExceptionSpecInfo ExceptionSpec;
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 52ad72eb608c31..de8c951f66f416 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -2866,6 +2866,13 @@ def ArmPreserves : TypeAttr, TargetSpecificAttr<TargetAArch64> {
   let Documentation = [ArmPreservesDocs];
 }
 
+def ArmAgnostic : TypeAttr, TargetSpecificAttr<TargetAArch64> {
+  let Spellings = [RegularKeyword<"__arm_agnostic">];
+  let Args = [VariadicStringArgument<"AgnosticArgs">];
+  let Subjects = SubjectList<[HasFunctionProto], ErrorDiag>;
+  let Documentation = [ArmAgnosticDocs];
+}
+
 def ArmLocallyStreaming : InheritableAttr, TargetSpecificAttr<TargetAArch64> {
   let Spellings = [RegularKeyword<"__arm_locally_streaming">];
   let Subjects = SubjectList<[Function], ErrorDiag>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index fdad4c9a3ea191..2b2c6f0691b0ce 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7559,6 +7559,26 @@ The attributes ``__arm_in(S)``, ``__arm_out(S)``, ``__arm_inout(S)`` and
   }];
 }
 
+def ArmAgnosticDocs : Documentation {
+  let Category = DocCatArmSmeAttributes;
+  let Content = [{
+The ``__arm_agnostic`` keyword applies to prototyped function types and
+specifies that the function is agnostic about the given state S and
+returns with state S unchanged if state S exists.
+
+The attribute takes string arguments to instruct the compiler which state
+the function is agnostic about.  The supported states for S are:
+
+* ``"sme_za_state"`` for any state enabled by PSTATE.ZA (including the
+  bit itself)
+
+The attributes ``__arm_agnostic("sme_za_state") cannot be used in conjunction
+with ``__arm_in(S)``, ``__arm_out(S)``, ``__arm_inout(S)`` or
+``__arm_preserves(S)`` where state S describes state enabled by PSTATE.ZA,
+such as "za" or "zt0".
+  }];
+}
+
 def ArmSmeLocallyStreamingDocs : Documentation {
   let Category = DocCatArmSmeAttributes;
   let Content = [{
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 03fb7ca9bc3c3b..7c69b0fc95874f 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -3835,6 +3835,9 @@ def err_sme_unimplemented_za_save_restore : Error<
   "call to a function that shares state other than 'za' from a "
   "function that has live 'za' state requires a spill/fill of ZA, which is not yet "
   "implemented">;
+def err_sme_unimplemented_agnostic_new : Error<
+  "support to handle __arm_agnostic(\"sme_za_state\") together with "
+  "__arm_new(\"za\") or __arm_new(\"zt0\") is not yet implemented">;
 def note_sme_use_preserves_za : Note<
   "add '__arm_preserves(\"za\")' to the callee if it preserves ZA">;
 def err_sme_definition_using_sm_in_non_sme_target : Error<
@@ -3851,6 +3854,8 @@ def warn_sme_locally_streaming_has_vl_args_returns : Warning<
   "%select{returning|passing}0 a VL-dependent argument %select{from|to}0 a locally streaming function is undefined"
   " behaviour when the streaming and non-streaming vector lengths are different at runtime">,
   InGroup<AArch64SMEAttributes>, DefaultIgnore;
+def err_conflicting_attributes_arm_agnostic : Error<
+  "__arm_agnostic(\"sme_za_state\") cannot share ZA state with its caller">;
 def err_conflicting_attributes_arm_state : Error<
   "conflicting attributes for state '%0'">;
 def err_unknown_arm_state : Error<
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 7caebfb061a50b..9590145ffbd5ae 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -1000,6 +1000,8 @@ void TypePrinter::printFunctionProtoAfter(const FunctionProtoType *T,
     OS << " __arm_streaming_compatible";
   if (SMEBits & FunctionType::SME_PStateSMEnabledMask)
     OS << " __arm_streaming";
+  if (SMEBits & FunctionType::SME_AgnosticZAStateMask)
+    OS << "__arm_agnostic(\"sme_za_state\")";
   if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Preserves)
     OS << " __arm_preserves(\"za\")";
   if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_In)
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index f139c30f3dfd44..e631134c972b83 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -1779,6 +1779,8 @@ static void AddAttributesFromFunctionProtoType(ASTContext &Ctx,
     FuncAttrs.addAttribute("aarch64_pstate_sm_enabled");
   if (SMEBits & FunctionType::SME_PStateSMCompatibleMask)
     FuncAttrs.addAttribute("aarch64_pstate_sm_compatible");
+  if (SMEBits & FunctionType::SME_AgnosticZAStateMask)
+    FuncAttrs.addAttribute("aarch64_za_state_agnostic");
 
   // ZA
   if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Preserves)
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 4001c4d263f1d2..a816e675260ea4 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -12295,6 +12295,14 @@ bool Sema::CheckFunctionDeclaration(Scope *S, FunctionDecl *NewFD,
     bool UsesZA = Attr && Attr->isNewZA();
     bool UsesZT0 = Attr && Attr->isNewZT0();
 
+    if (UsesZA || UsesZT0) {
+      if (const auto *FPT = NewFD->getType()->getAs<FunctionProtoType>()) {
+        FunctionProtoType::ExtProtoInfo EPI = FPT->getExtProtoInfo();
+        if (EPI.AArch64SMEAttributes & FunctionType::SME_AgnosticZAStateMask)
+          Diag(NewFD->getLocation(), diag::err_sme_unimplemented_agnostic_new);
+      }
+    }
+
     if (NewFD->hasAttr<ArmLocallyStreamingAttr>()) {
       if (NewFD->getReturnType()->isSizelessVectorType())
         Diag(NewFD->getLocation(),
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 83464c50b4b238..0688b97d79ad88 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -161,6 +161,7 @@ static void diagnoseBadTypeAttribute(Sema &S, const ParsedAttr &attr,
   case ParsedAttr::AT_ArmIn:                                                   \
   case ParsedAttr::AT_ArmOut:                                                  \
   case ParsedAttr::AT_ArmInOut:                                                \
+  case ParsedAttr::AT_ArmAgnostic:                                             \
   case ParsedAttr::AT_AnyX86NoCallerSavedRegisters:                            \
   case ParsedAttr::AT_AnyX86NoCfCheck:                                         \
     CALLING_CONV_ATTRS_CASELIST
@@ -7745,6 +7746,38 @@ static bool checkMutualExclusion(TypeProcessingState &state,
   return true;
 }
 
+static bool handleArmAgnosticAttribute(Sema &S,
+                                       FunctionProtoType::ExtProtoInfo &EPI,
+                                       ParsedAttr &Attr) {
+  if (!Attr.getNumArgs()) {
+    S.Diag(Attr.getLoc(), diag::err_missing_arm_state) << Attr;
+    Attr.setInvalid();
+    return true;
+  }
+
+  for (unsigned I = 0; I < Attr.getNumArgs(); ++I) {
+    StringRef StateName;
+    SourceLocation LiteralLoc;
+    if (!S.checkStringLiteralArgumentAttr(Attr, I, StateName, &LiteralLoc))
+      return true;
+
+    if (StateName == "sme_za_state") {
+      if (EPI.AArch64SMEAttributes &
+          (FunctionType::SME_ZAMask | FunctionType::SME_ZT0Mask)) {
+        S.Diag(Attr.getLoc(), diag::err_conflicting_attributes_arm_agnostic);
+        Attr.setInvalid();
+        return true;
+      }
+      EPI.setArmSMEAttribute(FunctionType::SME_AgnosticZAStateMask);
+    } else {
+      S.Diag(LiteralLoc, diag::err_unknown_arm_state) << StateName;
+      Attr.setInvalid();
+      return true;
+    }
+  }
+  return false;
+}
+
 static bool handleArmStateAttribute(Sema &S,
                                     FunctionProtoType::ExtProtoInfo &EPI,
                                     ParsedAttr &Attr,
@@ -7775,6 +7808,12 @@ static bool handleArmStateAttribute(Sema &S,
       return true;
     }
 
+    if (EPI.AArch64SMEAttributes & FunctionType::SME_AgnosticZAStateMask) {
+      S.Diag(LiteralLoc, diag::err_conflicting_attributes_arm_agnostic);
+      Attr.setInvalid();
+      return true;
+    }
+
     // __arm_in(S), __arm_out(S), __arm_inout(S) and __arm_preserves(S)
     // are all mutually exclusive for the same S, so check if there are
     // conflicting attributes.
@@ -7925,7 +7964,8 @@ static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
       attr.getKind() == ParsedAttr::AT_ArmPreserves ||
       attr.getKind() == ParsedAttr::AT_ArmIn ||
       attr.getKind() == ParsedAttr::AT_ArmOut ||
-      attr.getKind() == ParsedAttr::AT_ArmInOut) {
+      attr.getKind() == ParsedAttr::AT_ArmInOut ||
+      attr.getKind() == ParsedAttr::AT_ArmAgnostic) {
     if (S.CheckAttrTarget(attr))
       return true;
 
@@ -7976,6 +8016,10 @@ static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
       if (handleArmStateAttribute(S, EPI, attr, FunctionType::ARM_InOut))
         return true;
       break;
+    case ParsedAttr::AT_ArmAgnostic:
+      if (handleArmAgnosticAttribute(S, EPI, attr))
+        return true;
+      break;
     default:
       llvm_unreachable("Unsupported attribute");
     }
diff --git a/clang/test/CodeGen/AArch64/sme-intrinsics/aarch64-sme-attrs.cpp b/clang/test/CodeGen/AArch64/sme-intrinsics/aarch64-sme-attrs.cpp
index 9885ac45e6a0e0..54762c8b414124 100644
--- a/clang/test/CodeGen/AArch64/sme-intrinsics/aarch64-sme-attrs.cpp
+++ b/clang/test/CodeGen/AArch64/sme-intrinsics/aarch64-sme-attrs.cpp
@@ -15,6 +15,7 @@ int streaming_compatible_decl(void) __arm_streaming_compatible;
 int shared_za_decl(void) __arm_inout("za");
 int preserves_za_decl(void) __arm_preserves("za");
 int private_za_decl(void);
+int agnostic_za_decl(void) __arm_agnostic("sme_za_state");
 
 // == FUNCTION DEFINITIONS ==
 
@@ -130,6 +131,27 @@ __arm_new("za") int new_za_callee() {
 
 // CHECK: declare i32 @private_za_decl()
 
+// CHECK-LABEL: @agnostic_za_caller()
+// CHECK-SAME: #[[ZA_AGNOSTIC:[0-9]+]]
+// CHECK: call i32 @normal_callee()
+//
+int agnostic_za_caller() __arm_agnostic("sme_za_state") {
+  return normal_callee();
+}
+
+// CHECK-LABEL: @agnostic_za_callee()
+// CHECK: call i32 @agnostic_za_decl() #[[ZA_AGNOSTIC_CALL:[0-9]+]]
+//
+int agnostic_za_callee() {
+  return agnostic_za_decl();
+}
+
+// CHECK-LABEL: @agnostic_za_callee_live_za()
+// CHECK: call i32 @agnostic_za_decl() #[[ZA_AGNOSTIC_CALL]]
+//
+int agnostic_za_callee_live_za() __arm_inout("za") {
+  return agnostic_za_decl();
+}
 
 // Ensure that the attributes are correctly propagated to function types
 // and also to callsites.
@@ -289,12 +311,14 @@ int test_variadic_template() __arm_inout("za") {
 // CHECK: attributes #[[ZA_PRESERVED]] = { mustprogress noinline nounwind "aarch64_preserves_za" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
 // CHECK: attributes #[[ZA_PRESERVED_DECL]] = { "aarch64_preserves_za" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
 // CHECK: attributes #[[ZA_NEW]] = { mustprogress noinline nounwind "aarch64_new_za" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
+// CHECK: attributes #[[ZA_AGNOSTIC]] = { mustprogress noinline nounwind "aarch64_za_state_agnostic" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
 // CHECK: attributes #[[NORMAL_DEF]] = { mustprogress noinline nounwind "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
 // CHECK: attributes #[[SM_ENABLED_CALL]] = { "aarch64_pstate_sm_enabled" }
 // CHECK: attributes #[[SM_COMPATIBLE_CALL]] = { "aarch64_pstate_sm_compatible" }
 // CHECK: attributes #[[SM_BODY_CALL]] = { "aarch64_pstate_sm_body" }
 // CHECK: attributes #[[ZA_SHARED_CALL]] = { "aarch64_inout_za" }
 // CHECK: attributes #[[ZA_PRESERVED_CALL]] = { "aarch64_preserves_za" }
+// CHECK: attributes #[[ZA_AGNOSTIC_CALL]] = { "aarch64_za_state_agnostic" }
 // CHECK: attributes #[[NOUNWIND_CALL]] = { nounwind }
 // CHECK: attributes #[[NOUNWIND_SM_ENABLED_CALL]] = { nounwind "aarch64_pstate_sm_enabled" }
 // CHECK: attributes #[[NOUNWIND_SM_COMPATIBLE_CALL]] = { nounwind "aarch64_pstate_sm_compatible" }
diff --git a/clang/test/Sema/aarch64-sme-func-attrs.c b/clang/test/Sema/aarch64-sme-func-attrs.c
index 0c263eb2610cfb..65628bea579fab 100644
--- a/clang/test/Sema/aarch64-sme-func-attrs.c
+++ b/clang/test/Sema/aarch64-sme-func-attrs.c
@@ -9,6 +9,7 @@ void sme_arm_streaming_compatible(void) __arm_streaming_compatible;
 __arm_new("za") void sme_arm_new_za(void) {}
 void sme_arm_shared_za(void) __arm_inout("za");
 void sme_arm_preserves_za(void) __arm_preserves("za");
+void sme_arm_agnostic(void) __arm_agnostic("sme_za_state");
 
 __arm_new("za") void sme_arm_streaming_new_za(void) __arm_streaming {}
 void sme_arm_streaming_shared_za(void) __arm_streaming __arm_inout("za");
@@ -88,6 +89,26 @@ fptrty7 invalid_streaming_func() { return streaming_ptr_invalid; }
 // expected-error at +1 {{'__arm_streaming' only applies to function types; type here is 'void ()'}}
 void function_no_prototype() __arm_streaming;
 
+// expected-cpp-error at +2 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
+// expected-error at +1 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
+void sme_arm_agnostic_shared_za_zt0(void) __arm_agnostic("sme_za_state") __arm_inout("zt0") {}
+
+// expected-cpp-error at +2 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
+// expected-error at +1 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
+void sme_arm_agnostic_shared_za_za(void) __arm_agnostic("sme_za_state") __arm_inout("za") {}
+
+// expected-cpp-error at +2 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
+// expected-error at +1 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
+void sme_arm_agnostic_shared_za_za_rev(void) __arm_inout("za") __arm_agnostic("sme_za_state") {}
+
+// expected-cpp-error at +2 {{support to handle __arm_agnostic("sme_za_state") together with __arm_new("za") or __arm_new("zt0") is not yet implemented}}
+// expected-error at +1 {{support to handle __arm_agnostic("sme_za_state") together with __arm_new("za") or __arm_new("zt0") is not yet implemented}}
+__arm_new("zt0") void sme_arm_agnostic_arm_new_zt0(void) __arm_agnostic("sme_za_state") {}
+
+// expected-cpp-error at +2 {{support to handle __arm_agnostic("sme_za_state") together with __arm_new("za") or __arm_new("zt0") is not yet implemented}}
+// expected-error at +1 {{support to handle __arm_agnostic("sme_za_state") together with __arm_new("za") or __arm_new("zt0") is not yet implemented}}
+__arm_new("za") void sme_arm_agnostic_arm_new_za(void) __arm_agnostic("sme_za_state") {}
+
 //
 // Check for incorrect conversions of function pointers with the attributes
 //



More information about the cfe-commits mailing list