[llvm] [AArch64][SME] Merge back-to-back SME call regions (PR #142111)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Fri May 30 02:01:10 PDT 2025
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/142111
Proof-of-concept patch for merging back-to-back SME call regions (e.g. lazy saves, zt0 spills, streaming mode switches).
>From 9ddd8629f073222f598922dd27b285cf9dcf264b Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 22 May 2025 11:23:25 +0000
Subject: [PATCH 1/3] [AArch64][SME] Simplify initialization of TPIDR2 block
This patch updates the definition of `AArch64ISD::INIT_TPIDR2OBJ` to
take the number of save slices (which is currently always all ZA
slices). Using this, we can initialize the TPIDR2 block with a single
STP of the save buffer pointer and the number of save slices. The
reserved bytes (10-15) will be implicitly zeroed as the result of RDSVL
will always be <= 16-bits. Using an STP is also possible for big-endian
targets with an additional left shift.
Note: We used to write the number of save slices to the TPIDR2 block
before every call with a lazy save; however, based on 6.6.9 "Changes to
the TPIDR2 block" in the aapcs64 (https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#changes-to-the-tpidr2-block),
it seems we can rely on callers preserving the contents of the TPIDR2
block.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 42 ++++++------
.../lib/Target/AArch64/AArch64SMEInstrInfo.td | 6 +-
.../AArch64/sme-disable-gisel-fisel.ll | 23 ++-----
.../CodeGen/AArch64/sme-lazy-save-call.ll | 57 ++++++----------
.../AArch64/sme-shared-za-interface.ll | 18 ++---
.../AArch64/sme-tpidr2-init-aarch64be.ll | 66 +++++++++++++++++++
.../AArch64/sme-za-lazy-save-buffer.ll | 30 ++++-----
llvm/test/CodeGen/AArch64/sme-zt0-state.ll | 9 +--
llvm/test/CodeGen/AArch64/stack-hazard.ll | 32 +++------
.../CodeGen/AArch64/sve-stack-frame-layout.ll | 9 +--
10 files changed, 150 insertions(+), 142 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/sme-tpidr2-init-aarch64be.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4dacd2273306e..3c1ee0560aef9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3227,20 +3227,24 @@ AArch64TargetLowering::EmitInitTPIDR2Object(MachineInstr &MI,
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
if (TPIDR2.Uses > 0) {
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
- // Store the buffer pointer to the TPIDR2 stack object.
- BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRXui))
+ unsigned TPIDInitSaveSlicesReg = MI.getOperand(1).getReg();
+ if (!Subtarget->isLittleEndian()) {
+ unsigned TmpReg =
+ MF->getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+ // For big-endian targets move "num_za_save_slices" to the top two bytes.
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::UBFMXri), TmpReg)
+ .addReg(TPIDInitSaveSlicesReg)
+ .addImm(16)
+ .addImm(15);
+ TPIDInitSaveSlicesReg = TmpReg;
+ }
+ // Store buffer pointer and num_za_save_slices.
+ // Bytes 10-15 are implicitly zeroed.
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STPXi))
.addReg(MI.getOperand(0).getReg())
+ .addReg(TPIDInitSaveSlicesReg)
.addFrameIndex(TPIDR2.FrameIndex)
.addImm(0);
- // Set the reserved bytes (10-15) to zero
- BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRHHui))
- .addReg(AArch64::WZR)
- .addFrameIndex(TPIDR2.FrameIndex)
- .addImm(5);
- BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRWui))
- .addReg(AArch64::WZR)
- .addFrameIndex(TPIDR2.FrameIndex)
- .addImm(3);
} else
MFI.RemoveStackObject(TPIDR2.FrameIndex);
@@ -8344,9 +8348,12 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
{Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
MFI.CreateVariableSizedObject(Align(16), nullptr);
}
+ SDValue NumZaSaveSlices = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
+ DAG.getConstant(1, DL, MVT::i32));
Chain = DAG.getNode(
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
- {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
+ {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0),
+ /*Num save slices*/ NumZaSaveSlices});
} else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
// Call __arm_sme_state_size().
SDValue BufferSize =
@@ -9127,19 +9134,10 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
bool RequiresLazySave = CallAttrs.requiresLazySave();
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
if (RequiresLazySave) {
- const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
- MachinePointerInfo MPI =
- MachinePointerInfo::getStack(MF, TPIDR2.FrameIndex);
+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
TPIDR2.FrameIndex,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
- SDValue NumZaSaveSlicesAddr =
- DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr,
- DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType()));
- SDValue NumZaSaveSlices = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
- DAG.getConstant(1, DL, MVT::i32));
- Chain = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, NumZaSaveSlicesAddr,
- MPI, MVT::i16);
Chain = DAG.getNode(
ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index e7482da001074..8667c778782a1 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -54,10 +54,10 @@ let usesCustomInserter = 1, Defs = [SP], Uses = [SP] in {
def : Pat<(i64 (AArch64AllocateZABuffer GPR64:$size)),
(AllocateZABuffer $size)>;
-def AArch64InitTPIDR2Obj : SDNode<"AArch64ISD::INIT_TPIDR2OBJ", SDTypeProfile<0, 1,
- [SDTCisInt<0>]>, [SDNPHasChain, SDNPMayStore]>;
+def AArch64InitTPIDR2Obj : SDNode<"AArch64ISD::INIT_TPIDR2OBJ", SDTypeProfile<0, 2,
+ [SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain, SDNPMayStore]>;
let usesCustomInserter = 1 in {
- def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
+ def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer, GPR64:$save_slices), [(AArch64InitTPIDR2Obj GPR64:$buffer, GPR64:$save_slices)]>, Sched<[WriteI]> {}
}
// Nodes to allocate a save buffer for SME.
diff --git a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
index 4a52bf27a7591..13ffbc1296217 100644
--- a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
+++ b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
@@ -268,10 +268,7 @@ define double @za_shared_caller_to_za_none_callee(double %x) nounwind noinline
; CHECK-COMMON-NEXT: mov x9, sp
; CHECK-COMMON-NEXT: msub x9, x8, x8, x9
; CHECK-COMMON-NEXT: mov sp, x9
-; CHECK-COMMON-NEXT: stur x9, [x29, #-16]
-; CHECK-COMMON-NEXT: sturh wzr, [x29, #-6]
-; CHECK-COMMON-NEXT: stur wzr, [x29, #-4]
-; CHECK-COMMON-NEXT: sturh w8, [x29, #-8]
+; CHECK-COMMON-NEXT: stp x9, x8, [x29, #-16]
; CHECK-COMMON-NEXT: sub x8, x29, #16
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x8
; CHECK-COMMON-NEXT: bl normal_callee
@@ -310,12 +307,9 @@ define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_inout_za" nounwind {
; CHECK-COMMON-NEXT: mov x9, sp
; CHECK-COMMON-NEXT: msub x9, x8, x8, x9
; CHECK-COMMON-NEXT: mov sp, x9
-; CHECK-COMMON-NEXT: stur x9, [x29, #-16]
-; CHECK-COMMON-NEXT: sub x9, x29, #16
-; CHECK-COMMON-NEXT: sturh wzr, [x29, #-6]
-; CHECK-COMMON-NEXT: stur wzr, [x29, #-4]
-; CHECK-COMMON-NEXT: sturh w8, [x29, #-8]
-; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x9
+; CHECK-COMMON-NEXT: sub x10, x29, #16
+; CHECK-COMMON-NEXT: stp x9, x8, [x29, #-16]
+; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x10
; CHECK-COMMON-NEXT: bl __addtf3
; CHECK-COMMON-NEXT: smstart za
; CHECK-COMMON-NEXT: mrs x8, TPIDR2_EL0
@@ -375,12 +369,9 @@ define double @frem_call_za(double %a, double %b) "aarch64_inout_za" nounwind {
; CHECK-COMMON-NEXT: mov x9, sp
; CHECK-COMMON-NEXT: msub x9, x8, x8, x9
; CHECK-COMMON-NEXT: mov sp, x9
-; CHECK-COMMON-NEXT: stur x9, [x29, #-16]
-; CHECK-COMMON-NEXT: sub x9, x29, #16
-; CHECK-COMMON-NEXT: sturh wzr, [x29, #-6]
-; CHECK-COMMON-NEXT: stur wzr, [x29, #-4]
-; CHECK-COMMON-NEXT: sturh w8, [x29, #-8]
-; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x9
+; CHECK-COMMON-NEXT: sub x10, x29, #16
+; CHECK-COMMON-NEXT: stp x9, x8, [x29, #-16]
+; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x10
; CHECK-COMMON-NEXT: bl fmod
; CHECK-COMMON-NEXT: smstart za
; CHECK-COMMON-NEXT: mrs x8, TPIDR2_EL0
diff --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
index e463e833bdbde..26c6dc4eb978b 100644
--- a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
+++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
@@ -16,12 +16,9 @@ define void @test_lazy_save_1_callee() nounwind "aarch64_inout_za" {
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
-; CHECK-NEXT: stur x9, [x29, #-16]
-; CHECK-NEXT: sub x9, x29, #16
-; CHECK-NEXT: sturh wzr, [x29, #-6]
-; CHECK-NEXT: stur wzr, [x29, #-4]
-; CHECK-NEXT: sturh w8, [x29, #-8]
-; CHECK-NEXT: msr TPIDR2_EL0, x9
+; CHECK-NEXT: sub x10, x29, #16
+; CHECK-NEXT: stp x9, x8, [x29, #-16]
+; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: bl private_za_callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
@@ -43,21 +40,17 @@ define void @test_lazy_save_1_callee() nounwind "aarch64_inout_za" {
define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
; CHECK-LABEL: test_lazy_save_2_callees:
; CHECK: // %bb.0:
-; CHECK-NEXT: stp x29, x30, [sp, #-48]! // 16-byte Folded Spill
-; CHECK-NEXT: str x21, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT: stp x20, x19, [sp, #16] // 16-byte Folded Spill
; CHECK-NEXT: mov x29, sp
-; CHECK-NEXT: stp x20, x19, [sp, #32] // 16-byte Folded Spill
; CHECK-NEXT: sub sp, sp, #16
-; CHECK-NEXT: rdsvl x20, #1
-; CHECK-NEXT: mov x8, sp
-; CHECK-NEXT: msub x8, x20, x20, x8
-; CHECK-NEXT: mov sp, x8
-; CHECK-NEXT: sub x21, x29, #16
-; CHECK-NEXT: stur x8, [x29, #-16]
-; CHECK-NEXT: sturh wzr, [x29, #-6]
-; CHECK-NEXT: stur wzr, [x29, #-4]
-; CHECK-NEXT: sturh w20, [x29, #-8]
-; CHECK-NEXT: msr TPIDR2_EL0, x21
+; CHECK-NEXT: rdsvl x8, #1
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: msub x9, x8, x8, x9
+; CHECK-NEXT: mov sp, x9
+; CHECK-NEXT: sub x20, x29, #16
+; CHECK-NEXT: stp x9, x8, [x29, #-16]
+; CHECK-NEXT: msr TPIDR2_EL0, x20
; CHECK-NEXT: bl private_za_callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
@@ -67,8 +60,7 @@ define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
; CHECK-NEXT: bl __arm_tpidr2_restore
; CHECK-NEXT: .LBB1_2:
; CHECK-NEXT: msr TPIDR2_EL0, xzr
-; CHECK-NEXT: sturh w20, [x29, #-8]
-; CHECK-NEXT: msr TPIDR2_EL0, x21
+; CHECK-NEXT: msr TPIDR2_EL0, x20
; CHECK-NEXT: bl private_za_callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
@@ -79,9 +71,8 @@ define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
; CHECK-NEXT: .LBB1_4:
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: mov sp, x29
-; CHECK-NEXT: ldp x20, x19, [sp, #32] // 16-byte Folded Reload
-; CHECK-NEXT: ldr x21, [sp, #16] // 8-byte Folded Reload
-; CHECK-NEXT: ldp x29, x30, [sp], #48 // 16-byte Folded Reload
+; CHECK-NEXT: ldp x20, x19, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @private_za_callee()
call void @private_za_callee()
@@ -100,12 +91,9 @@ define float @test_lazy_save_expanded_intrinsic(float %a) nounwind "aarch64_inou
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
-; CHECK-NEXT: stur x9, [x29, #-16]
-; CHECK-NEXT: sub x9, x29, #16
-; CHECK-NEXT: sturh wzr, [x29, #-6]
-; CHECK-NEXT: stur wzr, [x29, #-4]
-; CHECK-NEXT: sturh w8, [x29, #-8]
-; CHECK-NEXT: msr TPIDR2_EL0, x9
+; CHECK-NEXT: sub x10, x29, #16
+; CHECK-NEXT: stp x9, x8, [x29, #-16]
+; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: bl cosf
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
@@ -141,12 +129,9 @@ define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_inout_za
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
-; CHECK-NEXT: stur x9, [x29, #-80]
-; CHECK-NEXT: sub x9, x29, #80
-; CHECK-NEXT: sturh wzr, [x29, #-70]
-; CHECK-NEXT: stur wzr, [x29, #-68]
-; CHECK-NEXT: sturh w8, [x29, #-72]
-; CHECK-NEXT: msr TPIDR2_EL0, x9
+; CHECK-NEXT: sub x10, x29, #80
+; CHECK-NEXT: stp x9, x8, [x29, #-80]
+; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x20, x0, #0x1
; CHECK-NEXT: tbz w20, #0, .LBB3_2
diff --git a/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll b/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
index 393ff3b79aedf..d12c304905b4b 100644
--- a/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
+++ b/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
@@ -14,12 +14,9 @@ define void @disable_tailcallopt() "aarch64_inout_za" nounwind {
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
-; CHECK-NEXT: stur x9, [x29, #-16]
-; CHECK-NEXT: sub x9, x29, #16
-; CHECK-NEXT: sturh wzr, [x29, #-6]
-; CHECK-NEXT: stur wzr, [x29, #-4]
-; CHECK-NEXT: sturh w8, [x29, #-8]
-; CHECK-NEXT: msr TPIDR2_EL0, x9
+; CHECK-NEXT: sub x10, x29, #16
+; CHECK-NEXT: stp x9, x8, [x29, #-16]
+; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: bl private_za_callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
@@ -47,12 +44,9 @@ define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_inout_za" nounwind {
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
-; CHECK-NEXT: stur x9, [x29, #-16]
-; CHECK-NEXT: sub x9, x29, #16
-; CHECK-NEXT: sturh wzr, [x29, #-6]
-; CHECK-NEXT: stur wzr, [x29, #-4]
-; CHECK-NEXT: sturh w8, [x29, #-8]
-; CHECK-NEXT: msr TPIDR2_EL0, x9
+; CHECK-NEXT: sub x10, x29, #16
+; CHECK-NEXT: stp x9, x8, [x29, #-16]
+; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: bl __addtf3
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
diff --git a/llvm/test/CodeGen/AArch64/sme-tpidr2-init-aarch64be.ll b/llvm/test/CodeGen/AArch64/sme-tpidr2-init-aarch64be.ll
new file mode 100644
index 0000000000000..78823e8b4da60
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-tpidr2-init-aarch64be.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64 -aarch64-streaming-hazard-size=0 -mattr=+sve -mattr=+sme < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64_be -aarch64-streaming-hazard-size=0 -mattr=+sve -mattr=+sme < %s | FileCheck %s --check-prefix=CHECK-BE
+
+declare void @private_za_callee()
+declare float @llvm.cos.f32(float)
+
+; Test TPIDR2_EL0 is initialized correctly for AArch64 big-endian.
+define void @test_tpidr2_init() nounwind "aarch64_inout_za" {
+; CHECK-LABEL: test_tpidr2_init:
+; CHECK: // %bb.0:
+; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT: mov x29, sp
+; CHECK-NEXT: sub sp, sp, #16
+; CHECK-NEXT: rdsvl x8, #1
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: msub x9, x8, x8, x9
+; CHECK-NEXT: mov sp, x9
+; CHECK-NEXT: sub x10, x29, #16
+; CHECK-NEXT: stp x9, x8, [x29, #-16]
+; CHECK-NEXT: msr TPIDR2_EL0, x10
+; CHECK-NEXT: bl private_za_callee
+; CHECK-NEXT: smstart za
+; CHECK-NEXT: mrs x8, TPIDR2_EL0
+; CHECK-NEXT: sub x0, x29, #16
+; CHECK-NEXT: cbnz x8, .LBB0_2
+; CHECK-NEXT: // %bb.1:
+; CHECK-NEXT: bl __arm_tpidr2_restore
+; CHECK-NEXT: .LBB0_2:
+; CHECK-NEXT: msr TPIDR2_EL0, xzr
+; CHECK-NEXT: mov sp, x29
+; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+;
+; CHECK-BE-LABEL: test_tpidr2_init:
+; CHECK-BE: // %bb.0:
+; CHECK-BE-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-BE-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-BE-NEXT: mov x29, sp
+; CHECK-BE-NEXT: sub sp, sp, #16
+; CHECK-BE-NEXT: rdsvl x8, #1
+; CHECK-BE-NEXT: mov x9, sp
+; CHECK-BE-NEXT: msub x9, x8, x8, x9
+; CHECK-BE-NEXT: mov sp, x9
+; CHECK-BE-NEXT: lsl x8, x8, #48
+; CHECK-BE-NEXT: sub x10, x29, #16
+; CHECK-BE-NEXT: stp x9, x8, [x29, #-16]
+; CHECK-BE-NEXT: msr TPIDR2_EL0, x10
+; CHECK-BE-NEXT: bl private_za_callee
+; CHECK-BE-NEXT: smstart za
+; CHECK-BE-NEXT: mrs x8, TPIDR2_EL0
+; CHECK-BE-NEXT: sub x0, x29, #16
+; CHECK-BE-NEXT: cbnz x8, .LBB0_2
+; CHECK-BE-NEXT: // %bb.1:
+; CHECK-BE-NEXT: bl __arm_tpidr2_restore
+; CHECK-BE-NEXT: .LBB0_2:
+; CHECK-BE-NEXT: msr TPIDR2_EL0, xzr
+; CHECK-BE-NEXT: mov sp, x29
+; CHECK-BE-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-BE-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-BE-NEXT: ret
+ call void @private_za_callee()
+ ret void
+}
diff --git a/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll b/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll
index ad3f7f5514d0e..256045cbe44f8 100644
--- a/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll
+++ b/llvm/test/CodeGen/AArch64/sme-za-lazy-save-buffer.ll
@@ -21,11 +21,9 @@ define float @multi_bb_stpidr2_save_required(i32 %a, float %b, float %c) "aarch6
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
-; CHECK-NEXT: msub x8, x8, x8, x9
-; CHECK-NEXT: mov sp, x8
-; CHECK-NEXT: stur x8, [x29, #-16]
-; CHECK-NEXT: sturh wzr, [x29, #-6]
-; CHECK-NEXT: stur wzr, [x29, #-4]
+; CHECK-NEXT: msub x9, x8, x8, x9
+; CHECK-NEXT: mov sp, x9
+; CHECK-NEXT: stp x9, x8, [x29, #-16]
; CHECK-NEXT: cbz w0, .LBB1_2
; CHECK-NEXT: // %bb.1: // %use_b
; CHECK-NEXT: fmov s1, #4.00000000
@@ -33,10 +31,8 @@ define float @multi_bb_stpidr2_save_required(i32 %a, float %b, float %c) "aarch6
; CHECK-NEXT: b .LBB1_5
; CHECK-NEXT: .LBB1_2: // %use_c
; CHECK-NEXT: fmov s0, s1
-; CHECK-NEXT: rdsvl x8, #1
-; CHECK-NEXT: sub x9, x29, #16
-; CHECK-NEXT: sturh w8, [x29, #-8]
-; CHECK-NEXT: msr TPIDR2_EL0, x9
+; CHECK-NEXT: sub x8, x29, #16
+; CHECK-NEXT: msr TPIDR2_EL0, x8
; CHECK-NEXT: bl cosf
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
@@ -77,20 +73,18 @@ define float @multi_bb_stpidr2_save_required_stackprobe(i32 %a, float %b, float
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
-; CHECK-NEXT: msub x8, x8, x8, x9
+; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: .LBB2_1: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: sub sp, sp, #16, lsl #12 // =65536
-; CHECK-NEXT: cmp sp, x8
+; CHECK-NEXT: cmp sp, x9
; CHECK-NEXT: b.le .LBB2_3
; CHECK-NEXT: // %bb.2: // in Loop: Header=BB2_1 Depth=1
; CHECK-NEXT: str xzr, [sp]
; CHECK-NEXT: b .LBB2_1
; CHECK-NEXT: .LBB2_3:
-; CHECK-NEXT: mov sp, x8
+; CHECK-NEXT: mov sp, x9
; CHECK-NEXT: ldr xzr, [sp]
-; CHECK-NEXT: stur x8, [x29, #-16]
-; CHECK-NEXT: sturh wzr, [x29, #-6]
-; CHECK-NEXT: stur wzr, [x29, #-4]
+; CHECK-NEXT: stp x9, x8, [x29, #-16]
; CHECK-NEXT: cbz w0, .LBB2_5
; CHECK-NEXT: // %bb.4: // %use_b
; CHECK-NEXT: fmov s1, #4.00000000
@@ -98,10 +92,8 @@ define float @multi_bb_stpidr2_save_required_stackprobe(i32 %a, float %b, float
; CHECK-NEXT: b .LBB2_8
; CHECK-NEXT: .LBB2_5: // %use_c
; CHECK-NEXT: fmov s0, s1
-; CHECK-NEXT: rdsvl x8, #1
-; CHECK-NEXT: sub x9, x29, #16
-; CHECK-NEXT: sturh w8, [x29, #-8]
-; CHECK-NEXT: msr TPIDR2_EL0, x9
+; CHECK-NEXT: sub x8, x29, #16
+; CHECK-NEXT: msr TPIDR2_EL0, x8
; CHECK-NEXT: bl cosf
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 63577e4d217a8..7d1ca3946f2d7 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -39,13 +39,10 @@ define void @za_zt0_shared_caller_no_state_callee(ptr %callee) "aarch64_inout_za
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
-; CHECK-NEXT: stur x9, [x29, #-16]
-; CHECK-NEXT: sub x9, x29, #16
+; CHECK-NEXT: sub x10, x29, #16
; CHECK-NEXT: sub x19, x29, #80
-; CHECK-NEXT: sturh wzr, [x29, #-6]
-; CHECK-NEXT: stur wzr, [x29, #-4]
-; CHECK-NEXT: sturh w8, [x29, #-8]
-; CHECK-NEXT: msr TPIDR2_EL0, x9
+; CHECK-NEXT: stp x9, x8, [x29, #-16]
+; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: blr x0
; CHECK-NEXT: smstart za
diff --git a/llvm/test/CodeGen/AArch64/stack-hazard.ll b/llvm/test/CodeGen/AArch64/stack-hazard.ll
index 791d7580c327d..f1db7188943db 100644
--- a/llvm/test/CodeGen/AArch64/stack-hazard.ll
+++ b/llvm/test/CodeGen/AArch64/stack-hazard.ll
@@ -2832,12 +2832,9 @@ define i32 @vastate(i32 %x) "aarch64_inout_za" "aarch64_pstate_sm_enabled" "targ
; CHECK0-NEXT: mov w20, w0
; CHECK0-NEXT: msub x9, x8, x8, x9
; CHECK0-NEXT: mov sp, x9
-; CHECK0-NEXT: stur x9, [x29, #-80]
-; CHECK0-NEXT: sub x9, x29, #80
-; CHECK0-NEXT: sturh wzr, [x29, #-70]
-; CHECK0-NEXT: stur wzr, [x29, #-68]
-; CHECK0-NEXT: sturh w8, [x29, #-72]
-; CHECK0-NEXT: msr TPIDR2_EL0, x9
+; CHECK0-NEXT: sub x10, x29, #80
+; CHECK0-NEXT: stp x9, x8, [x29, #-80]
+; CHECK0-NEXT: msr TPIDR2_EL0, x10
; CHECK0-NEXT: .cfi_offset vg, -32
; CHECK0-NEXT: smstop sm
; CHECK0-NEXT: bl other
@@ -2906,12 +2903,9 @@ define i32 @vastate(i32 %x) "aarch64_inout_za" "aarch64_pstate_sm_enabled" "targ
; CHECK64-NEXT: mov w20, w0
; CHECK64-NEXT: msub x9, x8, x8, x9
; CHECK64-NEXT: mov sp, x9
-; CHECK64-NEXT: stur x9, [x29, #-208]
-; CHECK64-NEXT: sub x9, x29, #208
-; CHECK64-NEXT: sturh wzr, [x29, #-198]
-; CHECK64-NEXT: stur wzr, [x29, #-196]
-; CHECK64-NEXT: sturh w8, [x29, #-200]
-; CHECK64-NEXT: msr TPIDR2_EL0, x9
+; CHECK64-NEXT: sub x10, x29, #208
+; CHECK64-NEXT: stp x9, x8, [x29, #-208]
+; CHECK64-NEXT: msr TPIDR2_EL0, x10
; CHECK64-NEXT: .cfi_offset vg, -32
; CHECK64-NEXT: smstop sm
; CHECK64-NEXT: bl other
@@ -2986,16 +2980,10 @@ define i32 @vastate(i32 %x) "aarch64_inout_za" "aarch64_pstate_sm_enabled" "targ
; CHECK1024-NEXT: mov w20, w0
; CHECK1024-NEXT: msub x9, x8, x8, x9
; CHECK1024-NEXT: mov sp, x9
-; CHECK1024-NEXT: sub x10, x29, #1872
-; CHECK1024-NEXT: stur x9, [x10, #-256]
-; CHECK1024-NEXT: sub x9, x29, #1862
-; CHECK1024-NEXT: sub x10, x29, #1860
-; CHECK1024-NEXT: sturh wzr, [x9, #-256]
-; CHECK1024-NEXT: sub x9, x29, #2128
-; CHECK1024-NEXT: stur wzr, [x10, #-256]
-; CHECK1024-NEXT: sub x10, x29, #1864
-; CHECK1024-NEXT: sturh w8, [x10, #-256]
-; CHECK1024-NEXT: msr TPIDR2_EL0, x9
+; CHECK1024-NEXT: sub x10, x29, #2128
+; CHECK1024-NEXT: sub x11, x29, #1616
+; CHECK1024-NEXT: stp x9, x8, [x11, #-512]
+; CHECK1024-NEXT: msr TPIDR2_EL0, x10
; CHECK1024-NEXT: .cfi_offset vg, -32
; CHECK1024-NEXT: smstop sm
; CHECK1024-NEXT: bl other
diff --git a/llvm/test/CodeGen/AArch64/sve-stack-frame-layout.ll b/llvm/test/CodeGen/AArch64/sve-stack-frame-layout.ll
index c5cf4593cc86d..e0fe3049289ca 100644
--- a/llvm/test/CodeGen/AArch64/sve-stack-frame-layout.ll
+++ b/llvm/test/CodeGen/AArch64/sve-stack-frame-layout.ll
@@ -548,12 +548,9 @@ define i32 @vastate(i32 %x) "aarch64_inout_za" "aarch64_pstate_sm_enabled" "targ
; CHECK-NEXT: mov w20, w0
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
-; CHECK-NEXT: stur x9, [x29, #-80]
-; CHECK-NEXT: sub x9, x29, #80
-; CHECK-NEXT: sturh wzr, [x29, #-70]
-; CHECK-NEXT: stur wzr, [x29, #-68]
-; CHECK-NEXT: sturh w8, [x29, #-72]
-; CHECK-NEXT: msr TPIDR2_EL0, x9
+; CHECK-NEXT: sub x10, x29, #80
+; CHECK-NEXT: stp x9, x8, [x29, #-80]
+; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: .cfi_offset vg, -32
; CHECK-NEXT: smstop sm
; CHECK-NEXT: bl other
>From 3a45655ed1f08ae4653ecee5954fbceae4a7b85c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 30 May 2025 08:43:30 +0000
Subject: [PATCH 2/3] [AArch64][SME] Abstract SME call lowering with high-level
ISD nodes (NFC)
---
.../Target/AArch64/AArch64ISelLowering.cpp | 475 +++++++++++-------
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 20 +-
.../AArch64/AArch64MachineFunctionInfo.h | 6 +
.../AArch64/Utils/AArch64SMEAttributes.h | 2 +
4 files changed, 311 insertions(+), 192 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 3c1ee0560aef9..9866149868fd7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2730,6 +2730,9 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::COND_SMSTOP)
MAKE_CASE(AArch64ISD::RESTORE_ZA)
MAKE_CASE(AArch64ISD::RESTORE_ZT)
+ MAKE_CASE(AArch64ISD::SME_CALL_START)
+ MAKE_CASE(AArch64ISD::SME_CALL_SM_CHANGE)
+ MAKE_CASE(AArch64ISD::SME_CALL_END)
MAKE_CASE(AArch64ISD::SAVE_ZT)
MAKE_CASE(AArch64ISD::CALL)
MAKE_CASE(AArch64ISD::ADRP)
@@ -7750,9 +7753,258 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
+ case AArch64ISD::SME_CALL_START:
+ return LowerSME_CALL_START(Op, DAG);
+ case AArch64ISD::SME_CALL_SM_CHANGE:
+ return LowerSME_CALL_SM_CHANGE(Op, DAG);
+ case AArch64ISD::SME_CALL_END:
+ return LowerSME_CALL_END(Op, DAG);
}
}
+static SDNode *findSMECallStart(SDNode *N) {
+ unsigned Opcode = N->getOpcode();
+ if (Opcode == AArch64ISD::SME_CALL_START)
+ return N;
+ if (Opcode == AArch64ISD::SME_CALL_SM_CHANGE)
+ return findSMECallStart(N->getOperand(1).getNode());
+ if (Opcode == AArch64ISD::SME_CALL_END)
+ return findSMECallStart(N->getOperand(1).getNode());
+ llvm_unreachable("Unexpected opcode!");
+}
+
+static SMECallAttrs findSMECallAttrs(SDNode *N) {
+ SDNode *Start = findSMECallStart(N);
+ SMEAttrs CallerAttrs(Start->getConstantOperandVal(1));
+ SMEAttrs CalleeAttrs(Start->getConstantOperandVal(2));
+ SMEAttrs CallsiteAttrs(Start->getConstantOperandVal(3));
+ return SMECallAttrs(CallerAttrs, CalleeAttrs, CallsiteAttrs);
+}
+
+static unsigned getOrCreateZT0SpillSlot(AArch64FunctionInfo *FuncInfo,
+ MachineFrameInfo &MFI) {
+ unsigned ZTObj = FuncInfo->getZT0Idx();
+ if (ZTObj == std::numeric_limits<int>::max()) {
+ ZTObj = MFI.CreateSpillStackObject(64, Align(16));
+ FuncInfo->setZT0Idx(ZTObj);
+ }
+ return ZTObj;
+}
+
+// Emit a call to __arm_sme_save or __arm_sme_restore.
+static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
+ SelectionDAG &DAG,
+ AArch64FunctionInfo *Info, SDLoc DL,
+ SDValue Chain, bool IsSave) {
+ MachineFunction &MF = DAG.getMachineFunction();
+ AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
+ FuncInfo->setSMESaveBufferUsed();
+
+ TargetLowering::ArgListTy Args;
+ TargetLowering::ArgListEntry Entry;
+ Entry.Ty = PointerType::getUnqual(*DAG.getContext());
+ Entry.Node =
+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
+ Args.push_back(Entry);
+
+ SDValue Callee =
+ DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
+ TLI.getPointerTy(DAG.getDataLayout()));
+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
+ TargetLowering::CallLoweringInfo CLI(DAG);
+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
+ Callee, std::move(Args));
+ return TLI.LowerCallTo(CLI).second;
+}
+
+static AArch64SME::ToggleCondition
+getSMToggleCondition(const SMECallAttrs &CallAttrs) {
+ if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
+ CallAttrs.caller().hasStreamingBody())
+ return AArch64SME::Always;
+ if (CallAttrs.callee().hasNonStreamingInterface())
+ return AArch64SME::IfCallerIsStreaming;
+ if (CallAttrs.callee().hasStreamingInterface())
+ return AArch64SME::IfCallerIsNonStreaming;
+
+ llvm_unreachable("Unsupported attributes");
+}
+
+SDValue AArch64TargetLowering::LowerSME_CALL_START(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ SMECallAttrs CallAttrs = findSMECallAttrs(Op.getNode());
+ auto &MF = DAG.getMachineFunction();
+ auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
+ bool RequiresLazySave = CallAttrs.requiresLazySave();
+ bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
+
+ SDValue Chain = Op->getOperand(0);
+ if (RequiresLazySave) {
+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
+ SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
+ TPIDR2.FrameIndex,
+ DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+ Chain = DAG.getNode(
+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
+ DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
+ TPIDR2ObjAddr);
+ } else if (RequiresSaveAllZA) {
+ assert(!CallAttrs.callee().hasSharedZAInterface() &&
+ "Cannot share state that may not exist");
+ Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+ /*IsSave=*/true);
+ }
+
+ SDValue PStateSM;
+ bool RequiresSMChange = CallAttrs.requiresSMChange();
+ if (RequiresSMChange) {
+ if (CallAttrs.caller().hasStreamingInterfaceOrBody())
+ PStateSM = DAG.getConstant(1, DL, MVT::i64);
+ else if (CallAttrs.caller().hasNonStreamingInterface())
+ PStateSM = DAG.getConstant(0, DL, MVT::i64);
+ else
+ PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
+ } else {
+ PStateSM = DAG.getUNDEF(MVT::i64);
+ }
+
+ SDValue ZTFrameIdx;
+ bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
+
+ // If the caller has ZT0 state which will not be preserved by the callee,
+ // spill ZT0 before the call.
+ if (ShouldPreserveZT0) {
+ MachineFrameInfo &MFI = MF.getFrameInfo();
+ unsigned ZTObj = getOrCreateZT0SpillSlot(FuncInfo, MFI);
+ ZTFrameIdx = DAG.getFrameIndex(
+ ZTObj,
+ DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+
+ Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
+ {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+ }
+
+ // If caller shares ZT0 but the callee is not shared ZA, we need to stop
+ // PSTATE.ZA before the call if there is no lazy-save active.
+ bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
+ assert((!DisableZA || !RequiresLazySave) &&
+ "Lazy-save should have PSTATE.SM=1 on entry to the function");
+
+ if (DisableZA)
+ Chain = DAG.getNode(
+ AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
+ DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
+
+ return DAG.getMergeValues({PStateSM, Chain, DAG.getUNDEF(MVT::Glue)}, DL);
+}
+
+SDValue
+AArch64TargetLowering::LowerSME_CALL_SM_CHANGE(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ SMECallAttrs CallAttrs = findSMECallAttrs(Op.getNode());
+ SDValue Chain = Op->getOperand(0);
+ SDValue PStateSM = Op->getOperand(2);
+ SDValue InGlue;
+
+ if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
+ Chain = DAG.getNode(AArch64ISD::VG_SAVE, DL,
+ DAG.getVTList(MVT::Other, MVT::Glue), Chain);
+ InGlue = Chain.getValue(1);
+ }
+
+ SDValue NewChain = changeStreamingMode(
+ DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
+ getSMToggleCondition(CallAttrs), PStateSM);
+
+ return DAG.getMergeValues({NewChain, NewChain.getValue(1)}, DL);
+}
+
+SDValue AArch64TargetLowering::LowerSME_CALL_END(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ SMECallAttrs CallAttrs = findSMECallAttrs(Op.getNode());
+ bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
+ auto &MF = DAG.getMachineFunction();
+ auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
+ SDValue Result = Op->getOperand(0);
+ SDValue PStateSM = Op->getOperand(2);
+ SDValue InGlue = Op->getOperand(3);
+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+ bool RequiresLazySave = CallAttrs.requiresLazySave();
+ bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
+ bool RequiresSMChange = CallAttrs.requiresSMChange();
+
+ if (RequiresSMChange) {
+ assert(PStateSM && "Expected a PStateSM to be set");
+ Result = changeStreamingMode(
+ DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
+ getSMToggleCondition(CallAttrs), PStateSM);
+
+ if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
+ InGlue = Result.getValue(1);
+ Result =
+ DAG.getNode(AArch64ISD::VG_RESTORE, DL,
+ DAG.getVTList(MVT::Other, MVT::Glue), {Result, InGlue});
+ }
+ }
+
+ if (CallAttrs.requiresEnablingZAAfterCall())
+ // Unconditionally resume ZA.
+ Result = DAG.getNode(
+ AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
+ DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
+
+ if (ShouldPreserveZT0) {
+ MachineFrameInfo &MFI = MF.getFrameInfo();
+ unsigned ZTObj = getOrCreateZT0SpillSlot(FuncInfo, MFI);
+ SDValue ZTFrameIdx = DAG.getFrameIndex(
+ ZTObj,
+ DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+ Result =
+ DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
+ {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+ }
+
+ if (RequiresLazySave) {
+ // Conditionally restore the lazy save using a pseudo node.
+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
+ SDValue RegMask = DAG.getRegisterMask(
+ TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
+ SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
+ "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
+ SDValue TPIDR2_EL0 = DAG.getNode(
+ ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
+ DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
+
+ // Copy the address of the TPIDR2 block into X0 before 'calling' the
+ // RESTORE_ZA pseudo.
+ SDValue Glue;
+ SDValue TPIDR2Block = DAG.getFrameIndex(
+ TPIDR2.FrameIndex,
+ DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+ Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
+ Result =
+ DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
+ {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
+ RestoreRoutine, RegMask, Result.getValue(1)});
+
+ // Finally reset the TPIDR2_EL0 register to 0.
+ Result = DAG.getNode(
+ ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
+ DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
+ DAG.getConstant(0, DL, MVT::i64));
+ TPIDR2.Uses++;
+ } else if (RequiresSaveAllZA) {
+ Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result,
+ /*IsSave=*/false);
+ }
+
+ return DAG.getMergeValues({Result, DAG.getUNDEF(MVT::Glue)}, DL);
+}
+
bool AArch64TargetLowering::mergeStoresAfterLegalization(EVT VT) const {
return !Subtarget->useSVEForFixedLengthVectors();
}
@@ -8949,46 +9201,6 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
}
-// Emit a call to __arm_sme_save or __arm_sme_restore.
-static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
- SelectionDAG &DAG,
- AArch64FunctionInfo *Info, SDLoc DL,
- SDValue Chain, bool IsSave) {
- MachineFunction &MF = DAG.getMachineFunction();
- AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
- FuncInfo->setSMESaveBufferUsed();
-
- TargetLowering::ArgListTy Args;
- TargetLowering::ArgListEntry Entry;
- Entry.Ty = PointerType::getUnqual(*DAG.getContext());
- Entry.Node =
- DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
- Args.push_back(Entry);
-
- SDValue Callee =
- DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
- TLI.getPointerTy(DAG.getDataLayout()));
- auto *RetTy = Type::getVoidTy(*DAG.getContext());
- TargetLowering::CallLoweringInfo CLI(DAG);
- CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
- CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
- Callee, std::move(Args));
- return TLI.LowerCallTo(CLI).second;
-}
-
-static AArch64SME::ToggleCondition
-getSMToggleCondition(const SMECallAttrs &CallAttrs) {
- if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
- CallAttrs.caller().hasStreamingBody())
- return AArch64SME::Always;
- if (CallAttrs.callee().hasNonStreamingInterface())
- return AArch64SME::IfCallerIsStreaming;
- if (CallAttrs.callee().hasStreamingInterface())
- return AArch64SME::IfCallerIsNonStreaming;
-
- llvm_unreachable("Unsupported attributes");
-}
-
/// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
/// and add input and output parameter nodes.
SDValue
@@ -9118,91 +9330,25 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// Determine whether we need any streaming mode changes.
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
- auto DescribeCallsite =
- [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
- R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
- if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
- R << ore::NV("Callee", ES->getSymbol());
- else if (CLI.CB && CLI.CB->getCalledFunction())
- R << ore::NV("Callee", CLI.CB->getCalledFunction()->getName());
- else
- R << "unknown callee";
- R << "'";
- return R;
- };
-
+ SDValue SMECallStart;
+ SDValue PStateSM;
bool RequiresLazySave = CallAttrs.requiresLazySave();
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
- if (RequiresLazySave) {
- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
- SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
- TPIDR2.FrameIndex,
- DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
- Chain = DAG.getNode(
- ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
- DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
- TPIDR2ObjAddr);
- OptimizationRemarkEmitter ORE(&MF.getFunction());
- ORE.emit([&]() {
- auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMELazySaveZA",
- CLI.CB)
- : OptimizationRemarkAnalysis("sme", "SMELazySaveZA",
- &MF.getFunction());
- return DescribeCallsite(R) << " sets up a lazy save for ZA";
- });
- } else if (RequiresSaveAllZA) {
- assert(!CallAttrs.callee().hasSharedZAInterface() &&
- "Cannot share state that may not exist");
- Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
- /*IsSave=*/true);
- }
-
- SDValue PStateSM;
bool RequiresSMChange = CallAttrs.requiresSMChange();
- if (RequiresSMChange) {
- if (CallAttrs.caller().hasStreamingInterfaceOrBody())
- PStateSM = DAG.getConstant(1, DL, MVT::i64);
- else if (CallAttrs.caller().hasNonStreamingInterface())
- PStateSM = DAG.getConstant(0, DL, MVT::i64);
- else
- PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
- OptimizationRemarkEmitter ORE(&MF.getFunction());
- ORE.emit([&]() {
- auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
- CLI.CB)
- : OptimizationRemarkAnalysis("sme", "SMETransition",
- &MF.getFunction());
- DescribeCallsite(R) << " requires a streaming mode transition";
- return R;
- });
- }
-
- SDValue ZTFrameIdx;
- MachineFrameInfo &MFI = MF.getFrameInfo();
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
-
- // If the caller has ZT0 state which will not be preserved by the callee,
- // spill ZT0 before the call.
- if (ShouldPreserveZT0) {
- unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
- ZTFrameIdx = DAG.getFrameIndex(
- ZTObj,
- DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
-
- Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
- {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
- }
-
- // If caller shares ZT0 but the callee is not shared ZA, we need to stop
- // PSTATE.ZA before the call if there is no lazy-save active.
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
- assert((!DisableZA || !RequiresLazySave) &&
- "Lazy-save should have PSTATE.SM=1 on entry to the function");
+ bool IsSMECall = RequiresLazySave || RequiresSaveAllZA || RequiresSMChange ||
+ ShouldPreserveZT0 || DisableZA;
- if (DisableZA)
- Chain = DAG.getNode(
- AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
- DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
+ if (IsSMECall) {
+ PStateSM = DAG.getNode(
+ AArch64ISD::SME_CALL_START, DL,
+ DAG.getVTList(MVT::i64, MVT::Other, MVT::Glue), Chain,
+ DAG.getTargetConstant(unsigned(CallAttrs.caller()), DL, MVT::i32),
+ DAG.getTargetConstant(unsigned(CallAttrs.callee()), DL, MVT::i32),
+ DAG.getTargetConstant(unsigned(CallAttrs.callsite()), DL, MVT::i32));
+ Chain = SMECallStart = PStateSM.getValue(1);
+ }
// Adjust the stack pointer for the new arguments...
// These operations are automatically eliminated by the prolog/epilog pass
@@ -9470,18 +9616,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains);
SDValue InGlue;
+ SDValue SMECallSMChange;
if (RequiresSMChange) {
- if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
- Chain = DAG.getNode(AArch64ISD::VG_SAVE, DL,
- DAG.getVTList(MVT::Other, MVT::Glue), Chain);
- InGlue = Chain.getValue(1);
- }
-
- SDValue NewChain = changeStreamingMode(
- DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
- getSMToggleCondition(CallAttrs), PStateSM);
- Chain = NewChain.getValue(0);
- InGlue = NewChain.getValue(1);
+ SMECallSMChange = DAG.getNode(AArch64ISD::SME_CALL_SM_CHANGE, DL,
+ DAG.getVTList(MVT::Other, MVT::Glue), Chain,
+ SMECallStart, PStateSM);
+ Chain = SMECallSMChange;
+ InGlue = SMECallSMChange.getValue(1);
}
// Build a sequence of copy-to-reg nodes chained together with token chain
@@ -9662,63 +9803,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (!Ins.empty())
InGlue = Result.getValue(Result->getNumValues() - 1);
- if (RequiresSMChange) {
- assert(PStateSM && "Expected a PStateSM to be set");
- Result = changeStreamingMode(
- DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
- getSMToggleCondition(CallAttrs), PStateSM);
-
- if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
- InGlue = Result.getValue(1);
- Result =
- DAG.getNode(AArch64ISD::VG_RESTORE, DL,
- DAG.getVTList(MVT::Other, MVT::Glue), {Result, InGlue});
- }
- }
-
- if (CallAttrs.requiresEnablingZAAfterCall())
- // Unconditionally resume ZA.
- Result = DAG.getNode(
- AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
- DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
-
- if (ShouldPreserveZT0)
- Result =
- DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
- {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
-
- if (RequiresLazySave) {
- // Conditionally restore the lazy save using a pseudo node.
- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
- SDValue RegMask = DAG.getRegisterMask(
- TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
- SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
- "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
- SDValue TPIDR2_EL0 = DAG.getNode(
- ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
- DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
-
- // Copy the address of the TPIDR2 block into X0 before 'calling' the
- // RESTORE_ZA pseudo.
- SDValue Glue;
- SDValue TPIDR2Block = DAG.getFrameIndex(
- TPIDR2.FrameIndex,
- DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
- Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
- Result =
- DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
- {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
- RestoreRoutine, RegMask, Result.getValue(1)});
-
- // Finally reset the TPIDR2_EL0 register to 0.
- Result = DAG.getNode(
- ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
- DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
- DAG.getConstant(0, DL, MVT::i64));
- TPIDR2.Uses++;
- } else if (RequiresSaveAllZA) {
- Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result,
- /*IsSave=*/false);
+ if (IsSMECall) {
+ Result = DAG.getNode(AArch64ISD::SME_CALL_END, DL,
+ DAG.getVTList(MVT::Other, MVT::Glue), Result,
+ SMECallSMChange ? SMECallSMChange : SMECallStart,
+ PStateSM ? PStateSM : DAG.getUNDEF(MVT::i64), InGlue);
}
if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
@@ -27828,6 +27917,18 @@ AArch64TargetLowering::getPreferredVectorAction(MVT VT) const {
return TargetLoweringBase::getPreferredVectorAction(VT);
}
+TargetLoweringBase::LegalizeAction
+AArch64TargetLowering::getCustomOperationAction(SDNode &N) const {
+ switch (N.getOpcode()) {
+ default:
+ return Legal;
+ case AArch64ISD::SME_CALL_START:
+ case AArch64ISD::SME_CALL_SM_CHANGE:
+ case AArch64ISD::SME_CALL_END:
+ return Custom;
+ }
+}
+
// In v8.4a, ldp and stp instructions are guaranteed to be single-copy atomic
// provided the address is 16-byte aligned.
bool AArch64TargetLowering::isOpSuitableForLDPSTP(const Instruction *I) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index b59526bf01888..174e19f3badae 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -79,6 +79,10 @@ enum NodeType : unsigned {
RESTORE_ZT,
SAVE_ZT,
+ SME_CALL_START,
+ SME_CALL_SM_CHANGE,
+ SME_CALL_END,
+
// A call with the callee in x16, i.e. "blr x16".
CALL_ARM64EC_TO_X64,
@@ -823,6 +827,9 @@ class AArch64TargetLowering : public TargetLowering {
TargetLoweringBase::LegalizeTypeAction
getPreferredVectorAction(MVT VT) const override;
+ TargetLoweringBase::LegalizeAction
+ getCustomOperationAction(SDNode &) const override;
+
/// If the target has a standard location for the stack protector cookie,
/// returns the address of that location. Otherwise, returns nullptr.
Value *getIRStackGuard(IRBuilderBase &IRB) const override;
@@ -1028,6 +1035,11 @@ class AArch64TargetLowering : public TargetLowering {
/// True if stack clash protection is enabled for this functions.
bool hasInlineStackProbe(const MachineFunction &MF) const override;
+ // Returns the runtime value for PSTATE.SM by generating a call to
+ // __arm_sme_state.
+ SDValue getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL,
+ EVT VT) const;
+
private:
/// Keep a pointer to the AArch64Subtarget around so that we can
/// make the right decision when generating code for different targets.
@@ -1211,6 +1223,9 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerSME_CALL_START(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerSME_CALL_SM_CHANGE(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerSME_CALL_END(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const;
@@ -1347,11 +1362,6 @@ class AArch64TargetLowering : public TargetLowering {
// This function does not handle predicate bitcasts.
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
- // Returns the runtime value for PSTATE.SM by generating a call to
- // __arm_sme_state.
- SDValue getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL,
- EVT VT) const;
-
bool preferScalarizeSplat(SDNode *N) const override;
unsigned getMinimumJumpTableEntries() const override;
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index d3026ca45c349..02cd0398bea04 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -245,6 +245,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
int64_t VGIdx = std::numeric_limits<int>::max();
int64_t StreamingVGIdx = std::numeric_limits<int>::max();
+ // The stack slot where ZT0 is stored.
+ int64_t ZT0Idx = std::numeric_limits<int>::max();
+
public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
@@ -275,6 +278,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
int64_t getStreamingVGIdx() const { return StreamingVGIdx; };
void setStreamingVGIdx(unsigned FrameIdx) { StreamingVGIdx = FrameIdx; };
+ int64_t getZT0Idx() const { return ZT0Idx; };
+ void setZT0Idx(unsigned FrameIdx) { ZT0Idx = FrameIdx; };
+
bool isSVECC() const { return IsSVECC; };
void setIsSVECC(bool s) { IsSVECC = s; };
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index f1be0ecbee7ed..4d136f1f65263 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -145,6 +145,8 @@ class SMEAttrs {
return Bitmask == Other.Bitmask;
}
+ explicit operator unsigned() const { return Bitmask; }
+
private:
void addKnownFunctionAttrs(StringRef FuncName);
};
>From 2efbedd239a3a8cfa002065bfde3570a9b999fb0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 30 May 2025 08:45:20 +0000
Subject: [PATCH 3/3] [AArch64][SME] Merge back-to-back SME call regions
---
.../Target/AArch64/AArch64ISelLowering.cpp | 61 +++++++++++++++++++
.../AArch64/Utils/AArch64SMEAttributes.h | 7 +++
.../CodeGen/AArch64/sme-lazy-save-call.ll | 19 ++----
.../test/CodeGen/AArch64/sme-peephole-opts.ll | 27 ++------
4 files changed, 79 insertions(+), 35 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9866149868fd7..54c616f15b8cd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -26719,6 +26719,65 @@ static SDValue performSHLCombine(SDNode *N,
return DAG.getNode(ISD::AND, DL, VT, NewShift, NewRHS);
}
+static SDValue performSMECallCombine(SDNode *SMECallEnd,
+ TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG) {
+ SDLoc DL(SMECallEnd);
+ SDNode *SMECallStart = findSMECallStart(SMECallEnd);
+ SDValue StartOutGlue = SDValue(SMECallStart, 2);
+ if (!StartOutGlue.use_empty())
+ return SDValue();
+
+ SDValue StartChain = SMECallStart->getOperand(0);
+ if (StartChain->getOpcode() != AArch64ISD::SME_CALL_END)
+ return SDValue();
+
+ // TODO: This probably does not need to be a CALLSEQ_END.
+ SDNode *PrevSMECallEnd = StartChain.getNode();
+ SDValue PrevCallOutChain = PrevSMECallEnd->getOperand(0);
+ if (PrevCallOutChain->getOpcode() != ISD::CALLSEQ_END)
+ return SDValue();
+
+ SDNode *PrevSMECallStart = findSMECallStart(PrevSMECallEnd);
+ SMECallAttrs CallAttrs = findSMECallAttrs(SMECallStart);
+ SMECallAttrs PrevCallAttrs = findSMECallAttrs(PrevSMECallStart);
+
+ // TODO: Handle case where we're already in (e.g.) streaming mode and just
+ // need to enable ZA etc.
+ if (CallAttrs != PrevCallAttrs)
+ return SDValue();
+
+ SDNode *MaybeSMSwitch = PrevSMECallEnd->getOperand(1).getNode();
+ if (MaybeSMSwitch->getOpcode() == AArch64ISD::SME_CALL_SM_CHANGE) {
+ SDNode *SMSwitch = SMECallEnd->getOperand(1).getNode();
+ // Remove duplicate SME_CALL_SM_CHANGE.
+ // FIXME: Can we avoid adding fake glue? This is needed as we need to remove
+ // this (duplicate) SME_CALL_SM_CHANGE, but it has been glued to another
+ // node so we need something to replace the glue.
+ SDValue FakeGlue = DAG.getUNDEF(MVT::Glue);
+ SDValue NopSMESwitch = DAG.getMergeValues(
+ {SMSwitch->getOperand(0), FakeGlue}, SDLoc(SMECallEnd));
+ DAG.ReplaceAllUsesWith(SMSwitch, NopSMESwitch.getNode());
+ }
+
+ // Update the last SME_CALL_END to point to the SME_CALL_START or
+ // SME_CALL_SM_CHANGE from the previous SME_CALL region:
+ DAG.UpdateNodeOperands(
+ SMECallEnd,
+ /*Chain=*/SMECallEnd->getOperand(0),
+ /*SMECallStart Or SMSwitch=*/PrevSMECallEnd->getOperand(1),
+ /*PStateSM=*/SMECallEnd->getOperand(2),
+ /*InGlue=*/SMECallEnd->getOperand(3));
+ // Remove the SME_CALL_END for the previous SME_CALL region:
+ DAG.ReplaceAllUsesWith(PrevSMECallEnd, PrevCallOutChain.getNode());
+ // Remove the SME_CALL_START for the current/next SME_CALL region:
+ SDValue NopSMECallStart = DAG.getMergeValues(
+ {SDValue(PrevSMECallStart, 0), PrevCallOutChain, DAG.getUNDEF(MVT::Glue)},
+ DL);
+ DAG.ReplaceAllUsesWith(SMECallStart, NopSMECallStart.getNode());
+ return SDValue();
+}
+
SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@@ -26880,6 +26939,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case AArch64ISD::UMULL:
case AArch64ISD::PMULL:
return performMULLCombine(N, DCI, DAG);
+ case AArch64ISD::SME_CALL_END:
+ return performSMECallCombine(N, DCI, DAG);
case ISD::INTRINSIC_VOID:
case ISD::INTRINSIC_W_CHAIN:
switch (N->getConstantOperandVal(1)) {
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 4d136f1f65263..2c3c8dff691a9 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -203,6 +203,13 @@ class SMECallAttrs {
return caller().hasAgnosticZAInterface() &&
!callee().hasAgnosticZAInterface() && !callee().isSMEABIRoutine();
}
+
+ bool operator==(SMECallAttrs const &Other) const {
+ return caller() == Other.caller() && callee() == Other.callee() &&
+ callsite() == Other.callsite();
+ }
+
+ bool operator!=(SMECallAttrs const &Other) const { return !(*this == Other); }
};
} // namespace llvm
diff --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
index 26c6dc4eb978b..cb72a8eb3d5ac 100644
--- a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
+++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
@@ -41,16 +41,17 @@ define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
; CHECK-LABEL: test_lazy_save_2_callees:
; CHECK: // %bb.0:
; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
-; CHECK-NEXT: stp x20, x19, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
; CHECK-NEXT: mov x29, sp
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
-; CHECK-NEXT: sub x20, x29, #16
+; CHECK-NEXT: sub x10, x29, #16
; CHECK-NEXT: stp x9, x8, [x29, #-16]
-; CHECK-NEXT: msr TPIDR2_EL0, x20
+; CHECK-NEXT: msr TPIDR2_EL0, x10
+; CHECK-NEXT: bl private_za_callee
; CHECK-NEXT: bl private_za_callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
@@ -60,18 +61,8 @@ define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
; CHECK-NEXT: bl __arm_tpidr2_restore
; CHECK-NEXT: .LBB1_2:
; CHECK-NEXT: msr TPIDR2_EL0, xzr
-; CHECK-NEXT: msr TPIDR2_EL0, x20
-; CHECK-NEXT: bl private_za_callee
-; CHECK-NEXT: smstart za
-; CHECK-NEXT: mrs x8, TPIDR2_EL0
-; CHECK-NEXT: sub x0, x29, #16
-; CHECK-NEXT: cbnz x8, .LBB1_4
-; CHECK-NEXT: // %bb.3:
-; CHECK-NEXT: bl __arm_tpidr2_restore
-; CHECK-NEXT: .LBB1_4:
-; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: mov sp, x29
-; CHECK-NEXT: ldp x20, x19, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @private_za_callee()
diff --git a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
index 130a316bcc2ba..400f8937df90c 100644
--- a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
@@ -75,21 +75,11 @@ define void @test2() nounwind "aarch64_pstate_sm_compatible" {
; CHECK-NEXT: smstop sm
; CHECK-NEXT: .LBB2_2:
; CHECK-NEXT: bl callee
+; CHECK-NEXT: bl callee
; CHECK-NEXT: tbz w19, #0, .LBB2_4
; CHECK-NEXT: // %bb.3:
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB2_4:
-; CHECK-NEXT: bl __arm_sme_state
-; CHECK-NEXT: and x19, x0, #0x1
-; CHECK-NEXT: tbz w19, #0, .LBB2_6
-; CHECK-NEXT: // %bb.5:
-; CHECK-NEXT: smstop sm
-; CHECK-NEXT: .LBB2_6:
-; CHECK-NEXT: bl callee
-; CHECK-NEXT: tbz w19, #0, .LBB2_8
-; CHECK-NEXT: // %bb.7:
-; CHECK-NEXT: smstart sm
-; CHECK-NEXT: .LBB2_8:
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
; CHECK-NEXT: ldr x19, [sp, #80] // 8-byte Folded Reload
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
@@ -252,22 +242,17 @@ define float @test6(float %f) nounwind "aarch64_pstate_sm_enabled" {
define void @test7() nounwind "aarch64_inout_zt0" {
; CHECK-LABEL: test7:
; CHECK: // %bb.0:
-; CHECK-NEXT: sub sp, sp, #144
-; CHECK-NEXT: stp x30, x19, [sp, #128] // 16-byte Folded Spill
-; CHECK-NEXT: add x19, sp, #64
-; CHECK-NEXT: str zt0, [x19]
-; CHECK-NEXT: smstop za
-; CHECK-NEXT: bl callee
-; CHECK-NEXT: smstart za
-; CHECK-NEXT: ldr zt0, [x19]
+; CHECK-NEXT: sub sp, sp, #80
+; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
+; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
-; CHECK-NEXT: ldp x30, x19, [sp, #128] // 16-byte Folded Reload
-; CHECK-NEXT: add sp, sp, #144
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
call void @callee()
call void @callee()
More information about the llvm-commits
mailing list