[llvm] [AArch64][SME] Merge back-to-back SME call regions (PR #142111)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Fri May 30 02:01:10 PDT 2025


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/142111

Proof-of-concept patch for merging back-to-back SME call regions (e.g. lazy saves, zt0 spills, streaming mode switches). 

>From 9ddd8629f073222f598922dd27b285cf9dcf264b Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 22 May 2025 11:23:25 +0000
Subject: [PATCH 1/3] [AArch64][SME] Simplify initialization of TPIDR2 block

This patch updates the definition of `AArch64ISD::INIT_TPIDR2OBJ` to
take the number of save slices (which is currently always all ZA
slices). Using this, we can initialize the TPIDR2 block with a single
STP of the save buffer pointer and the number of save slices. The
reserved bytes (10-15) will be implicitly zeroed as the result of RDSVL
will always be <= 16-bits. Using an STP is also possible for big-endian
targets with an additional left shift.

Note: We used to write the number of save slices to the TPIDR2 block
before every call with a lazy save; however, based on 6.6.9 "Changes to
the TPIDR2 block" in the aapcs64 (https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#changes-to-the-tpidr2-block),
it seems we can rely on callers preserving the contents of the TPIDR2
block.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 42 ++++++------
 .../lib/Target/AArch64/AArch64SMEInstrInfo.td |  6 +-
 .../AArch64/sme-disable-gisel-fisel.ll        | 23 ++-----
 .../CodeGen/AArch64/sme-lazy-save-call.ll     | 57 ++++++----------
 .../AArch64/sme-shared-za-interface.ll        | 18 ++---
 .../AArch64/sme-tpidr2-init-aarch64be.ll      | 66 +++++++++++++++++++
 .../AArch64/sme-za-lazy-save-buffer.ll        | 30 ++++-----
 llvm/test/CodeGen/AArch64/sme-zt0-state.ll    |  9 +--
 llvm/test/CodeGen/AArch64/stack-hazard.ll     | 32 +++------
 .../CodeGen/AArch64/sve-stack-frame-layout.ll |  9 +--
 10 files changed, 150 insertions(+), 142 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sme-tpidr2-init-aarch64be.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4dacd2273306e..3c1ee0560aef9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3227,20 +3227,24 @@ AArch64TargetLowering::EmitInitTPIDR2Object(MachineInstr &MI,
   TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
   if (TPIDR2.Uses > 0) {
     const TargetInstrInfo *TII = Subtarget->getInstrInfo();
-    // Store the buffer pointer to the TPIDR2 stack object.
-    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRXui))
+    unsigned TPIDInitSaveSlicesReg = MI.getOperand(1).getReg();
+    if (!Subtarget->isLittleEndian()) {
+      unsigned TmpReg =
+          MF->getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+      // For big-endian targets move "num_za_save_slices" to the top two bytes.
+      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::UBFMXri), TmpReg)
+          .addReg(TPIDInitSaveSlicesReg)
+          .addImm(16)
+          .addImm(15);
+      TPIDInitSaveSlicesReg = TmpReg;
+    }
+    // Store buffer pointer and num_za_save_slices.
+    // Bytes 10-15 are implicitly zeroed.
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STPXi))
         .addReg(MI.getOperand(0).getReg())
+        .addReg(TPIDInitSaveSlicesReg)
         .addFrameIndex(TPIDR2.FrameIndex)
         .addImm(0);
-    // Set the reserved bytes (10-15) to zero
-    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRHHui))
-        .addReg(AArch64::WZR)
-        .addFrameIndex(TPIDR2.FrameIndex)
-        .addImm(5);
-    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRWui))
-        .addReg(AArch64::WZR)
-        .addFrameIndex(TPIDR2.FrameIndex)
-        .addImm(3);
   } else
     MFI.RemoveStackObject(TPIDR2.FrameIndex);
 
@@ -8344,9 +8348,12 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
                            {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
       MFI.CreateVariableSizedObject(Align(16), nullptr);
     }
+    SDValue NumZaSaveSlices = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
+                                          DAG.getConstant(1, DL, MVT::i32));
     Chain = DAG.getNode(
         AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
-        {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
+        {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0),
+         /*Num save slices*/ NumZaSaveSlices});
   } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
     // Call __arm_sme_state_size().
     SDValue BufferSize =
@@ -9127,19 +9134,10 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   bool RequiresLazySave = CallAttrs.requiresLazySave();
   bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
   if (RequiresLazySave) {
-    const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
-    MachinePointerInfo MPI =
-        MachinePointerInfo::getStack(MF, TPIDR2.FrameIndex);
+    TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
     SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
         TPIDR2.FrameIndex,
         DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
-    SDValue NumZaSaveSlicesAddr =
-        DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr,
-                    DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType()));
-    SDValue NumZaSaveSlices = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
-                                          DAG.getConstant(1, DL, MVT::i32));
-    Chain = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, NumZaSaveSlicesAddr,
-                              MPI, MVT::i16);
     Chain = DAG.getNode(
         ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
         DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index e7482da001074..8667c778782a1 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -54,10 +54,10 @@ let usesCustomInserter = 1, Defs = [SP], Uses = [SP] in {
 def : Pat<(i64 (AArch64AllocateZABuffer GPR64:$size)),
           (AllocateZABuffer $size)>;
 
-def AArch64InitTPIDR2Obj  : SDNode<"AArch64ISD::INIT_TPIDR2OBJ", SDTypeProfile<0, 1,
-                              [SDTCisInt<0>]>, [SDNPHasChain, SDNPMayStore]>;
+def AArch64InitTPIDR2Obj  : SDNode<"AArch64ISD::INIT_TPIDR2OBJ", SDTypeProfile<0, 2,
+                              [SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain, SDNPMayStore]>;
 let usesCustomInserter = 1 in {
-  def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
+  def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer, GPR64:$save_slices), [(AArch64InitTPIDR2Obj GPR64:$buffer, GPR64:$save_slices)]>, Sched<[WriteI]> {}
 }
 
 // Nodes to allocate a save buffer for SME.
diff --git a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
index 4a52bf27a7591..13ffbc1296217 100644
--- a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
+++ b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
@@ -268,10 +268,7 @@ define double  @za_shared_caller_to_za_none_callee(double %x) nounwind noinline
 ; CHECK-COMMON-NEXT:    mov x9, sp
 ; CHECK-COMMON-NEXT:    msub x9, x8, x8, x9
 ; CHECK-COMMON-NEXT:    mov sp, x9
-; CHECK-COMMON-NEXT:    stur x9, [x29, #-16]
-; CHECK-COMMON-NEXT:    sturh wzr, [x29, #-6]
-; CHECK-COMMON-NEXT:    stur wzr, [x29, #-4]
-; CHECK-COMMON-NEXT:    sturh w8, [x29, #-8]
+; CHECK-COMMON-NEXT:    stp x9, x8, [x29, #-16]
 ; CHECK-COMMON-NEXT:    sub x8, x29, #16
 ; CHECK-COMMON-NEXT:    msr TPIDR2_EL0, x8
 ; CHECK-COMMON-NEXT:    bl normal_callee
@@ -310,12 +307,9 @@ define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_inout_za" nounwind {
 ; CHECK-COMMON-NEXT:    mov x9, sp
 ; CHECK-COMMON-NEXT:    msub x9, x8, x8, x9
 ; CHECK-COMMON-NEXT:    mov sp, x9
-; CHECK-COMMON-NEXT:    stur x9, [x29, #-16]
-; CHECK-COMMON-NEXT:    sub x9, x29, #16
-; CHECK-COMMON-NEXT:    sturh wzr, [x29, #-6]
-; CHECK-COMMON-NEXT:    stur wzr, [x29, #-4]
-; CHECK-COMMON-NEXT:    sturh w8, [x29, #-8]
-; CHECK-COMMON-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-COMMON-NEXT:    sub x10, x29, #16
+; CHECK-COMMON-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-COMMON-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-COMMON-NEXT:    bl __addtf3
 ; CHECK-COMMON-NEXT:    smstart za
 ; CHECK-COMMON-NEXT:    mrs x8, TPIDR2_EL0
@@ -375,12 +369,9 @@ define double @frem_call_za(double %a, double %b) "aarch64_inout_za" nounwind {
 ; CHECK-COMMON-NEXT:    mov x9, sp
 ; CHECK-COMMON-NEXT:    msub x9, x8, x8, x9
 ; CHECK-COMMON-NEXT:    mov sp, x9
-; CHECK-COMMON-NEXT:    stur x9, [x29, #-16]
-; CHECK-COMMON-NEXT:    sub x9, x29, #16
-; CHECK-COMMON-NEXT:    sturh wzr, [x29, #-6]
-; CHECK-COMMON-NEXT:    stur wzr, [x29, #-4]
-; CHECK-COMMON-NEXT:    sturh w8, [x29, #-8]
-; CHECK-COMMON-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-COMMON-NEXT:    sub x10, x29, #16
+; CHECK-COMMON-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-COMMON-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-COMMON-NEXT:    bl fmod
 ; CHECK-COMMON-NEXT:    smstart za
 ; CHECK-COMMON-NEXT:    mrs x8, TPIDR2_EL0
diff --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
index e463e833bdbde..26c6dc4eb978b 100644
--- a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
+++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
@@ -16,12 +16,9 @@ define void @test_lazy_save_1_callee() nounwind "aarch64_inout_za" {
 ; CHECK-NEXT:    mov x9, sp
 ; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x9
-; CHECK-NEXT:    stur x9, [x29, #-16]
-; CHECK-NEXT:    sub x9, x29, #16
-; CHECK-NEXT:    sturh wzr, [x29, #-6]
-; CHECK-NEXT:    stur wzr, [x29, #-4]
-; CHECK-NEXT:    sturh w8, [x29, #-8]
-; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-NEXT:    bl private_za_callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
@@ -43,21 +40,17 @@ define void @test_lazy_save_1_callee() nounwind "aarch64_inout_za" {
 define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
 ; CHECK-LABEL: test_lazy_save_2_callees:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    stp x29, x30, [sp, #-48]! // 16-byte Folded Spill
-; CHECK-NEXT:    str x21, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp x20, x19, [sp, #16] // 16-byte Folded Spill
 ; CHECK-NEXT:    mov x29, sp
-; CHECK-NEXT:    stp x20, x19, [sp, #32] // 16-byte Folded Spill
 ; CHECK-NEXT:    sub sp, sp, #16
-; CHECK-NEXT:    rdsvl x20, #1
-; CHECK-NEXT:    mov x8, sp
-; CHECK-NEXT:    msub x8, x20, x20, x8
-; CHECK-NEXT:    mov sp, x8
-; CHECK-NEXT:    sub x21, x29, #16
-; CHECK-NEXT:    stur x8, [x29, #-16]
-; CHECK-NEXT:    sturh wzr, [x29, #-6]
-; CHECK-NEXT:    stur wzr, [x29, #-4]
-; CHECK-NEXT:    sturh w20, [x29, #-8]
-; CHECK-NEXT:    msr TPIDR2_EL0, x21
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x20, x29, #16
+; CHECK-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-NEXT:    msr TPIDR2_EL0, x20
 ; CHECK-NEXT:    bl private_za_callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
@@ -67,8 +60,7 @@ define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
 ; CHECK-NEXT:    bl __arm_tpidr2_restore
 ; CHECK-NEXT:  .LBB1_2:
 ; CHECK-NEXT:    msr TPIDR2_EL0, xzr
-; CHECK-NEXT:    sturh w20, [x29, #-8]
-; CHECK-NEXT:    msr TPIDR2_EL0, x21
+; CHECK-NEXT:    msr TPIDR2_EL0, x20
 ; CHECK-NEXT:    bl private_za_callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
@@ -79,9 +71,8 @@ define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
 ; CHECK-NEXT:  .LBB1_4:
 ; CHECK-NEXT:    msr TPIDR2_EL0, xzr
 ; CHECK-NEXT:    mov sp, x29
-; CHECK-NEXT:    ldp x20, x19, [sp, #32] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr x21, [sp, #16] // 8-byte Folded Reload
-; CHECK-NEXT:    ldp x29, x30, [sp], #48 // 16-byte Folded Reload
+; CHECK-NEXT:    ldp x20, x19, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
   call void @private_za_callee()
   call void @private_za_callee()
@@ -100,12 +91,9 @@ define float @test_lazy_save_expanded_intrinsic(float %a) nounwind "aarch64_inou
 ; CHECK-NEXT:    mov x9, sp
 ; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x9
-; CHECK-NEXT:    stur x9, [x29, #-16]
-; CHECK-NEXT:    sub x9, x29, #16
-; CHECK-NEXT:    sturh wzr, [x29, #-6]
-; CHECK-NEXT:    stur wzr, [x29, #-4]
-; CHECK-NEXT:    sturh w8, [x29, #-8]
-; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-NEXT:    bl cosf
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
@@ -141,12 +129,9 @@ define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_inout_za
 ; CHECK-NEXT:    mov x9, sp
 ; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x9
-; CHECK-NEXT:    stur x9, [x29, #-80]
-; CHECK-NEXT:    sub x9, x29, #80
-; CHECK-NEXT:    sturh wzr, [x29, #-70]
-; CHECK-NEXT:    stur wzr, [x29, #-68]
-; CHECK-NEXT:    sturh w8, [x29, #-72]
-; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    sub x10, x29, #80
+; CHECK-NEXT:    stp x9, x8, [x29, #-80]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-NEXT:    bl __arm_sme_state
 ; CHECK-NEXT:    and x20, x0, #0x1
 ; CHECK-NEXT:    tbz w20, #0, .LBB3_2
diff --git a/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll b/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
index 393ff3b79aedf..d12c304905b4b 100644
--- a/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
+++ b/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
@@ -14,12 +14,9 @@ define void @disable_tailcallopt() "aarch64_inout_za" nounwind {
 ; CHECK-NEXT:    mov x9, sp
 ; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x9
-; CHECK-NEXT:    stur x9, [x29, #-16]
-; CHECK-NEXT:    sub x9, x29, #16
-; CHECK-NEXT:    sturh wzr, [x29, #-6]
-; CHECK-NEXT:    stur wzr, [x29, #-4]
-; CHECK-NEXT:    sturh w8, [x29, #-8]
-; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-NEXT:    bl private_za_callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
@@ -47,12 +44,9 @@ define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_inout_za" nounwind {
 ; CHECK-NEXT:    mov x9, sp
 ; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x9
-; CHECK-NEXT:    stur x9, [x29, #-16]
-; CHECK-NEXT:    sub x9, x29, #16
-; CHECK-NEXT:    sturh wzr, [x29, #-6]
-; CHECK-NEXT:    stur wzr, [x29, #-4]
-; CHECK-NEXT:    sturh w8, [x29, #-8]
-; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-NEXT:    bl __addtf3
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
diff --git a/llvm/test/CodeGen/AArch64/sme-tpidr2-init-aarch64be.ll b/llvm/test/CodeGen/AArch64/sme-tpidr2-init-aarch64be.ll
new file mode 100644
index 0000000000000..78823e8b4da60
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-tpidr2-init-aarch64be.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64 -aarch64-streaming-hazard-size=0 -mattr=+sve -mattr=+sme < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64_be -aarch64-streaming-hazard-size=0 -mattr=+sve -mattr=+sme < %s | FileCheck %s --check-prefix=CHECK-BE
+
+declare void @private_za_callee()
+declare float @llvm.cos.f32(float)
+
+; Test TPIDR2_EL0 is initialized correctly for AArch64 big-endian.
+define void @test_tpidr2_init() nounwind "aarch64_inout_za" {
+; CHECK-LABEL: test_tpidr2_init:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    bl private_za_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB0_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+;
+; CHECK-BE-LABEL: test_tpidr2_init:
+; CHECK-BE:       // %bb.0:
+; CHECK-BE-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-BE-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-BE-NEXT:    mov x29, sp
+; CHECK-BE-NEXT:    sub sp, sp, #16
+; CHECK-BE-NEXT:    rdsvl x8, #1
+; CHECK-BE-NEXT:    mov x9, sp
+; CHECK-BE-NEXT:    msub x9, x8, x8, x9
+; CHECK-BE-NEXT:    mov sp, x9
+; CHECK-BE-NEXT:    lsl x8, x8, #48
+; CHECK-BE-NEXT:    sub x10, x29, #16
+; CHECK-BE-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-BE-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-BE-NEXT:    bl private_za_callee
+; CHECK-BE-NEXT:    smstart za
+; CHECK-BE-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-BE-NEXT:    sub x0, x29, #16
+; CHECK-BE-NEXT:    cbnz x8, .LBB0_2
+; CHECK-BE-NEXT:  // %bb.1:
+; CHECK-BE-NEXT:    bl __arm_tpidr2_restore
+; CHECK-BE-NEXT:  .LBB0_2:
+; CHECK-BE-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-BE-NEXT:    mov sp, x29
+; CHECK-BE-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-BE-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-BE-NEXT:    ret
+  call void @private_za_callee()
+  ret void
+}
diff --git a/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll b/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll
index ad3f7f5514d0e..256045cbe44f8 100644
--- a/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll
+++ b/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll
@@ -21,11 +21,9 @@ define float @multi_bb_stpidr2_save_required(i32 %a, float %b, float %c) "aarch6
 ; CHECK-NEXT:    .cfi_offset w29, -16
 ; CHECK-NEXT:    rdsvl x8, #1
 ; CHECK-NEXT:    mov x9, sp
-; CHECK-NEXT:    msub x8, x8, x8, x9
-; CHECK-NEXT:    mov sp, x8
-; CHECK-NEXT:    stur x8, [x29, #-16]
-; CHECK-NEXT:    sturh wzr, [x29, #-6]
-; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    stp x9, x8, [x29, #-16]
 ; CHECK-NEXT:    cbz w0, .LBB1_2
 ; CHECK-NEXT:  // %bb.1: // %use_b
 ; CHECK-NEXT:    fmov s1, #4.00000000
@@ -33,10 +31,8 @@ define float @multi_bb_stpidr2_save_required(i32 %a, float %b, float %c) "aarch6
 ; CHECK-NEXT:    b .LBB1_5
 ; CHECK-NEXT:  .LBB1_2: // %use_c
 ; CHECK-NEXT:    fmov s0, s1
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    sub x9, x29, #16
-; CHECK-NEXT:    sturh w8, [x29, #-8]
-; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    sub x8, x29, #16
+; CHECK-NEXT:    msr TPIDR2_EL0, x8
 ; CHECK-NEXT:    bl cosf
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
@@ -77,20 +73,18 @@ define float @multi_bb_stpidr2_save_required_stackprobe(i32 %a, float %b, float
 ; CHECK-NEXT:    .cfi_offset w29, -16
 ; CHECK-NEXT:    rdsvl x8, #1
 ; CHECK-NEXT:    mov x9, sp
-; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:  .LBB2_1: // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    sub sp, sp, #16, lsl #12 // =65536
-; CHECK-NEXT:    cmp sp, x8
+; CHECK-NEXT:    cmp sp, x9
 ; CHECK-NEXT:    b.le .LBB2_3
 ; CHECK-NEXT:  // %bb.2: // in Loop: Header=BB2_1 Depth=1
 ; CHECK-NEXT:    str xzr, [sp]
 ; CHECK-NEXT:    b .LBB2_1
 ; CHECK-NEXT:  .LBB2_3:
-; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    mov sp, x9
 ; CHECK-NEXT:    ldr xzr, [sp]
-; CHECK-NEXT:    stur x8, [x29, #-16]
-; CHECK-NEXT:    sturh wzr, [x29, #-6]
-; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    stp x9, x8, [x29, #-16]
 ; CHECK-NEXT:    cbz w0, .LBB2_5
 ; CHECK-NEXT:  // %bb.4: // %use_b
 ; CHECK-NEXT:    fmov s1, #4.00000000
@@ -98,10 +92,8 @@ define float @multi_bb_stpidr2_save_required_stackprobe(i32 %a, float %b, float
 ; CHECK-NEXT:    b .LBB2_8
 ; CHECK-NEXT:  .LBB2_5: // %use_c
 ; CHECK-NEXT:    fmov s0, s1
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    sub x9, x29, #16
-; CHECK-NEXT:    sturh w8, [x29, #-8]
-; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    sub x8, x29, #16
+; CHECK-NEXT:    msr TPIDR2_EL0, x8
 ; CHECK-NEXT:    bl cosf
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 63577e4d217a8..7d1ca3946f2d7 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -39,13 +39,10 @@ define void @za_zt0_shared_caller_no_state_callee(ptr %callee) "aarch64_inout_za
 ; CHECK-NEXT:    mov x9, sp
 ; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x9
-; CHECK-NEXT:    stur x9, [x29, #-16]
-; CHECK-NEXT:    sub x9, x29, #16
+; CHECK-NEXT:    sub x10, x29, #16
 ; CHECK-NEXT:    sub x19, x29, #80
-; CHECK-NEXT:    sturh wzr, [x29, #-6]
-; CHECK-NEXT:    stur wzr, [x29, #-4]
-; CHECK-NEXT:    sturh w8, [x29, #-8]
-; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-NEXT:    str zt0, [x19]
 ; CHECK-NEXT:    blr x0
 ; CHECK-NEXT:    smstart za
diff --git a/llvm/test/CodeGen/AArch64/stack-hazard.ll b/llvm/test/CodeGen/AArch64/stack-hazard.ll
index 791d7580c327d..f1db7188943db 100644
--- a/llvm/test/CodeGen/AArch64/stack-hazard.ll
+++ b/llvm/test/CodeGen/AArch64/stack-hazard.ll
@@ -2832,12 +2832,9 @@ define i32 @vastate(i32 %x) "aarch64_inout_za" "aarch64_pstate_sm_enabled" "targ
 ; CHECK0-NEXT:    mov w20, w0
 ; CHECK0-NEXT:    msub x9, x8, x8, x9
 ; CHECK0-NEXT:    mov sp, x9
-; CHECK0-NEXT:    stur x9, [x29, #-80]
-; CHECK0-NEXT:    sub x9, x29, #80
-; CHECK0-NEXT:    sturh wzr, [x29, #-70]
-; CHECK0-NEXT:    stur wzr, [x29, #-68]
-; CHECK0-NEXT:    sturh w8, [x29, #-72]
-; CHECK0-NEXT:    msr TPIDR2_EL0, x9
+; CHECK0-NEXT:    sub x10, x29, #80
+; CHECK0-NEXT:    stp x9, x8, [x29, #-80]
+; CHECK0-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK0-NEXT:    .cfi_offset vg, -32
 ; CHECK0-NEXT:    smstop sm
 ; CHECK0-NEXT:    bl other
@@ -2906,12 +2903,9 @@ define i32 @vastate(i32 %x) "aarch64_inout_za" "aarch64_pstate_sm_enabled" "targ
 ; CHECK64-NEXT:    mov w20, w0
 ; CHECK64-NEXT:    msub x9, x8, x8, x9
 ; CHECK64-NEXT:    mov sp, x9
-; CHECK64-NEXT:    stur x9, [x29, #-208]
-; CHECK64-NEXT:    sub x9, x29, #208
-; CHECK64-NEXT:    sturh wzr, [x29, #-198]
-; CHECK64-NEXT:    stur wzr, [x29, #-196]
-; CHECK64-NEXT:    sturh w8, [x29, #-200]
-; CHECK64-NEXT:    msr TPIDR2_EL0, x9
+; CHECK64-NEXT:    sub x10, x29, #208
+; CHECK64-NEXT:    stp x9, x8, [x29, #-208]
+; CHECK64-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK64-NEXT:    .cfi_offset vg, -32
 ; CHECK64-NEXT:    smstop sm
 ; CHECK64-NEXT:    bl other
@@ -2986,16 +2980,10 @@ define i32 @vastate(i32 %x) "aarch64_inout_za" "aarch64_pstate_sm_enabled" "targ
 ; CHECK1024-NEXT:    mov w20, w0
 ; CHECK1024-NEXT:    msub x9, x8, x8, x9
 ; CHECK1024-NEXT:    mov sp, x9
-; CHECK1024-NEXT:    sub x10, x29, #1872
-; CHECK1024-NEXT:    stur x9, [x10, #-256]
-; CHECK1024-NEXT:    sub x9, x29, #1862
-; CHECK1024-NEXT:    sub x10, x29, #1860
-; CHECK1024-NEXT:    sturh wzr, [x9, #-256]
-; CHECK1024-NEXT:    sub x9, x29, #2128
-; CHECK1024-NEXT:    stur wzr, [x10, #-256]
-; CHECK1024-NEXT:    sub x10, x29, #1864
-; CHECK1024-NEXT:    sturh w8, [x10, #-256]
-; CHECK1024-NEXT:    msr TPIDR2_EL0, x9
+; CHECK1024-NEXT:    sub x10, x29, #2128
+; CHECK1024-NEXT:    sub x11, x29, #1616
+; CHECK1024-NEXT:    stp x9, x8, [x11, #-512]
+; CHECK1024-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK1024-NEXT:    .cfi_offset vg, -32
 ; CHECK1024-NEXT:    smstop sm
 ; CHECK1024-NEXT:    bl other
diff --git a/llvm/test/CodeGen/AArch64/sve-stack-frame-layout.ll b/llvm/test/CodeGen/AArch64/sve-stack-frame-layout.ll
index c5cf4593cc86d..e0fe3049289ca 100644
--- a/llvm/test/CodeGen/AArch64/sve-stack-frame-layout.ll
+++ b/llvm/test/CodeGen/AArch64/sve-stack-frame-layout.ll
@@ -548,12 +548,9 @@ define i32 @vastate(i32 %x) "aarch64_inout_za" "aarch64_pstate_sm_enabled" "targ
 ; CHECK-NEXT:    mov w20, w0
 ; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x9
-; CHECK-NEXT:    stur x9, [x29, #-80]
-; CHECK-NEXT:    sub x9, x29, #80
-; CHECK-NEXT:    sturh wzr, [x29, #-70]
-; CHECK-NEXT:    stur wzr, [x29, #-68]
-; CHECK-NEXT:    sturh w8, [x29, #-72]
-; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    sub x10, x29, #80
+; CHECK-NEXT:    stp x9, x8, [x29, #-80]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-NEXT:    .cfi_offset vg, -32
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    bl other

>From 3a45655ed1f08ae4653ecee5954fbceae4a7b85c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 30 May 2025 08:43:30 +0000
Subject: [PATCH 2/3] [AArch64][SME] Abstract SME call lowering with high-level
 ISD nodes (NFC)

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 475 +++++++++++-------
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  20 +-
 .../AArch64/AArch64MachineFunctionInfo.h      |   6 +
 .../AArch64/Utils/AArch64SMEAttributes.h      |   2 +
 4 files changed, 311 insertions(+), 192 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 3c1ee0560aef9..9866149868fd7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2730,6 +2730,9 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::COND_SMSTOP)
     MAKE_CASE(AArch64ISD::RESTORE_ZA)
     MAKE_CASE(AArch64ISD::RESTORE_ZT)
+    MAKE_CASE(AArch64ISD::SME_CALL_START)
+    MAKE_CASE(AArch64ISD::SME_CALL_SM_CHANGE)
+    MAKE_CASE(AArch64ISD::SME_CALL_END)
     MAKE_CASE(AArch64ISD::SAVE_ZT)
     MAKE_CASE(AArch64ISD::CALL)
     MAKE_CASE(AArch64ISD::ADRP)
@@ -7750,9 +7753,258 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
     return LowerPARTIAL_REDUCE_MLA(Op, DAG);
+  case AArch64ISD::SME_CALL_START:
+    return LowerSME_CALL_START(Op, DAG);
+  case AArch64ISD::SME_CALL_SM_CHANGE:
+    return LowerSME_CALL_SM_CHANGE(Op, DAG);
+  case AArch64ISD::SME_CALL_END:
+    return LowerSME_CALL_END(Op, DAG);
   }
 }
 
+static SDNode *findSMECallStart(SDNode *N) {
+  unsigned Opcode = N->getOpcode();
+  if (Opcode == AArch64ISD::SME_CALL_START)
+    return N;
+  if (Opcode == AArch64ISD::SME_CALL_SM_CHANGE)
+    return findSMECallStart(N->getOperand(1).getNode());
+  if (Opcode == AArch64ISD::SME_CALL_END)
+    return findSMECallStart(N->getOperand(1).getNode());
+  llvm_unreachable("Unexpected opcode!");
+}
+
+static SMECallAttrs findSMECallAttrs(SDNode *N) {
+  SDNode *Start = findSMECallStart(N);
+  SMEAttrs CallerAttrs(Start->getConstantOperandVal(1));
+  SMEAttrs CalleeAttrs(Start->getConstantOperandVal(2));
+  SMEAttrs CallsiteAttrs(Start->getConstantOperandVal(3));
+  return SMECallAttrs(CallerAttrs, CalleeAttrs, CallsiteAttrs);
+}
+
+static unsigned getOrCreateZT0SpillSlot(AArch64FunctionInfo *FuncInfo,
+                                        MachineFrameInfo &MFI) {
+  unsigned ZTObj = FuncInfo->getZT0Idx();
+  if (ZTObj == std::numeric_limits<int>::max()) {
+    ZTObj = MFI.CreateSpillStackObject(64, Align(16));
+    FuncInfo->setZT0Idx(ZTObj);
+  }
+  return ZTObj;
+}
+
+// Emit a call to __arm_sme_save or __arm_sme_restore.
+static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
+                                       SelectionDAG &DAG,
+                                       AArch64FunctionInfo *Info, SDLoc DL,
+                                       SDValue Chain, bool IsSave) {
+  MachineFunction &MF = DAG.getMachineFunction();
+  AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
+  FuncInfo->setSMESaveBufferUsed();
+
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  Entry.Ty = PointerType::getUnqual(*DAG.getContext());
+  Entry.Node =
+      DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
+  Args.push_back(Entry);
+
+  SDValue Callee =
+      DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
+                            TLI.getPointerTy(DAG.getDataLayout()));
+  auto *RetTy = Type::getVoidTy(*DAG.getContext());
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
+      CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
+      Callee, std::move(Args));
+  return TLI.LowerCallTo(CLI).second;
+}
+
+static AArch64SME::ToggleCondition
+getSMToggleCondition(const SMECallAttrs &CallAttrs) {
+  if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
+      CallAttrs.caller().hasStreamingBody())
+    return AArch64SME::Always;
+  if (CallAttrs.callee().hasNonStreamingInterface())
+    return AArch64SME::IfCallerIsStreaming;
+  if (CallAttrs.callee().hasStreamingInterface())
+    return AArch64SME::IfCallerIsNonStreaming;
+
+  llvm_unreachable("Unsupported attributes");
+}
+
+SDValue AArch64TargetLowering::LowerSME_CALL_START(SDValue Op,
+                                                   SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SMECallAttrs CallAttrs = findSMECallAttrs(Op.getNode());
+  auto &MF = DAG.getMachineFunction();
+  auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
+  bool RequiresLazySave = CallAttrs.requiresLazySave();
+  bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
+
+  SDValue Chain = Op->getOperand(0);
+  if (RequiresLazySave) {
+    TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
+    SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
+        TPIDR2.FrameIndex,
+        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+    Chain = DAG.getNode(
+        ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
+        DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
+        TPIDR2ObjAddr);
+  } else if (RequiresSaveAllZA) {
+    assert(!CallAttrs.callee().hasSharedZAInterface() &&
+           "Cannot share state that may not exist");
+    Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+                                    /*IsSave=*/true);
+  }
+
+  SDValue PStateSM;
+  bool RequiresSMChange = CallAttrs.requiresSMChange();
+  if (RequiresSMChange) {
+    if (CallAttrs.caller().hasStreamingInterfaceOrBody())
+      PStateSM = DAG.getConstant(1, DL, MVT::i64);
+    else if (CallAttrs.caller().hasNonStreamingInterface())
+      PStateSM = DAG.getConstant(0, DL, MVT::i64);
+    else
+      PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
+  } else {
+    PStateSM = DAG.getUNDEF(MVT::i64);
+  }
+
+  SDValue ZTFrameIdx;
+  bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
+
+  // If the caller has ZT0 state which will not be preserved by the callee,
+  // spill ZT0 before the call.
+  if (ShouldPreserveZT0) {
+    MachineFrameInfo &MFI = MF.getFrameInfo();
+    unsigned ZTObj = getOrCreateZT0SpillSlot(FuncInfo, MFI);
+    ZTFrameIdx = DAG.getFrameIndex(
+        ZTObj,
+        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+
+    Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
+                        {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+  }
+
+  // If caller shares ZT0 but the callee is not shared ZA, we need to stop
+  // PSTATE.ZA before the call if there is no lazy-save active.
+  bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
+  assert((!DisableZA || !RequiresLazySave) &&
+         "Lazy-save should have PSTATE.SM=1 on entry to the function");
+
+  if (DisableZA)
+    Chain = DAG.getNode(
+        AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
+        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
+
+  return DAG.getMergeValues({PStateSM, Chain, DAG.getUNDEF(MVT::Glue)}, DL);
+}
+
+SDValue
+AArch64TargetLowering::LowerSME_CALL_SM_CHANGE(SDValue Op,
+                                               SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SMECallAttrs CallAttrs = findSMECallAttrs(Op.getNode());
+  SDValue Chain = Op->getOperand(0);
+  SDValue PStateSM = Op->getOperand(2);
+  SDValue InGlue;
+
+  if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
+    Chain = DAG.getNode(AArch64ISD::VG_SAVE, DL,
+                        DAG.getVTList(MVT::Other, MVT::Glue), Chain);
+    InGlue = Chain.getValue(1);
+  }
+
+  SDValue NewChain = changeStreamingMode(
+      DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
+      getSMToggleCondition(CallAttrs), PStateSM);
+
+  return DAG.getMergeValues({NewChain, NewChain.getValue(1)}, DL);
+}
+
+SDValue AArch64TargetLowering::LowerSME_CALL_END(SDValue Op,
+                                                 SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SMECallAttrs CallAttrs = findSMECallAttrs(Op.getNode());
+  bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
+  auto &MF = DAG.getMachineFunction();
+  auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
+  SDValue Result = Op->getOperand(0);
+  SDValue PStateSM = Op->getOperand(2);
+  SDValue InGlue = Op->getOperand(3);
+  const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+  bool RequiresLazySave = CallAttrs.requiresLazySave();
+  bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
+  bool RequiresSMChange = CallAttrs.requiresSMChange();
+
+  if (RequiresSMChange) {
+    assert(PStateSM && "Expected a PStateSM to be set");
+    Result = changeStreamingMode(
+        DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
+        getSMToggleCondition(CallAttrs), PStateSM);
+
+    if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
+      InGlue = Result.getValue(1);
+      Result =
+          DAG.getNode(AArch64ISD::VG_RESTORE, DL,
+                      DAG.getVTList(MVT::Other, MVT::Glue), {Result, InGlue});
+    }
+  }
+
+  if (CallAttrs.requiresEnablingZAAfterCall())
+    // Unconditionally resume ZA.
+    Result = DAG.getNode(
+        AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
+        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
+
+  if (ShouldPreserveZT0) {
+    MachineFrameInfo &MFI = MF.getFrameInfo();
+    unsigned ZTObj = getOrCreateZT0SpillSlot(FuncInfo, MFI);
+    SDValue ZTFrameIdx = DAG.getFrameIndex(
+        ZTObj,
+        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+    Result =
+        DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
+                    {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+  }
+
+  if (RequiresLazySave) {
+    // Conditionally restore the lazy save using a pseudo node.
+    TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
+    SDValue RegMask = DAG.getRegisterMask(
+        TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
+    SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
+        "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
+    SDValue TPIDR2_EL0 = DAG.getNode(
+        ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
+        DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
+
+    // Copy the address of the TPIDR2 block into X0 before 'calling' the
+    // RESTORE_ZA pseudo.
+    SDValue Glue;
+    SDValue TPIDR2Block = DAG.getFrameIndex(
+        TPIDR2.FrameIndex,
+        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+    Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
+    Result =
+        DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
+                    {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
+                     RestoreRoutine, RegMask, Result.getValue(1)});
+
+    // Finally reset the TPIDR2_EL0 register to 0.
+    Result = DAG.getNode(
+        ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
+        DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
+        DAG.getConstant(0, DL, MVT::i64));
+    TPIDR2.Uses++;
+  } else if (RequiresSaveAllZA) {
+    Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result,
+                                     /*IsSave=*/false);
+  }
+
+  return DAG.getMergeValues({Result, DAG.getUNDEF(MVT::Glue)}, DL);
+}
+
 bool AArch64TargetLowering::mergeStoresAfterLegalization(EVT VT) const {
   return !Subtarget->useSVEForFixedLengthVectors();
 }
@@ -8949,46 +9201,6 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
   return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
 }
 
-// Emit a call to __arm_sme_save or __arm_sme_restore.
-static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
-                                       SelectionDAG &DAG,
-                                       AArch64FunctionInfo *Info, SDLoc DL,
-                                       SDValue Chain, bool IsSave) {
-  MachineFunction &MF = DAG.getMachineFunction();
-  AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
-  FuncInfo->setSMESaveBufferUsed();
-
-  TargetLowering::ArgListTy Args;
-  TargetLowering::ArgListEntry Entry;
-  Entry.Ty = PointerType::getUnqual(*DAG.getContext());
-  Entry.Node =
-      DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
-  Args.push_back(Entry);
-
-  SDValue Callee =
-      DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
-                            TLI.getPointerTy(DAG.getDataLayout()));
-  auto *RetTy = Type::getVoidTy(*DAG.getContext());
-  TargetLowering::CallLoweringInfo CLI(DAG);
-  CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
-      CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
-      Callee, std::move(Args));
-  return TLI.LowerCallTo(CLI).second;
-}
-
-static AArch64SME::ToggleCondition
-getSMToggleCondition(const SMECallAttrs &CallAttrs) {
-  if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
-      CallAttrs.caller().hasStreamingBody())
-    return AArch64SME::Always;
-  if (CallAttrs.callee().hasNonStreamingInterface())
-    return AArch64SME::IfCallerIsStreaming;
-  if (CallAttrs.callee().hasStreamingInterface())
-    return AArch64SME::IfCallerIsNonStreaming;
-
-  llvm_unreachable("Unsupported attributes");
-}
-
 /// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
 /// and add input and output parameter nodes.
 SDValue
@@ -9118,91 +9330,25 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   // Determine whether we need any streaming mode changes.
   SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
 
-  auto DescribeCallsite =
-      [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
-    R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
-    if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
-      R << ore::NV("Callee", ES->getSymbol());
-    else if (CLI.CB && CLI.CB->getCalledFunction())
-      R << ore::NV("Callee", CLI.CB->getCalledFunction()->getName());
-    else
-      R << "unknown callee";
-    R << "'";
-    return R;
-  };
-
+  SDValue SMECallStart;
+  SDValue PStateSM;
   bool RequiresLazySave = CallAttrs.requiresLazySave();
   bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
-  if (RequiresLazySave) {
-    TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
-    SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
-        TPIDR2.FrameIndex,
-        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
-    Chain = DAG.getNode(
-        ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
-        DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
-        TPIDR2ObjAddr);
-    OptimizationRemarkEmitter ORE(&MF.getFunction());
-    ORE.emit([&]() {
-      auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMELazySaveZA",
-                                                   CLI.CB)
-                      : OptimizationRemarkAnalysis("sme", "SMELazySaveZA",
-                                                   &MF.getFunction());
-      return DescribeCallsite(R) << " sets up a lazy save for ZA";
-    });
-  } else if (RequiresSaveAllZA) {
-    assert(!CallAttrs.callee().hasSharedZAInterface() &&
-           "Cannot share state that may not exist");
-    Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
-                                    /*IsSave=*/true);
-  }
-
-  SDValue PStateSM;
   bool RequiresSMChange = CallAttrs.requiresSMChange();
-  if (RequiresSMChange) {
-    if (CallAttrs.caller().hasStreamingInterfaceOrBody())
-      PStateSM = DAG.getConstant(1, DL, MVT::i64);
-    else if (CallAttrs.caller().hasNonStreamingInterface())
-      PStateSM = DAG.getConstant(0, DL, MVT::i64);
-    else
-      PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
-    OptimizationRemarkEmitter ORE(&MF.getFunction());
-    ORE.emit([&]() {
-      auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
-                                                   CLI.CB)
-                      : OptimizationRemarkAnalysis("sme", "SMETransition",
-                                                   &MF.getFunction());
-      DescribeCallsite(R) << " requires a streaming mode transition";
-      return R;
-    });
-  }
-
-  SDValue ZTFrameIdx;
-  MachineFrameInfo &MFI = MF.getFrameInfo();
   bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
-
-  // If the caller has ZT0 state which will not be preserved by the callee,
-  // spill ZT0 before the call.
-  if (ShouldPreserveZT0) {
-    unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
-    ZTFrameIdx = DAG.getFrameIndex(
-        ZTObj,
-        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
-
-    Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
-                        {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
-  }
-
-  // If caller shares ZT0 but the callee is not shared ZA, we need to stop
-  // PSTATE.ZA before the call if there is no lazy-save active.
   bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
-  assert((!DisableZA || !RequiresLazySave) &&
-         "Lazy-save should have PSTATE.SM=1 on entry to the function");
+  bool IsSMECall = RequiresLazySave || RequiresSaveAllZA || RequiresSMChange ||
+                   ShouldPreserveZT0 || DisableZA;
 
-  if (DisableZA)
-    Chain = DAG.getNode(
-        AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
-        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
+  if (IsSMECall) {
+    PStateSM = DAG.getNode(
+        AArch64ISD::SME_CALL_START, DL,
+        DAG.getVTList(MVT::i64, MVT::Other, MVT::Glue), Chain,
+        DAG.getTargetConstant(unsigned(CallAttrs.caller()), DL, MVT::i32),
+        DAG.getTargetConstant(unsigned(CallAttrs.callee()), DL, MVT::i32),
+        DAG.getTargetConstant(unsigned(CallAttrs.callsite()), DL, MVT::i32));
+    Chain = SMECallStart = PStateSM.getValue(1);
+  }
 
   // Adjust the stack pointer for the new arguments...
   // These operations are automatically eliminated by the prolog/epilog pass
@@ -9470,18 +9616,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains);
 
   SDValue InGlue;
+  SDValue SMECallSMChange;
   if (RequiresSMChange) {
-    if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
-      Chain = DAG.getNode(AArch64ISD::VG_SAVE, DL,
-                          DAG.getVTList(MVT::Other, MVT::Glue), Chain);
-      InGlue = Chain.getValue(1);
-    }
-
-    SDValue NewChain = changeStreamingMode(
-        DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
-        getSMToggleCondition(CallAttrs), PStateSM);
-    Chain = NewChain.getValue(0);
-    InGlue = NewChain.getValue(1);
+    SMECallSMChange = DAG.getNode(AArch64ISD::SME_CALL_SM_CHANGE, DL,
+                                  DAG.getVTList(MVT::Other, MVT::Glue), Chain,
+                                  SMECallStart, PStateSM);
+    Chain = SMECallSMChange;
+    InGlue = SMECallSMChange.getValue(1);
   }
 
   // Build a sequence of copy-to-reg nodes chained together with token chain
@@ -9662,63 +9803,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   if (!Ins.empty())
     InGlue = Result.getValue(Result->getNumValues() - 1);
 
-  if (RequiresSMChange) {
-    assert(PStateSM && "Expected a PStateSM to be set");
-    Result = changeStreamingMode(
-        DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
-        getSMToggleCondition(CallAttrs), PStateSM);
-
-    if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
-      InGlue = Result.getValue(1);
-      Result =
-          DAG.getNode(AArch64ISD::VG_RESTORE, DL,
-                      DAG.getVTList(MVT::Other, MVT::Glue), {Result, InGlue});
-    }
-  }
-
-  if (CallAttrs.requiresEnablingZAAfterCall())
-    // Unconditionally resume ZA.
-    Result = DAG.getNode(
-        AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
-        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
-
-  if (ShouldPreserveZT0)
-    Result =
-        DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
-                    {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
-
-  if (RequiresLazySave) {
-    // Conditionally restore the lazy save using a pseudo node.
-    TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
-    SDValue RegMask = DAG.getRegisterMask(
-        TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
-    SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
-        "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
-    SDValue TPIDR2_EL0 = DAG.getNode(
-        ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
-        DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
-
-    // Copy the address of the TPIDR2 block into X0 before 'calling' the
-    // RESTORE_ZA pseudo.
-    SDValue Glue;
-    SDValue TPIDR2Block = DAG.getFrameIndex(
-        TPIDR2.FrameIndex,
-        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
-    Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
-    Result =
-        DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
-                    {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
-                     RestoreRoutine, RegMask, Result.getValue(1)});
-
-    // Finally reset the TPIDR2_EL0 register to 0.
-    Result = DAG.getNode(
-        ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
-        DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
-        DAG.getConstant(0, DL, MVT::i64));
-    TPIDR2.Uses++;
-  } else if (RequiresSaveAllZA) {
-    Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result,
-                                     /*IsSave=*/false);
+  if (IsSMECall) {
+    Result = DAG.getNode(AArch64ISD::SME_CALL_END, DL,
+                         DAG.getVTList(MVT::Other, MVT::Glue), Result,
+                         SMECallSMChange ? SMECallSMChange : SMECallStart,
+                         PStateSM ? PStateSM : DAG.getUNDEF(MVT::i64), InGlue);
   }
 
   if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
@@ -27828,6 +27917,18 @@ AArch64TargetLowering::getPreferredVectorAction(MVT VT) const {
   return TargetLoweringBase::getPreferredVectorAction(VT);
 }
 
+TargetLoweringBase::LegalizeAction
+AArch64TargetLowering::getCustomOperationAction(SDNode &N) const {
+  switch (N.getOpcode()) {
+  default:
+    return Legal;
+  case AArch64ISD::SME_CALL_START:
+  case AArch64ISD::SME_CALL_SM_CHANGE:
+  case AArch64ISD::SME_CALL_END:
+    return Custom;
+  }
+}
+
 // In v8.4a, ldp and stp instructions are guaranteed to be single-copy atomic
 // provided the address is 16-byte aligned.
 bool AArch64TargetLowering::isOpSuitableForLDPSTP(const Instruction *I) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index b59526bf01888..174e19f3badae 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -79,6 +79,10 @@ enum NodeType : unsigned {
   RESTORE_ZT,
   SAVE_ZT,
 
+  SME_CALL_START,
+  SME_CALL_SM_CHANGE,
+  SME_CALL_END,
+
   // A call with the callee in x16, i.e. "blr x16".
   CALL_ARM64EC_TO_X64,
 
@@ -823,6 +827,9 @@ class AArch64TargetLowering : public TargetLowering {
   TargetLoweringBase::LegalizeTypeAction
   getPreferredVectorAction(MVT VT) const override;
 
+  TargetLoweringBase::LegalizeAction
+  getCustomOperationAction(SDNode &) const override;
+
   /// If the target has a standard location for the stack protector cookie,
   /// returns the address of that location. Otherwise, returns nullptr.
   Value *getIRStackGuard(IRBuilderBase &IRB) const override;
@@ -1028,6 +1035,11 @@ class AArch64TargetLowering : public TargetLowering {
   /// True if stack clash protection is enabled for this functions.
   bool hasInlineStackProbe(const MachineFunction &MF) const override;
 
+  // Returns the runtime value for PSTATE.SM by generating a call to
+  // __arm_sme_state.
+  SDValue getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL,
+                             EVT VT) const;
+
 private:
   /// Keep a pointer to the AArch64Subtarget around so that we can
   /// make the right decision when generating code for different targets.
@@ -1211,6 +1223,9 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerSME_CALL_START(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerSME_CALL_SM_CHANGE(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerSME_CALL_END(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const;
 
@@ -1347,11 +1362,6 @@ class AArch64TargetLowering : public TargetLowering {
   // This function does not handle predicate bitcasts.
   SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
 
-  // Returns the runtime value for PSTATE.SM by generating a call to
-  // __arm_sme_state.
-  SDValue getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL,
-                             EVT VT) const;
-
   bool preferScalarizeSplat(SDNode *N) const override;
 
   unsigned getMinimumJumpTableEntries() const override;
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index d3026ca45c349..02cd0398bea04 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -245,6 +245,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   int64_t VGIdx = std::numeric_limits<int>::max();
   int64_t StreamingVGIdx = std::numeric_limits<int>::max();
 
+  // The stack slot where ZT0 is stored.
+  int64_t ZT0Idx = std::numeric_limits<int>::max();
+
 public:
   AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
 
@@ -275,6 +278,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   int64_t getStreamingVGIdx() const { return StreamingVGIdx; };
   void setStreamingVGIdx(unsigned FrameIdx) { StreamingVGIdx = FrameIdx; };
 
+  int64_t getZT0Idx() const { return ZT0Idx; };
+  void setZT0Idx(unsigned FrameIdx) { ZT0Idx = FrameIdx; };
+
   bool isSVECC() const { return IsSVECC; };
   void setIsSVECC(bool s) { IsSVECC = s; };
 
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index f1be0ecbee7ed..4d136f1f65263 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -145,6 +145,8 @@ class SMEAttrs {
     return Bitmask == Other.Bitmask;
   }
 
+  explicit operator unsigned() const { return Bitmask; }
+
 private:
   void addKnownFunctionAttrs(StringRef FuncName);
 };

>From 2efbedd239a3a8cfa002065bfde3570a9b999fb0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 30 May 2025 08:45:20 +0000
Subject: [PATCH 3/3] [AArch64][SME] Merge back-to-back SME call regions

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 61 +++++++++++++++++++
 .../AArch64/Utils/AArch64SMEAttributes.h      |  7 +++
 .../CodeGen/AArch64/sme-lazy-save-call.ll     | 19 ++----
 .../test/CodeGen/AArch64/sme-peephole-opts.ll | 27 ++------
 4 files changed, 79 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9866149868fd7..54c616f15b8cd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -26719,6 +26719,65 @@ static SDValue performSHLCombine(SDNode *N,
   return DAG.getNode(ISD::AND, DL, VT, NewShift, NewRHS);
 }
 
+static SDValue performSMECallCombine(SDNode *SMECallEnd,
+                                     TargetLowering::DAGCombinerInfo &DCI,
+                                     SelectionDAG &DAG) {
+  SDLoc DL(SMECallEnd);
+  SDNode *SMECallStart = findSMECallStart(SMECallEnd);
+  SDValue StartOutGlue = SDValue(SMECallStart, 2);
+  if (!StartOutGlue.use_empty())
+    return SDValue();
+
+  SDValue StartChain = SMECallStart->getOperand(0);
+  if (StartChain->getOpcode() != AArch64ISD::SME_CALL_END)
+    return SDValue();
+
+  // TODO: This probably does not need to be a CALLSEQ_END.
+  SDNode *PrevSMECallEnd = StartChain.getNode();
+  SDValue PrevCallOutChain = PrevSMECallEnd->getOperand(0);
+  if (PrevCallOutChain->getOpcode() != ISD::CALLSEQ_END)
+    return SDValue();
+
+  SDNode *PrevSMECallStart = findSMECallStart(PrevSMECallEnd);
+  SMECallAttrs CallAttrs = findSMECallAttrs(SMECallStart);
+  SMECallAttrs PrevCallAttrs = findSMECallAttrs(PrevSMECallStart);
+
+  // TODO: Handle case where we're already in (e.g.) streaming mode and just
+  // need to enable ZA etc.
+  if (CallAttrs != PrevCallAttrs)
+    return SDValue();
+
+  SDNode *MaybeSMSwitch = PrevSMECallEnd->getOperand(1).getNode();
+  if (MaybeSMSwitch->getOpcode() == AArch64ISD::SME_CALL_SM_CHANGE) {
+    SDNode *SMSwitch = SMECallEnd->getOperand(1).getNode();
+    // Remove duplicate SME_CALL_SM_CHANGE.
+    // FIXME: Can we avoid adding fake glue? This is needed as we need to remove
+    // this (duplicate) SME_CALL_SM_CHANGE, but it has been glued to another
+    // node so we need something to replace the glue.
+    SDValue FakeGlue = DAG.getUNDEF(MVT::Glue);
+    SDValue NopSMESwitch = DAG.getMergeValues(
+        {SMSwitch->getOperand(0), FakeGlue}, SDLoc(SMECallEnd));
+    DAG.ReplaceAllUsesWith(SMSwitch, NopSMESwitch.getNode());
+  }
+
+  // Update the last SME_CALL_END to point to the SME_CALL_START or
+  // SME_CALL_SM_CHANGE from the previous SME_CALL region:
+  DAG.UpdateNodeOperands(
+      SMECallEnd,
+      /*Chain=*/SMECallEnd->getOperand(0),
+      /*SMECallStart Or SMSwitch=*/PrevSMECallEnd->getOperand(1),
+      /*PStateSM=*/SMECallEnd->getOperand(2),
+      /*InGlue=*/SMECallEnd->getOperand(3));
+  // Remove the SME_CALL_END for the previous SME_CALL region:
+  DAG.ReplaceAllUsesWith(PrevSMECallEnd, PrevCallOutChain.getNode());
+  // Remove the SME_CALL_START for the current/next SME_CALL region:
+  SDValue NopSMECallStart = DAG.getMergeValues(
+      {SDValue(PrevSMECallStart, 0), PrevCallOutChain, DAG.getUNDEF(MVT::Glue)},
+      DL);
+  DAG.ReplaceAllUsesWith(SMECallStart, NopSMECallStart.getNode());
+  return SDValue();
+}
+
 SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
                                                  DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
@@ -26880,6 +26939,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case AArch64ISD::UMULL:
   case AArch64ISD::PMULL:
     return performMULLCombine(N, DCI, DAG);
+  case AArch64ISD::SME_CALL_END:
+    return performSMECallCombine(N, DCI, DAG);
   case ISD::INTRINSIC_VOID:
   case ISD::INTRINSIC_W_CHAIN:
     switch (N->getConstantOperandVal(1)) {
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 4d136f1f65263..2c3c8dff691a9 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -203,6 +203,13 @@ class SMECallAttrs {
     return caller().hasAgnosticZAInterface() &&
            !callee().hasAgnosticZAInterface() && !callee().isSMEABIRoutine();
   }
+
+  bool operator==(SMECallAttrs const &Other) const {
+    return caller() == Other.caller() && callee() == Other.callee() &&
+           callsite() == Other.callsite();
+  }
+
+  bool operator!=(SMECallAttrs const &Other) const { return !(*this == Other); }
 };
 
 } // namespace llvm
diff --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
index 26c6dc4eb978b..cb72a8eb3d5ac 100644
--- a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
+++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
@@ -41,16 +41,17 @@ define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
 ; CHECK-LABEL: test_lazy_save_2_callees:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
-; CHECK-NEXT:    stp x20, x19, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
 ; CHECK-NEXT:    mov x29, sp
 ; CHECK-NEXT:    sub sp, sp, #16
 ; CHECK-NEXT:    rdsvl x8, #1
 ; CHECK-NEXT:    mov x9, sp
 ; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x9
-; CHECK-NEXT:    sub x20, x29, #16
+; CHECK-NEXT:    sub x10, x29, #16
 ; CHECK-NEXT:    stp x9, x8, [x29, #-16]
-; CHECK-NEXT:    msr TPIDR2_EL0, x20
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    bl private_za_callee
 ; CHECK-NEXT:    bl private_za_callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
@@ -60,18 +61,8 @@ define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
 ; CHECK-NEXT:    bl __arm_tpidr2_restore
 ; CHECK-NEXT:  .LBB1_2:
 ; CHECK-NEXT:    msr TPIDR2_EL0, xzr
-; CHECK-NEXT:    msr TPIDR2_EL0, x20
-; CHECK-NEXT:    bl private_za_callee
-; CHECK-NEXT:    smstart za
-; CHECK-NEXT:    mrs x8, TPIDR2_EL0
-; CHECK-NEXT:    sub x0, x29, #16
-; CHECK-NEXT:    cbnz x8, .LBB1_4
-; CHECK-NEXT:  // %bb.3:
-; CHECK-NEXT:    bl __arm_tpidr2_restore
-; CHECK-NEXT:  .LBB1_4:
-; CHECK-NEXT:    msr TPIDR2_EL0, xzr
 ; CHECK-NEXT:    mov sp, x29
-; CHECK-NEXT:    ldp x20, x19, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
 ; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
   call void @private_za_callee()
diff --git a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
index 130a316bcc2ba..400f8937df90c 100644
--- a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
@@ -75,21 +75,11 @@ define void @test2() nounwind "aarch64_pstate_sm_compatible" {
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:  .LBB2_2:
 ; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    tbz w19, #0, .LBB2_4
 ; CHECK-NEXT:  // %bb.3:
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:  .LBB2_4:
-; CHECK-NEXT:    bl __arm_sme_state
-; CHECK-NEXT:    and x19, x0, #0x1
-; CHECK-NEXT:    tbz w19, #0, .LBB2_6
-; CHECK-NEXT:  // %bb.5:
-; CHECK-NEXT:    smstop sm
-; CHECK-NEXT:  .LBB2_6:
-; CHECK-NEXT:    bl callee
-; CHECK-NEXT:    tbz w19, #0, .LBB2_8
-; CHECK-NEXT:  // %bb.7:
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:  .LBB2_8:
 ; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldr x19, [sp, #80] // 8-byte Folded Reload
 ; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
@@ -252,22 +242,17 @@ define float @test6(float %f) nounwind "aarch64_pstate_sm_enabled" {
 define void @test7() nounwind "aarch64_inout_zt0" {
 ; CHECK-LABEL: test7:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    sub sp, sp, #144
-; CHECK-NEXT:    stp x30, x19, [sp, #128] // 16-byte Folded Spill
-; CHECK-NEXT:    add x19, sp, #64
-; CHECK-NEXT:    str zt0, [x19]
-; CHECK-NEXT:    smstop za
-; CHECK-NEXT:    bl callee
-; CHECK-NEXT:    smstart za
-; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
 ; CHECK-NEXT:    mov x19, sp
 ; CHECK-NEXT:    str zt0, [x19]
 ; CHECK-NEXT:    smstop za
 ; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    ldr zt0, [x19]
-; CHECK-NEXT:    ldp x30, x19, [sp, #128] // 16-byte Folded Reload
-; CHECK-NEXT:    add sp, sp, #144
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #80
 ; CHECK-NEXT:    ret
   call void @callee()
   call void @callee()



More information about the llvm-commits mailing list