[llvm] [AArch64][SME] Avoid ZA save state changes in loops in MachineSMEABIPass (PR #149065)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 28 03:18:21 PDT 2025


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/149065

>From 72b89567ebf09c7340dbb692e8bd1b2673a9769d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 5 Sep 2025 12:15:28 +0000
Subject: [PATCH] [AArch64][SME] Avoid ZA save state changes in loops in
 MachineSMEABIPass

This patch uses the MachineLoopInfo to give blocks within loops a higher
weight when choosing the bundle ZA state. MachineLoopInfo does not find
loop trip counts, so this uses an arbitrary weight (default 10), which
can be configured with the `-aarch64-sme-abi-loop-edge-weight` flag.

This makes the MachineSMEABIPass pass more likely to pick a bundle state
that matches the loop's entry/exit state, which avoids state changes in
the loop (which we assume will happen more than once).

This does require some extra analysis, so this is only enabled at -O1
and above.

Change-Id: If318c809d2f7cc1fca144fbe424ba2a2ca7fb19f
---
 llvm/lib/Target/AArch64/MachineSMEABIPass.cpp |  25 +++-
 .../CodeGen/AArch64/sme-lazy-save-in-loop.ll  | 113 ++++++++++++++++++
 .../CodeGen/AArch64/sme-za-control-flow.ll    |  55 ++++-----
 3 files changed, 158 insertions(+), 35 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sme-lazy-save-in-loop.ll

diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index 7cb500394cec2..03570dba09caa 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -63,6 +63,7 @@
 #include "llvm/CodeGen/LivePhysRegs.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineLoopInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/TargetRegisterInfo.h"
 
@@ -70,6 +71,12 @@ using namespace llvm;
 
 #define DEBUG_TYPE "aarch64-machine-sme-abi"
 
+static cl::opt<int>
+    LoopEdgeWeight("aarch64-sme-abi-loop-edge-weight", cl::ReallyHidden,
+                   cl::init(10),
+                   cl::desc("Edge weight for basic blocks witin loops (used "
+                            "for placing ZA saves/restores)"));
+
 namespace {
 
 enum ZAState {
@@ -255,6 +262,9 @@ struct MachineSMEABI : public MachineFunctionPass {
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesCFG();
     AU.addRequired<EdgeBundlesWrapperLegacy>();
+    // Only analyse loops at -01 and above.
+    if (OptLevel != CodeGenOptLevel::None)
+      AU.addRequired<MachineLoopInfoWrapperPass>();
     AU.addPreservedID(MachineLoopInfoID);
     AU.addPreservedID(MachineDominatorsID);
     MachineFunctionPass::getAnalysisUsage(AU);
@@ -516,24 +526,31 @@ MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
     int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
     for (unsigned BlockID : Bundles.getBlocks(I)) {
       LLVM_DEBUG(dbgs() << "- bb." << BlockID);
-
       const BlockInfo &Block = FnInfo.Blocks[BlockID];
+      bool IsLoop = MLI && MLI->getLoopFor(MF->getBlockNumbered(BlockID));
       bool InEdge = Bundles.getBundle(BlockID, /*Out=*/false) == I;
       bool OutEdge = Bundles.getBundle(BlockID, /*Out=*/true) == I;
 
+      // TODO: Use MachineBranchProbabilityInfo for edge weights?
+      int EdgeWeight = IsLoop ? LoopEdgeWeight : 1;
+      if (IsLoop)
+        LLVM_DEBUG(dbgs() << " IsLoop");
+
       bool LegalInEdge =
           InEdge && isLegalEdgeBundleZAState(Block.DesiredIncomingState);
       bool LegalOutEgde =
           OutEdge && isLegalEdgeBundleZAState(Block.DesiredOutgoingState);
+
+      LLVM_DEBUG(dbgs() << " (EdgeWeight: " << EdgeWeight << ')');
       if (LegalInEdge) {
         LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
                           << getZAStateString(Block.DesiredIncomingState));
-        EdgeStateCounts[Block.DesiredIncomingState]++;
+        EdgeStateCounts[Block.DesiredIncomingState] += EdgeWeight;
       }
       if (LegalOutEgde) {
         LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
                           << getZAStateString(Block.DesiredOutgoingState));
-        EdgeStateCounts[Block.DesiredOutgoingState]++;
+        EdgeStateCounts[Block.DesiredOutgoingState] += EdgeWeight;
       }
       if (!LegalInEdge && !LegalOutEgde)
         LLVM_DEBUG(dbgs() << " (no state preference)");
@@ -982,6 +999,8 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
   TII = Subtarget->getInstrInfo();
   TRI = Subtarget->getRegisterInfo();
   MRI = &MF.getRegInfo();
+  if (OptLevel != CodeGenOptLevel::None)
+    MLI = &getAnalysis<MachineLoopInfoWrapperPass>().getLI();
 
   const EdgeBundles &Bundles =
       getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
diff --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-in-loop.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-in-loop.ll
new file mode 100644
index 0000000000000..1abc87080e2c7
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-in-loop.ll
@@ -0,0 +1,113 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -O0 -mtriple=aarch64-linux-gnu -mattr=+sme -aarch64-new-sme-abi < %s | FileCheck %s --check-prefix=CHECK-O0
+; RUN: llc -O1 -mtriple=aarch64-linux-gnu -mattr=+sme -aarch64-new-sme-abi < %s | FileCheck %s --check-prefix=CHECK-O1
+
+declare void @private_za_call()
+declare void @shared_za_call() "aarch64_inout_za"
+
+; This test checks that at -O0 we don't attempt to optimize lazy save state
+; changes in loops, and that -O1 (and above) we attempt to push state changes
+; out of loops.
+
+define void @private_za_loop_active_entry_and_exit(i32 %n) "aarch64_inout_za" nounwind {
+; CHECK-O0-LABEL: private_za_loop_active_entry_and_exit:
+; CHECK-O0:       // %bb.0: // %entry
+; CHECK-O0-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-O0-NEXT:    mov x29, sp
+; CHECK-O0-NEXT:    sub sp, sp, #32
+; CHECK-O0-NEXT:    rdsvl x9, #1
+; CHECK-O0-NEXT:    mov x8, sp
+; CHECK-O0-NEXT:    msub x8, x9, x9, x8
+; CHECK-O0-NEXT:    mov sp, x8
+; CHECK-O0-NEXT:    stp x8, x9, [x29, #-16]
+; CHECK-O0-NEXT:    stur w0, [x29, #-24] // 4-byte Folded Spill
+; CHECK-O0-NEXT:    bl shared_za_call
+; CHECK-O0-NEXT:    ldur w0, [x29, #-24] // 4-byte Folded Reload
+; CHECK-O0-NEXT:    mov w8, wzr
+; CHECK-O0-NEXT:    subs w9, w0, #1
+; CHECK-O0-NEXT:    stur w8, [x29, #-20] // 4-byte Folded Spill
+; CHECK-O0-NEXT:    b.lt .LBB0_4
+; CHECK-O0-NEXT:    b .LBB0_1
+; CHECK-O0-NEXT:  .LBB0_1: // %loop
+; CHECK-O0-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-O0-NEXT:    ldur w8, [x29, #-20] // 4-byte Folded Reload
+; CHECK-O0-NEXT:    stur w8, [x29, #-28] // 4-byte Folded Spill
+; CHECK-O0-NEXT:    sub x8, x29, #16
+; CHECK-O0-NEXT:    msr TPIDR2_EL0, x8
+; CHECK-O0-NEXT:    bl private_za_call
+; CHECK-O0-NEXT:    ldur w8, [x29, #-28] // 4-byte Folded Reload
+; CHECK-O0-NEXT:    ldur w10, [x29, #-24] // 4-byte Folded Reload
+; CHECK-O0-NEXT:    add w9, w8, #1
+; CHECK-O0-NEXT:    mov w8, w9
+; CHECK-O0-NEXT:    smstart za
+; CHECK-O0-NEXT:    mrs x11, TPIDR2_EL0
+; CHECK-O0-NEXT:    sub x0, x29, #16
+; CHECK-O0-NEXT:    cbz x11, .LBB0_2
+; CHECK-O0-NEXT:    b .LBB0_3
+; CHECK-O0-NEXT:  .LBB0_2: // %loop
+; CHECK-O0-NEXT:    // in Loop: Header=BB0_1 Depth=1
+; CHECK-O0-NEXT:    bl __arm_tpidr2_restore
+; CHECK-O0-NEXT:    b .LBB0_3
+; CHECK-O0-NEXT:  .LBB0_3: // %loop
+; CHECK-O0-NEXT:    // in Loop: Header=BB0_1 Depth=1
+; CHECK-O0-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-O0-NEXT:    subs w9, w9, w10
+; CHECK-O0-NEXT:    stur w8, [x29, #-20] // 4-byte Folded Spill
+; CHECK-O0-NEXT:    b.ne .LBB0_1
+; CHECK-O0-NEXT:    b .LBB0_4
+; CHECK-O0-NEXT:  .LBB0_4: // %exit
+; CHECK-O0-NEXT:    mov sp, x29
+; CHECK-O0-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-O0-NEXT:    b shared_za_call
+;
+; CHECK-O1-LABEL: private_za_loop_active_entry_and_exit:
+; CHECK-O1:       // %bb.0: // %entry
+; CHECK-O1-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-O1-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-O1-NEXT:    mov x29, sp
+; CHECK-O1-NEXT:    sub sp, sp, #16
+; CHECK-O1-NEXT:    rdsvl x8, #1
+; CHECK-O1-NEXT:    mov x9, sp
+; CHECK-O1-NEXT:    msub x9, x8, x8, x9
+; CHECK-O1-NEXT:    mov sp, x9
+; CHECK-O1-NEXT:    mov w19, w0
+; CHECK-O1-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-O1-NEXT:    bl shared_za_call
+; CHECK-O1-NEXT:    sub x8, x29, #16
+; CHECK-O1-NEXT:    cmp w19, #1
+; CHECK-O1-NEXT:    msr TPIDR2_EL0, x8
+; CHECK-O1-NEXT:    b.lt .LBB0_2
+; CHECK-O1-NEXT:  .LBB0_1: // %loop
+; CHECK-O1-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-O1-NEXT:    bl private_za_call
+; CHECK-O1-NEXT:    subs w19, w19, #1
+; CHECK-O1-NEXT:    b.ne .LBB0_1
+; CHECK-O1-NEXT:  .LBB0_2: // %exit
+; CHECK-O1-NEXT:    smstart za
+; CHECK-O1-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-O1-NEXT:    sub x0, x29, #16
+; CHECK-O1-NEXT:    cbnz x8, .LBB0_4
+; CHECK-O1-NEXT:  // %bb.3: // %exit
+; CHECK-O1-NEXT:    bl __arm_tpidr2_restore
+; CHECK-O1-NEXT:  .LBB0_4: // %exit
+; CHECK-O1-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-O1-NEXT:    mov sp, x29
+; CHECK-O1-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-O1-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-O1-NEXT:    b shared_za_call
+entry:
+  %cmpgt = icmp sgt i32 %n, 0
+  tail call void @shared_za_call()
+  br i1 %cmpgt, label %loop, label %exit
+
+loop:
+  %iv = phi i32 [ %next_iv, %loop ], [ 0, %entry ]
+  tail call void @private_za_call()
+  %next_iv = add nuw nsw i32 %iv, 1
+  %cmpeq = icmp eq i32 %next_iv, %n
+  br i1 %cmpeq, label %exit, label %loop
+
+exit:
+  tail call void @shared_za_call()
+  ret void
+}
diff --git a/llvm/test/CodeGen/AArch64/sme-za-control-flow.ll b/llvm/test/CodeGen/AArch64/sme-za-control-flow.ll
index c753e9c569d22..fcb6c0277c1a1 100644
--- a/llvm/test/CodeGen/AArch64/sme-za-control-flow.ll
+++ b/llvm/test/CodeGen/AArch64/sme-za-control-flow.ll
@@ -96,7 +96,7 @@ exit:
   ret void
 }
 
-; FIXME: In the new lowering we could weight edges to avoid doing the lazy save in the loop.
+; This tests that with the new lowering we push state changes out of loops (at -O1 and above).
 define void @private_za_loop_active_entry_and_exit(i32 %n) "aarch64_inout_za" nounwind {
 ; CHECK-LABEL: private_za_loop_active_entry_and_exit:
 ; CHECK:       // %bb.0: // %entry
@@ -142,7 +142,7 @@ define void @private_za_loop_active_entry_and_exit(i32 %n) "aarch64_inout_za" no
 ; CHECK-NEWLOWERING-LABEL: private_za_loop_active_entry_and_exit:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
 ; CHECK-NEWLOWERING-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    stp x20, x19, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
 ; CHECK-NEWLOWERING-NEXT:    mov x29, sp
 ; CHECK-NEWLOWERING-NEXT:    sub sp, sp, #16
 ; CHECK-NEWLOWERING-NEXT:    rdsvl x8, #1
@@ -152,31 +152,26 @@ define void @private_za_loop_active_entry_and_exit(i32 %n) "aarch64_inout_za" no
 ; CHECK-NEWLOWERING-NEXT:    mov w19, w0
 ; CHECK-NEWLOWERING-NEXT:    stp x9, x8, [x29, #-16]
 ; CHECK-NEWLOWERING-NEXT:    bl shared_za_call
+; CHECK-NEWLOWERING-NEXT:    sub x8, x29, #16
 ; CHECK-NEWLOWERING-NEXT:    cmp w19, #1
-; CHECK-NEWLOWERING-NEXT:    b.lt .LBB1_5
-; CHECK-NEWLOWERING-NEXT:  // %bb.1: // %loop.preheader
-; CHECK-NEWLOWERING-NEXT:    sub x20, x29, #16
-; CHECK-NEWLOWERING-NEXT:    b .LBB1_3
-; CHECK-NEWLOWERING-NEXT:  .LBB1_2: // %loop
-; CHECK-NEWLOWERING-NEXT:    // in Loop: Header=BB1_3 Depth=1
-; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
-; CHECK-NEWLOWERING-NEXT:    cbz w19, .LBB1_5
-; CHECK-NEWLOWERING-NEXT:  .LBB1_3: // %loop
+; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, x8
+; CHECK-NEWLOWERING-NEXT:    b.lt .LBB1_2
+; CHECK-NEWLOWERING-NEXT:  .LBB1_1: // %loop
 ; CHECK-NEWLOWERING-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, x20
 ; CHECK-NEWLOWERING-NEXT:    bl private_za_call
-; CHECK-NEWLOWERING-NEXT:    sub w19, w19, #1
+; CHECK-NEWLOWERING-NEXT:    subs w19, w19, #1
+; CHECK-NEWLOWERING-NEXT:    b.ne .LBB1_1
+; CHECK-NEWLOWERING-NEXT:  .LBB1_2: // %exit
 ; CHECK-NEWLOWERING-NEXT:    smstart za
 ; CHECK-NEWLOWERING-NEXT:    mrs x8, TPIDR2_EL0
 ; CHECK-NEWLOWERING-NEXT:    sub x0, x29, #16
-; CHECK-NEWLOWERING-NEXT:    cbnz x8, .LBB1_2
-; CHECK-NEWLOWERING-NEXT:  // %bb.4: // %loop
-; CHECK-NEWLOWERING-NEXT:    // in Loop: Header=BB1_3 Depth=1
+; CHECK-NEWLOWERING-NEXT:    cbnz x8, .LBB1_4
+; CHECK-NEWLOWERING-NEXT:  // %bb.3: // %exit
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_restore
-; CHECK-NEWLOWERING-NEXT:    b .LBB1_2
-; CHECK-NEWLOWERING-NEXT:  .LBB1_5: // %exit
+; CHECK-NEWLOWERING-NEXT:  .LBB1_4: // %exit
+; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
 ; CHECK-NEWLOWERING-NEXT:    mov sp, x29
-; CHECK-NEWLOWERING-NEXT:    ldp x20, x19, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
 ; CHECK-NEWLOWERING-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEWLOWERING-NEXT:    b shared_za_call
 entry:
@@ -879,7 +874,7 @@ define void @loop_with_external_entry(i1 %c1, i1 %c2) "aarch64_inout_za" nounwin
 ; CHECK-NEWLOWERING-LABEL: loop_with_external_entry:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
 ; CHECK-NEWLOWERING-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    stp x20, x19, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
 ; CHECK-NEWLOWERING-NEXT:    mov x29, sp
 ; CHECK-NEWLOWERING-NEXT:    sub sp, sp, #16
 ; CHECK-NEWLOWERING-NEXT:    rdsvl x8, #1
@@ -892,27 +887,23 @@ define void @loop_with_external_entry(i1 %c1, i1 %c2) "aarch64_inout_za" nounwin
 ; CHECK-NEWLOWERING-NEXT:  // %bb.1: // %init
 ; CHECK-NEWLOWERING-NEXT:    bl shared_za_call
 ; CHECK-NEWLOWERING-NEXT:  .LBB11_2: // %loop.preheader
-; CHECK-NEWLOWERING-NEXT:    sub x20, x29, #16
-; CHECK-NEWLOWERING-NEXT:    b .LBB11_4
+; CHECK-NEWLOWERING-NEXT:    sub x8, x29, #16
+; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, x8
 ; CHECK-NEWLOWERING-NEXT:  .LBB11_3: // %loop
-; CHECK-NEWLOWERING-NEXT:    // in Loop: Header=BB11_4 Depth=1
-; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
-; CHECK-NEWLOWERING-NEXT:    tbz w19, #0, .LBB11_6
-; CHECK-NEWLOWERING-NEXT:  .LBB11_4: // %loop
 ; CHECK-NEWLOWERING-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, x20
 ; CHECK-NEWLOWERING-NEXT:    bl private_za_call
+; CHECK-NEWLOWERING-NEXT:    tbnz w19, #0, .LBB11_3
+; CHECK-NEWLOWERING-NEXT:  // %bb.4: // %exit
 ; CHECK-NEWLOWERING-NEXT:    smstart za
 ; CHECK-NEWLOWERING-NEXT:    mrs x8, TPIDR2_EL0
 ; CHECK-NEWLOWERING-NEXT:    sub x0, x29, #16
-; CHECK-NEWLOWERING-NEXT:    cbnz x8, .LBB11_3
-; CHECK-NEWLOWERING-NEXT:  // %bb.5: // %loop
-; CHECK-NEWLOWERING-NEXT:    // in Loop: Header=BB11_4 Depth=1
+; CHECK-NEWLOWERING-NEXT:    cbnz x8, .LBB11_6
+; CHECK-NEWLOWERING-NEXT:  // %bb.5: // %exit
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_restore
-; CHECK-NEWLOWERING-NEXT:    b .LBB11_3
 ; CHECK-NEWLOWERING-NEXT:  .LBB11_6: // %exit
+; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
 ; CHECK-NEWLOWERING-NEXT:    mov sp, x29
-; CHECK-NEWLOWERING-NEXT:    ldp x20, x19, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
 ; CHECK-NEWLOWERING-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:



More information about the llvm-commits mailing list