[llvm] [AArch64][SME] Split SMECallAttrs out of SMEAttrs (NFC) (PR #137239)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 25 05:57:28 PDT 2025


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/137239

>From 0b9495b81e7373a661e0b0bcc1abcaef56f7c033 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 25 Apr 2025 07:58:43 +0000
Subject: [PATCH] [AArch64][SME] Split SMECallAttrs out of SMEAttrs (NFC)

SMECallAttrs is a new helper class that holds all the SMEAttrs for a
call. The interfaces to query actions needed for the call (e.g. change
streaming mode) have been moved to the SMECallAttrs class.

The main motivation for this change is to make the split between caller,
callee, and callsite attributes more apparent. Places that previously
implicitly checked callsite attributes have been updated to make these
checks explicit. Similarly, places known to only check callee or
callsite attributes have also been updated to make this clear.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  74 ++++++------
 .../AArch64/AArch64TargetTransformInfo.cpp    |  25 +++--
 .../AArch64/Utils/AArch64SMEAttributes.cpp    |  58 +++++-----
 .../AArch64/Utils/AArch64SMEAttributes.h      | 104 ++++++++++++-----
 .../Target/AArch64/SMEAttributesTest.cpp      | 106 +++++++++---------
 5 files changed, 206 insertions(+), 161 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 713f814121aa3..1b0026d5409a4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8653,6 +8653,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
   }
 }
 
+static SMECallAttrs
+getSMECallAttrs(const Function &Function,
+                const TargetLowering::CallLoweringInfo &CLI) {
+  if (CLI.CB)
+    return SMECallAttrs(*CLI.CB);
+  if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
+    return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
+  return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
+}
+
 bool AArch64TargetLowering::isEligibleForTailCallOptimization(
     const CallLoweringInfo &CLI) const {
   CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8671,12 +8681,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
 
   // SME Streaming functions are not eligible for TCO as they may require
   // the streaming mode or ZA to be restored after returning from the call.
-  SMEAttrs CallerAttrs(MF.getFunction());
-  auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
-  if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
-      CallerAttrs.requiresLazySave(CalleeAttrs) ||
-      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
-      CallerAttrs.hasStreamingBody())
+  SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
+  if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
+      CallAttrs.requiresPreservingAllZAState() ||
+      CallAttrs.caller().hasStreamingBody())
     return false;
 
   // Functions using the C or Fast calling convention that have an SVE signature
@@ -8968,14 +8976,13 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
   return TLI.LowerCallTo(CLI).second;
 }
 
-static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
-                               const SMEAttrs &CalleeAttrs) {
-  if (!CallerAttrs.hasStreamingCompatibleInterface() ||
-      CallerAttrs.hasStreamingBody())
+static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
+  if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
+      CallAttrs.caller().hasStreamingBody())
     return AArch64SME::Always;
-  if (CalleeAttrs.hasNonStreamingInterface())
+  if (CallAttrs.calleeOrCallsite().hasNonStreamingInterface())
     return AArch64SME::IfCallerIsStreaming;
-  if (CalleeAttrs.hasStreamingInterface())
+  if (CallAttrs.calleeOrCallsite().hasStreamingInterface())
     return AArch64SME::IfCallerIsNonStreaming;
 
   llvm_unreachable("Unsupported attributes");
@@ -9108,11 +9115,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   }
 
   // Determine whether we need any streaming mode changes.
-  SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
-  if (CLI.CB)
-    CalleeAttrs = SMEAttrs(*CLI.CB);
-  else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
-    CalleeAttrs = SMEAttrs(ES->getSymbol());
+  SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
 
   auto DescribeCallsite =
       [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9127,9 +9130,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     return R;
   };
 
-  bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
-  bool RequiresSaveAllZA =
-      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
+  bool RequiresLazySave = CallAttrs.requiresLazySave();
+  bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
   if (RequiresLazySave) {
     const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
     MachinePointerInfo MPI =
@@ -9157,18 +9159,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
       return DescribeCallsite(R) << " sets up a lazy save for ZA";
     });
   } else if (RequiresSaveAllZA) {
-    assert(!CalleeAttrs.hasSharedZAInterface() &&
+    assert(!CallAttrs.calleeOrCallsite().hasSharedZAInterface() &&
            "Cannot share state that may not exist");
     Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
                                     /*IsSave=*/true);
   }
 
   SDValue PStateSM;
-  bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
+  bool RequiresSMChange = CallAttrs.requiresSMChange();
   if (RequiresSMChange) {
-    if (CallerAttrs.hasStreamingInterfaceOrBody())
+    if (CallAttrs.caller().hasStreamingInterfaceOrBody())
       PStateSM = DAG.getConstant(1, DL, MVT::i64);
-    else if (CallerAttrs.hasNonStreamingInterface())
+    else if (CallAttrs.caller().hasNonStreamingInterface())
       PStateSM = DAG.getConstant(0, DL, MVT::i64);
     else
       PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9185,7 +9187,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   SDValue ZTFrameIdx;
   MachineFrameInfo &MFI = MF.getFrameInfo();
-  bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
+  bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
 
   // If the caller has ZT0 state which will not be preserved by the callee,
   // spill ZT0 before the call.
@@ -9201,7 +9203,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   // If caller shares ZT0 but the callee is not shared ZA, we need to stop
   // PSTATE.ZA before the call if there is no lazy-save active.
-  bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
+  bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
   assert((!DisableZA || !RequiresLazySave) &&
          "Lazy-save should have PSTATE.SM=1 on entry to the function");
 
@@ -9484,8 +9486,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     }
 
     SDValue NewChain = changeStreamingMode(
-        DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
-        getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
+        DAG, DL, CallAttrs.calleeOrCallsite().hasStreamingInterface(), Chain,
+        InGlue, getSMCondition(CallAttrs), PStateSM);
     Chain = NewChain.getValue(0);
     InGlue = NewChain.getValue(1);
   }
@@ -9664,8 +9666,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   if (RequiresSMChange) {
     assert(PStateSM && "Expected a PStateSM to be set");
     Result = changeStreamingMode(
-        DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
-        getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
+        DAG, DL, !CallAttrs.calleeOrCallsite().hasStreamingInterface(), Result,
+        InGlue, getSMCondition(CallAttrs), PStateSM);
 
     if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
       InGlue = Result.getValue(1);
@@ -9675,7 +9677,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     }
   }
 
-  if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
+  if (CallAttrs.requiresEnablingZAAfterCall())
     // Unconditionally resume ZA.
     Result = DAG.getNode(
         AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28552,12 +28554,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
 
   // Checks to allow the use of SME instructions
   if (auto *Base = dyn_cast<CallBase>(&Inst)) {
-    auto CallerAttrs = SMEAttrs(*Inst.getFunction());
-    auto CalleeAttrs = SMEAttrs(*Base);
-    if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
-        CallerAttrs.requiresLazySave(CalleeAttrs) ||
-        CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
-        CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
+    auto CallAttrs = SMECallAttrs(*Base);
+    if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
+        CallAttrs.requiresPreservingZT0() ||
+        CallAttrs.requiresPreservingAllZAState())
       return true;
   }
   return false;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index fcc5eb1c05ba0..287dcb367733c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -268,22 +268,21 @@ const FeatureBitset AArch64TTIImpl::InlineInverseFeatures = {
 
 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
                                          const Function *Callee) const {
-  SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
+  SMECallAttrs CallAttrs(*Caller, *Callee);
 
   // When inlining, we should consider the body of the function, not the
   // interface.
-  if (CalleeAttrs.hasStreamingBody()) {
-    CalleeAttrs.set(SMEAttrs::SM_Compatible, false);
-    CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
+  if (CallAttrs.callee().hasStreamingBody()) {
+    CallAttrs.callee().set(SMEAttrs::SM_Compatible, false);
+    CallAttrs.callee().set(SMEAttrs::SM_Enabled, true);
   }
 
-  if (CalleeAttrs.isNewZA() || CalleeAttrs.isNewZT0())
+  if (CallAttrs.callee().isNewZA() || CallAttrs.callee().isNewZT0())
     return false;
 
-  if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
-      CallerAttrs.requiresSMChange(CalleeAttrs) ||
-      CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
-      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
+  if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
+      CallAttrs.requiresPreservingZT0() ||
+      CallAttrs.requiresPreservingAllZAState()) {
     if (hasPossibleIncompatibleOps(Callee))
       return false;
   }
@@ -349,12 +348,14 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
   // streaming-mode change, and the call to G from F would also require a
   // streaming-mode change, then there is benefit to do the streaming-mode
   // change only once and avoid inlining of G into F.
+
   SMEAttrs FAttrs(*F);
-  SMEAttrs CalleeAttrs(Call);
-  if (FAttrs.requiresSMChange(CalleeAttrs)) {
+  SMECallAttrs CallAttrs(Call);
+
+  if (SMECallAttrs(FAttrs, CallAttrs.calleeOrCallsite()).requiresSMChange()) {
     if (F == Call.getCaller()) // (1)
       return CallPenaltyChangeSM * DefaultCallPenalty;
-    if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
+    if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
       return InlineCallPenaltyChangeSM * DefaultCallPenalty;
   }
 
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 76d2ac6a601e5..1085d618116eb 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -27,15 +27,14 @@ void SMEAttrs::set(unsigned M, bool Enable) {
          "ZA_New and SME_ABI_Routine are mutually exclusive");
 
   assert(
-      (!sharesZA() ||
-       (isNewZA() ^ isInZA() ^ isInOutZA() ^ isOutZA() ^ isPreservesZA())) &&
+      (isNewZA() + isInZA() + isOutZA() + isInOutZA() + isPreservesZA()) <= 1 &&
       "Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
       "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive");
 
   // ZT0 Attrs
   assert(
-      (!sharesZT0() || (isNewZT0() ^ isInZT0() ^ isInOutZT0() ^ isOutZT0() ^
-                        isPreservesZT0())) &&
+      (isNewZT0() + isInZT0() + isOutZT0() + isInOutZT0() + isPreservesZT0()) <=
+          1 &&
       "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
       "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
 
@@ -44,27 +43,6 @@ void SMEAttrs::set(unsigned M, bool Enable) {
          "interface");
 }
 
-SMEAttrs::SMEAttrs(const CallBase &CB) {
-  *this = SMEAttrs(CB.getAttributes());
-  if (auto *F = CB.getCalledFunction()) {
-    set(SMEAttrs(*F).Bitmask | SMEAttrs(F->getName()).Bitmask);
-  }
-}
-
-SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
-  if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
-    Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
-  if (FuncName == "__arm_tpidr2_restore")
-    Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
-               SMEAttrs::SME_ABI_Routine;
-  if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
-      FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
-    Bitmask |= SMEAttrs::SM_Compatible;
-  if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
-      FuncName == "__arm_sme_state_size")
-    Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
-}
-
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
   Bitmask = 0;
   if (Attrs.hasFnAttr("aarch64_pstate_sm_enabled"))
@@ -99,17 +77,39 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
     Bitmask |= encodeZT0State(StateValue::New);
 }
 
-bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
-  if (Callee.hasStreamingCompatibleInterface())
+void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
+  unsigned KnownAttrs = SMEAttrs::Normal;
+  if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
+    KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
+  if (FuncName == "__arm_tpidr2_restore")
+    KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
+                  SMEAttrs::SME_ABI_Routine;
+  if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
+      FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
+    KnownAttrs |= SMEAttrs::SM_Compatible;
+  if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
+      FuncName == "__arm_sme_state_size")
+    KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
+  set(KnownAttrs, /*Enable=*/true);
+}
+
+bool SMECallAttrs::requiresSMChange() const {
+  if ((Callsite | Callee).hasStreamingCompatibleInterface())
     return false;
 
   // Both non-streaming
-  if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface())
+  if (Caller.hasNonStreamingInterfaceAndBody() &&
+      (Callsite | Callee).hasNonStreamingInterface())
     return false;
 
   // Both streaming
-  if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface())
+  if (Caller.hasStreamingInterfaceOrBody() &&
+      (Callsite | Callee).hasStreamingInterface())
     return false;
 
   return true;
 }
+
+SMECallAttrs::SMECallAttrs(const CallBase &CB)
+    : SMECallAttrs(*CB.getFunction(), CB.getCalledFunction(),
+                   CB.getAttributes()) {}
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 1691d4fec8b68..791bb891a18b7 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -18,12 +18,9 @@ class CallBase;
 class AttributeList;
 
 /// SMEAttrs is a utility class to parse the SME ACLE attributes on functions.
-/// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM. It
-/// has interfaces to query whether a streaming mode change or lazy-save
-/// mechanism is required when going from one function to another (e.g. through
-/// a call).
+/// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM.
 class SMEAttrs {
-  unsigned Bitmask;
+  unsigned Bitmask = Normal;
 
 public:
   enum class StateValue {
@@ -43,18 +40,23 @@ class SMEAttrs {
     SM_Body = 1 << 2,         // aarch64_pstate_sm_body
     SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
     ZA_State_Agnostic = 1 << 4,
-    ZT0_Undef = 1 << 5,       // Use to mark ZT0 as undef to avoid spills
+    ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills
     ZA_Shift = 6,
     ZA_Mask = 0b111 << ZA_Shift,
     ZT0_Shift = 9,
     ZT0_Mask = 0b111 << ZT0_Shift
   };
 
-  SMEAttrs(unsigned Mask = Normal) : Bitmask(0) { set(Mask); }
-  SMEAttrs(const Function &F) : SMEAttrs(F.getAttributes()) {}
-  SMEAttrs(const CallBase &CB);
+  SMEAttrs() = default;
+  SMEAttrs(unsigned Mask) { set(Mask); }
+  SMEAttrs(const Function *F)
+      : SMEAttrs(F ? F->getAttributes() : AttributeList()) {
+    if (F)
+      addKnownFunctionAttrs(F->getName());
+  }
+  SMEAttrs(const Function &F) : SMEAttrs(&F) {}
   SMEAttrs(const AttributeList &L);
-  SMEAttrs(StringRef FuncName);
+  SMEAttrs(StringRef FuncName) { addKnownFunctionAttrs(FuncName); };
 
   void set(unsigned M, bool Enable = true);
 
@@ -74,10 +76,6 @@ class SMEAttrs {
     return hasNonStreamingInterface() && !hasStreamingBody();
   }
 
-  /// \return true if a call from Caller -> Callee requires a change in
-  /// streaming mode.
-  bool requiresSMChange(const SMEAttrs &Callee) const;
-
   // Interfaces to query ZA
   static StateValue decodeZAState(unsigned Bitmask) {
     return static_cast<StateValue>((Bitmask & ZA_Mask) >> ZA_Shift);
@@ -104,10 +102,7 @@ class SMEAttrs {
     return !hasSharedZAInterface() && !hasAgnosticZAInterface();
   }
   bool hasZAState() const { return isNewZA() || sharesZA(); }
-  bool requiresLazySave(const SMEAttrs &Callee) const {
-    return hasZAState() && Callee.hasPrivateZAInterface() &&
-           !(Callee.Bitmask & SME_ABI_Routine);
-  }
+  bool isSMEABIRoutine() const { return Bitmask & SME_ABI_Routine; }
 
   // Interfaces to query ZT0 State
   static StateValue decodeZT0State(unsigned Bitmask) {
@@ -126,27 +121,76 @@ class SMEAttrs {
   bool isPreservesZT0() const {
     return decodeZT0State(Bitmask) == StateValue::Preserved;
   }
-  bool isUndefZT0() const { return Bitmask & ZT0_Undef; }
+  bool hasUndefZT0() const { return Bitmask & ZT0_Undef; }
   bool sharesZT0() const {
     StateValue State = decodeZT0State(Bitmask);
     return State == StateValue::In || State == StateValue::Out ||
            State == StateValue::InOut || State == StateValue::Preserved;
   }
   bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
-  bool requiresPreservingZT0(const SMEAttrs &Callee) const {
-    return hasZT0State() && !Callee.isUndefZT0() && !Callee.sharesZT0() &&
-           !Callee.hasAgnosticZAInterface();
+
+  SMEAttrs operator|(SMEAttrs Other) const {
+    SMEAttrs Merged(*this);
+    Merged.set(Other.Bitmask, /*Enable=*/true);
+    return Merged;
   }
-  bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
-    return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() &&
-           !(Callee.Bitmask & SME_ABI_Routine);
+
+private:
+  void addKnownFunctionAttrs(StringRef FuncName);
+};
+
+/// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has
+/// interfaces to query whether a streaming mode change or lazy-save mechanism
+/// is required when going from one function to another (e.g. through a call).
+class SMECallAttrs {
+  SMEAttrs Caller;
+  SMEAttrs Callee;
+  SMEAttrs Callsite;
+
+public:
+  SMECallAttrs(SMEAttrs Caller, SMEAttrs Callee,
+               SMEAttrs Callsite = SMEAttrs::Normal)
+      : Caller(Caller), Callee(Callee), Callsite(Callsite) {}
+
+  SMECallAttrs(const CallBase &CB);
+
+  SMEAttrs &caller() { return Caller; }
+  SMEAttrs &callee() { return Callee; }
+  SMEAttrs &callsite() { return Callsite; }
+  SMEAttrs const &caller() const { return Caller; }
+  SMEAttrs const &callee() const { return Callee; }
+  SMEAttrs const &callsite() const { return Callsite; }
+  SMEAttrs calleeOrCallsite() const { return Callsite | Callee; }
+
+  /// \return true if a call from Caller -> Callee requires a change in
+  /// streaming mode.
+  bool requiresSMChange() const;
+
+  bool requiresLazySave() const {
+    return Caller.hasZAState() && (Callsite | Callee).hasPrivateZAInterface() &&
+           !Callee.isSMEABIRoutine();
   }
-  bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
-    return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee);
+
+  bool requiresPreservingZT0() const {
+    return Caller.hasZT0State() && !Callsite.hasUndefZT0() &&
+           !(Callsite | Callee).sharesZT0() &&
+           !(Callsite | Callee).hasAgnosticZAInterface();
   }
-  bool requiresPreservingAllZAState(const SMEAttrs &Callee) const {
-    return hasAgnosticZAInterface() && !Callee.hasAgnosticZAInterface() &&
-           !(Callee.Bitmask & SME_ABI_Routine);
+
+  bool requiresDisablingZABeforeCall() const {
+    return Caller.hasZT0State() && !Caller.hasZAState() &&
+           (Callsite | Callee).hasPrivateZAInterface() &&
+           !Callee.isSMEABIRoutine();
+  }
+
+  bool requiresEnablingZAAfterCall() const {
+    return requiresLazySave() || requiresDisablingZABeforeCall();
+  }
+
+  bool requiresPreservingAllZAState() const {
+    return Caller.hasAgnosticZAInterface() &&
+           !(Callsite | Callee).hasAgnosticZAInterface() &&
+           !Callee.isSMEABIRoutine();
   }
 };
 
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index f8c77fcba19cf..f13252f3a4c28 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -9,6 +9,7 @@
 
 using namespace llvm;
 using SA = SMEAttrs;
+using CA = SMECallAttrs;
 
 std::unique_ptr<Module> parseIR(const char *IR) {
   static LLVMContext C;
@@ -70,15 +71,14 @@ TEST(SMEAttributes, Constructors) {
   ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_new_zt0\"")
                       ->getFunction("foo"))
                   .isNewZT0());
-  ASSERT_TRUE(
-      SA(cast<CallBase>((parseIR("declare void @callee()\n"
-                                 "define void @foo() {"
-                                 "call void @callee() \"aarch64_zt0_undef\"\n"
-                                 "ret void\n}")
-                             ->getFunction("foo")
-                             ->begin()
-                             ->front())))
-          .isUndefZT0());
+
+  auto CallModule = parseIR("declare void @callee()\n"
+                            "define void @foo() {"
+                            "call void @callee() \"aarch64_zt0_undef\"\n"
+                            "ret void\n}");
+  CallBase &Call =
+      cast<CallBase>((CallModule->getFunction("foo")->begin()->front()));
+  ASSERT_TRUE(SMECallAttrs(Call).callsite().hasUndefZT0());
 
   // Invalid combinations.
   EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible),
@@ -235,7 +235,7 @@ TEST(SMEAttributes, Basics) {
   ASSERT_TRUE(ZT0_Undef.hasZT0State());
   ASSERT_FALSE(ZT0_Undef.hasSharedZAInterface());
   ASSERT_TRUE(ZT0_Undef.hasPrivateZAInterface());
-  ASSERT_TRUE(ZT0_Undef.isUndefZT0());
+  ASSERT_TRUE(ZT0_Undef.hasUndefZT0());
 
   ASSERT_FALSE(SA(SA::Normal).isInZT0());
   ASSERT_FALSE(SA(SA::Normal).isOutZT0());
@@ -248,59 +248,57 @@ TEST(SMEAttributes, Basics) {
 
 TEST(SMEAttributes, Transitions) {
   // Normal -> Normal
-  ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal)));
-  ASSERT_FALSE(SA(SA::Normal).requiresPreservingZT0(SA(SA::Normal)));
-  ASSERT_FALSE(SA(SA::Normal).requiresDisablingZABeforeCall(SA(SA::Normal)));
-  ASSERT_FALSE(SA(SA::Normal).requiresEnablingZAAfterCall(SA(SA::Normal)));
+  ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresSMChange());
+  ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresPreservingZT0());
+  ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresDisablingZABeforeCall());
+  ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresEnablingZAAfterCall());
   // Normal -> Normal + LocallyStreaming
-  ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
+  ASSERT_FALSE(CA(SA::Normal, SA::Normal | SA::SM_Body).requiresSMChange());
 
   // Normal -> Streaming
-  ASSERT_TRUE(SA(SA::Normal).requiresSMChange(SA(SA::SM_Enabled)));
+  ASSERT_TRUE(CA(SA::Normal, SA::SM_Enabled).requiresSMChange());
   // Normal -> Streaming + LocallyStreaming
-  ASSERT_TRUE(
-      SA(SA::Normal).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body)));
+  ASSERT_TRUE(CA(SA::Normal, SA::SM_Enabled | SA::SM_Body).requiresSMChange());
 
   // Normal -> Streaming-compatible
-  ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::SM_Compatible)));
+  ASSERT_FALSE(CA(SA::Normal, SA::SM_Compatible).requiresSMChange());
   // Normal -> Streaming-compatible + LocallyStreaming
   ASSERT_FALSE(
-      SA(SA::Normal).requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
+      CA(SA::Normal, SA::SM_Compatible | SA::SM_Body).requiresSMChange());
 
   // Streaming -> Normal
-  ASSERT_TRUE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::Normal)));
+  ASSERT_TRUE(CA(SA::SM_Enabled, SA::Normal).requiresSMChange());
   // Streaming -> Normal + LocallyStreaming
-  ASSERT_TRUE(
-      SA(SA::SM_Enabled).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
+  ASSERT_TRUE(CA(SA::SM_Enabled, SA::Normal | SA::SM_Body).requiresSMChange());
 
   // Streaming -> Streaming
-  ASSERT_FALSE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Enabled)));
+  ASSERT_FALSE(CA(SA::SM_Enabled, SA::SM_Enabled).requiresSMChange());
   // Streaming -> Streaming + LocallyStreaming
   ASSERT_FALSE(
-      SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body)));
+      CA(SA::SM_Enabled, SA::SM_Enabled | SA::SM_Body).requiresSMChange());
 
   // Streaming -> Streaming-compatible
-  ASSERT_FALSE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Compatible)));
+  ASSERT_FALSE(CA(SA::SM_Enabled, SA::SM_Compatible).requiresSMChange());
   // Streaming -> Streaming-compatible + LocallyStreaming
   ASSERT_FALSE(
-      SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
+      CA(SA::SM_Enabled, SA::SM_Compatible | SA::SM_Body).requiresSMChange());
 
   // Streaming-compatible -> Normal
-  ASSERT_TRUE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::Normal)));
+  ASSERT_TRUE(CA(SA::SM_Compatible, SA::Normal).requiresSMChange());
   ASSERT_TRUE(
-      SA(SA::SM_Compatible).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
+      CA(SA::SM_Compatible, SA::Normal | SA::SM_Body).requiresSMChange());
 
   // Streaming-compatible -> Streaming
-  ASSERT_TRUE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Enabled)));
+  ASSERT_TRUE(CA(SA::SM_Compatible, SA::SM_Enabled).requiresSMChange());
   // Streaming-compatible -> Streaming + LocallyStreaming
   ASSERT_TRUE(
-      SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body)));
+      CA(SA::SM_Compatible, SA::SM_Enabled | SA::SM_Body).requiresSMChange());
 
   // Streaming-compatible -> Streaming-compatible
-  ASSERT_FALSE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Compatible)));
+  ASSERT_FALSE(CA(SA::SM_Compatible, SA::SM_Compatible).requiresSMChange());
   // Streaming-compatible -> Streaming-compatible + LocallyStreaming
-  ASSERT_FALSE(SA(SA::SM_Compatible)
-                   .requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
+  ASSERT_FALSE(CA(SA::SM_Compatible, SA::SM_Compatible | SA::SM_Body)
+                   .requiresSMChange());
 
   SA Private_ZA = SA(SA::Normal);
   SA ZA_Shared = SA(SA::encodeZAState(SA::StateValue::In));
@@ -310,37 +308,39 @@ TEST(SMEAttributes, Transitions) {
   SA Undef_ZT0 = SA(SA::ZT0_Undef);
 
   // Shared ZA -> Private ZA Interface
-  ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
-  ASSERT_TRUE(ZA_Shared.requiresEnablingZAAfterCall(Private_ZA));
+  ASSERT_FALSE(CA(ZA_Shared, Private_ZA).requiresDisablingZABeforeCall());
+  ASSERT_TRUE(CA(ZA_Shared, Private_ZA).requiresEnablingZAAfterCall());
 
   // Shared ZT0 -> Private ZA Interface
-  ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
-  ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA));
-  ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
+  ASSERT_TRUE(CA(ZT0_Shared, Private_ZA).requiresDisablingZABeforeCall());
+  ASSERT_TRUE(CA(ZT0_Shared, Private_ZA).requiresPreservingZT0());
+  ASSERT_TRUE(CA(ZT0_Shared, Private_ZA).requiresEnablingZAAfterCall());
 
   // Shared Undef ZT0 -> Private ZA Interface
   // Note: "Undef ZT0" is a callsite attribute that means ZT0 is undefined at
   // point the of the call.
-  ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Undef_ZT0));
-  ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(Undef_ZT0));
-  ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Undef_ZT0));
+  ASSERT_TRUE(
+      CA(ZT0_Shared, Private_ZA, Undef_ZT0).requiresDisablingZABeforeCall());
+  ASSERT_FALSE(CA(ZT0_Shared, Private_ZA, Undef_ZT0).requiresPreservingZT0());
+  ASSERT_TRUE(
+      CA(ZT0_Shared, Private_ZA, Undef_ZT0).requiresEnablingZAAfterCall());
 
   // Shared ZA & ZT0 -> Private ZA Interface
-  ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
-  ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));
-  ASSERT_TRUE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
+  ASSERT_FALSE(CA(ZA_ZT0_Shared, Private_ZA).requiresDisablingZABeforeCall());
+  ASSERT_TRUE(CA(ZA_ZT0_Shared, Private_ZA).requiresPreservingZT0());
+  ASSERT_TRUE(CA(ZA_ZT0_Shared, Private_ZA).requiresEnablingZAAfterCall());
 
   // Shared ZA -> Shared ZA Interface
-  ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
-  ASSERT_FALSE(ZA_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
+  ASSERT_FALSE(CA(ZA_Shared, ZT0_Shared).requiresDisablingZABeforeCall());
+  ASSERT_FALSE(CA(ZA_Shared, ZT0_Shared).requiresEnablingZAAfterCall());
 
   // Shared ZT0 -> Shared ZA Interface
-  ASSERT_FALSE(ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
-  ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
-  ASSERT_FALSE(ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
+  ASSERT_FALSE(CA(ZT0_Shared, ZT0_Shared).requiresDisablingZABeforeCall());
+  ASSERT_FALSE(CA(ZT0_Shared, ZT0_Shared).requiresPreservingZT0());
+  ASSERT_FALSE(CA(ZT0_Shared, ZT0_Shared).requiresEnablingZAAfterCall());
 
   // Shared ZA & ZT0 -> Shared ZA Interface
-  ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
-  ASSERT_FALSE(ZA_ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
-  ASSERT_FALSE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
+  ASSERT_FALSE(CA(ZA_ZT0_Shared, ZT0_Shared).requiresDisablingZABeforeCall());
+  ASSERT_FALSE(CA(ZA_ZT0_Shared, ZT0_Shared).requiresPreservingZT0());
+  ASSERT_FALSE(CA(ZA_ZT0_Shared, ZT0_Shared).requiresEnablingZAAfterCall());
 }



More information about the llvm-commits mailing list