[llvm] 703b479 - Revert "[AArch64][SME] Split SMECallAttrs out of SMEAttrs" (#138664)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 6 02:28:16 PDT 2025


Author: Benjamin Maxwell
Date: 2025-05-06T10:28:13+01:00
New Revision: 703b479f16b9657a9c0d3a3d992278ad9c555166

URL: https://github.com/llvm/llvm-project/commit/703b479f16b9657a9c0d3a3d992278ad9c555166
DIFF: https://github.com/llvm/llvm-project/commit/703b479f16b9657a9c0d3a3d992278ad9c555166.diff

LOG: Revert "[AArch64][SME] Split SMECallAttrs out of SMEAttrs" (#138664)

Reverts llvm/llvm-project#137239

This broke implementing SME ABI routines in C/C++ (used for some stubs),
see: https://lab.llvm.org/buildbot/#/builders/94/builds/6859

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
    llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
    llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
    llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll
    llvm/test/CodeGen/AArch64/sme-zt0-state.ll
    llvm/unittests/Target/AArch64/SMEAttributesTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ab577e130ad9c..1c889d67c81e0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8636,16 +8636,6 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
   }
 }
 
-static SMECallAttrs
-getSMECallAttrs(const Function &Function,
-                const TargetLowering::CallLoweringInfo &CLI) {
-  if (CLI.CB)
-    return SMECallAttrs(*CLI.CB);
-  if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
-    return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
-  return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
-}
-
 bool AArch64TargetLowering::isEligibleForTailCallOptimization(
     const CallLoweringInfo &CLI) const {
   CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8664,10 +8654,12 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
 
   // SME Streaming functions are not eligible for TCO as they may require
   // the streaming mode or ZA to be restored after returning from the call.
-  SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
-  if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
-      CallAttrs.requiresPreservingAllZAState() ||
-      CallAttrs.caller().hasStreamingBody())
+  SMEAttrs CallerAttrs(MF.getFunction());
+  auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
+  if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
+      CallerAttrs.requiresLazySave(CalleeAttrs) ||
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
+      CallerAttrs.hasStreamingBody())
     return false;
 
   // Functions using the C or Fast calling convention that have an SVE signature
@@ -8959,13 +8951,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
   return TLI.LowerCallTo(CLI).second;
 }
 
-static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
-  if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
-      CallAttrs.caller().hasStreamingBody())
+static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
+                               const SMEAttrs &CalleeAttrs) {
+  if (!CallerAttrs.hasStreamingCompatibleInterface() ||
+      CallerAttrs.hasStreamingBody())
     return AArch64SME::Always;
-  if (CallAttrs.callee().hasNonStreamingInterface())
+  if (CalleeAttrs.hasNonStreamingInterface())
     return AArch64SME::IfCallerIsStreaming;
-  if (CallAttrs.callee().hasStreamingInterface())
+  if (CalleeAttrs.hasStreamingInterface())
     return AArch64SME::IfCallerIsNonStreaming;
 
   llvm_unreachable("Unsupported attributes");
@@ -9098,7 +9091,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   }
 
   // Determine whether we need any streaming mode changes.
-  SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
+  SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
+  if (CLI.CB)
+    CalleeAttrs = SMEAttrs(*CLI.CB);
+  else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
+    CalleeAttrs = SMEAttrs(ES->getSymbol());
 
   auto DescribeCallsite =
       [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9113,8 +9110,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     return R;
   };
 
-  bool RequiresLazySave = CallAttrs.requiresLazySave();
-  bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
+  bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
+  bool RequiresSaveAllZA =
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
   if (RequiresLazySave) {
     const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
     MachinePointerInfo MPI =
@@ -9142,18 +9140,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
       return DescribeCallsite(R) << " sets up a lazy save for ZA";
     });
   } else if (RequiresSaveAllZA) {
-    assert(!CallAttrs.callee().hasSharedZAInterface() &&
+    assert(!CalleeAttrs.hasSharedZAInterface() &&
            "Cannot share state that may not exist");
     Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
                                     /*IsSave=*/true);
   }
 
   SDValue PStateSM;
-  bool RequiresSMChange = CallAttrs.requiresSMChange();
+  bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
   if (RequiresSMChange) {
-    if (CallAttrs.caller().hasStreamingInterfaceOrBody())
+    if (CallerAttrs.hasStreamingInterfaceOrBody())
       PStateSM = DAG.getConstant(1, DL, MVT::i64);
-    else if (CallAttrs.caller().hasNonStreamingInterface())
+    else if (CallerAttrs.hasNonStreamingInterface())
       PStateSM = DAG.getConstant(0, DL, MVT::i64);
     else
       PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9170,7 +9168,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   SDValue ZTFrameIdx;
   MachineFrameInfo &MFI = MF.getFrameInfo();
-  bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
+  bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
 
   // If the caller has ZT0 state which will not be preserved by the callee,
   // spill ZT0 before the call.
@@ -9186,7 +9184,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   // If caller shares ZT0 but the callee is not shared ZA, we need to stop
   // PSTATE.ZA before the call if there is no lazy-save active.
-  bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
+  bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
   assert((!DisableZA || !RequiresLazySave) &&
          "Lazy-save should have PSTATE.SM=1 on entry to the function");
 
@@ -9468,9 +9466,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
       InGlue = Chain.getValue(1);
     }
 
-    SDValue NewChain =
-        changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
-                            Chain, InGlue, getSMCondition(CallAttrs), PStateSM);
+    SDValue NewChain = changeStreamingMode(
+        DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
+        getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
     Chain = NewChain.getValue(0);
     InGlue = NewChain.getValue(1);
   }
@@ -9649,8 +9647,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   if (RequiresSMChange) {
     assert(PStateSM && "Expected a PStateSM to be set");
     Result = changeStreamingMode(
-        DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
-        getSMCondition(CallAttrs), PStateSM);
+        DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
+        getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
 
     if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
       InGlue = Result.getValue(1);
@@ -9660,7 +9658,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     }
   }
 
-  if (CallAttrs.requiresEnablingZAAfterCall())
+  if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
     // Unconditionally resume ZA.
     Result = DAG.getNode(
         AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28520,10 +28518,12 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
 
   // Checks to allow the use of SME instructions
   if (auto *Base = dyn_cast<CallBase>(&Inst)) {
-    auto CallAttrs = SMECallAttrs(*Base);
-    if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
-        CallAttrs.requiresPreservingZT0() ||
-        CallAttrs.requiresPreservingAllZAState())
+    auto CallerAttrs = SMEAttrs(*Inst.getFunction());
+    auto CalleeAttrs = SMEAttrs(*Base);
+    if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
+        CallerAttrs.requiresLazySave(CalleeAttrs) ||
+        CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+        CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
       return true;
   }
   return false;

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 34fb76e66ec00..5b3a1df9dfd7c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -268,21 +268,22 @@ const FeatureBitset AArch64TTIImpl::InlineInverseFeatures = {
 
 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
                                          const Function *Callee) const {
-  SMECallAttrs CallAttrs(*Caller, *Callee);
+  SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
 
   // When inlining, we should consider the body of the function, not the
   // interface.
-  if (CallAttrs.callee().hasStreamingBody()) {
-    CallAttrs.callee().set(SMEAttrs::SM_Compatible, false);
-    CallAttrs.callee().set(SMEAttrs::SM_Enabled, true);
+  if (CalleeAttrs.hasStreamingBody()) {
+    CalleeAttrs.set(SMEAttrs::SM_Compatible, false);
+    CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
   }
 
-  if (CallAttrs.callee().isNewZA() || CallAttrs.callee().isNewZT0())
+  if (CalleeAttrs.isNewZA() || CalleeAttrs.isNewZT0())
     return false;
 
-  if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
-      CallAttrs.requiresPreservingZT0() ||
-      CallAttrs.requiresPreservingAllZAState()) {
+  if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
+      CallerAttrs.requiresSMChange(CalleeAttrs) ||
+      CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
     if (hasPossibleIncompatibleOps(Callee))
       return false;
   }
@@ -348,14 +349,12 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
   // streaming-mode change, and the call to G from F would also require a
   // streaming-mode change, then there is benefit to do the streaming-mode
   // change only once and avoid inlining of G into F.
-
   SMEAttrs FAttrs(*F);
-  SMECallAttrs CallAttrs(Call);
-
-  if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
+  SMEAttrs CalleeAttrs(Call);
+  if (FAttrs.requiresSMChange(CalleeAttrs)) {
     if (F == Call.getCaller()) // (1)
       return CallPenaltyChangeSM * DefaultCallPenalty;
-    if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
+    if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
       return InlineCallPenaltyChangeSM * DefaultCallPenalty;
   }
 

diff  --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 16ae5434e596a..76d2ac6a601e5 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -27,14 +27,15 @@ void SMEAttrs::set(unsigned M, bool Enable) {
          "ZA_New and SME_ABI_Routine are mutually exclusive");
 
   assert(
-      (isNewZA() + isInZA() + isOutZA() + isInOutZA() + isPreservesZA()) <= 1 &&
+      (!sharesZA() ||
+       (isNewZA() ^ isInZA() ^ isInOutZA() ^ isOutZA() ^ isPreservesZA())) &&
       "Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
       "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive");
 
   // ZT0 Attrs
   assert(
-      (isNewZT0() + isInZT0() + isOutZT0() + isInOutZT0() + isPreservesZT0()) <=
-          1 &&
+      (!sharesZT0() || (isNewZT0() ^ isInZT0() ^ isInOutZT0() ^ isOutZT0() ^
+                        isPreservesZT0())) &&
       "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
       "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
 
@@ -43,6 +44,27 @@ void SMEAttrs::set(unsigned M, bool Enable) {
          "interface");
 }
 
+SMEAttrs::SMEAttrs(const CallBase &CB) {
+  *this = SMEAttrs(CB.getAttributes());
+  if (auto *F = CB.getCalledFunction()) {
+    set(SMEAttrs(*F).Bitmask | SMEAttrs(F->getName()).Bitmask);
+  }
+}
+
+SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
+  if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
+    Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
+  if (FuncName == "__arm_tpidr2_restore")
+    Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
+               SMEAttrs::SME_ABI_Routine;
+  if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
+      FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
+    Bitmask |= SMEAttrs::SM_Compatible;
+  if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
+      FuncName == "__arm_sme_state_size")
+    Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
+}
+
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
   Bitmask = 0;
   if (Attrs.hasFnAttr("aarch64_pstate_sm_enabled"))
@@ -77,45 +99,17 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
     Bitmask |= encodeZT0State(StateValue::New);
 }
 
-void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
-  unsigned KnownAttrs = SMEAttrs::Normal;
-  if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
-    KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
-  if (FuncName == "__arm_tpidr2_restore")
-    KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
-                  SMEAttrs::SME_ABI_Routine;
-  if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
-      FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
-    KnownAttrs |= SMEAttrs::SM_Compatible;
-  if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
-      FuncName == "__arm_sme_state_size")
-    KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
-  set(KnownAttrs, /*Enable=*/true);
-}
-
-bool SMECallAttrs::requiresSMChange() const {
-  if (callee().hasStreamingCompatibleInterface())
+bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
+  if (Callee.hasStreamingCompatibleInterface())
     return false;
 
   // Both non-streaming
-  if (caller().hasNonStreamingInterfaceAndBody() &&
-      callee().hasNonStreamingInterface())
+  if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface())
     return false;
 
   // Both streaming
-  if (caller().hasStreamingInterfaceOrBody() &&
-      callee().hasStreamingInterface())
+  if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface())
     return false;
 
   return true;
 }
-
-SMECallAttrs::SMECallAttrs(const CallBase &CB)
-    : CallerFn(*CB.getFunction()), CalledFn(CB.getCalledFunction()),
-      Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) {
-  // FIXME: We probably should not allow SME attributes on direct calls but
-  // clang duplicates streaming mode attributes at each callsite.
-  assert((IsIndirect ||
-          ((Callsite.withoutPerCallsiteFlags() | CalledFn) == CalledFn)) &&
-         "SME attributes at callsite do not match declaration");
-}

diff  --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index c4f132ba6ddf1..1691d4fec8b68 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -18,9 +18,12 @@ class CallBase;
 class AttributeList;
 
 /// SMEAttrs is a utility class to parse the SME ACLE attributes on functions.
-/// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM.
+/// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM. It
+/// has interfaces to query whether a streaming mode change or lazy-save
+/// mechanism is required when going from one function to another (e.g. through
+/// a call).
 class SMEAttrs {
-  unsigned Bitmask = Normal;
+  unsigned Bitmask;
 
 public:
   enum class StateValue {
@@ -40,24 +43,18 @@ class SMEAttrs {
     SM_Body = 1 << 2,         // aarch64_pstate_sm_body
     SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
     ZA_State_Agnostic = 1 << 4,
-    ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills
+    ZT0_Undef = 1 << 5,       // Use to mark ZT0 as undef to avoid spills
     ZA_Shift = 6,
     ZA_Mask = 0b111 << ZA_Shift,
     ZT0_Shift = 9,
-    ZT0_Mask = 0b111 << ZT0_Shift,
-    Callsite_Flags = ZT0_Undef
+    ZT0_Mask = 0b111 << ZT0_Shift
   };
 
-  SMEAttrs() = default;
-  SMEAttrs(unsigned Mask) { set(Mask); }
-  SMEAttrs(const Function *F)
-      : SMEAttrs(F ? F->getAttributes() : AttributeList()) {
-    if (F)
-      addKnownFunctionAttrs(F->getName());
-  }
-  SMEAttrs(const Function &F) : SMEAttrs(&F) {}
+  SMEAttrs(unsigned Mask = Normal) : Bitmask(0) { set(Mask); }
+  SMEAttrs(const Function &F) : SMEAttrs(F.getAttributes()) {}
+  SMEAttrs(const CallBase &CB);
   SMEAttrs(const AttributeList &L);
-  SMEAttrs(StringRef FuncName) { addKnownFunctionAttrs(FuncName); };
+  SMEAttrs(StringRef FuncName);
 
   void set(unsigned M, bool Enable = true);
 
@@ -77,6 +74,10 @@ class SMEAttrs {
     return hasNonStreamingInterface() && !hasStreamingBody();
   }
 
+  /// \return true if a call from Caller -> Callee requires a change in
+  /// streaming mode.
+  bool requiresSMChange(const SMEAttrs &Callee) const;
+
   // Interfaces to query ZA
   static StateValue decodeZAState(unsigned Bitmask) {
     return static_cast<StateValue>((Bitmask & ZA_Mask) >> ZA_Shift);
@@ -103,7 +104,10 @@ class SMEAttrs {
     return !hasSharedZAInterface() && !hasAgnosticZAInterface();
   }
   bool hasZAState() const { return isNewZA() || sharesZA(); }
-  bool isSMEABIRoutine() const { return Bitmask & SME_ABI_Routine; }
+  bool requiresLazySave(const SMEAttrs &Callee) const {
+    return hasZAState() && Callee.hasPrivateZAInterface() &&
+           !(Callee.Bitmask & SME_ABI_Routine);
+  }
 
   // Interfaces to query ZT0 State
   static StateValue decodeZT0State(unsigned Bitmask) {
@@ -122,83 +126,27 @@ class SMEAttrs {
   bool isPreservesZT0() const {
     return decodeZT0State(Bitmask) == StateValue::Preserved;
   }
-  bool hasUndefZT0() const { return Bitmask & ZT0_Undef; }
+  bool isUndefZT0() const { return Bitmask & ZT0_Undef; }
   bool sharesZT0() const {
     StateValue State = decodeZT0State(Bitmask);
     return State == StateValue::In || State == StateValue::Out ||
            State == StateValue::InOut || State == StateValue::Preserved;
   }
   bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
-
-  SMEAttrs operator|(SMEAttrs Other) const {
-    SMEAttrs Merged(*this);
-    Merged.set(Other.Bitmask, /*Enable=*/true);
-    return Merged;
-  }
-
-  SMEAttrs withoutPerCallsiteFlags() const {
-    return (Bitmask & ~Callsite_Flags);
-  }
-
-  bool operator==(SMEAttrs const &Other) const {
-    return Bitmask == Other.Bitmask;
-  }
-
-private:
-  void addKnownFunctionAttrs(StringRef FuncName);
-};
-
-/// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has
-/// interfaces to query whether a streaming mode change or lazy-save mechanism
-/// is required when going from one function to another (e.g. through a call).
-class SMECallAttrs {
-  SMEAttrs CallerFn;
-  SMEAttrs CalledFn;
-  SMEAttrs Callsite;
-  bool IsIndirect = false;
-
-public:
-  SMECallAttrs(SMEAttrs Caller, SMEAttrs Callee,
-               SMEAttrs Callsite = SMEAttrs::Normal)
-      : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {}
-
-  SMECallAttrs(const CallBase &CB);
-
-  SMEAttrs &caller() { return CallerFn; }
-  SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; }
-  SMEAttrs &callsite() { return Callsite; }
-  SMEAttrs const &caller() const { return CallerFn; }
-  SMEAttrs const &callee() const {
-    return const_cast<SMECallAttrs *>(this)->callee();
-  }
-  SMEAttrs const &callsite() const { return Callsite; }
-
-  /// \return true if a call from Caller -> Callee requires a change in
-  /// streaming mode.
-  bool requiresSMChange() const;
-
-  bool requiresLazySave() const {
-    return caller().hasZAState() && callee().hasPrivateZAInterface() &&
-           !callee().isSMEABIRoutine();
+  bool requiresPreservingZT0(const SMEAttrs &Callee) const {
+    return hasZT0State() && !Callee.isUndefZT0() && !Callee.sharesZT0() &&
+           !Callee.hasAgnosticZAInterface();
   }
-
-  bool requiresPreservingZT0() const {
-    return caller().hasZT0State() && !callsite().hasUndefZT0() &&
-           !callee().sharesZT0() && !callee().hasAgnosticZAInterface();
+  bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
+    return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() &&
+           !(Callee.Bitmask & SME_ABI_Routine);
   }
-
-  bool requiresDisablingZABeforeCall() const {
-    return caller().hasZT0State() && !caller().hasZAState() &&
-           callee().hasPrivateZAInterface() && !callee().isSMEABIRoutine();
+  bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
+    return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee);
   }
-
-  bool requiresEnablingZAAfterCall() const {
-    return requiresLazySave() || requiresDisablingZABeforeCall();
-  }
-
-  bool requiresPreservingAllZAState() const {
-    return caller().hasAgnosticZAInterface() &&
-           !callee().hasAgnosticZAInterface() && !callee().isSMEABIRoutine();
+  bool requiresPreservingAllZAState(const SMEAttrs &Callee) const {
+    return hasAgnosticZAInterface() && !Callee.hasAgnosticZAInterface() &&
+           !(Callee.Bitmask & SME_ABI_Routine);
   }
 };
 

diff  --git a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
index 130a316bcc2ba..6ea2267cd22e6 100644
--- a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
@@ -2,12 +2,11 @@
 ; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-streaming-hazard-size=0 -mattr=+sve,+sme2 < %s | FileCheck %s
 
 declare void @callee()
-declare void @callee_sm() "aarch64_pstate_sm_enabled"
 declare void @callee_farg(float)
 declare float @callee_farg_fret(float)
 
 ; normal caller -> streaming callees
-define void @test0(ptr %callee) nounwind {
+define void @test0() nounwind {
 ; CHECK-LABEL: test0:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
@@ -17,8 +16,8 @@ define void @test0(ptr %callee) nounwind {
 ; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
 ; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
 ; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    bl callee_sm
-; CHECK-NEXT:    bl callee_sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
@@ -26,8 +25,8 @@ define void @test0(ptr %callee) nounwind {
 ; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void @callee_sm()
-  call void @callee_sm()
+  call void @callee() "aarch64_pstate_sm_enabled"
+  call void @callee() "aarch64_pstate_sm_enabled"
   ret void
 }
 
@@ -119,7 +118,7 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
 ; CHECK-NEXT:  // %bb.1:
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:  .LBB3_2:
-; CHECK-NEXT:    bl callee_sm
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    tbnz w19, #0, .LBB3_4
 ; CHECK-NEXT:  // %bb.3:
 ; CHECK-NEXT:    smstop sm
@@ -141,7 +140,7 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
 ; CHECK-NEXT:  // %bb.9:
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:  .LBB3_10:
-; CHECK-NEXT:    bl callee_sm
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    tbnz w19, #0, .LBB3_12
 ; CHECK-NEXT:  // %bb.11:
 ; CHECK-NEXT:    smstop sm
@@ -153,9 +152,9 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
 ; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void @callee_sm()
+  call void @callee() "aarch64_pstate_sm_enabled"
   call void @callee()
-  call void @callee_sm()
+  call void @callee() "aarch64_pstate_sm_enabled"
   ret void
 }
 
@@ -343,7 +342,7 @@ define void @test10() "aarch64_pstate_sm_body" {
 ; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:    .cfi_restore vg
-; CHECK-NEXT:    bl callee_sm
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    .cfi_offset vg, -24
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    bl callee
@@ -364,7 +363,7 @@ define void @test10() "aarch64_pstate_sm_body" {
 ; CHECK-NEXT:    .cfi_restore b15
 ; CHECK-NEXT:    ret
   call void @callee()
-  call void @callee_sm()
+  call void @callee() "aarch64_pstate_sm_enabled"
   call void @callee()
   ret void
 }

diff  --git a/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll b/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll
index 0853325e449af..17d689d2c9eb5 100644
--- a/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll
+++ b/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll
@@ -1098,11 +1098,11 @@ define void @test_rdsvl_right_after_prologue(i64 %x0) nounwind {
 ; NO-SVE-CHECK-NEXT: ret
   %some_alloc = alloca i64, align 8
   %rdsvl = tail call i64 @llvm.aarch64.sme.cntsd()
-  call void @bar(i64 %rdsvl, i64 %x0)
+  call void @bar(i64 %rdsvl, i64 %x0) "aarch64_pstate_sm_enabled"
   ret void
 }
 
-declare void @bar(i64, i64) "aarch64_pstate_sm_enabled"
+declare void @bar(i64, i64)
 
 ; Ensure we still emit async unwind information with -fno-asynchronous-unwind-tables
 ; if the function contains a streaming-mode change.

diff  --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 63577e4d217a8..7361e850d713e 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -1,13 +1,15 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
 ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
 
+declare void @callee();
+
 ;
 ; Private-ZA Callee
 ;
 
 ; Expect spill & fill of ZT0 around call
 ; Expect smstop/smstart za around call
-define void @zt0_in_caller_no_state_callee(ptr %callee) "aarch64_in_zt0" nounwind {
+define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: zt0_in_caller_no_state_callee:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    sub sp, sp, #80
@@ -15,20 +17,20 @@ define void @zt0_in_caller_no_state_callee(ptr %callee) "aarch64_in_zt0" nounwin
 ; CHECK-NEXT:    mov x19, sp
 ; CHECK-NEXT:    str zt0, [x19]
 ; CHECK-NEXT:    smstop za
-; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    ldr zt0, [x19]
 ; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
 ; CHECK-NEXT:    add sp, sp, #80
 ; CHECK-NEXT:    ret
-  call void %callee();
+  call void @callee();
   ret void;
 }
 
 ; Expect spill & fill of ZT0 around call
 ; Expect setup and restore lazy-save around call
 ; Expect smstart za after call
-define void @za_zt0_shared_caller_no_state_callee(ptr %callee) "aarch64_inout_za" "aarch64_in_zt0" nounwind {
+define void @za_zt0_shared_caller_no_state_callee() "aarch64_inout_za" "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: za_zt0_shared_caller_no_state_callee:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
@@ -47,7 +49,7 @@ define void @za_zt0_shared_caller_no_state_callee(ptr %callee) "aarch64_inout_za
 ; CHECK-NEXT:    sturh w8, [x29, #-8]
 ; CHECK-NEXT:    msr TPIDR2_EL0, x9
 ; CHECK-NEXT:    str zt0, [x19]
-; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    ldr zt0, [x19]
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
@@ -61,7 +63,7 @@ define void @za_zt0_shared_caller_no_state_callee(ptr %callee) "aarch64_inout_za
 ; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
 ; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void %callee();
+  call void @callee();
   ret void;
 }
 
@@ -70,43 +72,43 @@ define void @za_zt0_shared_caller_no_state_callee(ptr %callee) "aarch64_inout_za
 ;
 
 ; Caller and callee have shared ZT0 state, no spill/fill of ZT0 required
-define void @zt0_shared_caller_zt0_shared_callee(ptr %callee) "aarch64_in_zt0" nounwind {
+define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: zt0_shared_caller_zt0_shared_callee:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void %callee() "aarch64_in_zt0";
+  call void @callee() "aarch64_in_zt0";
   ret void;
 }
 
 ; Expect spill & fill of ZT0 around call
-define void @za_zt0_shared_caller_za_shared_callee(ptr %callee) "aarch64_inout_za" "aarch64_in_zt0" nounwind {
+define void @za_zt0_shared_caller_za_shared_callee() "aarch64_inout_za" "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: za_zt0_shared_caller_za_shared_callee:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    sub sp, sp, #80
 ; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
 ; CHECK-NEXT:    mov x19, sp
 ; CHECK-NEXT:    str zt0, [x19]
-; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    ldr zt0, [x19]
 ; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
 ; CHECK-NEXT:    add sp, sp, #80
 ; CHECK-NEXT:    ret
-  call void %callee() "aarch64_inout_za";
+  call void @callee() "aarch64_inout_za";
   ret void;
 }
 
 ; Caller and callee have shared ZA & ZT0
-define void @za_zt0_shared_caller_za_zt0_shared_callee(ptr %callee) "aarch64_inout_za" "aarch64_in_zt0" nounwind {
+define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_inout_za" "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: za_zt0_shared_caller_za_zt0_shared_callee:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
+  call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
   ret void;
 }
 
@@ -114,7 +116,7 @@ define void @za_zt0_shared_caller_za_zt0_shared_callee(ptr %callee) "aarch64_ino
 
 ; Expect spill & fill of ZT0 around call
 ; Expect smstop/smstart za around call
-define void @zt0_in_caller_zt0_new_callee(ptr %callee) "aarch64_in_zt0" nounwind {
+define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: zt0_in_caller_zt0_new_callee:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    sub sp, sp, #80
@@ -122,13 +124,13 @@ define void @zt0_in_caller_zt0_new_callee(ptr %callee) "aarch64_in_zt0" nounwind
 ; CHECK-NEXT:    mov x19, sp
 ; CHECK-NEXT:    str zt0, [x19]
 ; CHECK-NEXT:    smstop za
-; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    ldr zt0, [x19]
 ; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
 ; CHECK-NEXT:    add sp, sp, #80
 ; CHECK-NEXT:    ret
-  call void %callee() "aarch64_new_zt0";
+  call void @callee() "aarch64_new_zt0";
   ret void;
 }
 
@@ -138,7 +140,7 @@ define void @zt0_in_caller_zt0_new_callee(ptr %callee) "aarch64_in_zt0" nounwind
 ; Expect smstart ZA & clear ZT0
 ; Expect spill & fill of ZT0 around call
 ; Before return, expect smstop ZA
-define void @zt0_new_caller_zt0_new_callee(ptr %callee) "aarch64_new_zt0" nounwind {
+define void @zt0_new_caller_zt0_new_callee() "aarch64_new_zt0" nounwind {
 ; CHECK-LABEL: zt0_new_caller_zt0_new_callee:
 ; CHECK:       // %bb.0: // %prelude
 ; CHECK-NEXT:    sub sp, sp, #80
@@ -154,14 +156,14 @@ define void @zt0_new_caller_zt0_new_callee(ptr %callee) "aarch64_new_zt0" nounwi
 ; CHECK-NEXT:    mov x19, sp
 ; CHECK-NEXT:    str zt0, [x19]
 ; CHECK-NEXT:    smstop za
-; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    ldr zt0, [x19]
 ; CHECK-NEXT:    smstop za
 ; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
 ; CHECK-NEXT:    add sp, sp, #80
 ; CHECK-NEXT:    ret
-  call void %callee() "aarch64_new_zt0";
+  call void @callee() "aarch64_new_zt0";
   ret void;
 }
 
@@ -205,7 +207,7 @@ declare {i64, i64} @__arm_sme_state()
 ; Expect commit of lazy-save if ZA is dormant
 ; Expect smstart ZA & clear ZT0
 ; Before return, expect smstop ZA
-define void @zt0_new_caller(ptr %callee) "aarch64_new_zt0" nounwind {
+define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
 ; CHECK-LABEL: zt0_new_caller:
 ; CHECK:       // %bb.0: // %prelude
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
@@ -217,18 +219,18 @@ define void @zt0_new_caller(ptr %callee) "aarch64_new_zt0" nounwind {
 ; CHECK-NEXT:  .LBB8_2:
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    zero { zt0 }
-; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstop za
 ; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void %callee() "aarch64_in_zt0";
+  call void @callee() "aarch64_in_zt0";
   ret void;
 }
 
 ; Expect commit of lazy-save if ZA is dormant
 ; Expect smstart ZA, clear ZA & clear ZT0
 ; Before return, expect smstop ZA
-define void @new_za_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_new_zt0" nounwind {
+define void @new_za_zt0_caller() "aarch64_new_za" "aarch64_new_zt0" nounwind {
 ; CHECK-LABEL: new_za_zt0_caller:
 ; CHECK:       // %bb.0: // %prelude
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
@@ -241,36 +243,36 @@ define void @new_za_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_new_zt0" n
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    zero {za}
 ; CHECK-NEXT:    zero { zt0 }
-; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstop za
 ; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
+  call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
   ret void;
 }
 
 ; Expect clear ZA on entry
-define void @new_za_shared_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_in_zt0" nounwind {
+define void @new_za_shared_zt0_caller() "aarch64_new_za" "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: new_za_shared_zt0_caller:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
 ; CHECK-NEXT:    zero {za}
-; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
+  call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
   ret void;
 }
 
 ; Expect clear ZT0 on entry
-define void @shared_za_new_zt0(ptr %callee) "aarch64_inout_za" "aarch64_new_zt0" nounwind {
+define void @shared_za_new_zt0() "aarch64_inout_za" "aarch64_new_zt0" nounwind {
 ; CHECK-LABEL: shared_za_new_zt0:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
 ; CHECK-NEXT:    zero { zt0 }
-; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
+  call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
   ret void;
 }

diff  --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index f13252f3a4c28..f8c77fcba19cf 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -9,7 +9,6 @@
 
 using namespace llvm;
 using SA = SMEAttrs;
-using CA = SMECallAttrs;
 
 std::unique_ptr<Module> parseIR(const char *IR) {
   static LLVMContext C;
@@ -71,14 +70,15 @@ TEST(SMEAttributes, Constructors) {
   ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_new_zt0\"")
                       ->getFunction("foo"))
                   .isNewZT0());
-
-  auto CallModule = parseIR("declare void @callee()\n"
-                            "define void @foo() {"
-                            "call void @callee() \"aarch64_zt0_undef\"\n"
-                            "ret void\n}");
-  CallBase &Call =
-      cast<CallBase>((CallModule->getFunction("foo")->begin()->front()));
-  ASSERT_TRUE(SMECallAttrs(Call).callsite().hasUndefZT0());
+  ASSERT_TRUE(
+      SA(cast<CallBase>((parseIR("declare void @callee()\n"
+                                 "define void @foo() {"
+                                 "call void @callee() \"aarch64_zt0_undef\"\n"
+                                 "ret void\n}")
+                             ->getFunction("foo")
+                             ->begin()
+                             ->front())))
+          .isUndefZT0());
 
   // Invalid combinations.
   EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible),
@@ -235,7 +235,7 @@ TEST(SMEAttributes, Basics) {
   ASSERT_TRUE(ZT0_Undef.hasZT0State());
   ASSERT_FALSE(ZT0_Undef.hasSharedZAInterface());
   ASSERT_TRUE(ZT0_Undef.hasPrivateZAInterface());
-  ASSERT_TRUE(ZT0_Undef.hasUndefZT0());
+  ASSERT_TRUE(ZT0_Undef.isUndefZT0());
 
   ASSERT_FALSE(SA(SA::Normal).isInZT0());
   ASSERT_FALSE(SA(SA::Normal).isOutZT0());
@@ -248,57 +248,59 @@ TEST(SMEAttributes, Basics) {
 
 TEST(SMEAttributes, Transitions) {
   // Normal -> Normal
-  ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresSMChange());
-  ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresPreservingZT0());
-  ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresDisablingZABeforeCall());
-  ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresEnablingZAAfterCall());
+  ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal)));
+  ASSERT_FALSE(SA(SA::Normal).requiresPreservingZT0(SA(SA::Normal)));
+  ASSERT_FALSE(SA(SA::Normal).requiresDisablingZABeforeCall(SA(SA::Normal)));
+  ASSERT_FALSE(SA(SA::Normal).requiresEnablingZAAfterCall(SA(SA::Normal)));
   // Normal -> Normal + LocallyStreaming
-  ASSERT_FALSE(CA(SA::Normal, SA::Normal | SA::SM_Body).requiresSMChange());
+  ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
 
   // Normal -> Streaming
-  ASSERT_TRUE(CA(SA::Normal, SA::SM_Enabled).requiresSMChange());
+  ASSERT_TRUE(SA(SA::Normal).requiresSMChange(SA(SA::SM_Enabled)));
   // Normal -> Streaming + LocallyStreaming
-  ASSERT_TRUE(CA(SA::Normal, SA::SM_Enabled | SA::SM_Body).requiresSMChange());
+  ASSERT_TRUE(
+      SA(SA::Normal).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body)));
 
   // Normal -> Streaming-compatible
-  ASSERT_FALSE(CA(SA::Normal, SA::SM_Compatible).requiresSMChange());
+  ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::SM_Compatible)));
   // Normal -> Streaming-compatible + LocallyStreaming
   ASSERT_FALSE(
-      CA(SA::Normal, SA::SM_Compatible | SA::SM_Body).requiresSMChange());
+      SA(SA::Normal).requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
 
   // Streaming -> Normal
-  ASSERT_TRUE(CA(SA::SM_Enabled, SA::Normal).requiresSMChange());
+  ASSERT_TRUE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::Normal)));
   // Streaming -> Normal + LocallyStreaming
-  ASSERT_TRUE(CA(SA::SM_Enabled, SA::Normal | SA::SM_Body).requiresSMChange());
+  ASSERT_TRUE(
+      SA(SA::SM_Enabled).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
 
   // Streaming -> Streaming
-  ASSERT_FALSE(CA(SA::SM_Enabled, SA::SM_Enabled).requiresSMChange());
+  ASSERT_FALSE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Enabled)));
   // Streaming -> Streaming + LocallyStreaming
   ASSERT_FALSE(
-      CA(SA::SM_Enabled, SA::SM_Enabled | SA::SM_Body).requiresSMChange());
+      SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body)));
 
   // Streaming -> Streaming-compatible
-  ASSERT_FALSE(CA(SA::SM_Enabled, SA::SM_Compatible).requiresSMChange());
+  ASSERT_FALSE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Compatible)));
   // Streaming -> Streaming-compatible + LocallyStreaming
   ASSERT_FALSE(
-      CA(SA::SM_Enabled, SA::SM_Compatible | SA::SM_Body).requiresSMChange());
+      SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
 
   // Streaming-compatible -> Normal
-  ASSERT_TRUE(CA(SA::SM_Compatible, SA::Normal).requiresSMChange());
+  ASSERT_TRUE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::Normal)));
   ASSERT_TRUE(
-      CA(SA::SM_Compatible, SA::Normal | SA::SM_Body).requiresSMChange());
+      SA(SA::SM_Compatible).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
 
   // Streaming-compatible -> Streaming
-  ASSERT_TRUE(CA(SA::SM_Compatible, SA::SM_Enabled).requiresSMChange());
+  ASSERT_TRUE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Enabled)));
   // Streaming-compatible -> Streaming + LocallyStreaming
   ASSERT_TRUE(
-      CA(SA::SM_Compatible, SA::SM_Enabled | SA::SM_Body).requiresSMChange());
+      SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body)));
 
   // Streaming-compatible -> Streaming-compatible
-  ASSERT_FALSE(CA(SA::SM_Compatible, SA::SM_Compatible).requiresSMChange());
+  ASSERT_FALSE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Compatible)));
   // Streaming-compatible -> Streaming-compatible + LocallyStreaming
-  ASSERT_FALSE(CA(SA::SM_Compatible, SA::SM_Compatible | SA::SM_Body)
-                   .requiresSMChange());
+  ASSERT_FALSE(SA(SA::SM_Compatible)
+                   .requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
 
   SA Private_ZA = SA(SA::Normal);
   SA ZA_Shared = SA(SA::encodeZAState(SA::StateValue::In));
@@ -308,39 +310,37 @@ TEST(SMEAttributes, Transitions) {
   SA Undef_ZT0 = SA(SA::ZT0_Undef);
 
   // Shared ZA -> Private ZA Interface
-  ASSERT_FALSE(CA(ZA_Shared, Private_ZA).requiresDisablingZABeforeCall());
-  ASSERT_TRUE(CA(ZA_Shared, Private_ZA).requiresEnablingZAAfterCall());
+  ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
+  ASSERT_TRUE(ZA_Shared.requiresEnablingZAAfterCall(Private_ZA));
 
   // Shared ZT0 -> Private ZA Interface
-  ASSERT_TRUE(CA(ZT0_Shared, Private_ZA).requiresDisablingZABeforeCall());
-  ASSERT_TRUE(CA(ZT0_Shared, Private_ZA).requiresPreservingZT0());
-  ASSERT_TRUE(CA(ZT0_Shared, Private_ZA).requiresEnablingZAAfterCall());
+  ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
+  ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA));
+  ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
 
   // Shared Undef ZT0 -> Private ZA Interface
   // Note: "Undef ZT0" is a callsite attribute that means ZT0 is undefined at
   // point the of the call.
-  ASSERT_TRUE(
-      CA(ZT0_Shared, Private_ZA, Undef_ZT0).requiresDisablingZABeforeCall());
-  ASSERT_FALSE(CA(ZT0_Shared, Private_ZA, Undef_ZT0).requiresPreservingZT0());
-  ASSERT_TRUE(
-      CA(ZT0_Shared, Private_ZA, Undef_ZT0).requiresEnablingZAAfterCall());
+  ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Undef_ZT0));
+  ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(Undef_ZT0));
+  ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Undef_ZT0));
 
   // Shared ZA & ZT0 -> Private ZA Interface
-  ASSERT_FALSE(CA(ZA_ZT0_Shared, Private_ZA).requiresDisablingZABeforeCall());
-  ASSERT_TRUE(CA(ZA_ZT0_Shared, Private_ZA).requiresPreservingZT0());
-  ASSERT_TRUE(CA(ZA_ZT0_Shared, Private_ZA).requiresEnablingZAAfterCall());
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
+  ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));
+  ASSERT_TRUE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
 
   // Shared ZA -> Shared ZA Interface
-  ASSERT_FALSE(CA(ZA_Shared, ZT0_Shared).requiresDisablingZABeforeCall());
-  ASSERT_FALSE(CA(ZA_Shared, ZT0_Shared).requiresEnablingZAAfterCall());
+  ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
+  ASSERT_FALSE(ZA_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
 
   // Shared ZT0 -> Shared ZA Interface
-  ASSERT_FALSE(CA(ZT0_Shared, ZT0_Shared).requiresDisablingZABeforeCall());
-  ASSERT_FALSE(CA(ZT0_Shared, ZT0_Shared).requiresPreservingZT0());
-  ASSERT_FALSE(CA(ZT0_Shared, ZT0_Shared).requiresEnablingZAAfterCall());
+  ASSERT_FALSE(ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
+  ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
+  ASSERT_FALSE(ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
 
   // Shared ZA & ZT0 -> Shared ZA Interface
-  ASSERT_FALSE(CA(ZA_ZT0_Shared, ZT0_Shared).requiresDisablingZABeforeCall());
-  ASSERT_FALSE(CA(ZA_ZT0_Shared, ZT0_Shared).requiresPreservingZT0());
-  ASSERT_FALSE(CA(ZA_ZT0_Shared, ZT0_Shared).requiresEnablingZAAfterCall());
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
 }


        


More information about the llvm-commits mailing list