[Mlir-commits] [llvm] [clang] [mlir] [AArch64] Replace LLVM IR function attributes for PSTATE.ZA. (PR #79166)

Sander de Smalen llvmlistbot at llvm.org
Wed Jan 31 07:04:13 PST 2024


https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/79166

>From c055495729fe35c4c49a05fdc64a780b2cf72d9e Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Mon, 22 Jan 2024 16:50:41 +0100
Subject: [PATCH 1/4] [AArch64] Replace LLVM IR function attributes for
 PSTATE.ZA.

Since https://github.com/ARM-software/acle/pull/276 the ACLE
defines attributes to better describe the use of a given SME state.

Previously the attributes merely described the possibility of it being
'shared' or 'preserved', whereas the new attributes have more semantics
and also describe how the data flows through the program.

For ZT0 we already had to add new LLVM IR attributes:
* aarch64_new_zt0
* aarch64_in_zt0
* aarch64_out_zt0
* aarch64_inout_zt0
* aarch64_preserves_zt0

We have now done the same for ZA, such that we add:
* aarch64_new_za       (previously `aarch64_pstate_za_new`)
* aarch64_in_za        (more specific variation of `aarch64_pstate_za_shared`)
* aarch64_out_za       (more specific variation of `aarch64_pstate_za_shared`)
* aarch64_inout_za     (more specific variation of `aarch64_pstate_za_shared`)
* aarch64_preserves_za (previously `aarch64_pstate_za_shared, aarch64_pstate_za_preserved`)

This explicitly removes 'pstate' from the name, because with SME2 and
the new ACLE attributes there is a difference between "sharing ZA" (sharing
the ZA matrix register with the caller) and "sharing PSTATE.ZA" (sharing
either the ZA or ZT0 register, both part of PSTATE.ZA with the caller).
---
 clang/lib/CodeGen/CGBuiltin.cpp               |   6 +-
 clang/lib/CodeGen/CGCall.cpp                  |  16 +--
 clang/lib/CodeGen/CodeGenModule.cpp           |   2 +-
 .../aarch64-sme-attrs.cpp                     |  18 +--
 .../aarch64-sme-intrinsics/acle_sme_zero.c    |   4 +-
 clang/test/Modules/aarch64-sme-keywords.cppm  |  10 +-
 llvm/docs/AArch64SME.rst                      |  37 +++---
 llvm/lib/IR/Verifier.cpp                      |  19 ++-
 .../AArch64/AArch64TargetTransformInfo.cpp    |   2 +-
 llvm/lib/Target/AArch64/SMEABIPass.cpp        |   9 +-
 .../AArch64/Utils/AArch64SMEAttributes.cpp    |  32 +++--
 .../AArch64/Utils/AArch64SMEAttributes.h      |  33 +++--
 .../AArch64/sme-disable-gisel-fisel.ll        |  10 +-
 .../AArch64/sme-lazy-save-call-remarks.ll     |   6 +-
 .../CodeGen/AArch64/sme-lazy-save-call.ll     |   8 +-
 .../CodeGen/AArch64/sme-new-za-function.ll    |   8 +-
 .../AArch64/sme-shared-za-interface.ll        |   4 +-
 llvm/test/CodeGen/AArch64/sme-zt0-state.ll    |  22 ++--
 .../Inline/AArch64/sme-pstateza-attrs.ll      |  22 ++--
 llvm/test/Verifier/sme-attributes.ll          |  32 ++++-
 .../Target/AArch64/SMEAttributesTest.cpp      | 116 ++++++++++++------
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  35 ++++--
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |   4 +-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  21 ++--
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  12 +-
 mlir/test/Dialect/ArmSME/enable-arm-za.mlir   |  16 ++-
 .../LLVMIR/Import/function-attributes.ll      |  26 ++--
 mlir/test/Target/LLVMIR/llvmir.mlir           |  24 +++-
 28 files changed, 348 insertions(+), 206 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index f3ab5ad7b08ec..196be813a4896 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -10676,10 +10676,8 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
         llvm::FunctionType::get(StructType::get(CGM.Int64Ty, CGM.Int64Ty), {},
                                 false),
         "__arm_sme_state"));
-    auto Attrs =
-        AttributeList()
-            .addFnAttribute(getLLVMContext(), "aarch64_pstate_sm_compatible")
-            .addFnAttribute(getLLVMContext(), "aarch64_pstate_za_preserved");
+    auto Attrs = AttributeList().addFnAttribute(getLLVMContext(),
+                                                "aarch64_pstate_sm_compatible");
     CI->setAttributes(Attrs);
     CI->setCallingConv(
         llvm::CallingConv::
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 28c211aa631e4..657666c9bda4e 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -1774,14 +1774,14 @@ static void AddAttributesFromFunctionProtoType(ASTContext &Ctx,
     FuncAttrs.addAttribute("aarch64_pstate_sm_compatible");
 
   // ZA
-  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Out ||
-      FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_InOut)
-    FuncAttrs.addAttribute("aarch64_pstate_za_shared");
-  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Preserves ||
-      FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_In) {
-    FuncAttrs.addAttribute("aarch64_pstate_za_shared");
-    FuncAttrs.addAttribute("aarch64_pstate_za_preserved");
-  }
+  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Preserves)
+    FuncAttrs.addAttribute("aarch64_preserves_za");
+  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_In)
+    FuncAttrs.addAttribute("aarch64_in_za");
+  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Out)
+    FuncAttrs.addAttribute("aarch64_out_za");
+  if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_InOut)
+    FuncAttrs.addAttribute("aarch64_inout_za");
 
   // ZT0
   if (FunctionType::getArmZT0State(SMEBits) == FunctionType::ARM_Preserves)
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp
index 6ec54cc01c923..c63e4ecc3dcba 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -2414,7 +2414,7 @@ void CodeGenModule::SetLLVMFunctionAttributesForDefinition(const Decl *D,
 
   if (auto *Attr = D->getAttr<ArmNewAttr>()) {
     if (Attr->isNewZA())
-      B.addAttribute("aarch64_pstate_za_new");
+      B.addAttribute("aarch64_new_za");
     if (Attr->isNewZT0())
       B.addAttribute("aarch64_new_zt0");
   }
diff --git a/clang/test/CodeGen/aarch64-sme-intrinsics/aarch64-sme-attrs.cpp b/clang/test/CodeGen/aarch64-sme-intrinsics/aarch64-sme-attrs.cpp
index f69703a8a7d89..fdd2de11365dd 100644
--- a/clang/test/CodeGen/aarch64-sme-intrinsics/aarch64-sme-attrs.cpp
+++ b/clang/test/CodeGen/aarch64-sme-intrinsics/aarch64-sme-attrs.cpp
@@ -284,20 +284,20 @@ int test_variadic_template() __arm_inout("za") {
 // CHECK: attributes #[[SM_COMPATIBLE]] = { mustprogress noinline nounwind "aarch64_pstate_sm_compatible" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
 // CHECK: attributes #[[SM_COMPATIBLE_DECL]] = { "aarch64_pstate_sm_compatible" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
 // CHECK: attributes #[[SM_BODY]] = { mustprogress noinline nounwind "aarch64_pstate_sm_body" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
-// CHECK: attributes #[[ZA_SHARED]] = { mustprogress noinline nounwind "aarch64_pstate_za_shared" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
-// CHECK: attributes #[[ZA_SHARED_DECL]] = { "aarch64_pstate_za_shared" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
-// CHECK: attributes #[[ZA_PRESERVED]] = { mustprogress noinline nounwind "aarch64_pstate_za_preserved" "aarch64_pstate_za_shared" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
-// CHECK: attributes #[[ZA_PRESERVED_DECL]] = { "aarch64_pstate_za_preserved" "aarch64_pstate_za_shared" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
-// CHECK: attributes #[[ZA_NEW]] = { mustprogress noinline nounwind "aarch64_pstate_za_new" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
+// CHECK: attributes #[[ZA_SHARED]] = { mustprogress noinline nounwind "aarch64_inout_za" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
+// CHECK: attributes #[[ZA_SHARED_DECL]] = { "aarch64_inout_za" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
+// CHECK: attributes #[[ZA_PRESERVED]] = { mustprogress noinline nounwind "aarch64_preserves_za" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
+// CHECK: attributes #[[ZA_PRESERVED_DECL]] = { "aarch64_preserves_za" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
+// CHECK: attributes #[[ZA_NEW]] = { mustprogress noinline nounwind "aarch64_new_za" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
 // CHECK: attributes #[[NORMAL_DEF]] = { mustprogress noinline nounwind "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
 // CHECK: attributes #[[SM_ENABLED_CALL]] = { "aarch64_pstate_sm_enabled" }
 // CHECK: attributes #[[SM_COMPATIBLE_CALL]] = { "aarch64_pstate_sm_compatible" }
 // CHECK: attributes #[[SM_BODY_CALL]] = { "aarch64_pstate_sm_body" }
-// CHECK: attributes #[[ZA_SHARED_CALL]] = { "aarch64_pstate_za_shared" }
-// CHECK: attributes #[[ZA_PRESERVED_CALL]] = { "aarch64_pstate_za_preserved" "aarch64_pstate_za_shared" }
+// CHECK: attributes #[[ZA_SHARED_CALL]] = { "aarch64_inout_za" }
+// CHECK: attributes #[[ZA_PRESERVED_CALL]] = { "aarch64_preserves_za" }
 // CHECK: attributes #[[NOUNWIND_CALL]] = { nounwind }
 // CHECK: attributes #[[NOUNWIND_SM_ENABLED_CALL]] = { nounwind "aarch64_pstate_sm_enabled" }
 // CHECK: attributes #[[NOUNWIND_SM_COMPATIBLE_CALL]] = { nounwind "aarch64_pstate_sm_compatible" }
-// CHECK: attributes #[[NOUNWIND_ZA_SHARED_CALL]] = { nounwind "aarch64_pstate_za_shared" }
-// CHECK: attributes #[[NOUNWIND_ZA_PRESERVED_CALL]] = { nounwind "aarch64_pstate_za_preserved" "aarch64_pstate_za_shared" }
+// CHECK: attributes #[[NOUNWIND_ZA_SHARED_CALL]] = { nounwind "aarch64_inout_za" }
+// CHECK: attributes #[[NOUNWIND_ZA_PRESERVED_CALL]] = { nounwind "aarch64_preserves_za" }
 
diff --git a/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_zero.c b/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_zero.c
index 7f56941108828..9963c0e48b8e7 100644
--- a/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_zero.c
+++ b/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_zero.c
@@ -55,13 +55,13 @@ void test_svzero_mask_za_2(void) __arm_inout("za") {
 }
 
 // CHECK-C-LABEL: define dso_local void @test_svzero_za(
-// CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0]] {
+// CHECK-C-SAME: ) local_unnamed_addr #[[ATTR2:[0-9]+]] {
 // CHECK-C-NEXT:  entry:
 // CHECK-C-NEXT:    tail call void @llvm.aarch64.sme.zero(i32 255)
 // CHECK-C-NEXT:    ret void
 //
 // CHECK-CXX-LABEL: define dso_local void @_Z14test_svzero_zav(
-// CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0]] {
+// CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR2:[0-9]+]] {
 // CHECK-CXX-NEXT:  entry:
 // CHECK-CXX-NEXT:    tail call void @llvm.aarch64.sme.zero(i32 255)
 // CHECK-CXX-NEXT:    ret void
diff --git a/clang/test/Modules/aarch64-sme-keywords.cppm b/clang/test/Modules/aarch64-sme-keywords.cppm
index df4dd32b16cff..759701a633ceb 100644
--- a/clang/test/Modules/aarch64-sme-keywords.cppm
+++ b/clang/test/Modules/aarch64-sme-keywords.cppm
@@ -43,14 +43,14 @@ import A;
 //
 // CHECK:declare void @_ZW1A22f_streaming_compatiblev() #[[STREAMING_COMPATIBLE_DECL:[0-9]+]]
 //
-// CHECK-DAG: attributes #[[SHARED_ZA_DEF]] = {{{.*}} "aarch64_pstate_za_shared" {{.*}}}
-// CHECK-DAG: attributes #[[SHARED_ZA_DECL]] = {{{.*}} "aarch64_pstate_za_shared" {{.*}}}
-// CHECK-DAG: attributes #[[PRESERVES_ZA_DECL]] = {{{.*}} "aarch64_pstate_za_preserved" {{.*}}}
+// CHECK-DAG: attributes #[[SHARED_ZA_DEF]] = {{{.*}} "aarch64_inout_za" {{.*}}}
+// CHECK-DAG: attributes #[[SHARED_ZA_DECL]] = {{{.*}} "aarch64_inout_za" {{.*}}}
+// CHECK-DAG: attributes #[[PRESERVES_ZA_DECL]] = {{{.*}} "aarch64_preserves_za" {{.*}}}
 // CHECK-DAG: attributes #[[NORMAL_DEF]] = {{{.*}}}
 // CHECK-DAG: attributes #[[STREAMING_DECL]] = {{{.*}} "aarch64_pstate_sm_enabled" {{.*}}}
 // CHECK-DAG: attributes #[[STREAMING_COMPATIBLE_DECL]] = {{{.*}} "aarch64_pstate_sm_compatible" {{.*}}}
-// CHECK-DAG: attributes #[[SHARED_ZA_USE]] = { "aarch64_pstate_za_shared" }
-// CHECK-DAG: attributes #[[PRESERVES_ZA_USE]] = { "aarch64_pstate_za_preserved" "aarch64_pstate_za_shared" }
+// CHECK-DAG: attributes #[[SHARED_ZA_USE]] = { "aarch64_inout_za" }
+// CHECK-DAG: attributes #[[PRESERVES_ZA_USE]] = { "aarch64_preserves_za" }
 // CHECK-DAG: attributes #[[STREAMING_USE]] = { "aarch64_pstate_sm_enabled" }
 // CHECK-DAG: attributes #[[STREAMING_COMPATIBLE_USE]] = { "aarch64_pstate_sm_compatible" }
 
diff --git a/llvm/docs/AArch64SME.rst b/llvm/docs/AArch64SME.rst
index 63573bf91eacb..4b6fa5e10f880 100644
--- a/llvm/docs/AArch64SME.rst
+++ b/llvm/docs/AArch64SME.rst
@@ -22,26 +22,32 @@ Below we describe the LLVM IR attributes and their relation to the C/C++
 level ACLE attributes:
 
 ``aarch64_pstate_sm_enabled``
-    is used for functions with ``__attribute__((arm_streaming))``
+    is used for functions with ``__arm_streaming``
 
 ``aarch64_pstate_sm_compatible``
-    is used for functions with ``__attribute__((arm_streaming_compatible))``
+    is used for functions with ``__arm_streaming_compatible``
 
 ``aarch64_pstate_sm_body``
-  is used for functions with ``__attribute__((arm_locally_streaming))`` and is
+  is used for functions with ``__arm_locally_streaming`` and is
   only valid on function definitions (not declarations)
 
-``aarch64_pstate_za_new``
-  is used for functions with ``__attribute__((arm_new_za))``
+``aarch64_pstate_new_za``
+  is used for functions with ``__arm_new("za")``
 
-``aarch64_pstate_za_shared``
-  is used for functions with ``__attribute__((arm_shared_za))``
+``aarch64_in_za``
+  is used for functions with ``__arm_in("za")``
 
-``aarch64_pstate_za_preserved``
-  is used for functions with ``__attribute__((arm_preserves_za))``
+``aarch64_out_za``
+  is used for functions with ``__arm_out("za")``
+
+``aarch64_inout_za``
+  is used for functions with ``__arm_inout("za")``
+
+``aarch64_preserves_za``
+  is used for functions with ``__arm_preserves("za")``
 
 ``aarch64_expanded_pstate_za``
-  is used for functions with ``__attribute__((arm_new_za))``
+  is used for functions with ``__arm_new_za``
 
 Clang must ensure that the above attributes are added both to the
 function's declaration/definition as well as to their call-sites. This is
@@ -89,11 +95,10 @@ Restrictions on attributes
 * It is not allowed for a function to be decorated with both
   ``aarch64_pstate_sm_compatible`` and ``aarch64_pstate_sm_enabled``.
 
-* It is not allowed for a function to be decorated with both
-  ``aarch64_pstate_za_new`` and ``aarch64_pstate_za_preserved``.
-
-* It is not allowed for a function to be decorated with both
-  ``aarch64_pstate_za_new`` and ``aarch64_pstate_za_shared``.
+* It is not allowed for a function to be decorated with more than one of the
+  following attributes:
+  ``aarch64_new_za``, ``aarch64_in_za``, ``aarch64_out_za``, ``aarch64_inout_za``,
+  ``aarch64_preserves_za``.
 
 These restrictions also apply in the higher level SME ACLE, which means we can
 emit diagnostics in Clang to signal users about incorrect behaviour.
@@ -426,7 +431,7 @@ to toggle PSTATE.ZA using intrinsics. This also makes it simpler to setup a
 lazy-save mechanism for calls to private-ZA functions (i.e. functions that may
 either directly or indirectly clobber ZA state).
 
-For the purpose of handling functions marked with ``aarch64_pstate_za_new``,
+For the purpose of handling functions marked with ``aarch64_new_za``,
 we have introduced a new LLVM IR pass (SMEABIPass) that is run just before
 SelectionDAG. Any such functions dealt with by this pass are marked with
 ``aarch64_expanded_pstate_za``.
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 91cf91fbc788b..53d923ced8d2e 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2155,17 +2155,14 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs,
            V);
   }
 
-  if (Attrs.hasFnAttr("aarch64_pstate_za_new")) {
-    Check(!Attrs.hasFnAttr("aarch64_pstate_za_preserved"),
-           "Attributes 'aarch64_pstate_za_new and aarch64_pstate_za_preserved' "
-           "are incompatible!",
-           V);
-
-    Check(!Attrs.hasFnAttr("aarch64_pstate_za_shared"),
-           "Attributes 'aarch64_pstate_za_new and aarch64_pstate_za_shared' "
-           "are incompatible!",
-           V);
-  }
+  Check(
+      (Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") +
+       Attrs.hasFnAttr("aarch64_inout_za") +
+       Attrs.hasFnAttr("aarch64_out_za") +
+       Attrs.hasFnAttr("aarch64_preserves_za")) <= 1,
+      "Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
+      "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive",
+      V);
 
   Check(
       (Attrs.hasFnAttr("aarch64_new_zt0") + Attrs.hasFnAttr("aarch64_in_zt0") +
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 992b11da7eeee..cdd2750521d2c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -242,7 +242,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
     CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
   }
 
-  if (CalleeAttrs.hasNewZABody())
+  if (CalleeAttrs.isNewZA())
     return false;
 
   if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index 0247488ce93f1..bee5d63c5a749 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -62,8 +62,7 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
       FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false);
   auto Attrs =
       AttributeList()
-          .addFnAttribute(M->getContext(), "aarch64_pstate_sm_compatible")
-          .addFnAttribute(M->getContext(), "aarch64_pstate_za_preserved");
+          .addFnAttribute(M->getContext(), "aarch64_pstate_sm_compatible");
   FunctionCallee Callee =
       M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs);
   CallInst *Call = Builder.CreateCall(Callee);
@@ -78,7 +77,7 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
 }
 
 /// This function generates code at the beginning and end of a function marked
-/// with either `aarch64_pstate_za_new` or `aarch64_new_zt0`.
+/// with either `aarch64_new_za` or `aarch64_new_zt0`.
 /// At the beginning of the function, the following code is generated:
 ///  - Commit lazy-save if active   [Private-ZA Interface*]
 ///  - Enable PSTATE.ZA             [Private-ZA Interface]
@@ -133,7 +132,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
     Builder.CreateCall(EnableZAIntr->getFunctionType(), EnableZAIntr);
   }
 
-  if (FnAttrs.hasNewZABody()) {
+  if (FnAttrs.isNewZA()) {
     Function *ZeroIntr =
         Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero);
     Builder.CreateCall(ZeroIntr->getFunctionType(), ZeroIntr,
@@ -174,7 +173,7 @@ bool SMEABI::runOnFunction(Function &F) {
 
   bool Changed = false;
   SMEAttrs FnAttrs(F);
-  if (FnAttrs.hasNewZABody() || FnAttrs.isNewZT0())
+  if (FnAttrs.isNewZA() || FnAttrs.isNewZT0())
     Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs);
 
   return Changed;
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 3ee54e5df0a13..ac07274a36c23 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -23,13 +23,15 @@ void SMEAttrs::set(unsigned M, bool Enable) {
          "SM_Enabled and SM_Compatible are mutually exclusive");
 
   // ZA Attrs
-  assert(!(hasNewZABody() && sharesZA()) &&
-         "ZA_New and ZA_Shared are mutually exclusive");
-  assert(!(hasNewZABody() && preservesZA()) &&
-         "ZA_New and ZA_Preserved are mutually exclusive");
-  assert(!(hasNewZABody() && (Bitmask & SME_ABI_Routine)) &&
+  assert(!(isNewZA() && (Bitmask & SME_ABI_Routine)) &&
          "ZA_New and SME_ABI_Routine are mutually exclusive");
 
+  assert(
+      (!sharesZA() ||
+       (isNewZA() ^ isInZA() ^ isInOutZA() ^ isOutZA() ^ isPreservesZA())) &&
+      "Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
+      "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive");
+
   // ZT0 Attrs
   assert(
       (!sharesZT0() || (isNewZT0() ^ isInZT0() ^ isInOutZT0() ^ isOutZT0() ^
@@ -49,8 +51,8 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
   if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
     Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
   if (FuncName == "__arm_tpidr2_restore")
-    Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
-                SMEAttrs::SME_ABI_Routine);
+    Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
+                SMEAttrs::SME_ABI_Routine;
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
@@ -61,12 +63,16 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
     Bitmask |= SM_Compatible;
   if (Attrs.hasFnAttr("aarch64_pstate_sm_body"))
     Bitmask |= SM_Body;
-  if (Attrs.hasFnAttr("aarch64_pstate_za_shared"))
-    Bitmask |= ZA_Shared;
-  if (Attrs.hasFnAttr("aarch64_pstate_za_new"))
-    Bitmask |= ZA_New;
-  if (Attrs.hasFnAttr("aarch64_pstate_za_preserved"))
-    Bitmask |= ZA_Preserved;
+  if (Attrs.hasFnAttr("aarch64_in_za"))
+    Bitmask |= encodeZAState(StateValue::In);
+  if (Attrs.hasFnAttr("aarch64_out_za"))
+    Bitmask |= encodeZAState(StateValue::Out);
+  if (Attrs.hasFnAttr("aarch64_inout_za"))
+    Bitmask |= encodeZAState(StateValue::InOut);
+  if (Attrs.hasFnAttr("aarch64_preserves_za"))
+    Bitmask |= encodeZAState(StateValue::Preserved);
+  if (Attrs.hasFnAttr("aarch64_new_za"))
+    Bitmask |= encodeZAState(StateValue::New);
   if (Attrs.hasFnAttr("aarch64_in_zt0"))
     Bitmask |= encodeZT0State(StateValue::In);
   if (Attrs.hasFnAttr("aarch64_out_zt0"))
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 27b7075a0944f..4c7c1c9b07953 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -41,10 +41,9 @@ class SMEAttrs {
     SM_Enabled = 1 << 0,      // aarch64_pstate_sm_enabled
     SM_Compatible = 1 << 1,   // aarch64_pstate_sm_compatible
     SM_Body = 1 << 2,         // aarch64_pstate_sm_body
-    ZA_Shared = 1 << 3,       // aarch64_pstate_sm_shared
-    ZA_New = 1 << 4,          // aarch64_pstate_sm_new
-    ZA_Preserved = 1 << 5,    // aarch64_pstate_sm_preserved
-    SME_ABI_Routine = 1 << 6, // Used for SME ABI routines to avoid lazy saves
+    SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
+    ZA_Shift = 4,
+    ZA_Mask = 0b111 << ZA_Shift,
     ZT0_Shift = 7,
     ZT0_Mask = 0b111 << ZT0_Shift
   };
@@ -77,13 +76,29 @@ class SMEAttrs {
   /// streaming mode.
   bool requiresSMChange(const SMEAttrs &Callee) const;
 
-  // Interfaces to query PSTATE.ZA
-  bool hasNewZABody() const { return Bitmask & ZA_New; }
-  bool sharesZA() const { return Bitmask & ZA_Shared; }
+  // Interfaces to query ZA
+  static StateValue decodeZAState(unsigned Bitmask) {
+    return static_cast<StateValue>((Bitmask & ZA_Mask) >> ZA_Shift);
+  }
+  static unsigned encodeZAState(StateValue S) {
+    return static_cast<unsigned>(S) << ZA_Shift;
+  }
+
+  bool isNewZA() const { return decodeZAState(Bitmask) == StateValue::New; }
+  bool isInZA() const { return decodeZAState(Bitmask) == StateValue::In; }
+  bool isOutZA() const { return decodeZAState(Bitmask) == StateValue::Out; }
+  bool isInOutZA() const { return decodeZAState(Bitmask) == StateValue::InOut; }
+  bool isPreservesZA() const {
+    return decodeZAState(Bitmask) == StateValue::Preserved;
+  }
+  bool sharesZA() const {
+    StateValue State = decodeZAState(Bitmask);
+    return State == StateValue::In || State == StateValue::Out ||
+           State == StateValue::InOut || State == StateValue::Preserved;
+  }
   bool hasSharedZAInterface() const { return sharesZA() || sharesZT0(); }
   bool hasPrivateZAInterface() const { return !hasSharedZAInterface(); }
-  bool preservesZA() const { return Bitmask & ZA_Preserved; }
-  bool hasZAState() const { return hasNewZABody() || sharesZA(); }
+  bool hasZAState() const { return isNewZA() || sharesZA(); }
   bool requiresLazySave(const SMEAttrs &Callee) const {
     return hasZAState() && Callee.hasPrivateZAInterface() &&
            !(Callee.Bitmask & SME_ABI_Routine);
diff --git a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
index 381091b453943..2a78012045ff4 100644
--- a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
+++ b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
@@ -209,9 +209,9 @@ define void @normal_call_to_streaming_callee_ptr(ptr %p) nounwind noinline optno
 ; Check ZA state
 ;
 
-declare double @za_shared_callee(double) "aarch64_pstate_za_shared"
+declare double @za_shared_callee(double) "aarch64_inout_za"
 
-define double  @za_new_caller_to_za_shared_callee(double %x) nounwind noinline optnone "aarch64_pstate_za_new"{
+define double  @za_new_caller_to_za_shared_callee(double %x) nounwind noinline optnone "aarch64_new_za"{
 ; CHECK-COMMON-LABEL: za_new_caller_to_za_shared_callee:
 ; CHECK-COMMON:       // %bb.0: // %prelude
 ; CHECK-COMMON-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
@@ -248,7 +248,7 @@ entry:
   ret double %add;
 }
 
-define double  @za_shared_caller_to_za_none_callee(double %x) nounwind noinline optnone "aarch64_pstate_za_shared"{
+define double  @za_shared_caller_to_za_none_callee(double %x) nounwind noinline optnone "aarch64_inout_za"{
 ; CHECK-COMMON-LABEL: za_shared_caller_to_za_none_callee:
 ; CHECK-COMMON:       // %bb.0: // %entry
 ; CHECK-COMMON-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
@@ -288,7 +288,7 @@ entry:
 }
 
 ; Ensure we set up and restore the lazy save correctly for instructions which are lowered to lib calls.
-define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_pstate_za_shared" nounwind {
+define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_inout_za" nounwind {
 ; CHECK-COMMON-LABEL: f128_call_za:
 ; CHECK-COMMON:       // %bb.0:
 ; CHECK-COMMON-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
@@ -350,7 +350,7 @@ define fp128 @f128_call_sm(fp128 %a, fp128 %b) "aarch64_pstate_sm_enabled" nounw
 }
 
 ; As above this should use Selection DAG to make sure the libcall call is lowered correctly.
-define double @frem_call_za(double %a, double %b) "aarch64_pstate_za_shared" nounwind {
+define double @frem_call_za(double %a, double %b) "aarch64_inout_za" nounwind {
 ; CHECK-COMMON-LABEL: frem_call_za:
 ; CHECK-COMMON:       // %bb.0:
 ; CHECK-COMMON-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
diff --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call-remarks.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call-remarks.ll
index d999311301f94..65e50842d5d78 100644
--- a/llvm/test/CodeGen/AArch64/sme-lazy-save-call-remarks.ll
+++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call-remarks.ll
@@ -4,13 +4,13 @@
 declare void @private_za_callee()
 declare float @llvm.cos.f32(float)
 
-define void @test_lazy_save_1_callee() nounwind "aarch64_pstate_za_shared" {
+define void @test_lazy_save_1_callee() nounwind "aarch64_inout_za" {
 ; CHECK: remark: <unknown>:0:0: call from 'test_lazy_save_1_callee' to 'private_za_callee' sets up a lazy save for ZA
   call void @private_za_callee()
   ret void
 }
 
-define void @test_lazy_save_2_callees() nounwind "aarch64_pstate_za_shared" {
+define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
 ; CHECK: remark: <unknown>:0:0: call from 'test_lazy_save_2_callees' to 'private_za_callee' sets up a lazy save for ZA
   call void @private_za_callee()
 ; CHECK: remark: <unknown>:0:0: call from 'test_lazy_save_2_callees' to 'private_za_callee' sets up a lazy save for ZA
@@ -18,7 +18,7 @@ define void @test_lazy_save_2_callees() nounwind "aarch64_pstate_za_shared" {
   ret void
 }
 
-define float @test_lazy_save_expanded_intrinsic(float %a) nounwind "aarch64_pstate_za_shared" {
+define float @test_lazy_save_expanded_intrinsic(float %a) nounwind "aarch64_inout_za" {
 ; CHECK: remark: <unknown>:0:0: call from 'test_lazy_save_expanded_intrinsic' to 'cosf' sets up a lazy save for ZA
   %res = call float @llvm.cos.f32(float %a)
   ret float %res
diff --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
index 9625e139bd0bc..9d635f0b88f19 100644
--- a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
+++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
@@ -5,7 +5,7 @@ declare void @private_za_callee()
 declare float @llvm.cos.f32(float)
 
 ; Test lazy-save mechanism for a single callee.
-define void @test_lazy_save_1_callee() nounwind "aarch64_pstate_za_shared" {
+define void @test_lazy_save_1_callee() nounwind "aarch64_inout_za" {
 ; CHECK-LABEL: test_lazy_save_1_callee:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
@@ -38,7 +38,7 @@ define void @test_lazy_save_1_callee() nounwind "aarch64_pstate_za_shared" {
 }
 
 ; Test lazy-save mechanism for multiple callees.
-define void @test_lazy_save_2_callees() nounwind "aarch64_pstate_za_shared" {
+define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
 ; CHECK-LABEL: test_lazy_save_2_callees:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
@@ -85,7 +85,7 @@ define void @test_lazy_save_2_callees() nounwind "aarch64_pstate_za_shared" {
 }
 
 ; Test a call of an intrinsic that gets expanded to a library call.
-define float @test_lazy_save_expanded_intrinsic(float %a) nounwind "aarch64_pstate_za_shared" {
+define float @test_lazy_save_expanded_intrinsic(float %a) nounwind "aarch64_inout_za" {
 ; CHECK-LABEL: test_lazy_save_expanded_intrinsic:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
@@ -118,7 +118,7 @@ define float @test_lazy_save_expanded_intrinsic(float %a) nounwind "aarch64_psta
 }
 
 ; Test a combination of streaming-compatible -> normal call with lazy-save.
-define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_pstate_za_shared" "aarch64_pstate_sm_compatible" {
+define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_inout_za" "aarch64_pstate_sm_compatible" {
 ; CHECK-LABEL: test_lazy_save_and_conditional_smstart:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
diff --git a/llvm/test/CodeGen/AArch64/sme-new-za-function.ll b/llvm/test/CodeGen/AArch64/sme-new-za-function.ll
index 0cee26dbb349e..04d26902c536a 100644
--- a/llvm/test/CodeGen/AArch64/sme-new-za-function.ll
+++ b/llvm/test/CodeGen/AArch64/sme-new-za-function.ll
@@ -1,9 +1,9 @@
 ; RUN: opt -S -mtriple=aarch64-linux-gnu -aarch64-sme-abi %s | FileCheck %s
 ; RUN: opt -S -mtriple=aarch64-linux-gnu -aarch64-sme-abi -aarch64-sme-abi %s | FileCheck %s
 
-declare void @shared_za_callee() "aarch64_pstate_za_shared"
+declare void @shared_za_callee() "aarch64_inout_za"
 
-define void @private_za() "aarch64_pstate_za_new" {
+define void @private_za() "aarch64_new_za" {
 ; CHECK-LABEL: @private_za(
 ; CHECK-NEXT:  prelude:
 ; CHECK-NEXT:    [[TPIDR2:%.*]] = call i64 @llvm.aarch64.sme.get.tpidr2()
@@ -24,7 +24,7 @@ define void @private_za() "aarch64_pstate_za_new" {
   ret void
 }
 
-define i32 @private_za_multiple_exit(i32 %a, i32 %b, i64 %cond) "aarch64_pstate_za_new" {
+define i32 @private_za_multiple_exit(i32 %a, i32 %b, i64 %cond) "aarch64_new_za" {
 ; CHECK-LABEL: @private_za_multiple_exit(
 ; CHECK-NEXT:  prelude:
 ; CHECK-NEXT:    [[TPIDR2:%.*]] = call i64 @llvm.aarch64.sme.get.tpidr2()
@@ -62,4 +62,4 @@ if.end:
 }
 
 ; CHECK: declare void @__arm_tpidr2_save() #[[ATTR:[0-9]+]]
-; CHECK: attributes #[[ATTR]] = { "aarch64_pstate_sm_compatible" "aarch64_pstate_za_preserved" }
+; CHECK: attributes #[[ATTR]] = { "aarch64_pstate_sm_compatible" }
diff --git a/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll b/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
index a2e20013d94ff..cd7460b177c4b 100644
--- a/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
+++ b/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
@@ -4,7 +4,7 @@
 declare void @private_za_callee()
 
 ; Ensure that we don't use tail call optimization when a lazy-save is required.
-define void @disable_tailcallopt() "aarch64_pstate_za_shared" nounwind {
+define void @disable_tailcallopt() "aarch64_inout_za" nounwind {
 ; CHECK-LABEL: disable_tailcallopt:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
@@ -37,7 +37,7 @@ define void @disable_tailcallopt() "aarch64_pstate_za_shared" nounwind {
 }
 
 ; Ensure we set up and restore the lazy save correctly for instructions which are lowered to lib calls
-define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_pstate_za_shared" nounwind {
+define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_inout_za" nounwind {
 ; CHECK-LABEL: f128_call_za:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 18d1e40bf4d0f..7f40b5e7e1344 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -30,7 +30,7 @@ define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
 ; Expect spill & fill of ZT0 around call
 ; Expect setup and restore lazy-save around call
 ; Expect smstart za after call
-define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+define void @za_zt0_shared_caller_no_state_callee() "aarch64_inout_za" "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: za_zt0_shared_caller_no_state_callee:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
@@ -84,7 +84,7 @@ define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
 }
 
 ; Expect spill & fill of ZT0 around call
-define void @za_zt0_shared_caller_za_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+define void @za_zt0_shared_caller_za_shared_callee() "aarch64_inout_za" "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: za_zt0_shared_caller_za_shared_callee:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
@@ -106,12 +106,12 @@ define void @za_zt0_shared_caller_za_shared_callee() "aarch64_pstate_za_shared"
 ; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
 ; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void @callee() "aarch64_pstate_za_shared";
+  call void @callee() "aarch64_inout_za";
   ret void;
 }
 
 ; Caller and callee have shared ZA & ZT0
-define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_inout_za" "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: za_zt0_shared_caller_za_zt0_shared_callee:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
@@ -128,7 +128,7 @@ define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_pstate_za_shar
 ; CHECK-NEXT:    mov sp, x29
 ; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+  call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
   ret void;
 }
 
@@ -189,7 +189,7 @@ define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
 ; Expect commit of lazy-save if ZA is dormant
 ; Expect smstart ZA, clear ZA & clear ZT0
 ; Before return, expect smstop ZA
-define void @new_za_zt0_caller() "aarch64_pstate_za_new" "aarch64_new_zt0" nounwind {
+define void @new_za_zt0_caller() "aarch64_new_za" "aarch64_new_zt0" nounwind {
 ; CHECK-LABEL: new_za_zt0_caller:
 ; CHECK:       // %bb.0: // %prelude
 ; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
@@ -219,12 +219,12 @@ define void @new_za_zt0_caller() "aarch64_pstate_za_new" "aarch64_new_zt0" nounw
 ; CHECK-NEXT:    mov sp, x29
 ; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+  call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
   ret void;
 }
 
 ; Expect clear ZA on entry
-define void @new_za_shared_zt0_caller() "aarch64_pstate_za_new" "aarch64_in_zt0" nounwind {
+define void @new_za_shared_zt0_caller() "aarch64_new_za" "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: new_za_shared_zt0_caller:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
@@ -242,12 +242,12 @@ define void @new_za_shared_zt0_caller() "aarch64_pstate_za_new" "aarch64_in_zt0"
 ; CHECK-NEXT:    mov sp, x29
 ; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+  call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
   ret void;
 }
 
 ; Expect clear ZT0 on entry
-define void @shared_za_new_zt0() "aarch64_pstate_za_shared" "aarch64_new_zt0" nounwind {
+define void @shared_za_new_zt0() "aarch64_inout_za" "aarch64_new_zt0" nounwind {
 ; CHECK-LABEL: shared_za_new_zt0:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
@@ -265,6 +265,6 @@ define void @shared_za_new_zt0() "aarch64_pstate_za_shared" "aarch64_new_zt0" no
 ; CHECK-NEXT:    mov sp, x29
 ; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+  call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
   ret void;
 }
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
index 7fca45b1e43f6..816492768cc0f 100644
--- a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
@@ -22,7 +22,7 @@ entry:
   ret void
 }
 
-define void @shared_za_callee() "aarch64_pstate_za_shared" {
+define void @shared_za_callee() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_callee
 ; CHECK-SAME: () #[[ATTR1:[0-9]+]] {
 ; CHECK-NEXT:  entry:
@@ -34,7 +34,7 @@ entry:
   ret void
 }
 
-define void @new_za_callee() "aarch64_pstate_za_new" {
+define void @new_za_callee() "aarch64_new_za" {
 ; CHECK-LABEL: define void @new_za_callee
 ; CHECK-SAME: () #[[ATTR2:[0-9]+]] {
 ; CHECK-NEXT:    call void @inlined_body()
@@ -84,7 +84,7 @@ entry:
 ; [x] Z -> N
 ; [ ] Z -> S
 ; [ ] Z -> Z
-define void @new_za_caller_nonza_callee_inline() "aarch64_pstate_za_new" {
+define void @new_za_caller_nonza_callee_inline() "aarch64_new_za" {
 ; CHECK-LABEL: define void @new_za_caller_nonza_callee_inline
 ; CHECK-SAME: () #[[ATTR2]] {
 ; CHECK-NEXT:  entry:
@@ -99,7 +99,7 @@ entry:
 ; [ ] Z -> N
 ; [x] Z -> S
 ; [ ] Z -> Z
-define void @new_za_caller_shared_za_callee_inline() "aarch64_pstate_za_new" {
+define void @new_za_caller_shared_za_callee_inline() "aarch64_new_za" {
 ; CHECK-LABEL: define void @new_za_caller_shared_za_callee_inline
 ; CHECK-SAME: () #[[ATTR2]] {
 ; CHECK-NEXT:  entry:
@@ -114,7 +114,7 @@ entry:
 ; [ ] Z -> N
 ; [ ] Z -> S
 ; [x] Z -> Z
-define void @new_za_caller_new_za_callee_dont_inline() "aarch64_pstate_za_new" {
+define void @new_za_caller_new_za_callee_dont_inline() "aarch64_new_za" {
 ; CHECK-LABEL: define void @new_za_caller_new_za_callee_dont_inline
 ; CHECK-SAME: () #[[ATTR2]] {
 ; CHECK-NEXT:  entry:
@@ -129,7 +129,7 @@ entry:
 ; [x] Z -> N
 ; [ ] Z -> S
 ; [ ] Z -> Z
-define void @shared_za_caller_nonza_callee_inline() "aarch64_pstate_za_shared" {
+define void @shared_za_caller_nonza_callee_inline() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_caller_nonza_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
@@ -144,7 +144,7 @@ entry:
 ; [ ] S -> N
 ; [x] S -> Z
 ; [ ] S -> S
-define void @shared_za_caller_new_za_callee_dont_inline() "aarch64_pstate_za_shared" {
+define void @shared_za_caller_new_za_callee_dont_inline() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_caller_new_za_callee_dont_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
@@ -159,7 +159,7 @@ entry:
 ; [ ] S -> N
 ; [ ] S -> Z
 ; [x] S -> S
-define void @shared_za_caller_shared_za_callee_inline() "aarch64_pstate_za_shared" {
+define void @shared_za_caller_shared_za_callee_inline() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_caller_shared_za_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
@@ -181,7 +181,7 @@ define void @private_za_callee_call_za_disable() {
   ret void
 }
 
-define void @shared_za_caller_private_za_callee_call_za_disable() "aarch64_pstate_za_shared" {
+define void @shared_za_caller_private_za_callee_call_za_disable() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_caller_private_za_callee_call_za_disable
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:    call void @private_za_callee_call_za_disable()
@@ -201,7 +201,7 @@ define void @private_za_callee_call_tpidr2_save() {
   ret void
 }
 
-define void @shared_za_caller_private_za_callee_call_tpidr2_save_dont_inline() "aarch64_pstate_za_shared" {
+define void @shared_za_caller_private_za_callee_call_tpidr2_save_dont_inline() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_caller_private_za_callee_call_tpidr2_save_dont_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:    call void @private_za_callee_call_tpidr2_save()
@@ -221,7 +221,7 @@ define void @private_za_callee_call_tpidr2_restore(ptr %ptr) {
   ret void
 }
 
-define void @shared_za_caller_private_za_callee_call_tpidr2_restore_dont_inline(ptr %ptr) "aarch64_pstate_za_shared" {
+define void @shared_za_caller_private_za_callee_call_tpidr2_restore_dont_inline(ptr %ptr) "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_caller_private_za_callee_call_tpidr2_restore_dont_inline
 ; CHECK-SAME: (ptr [[PTR:%.*]]) #[[ATTR1]] {
 ; CHECK-NEXT:    call void @private_za_callee_call_tpidr2_restore(ptr [[PTR]])
diff --git a/llvm/test/Verifier/sme-attributes.ll b/llvm/test/Verifier/sme-attributes.ll
index 2b949951dc1bb..3d01613ebf2fe 100644
--- a/llvm/test/Verifier/sme-attributes.ll
+++ b/llvm/test/Verifier/sme-attributes.ll
@@ -3,11 +3,35 @@
 declare void @sm_attrs() "aarch64_pstate_sm_enabled" "aarch64_pstate_sm_compatible";
 ; CHECK: Attributes 'aarch64_pstate_sm_enabled and aarch64_pstate_sm_compatible' are incompatible!
 
-declare void @za_preserved() "aarch64_pstate_za_new" "aarch64_pstate_za_preserved";
-; CHECK: Attributes 'aarch64_pstate_za_new and aarch64_pstate_za_preserved' are incompatible!
+declare void @za_new_preserved() "aarch64_new_za" "aarch64_preserves_za";
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
 
-declare void @za_shared() "aarch64_pstate_za_new" "aarch64_pstate_za_shared";
-; CHECK: Attributes 'aarch64_pstate_za_new and aarch64_pstate_za_shared' are incompatible!
+declare void @za_new_in() "aarch64_new_za" "aarch64_in_za";
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+
+declare void @za_new_inout() "aarch64_new_za" "aarch64_inout_za";
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+
+declare void @za_new_out() "aarch64_new_za" "aarch64_out_za";
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+
+declare void @za_preserved_in() "aarch64_preserves_za" "aarch64_in_za";
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+
+declare void @za_preserved_inout() "aarch64_preserves_za" "aarch64_inout_za";
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+
+declare void @za_preserved_out() "aarch64_preserves_za" "aarch64_out_za";
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+
+declare void @za_in_inout() "aarch64_in_za" "aarch64_inout_za";
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+
+declare void @za_in_out() "aarch64_in_za" "aarch64_out_za";
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+
+declare void @za_inout_out() "aarch64_inout_za" "aarch64_out_za";
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
 
 declare void @zt0_new_preserved() "aarch64_new_zt0" "aarch64_preserves_zt0";
 ; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 044de72449ec8..2c1c92dfa602a 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -38,21 +38,22 @@ TEST(SMEAttributes, Constructors) {
               ->getFunction("foo"))
           .hasStreamingCompatibleInterface());
 
-  ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_pstate_za_shared\"")
-                      ->getFunction("foo"))
-                  .sharesZA());
 
-  ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_pstate_za_shared\"")
+  ASSERT_TRUE(
+      SA(*parseIR("declare void @foo() \"aarch64_in_za\"")->getFunction("foo"))
+          .isInZA());
+  ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_out_za\"")
                       ->getFunction("foo"))
-                  .hasSharedZAInterface());
-
-  ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_pstate_za_new\"")
+                  .isOutZA());
+  ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_inout_za\"")
                       ->getFunction("foo"))
-                  .hasNewZABody());
-
-  ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_pstate_za_preserved\"")
+                  .isInOutZA());
+  ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_preserves_za\"")
+                      ->getFunction("foo"))
+                  .isPreservesZA());
+  ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_new_za\"")
                       ->getFunction("foo"))
-                  .preservesZA());
+                  .isNewZA());
 
   ASSERT_TRUE(
       SA(*parseIR("declare void @foo() \"aarch64_in_zt0\"")->getFunction("foo"))
@@ -73,10 +74,6 @@ TEST(SMEAttributes, Constructors) {
   // Invalid combinations.
   EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible),
                      "SM_Enabled and SM_Compatible are mutually exclusive");
-  EXPECT_DEBUG_DEATH(SA(SA::ZA_New | SA::ZA_Shared),
-                     "ZA_New and ZA_Shared are mutually exclusive");
-  EXPECT_DEBUG_DEATH(SA(SA::ZA_New | SA::ZA_Preserved),
-                     "ZA_New and ZA_Preserved are mutually exclusive");
 
   // Test that the set() methods equally check validity.
   EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled).set(SA::SM_Compatible),
@@ -99,29 +96,69 @@ TEST(SMEAttributes, Basics) {
   ASSERT_TRUE(SA(SA::SM_Compatible | SA::SM_Body).hasStreamingBody());
   ASSERT_FALSE(SA(SA::SM_Compatible | SA::SM_Body).hasNonStreamingInterface());
 
-  // Test PSTATE.ZA interfaces.
-  ASSERT_FALSE(SA(SA::ZA_Shared).hasPrivateZAInterface());
-  ASSERT_TRUE(SA(SA::ZA_Shared).hasSharedZAInterface());
-  ASSERT_TRUE(SA(SA::ZA_Shared).sharesZA());
-  ASSERT_TRUE(SA(SA::ZA_Shared).hasZAState());
-  ASSERT_FALSE(SA(SA::ZA_Shared).preservesZA());
-  ASSERT_TRUE(SA(SA::ZA_Shared | SA::ZA_Preserved).preservesZA());
-  ASSERT_FALSE(SA(SA::ZA_Shared).sharesZT0());
-  ASSERT_FALSE(SA(SA::ZA_Shared).hasZT0State());
-
-  ASSERT_TRUE(SA(SA::ZA_New).hasPrivateZAInterface());
-  ASSERT_FALSE(SA(SA::ZA_New).hasSharedZAInterface());
-  ASSERT_TRUE(SA(SA::ZA_New).hasNewZABody());
-  ASSERT_TRUE(SA(SA::ZA_New).hasZAState());
-  ASSERT_FALSE(SA(SA::ZA_New).preservesZA());
-  ASSERT_FALSE(SA(SA::ZA_New).sharesZT0());
-  ASSERT_FALSE(SA(SA::ZA_New).hasZT0State());
-
-  ASSERT_TRUE(SA(SA::Normal).hasPrivateZAInterface());
-  ASSERT_FALSE(SA(SA::Normal).hasSharedZAInterface());
-  ASSERT_FALSE(SA(SA::Normal).hasNewZABody());
+  // Test ZA State interfaces
+  SA ZA_In = SA(SA::encodeZAState(SA::StateValue::In));
+  ASSERT_TRUE(ZA_In.isInZA());
+  ASSERT_FALSE(ZA_In.isOutZA());
+  ASSERT_FALSE(ZA_In.isInOutZA());
+  ASSERT_FALSE(ZA_In.isPreservesZA());
+  ASSERT_FALSE(ZA_In.isNewZA());
+  ASSERT_TRUE(ZA_In.sharesZA());
+  ASSERT_TRUE(ZA_In.hasZAState());
+  ASSERT_TRUE(ZA_In.hasSharedZAInterface());
+  ASSERT_FALSE(ZA_In.hasPrivateZAInterface());
+
+  SA ZA_Out = SA(SA::encodeZAState(SA::StateValue::Out));
+  ASSERT_TRUE(ZA_Out.isOutZA());
+  ASSERT_FALSE(ZA_Out.isInZA());
+  ASSERT_FALSE(ZA_Out.isInOutZA());
+  ASSERT_FALSE(ZA_Out.isPreservesZA());
+  ASSERT_FALSE(ZA_Out.isNewZA());
+  ASSERT_TRUE(ZA_Out.sharesZA());
+  ASSERT_TRUE(ZA_Out.hasZAState());
+  ASSERT_TRUE(ZA_Out.hasSharedZAInterface());
+  ASSERT_FALSE(ZA_Out.hasPrivateZAInterface());
+
+  SA ZA_InOut = SA(SA::encodeZAState(SA::StateValue::InOut));
+  ASSERT_TRUE(ZA_InOut.isInOutZA());
+  ASSERT_FALSE(ZA_InOut.isInZA());
+  ASSERT_FALSE(ZA_InOut.isOutZA());
+  ASSERT_FALSE(ZA_InOut.isPreservesZA());
+  ASSERT_FALSE(ZA_InOut.isNewZA());
+  ASSERT_TRUE(ZA_InOut.sharesZA());
+  ASSERT_TRUE(ZA_InOut.hasZAState());
+  ASSERT_TRUE(ZA_InOut.hasSharedZAInterface());
+  ASSERT_FALSE(ZA_InOut.hasPrivateZAInterface());
+
+  SA ZA_Preserved = SA(SA::encodeZAState(SA::StateValue::Preserved));
+  ASSERT_TRUE(ZA_Preserved.isPreservesZA());
+  ASSERT_FALSE(ZA_Preserved.isInZA());
+  ASSERT_FALSE(ZA_Preserved.isOutZA());
+  ASSERT_FALSE(ZA_Preserved.isInOutZA());
+  ASSERT_FALSE(ZA_Preserved.isNewZA());
+  ASSERT_TRUE(ZA_Preserved.sharesZA());
+  ASSERT_TRUE(ZA_Preserved.hasZAState());
+  ASSERT_TRUE(ZA_Preserved.hasSharedZAInterface());
+  ASSERT_FALSE(ZA_Preserved.hasPrivateZAInterface());
+
+  SA ZA_New = SA(SA::encodeZAState(SA::StateValue::New));
+  ASSERT_TRUE(ZA_New.isNewZA());
+  ASSERT_FALSE(ZA_New.isInZA());
+  ASSERT_FALSE(ZA_New.isOutZA());
+  ASSERT_FALSE(ZA_New.isInOutZA());
+  ASSERT_FALSE(ZA_New.isPreservesZA());
+  ASSERT_FALSE(ZA_New.sharesZA());
+  ASSERT_TRUE(ZA_New.hasZAState());
+  ASSERT_FALSE(ZA_New.hasSharedZAInterface());
+  ASSERT_TRUE(ZA_New.hasPrivateZAInterface());
+
+  ASSERT_FALSE(SA(SA::Normal).isInZA());
+  ASSERT_FALSE(SA(SA::Normal).isOutZA());
+  ASSERT_FALSE(SA(SA::Normal).isInOutZA());
+  ASSERT_FALSE(SA(SA::Normal).isPreservesZA());
+  ASSERT_FALSE(SA(SA::Normal).isNewZA());
+  ASSERT_FALSE(SA(SA::Normal).sharesZA());
   ASSERT_FALSE(SA(SA::Normal).hasZAState());
-  ASSERT_FALSE(SA(SA::Normal).preservesZA());
 
   // Test ZT0 State interfaces
   SA ZT0_In = SA(SA::encodeZT0State(SA::StateValue::In));
@@ -245,9 +282,10 @@ TEST(SMEAttributes, Transitions) {
                    .requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
 
   SA Private_ZA = SA(SA::Normal);
-  SA ZA_Shared = SA(SA::ZA_Shared);
+  SA ZA_Shared = SA(SA::encodeZAState(SA::StateValue::In));
   SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In));
-  SA ZA_ZT0_Shared = SA(SA::ZA_Shared | SA::encodeZT0State(SA::StateValue::In));
+  SA ZA_ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In) |
+                        SA::encodeZT0State(SA::StateValue::In));
 
   // Shared ZA -> Private ZA Interface
   ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 66027c5ba77bd..7959d291e8926 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -43,10 +43,16 @@ def ArmZaMode : I32EnumAttr<"ArmZaMode", "Armv9 ZA storage mode",
       I32EnumAttrCase<"Disabled", 0, "disabled">,
       // A function's ZA state is created on entry and destroyed on exit.
       I32EnumAttrCase<"NewZA", 1, "arm_new_za">,
-      // A function that preserves ZA state.
-      I32EnumAttrCase<"PreservesZA", 2, "arm_preserves_za">,
-      // A function that uses ZA state as input and/or output
-      I32EnumAttrCase<"SharedZA", 3, "arm_shared_za">,
+      // A function with a Shared-ZA interfaces that takes ZA as input.
+      I32EnumAttrCase<"InZA", 2, "arm_in_za">,
+      // A function with a Shared-ZA interfaces that returns ZA as output.
+      I32EnumAttrCase<"OutZA", 3, "arm_out_za">,
+      // A function with a Shared-ZA interfaces that takes ZA as input and
+      // returns ZA as output.
+      I32EnumAttrCase<"InOutZA", 4, "arm_inout_za">,
+      // A function with a Shared-ZA interface that does not read ZA and
+      // returns with ZA unchanged.
+      I32EnumAttrCase<"PreservesZA", 5, "arm_preserves_za">,
     ]>{
   let cppNamespace = "mlir::arm_sme";
   let genSpecializedAttr = 0;
@@ -92,14 +98,23 @@ def EnableArmStreaming
                             "new-za",
                             "The function has ZA state. The ZA state is "
                             "created on entry and destroyed on exit."),
+                 clEnumValN(mlir::arm_sme::ArmZaMode::InZA,
+                            "in-za",
+                            "The function uses ZA state. The ZA state may "
+                            "be used for input."),
+                 clEnumValN(mlir::arm_sme::ArmZaMode::OutZA,
+                            "out-za",
+                            "The function uses ZA state. The ZA state may "
+                            "be used for output."),
+                 clEnumValN(mlir::arm_sme::ArmZaMode::InOutZA,
+                            "inout-za",
+                            "The function uses ZA state. The ZA state may "
+                            "be used for input and/or output."),
                  clEnumValN(mlir::arm_sme::ArmZaMode::PreservesZA,
                             "preserves-za",
-                            "The function preserves ZA state. The ZA state is "
-                            "saved on entry and restored on exit."),
-                 clEnumValN(mlir::arm_sme::ArmZaMode::SharedZA,
-                            "shared-za",
-                            "The function uses ZA state. The ZA state may "
-                            "be used for input and/or output.")
+                            "The function shares ZA state. The ZA state may "
+                            "not be used for input and/or output and the "
+                            "function must return with ZA unchanged")
            )}]>,
     Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
            /*default=*/"false",
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index ad67fba5a81cf..d9b130bdf18cb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1420,8 +1420,10 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OptionalAttr<UnitAttr>:$arm_locally_streaming,
     OptionalAttr<UnitAttr>:$arm_streaming_compatible,
     OptionalAttr<UnitAttr>:$arm_new_za,
+    OptionalAttr<UnitAttr>:$arm_in_za,
+    OptionalAttr<UnitAttr>:$arm_out_za,
+    OptionalAttr<UnitAttr>:$arm_inout_za,
     OptionalAttr<UnitAttr>:$arm_preserves_za,
-    OptionalAttr<UnitAttr>:$arm_shared_za,
     OptionalAttr<StrAttr>:$section,
     OptionalAttr<UnnamedAddr>:$unnamed_addr,
     OptionalAttr<I64Attr>:$alignment,
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 5ca4a9fd68d65..97ccb2b29f3ae 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1640,9 +1640,11 @@ static constexpr std::array ExplicitAttributes{
     StringLiteral("aarch64_pstate_sm_enabled"),
     StringLiteral("aarch64_pstate_sm_body"),
     StringLiteral("aarch64_pstate_sm_compatible"),
-    StringLiteral("aarch64_pstate_za_new"),
-    StringLiteral("aarch64_pstate_za_preserved"),
-    StringLiteral("aarch64_pstate_za_shared"),
+    StringLiteral("aarch64_new_za"),
+    StringLiteral("aarch64_preserves_za"),
+    StringLiteral("aarch64_in_za"),
+    StringLiteral("aarch64_out_za"),
+    StringLiteral("aarch64_inout_za"),
     StringLiteral("vscale_range"),
     StringLiteral("frame-pointer"),
     StringLiteral("target-features"),
@@ -1722,12 +1724,15 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
   else if (func->hasFnAttribute("aarch64_pstate_sm_compatible"))
     funcOp.setArmStreamingCompatible(true);
 
-  if (func->hasFnAttribute("aarch64_pstate_za_new"))
+  if (func->hasFnAttribute("aarch64_new_za"))
     funcOp.setArmNewZa(true);
-  else if (func->hasFnAttribute("aarch64_pstate_za_shared"))
-    funcOp.setArmSharedZa(true);
-  // PreservedZA can be used with either NewZA or SharedZA.
-  if (func->hasFnAttribute("aarch64_pstate_za_preserved"))
+  else if (func->hasFnAttribute("aarch64_in_za"))
+    funcOp.setArmInZa(true);
+  else if (func->hasFnAttribute("aarch64_out_za"))
+    funcOp.setArmOutZa(true);
+  else if (func->hasFnAttribute("aarch64_inout_za"))
+    funcOp.setArmInoutZa(true);
+  else if (func->hasFnAttribute("aarch64_preserves_za"))
     funcOp.setArmPreservesZa(true);
 
   llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 6364cacbd1924..d254925db7a0c 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1198,11 +1198,15 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
     llvmFunc->addFnAttr("aarch64_pstate_sm_compatible");
 
   if (func.getArmNewZa())
-    llvmFunc->addFnAttr("aarch64_pstate_za_new");
-  else if (func.getArmSharedZa())
-    llvmFunc->addFnAttr("aarch64_pstate_za_shared");
+    llvmFunc->addFnAttr("aarch64_new_za");
+  else if (func.getArmInZa())
+    llvmFunc->addFnAttr("aarch64_in_za");
+  else if (func.getArmOutZa())
+    llvmFunc->addFnAttr("aarch64_out_za");
+  else if (func.getArmInoutZa())
+    llvmFunc->addFnAttr("aarch64_inout_za");
   if (func.getArmPreservesZa())
-    llvmFunc->addFnAttr("aarch64_pstate_za_preserved");
+    llvmFunc->addFnAttr("aarch64_preserves_za");
 
   if (auto targetCpu = func.getTargetCpu())
     llvmFunc->addFnAttr("target-cpu", *targetCpu);
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index 0aa00f75c3a56..a20203d7e5579 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,6 +1,8 @@
 // RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=ENABLE-ZA
 // RUN: mlir-opt %s -enable-arm-streaming -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=DISABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=shared-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=SHARED-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=in-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=IN-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=out-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=OUT-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=inout-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=INOUT-ZA
 // RUN: mlir-opt %s -enable-arm-streaming=za-mode=preserves-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=PRESERVES-ZA
 // RUN: mlir-opt %s -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=NO-ARM-STREAMING
 
@@ -9,8 +11,12 @@ func.func private @declaration()
 
 // ENABLE-ZA-LABEL: @arm_new_za
 // ENABLE-ZA-SAME: attributes {arm_new_za, arm_streaming}
-// SHARED-ZA-LABEL: @arm_new_za
-// SHARED-ZA-SAME: attributes {arm_shared_za, arm_streaming}
+// IN-ZA-LABEL: @arm_new_za
+// IN-ZA-SAME: attributes {arm_in_za, arm_streaming}
+// OUT-ZA-LABEL: @arm_new_za
+// OUT-ZA-SAME: attributes {arm_out_za, arm_streaming}
+// INOUT-ZA-LABEL: @arm_new_za
+// INOUT-ZA-SAME: attributes {arm_inout_za, arm_streaming}
 // PRESERVES-ZA-LABEL: @arm_new_za
 // PRESERVES-ZA-SAME: attributes {arm_preserves_za, arm_streaming}
 // DISABLE-ZA-LABEL: @arm_new_za
@@ -19,6 +25,8 @@ func.func private @declaration()
 // NO-ARM-STREAMING-LABEL: @arm_new_za
 // NO-ARM-STREAMING-NOT: arm_new_za
 // NO-ARM-STREAMING-NOT: arm_streaming
-// NO-ARM-STREAMING-NOT: arm_shared_za
+// NO-ARM-STREAMING-NOT: arm_in_za
+// NO-ARM-STREAMING-NOT: arm_out_za
+// NO-ARM-STREAMING-NOT: arm_inout_za
 // NO-ARM-STREAMING-NOT: arm_preserves_za
 func.func @arm_new_za() { return }
diff --git a/mlir/test/Target/LLVMIR/Import/function-attributes.ll b/mlir/test/Target/LLVMIR/Import/function-attributes.ll
index c46db5e346434..f5fb06df49487 100644
--- a/mlir/test/Target/LLVMIR/Import/function-attributes.ll
+++ b/mlir/test/Target/LLVMIR/Import/function-attributes.ll
@@ -222,20 +222,32 @@ define void @streaming_compatible_func() "aarch64_pstate_sm_compatible" {
 
 ; CHECK-LABEL: @arm_new_za_func
 ; CHECK-SAME: attributes {arm_new_za}
-define void @arm_new_za_func() "aarch64_pstate_za_new" {
+define void @arm_new_za_func() "aarch64_new_za" {
   ret void
 }
 
 
-; CHECK-LABEL: @arm_preserves_za_func
-; CHECK-SAME: attributes {arm_preserves_za}
-define void @arm_preserves_za_func() "aarch64_pstate_za_preserved" {
+; CHECK-LABEL: @arm_in_za_func
+; CHECK-SAME: attributes {arm_in_za}
+define void @arm_in_za_func() "aarch64_in_za" {
+  ret void
+}
+
+; CHECK-LABEL: @arm_out_za_func
+; CHECK-SAME: attributes {arm_out_za}
+define void @arm_out_za_func() "aarch64_out_za" {
   ret void
 }
 
-; CHECK-LABEL: @arm_shared_za_func
-; CHECK-SAME: attributes {arm_shared_za}
-define void @arm_shared_za_func() "aarch64_pstate_za_shared" {
+; CHECK-LABEL: @arm_inout_za_func
+; CHECK-SAME: attributes {arm_inout_za}
+define void @arm_inout_za_func() "aarch64_inout_za" {
+  ret void
+}
+
+; CHECK-LABEL: @arm_preserves_za_func
+; CHECK-SAME: attributes {arm_preserves_za}
+define void @arm_preserves_za_func() "aarch64_preserves_za" {
   ret void
 }
 
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 448aa3a5d85d7..63774bf0baf68 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2358,21 +2358,35 @@ llvm.func @streaming_compatible_func() attributes {arm_streaming_compatible} {
 llvm.func @new_za_func() attributes {arm_new_za} {
   llvm.return
 }
-// CHECK #[[ATTR]] = { "aarch64_pstate_za_new" }
+// CHECK #[[ATTR]] = { "aarch64_new_za" }
 
-// CHECK-LABEL: @shared_za_func
+// CHECK-LABEL: @in_za_func
 // CHECK: #[[ATTR:[0-9]*]]
-llvm.func @shared_za_func() attributes {arm_shared_za } {
+llvm.func @in_za_func() attributes {arm_in_za } {
   llvm.return
 }
-// CHECK #[[ATTR]] = { "aarch64_pstate_za_shared" }
+// CHECK #[[ATTR]] = { "aarch64_in_za" }
+
+// CHECK-LABEL: @out_za_func
+// CHECK: #[[ATTR:[0-9]*]]
+llvm.func @out_za_func() attributes {arm_out_za } {
+  llvm.return
+}
+// CHECK #[[ATTR]] = { "aarch64_out_za" }
+
+// CHECK-LABEL: @inout_za_func
+// CHECK: #[[ATTR:[0-9]*]]
+llvm.func @inout_za_func() attributes {arm_inout_za } {
+  llvm.return
+}
+// CHECK #[[ATTR]] = { "aarch64_inout_za" }
 
 // CHECK-LABEL: @preserves_za_func
 // CHECK: #[[ATTR:[0-9]*]]
 llvm.func @preserves_za_func() attributes {arm_preserves_za} {
   llvm.return
 }
-// CHECK #[[ATTR]] = { "aarch64_pstate_za_preserved" }
+// CHECK #[[ATTR]] = { "aarch64_preserves_za" }
 
 // -----
 

>From ed88e6e6935e7117e18f42a1d1e3f4c5619c6f79 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Mon, 29 Jan 2024 11:49:05 +0000
Subject: [PATCH 2/4] Address comments [to squash]

---
 llvm/docs/AArch64SME.rst                     | 2 +-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/docs/AArch64SME.rst b/llvm/docs/AArch64SME.rst
index 4b6fa5e10f880..b5a01cb204b81 100644
--- a/llvm/docs/AArch64SME.rst
+++ b/llvm/docs/AArch64SME.rst
@@ -31,7 +31,7 @@ level ACLE attributes:
   is used for functions with ``__arm_locally_streaming`` and is
   only valid on function definitions (not declarations)
 
-``aarch64_pstate_new_za``
+``aarch64_new_za``
   is used for functions with ``__arm_new("za")``
 
 ``aarch64_in_za``
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index d254925db7a0c..a54221580b28b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1205,7 +1205,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
     llvmFunc->addFnAttr("aarch64_out_za");
   else if (func.getArmInoutZa())
     llvmFunc->addFnAttr("aarch64_inout_za");
-  if (func.getArmPreservesZa())
+  else if (func.getArmPreservesZa())
     llvmFunc->addFnAttr("aarch64_preserves_za");
 
   if (auto targetCpu = func.getTargetCpu())

>From 9c71c908aec138992c66dc376c16b22a9373f02a Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Wed, 31 Jan 2024 14:27:22 +0000
Subject: [PATCH 3/4] Fix SMEAttributesTest [to squash]

---
 llvm/unittests/Target/AArch64/SMEAttributesTest.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 2c1c92dfa602a..0f330b3a17e1a 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -284,7 +284,7 @@ TEST(SMEAttributes, Transitions) {
   SA Private_ZA = SA(SA::Normal);
   SA ZA_Shared = SA(SA::encodeZAState(SA::StateValue::In));
   SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In));
-  SA ZA_ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In) |
+  SA ZA_ZT0_Shared = SA(SA::encodeZAState(SA::StateValue::In) |
                         SA::encodeZT0State(SA::StateValue::In));
 
   // Shared ZA -> Private ZA Interface

>From 3f288591a22b12a962249988e91bce3b39fd250b Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Wed, 31 Jan 2024 15:02:19 +0000
Subject: [PATCH 4/4] Run clang-format [to squash]

---
 llvm/lib/IR/Verifier.cpp                          | 15 +++++++--------
 llvm/lib/Target/AArch64/SMEABIPass.cpp            |  5 ++---
 .../Target/AArch64/Utils/AArch64SMEAttributes.cpp |  2 +-
 .../Target/AArch64/SMEAttributesTest.cpp          | 13 ++++++-------
 4 files changed, 16 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 53d923ced8d2e..b04d39c700a8f 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2155,14 +2155,13 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs,
            V);
   }
 
-  Check(
-      (Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") +
-       Attrs.hasFnAttr("aarch64_inout_za") +
-       Attrs.hasFnAttr("aarch64_out_za") +
-       Attrs.hasFnAttr("aarch64_preserves_za")) <= 1,
-      "Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
-      "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive",
-      V);
+  Check((Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") +
+         Attrs.hasFnAttr("aarch64_inout_za") +
+         Attrs.hasFnAttr("aarch64_out_za") +
+         Attrs.hasFnAttr("aarch64_preserves_za")) <= 1,
+        "Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
+        "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive",
+        V);
 
   Check(
       (Attrs.hasFnAttr("aarch64_new_zt0") + Attrs.hasFnAttr("aarch64_in_zt0") +
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index bee5d63c5a749..23b3cc9ec6215 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -60,9 +60,8 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); }
 void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
   auto *TPIDR2SaveTy =
       FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false);
-  auto Attrs =
-      AttributeList()
-          .addFnAttribute(M->getContext(), "aarch64_pstate_sm_compatible");
+  auto Attrs = AttributeList().addFnAttribute(M->getContext(),
+                                              "aarch64_pstate_sm_compatible");
   FunctionCallee Callee =
       M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs);
   CallInst *Call = Builder.CreateCall(Callee);
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index ac07274a36c23..d399e0ac0794f 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -52,7 +52,7 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
     Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
   if (FuncName == "__arm_tpidr2_restore")
     Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
-                SMEAttrs::SME_ABI_Routine;
+               SMEAttrs::SME_ABI_Routine;
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 0f330b3a17e1a..3af5e24168c8c 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -38,22 +38,21 @@ TEST(SMEAttributes, Constructors) {
               ->getFunction("foo"))
           .hasStreamingCompatibleInterface());
 
-
   ASSERT_TRUE(
       SA(*parseIR("declare void @foo() \"aarch64_in_za\"")->getFunction("foo"))
           .isInZA());
-  ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_out_za\"")
-                      ->getFunction("foo"))
-                  .isOutZA());
+  ASSERT_TRUE(
+      SA(*parseIR("declare void @foo() \"aarch64_out_za\"")->getFunction("foo"))
+          .isOutZA());
   ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_inout_za\"")
                       ->getFunction("foo"))
                   .isInOutZA());
   ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_preserves_za\"")
                       ->getFunction("foo"))
                   .isPreservesZA());
-  ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_new_za\"")
-                      ->getFunction("foo"))
-                  .isNewZA());
+  ASSERT_TRUE(
+      SA(*parseIR("declare void @foo() \"aarch64_new_za\"")->getFunction("foo"))
+          .isNewZA());
 
   ASSERT_TRUE(
       SA(*parseIR("declare void @foo() \"aarch64_in_zt0\"")->getFunction("foo"))



More information about the Mlir-commits mailing list