[llvm] [AArch64][SME] Handle SME state around TLS-descriptor calls (PR #155608)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 4 08:57:01 PST 2025


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/155608

>From b2ccb35fe3c946f3336b8b94a515f5ac3cc138e5 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 26 Aug 2025 16:07:21 +0000
Subject: [PATCH 1/2] [AArch64][SME] Handle SME state around TLS-descriptor
 calls

This patch ensures we switch out of streaming mode before TLS-descriptor
calls. ZA state will also be preserved when using the new SME ABI
lowering (`-aarch64-new-sme-abi`).

Fixes #152165
---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  24 ++-
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |   4 +-
 llvm/lib/Target/AArch64/MachineSMEABIPass.cpp |   6 +
 llvm/test/CodeGen/AArch64/sme-dynamic-tls.ll  | 160 ++++++++++++++++++
 4 files changed, 190 insertions(+), 4 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sme-dynamic-tls.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 60aa61e993b26..b85639cfb4d8f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -10623,16 +10623,36 @@ SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr,
                                                       const SDLoc &DL,
                                                       SelectionDAG &DAG) const {
   EVT PtrVT = getPointerTy(DAG.getDataLayout());
+  auto &MF = DAG.getMachineFunction();
+  auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
 
+  SDValue Glue;
   SDValue Chain = DAG.getEntryNode();
   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
 
+  SMECallAttrs TLSCallAttrs(FuncInfo->getSMEFnAttrs(), {}, SMEAttrs::Normal);
+  bool RequiresSMChange = TLSCallAttrs.requiresSMChange();
+
+  if (RequiresSMChange) {
+    Chain = changeStreamingMode(DAG, DL, /*Enable=*/false, Chain, Glue,
+                                getSMToggleCondition(TLSCallAttrs));
+    Glue = Chain.getValue(1);
+  }
+
   unsigned Opcode =
       DAG.getMachineFunction().getInfo<AArch64FunctionInfo>()->hasELFSignedGOT()
           ? AArch64ISD::TLSDESC_AUTH_CALLSEQ
           : AArch64ISD::TLSDESC_CALLSEQ;
-  Chain = DAG.getNode(Opcode, DL, NodeTys, {Chain, SymAddr});
-  SDValue Glue = Chain.getValue(1);
+  SDValue Ops[] = {Chain, SymAddr, Glue};
+  Chain = DAG.getNode(Opcode, DL, NodeTys,
+                      Glue ? ArrayRef(Ops) : ArrayRef(Ops).drop_back());
+  Glue = Chain.getValue(1);
+
+  if (RequiresSMChange) {
+    Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
+                                getSMToggleCondition(TLSCallAttrs));
+    Glue = Chain.getValue(1);
+  }
 
   return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue);
 }
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 2871a20e28b65..027ed215fe727 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1034,11 +1034,11 @@ def AArch64uitof: SDNode<"AArch64ISD::UITOF", SDT_AArch64ITOF>;
 // offset of a variable into X0, using the TLSDesc model.
 def AArch64tlsdesc_callseq : SDNode<"AArch64ISD::TLSDESC_CALLSEQ",
                                     SDT_AArch64TLSDescCallSeq,
-                                    [SDNPOutGlue, SDNPHasChain, SDNPVariadic]>;
+                                    [SDNPOutGlue, SDNPOptInGlue, SDNPHasChain, SDNPVariadic]>;
 
 def AArch64tlsdesc_auth_callseq : SDNode<"AArch64ISD::TLSDESC_AUTH_CALLSEQ",
                                     SDT_AArch64TLSDescCallSeq,
-                                    [SDNPOutGlue, SDNPHasChain, SDNPVariadic]>;
+                                    [SDNPOutGlue, SDNPOptInGlue, SDNPHasChain, SDNPVariadic]>;
 
 def AArch64WrapperLarge : SDNode<"AArch64ISD::WrapperLarge",
                                  SDT_AArch64WrapperLarge>;
diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index 7cb500394cec2..0f6d930cc1fba 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -231,6 +231,12 @@ getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
   if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
     return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
 
+  // TLS-descriptor calls don't use the standard call lowering, so handle them
+  // as a special case here. Assume a private ZA interface.
+  if (MI.getOpcode() == AArch64::TLSDESC_CALLSEQ ||
+      MI.getOpcode() == AArch64::TLSDESC_AUTH_CALLSEQ)
+    return {ZAState::LOCAL_SAVED, InsertPt};
+
   if (MI.isReturn())
     return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
 
diff --git a/llvm/test/CodeGen/AArch64/sme-dynamic-tls.ll b/llvm/test/CodeGen/AArch64/sme-dynamic-tls.ll
new file mode 100644
index 0000000000000..0f6c0c9705cc5
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-dynamic-tls.ll
@@ -0,0 +1,160 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-unknown-linux-gnu -mattr=+sme -aarch64-new-sme-abi -relocation-model=pic < %s | FileCheck %s
+
+ at x = external thread_local local_unnamed_addr global i32, align 4
+
+define i32 @load_tls_streaming_compat() nounwind "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: load_tls_streaming_compat:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    mov x8, x0
+; CHECK-NEXT:    tbz w8, #0, .LBB0_2
+; CHECK-NEXT:  // %bb.1: // %entry
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB0_2: // %entry
+; CHECK-NEXT:    adrp x0, :tlsdesc:x
+; CHECK-NEXT:    ldr x1, [x0, :tlsdesc_lo12:x]
+; CHECK-NEXT:    add x0, x0, :tlsdesc_lo12:x
+; CHECK-NEXT:    .tlsdesccall x
+; CHECK-NEXT:    blr x1
+; CHECK-NEXT:    tbz w8, #0, .LBB0_4
+; CHECK-NEXT:  // %bb.3: // %entry
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB0_4: // %entry
+; CHECK-NEXT:    mrs x8, TPIDR_EL0
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr w0, [x8, x0]
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  %0 = tail call align 4 ptr @llvm.threadlocal.address.p0(ptr align 4 @x)
+  %1 = load i32, ptr %0, align 4
+  ret i32 %1
+}
+
+define i32 @load_tls_streaming() nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: load_tls_streaming:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    adrp x0, :tlsdesc:x
+; CHECK-NEXT:    ldr x1, [x0, :tlsdesc_lo12:x]
+; CHECK-NEXT:    add x0, x0, :tlsdesc_lo12:x
+; CHECK-NEXT:    .tlsdesccall x
+; CHECK-NEXT:    blr x1
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    mrs x8, TPIDR_EL0
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr w0, [x8, x0]
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  %0 = tail call align 4 ptr @llvm.threadlocal.address.p0(ptr align 4 @x)
+  %1 = load i32, ptr %0, align 4
+  ret i32 %1
+}
+
+define i32 @load_tls_shared_za() nounwind "aarch64_inout_za" {
+; CHECK-LABEL: load_tls_shared_za:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    stp x9, x8, [x29, #-16]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    adrp x0, :tlsdesc:x
+; CHECK-NEXT:    ldr x1, [x0, :tlsdesc_lo12:x]
+; CHECK-NEXT:    add x0, x0, :tlsdesc_lo12:x
+; CHECK-NEXT:    .tlsdesccall x
+; CHECK-NEXT:    blr x1
+; CHECK-NEXT:    mrs x8, TPIDR_EL0
+; CHECK-NEXT:    ldr w0, [x8, x0]
+; CHECK-NEXT:    mov w8, w0
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    mrs x9, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x9, .LBB2_2
+; CHECK-NEXT:  // %bb.1: // %entry
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB2_2: // %entry
+; CHECK-NEXT:    mov w0, w8
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  %0 = tail call align 4 ptr @llvm.threadlocal.address.p0(ptr align 4 @x)
+  %1 = load i32, ptr %0, align 4
+  ret i32 %1
+}
+
+define i32 @load_tls_streaming_shared_za() nounwind "aarch64_inout_za" "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: load_tls_streaming_shared_za:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x29, x30, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    add x29, sp, #64
+; CHECK-NEXT:    str x19, [sp, #80] // 8-byte Folded Spill
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    stp x9, x8, [x29, #-80]
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    sub x8, x29, #80
+; CHECK-NEXT:    msr TPIDR2_EL0, x8
+; CHECK-NEXT:    adrp x0, :tlsdesc:x
+; CHECK-NEXT:    ldr x1, [x0, :tlsdesc_lo12:x]
+; CHECK-NEXT:    add x0, x0, :tlsdesc_lo12:x
+; CHECK-NEXT:    .tlsdesccall x
+; CHECK-NEXT:    blr x1
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    mrs x8, TPIDR_EL0
+; CHECK-NEXT:    ldr w0, [x8, x0]
+; CHECK-NEXT:    mov w8, w0
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    mrs x9, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #80
+; CHECK-NEXT:    cbnz x9, .LBB3_2
+; CHECK-NEXT:  // %bb.1: // %entry
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB3_2: // %entry
+; CHECK-NEXT:    mov w0, w8
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    sub sp, x29, #64
+; CHECK-NEXT:    ldp x29, x30, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x19, [sp, #80] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  %0 = tail call align 4 ptr @llvm.threadlocal.address.p0(ptr align 4 @x)
+  %1 = load i32, ptr %0, align 4
+  ret i32 %1
+}

>From 3e9e66cba1016d15fe1ac99823348dc9cb97564a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 4 Nov 2025 16:54:29 +0000
Subject: [PATCH 2/2] Fixup ZA saves

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 36 +++++++++++--------
 .../lib/Target/AArch64/AArch64SMEInstrInfo.td |  4 +--
 llvm/lib/Target/AArch64/MachineSMEABIPass.cpp | 20 ++++++-----
 llvm/test/CodeGen/AArch64/sme-dynamic-tls.ll  |  3 +-
 4 files changed, 36 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b85639cfb4d8f..0b09bd554bd38 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -9617,8 +9617,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
       // using a chain can result in incorrect scheduling. The markers refer to
       // the position just before the CALLSEQ_START (though occur after as
       // CALLSEQ_START lacks in-glue).
-      Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
-                          {Chain, Chain.getValue(1)});
+      Chain =
+          DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other, MVT::Glue),
+                      {Chain, Chain.getValue(1)});
     }
   }
 
@@ -10633,26 +10634,31 @@ SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr,
   SMECallAttrs TLSCallAttrs(FuncInfo->getSMEFnAttrs(), {}, SMEAttrs::Normal);
   bool RequiresSMChange = TLSCallAttrs.requiresSMChange();
 
-  if (RequiresSMChange) {
-    Chain = changeStreamingMode(DAG, DL, /*Enable=*/false, Chain, Glue,
-                                getSMToggleCondition(TLSCallAttrs));
-    Glue = Chain.getValue(1);
-  }
+  auto ChainAndGlue = [](SDValue Chain) -> std::pair<SDValue, SDValue> {
+    return {Chain, Chain.getValue(1)};
+  };
+
+  if (RequiresSMChange)
+    std::tie(Chain, Glue) =
+        ChainAndGlue(changeStreamingMode(DAG, DL, /*Enable=*/false, Chain, Glue,
+                                         getSMToggleCondition(TLSCallAttrs)));
 
   unsigned Opcode =
       DAG.getMachineFunction().getInfo<AArch64FunctionInfo>()->hasELFSignedGOT()
           ? AArch64ISD::TLSDESC_AUTH_CALLSEQ
           : AArch64ISD::TLSDESC_CALLSEQ;
   SDValue Ops[] = {Chain, SymAddr, Glue};
-  Chain = DAG.getNode(Opcode, DL, NodeTys,
-                      Glue ? ArrayRef(Ops) : ArrayRef(Ops).drop_back());
-  Glue = Chain.getValue(1);
+  std::tie(Chain, Glue) = ChainAndGlue(DAG.getNode(
+      Opcode, DL, NodeTys, Glue ? ArrayRef(Ops) : ArrayRef(Ops).drop_back()));
 
-  if (RequiresSMChange) {
-    Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
-                                getSMToggleCondition(TLSCallAttrs));
-    Glue = Chain.getValue(1);
-  }
+  if (TLSCallAttrs.requiresLazySave())
+    std::tie(Chain, Glue) = ChainAndGlue(DAG.getNode(
+        AArch64ISD::REQUIRES_ZA_SAVE, DL, NodeTys, {Chain, Chain.getValue(1)}));
+
+  if (RequiresSMChange)
+    std::tie(Chain, Glue) =
+        ChainAndGlue(changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
+                                         getSMToggleCondition(TLSCallAttrs)));
 
   return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue);
 }
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 752b185832c30..5bb70ee11b06d 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -113,12 +113,12 @@ def CommitZASavePseudo
 
 def AArch64_inout_za_use
   : SDNode<"AArch64ISD::INOUT_ZA_USE", SDTypeProfile<0, 0,[]>,
-           [SDNPHasChain, SDNPInGlue]>;
+           [SDNPHasChain, SDNPInGlue, SDNPOutGlue]>;
 def : Pat<(AArch64_inout_za_use), (InOutZAUsePseudo)>;
 
 def AArch64_requires_za_save
   : SDNode<"AArch64ISD::REQUIRES_ZA_SAVE", SDTypeProfile<0, 0,[]>,
-           [SDNPHasChain, SDNPInGlue]>;
+           [SDNPHasChain, SDNPInGlue, SDNPOutGlue]>;
 def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
 
 def AArch64_sme_state_alloc
diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index 0f6d930cc1fba..24d30c731b945 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -231,12 +231,6 @@ getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
   if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
     return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
 
-  // TLS-descriptor calls don't use the standard call lowering, so handle them
-  // as a special case here. Assume a private ZA interface.
-  if (MI.getOpcode() == AArch64::TLSDESC_CALLSEQ ||
-      MI.getOpcode() == AArch64::TLSDESC_AUTH_CALLSEQ)
-    return {ZAState::LOCAL_SAVED, InsertPt};
-
   if (MI.isReturn())
     return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
 
@@ -387,6 +381,17 @@ static void setPhysLiveRegs(LiveRegUnits &LiveUnits, LiveRegs PhysLiveRegs) {
     LiveUnits.addReg(AArch64::W0_HI);
 }
 
+[[maybe_unused]] bool isCallStartOpcode(unsigned Opc) {
+  switch (Opc) {
+  case AArch64::TLSDESC_CALLSEQ:
+  case AArch64::TLSDESC_AUTH_CALLSEQ:
+  case AArch64::ADJCALLSTACKDOWN:
+    return true;
+  default:
+    return false;
+  }
+}
+
 FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
   assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
           SMEFnAttrs.hasZAState()) &&
@@ -430,8 +435,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
       // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
       auto [NeededState, InsertPt] = getZAStateBeforeInst(
           *TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
-      assert((InsertPt == MBBI ||
-              InsertPt->getOpcode() == AArch64::ADJCALLSTACKDOWN) &&
+      assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) &&
              "Unexpected state change insertion point!");
       // TODO: Do something to avoid state changes where NZCV is live.
       if (MBBI == FirstTerminatorInsertPt)
diff --git a/llvm/test/CodeGen/AArch64/sme-dynamic-tls.ll b/llvm/test/CodeGen/AArch64/sme-dynamic-tls.ll
index 0f6c0c9705cc5..f72ccadea5dba 100644
--- a/llvm/test/CodeGen/AArch64/sme-dynamic-tls.ll
+++ b/llvm/test/CodeGen/AArch64/sme-dynamic-tls.ll
@@ -11,8 +11,7 @@ define i32 @load_tls_streaming_compat() nounwind "aarch64_pstate_sm_compatible"
 ; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
 ; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
 ; CHECK-NEXT:    str x30, [sp, #64] // 8-byte Folded Spill
-; CHECK-NEXT:    bl __arm_sme_state
-; CHECK-NEXT:    mov x8, x0
+; CHECK-NEXT:    mrs x8, SVCR
 ; CHECK-NEXT:    tbz w8, #0, .LBB0_2
 ; CHECK-NEXT:  // %bb.1: // %entry
 ; CHECK-NEXT:    smstop sm



More information about the llvm-commits mailing list