[clang] [Clang][AArch64] Change SME attributes for shared/new/preserved state. (PR #76971)

via cfe-commits cfe-commits at lists.llvm.org
Thu Jan 4 08:41:09 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang-codegen

Author: Sander de Smalen (sdesmalen-arm)

<details>
<summary>Changes</summary>

This patch replaces the `__arm_new_za`, `__arm_shared_za` and `__arm_preserves_za` attributes in favour of:
* `__arm_new("za")`
* `__arm_in("za")`
* `__arm_out("za")`
* `__arm_inout("za")`
* `__arm_preserves("za")`

As described in https://github.com/ARM-software/acle/pull/276.

One change is that `__arm_in/out/inout/preserves(S)` are all mutually exclusive, whereas previously it was fine to write `__arm_shared_za __arm_preserves_za`. This case is now represented with `__arm_in("za")`.

The current implementation uses the same LLVM attributes under the hood, since `__arm_in/out/inout` are all variations of "shared ZA", so can use the existing `aarch64_pstate_za_shared` attribute in LLVM.

A future patch will add support for the new "zt0" state as introduced with SME2.

---

Patch is 634.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76971.diff


68 Files Affected:

- (modified) clang/include/clang/AST/Type.h (+19-4) 
- (modified) clang/include/clang/Basic/Attr.td (+35-14) 
- (modified) clang/include/clang/Basic/AttrDocs.td (+65-17) 
- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+6) 
- (modified) clang/lib/AST/TypePrinter.cpp (+15-16) 
- (modified) clang/lib/CodeGen/CGCall.cpp (+12-7) 
- (modified) clang/lib/CodeGen/CodeGenModule.cpp (+4-2) 
- (modified) clang/lib/Parse/ParseDecl.cpp (+3) 
- (modified) clang/lib/Parse/ParseDeclCXX.cpp (+13-5) 
- (modified) clang/lib/Parse/ParseTentative.cpp (-2) 
- (modified) clang/lib/Sema/SemaChecking.cpp (+24-9) 
- (modified) clang/lib/Sema/SemaDecl.cpp (+7-2) 
- (modified) clang/lib/Sema/SemaDeclAttr.cpp (+79-15) 
- (modified) clang/lib/Sema/SemaExpr.cpp (+9-4) 
- (modified) clang/lib/Sema/SemaOverload.cpp (+8-7) 
- (modified) clang/lib/Sema/SemaType.cpp (+71-9) 
- (modified) clang/test/AST/ast-dump-sme-attributes.cpp (+6-6) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/aarch64-sme-attrs.cpp (+14-14) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_add-i32.c (+8-8) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_add-i64.c (+8-8) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ld1.c (+10-10) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ld1_vnum.c (+10-10) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ldr.c (+5-5) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_mopa-za32.c (+7-7) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_mopa-za64.c (+5-5) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_mops-za32.c (+7-7) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_mops-za64.c (+5-5) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_read.c (+96-96) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_st1.c (+10-10) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_st1_vnum.c (+10-10) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_state_funs.c (+1-1) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_str.c (+5-5) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_write.c (+96-96) 
- (modified) clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_zero.c (+4-4) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_add.c (+28-28) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_bmop.c (+4-4) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_fp_dots.c (+12-12) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_int_dots.c (+48-48) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_ldr_str_zt.c (+2-2) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_luti2_lane_zt.c (+9-9) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_luti2_lane_zt_x2.c (+9-9) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_luti2_lane_zt_x4.c (+9-9) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_luti4_lane_zt.c (+9-9) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_luti4_lane_zt_x2.c (+9-9) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_luti4_lane_zt_x4.c (+7-7) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_mla.c (+12-12) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_mlal.c (+32-32) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_mlall.c (+80-80) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_mls.c (+12-12) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_mlsl.c (+32-32) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_mop.c (+4-4) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_read.c (+72-72) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_sub.c (+28-28) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_vdot.c (+10-10) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_write.c (+72-72) 
- (modified) clang/test/CodeGen/aarch64-sme2-intrinsics/acle_sme2_zero_zt.c (+1-1) 
- (modified) clang/test/Modules/aarch64-sme-keywords.cppm (+3-3) 
- (modified) clang/test/Parser/c2x-attribute-keywords.c (+67-67) 
- (modified) clang/test/Parser/c2x-attribute-keywords.m (+2-2) 
- (modified) clang/test/Parser/cxx0x-keyword-attributes.cpp (+181-181) 
- (modified) clang/test/Sema/aarch64-incompat-sm-builtin-calls.c (+2-2) 
- (modified) clang/test/Sema/aarch64-sme-func-attrs-without-target-feature.cpp (+6-6) 
- (modified) clang/test/Sema/aarch64-sme-func-attrs.c (+134-64) 
- (modified) clang/test/Sema/aarch64-sme-intrinsics/acle_sme_imm.cpp (+7-7) 
- (modified) clang/test/Sema/aarch64-sme-intrinsics/acle_sme_target.c (+4-4) 
- (modified) clang/test/Sema/aarch64-sme2-intrinsics/acle_sme2_imm.cpp (+15-15) 
- (modified) clang/utils/TableGen/ClangAttrEmitter.cpp (+1-6) 
- (modified) clang/utils/TableGen/SveEmitter.cpp (+1-1) 


``````````diff
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 1afa693672860f..fe22ea94137007 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -4033,12 +4033,27 @@ class FunctionType : public Type {
     SME_NormalFunction = 0,
     SME_PStateSMEnabledMask = 1 << 0,
     SME_PStateSMCompatibleMask = 1 << 1,
-    SME_PStateZASharedMask = 1 << 2,
-    SME_PStateZAPreservedMask = 1 << 3,
-    SME_AttributeMask = 0b111'111 // We only support maximum 6 bits because of the
-                                  // bitmask in FunctionTypeExtraBitfields.
+
+    // Describes the value of the state using ArmStateValue.
+    SME_PstateZAShift = 2,
+    SME_PStateZAMask = 0b111 << SME_PstateZAShift,
+
+    SME_AttributeMask = 0b111'111 // We only support maximum 6 bits because of
+                                  // the bitmask in FunctionTypeExtraBitfields.
+  };
+
+  enum ArmStateValue : unsigned {
+    ARM_None = 0,
+    ARM_Preserves = 1,
+    ARM_In = 2,
+    ARM_Out = 3,
+    ARM_InOut = 4,
   };
 
+  static ArmStateValue getArmZAState(unsigned AttrBits) {
+    return (ArmStateValue)((AttrBits & SME_PStateZAMask) >> SME_PstateZAShift);
+  }
+
   /// A simple holder for various uncommon bits which do not fit in
   /// FunctionTypeBitfields. Aligned to alignof(void *) to maintain the
   /// alignment of subsequent objects in TrailingObjects.
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index db17211747b17d..31f9c9f938a849 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -2525,16 +2525,45 @@ def ArmStreamingCompatible : TypeAttr, TargetSpecificAttr<TargetAArch64> {
   let Documentation = [ArmSmeStreamingCompatibleDocs];
 }
 
-def ArmSharedZA : TypeAttr, TargetSpecificAttr<TargetAArch64> {
-  let Spellings = [RegularKeyword<"__arm_shared_za">];
+def ArmNew : InheritableAttr, TargetSpecificAttr<TargetAArch64> {
+  let Spellings = [RegularKeyword<"__arm_new">];
+  let Args = [VariadicStringArgument<"NewArgs">];
+  let Subjects = SubjectList<[Function], ErrorDiag>;
+  let Documentation = [ArmNewDocs];
+
+  let AdditionalMembers = [{
+    bool isNewZA() const {
+      return llvm::is_contained(newArgs(), "za");
+    }
+  }];
+}
+
+def ArmIn : TypeAttr, TargetSpecificAttr<TargetAArch64> {
+  let Spellings = [RegularKeyword<"__arm_in">];
+  let Args = [VariadicStringArgument<"InArgs">];
+  let Subjects = SubjectList<[HasFunctionProto], ErrorDiag>;
+  let Documentation = [ArmInDocs];
+}
+
+def ArmOut : TypeAttr, TargetSpecificAttr<TargetAArch64> {
+  let Spellings = [RegularKeyword<"__arm_out">];
+  let Args = [VariadicStringArgument<"OutArgs">];
+  let Subjects = SubjectList<[HasFunctionProto], ErrorDiag>;
+  let Documentation = [ArmOutDocs];
+}
+
+def ArmInOut : TypeAttr, TargetSpecificAttr<TargetAArch64> {
+  let Spellings = [RegularKeyword<"__arm_inout">];
+  let Args = [VariadicStringArgument<"InOutArgs">];
   let Subjects = SubjectList<[HasFunctionProto], ErrorDiag>;
-  let Documentation = [ArmSmeSharedZADocs];
+  let Documentation = [ArmInOutDocs];
 }
 
-def ArmPreservesZA : TypeAttr, TargetSpecificAttr<TargetAArch64> {
-  let Spellings = [RegularKeyword<"__arm_preserves_za">];
+def ArmPreserves : TypeAttr, TargetSpecificAttr<TargetAArch64> {
+  let Spellings = [RegularKeyword<"__arm_preserves">];
+  let Args = [VariadicStringArgument<"PreserveArgs">];
   let Subjects = SubjectList<[HasFunctionProto], ErrorDiag>;
-  let Documentation = [ArmSmePreservesZADocs];
+  let Documentation = [ArmPreservesDocs];
 }
 
 def ArmLocallyStreaming : InheritableAttr, TargetSpecificAttr<TargetAArch64> {
@@ -2543,14 +2572,6 @@ def ArmLocallyStreaming : InheritableAttr, TargetSpecificAttr<TargetAArch64> {
   let Documentation = [ArmSmeLocallyStreamingDocs];
 }
 
-def ArmNewZA : InheritableAttr, TargetSpecificAttr<TargetAArch64> {
-  let Spellings = [RegularKeyword<"__arm_new_za">];
-  let Subjects = SubjectList<[Function], ErrorDiag>;
-  let Documentation = [ArmSmeNewZADocs];
-}
-def : MutualExclusions<[ArmNewZA, ArmSharedZA]>;
-def : MutualExclusions<[ArmNewZA, ArmPreservesZA]>;
-
 
 def Pure : InheritableAttr {
   let Spellings = [GCC<"pure">];
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 98a7ecc7fd7df3..f57035f9d139e1 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -6852,30 +6852,73 @@ without changing modes.
   }];
 }
 
-def ArmSmeSharedZADocs : Documentation {
+def ArmInDocs : Documentation {
   let Category = DocCatArmSmeAttributes;
   let Content = [{
-The ``__arm_shared_za`` keyword applies to prototyped function types and specifies
-that the function shares SME's matrix storage (ZA) with its caller.  This
-means that:
+The ``__arm_in`` keyword applies to prototyped function types and specifies
+that the function shares state S with its caller.  For ``__arm_in``, the
+function takes the state S as input and returns with the state S unchanged.
 
-* the function requires that the processor implements the Scalable Matrix
-  Extension (SME).
+The attribute takes string arguments to instruct the compiler which state
+is shared.  The supported states for S are:
 
-* the function enters with ZA in an active state.
+* ``"za"`` for Matrix Storage (requires SME)
 
-* the function returns with ZA in an active state.
+The attributes ``__arm_in(S)``, ``__arm_out(S)``, ``__arm_inout(S)`` and
+``__arm_preserves(S)`` are all mutually exclusive for the same state S.
   }];
 }
 
-def ArmSmePreservesZADocs : Documentation {
+def ArmOutDocs : Documentation {
   let Category = DocCatArmSmeAttributes;
   let Content = [{
-The ``__arm_preserves_za`` keyword applies to prototyped function types and
-specifies that the function does not modify ZA state.
+The ``__arm_out`` keyword applies to prototyped function types and specifies
+that the function shares state S with its caller.  For ``__arm_out``, the
+function ignores the incoming state for S and returns new state for S.
+
+The attribute takes string arguments to instruct the compiler which state
+is shared.  The supported states for S are:
+
+* ``"za"`` for Matrix Storage (requires SME)
+
+The attributes ``__arm_in(S)``, ``__arm_out(S)``, ``__arm_inout(S)`` and
+``__arm_preserves(S)`` are all mutually exclusive for the same state S.
   }];
 }
 
+def ArmInOutDocs : Documentation {
+  let Category = DocCatArmSmeAttributes;
+  let Content = [{
+The ``__arm_inout`` keyword applies to prototyped function types and specifies
+that the function shares state S with its caller.  For ``__arm_inout``, the
+function takes the state S as input and returns new state for S.
+
+The attribute takes string arguments to instruct the compiler which state
+is shared.  The supported states for S are:
+
+* ``"za"`` for Matrix Storage (requires SME)
+
+The attributes ``__arm_in(S)``, ``__arm_out(S)``, ``__arm_inout(S)`` and
+``__arm_preserves(S)`` are all mutually exclusive for the same state S.
+  }];
+}
+
+def ArmPreservesDocs : Documentation {
+  let Category = DocCatArmSmeAttributes;
+  let Content = [{
+The ``__arm_preserves`` keyword applies to prototyped function types and
+specifies that the function does not read the incoming state S and returns with
+the state S unchanged.
+
+The attribute takes string arguments to instruct the compiler which state
+is shared.  The supported states for S are:
+
+* ``"za"`` for Matrix Storage (requires SME)
+
+The attributes ``__arm_in(S)``, ``__arm_out(S)``, ``__arm_inout(S)`` and
+``__arm_preserves(S)`` are all mutually exclusive for the same state S.
+  }];
+}
 
 def ArmSmeLocallyStreamingDocs : Documentation {
   let Category = DocCatArmSmeAttributes;
@@ -6898,13 +6941,18 @@ at the end of the function.
   }];
 }
 
-def ArmSmeNewZADocs : Documentation {
+def ArmNewDocs : Documentation {
   let Category = DocCatArmSmeAttributes;
   let Content = [{
-The ``__arm_new_za`` keyword applies to function declarations and specifies
-that the function will be set up with a fresh ZA context.
+The ``__arm_new`` keyword applies to function declarations and specifies
+that the function will create a new scope for state S.
+
+The attribute takes string arguments to instruct the compiler which state
+is shared.  The supported states for S are:
+
+* ``"za"`` for Matrix Storage (requires SME)
 
-This means that:
+For state ``"za"``, this means that:
 
 * the function requires that the target processor implements the Scalable Matrix
   Extension (SME).
@@ -6915,8 +6963,8 @@ This means that:
 
 * the function will disable PSTATE.ZA (by setting it to 0) before returning.
 
-For ``__arm_new_za`` functions Clang will set up the ZA context automatically
-on entry to the function, and disable it before returning. For example, if ZA is
+For ``__arm_new("za")`` functions Clang will set up the ZA context automatically
+on entry to the function and disable it before returning. For example, if ZA is
 in a dormant state Clang will generate the code to commit a lazy-save and set up
 a new ZA state before executing user code.
   }];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index e54f969c19039d..ac6053b55cc753 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -3696,6 +3696,12 @@ def err_sme_definition_using_sm_in_non_sme_target : Error<
   "function executed in streaming-SVE mode requires 'sme'">;
 def err_sme_definition_using_za_in_non_sme_target : Error<
   "function using ZA state requires 'sme'">;
+def err_conflicting_attributes_sme_state : Error<
+  "conflicting attributes for state '%0'">;
+def err_unknown_arm_state : Error<
+  "unknown state '%0'">;
+def err_missing_arm_state : Error<
+  "missing state for %0">;
 def err_cconv_change : Error<
   "function declared '%0' here was previously declared "
   "%select{'%2'|without calling convention}1">;
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index f6941242927367..1baf895ebaec2c 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -937,15 +937,20 @@ void TypePrinter::printFunctionProtoAfter(const FunctionProtoType *T,
   OS << ')';
 
   FunctionType::ExtInfo Info = T->getExtInfo();
+  unsigned SMEBits = T->getAArch64SMEAttributes();
 
-  if ((T->getAArch64SMEAttributes() & FunctionType::SME_PStateSMCompatibleMask))
+  if (SMEBits & FunctionType::SME_PStateSMCompatibleMask)
     OS << " __arm_streaming_compatible";
-  if ((T->getAArch64SMEAttributes() & FunctionType::SME_PStateSMEnabledMask))
+  if (SMEBits & FunctionType::SME_PStateSMEnabledMask)
     OS << " __arm_streaming";
-  if ((T->getAArch64SMEAttributes() & FunctionType::SME_PStateZASharedMask))
-    OS << " __arm_shared_za";
-  if ((T->getAArch64SMEAttributes() & FunctionType::SME_PStateZAPreservedMask))
-    OS << " __arm_preserves_za";
+  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Preserves)
+    OS << " __arm_preserves(\"za\")";
+  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_In)
+    OS << " __arm_in(\"za\")";
+  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Out)
+    OS << " __arm_out(\"za\")";
+  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_InOut)
+    OS << " __arm_inout(\"za\")";
 
   printFunctionAfter(Info, OS);
 
@@ -1788,14 +1793,6 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
     OS << "__arm_streaming_compatible";
     return;
   }
-  if (T->getAttrKind() == attr::ArmSharedZA) {
-    OS << "__arm_shared_za";
-    return;
-  }
-  if (T->getAttrKind() == attr::ArmPreservesZA) {
-    OS << "__arm_preserves_za";
-    return;
-  }
 
   OS << " __attribute__((";
   switch (T->getAttrKind()) {
@@ -1839,8 +1836,10 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
   case attr::WebAssemblyFuncref:
   case attr::ArmStreaming:
   case attr::ArmStreamingCompatible:
-  case attr::ArmSharedZA:
-  case attr::ArmPreservesZA:
+  case attr::ArmIn:
+  case attr::ArmOut:
+  case attr::ArmInOut:
+  case attr::ArmPreserves:
     llvm_unreachable("This attribute should have been handled already");
 
   case attr::NSReturnsRetained:
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 51a43b5f85b3cc..7e1389cbd8953c 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -1767,14 +1767,22 @@ static void AddAttributesFromFunctionProtoType(ASTContext &Ctx,
       FPT->isNothrow())
     FuncAttrs.addAttribute(llvm::Attribute::NoUnwind);
 
-  if (FPT->getAArch64SMEAttributes() & FunctionType::SME_PStateSMEnabledMask)
+  unsigned SMEBits = FPT->getAArch64SMEAttributes();
+  if (SMEBits & FunctionType::SME_PStateSMEnabledMask)
     FuncAttrs.addAttribute("aarch64_pstate_sm_enabled");
-  if (FPT->getAArch64SMEAttributes() & FunctionType::SME_PStateSMCompatibleMask)
+  if (SMEBits & FunctionType::SME_PStateSMCompatibleMask)
     FuncAttrs.addAttribute("aarch64_pstate_sm_compatible");
-  if (FPT->getAArch64SMEAttributes() & FunctionType::SME_PStateZASharedMask)
+
+  // ZA
+  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Preserves)
+    FuncAttrs.addAttribute("aarch64_pstate_za_preserved");
+  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Out ||
+      FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_InOut)
+    FuncAttrs.addAttribute("aarch64_pstate_za_shared");
+  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_In) {
     FuncAttrs.addAttribute("aarch64_pstate_za_shared");
-  if (FPT->getAArch64SMEAttributes() & FunctionType::SME_PStateZAPreservedMask)
     FuncAttrs.addAttribute("aarch64_pstate_za_preserved");
+  }
 }
 
 static void AddAttributesFromAssumes(llvm::AttrBuilder &FuncAttrs,
@@ -2446,9 +2454,6 @@ void CodeGenModule::ConstructAttributeList(StringRef Name,
 
     if (TargetDecl->hasAttr<ArmLocallyStreamingAttr>())
       FuncAttrs.addAttribute("aarch64_pstate_sm_body");
-
-    if (TargetDecl->hasAttr<ArmNewZAAttr>())
-      FuncAttrs.addAttribute("aarch64_pstate_za_new");
   }
 
   // Attach "no-builtins" attributes to:
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp
index d78f2594a23764..3c67a5ce72d123 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -2378,8 +2378,10 @@ void CodeGenModule::SetLLVMFunctionAttributesForDefinition(const Decl *D,
   if (D->hasAttr<ArmLocallyStreamingAttr>())
     B.addAttribute("aarch64_pstate_sm_body");
 
-  if (D->hasAttr<ArmNewZAAttr>())
-    B.addAttribute("aarch64_pstate_za_new");
+  if (auto *Attr = D->getAttr<ArmNewAttr>()) {
+    if (Attr->isNewZA())
+      B.addAttribute("aarch64_pstate_za_new");
+  }
 
   // Track whether we need to add the optnone LLVM attribute,
   // starting with the default for this optimization level.
diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp
index ed006f9d67de45..2c79f58e198933 100644
--- a/clang/lib/Parse/ParseDecl.cpp
+++ b/clang/lib/Parse/ParseDecl.cpp
@@ -6787,6 +6787,9 @@ void Parser::ParseDirectDeclarator(Declarator &D) {
       // For consistency with attribute parsing.
       Diag(Tok, diag::err_keyword_not_allowed) << Tok.getIdentifierInfo();
       ConsumeToken();
+      BalancedDelimiterTracker T(*this, tok::l_paren);
+      if (!T.consumeOpen())
+        T.skipToEnd();
     } else if (Tok.is(tok::kw_requires) && D.hasGroupingParens()) {
       // This declarator is declaring a function, but the requires clause is
       // in the wrong place:
diff --git a/clang/lib/Parse/ParseDeclCXX.cpp b/clang/lib/Parse/ParseDeclCXX.cpp
index 910112ecae964c..329c4729740bf7 100644
--- a/clang/lib/Parse/ParseDeclCXX.cpp
+++ b/clang/lib/Parse/ParseDeclCXX.cpp
@@ -1891,6 +1891,9 @@ void Parser::ParseClassSpecifier(tok::TokenKind TagTokKind,
           break;
       } else if (Tok.isRegularKeywordAttribute()) {
         ConsumeToken();
+        BalancedDelimiterTracker T(*this, tok::l_paren);
+        if (!T.consumeOpen())
+          T.skipToEnd();
       } else {
         break;
       }
@@ -4537,8 +4540,15 @@ void Parser::ParseCXX11AttributeSpecifierInternal(ParsedAttributes &Attrs,
   if (Tok.isRegularKeywordAttribute()) {
     SourceLocation Loc = Tok.getLocation();
     IdentifierInfo *AttrName = Tok.getIdentifierInfo();
-    Attrs.addNew(AttrName, Loc, nullptr, Loc, nullptr, 0, Tok.getKind());
+    ParsedAttr::Form Form = ParsedAttr::Form(Tok.getKind());
     ConsumeToken();
+    if (Tok.is(tok::l_paren)) {
+      const LangOptions &LO = getLangOpts();
+      unsigned NumArgs = ParseAttributeArgsCommon(AttrName, Loc, Attrs, EndLoc,
+                                                  /*ScopeName*/ nullptr,
+                                                  /*ScopeLoc*/ Loc, Form);
+    } else
+      Attrs.addNew(AttrName, Loc, nullptr, Loc, nullptr, 0, Form);
     return;
   }
 
@@ -4704,11 +4714,9 @@ SourceLocation Parser::SkipCXX11Attributes() {
       T.consumeOpen();
       T.skipToEnd();
       EndLoc = T.getCloseLocation();
-    } else if (Tok.isRegularKeywordAttribute()) {
-      EndLoc = Tok.getLocation();
-      ConsumeToken();
     } else {
-      assert(Tok.is(tok::kw_alignas) && "not an attribute specifier");
+      assert((Tok.is(tok::kw_alignas) || Tok.isRegularKeywordAttribute()) &&
+             "not an attribute specifier");
       ConsumeToken();
       BalancedDelimiterTracker T(*this, tok::l_paren);
       if (!T.consumeOpen())
diff --git a/clang/lib/Parse/ParseTentative.cpp b/clang/lib/Parse/ParseTentative.cpp
index 242741c15b5ffa..d7b21fbe2b57bc 100644
--- a/clang/lib/Parse/ParseTentative.cpp
+++ b/clang/lib/Parse/ParseTentative.cpp
@@ -894,8 +894,6 @@ bool Parser::TrySkipAttributes() {
       // Note that explicitly checking for `[[` and `]]` allows to fail as
       // expected in the case of the Objective-C message send syntax.
       ConsumeBracket();
-    } else if (Tok.isRegularKeywordAttribute()) {
-      ConsumeToken();
     } else {
       ConsumeToken();
       if (Tok.isNot(tok::l_paren))
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 3168d38dd66c36..0f08042ac51668 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -3175,11 +3175,16 @@ static void checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
 }
 
 static bool hasSMEZAState(const FunctionDecl *FD) {
-  if (FD->hasAttr<ArmNewZAAttr>())
-    return true;
-  if (const auto *T = FD->getType()->getAs<FunctionProtoType>())
-    if (T->getAArch64SMEAttributes() & FunctionType::SME_PStateZASharedMask)
+  if (auto *Attr = FD->getAttr<ArmNewAttr>())
+    if (Attr->isNewZA())
+      return true;
+  if (const auto *T = FD->getType()->getAs<FunctionProtoType>()) {
+    FunctionType::ArmStateValue State =
+        FunctionType::getArmZAState(T->getAArch64SMEAttributes());
+    if (State == FunctionType::ARM_In || State == FunctionType::ARM_Out ||
+        State == FunctionType::ARM_InOut)
       return true;
+  }
   return false;
 }
 
@@ -7507,14 +7512,24 @@ void Sema::checkCall(NamedDecl *FDecl, const FunctionProtoType *Proto,
 
     // If the callee uses AArch64 SME ZA state but the caller doesn't define
     // any, then this is an error.
-    if (ExtInfo.AArch64SMEAttributes & FunctionType::SME_PStateZASharedMask) {
+    FunctionType::ArmStateValue ArmZAState =
+        FunctionType::getArmZAState(ExtInfo.AArch64SMEAttributes);
+    if (ArmZAState == FunctionType::ARM_In ||
+        ArmZAState == FunctionType::ARM_Out ||
+        ArmZAState == FunctionType::ARM_InOut) {
       bool CallerHasZAState = false;
       if (const auto *CallerFD = dyn_cast<FunctionDecl>(CurContext)) {
-        if (CallerFD->hasAttr<ArmNewZAAttr>())
+        auto *Attr = CallerFD->getAttr<ArmNewAttr>();
+        if (Attr && Attr->isNewZA())
           CallerHasZAState = true;
-        else if (const auto *FPT = CallerFD->getType()->getAs<FunctionProtoType>())
-          CallerHasZAState = FPT->getExtProtoInfo().AArch64SMEAttributes &
-                             FunctionType::SME_PStateZASharedMask;
+        else if (const auto *FPT =
+                     CallerFD->getType()->getAs<FunctionProtoType>()) {
+          ArmZAState = FunctionType::getArmZAState(
+              FPT->getExtProtoInfo().AArch64SMEAttributes);
+          CallerHasZAStat...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/76971


More information about the cfe-commits mailing list