[llvm] [AArch64][SME] Propagate desired ZA states in the MachineSMEABIPass (PR #149510)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 12 04:17:49 PDT 2025


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/149510

>From 1a80766cfe163175e5d277378e12c86f6c15ec23 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 9 Sep 2025 13:45:26 +0000
Subject: [PATCH 1/2] Pass CodeGenOptLevel to MachineSMEABI pass (part of
 #149065)

Change-Id: Idef5b1e2a45585f97897fc11c4f237996edb7c8b
---
 llvm/lib/Target/AArch64/AArch64.h                |  2 +-
 llvm/lib/Target/AArch64/AArch64TargetMachine.cpp |  6 +++---
 llvm/lib/Target/AArch64/MachineSMEABIPass.cpp    | 10 ++++++++--
 3 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index 8d0ff41fc8c08..139684172f1bb 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -60,7 +60,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass();
 FunctionPass *createAArch64CollectLOHPass();
 FunctionPass *createSMEABIPass();
 FunctionPass *createSMEPeepholeOptPass();
-FunctionPass *createMachineSMEABIPass();
+FunctionPass *createMachineSMEABIPass(CodeGenOptLevel);
 ModulePass *createSVEIntrinsicOptsPass();
 InstructionSelector *
 createAArch64InstructionSelector(const AArch64TargetMachine &,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index dde1d88403bfe..6606a731d357b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -770,8 +770,8 @@ bool AArch64PassConfig::addGlobalInstructionSelect() {
 }
 
 void AArch64PassConfig::addMachineSSAOptimization() {
-  if (EnableNewSMEABILowering && TM->getOptLevel() != CodeGenOptLevel::None)
-    addPass(createMachineSMEABIPass());
+  if (TM->getOptLevel() != CodeGenOptLevel::None && EnableNewSMEABILowering)
+    addPass(createMachineSMEABIPass(TM->getOptLevel()));
 
   if (TM->getOptLevel() != CodeGenOptLevel::None && EnableSMEPeepholeOpt)
     addPass(createSMEPeepholeOptPass());
@@ -804,7 +804,7 @@ bool AArch64PassConfig::addILPOpts() {
 
 void AArch64PassConfig::addPreRegAlloc() {
   if (TM->getOptLevel() == CodeGenOptLevel::None && EnableNewSMEABILowering)
-    addPass(createMachineSMEABIPass());
+    addPass(createMachineSMEABIPass(CodeGenOptLevel::None));
 
   // Change dead register definitions to refer to the zero register.
   if (TM->getOptLevel() != CodeGenOptLevel::None &&
diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index cced0faa28889..7779bd0952f87 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -238,7 +238,8 @@ getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
 struct MachineSMEABI : public MachineFunctionPass {
   inline static char ID = 0;
 
-  MachineSMEABI() : MachineFunctionPass(ID) {}
+  MachineSMEABI(CodeGenOptLevel OptLevel = CodeGenOptLevel::Default)
+      : MachineFunctionPass(ID), OptLevel(OptLevel) {}
 
   bool runOnMachineFunction(MachineFunction &MF) override;
 
@@ -329,12 +330,15 @@ struct MachineSMEABI : public MachineFunctionPass {
                          MachineBasicBlock::iterator MBBI, DebugLoc DL);
 
 private:
+  CodeGenOptLevel OptLevel = CodeGenOptLevel::Default;
+
   MachineFunction *MF = nullptr;
   const AArch64Subtarget *Subtarget = nullptr;
   const AArch64RegisterInfo *TRI = nullptr;
   const AArch64FunctionInfo *AFI = nullptr;
   const TargetInstrInfo *TII = nullptr;
   MachineRegisterInfo *MRI = nullptr;
+  MachineLoopInfo *MLI = nullptr;
 };
 
 FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
@@ -877,4 +881,6 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
   return true;
 }
 
-FunctionPass *llvm::createMachineSMEABIPass() { return new MachineSMEABI(); }
+FunctionPass *llvm::createMachineSMEABIPass(CodeGenOptLevel OptLevel) {
+  return new MachineSMEABI(OptLevel);
+}

>From 6266ecb5654406372060ac858421fce8bd1f91e0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 15 Jul 2025 17:00:04 +0000
Subject: [PATCH 2/2] [AArch64][SME] Propagate desired ZA states in the
 MachineSMEABIPass

This patch adds a propagation step to the MachineSMEABIPass that
propagates desired ZA states forwards/backwards (from predecessors to
successors, or vice versa).

The aim of this is to pick better ZA states for edge bundles, as when
many (or all) blocks in a bundle do not have a preferred ZA state, the
ZA state assigned to a bundle can be less than ideal.

An important case is nested loops, where only the inner loop has a
preferred ZA state. Here we'd like to propagate the ZA state up from the
inner loop to the outer loops (to avoid saves/restores in any loop).

Change-Id: I39f9c7d7608e2fa070be2fb88351b4d1d0079041
---
 llvm/lib/Target/AArch64/MachineSMEABIPass.cpp | 132 +++++++-
 llvm/test/CodeGen/AArch64/sme-agnostic-za.ll  |   9 +-
 .../CodeGen/AArch64/sme-za-control-flow.ll    | 107 +++----
 .../test/CodeGen/AArch64/sme-za-exceptions.ll |  36 +--
 .../sme-za-function-with-many-blocks.ll       | 296 ++++++++++++++++++
 .../AArch64/sme-za-lazy-save-buffer.ll        | 110 +++----
 6 files changed, 504 insertions(+), 186 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sme-za-function-with-many-blocks.ll

diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index 7779bd0952f87..50134be691665 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -121,8 +121,10 @@ struct InstInfo {
 /// Contains the needed ZA state for each instruction in a block. Instructions
 /// that do not require a ZA state are not recorded.
 struct BlockInfo {
-  ZAState FixedEntryState{ZAState::ANY};
   SmallVector<InstInfo> Insts;
+  ZAState FixedEntryState{ZAState::ANY};
+  ZAState DesiredIncomingState{ZAState::ANY};
+  ZAState DesiredOutgoingState{ZAState::ANY};
   LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
   LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
 };
@@ -268,6 +270,11 @@ struct MachineSMEABI : public MachineFunctionPass {
                           const EdgeBundles &Bundles,
                           ArrayRef<ZAState> BundleStates);
 
+  /// Propagates desired states forwards (from predecessors -> successors) if
+  /// \p Forwards, otherwise, propagates backwards (from successors ->
+  /// predecessors).
+  void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);
+
   // Emission routines for private and shared ZA functions (using lazy saves).
   void emitNewZAPrologue(MachineBasicBlock &MBB,
                          MachineBasicBlock::iterator MBBI);
@@ -411,12 +418,70 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
 
     // Reverse vector (as we had to iterate backwards for liveness).
     std::reverse(Block.Insts.begin(), Block.Insts.end());
+
+    // Record the desired states on entry/exit of this block. These are the
+    // states that would not incur a state transition.
+    if (!Block.Insts.empty()) {
+      Block.DesiredIncomingState = Block.Insts.front().NeededState;
+      Block.DesiredOutgoingState = Block.Insts.back().NeededState;
+    }
   }
 
   return FunctionInfo{std::move(Blocks), AfterSMEProloguePt,
                       PhysLiveRegsAfterSMEPrologue};
 }
 
+void MachineSMEABI::propagateDesiredStates(FunctionInfo &FnInfo,
+                                           bool Forwards) {
+  // If `Forwards`, this propagates desired states from predecessors to
+  // successors, otherwise, this propagates states from successors to
+  // predecessors.
+  auto GetBlockState = [](BlockInfo &Block, bool Incoming) -> ZAState & {
+    return Incoming ? Block.DesiredIncomingState : Block.DesiredOutgoingState;
+  };
+
+  SmallVector<MachineBasicBlock *> Worklist;
+  for (auto [BlockID, BlockInfo] : enumerate(FnInfo.Blocks)) {
+    if (!isLegalEdgeBundleZAState(GetBlockState(BlockInfo, Forwards)))
+      Worklist.push_back(MF->getBlockNumbered(BlockID));
+  }
+
+  while (!Worklist.empty()) {
+    MachineBasicBlock *MBB = Worklist.pop_back_val();
+    auto &BlockInfo = FnInfo.Blocks[MBB->getNumber()];
+
+    // Pick a legal edge bundle state that matches the majority of
+    // predecessors/successors.
+    int StateCounts[ZAState::NUM_ZA_STATE] = {0};
+    for (MachineBasicBlock *PredOrSucc :
+         Forwards ? predecessors(MBB) : successors(MBB)) {
+      auto &PredOrSuccBlockInfo = FnInfo.Blocks[PredOrSucc->getNumber()];
+      auto ZAState = GetBlockState(PredOrSuccBlockInfo, !Forwards);
+      if (isLegalEdgeBundleZAState(ZAState))
+        StateCounts[ZAState]++;
+    }
+
+    ZAState PropagatedState = ZAState(max_element(StateCounts) - StateCounts);
+    auto &CurrentState = GetBlockState(BlockInfo, Forwards);
+    if (PropagatedState != CurrentState) {
+      CurrentState = PropagatedState;
+      auto &OtherState = GetBlockState(BlockInfo, !Forwards);
+      // Propagate to the incoming/outgoing state if that is also "ANY".
+      if (OtherState == ZAState::ANY)
+        OtherState = PropagatedState;
+      // Push any successors/predecessors that may need updating to the
+      // worklist.
+      for (MachineBasicBlock *SuccOrPred :
+           Forwards ? successors(MBB) : predecessors(MBB)) {
+        auto &SuccOrPredBlockInfo = FnInfo.Blocks[SuccOrPred->getNumber()];
+        if (!isLegalEdgeBundleZAState(
+                GetBlockState(SuccOrPredBlockInfo, Forwards)))
+          Worklist.push_back(SuccOrPred);
+      }
+    }
+  }
+}
+
 /// Assigns each edge bundle a ZA state based on the needed states of blocks
 /// that have incoming or outgoing edges in that bundle.
 SmallVector<ZAState>
@@ -429,40 +494,36 @@ MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
     // Attempt to assign a ZA state for this bundle that minimizes state
     // transitions. Edges within loops are given a higher weight as we assume
     // they will be executed more than once.
-    // TODO: We should propagate desired incoming/outgoing states through blocks
-    // that have the "ANY" state first to make better global decisions.
     int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
     for (unsigned BlockID : Bundles.getBlocks(I)) {
       LLVM_DEBUG(dbgs() << "- bb." << BlockID);
 
       const BlockInfo &Block = FnInfo.Blocks[BlockID];
-      if (Block.Insts.empty()) {
-        LLVM_DEBUG(dbgs() << " (no state preference)\n");
-        continue;
-      }
       bool InEdge = Bundles.getBundle(BlockID, /*Out=*/false) == I;
       bool OutEdge = Bundles.getBundle(BlockID, /*Out=*/true) == I;
 
-      ZAState DesiredIncomingState = Block.Insts.front().NeededState;
-      if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
-        EdgeStateCounts[DesiredIncomingState]++;
+      bool LegalInEdge =
+          InEdge && isLegalEdgeBundleZAState(Block.DesiredIncomingState);
+      bool LegalOutEgde =
+          OutEdge && isLegalEdgeBundleZAState(Block.DesiredOutgoingState);
+      if (LegalInEdge) {
         LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
-                          << getZAStateString(DesiredIncomingState));
+                          << getZAStateString(Block.DesiredIncomingState));
+        EdgeStateCounts[Block.DesiredIncomingState]++;
       }
-      ZAState DesiredOutgoingState = Block.Insts.back().NeededState;
-      if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) {
-        EdgeStateCounts[DesiredOutgoingState]++;
+      if (LegalOutEgde) {
         LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
-                          << getZAStateString(DesiredOutgoingState));
+                          << getZAStateString(Block.DesiredOutgoingState));
+        EdgeStateCounts[Block.DesiredOutgoingState]++;
       }
+      if (!LegalInEdge && !LegalOutEgde)
+        LLVM_DEBUG(dbgs() << " (no state preference)");
       LLVM_DEBUG(dbgs() << '\n');
     }
 
     ZAState BundleState =
         ZAState(max_element(EdgeStateCounts) - EdgeStateCounts);
 
-    // Force ZA to be active in bundles that don't have a preferred state.
-    // TODO: Something better here (to avoid extra mode switches).
     if (BundleState == ZAState::ANY)
       BundleState = ZAState::ACTIVE;
 
@@ -858,6 +919,43 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
       getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
 
   FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
+
+  if (OptLevel != CodeGenOptLevel::None) {
+    // Propagate desired states forwards then backwards. We propagate forwards
+    // first as this propagates desired states from inner to outer loops.
+    // Backwards propagation is then used to fill in any gaps. Note: Doing both
+    // in one step can give poor results. For example:
+    //
+    //    ┌─────┐
+    //  ┌─┤ BB0 ◄───┐
+    //  │ └─┬───┘   │
+    //  │ ┌─▼───◄──┐│
+    //  │ │ BB1 │  ││
+    //  │ └─┬┬──┘  ││
+    //  │   │└─────┘│
+    //  │ ┌─▼───┐   │
+    //  │ │ BB2 ├───┘
+    //  │ └─┬───┘
+    //  │ ┌─▼───┐
+    //  └─► BB3 │
+    //    └─────┘
+    //
+    // If:
+    // - "BB0" and "BB2" (outer loop) has no state preference
+    // - "BB1" (inner loop) desires the ACTIVE state on entry/exit
+    // - "BB3" desires the LOCAL_SAVED state on entry
+    //
+    // If we propagate forwards first, ACTIVE is propagated from BB1 to BB2,
+    // then from BB2 to BB0. Which results in the inner and outer loops having
+    // the "ACTIVE" state. This avoids any state changes in the loops.
+    //
+    // If we propagate backwards first, we _could_ propagate LOCAL_SAVED from
+    // BB3 to BB0, which would result in a transition from ACTIVE -> LOCAL_SAVED
+    // in the outer loop.
+    for (bool Forwards : {true, false})
+      propagateDesiredStates(FnInfo, Forwards);
+  }
+
   SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);
 
   EmitContext Context;
diff --git a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
index a0a14f2ffae3f..077e9b5ced624 100644
--- a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
+++ b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
@@ -361,7 +361,6 @@ define i64  @test_many_callee_arguments(
   ret i64 %ret
 }
 
-; FIXME: The new lowering should avoid saves/restores in the probing loop.
 define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_state_agnostic" "probe-stack"="inline-asm" "stack-probe-size"="65536"{
 ; CHECK-LABEL: agnostic_za_buffer_alloc_with_stack_probes:
 ; CHECK:       // %bb.0:
@@ -399,18 +398,14 @@ define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_s
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_state_size
 ; CHECK-NEWLOWERING-NEXT:    mov x8, sp
 ; CHECK-NEWLOWERING-NEXT:    sub x19, x8, x0
+; CHECK-NEWLOWERING-NEXT:    mov x0, x19
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_save
 ; CHECK-NEWLOWERING-NEXT:  .LBB7_1: // =>This Inner Loop Header: Depth=1
 ; CHECK-NEWLOWERING-NEXT:    sub sp, sp, #16, lsl #12 // =65536
 ; CHECK-NEWLOWERING-NEXT:    cmp sp, x19
-; CHECK-NEWLOWERING-NEXT:    mov x0, x19
-; CHECK-NEWLOWERING-NEXT:    mrs x8, NZCV
-; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_save
-; CHECK-NEWLOWERING-NEXT:    msr NZCV, x8
 ; CHECK-NEWLOWERING-NEXT:    b.le .LBB7_3
 ; CHECK-NEWLOWERING-NEXT:  // %bb.2: // in Loop: Header=BB7_1 Depth=1
-; CHECK-NEWLOWERING-NEXT:    mov x0, x19
 ; CHECK-NEWLOWERING-NEXT:    str xzr, [sp]
-; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_restore
 ; CHECK-NEWLOWERING-NEXT:    b .LBB7_1
 ; CHECK-NEWLOWERING-NEXT:  .LBB7_3:
 ; CHECK-NEWLOWERING-NEXT:    mov sp, x19
diff --git a/llvm/test/CodeGen/AArch64/sme-za-control-flow.ll b/llvm/test/CodeGen/AArch64/sme-za-control-flow.ll
index 18ea07e38fe89..c753e9c569d22 100644
--- a/llvm/test/CodeGen/AArch64/sme-za-control-flow.ll
+++ b/llvm/test/CodeGen/AArch64/sme-za-control-flow.ll
@@ -228,65 +228,34 @@ exit:
   ret void
 }
 
-; FIXME: The codegen for this case could be improved (by tuning weights).
-; Here the ZA save has been hoisted out of the conditional, but would be better
-; to sink it.
 define void @cond_private_za_call(i1 %cond) "aarch64_inout_za" nounwind {
-; CHECK-LABEL: cond_private_za_call:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
-; CHECK-NEXT:    mov x29, sp
-; CHECK-NEXT:    sub sp, sp, #16
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    mov x9, sp
-; CHECK-NEXT:    msub x9, x8, x8, x9
-; CHECK-NEXT:    mov sp, x9
-; CHECK-NEXT:    stp x9, x8, [x29, #-16]
-; CHECK-NEXT:    tbz w0, #0, .LBB3_4
-; CHECK-NEXT:  // %bb.1: // %private_za_call
-; CHECK-NEXT:    sub x8, x29, #16
-; CHECK-NEXT:    msr TPIDR2_EL0, x8
-; CHECK-NEXT:    bl private_za_call
-; CHECK-NEXT:    smstart za
-; CHECK-NEXT:    mrs x8, TPIDR2_EL0
-; CHECK-NEXT:    sub x0, x29, #16
-; CHECK-NEXT:    cbnz x8, .LBB3_3
-; CHECK-NEXT:  // %bb.2: // %private_za_call
-; CHECK-NEXT:    bl __arm_tpidr2_restore
-; CHECK-NEXT:  .LBB3_3: // %private_za_call
-; CHECK-NEXT:    msr TPIDR2_EL0, xzr
-; CHECK-NEXT:  .LBB3_4: // %exit
-; CHECK-NEXT:    mov sp, x29
-; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
-; CHECK-NEXT:    b shared_za_call
-;
-; CHECK-NEWLOWERING-LABEL: cond_private_za_call:
-; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    mov x29, sp
-; CHECK-NEWLOWERING-NEXT:    sub sp, sp, #16
-; CHECK-NEWLOWERING-NEXT:    rdsvl x8, #1
-; CHECK-NEWLOWERING-NEXT:    mov x9, sp
-; CHECK-NEWLOWERING-NEXT:    msub x9, x8, x8, x9
-; CHECK-NEWLOWERING-NEXT:    mov sp, x9
-; CHECK-NEWLOWERING-NEXT:    sub x10, x29, #16
-; CHECK-NEWLOWERING-NEXT:    stp x9, x8, [x29, #-16]
-; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, x10
-; CHECK-NEWLOWERING-NEXT:    tbz w0, #0, .LBB3_2
-; CHECK-NEWLOWERING-NEXT:  // %bb.1: // %private_za_call
-; CHECK-NEWLOWERING-NEXT:    bl private_za_call
-; CHECK-NEWLOWERING-NEXT:  .LBB3_2: // %exit
-; CHECK-NEWLOWERING-NEXT:    smstart za
-; CHECK-NEWLOWERING-NEXT:    mrs x8, TPIDR2_EL0
-; CHECK-NEWLOWERING-NEXT:    sub x0, x29, #16
-; CHECK-NEWLOWERING-NEXT:    cbnz x8, .LBB3_4
-; CHECK-NEWLOWERING-NEXT:  // %bb.3: // %exit
-; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_restore
-; CHECK-NEWLOWERING-NEXT:  .LBB3_4: // %exit
-; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
-; CHECK-NEWLOWERING-NEXT:    mov sp, x29
-; CHECK-NEWLOWERING-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT:    b shared_za_call
+; CHECK-COMMON-LABEL: cond_private_za_call:
+; CHECK-COMMON:       // %bb.0:
+; CHECK-COMMON-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-COMMON-NEXT:    mov x29, sp
+; CHECK-COMMON-NEXT:    sub sp, sp, #16
+; CHECK-COMMON-NEXT:    rdsvl x8, #1
+; CHECK-COMMON-NEXT:    mov x9, sp
+; CHECK-COMMON-NEXT:    msub x9, x8, x8, x9
+; CHECK-COMMON-NEXT:    mov sp, x9
+; CHECK-COMMON-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-COMMON-NEXT:    tbz w0, #0, .LBB3_4
+; CHECK-COMMON-NEXT:  // %bb.1: // %private_za_call
+; CHECK-COMMON-NEXT:    sub x8, x29, #16
+; CHECK-COMMON-NEXT:    msr TPIDR2_EL0, x8
+; CHECK-COMMON-NEXT:    bl private_za_call
+; CHECK-COMMON-NEXT:    smstart za
+; CHECK-COMMON-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-COMMON-NEXT:    sub x0, x29, #16
+; CHECK-COMMON-NEXT:    cbnz x8, .LBB3_3
+; CHECK-COMMON-NEXT:  // %bb.2: // %private_za_call
+; CHECK-COMMON-NEXT:    bl __arm_tpidr2_restore
+; CHECK-COMMON-NEXT:  .LBB3_3: // %private_za_call
+; CHECK-COMMON-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-COMMON-NEXT:  .LBB3_4: // %exit
+; CHECK-COMMON-NEXT:    mov sp, x29
+; CHECK-COMMON-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-COMMON-NEXT:    b shared_za_call
   br i1 %cond, label %private_za_call, label %exit
 
 private_za_call:
@@ -910,7 +879,7 @@ define void @loop_with_external_entry(i1 %c1, i1 %c2) "aarch64_inout_za" nounwin
 ; CHECK-NEWLOWERING-LABEL: loop_with_external_entry:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
 ; CHECK-NEWLOWERING-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    stp x20, x19, [sp, #16] // 16-byte Folded Spill
 ; CHECK-NEWLOWERING-NEXT:    mov x29, sp
 ; CHECK-NEWLOWERING-NEXT:    sub sp, sp, #16
 ; CHECK-NEWLOWERING-NEXT:    rdsvl x8, #1
@@ -923,23 +892,27 @@ define void @loop_with_external_entry(i1 %c1, i1 %c2) "aarch64_inout_za" nounwin
 ; CHECK-NEWLOWERING-NEXT:  // %bb.1: // %init
 ; CHECK-NEWLOWERING-NEXT:    bl shared_za_call
 ; CHECK-NEWLOWERING-NEXT:  .LBB11_2: // %loop.preheader
-; CHECK-NEWLOWERING-NEXT:    sub x8, x29, #16
-; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, x8
+; CHECK-NEWLOWERING-NEXT:    sub x20, x29, #16
+; CHECK-NEWLOWERING-NEXT:    b .LBB11_4
 ; CHECK-NEWLOWERING-NEXT:  .LBB11_3: // %loop
+; CHECK-NEWLOWERING-NEXT:    // in Loop: Header=BB11_4 Depth=1
+; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEWLOWERING-NEXT:    tbz w19, #0, .LBB11_6
+; CHECK-NEWLOWERING-NEXT:  .LBB11_4: // %loop
 ; CHECK-NEWLOWERING-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, x20
 ; CHECK-NEWLOWERING-NEXT:    bl private_za_call
-; CHECK-NEWLOWERING-NEXT:    tbnz w19, #0, .LBB11_3
-; CHECK-NEWLOWERING-NEXT:  // %bb.4: // %exit
 ; CHECK-NEWLOWERING-NEXT:    smstart za
 ; CHECK-NEWLOWERING-NEXT:    mrs x8, TPIDR2_EL0
 ; CHECK-NEWLOWERING-NEXT:    sub x0, x29, #16
-; CHECK-NEWLOWERING-NEXT:    cbnz x8, .LBB11_6
-; CHECK-NEWLOWERING-NEXT:  // %bb.5: // %exit
+; CHECK-NEWLOWERING-NEXT:    cbnz x8, .LBB11_3
+; CHECK-NEWLOWERING-NEXT:  // %bb.5: // %loop
+; CHECK-NEWLOWERING-NEXT:    // in Loop: Header=BB11_4 Depth=1
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEWLOWERING-NEXT:    b .LBB11_3
 ; CHECK-NEWLOWERING-NEXT:  .LBB11_6: // %exit
-; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
 ; CHECK-NEWLOWERING-NEXT:    mov sp, x29
-; CHECK-NEWLOWERING-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp x20, x19, [sp, #16] // 16-byte Folded Reload
 ; CHECK-NEWLOWERING-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
diff --git a/llvm/test/CodeGen/AArch64/sme-za-exceptions.ll b/llvm/test/CodeGen/AArch64/sme-za-exceptions.ll
index bb88142efa592..506974a14c3be 100644
--- a/llvm/test/CodeGen/AArch64/sme-za-exceptions.ll
+++ b/llvm/test/CodeGen/AArch64/sme-za-exceptions.ll
@@ -56,31 +56,23 @@ define void @za_with_raii(i1 %fail) "aarch64_inout_za" personality ptr @__gxx_pe
 ; CHECK-NEXT:    adrp x8, .L.str
 ; CHECK-NEXT:    add x8, x8, :lo12:.L.str
 ; CHECK-NEXT:    str x8, [x0]
-; CHECK-NEXT:  .Ltmp0:
+; CHECK-NEXT:  .Ltmp0: // EH_LABEL
 ; CHECK-NEXT:    adrp x1, :got:typeinfo_for_char_const_ptr
 ; CHECK-NEXT:    mov x2, xzr
 ; CHECK-NEXT:    ldr x1, [x1, :got_lo12:typeinfo_for_char_const_ptr]
 ; CHECK-NEXT:    bl __cxa_throw
-; CHECK-NEXT:  .Ltmp1:
-; CHECK-NEXT:    smstart za
-; CHECK-NEXT:    mrs x8, TPIDR2_EL0
-; CHECK-NEXT:    sub x0, x29, #16
-; CHECK-NEXT:    cbnz x8, .LBB0_4
-; CHECK-NEXT:  // %bb.3: // %throw_exception
-; CHECK-NEXT:    bl __arm_tpidr2_restore
-; CHECK-NEXT:  .LBB0_4: // %throw_exception
-; CHECK-NEXT:    msr TPIDR2_EL0, xzr
-; CHECK-NEXT:  // %bb.5: // %throw_fail
-; CHECK-NEXT:  .LBB0_6: // %unwind_dtors
-; CHECK-NEXT:  .Ltmp2:
+; CHECK-NEXT:  .Ltmp1: // EH_LABEL
+; CHECK-NEXT:  // %bb.3: // %throw_fail
+; CHECK-NEXT:  .LBB0_4: // %unwind_dtors
+; CHECK-NEXT:  .Ltmp2: // EH_LABEL
 ; CHECK-NEXT:    mov x19, x0
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
 ; CHECK-NEXT:    sub x0, x29, #16
-; CHECK-NEXT:    cbnz x8, .LBB0_8
-; CHECK-NEXT:  // %bb.7: // %unwind_dtors
+; CHECK-NEXT:    cbnz x8, .LBB0_6
+; CHECK-NEXT:  // %bb.5: // %unwind_dtors
 ; CHECK-NEXT:    bl __arm_tpidr2_restore
-; CHECK-NEXT:  .LBB0_8: // %unwind_dtors
+; CHECK-NEXT:  .LBB0_6: // %unwind_dtors
 ; CHECK-NEXT:    msr TPIDR2_EL0, xzr
 ; CHECK-NEXT:    bl shared_za_call
 ; CHECK-NEXT:    sub x8, x29, #16
@@ -142,11 +134,11 @@ define dso_local void @try_catch() "aarch64_inout_za" personality ptr @__gxx_per
 ; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x9
 ; CHECK-NEXT:    stp x9, x8, [x29, #-16]
-; CHECK-NEXT:  .Ltmp3:
+; CHECK-NEXT:  .Ltmp3: // EH_LABEL
 ; CHECK-NEXT:    sub x8, x29, #16
 ; CHECK-NEXT:    msr TPIDR2_EL0, x8
 ; CHECK-NEXT:    bl may_throw
-; CHECK-NEXT:  .Ltmp4:
+; CHECK-NEXT:  .Ltmp4: // EH_LABEL
 ; CHECK-NEXT:  .LBB1_1: // %after_catch
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
@@ -160,7 +152,7 @@ define dso_local void @try_catch() "aarch64_inout_za" personality ptr @__gxx_per
 ; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
 ; CHECK-NEXT:    b shared_za_call
 ; CHECK-NEXT:  .LBB1_4: // %catch
-; CHECK-NEXT:  .Ltmp5:
+; CHECK-NEXT:  .Ltmp5: // EH_LABEL
 ; CHECK-NEXT:    bl __cxa_begin_catch
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
@@ -235,16 +227,16 @@ define void @try_catch_shared_za_callee() "aarch64_new_za" personality ptr @__gx
 ; CHECK-NEXT:    zero {za}
 ; CHECK-NEXT:  .LBB2_2:
 ; CHECK-NEXT:    smstart za
-; CHECK-NEXT:  .Ltmp6:
+; CHECK-NEXT:  .Ltmp6: // EH_LABEL
 ; CHECK-NEXT:    bl shared_za_call
-; CHECK-NEXT:  .Ltmp7:
+; CHECK-NEXT:  .Ltmp7: // EH_LABEL
 ; CHECK-NEXT:  .LBB2_3: // %exit
 ; CHECK-NEXT:    smstop za
 ; CHECK-NEXT:    mov sp, x29
 ; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
 ; CHECK-NEXT:  .LBB2_4: // %catch
-; CHECK-NEXT:  .Ltmp8:
+; CHECK-NEXT:  .Ltmp8: // EH_LABEL
 ; CHECK-NEXT:    bl __cxa_begin_catch
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
diff --git a/llvm/test/CodeGen/AArch64/sme-za-function-with-many-blocks.ll b/llvm/test/CodeGen/AArch64/sme-za-function-with-many-blocks.ll
new file mode 100644
index 0000000000000..0306b27cb17e1
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-za-function-with-many-blocks.ll
@@ -0,0 +1,296 @@
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -aarch64-new-sme-abi < %s | FileCheck %s
+
+; This test case was generated by lowering mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir to LLVM IR.
+; The actual contents of the function are not that important. The main interesting quality here is that many blocks
+; don't directly use ZA. The only blocks that require ZA are the MOPA (and load/stores) in the inner loop, and the
+;`printMemrefF32()` call in the exit block.
+;
+; If ZA states are not propagated in the MachineSMEABIPass block %48 (which is within the outer loop), will
+; have  an edge to block %226 (the exit block), which requires ZA in the "saved" state, and an edge to block %51
+; (which has no preference on ZA state). This means block %48 will also end up in the locally saved state.
+; This is not really what we want, as it means we will save/restore ZA in the outer loop. We can fix this by
+; propagating the "active" state from the inner loop through basic blocks with no preference, to ensure the outer
+; loop is in the "active" state too.
+;
+; If done correctly, the only ZA save/restore should be in the exit block (with all other blocks in the active state).
+
+define void @matmul(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, ptr %7, ptr %8, i64 %9, i64 %10, i64 %11, i64 %12, i64 %13, ptr %14, ptr %15, i64 %16, i64 %17, i64 %18, i64 %19, i64 %20) #0 {
+; Check for a ZA zero in the entry block, then no uses of TPIDR2_EL0 (for ZA saves/restore)
+; until the exit block (which contains the call to printMemrefF32).
+;
+; CHECK-LABEL: matmul:
+; CHECK:      zero {za}
+; CHECK-NOT:  TPIDR2_EL0
+; CHECK:      msr TPIDR2_EL0, x{{.*}}
+; CHECK-NOT:  .LBB{{.*}}
+; CHECK:      bl printMemrefF32
+  %22 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } poison, ptr %14, 0
+  %23 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %22, ptr %15, 1
+  %24 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %23, i64 %16, 2
+  %25 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %24, i64 %17, 3, 0
+  %26 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %25, i64 %19, 4, 0
+  %27 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %26, i64 %18, 3, 1
+  %28 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %27, i64 %20, 4, 1
+  %29 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } poison, ptr %7, 0
+  %30 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %29, ptr %8, 1
+  %31 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %30, i64 %9, 2
+  %32 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %31, i64 %10, 3, 0
+  %33 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %32, i64 %12, 4, 0
+  %34 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %33, i64 %11, 3, 1
+  %35 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %34, i64 %13, 4, 1
+  %36 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } poison, ptr %0, 0
+  %37 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %36, ptr %1, 1
+  %38 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %37, i64 %2, 2
+  %39 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %38, i64 %3, 3, 0
+  %40 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %39, i64 %5, 4, 0
+  %41 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %40, i64 %4, 3, 1
+  %42 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %41, i64 %6, 4, 1
+  %43 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, 3, 0
+  %44 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, 3, 1
+  %45 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %35, 3, 1
+  %46 = call i64 @llvm.vscale.i64()
+  %47 = mul i64 %46, 4
+  br label %48
+
+48:                                               ; preds = %224, %21
+  %49 = phi i64 [ %225, %224 ], [ 0, %21 ]
+  %50 = icmp slt i64 %49, %43
+  br i1 %50, label %51, label %226
+
+51:                                               ; preds = %48
+  %52 = sub i64 %43, %49
+  %53 = call i64 @llvm.smin.i64(i64 %47, i64 %52)
+  %54 = call <vscale x 4 x i32> @llvm.stepvector.nxv4i32()
+  %55 = trunc i64 %53 to i32
+  %56 = insertelement <vscale x 4 x i32> poison, i32 %55, i32 0
+  %57 = shufflevector <vscale x 4 x i32> %56, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+  %58 = icmp slt <vscale x 4 x i32> %54, %57
+  br label %59
+
+59:                                               ; preds = %222, %51
+  %60 = phi i64 [ %223, %222 ], [ 0, %51 ]
+  %61 = icmp slt i64 %60, %45
+  br i1 %61, label %62, label %224
+
+62:                                               ; preds = %59
+  %63 = sub i64 %45, %60
+  %64 = call i64 @llvm.smin.i64(i64 %47, i64 %63)
+  %65 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %28, 0
+  %66 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %28, 1
+  %67 = insertvalue { ptr, ptr, i64 } poison, ptr %65, 0
+  %68 = insertvalue { ptr, ptr, i64 } %67, ptr %66, 1
+  %69 = insertvalue { ptr, ptr, i64 } %68, i64 0, 2
+  %70 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %28, 2
+  %71 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %28, 3, 0
+  %72 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %28, 3, 1
+  %73 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %28, 4, 0
+  %74 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %28, 4, 1
+  %75 = mul nsw i64 %49, %73
+  %76 = add i64 %70, %75
+  %77 = mul nsw i64 %60, %74
+  %78 = add i64 %76, %77
+  %79 = extractvalue { ptr, ptr, i64 } %69, 0
+  %80 = extractvalue { ptr, ptr, i64 } %69, 1
+  %81 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } poison, ptr %79, 0
+  %82 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %81, ptr %80, 1
+  %83 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %82, i64 %78, 2
+  %84 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %83, i64 %53, 3, 0
+  %85 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %84, i64 %73, 4, 0
+  %86 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %85, i64 %64, 3, 1
+  %87 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %86, i64 %74, 4, 1
+  %88 = call <vscale x 4 x i32> @llvm.stepvector.nxv4i32()
+  %89 = trunc i64 %64 to i32
+  %90 = insertelement <vscale x 4 x i32> poison, i32 %89, i32 0
+  %91 = shufflevector <vscale x 4 x i32> %90, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+  %92 = icmp slt <vscale x 4 x i32> %88, %91
+  br label %93
+
+93:                                               ; preds = %220, %62
+  %94 = phi i64 [ %221, %220 ], [ 0, %62 ]
+  %95 = icmp slt i64 %94, %44
+  br i1 %95, label %96, label %222
+
+96:                                               ; preds = %93
+  %97 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, 0
+  %98 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, 1
+  %99 = insertvalue { ptr, ptr, i64 } poison, ptr %97, 0
+  %100 = insertvalue { ptr, ptr, i64 } %99, ptr %98, 1
+  %101 = insertvalue { ptr, ptr, i64 } %100, i64 0, 2
+  %102 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, 2
+  %103 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, 3, 0
+  %104 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, 3, 1
+  %105 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, 4, 0
+  %106 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, 4, 1
+  %107 = mul nsw i64 %49, %105
+  %108 = add i64 %102, %107
+  %109 = mul nsw i64 %94, %106
+  %110 = add i64 %108, %109
+  %111 = extractvalue { ptr, ptr, i64 } %101, 0
+  %112 = extractvalue { ptr, ptr, i64 } %101, 1
+  %113 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } poison, ptr %111, 0
+  %114 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %113, ptr %112, 1
+  %115 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %114, i64 %110, 2
+  %116 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %115, i64 %53, 3, 0
+  %117 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %116, i64 %105, 4, 0
+  br label %118
+
+118:                                              ; preds = %133, %96
+  %119 = phi i64 [ %135, %133 ], [ 0, %96 ]
+  %120 = phi <vscale x 4 x float> [ %134, %133 ], [ poison, %96 ]
+  %121 = icmp slt i64 %119, %47
+  br i1 %121, label %122, label %136
+
+122:                                              ; preds = %118
+  %123 = extractelement <vscale x 4 x i1> %58, i64 %119
+  br i1 %123, label %124, label %133
+
+124:                                              ; preds = %122
+  %125 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %117, 1
+  %126 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %117, 2
+  %127 = getelementptr float, ptr %125, i64 %126
+  %128 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %117, 4, 0
+  %129 = mul nuw nsw i64 %119, %128
+  %130 = getelementptr inbounds nuw float, ptr %127, i64 %129
+  %131 = load float, ptr %130, align 4
+  %132 = insertelement <vscale x 4 x float> %120, float %131, i64 %119
+  br label %133
+
+133:                                              ; preds = %124, %122
+  %134 = phi <vscale x 4 x float> [ %132, %124 ], [ %120, %122 ]
+  %135 = add i64 %119, 1
+  br label %118
+
+136:                                              ; preds = %118
+  %137 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %35, 0
+  %138 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %35, 1
+  %139 = insertvalue { ptr, ptr, i64 } poison, ptr %137, 0
+  %140 = insertvalue { ptr, ptr, i64 } %139, ptr %138, 1
+  %141 = insertvalue { ptr, ptr, i64 } %140, i64 0, 2
+  %142 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %35, 2
+  %143 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %35, 3, 0
+  %144 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %35, 3, 1
+  %145 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %35, 4, 0
+  %146 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %35, 4, 1
+  %147 = mul nsw i64 %94, %145
+  %148 = add i64 %142, %147
+  %149 = mul nsw i64 %60, %146
+  %150 = add i64 %148, %149
+  %151 = extractvalue { ptr, ptr, i64 } %141, 0
+  %152 = extractvalue { ptr, ptr, i64 } %141, 1
+  %153 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } poison, ptr %151, 0
+  %154 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %153, ptr %152, 1
+  %155 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %154, i64 %150, 2
+  %156 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %155, i64 %64, 3, 0
+  %157 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %156, i64 %146, 4, 0
+  br label %158
+
+158:                                              ; preds = %173, %136
+  %159 = phi i64 [ %175, %173 ], [ 0, %136 ]
+  %160 = phi <vscale x 4 x float> [ %174, %173 ], [ poison, %136 ]
+  %161 = icmp slt i64 %159, %47
+  br i1 %161, label %162, label %176
+
+162:                                              ; preds = %158
+  %163 = extractelement <vscale x 4 x i1> %92, i64 %159
+  br i1 %163, label %164, label %173
+
+164:                                              ; preds = %162
+  %165 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %157, 1
+  %166 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %157, 2
+  %167 = getelementptr float, ptr %165, i64 %166
+  %168 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %157, 4, 0
+  %169 = mul nuw nsw i64 %159, %168
+  %170 = getelementptr inbounds nuw float, ptr %167, i64 %169
+  %171 = load float, ptr %170, align 4
+  %172 = insertelement <vscale x 4 x float> %160, float %171, i64 %159
+  br label %173
+
+173:                                              ; preds = %164, %162
+  %174 = phi <vscale x 4 x float> [ %172, %164 ], [ %160, %162 ]
+  %175 = add i64 %159, 1
+  br label %158
+
+176:                                              ; preds = %158
+  %177 = trunc i64 %64 to i32
+  br label %178
+
+178:                                              ; preds = %181, %176
+  %179 = phi i64 [ %202, %181 ], [ 0, %176 ]
+  %180 = icmp slt i64 %179, %47
+  br i1 %180, label %181, label %203
+
+181:                                              ; preds = %178
+  %182 = icmp ult i64 %179, %53
+  %183 = sext i1 %182 to i32
+  %184 = and i32 %183, %177
+  %185 = sext i32 %184 to i64
+  %186 = call <vscale x 4 x i32> @llvm.stepvector.nxv4i32()
+  %187 = trunc i64 %185 to i32
+  %188 = insertelement <vscale x 4 x i32> poison, i32 %187, i32 0
+  %189 = shufflevector <vscale x 4 x i32> %188, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+  %190 = icmp slt <vscale x 4 x i32> %186, %189
+  %191 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %87, 1
+  %192 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %87, 2
+  %193 = getelementptr float, ptr %191, i64 %192
+  %194 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %87, 4, 0
+  %195 = mul i64 %179, %194
+  %196 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %87, 4, 1
+  %197 = mul i64 0, %196
+  %198 = add i64 %195, %197
+  %199 = getelementptr float, ptr %193, i64 %198
+  %200 = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %199, i32 4, <vscale x 4 x i1> %190, <vscale x 4 x float> poison)
+  %201 = trunc i64 %179 to i32
+  call void @llvm.aarch64.sme.write.horiz.nxv4f32(i32 0, i32 %201, <vscale x 4 x i1> splat (i1 true), <vscale x 4 x float> %200)
+  %202 = add i64 %179, 1
+  br label %178
+
+203:                                              ; preds = %178
+  call void @llvm.aarch64.sme.mopa.nxv4f32(i32 0, <vscale x 4 x i1> %58, <vscale x 4 x i1> %92, <vscale x 4 x float> %120, <vscale x 4 x float> %160)
+  %204 = call i64 @llvm.smin.i64(i64 %53, i64 %47)
+  br label %205
+
+205:                                              ; preds = %208, %203
+  %206 = phi i64 [ %219, %208 ], [ 0, %203 ]
+  %207 = icmp slt i64 %206, %204
+  br i1 %207, label %208, label %220
+
+208:                                              ; preds = %205
+  %209 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %87, 1
+  %210 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %87, 2
+  %211 = getelementptr float, ptr %209, i64 %210
+  %212 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %87, 4, 0
+  %213 = mul i64 %206, %212
+  %214 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %87, 4, 1
+  %215 = mul i64 0, %214
+  %216 = add i64 %213, %215
+  %217 = getelementptr float, ptr %211, i64 %216
+  %218 = trunc i64 %206 to i32
+  call void @llvm.aarch64.sme.st1w.horiz(<vscale x 4 x i1> %92, ptr %217, i32 0, i32 %218)
+  %219 = add i64 %206, 1
+  br label %205
+
+220:                                              ; preds = %205
+  %221 = add i64 %94, 1
+  br label %93
+
+222:                                              ; preds = %93
+  %223 = add i64 %60, %47
+  br label %59
+
+224:                                              ; preds = %59
+  %225 = add i64 %49, %47
+  br label %48
+
+226:                                              ; preds = %48
+  %227 = alloca { ptr, ptr, i64, [2 x i64], [2 x i64] }, i64 1, align 8
+  store { ptr, ptr, i64, [2 x i64], [2 x i64] } %28, ptr %227, align 8
+  %228 = insertvalue { i64, ptr } { i64 2, ptr poison }, ptr %227, 1
+  %229 = extractvalue { i64, ptr } %228, 0
+  %230 = extractvalue { i64, ptr } %228, 1
+  call void @printMemrefF32(i64 %229, ptr %230)
+  ret void
+}
+
+declare void @printMemrefF32(i64, ptr)
+
+attributes #0 = { "aarch64_new_za" "aarch64_pstate_sm_body" }
diff --git a/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll b/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll
index 066ee3b040469..afd56d198d0d3 100644
--- a/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll
+++ b/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll
@@ -12,77 +12,41 @@ entry:
 }
 
 define float @multi_bb_stpidr2_save_required(i32 %a, float %b, float %c) "aarch64_inout_za" {
-; CHECK-LABEL: multi_bb_stpidr2_save_required:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
-; CHECK-NEXT:    mov x29, sp
-; CHECK-NEXT:    sub sp, sp, #16
-; CHECK-NEXT:    .cfi_def_cfa w29, 16
-; CHECK-NEXT:    .cfi_offset w30, -8
-; CHECK-NEXT:    .cfi_offset w29, -16
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    mov x9, sp
-; CHECK-NEXT:    msub x9, x8, x8, x9
-; CHECK-NEXT:    mov sp, x9
-; CHECK-NEXT:    stp x9, x8, [x29, #-16]
-; CHECK-NEXT:    cbz w0, .LBB1_2
-; CHECK-NEXT:  // %bb.1: // %use_b
-; CHECK-NEXT:    fmov s1, #4.00000000
-; CHECK-NEXT:    fadd s0, s0, s1
-; CHECK-NEXT:    b .LBB1_5
-; CHECK-NEXT:  .LBB1_2: // %use_c
-; CHECK-NEXT:    fmov s0, s1
-; CHECK-NEXT:    sub x8, x29, #16
-; CHECK-NEXT:    msr TPIDR2_EL0, x8
-; CHECK-NEXT:    bl cosf
-; CHECK-NEXT:    smstart za
-; CHECK-NEXT:    mrs x8, TPIDR2_EL0
-; CHECK-NEXT:    sub x0, x29, #16
-; CHECK-NEXT:    cbnz x8, .LBB1_4
-; CHECK-NEXT:  // %bb.3: // %use_c
-; CHECK-NEXT:    bl __arm_tpidr2_restore
-; CHECK-NEXT:  .LBB1_4: // %use_c
-; CHECK-NEXT:    msr TPIDR2_EL0, xzr
-; CHECK-NEXT:  .LBB1_5: // %exit
-; CHECK-NEXT:    mov sp, x29
-; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
-; CHECK-NEXT:    ret
-;
-; CHECK-NEWLOWERING-LABEL: multi_bb_stpidr2_save_required:
-; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    mov x29, sp
-; CHECK-NEWLOWERING-NEXT:    sub sp, sp, #16
-; CHECK-NEWLOWERING-NEXT:    .cfi_def_cfa w29, 16
-; CHECK-NEWLOWERING-NEXT:    .cfi_offset w30, -8
-; CHECK-NEWLOWERING-NEXT:    .cfi_offset w29, -16
-; CHECK-NEWLOWERING-NEXT:    rdsvl x8, #1
-; CHECK-NEWLOWERING-NEXT:    mov x9, sp
-; CHECK-NEWLOWERING-NEXT:    msub x9, x8, x8, x9
-; CHECK-NEWLOWERING-NEXT:    mov sp, x9
-; CHECK-NEWLOWERING-NEXT:    sub x10, x29, #16
-; CHECK-NEWLOWERING-NEXT:    stp x9, x8, [x29, #-16]
-; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, x10
-; CHECK-NEWLOWERING-NEXT:    cbz w0, .LBB1_2
-; CHECK-NEWLOWERING-NEXT:  // %bb.1: // %use_b
-; CHECK-NEWLOWERING-NEXT:    fmov s1, #4.00000000
-; CHECK-NEWLOWERING-NEXT:    fadd s0, s0, s1
-; CHECK-NEWLOWERING-NEXT:    b .LBB1_3
-; CHECK-NEWLOWERING-NEXT:  .LBB1_2: // %use_c
-; CHECK-NEWLOWERING-NEXT:    fmov s0, s1
-; CHECK-NEWLOWERING-NEXT:    bl cosf
-; CHECK-NEWLOWERING-NEXT:  .LBB1_3: // %exit
-; CHECK-NEWLOWERING-NEXT:    smstart za
-; CHECK-NEWLOWERING-NEXT:    mrs x8, TPIDR2_EL0
-; CHECK-NEWLOWERING-NEXT:    sub x0, x29, #16
-; CHECK-NEWLOWERING-NEXT:    cbnz x8, .LBB1_5
-; CHECK-NEWLOWERING-NEXT:  // %bb.4: // %exit
-; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_restore
-; CHECK-NEWLOWERING-NEXT:  .LBB1_5: // %exit
-; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
-; CHECK-NEWLOWERING-NEXT:    mov sp, x29
-; CHECK-NEWLOWERING-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-COMMON-LABEL: multi_bb_stpidr2_save_required:
+; CHECK-COMMON:       // %bb.0:
+; CHECK-COMMON-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-COMMON-NEXT:    mov x29, sp
+; CHECK-COMMON-NEXT:    sub sp, sp, #16
+; CHECK-COMMON-NEXT:    .cfi_def_cfa w29, 16
+; CHECK-COMMON-NEXT:    .cfi_offset w30, -8
+; CHECK-COMMON-NEXT:    .cfi_offset w29, -16
+; CHECK-COMMON-NEXT:    rdsvl x8, #1
+; CHECK-COMMON-NEXT:    mov x9, sp
+; CHECK-COMMON-NEXT:    msub x9, x8, x8, x9
+; CHECK-COMMON-NEXT:    mov sp, x9
+; CHECK-COMMON-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-COMMON-NEXT:    cbz w0, .LBB1_2
+; CHECK-COMMON-NEXT:  // %bb.1: // %use_b
+; CHECK-COMMON-NEXT:    fmov s1, #4.00000000
+; CHECK-COMMON-NEXT:    fadd s0, s0, s1
+; CHECK-COMMON-NEXT:    b .LBB1_5
+; CHECK-COMMON-NEXT:  .LBB1_2: // %use_c
+; CHECK-COMMON-NEXT:    fmov s0, s1
+; CHECK-COMMON-NEXT:    sub x8, x29, #16
+; CHECK-COMMON-NEXT:    msr TPIDR2_EL0, x8
+; CHECK-COMMON-NEXT:    bl cosf
+; CHECK-COMMON-NEXT:    smstart za
+; CHECK-COMMON-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-COMMON-NEXT:    sub x0, x29, #16
+; CHECK-COMMON-NEXT:    cbnz x8, .LBB1_4
+; CHECK-COMMON-NEXT:  // %bb.3: // %use_c
+; CHECK-COMMON-NEXT:    bl __arm_tpidr2_restore
+; CHECK-COMMON-NEXT:  .LBB1_4: // %use_c
+; CHECK-COMMON-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-COMMON-NEXT:  .LBB1_5: // %exit
+; CHECK-COMMON-NEXT:    mov sp, x29
+; CHECK-COMMON-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-COMMON-NEXT:    ret
   %cmp = icmp ne i32 %a, 0
   br i1 %cmp, label %use_b, label %use_c
 
@@ -155,7 +119,9 @@ define float @multi_bb_stpidr2_save_required_stackprobe(i32 %a, float %b, float
 ; CHECK-NEWLOWERING-NEXT:    .cfi_offset w29, -16
 ; CHECK-NEWLOWERING-NEXT:    rdsvl x8, #1
 ; CHECK-NEWLOWERING-NEXT:    mov x9, sp
+; CHECK-NEWLOWERING-NEXT:    sub x10, x29, #16
 ; CHECK-NEWLOWERING-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-NEWLOWERING-NEXT:  .LBB2_1: // =>This Inner Loop Header: Depth=1
 ; CHECK-NEWLOWERING-NEXT:    sub sp, sp, #16, lsl #12 // =65536
 ; CHECK-NEWLOWERING-NEXT:    cmp sp, x9
@@ -166,9 +132,7 @@ define float @multi_bb_stpidr2_save_required_stackprobe(i32 %a, float %b, float
 ; CHECK-NEWLOWERING-NEXT:  .LBB2_3:
 ; CHECK-NEWLOWERING-NEXT:    mov sp, x9
 ; CHECK-NEWLOWERING-NEXT:    ldr xzr, [sp]
-; CHECK-NEWLOWERING-NEXT:    sub x10, x29, #16
 ; CHECK-NEWLOWERING-NEXT:    stp x9, x8, [x29, #-16]
-; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-NEWLOWERING-NEXT:    cbz w0, .LBB2_5
 ; CHECK-NEWLOWERING-NEXT:  // %bb.4: // %use_b
 ; CHECK-NEWLOWERING-NEXT:    fmov s1, #4.00000000



More information about the llvm-commits mailing list