[llvm] [AArch64] Add SME peephole optimizer pass (PR #104612)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 20 01:01:53 PDT 2024


https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/104612

>From cdfa9cc6417db52d19af8f8111842d28f4112eba Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 16 Aug 2024 16:43:45 +0100
Subject: [PATCH 1/4] Precommit test

---
 .../test/CodeGen/AArch64/sme-peephole-opts.ll | 560 ++++++++++++++++++
 1 file changed, 560 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/sme-peephole-opts.ll

diff --git a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
new file mode 100644
index 00000000000000..66efb71596171a
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
@@ -0,0 +1,560 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve,+sme2 < %s | FileCheck %s
+
+declare void @callee()
+declare void @callee_farg(float)
+declare float @callee_farg_fret(float)
+
+; normal caller -> streaming callees
+define void @test0() nounwind {
+; CHECK-LABEL: test0:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_pstate_sm_enabled"
+  call void @callee() "aarch64_pstate_sm_enabled"
+  ret void
+}
+
+; streaming caller -> normal callees
+define void @test1() nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: test1:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee()
+  call void @callee()
+  ret void
+}
+
+; streaming-compatible caller -> normal callees
+; these conditional smstart/smstop are not yet optimized away.
+define void @test2() nounwind "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: test2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #80] // 8-byte Folded Spill
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbz w19, #0, .LBB2_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB2_2:
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    tbz w19, #0, .LBB2_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB2_4:
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbz w19, #0, .LBB2_6
+; CHECK-NEXT:  // %bb.5:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB2_6:
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    tbz w19, #0, .LBB2_8
+; CHECK-NEXT:  // %bb.7:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB2_8:
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x19, [sp, #80] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee()
+  call void @callee()
+  ret void
+}
+
+; streaming-compatible caller -> mixed callees
+define void @test3() nounwind "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: test3:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #80] // 8-byte Folded Spill
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbnz w19, #0, .LBB3_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB3_2:
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    tbnz w19, #0, .LBB3_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB3_4:
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbz w19, #0, .LBB3_6
+; CHECK-NEXT:  // %bb.5:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB3_6:
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    tbz w19, #0, .LBB3_8
+; CHECK-NEXT:  // %bb.7:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB3_8:
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbnz w19, #0, .LBB3_10
+; CHECK-NEXT:  // %bb.9:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB3_10:
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    tbnz w19, #0, .LBB3_12
+; CHECK-NEXT:  // %bb.11:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB3_12:
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x19, [sp, #80] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_pstate_sm_enabled"
+  call void @callee()
+  call void @callee() "aarch64_pstate_sm_enabled"
+  ret void
+}
+
+; streaming caller -> normal callees (pass 0.0f)
+define void @test4() nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: test4:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    fmov s0, wzr
+; CHECK-NEXT:    bl callee_farg
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    fmov s0, wzr
+; CHECK-NEXT:    bl callee_farg
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee_farg(float zeroinitializer)
+  call void @callee_farg(float zeroinitializer)
+  ret void
+}
+
+; streaming caller -> normal callees (pass fp arg)
+define void @test5(float %f) nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: test5:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #96
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    stp d15, d14, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #80] // 16-byte Folded Spill
+; CHECK-NEXT:    str s0, [sp, #12] // 4-byte Folded Spill
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldr s0, [sp, #12] // 4-byte Folded Reload
+; CHECK-NEXT:    bl callee_farg
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldr s0, [sp, #12] // 4-byte Folded Reload
+; CHECK-NEXT:    bl callee_farg
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #80] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #96
+; CHECK-NEXT:    ret
+  call void @callee_farg(float %f)
+  call void @callee_farg(float %f)
+  ret void
+}
+
+define float @test6(float %f) nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: test6:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #96
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    stp d15, d14, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #80] // 16-byte Folded Spill
+; CHECK-NEXT:    str s0, [sp, #12] // 4-byte Folded Spill
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldr s0, [sp, #12] // 4-byte Folded Reload
+; CHECK-NEXT:    bl callee_farg_fret
+; CHECK-NEXT:    str s0, [sp, #12] // 4-byte Folded Spill
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldr s0, [sp, #12] // 4-byte Folded Reload
+; CHECK-NEXT:    bl callee_farg_fret
+; CHECK-NEXT:    str s0, [sp, #12] // 4-byte Folded Spill
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    ldr s0, [sp, #12] // 4-byte Folded Reload
+; CHECK-NEXT:    ldp d9, d8, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #80] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #96
+; CHECK-NEXT:    ret
+  %res0 = call float @callee_farg_fret(float %f)
+  %res1 = call float @callee_farg_fret(float %res0)
+  ret float %res1
+}
+
+; save/restore zt0 to stack is not yet optimised away by the pass,
+; because of the ldr/str of zt0, which will need some further analysis
+; to make sure if the redundant str can be removed.
+define void @test7() nounwind "aarch64_inout_zt0" {
+; CHECK-LABEL: test7:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #144
+; CHECK-NEXT:    stp x30, x19, [sp, #128] // 16-byte Folded Spill
+; CHECK-NEXT:    add x19, sp, #64
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mov x19, sp
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    ldp x30, x19, [sp, #128] // 16-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #144
+; CHECK-NEXT:    ret
+  call void @callee()
+  call void @callee()
+  ret void
+}
+
+; test that 'smstop za' is not cancelled out with 'smstart sm'.
+define void @test8() nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: test8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee()
+  call void @llvm.aarch64.sme.za.disable()
+  ret void
+}
+
+; test that the 'smstart' and 'smstop' are entirely removed,
+; along with any code to read 'vg' for the CFI.
+define void @test9() "aarch64_pstate_sm_body" {
+; CHECK-LABEL: test9:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 96
+; CHECK-NEXT:    rdsvl x9, #1
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    lsr x9, x9, #3
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    str x9, [sp, #80] // 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_offset vg, -16
+; CHECK-NEXT:    .cfi_offset w30, -32
+; CHECK-NEXT:    .cfi_offset b8, -40
+; CHECK-NEXT:    .cfi_offset b9, -48
+; CHECK-NEXT:    .cfi_offset b10, -56
+; CHECK-NEXT:    .cfi_offset b11, -64
+; CHECK-NEXT:    .cfi_offset b12, -72
+; CHECK-NEXT:    .cfi_offset b13, -80
+; CHECK-NEXT:    .cfi_offset b14, -88
+; CHECK-NEXT:    .cfi_offset b15, -96
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    .cfi_offset vg, -24
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    .cfi_restore vg
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    .cfi_def_cfa_offset 0
+; CHECK-NEXT:    .cfi_restore w30
+; CHECK-NEXT:    .cfi_restore b8
+; CHECK-NEXT:    .cfi_restore b9
+; CHECK-NEXT:    .cfi_restore b10
+; CHECK-NEXT:    .cfi_restore b11
+; CHECK-NEXT:    .cfi_restore b12
+; CHECK-NEXT:    .cfi_restore b13
+; CHECK-NEXT:    .cfi_restore b14
+; CHECK-NEXT:    .cfi_restore b15
+; CHECK-NEXT:    ret
+  call void @callee()
+  ret void
+}
+
+; similar to above, but in this case only the first
+; 'smstart, smstop' pair can be removed and the code required
+; for the CFI is still needed.
+define void @test10() "aarch64_pstate_sm_body" {
+; CHECK-LABEL: test10:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 96
+; CHECK-NEXT:    rdsvl x9, #1
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    lsr x9, x9, #3
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    str x9, [sp, #80] // 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_offset vg, -16
+; CHECK-NEXT:    .cfi_offset w30, -32
+; CHECK-NEXT:    .cfi_offset b8, -40
+; CHECK-NEXT:    .cfi_offset b9, -48
+; CHECK-NEXT:    .cfi_offset b10, -56
+; CHECK-NEXT:    .cfi_offset b11, -64
+; CHECK-NEXT:    .cfi_offset b12, -72
+; CHECK-NEXT:    .cfi_offset b13, -80
+; CHECK-NEXT:    .cfi_offset b14, -88
+; CHECK-NEXT:    .cfi_offset b15, -96
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    .cfi_offset vg, -24
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    .cfi_restore vg
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    .cfi_offset vg, -24
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    .cfi_restore vg
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    .cfi_def_cfa_offset 0
+; CHECK-NEXT:    .cfi_restore w30
+; CHECK-NEXT:    .cfi_restore b8
+; CHECK-NEXT:    .cfi_restore b9
+; CHECK-NEXT:    .cfi_restore b10
+; CHECK-NEXT:    .cfi_restore b11
+; CHECK-NEXT:    .cfi_restore b12
+; CHECK-NEXT:    .cfi_restore b13
+; CHECK-NEXT:    .cfi_restore b14
+; CHECK-NEXT:    .cfi_restore b15
+; CHECK-NEXT:    ret
+  call void @callee()
+  call void @callee() "aarch64_pstate_sm_enabled"
+  call void @callee()
+  ret void
+}
+
+; test that an operation like a store is executed in the right
+; streaming mode and blocks the optimization.
+define void @test11(ptr %p) nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: test11:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #80] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x19, x0
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    mov z0.b, #0 // =0x0
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    st1b { z0.b }, p0, [x19]
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x19, [sp, #80] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee()
+  store <vscale x 16 x i8> zeroinitializer, ptr %p
+  call void @callee()
+  ret void
+}
+
+; test that 'smstart sm' and 'smstop za' don't get folded away together.
+; we can further optimize this test by considering streaming mode
+; separately from ZA.
+define void @test12() "aarch64_pstate_sm_body" {
+; CHECK-LABEL: test12:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 96
+; CHECK-NEXT:    rdsvl x9, #1
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    lsr x9, x9, #3
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    str x9, [sp, #80] // 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_offset vg, -16
+; CHECK-NEXT:    .cfi_offset w30, -32
+; CHECK-NEXT:    .cfi_offset b8, -40
+; CHECK-NEXT:    .cfi_offset b9, -48
+; CHECK-NEXT:    .cfi_offset b10, -56
+; CHECK-NEXT:    .cfi_offset b11, -64
+; CHECK-NEXT:    .cfi_offset b12, -72
+; CHECK-NEXT:    .cfi_offset b13, -80
+; CHECK-NEXT:    .cfi_offset b14, -88
+; CHECK-NEXT:    .cfi_offset b15, -96
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    .cfi_offset vg, -24
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    .cfi_restore vg
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    .cfi_def_cfa_offset 0
+; CHECK-NEXT:    .cfi_restore w30
+; CHECK-NEXT:    .cfi_restore b8
+; CHECK-NEXT:    .cfi_restore b9
+; CHECK-NEXT:    .cfi_restore b10
+; CHECK-NEXT:    .cfi_restore b11
+; CHECK-NEXT:    .cfi_restore b12
+; CHECK-NEXT:    .cfi_restore b13
+; CHECK-NEXT:    .cfi_restore b14
+; CHECK-NEXT:    .cfi_restore b15
+; CHECK-NEXT:    ret
+  call void @llvm.aarch64.sme.za.disable()
+  call void @callee()
+  call void @llvm.aarch64.sme.za.enable()
+  ret void
+}
+
+; We conservatively don't remove the smstart/smstop pair yet when there are COPY
+; instructions that copy SVE registers, because we can't yet conclusively prove
+; that this is safe (although for this example, it would be).
+define void @test13(ptr %ptr) nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: test13:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT:    cntd x9
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x29, x30, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x9, x19, [sp, #80] // 16-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    mov z0.s, #0 // =0x0
+; CHECK-NEXT:    mov x19, x0
+; CHECK-NEXT:    str z0, [sp] // 16-byte Folded Spill
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldr z0, [sp] // 16-byte Folded Reload
+; CHECK-NEXT:    bl callee_farg_fret
+; CHECK-NEXT:    str z0, [sp] // 16-byte Folded Spill
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldr z0, [sp] // 16-byte Folded Reload
+; CHECK-NEXT:    bl callee_farg_fret
+; CHECK-NEXT:    str z0, [sp] // 16-byte Folded Spill
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    ldr z0, [sp] // 16-byte Folded Reload
+; CHECK-NEXT:    st1w { z0.s }, p0, [x19]
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldp x29, x30, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x19, [sp, #88] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  %res0 = call <vscale x 4 x float> @callee_farg_fret(<vscale x 4 x float> zeroinitializer)
+  %res1 = call <vscale x 4 x float> @callee_farg_fret(<vscale x 4 x float> %res0)
+  store <vscale x 4 x float> %res1, ptr %ptr
+  ret void
+}

>From 84c917e981e6387b613dadef8b17dc2c061988be Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 16 Aug 2024 14:22:15 +0100
Subject: [PATCH 2/4] [AArch64] Add SME peephole optimizer pass

This pass removes back-to-back smstart/smstop instructions
to reduce the number of streaming mode changes in a function.

The implementation as proposed doesn't aim to solve all problems
yet and suggests a number of cases that can be optimized in the
future.
---
 llvm/lib/Target/AArch64/AArch64.h             |   2 +
 .../Target/AArch64/AArch64TargetMachine.cpp   |   9 +
 llvm/lib/Target/AArch64/CMakeLists.txt        |   1 +
 llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp    | 216 ++++++++++++++++++
 llvm/test/CodeGen/AArch64/O3-pipeline.ll      |   1 +
 .../test/CodeGen/AArch64/sme-peephole-opts.ll |  59 +----
 .../CodeGen/AArch64/sme-streaming-body.ll     |  20 +-
 .../AArch64/sme-streaming-interface.ll        |   2 -
 .../CodeGen/AArch64/sme-toggle-pstateza.ll    |   7 +-
 llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll  |  12 -
 .../streaming-compatible-memory-ops.ll        |   2 -
 11 files changed, 241 insertions(+), 90 deletions(-)
 create mode 100644 llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp

diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index ff19327c692021..62fbf94e803f0c 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -59,6 +59,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass();
 
 FunctionPass *createAArch64CollectLOHPass();
 FunctionPass *createSMEABIPass();
+FunctionPass *createSMEPeepholeOptPass();
 ModulePass *createSVEIntrinsicOptsPass();
 InstructionSelector *
 createAArch64InstructionSelector(const AArch64TargetMachine &,
@@ -110,6 +111,7 @@ void initializeFalkorHWPFFixPass(PassRegistry&);
 void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&);
 void initializeLDTLSCleanupPass(PassRegistry&);
 void initializeSMEABIPass(PassRegistry &);
+void initializeSMEPeepholeOptPass(PassRegistry &);
 void initializeSVEIntrinsicOptsPass(PassRegistry &);
 void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &);
 } // end namespace llvm
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index bcd677310d1247..bd5684a287381a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -167,6 +167,11 @@ static cl::opt<bool>
                            cl::desc("Enable SVE intrinsic opts"),
                            cl::init(true));
 
+static cl::opt<bool>
+    EnableSMEPeepholeOpt("enable-aarch64-sme-peephole-opt", cl::init(true),
+                         cl::Hidden,
+                         cl::desc("Perform SME peephole optimization"));
+
 static cl::opt<bool> EnableFalkorHWPFFix("aarch64-enable-falkor-hwpf-fix",
                                          cl::init(true), cl::Hidden);
 
@@ -256,6 +261,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAArch64Target() {
   initializeLDTLSCleanupPass(*PR);
   initializeKCFIPass(*PR);
   initializeSMEABIPass(*PR);
+  initializeSMEPeepholeOptPass(*PR);
   initializeSVEIntrinsicOptsPass(*PR);
   initializeAArch64SpeculationHardeningPass(*PR);
   initializeAArch64SLSHardeningPass(*PR);
@@ -754,6 +760,9 @@ bool AArch64PassConfig::addGlobalInstructionSelect() {
 }
 
 void AArch64PassConfig::addMachineSSAOptimization() {
+  if (TM->getOptLevel() != CodeGenOptLevel::None && EnableSMEPeepholeOpt)
+    addPass(createSMEPeepholeOptPass());
+
   // Run default MachineSSAOptimization first.
   TargetPassConfig::addMachineSSAOptimization();
 
diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt
index 639bc0707dff24..da13db8e68b0e6 100644
--- a/llvm/lib/Target/AArch64/CMakeLists.txt
+++ b/llvm/lib/Target/AArch64/CMakeLists.txt
@@ -87,6 +87,7 @@ add_llvm_target(AArch64CodeGen
   AArch64TargetObjectFile.cpp
   AArch64TargetTransformInfo.cpp
   SMEABIPass.cpp
+  SMEPeepholeOpt.cpp
   SVEIntrinsicOpts.cpp
   AArch64SIMDInstrOpt.cpp
 
diff --git a/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
new file mode 100644
index 00000000000000..e6b8c6664f9fee
--- /dev/null
+++ b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
@@ -0,0 +1,216 @@
+//===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This pass tries to remove back-to-back (smstart, smstop) and
+// (smstop, smstart) sequences. The pass is conservative when it cannot
+// determine that it is safe to remove these sequences.
+//===----------------------------------------------------------------------===//
+
+#include "AArch64InstrInfo.h"
+#include "AArch64MachineFunctionInfo.h"
+#include "AArch64Subtarget.h"
+#include "Utils/AArch64SMEAttributes.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-sme-peephole-opt"
+
+namespace {
+
+struct SMEPeepholeOpt : public MachineFunctionPass {
+  static char ID;
+
+  SMEPeepholeOpt() : MachineFunctionPass(ID) {
+    initializeSMEPeepholeOptPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+
+  StringRef getPassName() const override {
+    return "SME Peephole Optimization pass";
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesCFG();
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  bool optimizeStartStopPairs(MachineBasicBlock &MBB,
+                              bool &HasRemainingSMChange) const;
+};
+
+char SMEPeepholeOpt::ID = 0;
+
+} // end anonymous namespace
+
+static bool isConditionalStartStop(const MachineInstr *MI) {
+  return MI->getOpcode() == AArch64::MSRpstatePseudo;
+}
+
+static bool isMatchingStartStopPair(const MachineInstr *MI1,
+                                    const MachineInstr *MI2) {
+  // We only consider the same type of streaming mode change here, i.e.
+  // start/stop SM, or start/stop ZA pairs.
+  if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm())
+    return false;
+
+  // One must be 'start', the other must be 'stop'
+  if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm())
+    return false;
+
+  bool IsConditional = isConditionalStartStop(MI2);
+  if (isConditionalStartStop(MI1) != IsConditional)
+    return false;
+
+  if (!IsConditional)
+    return true;
+
+  // Check to make sure the conditional start/stop pairs are identical.
+  if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm())
+    return false;
+
+  // Ensure reg masks are identical.
+  if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
+    return false;
+
+  // This optimisation is unlikely to happen in practice for conditional
+  // smstart/smstop pairs as the virtual registers for pstate.sm will always
+  // be different.
+  // TODO: For this optimisation to apply to conditional smstart/smstop,
+  // this pass will need to do more work to remove redundant calls to
+  // __arm_sme_state.
+
+  // Only consider conditional start/stop pairs which read the same register
+  // holding the original value of pstate.sm, as some conditional start/stops
+  // require the state on entry to the function.
+  if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
+    Register Reg1 = MI1->getOperand(3).getReg();
+    Register Reg2 = MI2->getOperand(3).getReg();
+    if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2)
+      return false;
+  }
+
+  return true;
+}
+
+static bool ChangesStreamingMode(const MachineInstr *MI) {
+  assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 ||
+          MI->getOpcode() == AArch64::MSRpstatePseudo) &&
+         "Expected MI to be a smstart/smstop instruction");
+  return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM ||
+         MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA;
+}
+
+bool SMEPeepholeOpt::optimizeStartStopPairs(MachineBasicBlock &MBB,
+                                            bool &HasRemainingSMChange) const {
+  SmallVector<MachineInstr *, 4> ToBeRemoved;
+
+  bool Changed = false;
+  MachineInstr *Prev = nullptr;
+  HasRemainingSMChange = false;
+
+  auto Reset = [&]() {
+    if (Prev && ChangesStreamingMode(Prev))
+      HasRemainingSMChange = true;
+    Prev = nullptr;
+    ToBeRemoved.clear();
+  };
+
+  // Walk through instructions in the block trying to find pairs of smstart
+  // and smstop nodes that cancel each other out. We only permit a limited
+  // set of instructions to appear between them, otherwise we reset our
+  // tracking.
+  for (MachineInstr &MI : make_early_inc_range(MBB)) {
+    switch (MI.getOpcode()) {
+    default:
+      Reset();
+      break;
+    case AArch64::COPY: {
+      // Permit copies of 32 and 64-bit registers.
+      if (!MI.getOperand(1).isReg()) {
+        Reset();
+        break;
+      }
+      Register Reg = MI.getOperand(1).getReg();
+      if (!AArch64::GPR32RegClass.contains(Reg) &&
+          !AArch64::GPR64RegClass.contains(Reg))
+        Reset();
+      break;
+    }
+    case AArch64::ADJCALLSTACKDOWN:
+    case AArch64::ADJCALLSTACKUP:
+    case AArch64::ANDXri:
+    case AArch64::ADDXri:
+      // We permit these as they don't generate SVE/NEON instructions.
+      break;
+    case AArch64::VGRestorePseudo:
+    case AArch64::VGSavePseudo:
+      // When the smstart/smstop are removed, we should also remove
+      // the pseudos that save/restore the VG value for CFI info.
+      ToBeRemoved.push_back(&MI);
+      break;
+    case AArch64::MSRpstatesvcrImm1:
+    case AArch64::MSRpstatePseudo: {
+      if (!Prev)
+        Prev = &MI;
+      else if (isMatchingStartStopPair(Prev, &MI)) {
+        // If they match, we can remove them, and possibly any instructions
+        // that we marked for deletion in between.
+        Prev->eraseFromParent();
+        MI.eraseFromParent();
+        for (MachineInstr *TBR : ToBeRemoved)
+          TBR->eraseFromParent();
+        ToBeRemoved.clear();
+        Prev = nullptr;
+        Changed = true;
+      } else {
+        Reset();
+        Prev = &MI;
+      }
+      break;
+    }
+    }
+  }
+
+  return Changed;
+}
+
+INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt",
+                "SME Peephole Optimization", false, false)
+
+bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
+  if (skipFunction(MF.getFunction()))
+    return false;
+
+  if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
+    return false;
+
+  assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
+
+  bool Changed = false;
+  bool FunctionHasRemainingSMChange = false;
+
+  // Even if the block lives in a function with no SME attributes attached we
+  // still have to analyze all the blocks because we may call a streaming
+  // function that requires smstart/smstop pairs.
+  for (MachineBasicBlock &MBB : MF) {
+    bool BlockHasRemainingSMChange;
+    Changed |= optimizeStartStopPairs(MBB, BlockHasRemainingSMChange);
+    FunctionHasRemainingSMChange |= BlockHasRemainingSMChange;
+  }
+
+  AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
+  if (Changed && AFI->hasStreamingModeChanges())
+    AFI->setHasStreamingModeChanges(FunctionHasRemainingSMChange);
+
+  return Changed;
+}
+
+FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }
diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
index 72a888bde5ebbc..3465b717261cf5 100644
--- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
@@ -122,6 +122,7 @@
 ; CHECK-NEXT:       MachineDominator Tree Construction
 ; CHECK-NEXT:       AArch64 Local Dynamic TLS Access Clean-up
 ; CHECK-NEXT:       Finalize ISel and expand pseudo-instructions
+; CHECK-NEXT:       SME Peephole Optimization pass
 ; CHECK-NEXT:       Lazy Machine Block Frequency Analysis
 ; CHECK-NEXT:       Early Tail Duplication
 ; CHECK-NEXT:       Optimize machine instruction PHIs
diff --git a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
index 66efb71596171a..275327e54dee86 100644
--- a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
@@ -17,8 +17,6 @@ define void @test0() nounwind {
 ; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:    bl callee
-; CHECK-NEXT:    smstop sm
-; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
@@ -44,8 +42,6 @@ define void @test1() nounwind "aarch64_pstate_sm_enabled" {
 ; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    bl callee
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
@@ -175,8 +171,6 @@ define void @test4() nounwind "aarch64_pstate_sm_enabled" {
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    fmov s0, wzr
 ; CHECK-NEXT:    bl callee_farg
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    fmov s0, wzr
 ; CHECK-NEXT:    bl callee_farg
 ; CHECK-NEXT:    smstart sm
@@ -206,8 +200,6 @@ define void @test5(float %f) nounwind "aarch64_pstate_sm_enabled" {
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    ldr s0, [sp, #12] // 4-byte Folded Reload
 ; CHECK-NEXT:    bl callee_farg
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    ldr s0, [sp, #12] // 4-byte Folded Reload
 ; CHECK-NEXT:    bl callee_farg
 ; CHECK-NEXT:    smstart sm
@@ -315,48 +307,11 @@ define void @test8() nounwind "aarch64_pstate_sm_enabled" {
 define void @test9() "aarch64_pstate_sm_body" {
 ; CHECK-LABEL: test9:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
-; CHECK-NEXT:    .cfi_def_cfa_offset 96
-; CHECK-NEXT:    rdsvl x9, #1
-; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
-; CHECK-NEXT:    lsr x9, x9, #3
-; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
-; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
-; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
-; CHECK-NEXT:    cntd x9
-; CHECK-NEXT:    str x9, [sp, #80] // 8-byte Folded Spill
-; CHECK-NEXT:    .cfi_offset vg, -16
-; CHECK-NEXT:    .cfi_offset w30, -32
-; CHECK-NEXT:    .cfi_offset b8, -40
-; CHECK-NEXT:    .cfi_offset b9, -48
-; CHECK-NEXT:    .cfi_offset b10, -56
-; CHECK-NEXT:    .cfi_offset b11, -64
-; CHECK-NEXT:    .cfi_offset b12, -72
-; CHECK-NEXT:    .cfi_offset b13, -80
-; CHECK-NEXT:    .cfi_offset b14, -88
-; CHECK-NEXT:    .cfi_offset b15, -96
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    .cfi_offset vg, -24
-; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    .cfi_offset w30, -16
 ; CHECK-NEXT:    bl callee
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    .cfi_restore vg
-; CHECK-NEXT:    smstop sm
-; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
-; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
-; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
-; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
-; CHECK-NEXT:    .cfi_def_cfa_offset 0
-; CHECK-NEXT:    .cfi_restore w30
-; CHECK-NEXT:    .cfi_restore b8
-; CHECK-NEXT:    .cfi_restore b9
-; CHECK-NEXT:    .cfi_restore b10
-; CHECK-NEXT:    .cfi_restore b11
-; CHECK-NEXT:    .cfi_restore b12
-; CHECK-NEXT:    .cfi_restore b13
-; CHECK-NEXT:    .cfi_restore b14
-; CHECK-NEXT:    .cfi_restore b15
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
   call void @callee()
   ret void
@@ -388,9 +343,6 @@ define void @test10() "aarch64_pstate_sm_body" {
 ; CHECK-NEXT:    .cfi_offset b13, -80
 ; CHECK-NEXT:    .cfi_offset b14, -88
 ; CHECK-NEXT:    .cfi_offset b15, -96
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    .cfi_offset vg, -24
-; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:    .cfi_restore vg
@@ -398,9 +350,6 @@ define void @test10() "aarch64_pstate_sm_body" {
 ; CHECK-NEXT:    .cfi_offset vg, -24
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    bl callee
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    .cfi_restore vg
-; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
 ; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-body.ll b/llvm/test/CodeGen/AArch64/sme-streaming-body.ll
index 3afd571ffba28e..673dca5e5fa37b 100644
--- a/llvm/test/CodeGen/AArch64/sme-streaming-body.ll
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-body.ll
@@ -136,25 +136,9 @@ define <2 x i64> @locally_streaming_caller_no_callee(<2 x i64> %a) "aarch64_psta
 define void @locally_streaming_caller_locally_streaming_callee() "aarch64_pstate_sm_body" nounwind {
 ; CHECK-LABEL: locally_streaming_caller_locally_streaming_callee:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
-; CHECK-NEXT:    rdsvl x9, #1
-; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
-; CHECK-NEXT:    lsr x9, x9, #3
-; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
-; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
-; CHECK-NEXT:    stp x30, x9, [sp, #64] // 16-byte Folded Spill
-; CHECK-NEXT:    cntd x9
-; CHECK-NEXT:    str x9, [sp, #80] // 8-byte Folded Spill
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
 ; CHECK-NEXT:    bl locally_streaming_caller_streaming_callee
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    smstop sm
-; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
-; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
-; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
-; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
 
   call void @locally_streaming_caller_streaming_callee();
diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
index 4321493434230f..bd0734df9e23e6 100644
--- a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
@@ -269,8 +269,6 @@ define <vscale x 4 x i32> @smstart_clobber_sve_duplicate(<vscale x 4 x i32> %x)
 ; CHECK-NEXT:    str z0, [sp] // 16-byte Folded Spill
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:    bl streaming_callee
-; CHECK-NEXT:    smstop sm
-; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:    bl streaming_callee
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    ldr z0, [sp] // 16-byte Folded Reload
diff --git a/llvm/test/CodeGen/AArch64/sme-toggle-pstateza.ll b/llvm/test/CodeGen/AArch64/sme-toggle-pstateza.ll
index 3c50ab54e561e6..cc119dae1aa4d5 100644
--- a/llvm/test/CodeGen/AArch64/sme-toggle-pstateza.ll
+++ b/llvm/test/CodeGen/AArch64/sme-toggle-pstateza.ll
@@ -1,7 +1,12 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=aarch64 -mattr=+sme -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64 -mattr=+sme -enable-aarch64-sme-peephole-opt=true -verify-machineinstrs < %s | FileCheck %s --check-prefix=CHECK-OPT
+; RUN: llc -mtriple=aarch64 -mattr=+sme -enable-aarch64-sme-peephole-opt=false -verify-machineinstrs < %s | FileCheck %s --check-prefix=CHECK
 
 define void @toggle_pstate_za() {
+; CHECK-OPT-LABEL: toggle_pstate_za:
+; CHECK-OPT:       // %bb.0:
+; CHECK-OPT-NEXT:    ret
+;
 ; CHECK-LABEL: toggle_pstate_za:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    smstart za
diff --git a/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll b/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll
index 6264ce0cf4ae6d..a96f9e382ed1a8 100644
--- a/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll
+++ b/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll
@@ -669,9 +669,6 @@ define void @vg_locally_streaming_fn() #3 {
 ; CHECK-NEXT:    .cfi_offset b13, -80
 ; CHECK-NEXT:    .cfi_offset b14, -88
 ; CHECK-NEXT:    .cfi_offset b15, -96
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    .cfi_offset vg, -24
-; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:    .cfi_restore vg
@@ -679,9 +676,6 @@ define void @vg_locally_streaming_fn() #3 {
 ; CHECK-NEXT:    .cfi_offset vg, -24
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    bl callee
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    .cfi_restore vg
-; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
 ; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
@@ -725,9 +719,6 @@ define void @vg_locally_streaming_fn() #3 {
 ; FP-CHECK-NEXT:    .cfi_offset b13, -80
 ; FP-CHECK-NEXT:    .cfi_offset b14, -88
 ; FP-CHECK-NEXT:    .cfi_offset b15, -96
-; FP-CHECK-NEXT:    smstart sm
-; FP-CHECK-NEXT:    .cfi_offset vg, -16
-; FP-CHECK-NEXT:    smstop sm
 ; FP-CHECK-NEXT:    bl callee
 ; FP-CHECK-NEXT:    smstart sm
 ; FP-CHECK-NEXT:    .cfi_restore vg
@@ -735,9 +726,6 @@ define void @vg_locally_streaming_fn() #3 {
 ; FP-CHECK-NEXT:    .cfi_offset vg, -16
 ; FP-CHECK-NEXT:    smstop sm
 ; FP-CHECK-NEXT:    bl callee
-; FP-CHECK-NEXT:    smstart sm
-; FP-CHECK-NEXT:    .cfi_restore vg
-; FP-CHECK-NEXT:    smstop sm
 ; FP-CHECK-NEXT:    .cfi_def_cfa wsp, 96
 ; FP-CHECK-NEXT:    ldp x29, x30, [sp, #64] // 16-byte Folded Reload
 ; FP-CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
diff --git a/llvm/test/CodeGen/AArch64/streaming-compatible-memory-ops.ll b/llvm/test/CodeGen/AArch64/streaming-compatible-memory-ops.ll
index 106d6190e88b9c..20faeb23eed59d 100644
--- a/llvm/test/CodeGen/AArch64/streaming-compatible-memory-ops.ll
+++ b/llvm/test/CodeGen/AArch64/streaming-compatible-memory-ops.ll
@@ -264,8 +264,6 @@ define void @sb_memcpy(i64 noundef %n) "aarch64_pstate_sm_body" nounwind {
 ; CHECK-NO-SME-ROUTINES-NEXT:    ldr x1, [x1, :got_lo12:src]
 ; CHECK-NO-SME-ROUTINES-NEXT:    smstop sm
 ; CHECK-NO-SME-ROUTINES-NEXT:    bl memcpy
-; CHECK-NO-SME-ROUTINES-NEXT:    smstart sm
-; CHECK-NO-SME-ROUTINES-NEXT:    smstop sm
 ; CHECK-NO-SME-ROUTINES-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
 ; CHECK-NO-SME-ROUTINES-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
 ; CHECK-NO-SME-ROUTINES-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload

>From a16b9e5842715e34d9800dfe63a1eafd45ed82f2 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Mon, 19 Aug 2024 16:00:29 +0100
Subject: [PATCH 3/4] Relax restrictions for COPY and refactor

The restrictions for the COPY nodes were too strict. I couldn't
come up with any tests or cases where a COPY on its own would
result in any issues.

Also refactored the code a bit so that we don't need to do any
analysis on COPY nodes when the algorithm isn't trying to match a
'smstart/smstop' sequence.
---
 llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp    | 93 +++++++++++++------
 .../test/CodeGen/AArch64/sme-peephole-opts.ll |  4 -
 .../CodeGen/AArch64/sme-streaming-body.ll     | 26 +-----
 3 files changed, 68 insertions(+), 55 deletions(-)

diff --git a/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
index e6b8c6664f9fee..0a5ab8a3cc5979 100644
--- a/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
+++ b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
@@ -17,6 +17,8 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
 
 using namespace llvm;
 
@@ -108,8 +110,30 @@ static bool ChangesStreamingMode(const MachineInstr *MI) {
          MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA;
 }
 
+static bool isSVERegOp(const TargetRegisterInfo &TRI,
+                       const MachineRegisterInfo &MRI,
+                       const MachineOperand &MO) {
+  if (!MO.isReg())
+    return false;
+
+  Register R = MO.getReg();
+  if (R.isPhysical())
+    return llvm::any_of(TRI.subregs_inclusive(R), [](const MCPhysReg &SR) {
+      return AArch64::ZPRRegClass.contains(SR) ||
+             AArch64::PPRRegClass.contains(SR);
+    });
+
+  const TargetRegisterClass *RC = MRI.getRegClass(R);
+  return TRI.getCommonSubClass(&AArch64::ZPRRegClass, RC) ||
+         TRI.getCommonSubClass(&AArch64::PPRRegClass, RC);
+}
+
 bool SMEPeepholeOpt::optimizeStartStopPairs(MachineBasicBlock &MBB,
                                             bool &HasRemainingSMChange) const {
+  const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
+  const TargetRegisterInfo &TRI =
+      *MBB.getParent()->getSubtarget().getRegisterInfo();
+
   SmallVector<MachineInstr *, 4> ToBeRemoved;
 
   bool Changed = false;
@@ -129,33 +153,6 @@ bool SMEPeepholeOpt::optimizeStartStopPairs(MachineBasicBlock &MBB,
   // tracking.
   for (MachineInstr &MI : make_early_inc_range(MBB)) {
     switch (MI.getOpcode()) {
-    default:
-      Reset();
-      break;
-    case AArch64::COPY: {
-      // Permit copies of 32 and 64-bit registers.
-      if (!MI.getOperand(1).isReg()) {
-        Reset();
-        break;
-      }
-      Register Reg = MI.getOperand(1).getReg();
-      if (!AArch64::GPR32RegClass.contains(Reg) &&
-          !AArch64::GPR64RegClass.contains(Reg))
-        Reset();
-      break;
-    }
-    case AArch64::ADJCALLSTACKDOWN:
-    case AArch64::ADJCALLSTACKUP:
-    case AArch64::ANDXri:
-    case AArch64::ADDXri:
-      // We permit these as they don't generate SVE/NEON instructions.
-      break;
-    case AArch64::VGRestorePseudo:
-    case AArch64::VGSavePseudo:
-      // When the smstart/smstop are removed, we should also remove
-      // the pseudos that save/restore the VG value for CFI info.
-      ToBeRemoved.push_back(&MI);
-      break;
     case AArch64::MSRpstatesvcrImm1:
     case AArch64::MSRpstatePseudo: {
       if (!Prev)
@@ -174,8 +171,50 @@ bool SMEPeepholeOpt::optimizeStartStopPairs(MachineBasicBlock &MBB,
         Reset();
         Prev = &MI;
       }
+      continue;
+    }
+    default:
+      if (!Prev)
+        // Avoid doing expensive checks when Prev is nullptr.
+        continue;
       break;
     }
+
+    // Test if the instructions in between the start/stop sequence are agnostic
+    // of streaming mode. If not, the algorithm should reset.
+    switch (MI.getOpcode()) {
+    default:
+      Reset();
+      break;
+    case AArch64::COALESCER_BARRIER_FPR16:
+    case AArch64::COALESCER_BARRIER_FPR32:
+    case AArch64::COALESCER_BARRIER_FPR64:
+    case AArch64::COALESCER_BARRIER_FPR128:
+    case AArch64::COPY:
+      // These instructions should be safe when executed on their own, but
+      // the code remains conservative when SVE registers are used. There may
+      // exist subtle cases where executing a COPY in a different mode results
+      // in different behaviour, even if we can't yet come up with any
+      // concrete example/test-case.
+      if (isSVERegOp(TRI, MRI, MI.getOperand(0)) ||
+          isSVERegOp(TRI, MRI, MI.getOperand(1)))
+        Reset();
+      break;
+    case AArch64::ADJCALLSTACKDOWN:
+    case AArch64::ADJCALLSTACKUP:
+    case AArch64::ANDXri:
+    case AArch64::ADDXri:
+      // We permit these as they don't generate SVE/NEON instructions.
+      break;
+    case AArch64::VGRestorePseudo:
+    case AArch64::VGSavePseudo:
+      // When the smstart/smstop are removed, we should also remove
+      // the pseudos that save/restore the VG value for CFI info.
+      ToBeRemoved.push_back(&MI);
+      break;
+    case AArch64::MSRpstatesvcrImm1:
+    case AArch64::MSRpstatePseudo:
+      llvm_unreachable("Should have been handled");
     }
   }
 
diff --git a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
index 275327e54dee86..cb8a825a201ad6 100644
--- a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
@@ -229,10 +229,6 @@ define float @test6(float %f) nounwind "aarch64_pstate_sm_enabled" {
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    ldr s0, [sp, #12] // 4-byte Folded Reload
 ; CHECK-NEXT:    bl callee_farg_fret
-; CHECK-NEXT:    str s0, [sp, #12] // 4-byte Folded Spill
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    smstop sm
-; CHECK-NEXT:    ldr s0, [sp, #12] // 4-byte Folded Reload
 ; CHECK-NEXT:    bl callee_farg_fret
 ; CHECK-NEXT:    str s0, [sp, #12] // 4-byte Folded Spill
 ; CHECK-NEXT:    smstart sm
diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-body.ll b/llvm/test/CodeGen/AArch64/sme-streaming-body.ll
index 673dca5e5fa37b..572b1fff3520a9 100644
--- a/llvm/test/CodeGen/AArch64/sme-streaming-body.ll
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-body.ll
@@ -256,31 +256,9 @@ declare void @use_ptr(ptr) "aarch64_pstate_sm_compatible"
 define double @call_to_intrinsic_without_chain(double %x) nounwind "aarch64_pstate_sm_body" {
 ; CHECK-LABEL: call_to_intrinsic_without_chain:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sub sp, sp, #112
-; CHECK-NEXT:    rdsvl x9, #1
-; CHECK-NEXT:    stp d15, d14, [sp, #16] // 16-byte Folded Spill
-; CHECK-NEXT:    lsr x9, x9, #3
-; CHECK-NEXT:    stp d13, d12, [sp, #32] // 16-byte Folded Spill
-; CHECK-NEXT:    stp d11, d10, [sp, #48] // 16-byte Folded Spill
-; CHECK-NEXT:    stp x30, x9, [sp, #80] // 16-byte Folded Spill
-; CHECK-NEXT:    cntd x9
-; CHECK-NEXT:    stp d9, d8, [sp, #64] // 16-byte Folded Spill
-; CHECK-NEXT:    str x9, [sp, #96] // 8-byte Folded Spill
-; CHECK-NEXT:    str d0, [sp, #8] // 8-byte Folded Spill
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    smstop sm
-; CHECK-NEXT:    ldr d0, [sp, #8] // 8-byte Folded Reload
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
 ; CHECK-NEXT:    bl cos
-; CHECK-NEXT:    str d0, [sp, #8] // 8-byte Folded Spill
-; CHECK-NEXT:    smstart sm
-; CHECK-NEXT:    smstop sm
-; CHECK-NEXT:    ldp d9, d8, [sp, #64] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr d0, [sp, #8] // 8-byte Folded Reload
-; CHECK-NEXT:    ldp d11, d10, [sp, #48] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr x30, [sp, #80] // 8-byte Folded Reload
-; CHECK-NEXT:    ldp d13, d12, [sp, #32] // 16-byte Folded Reload
-; CHECK-NEXT:    ldp d15, d14, [sp, #16] // 16-byte Folded Reload
-; CHECK-NEXT:    add sp, sp, #112
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
 entry:
   %0 = call fast double @llvm.cos.f64(double %x)

>From 55f389e2333ecc14ab46df88594b25a21daa13d8 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Tue, 20 Aug 2024 08:56:02 +0100
Subject: [PATCH 4/4] Fix issue where HasStreamingModeChanges may be set to
 'false' incorrectly

---
 llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp | 33 +++++++++++++---------
 1 file changed, 19 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
index 0a5ab8a3cc5979..ba737afadaf943 100644
--- a/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
+++ b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
@@ -45,7 +45,7 @@ struct SMEPeepholeOpt : public MachineFunctionPass {
   }
 
   bool optimizeStartStopPairs(MachineBasicBlock &MBB,
-                              bool &HasRemainingSMChange) const;
+                              bool &HasRemovedAllSMChanges) const;
 };
 
 char SMEPeepholeOpt::ID = 0;
@@ -128,21 +128,18 @@ static bool isSVERegOp(const TargetRegisterInfo &TRI,
          TRI.getCommonSubClass(&AArch64::PPRRegClass, RC);
 }
 
-bool SMEPeepholeOpt::optimizeStartStopPairs(MachineBasicBlock &MBB,
-                                            bool &HasRemainingSMChange) const {
+bool SMEPeepholeOpt::optimizeStartStopPairs(
+    MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const {
   const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
   const TargetRegisterInfo &TRI =
       *MBB.getParent()->getSubtarget().getRegisterInfo();
 
-  SmallVector<MachineInstr *, 4> ToBeRemoved;
-
   bool Changed = false;
   MachineInstr *Prev = nullptr;
-  HasRemainingSMChange = false;
+  SmallVector<MachineInstr *, 4> ToBeRemoved;
 
+  // Convenience function to reset the matching of a sequence.
   auto Reset = [&]() {
-    if (Prev && ChangesStreamingMode(Prev))
-      HasRemainingSMChange = true;
     Prev = nullptr;
     ToBeRemoved.clear();
   };
@@ -151,10 +148,15 @@ bool SMEPeepholeOpt::optimizeStartStopPairs(MachineBasicBlock &MBB,
   // and smstop nodes that cancel each other out. We only permit a limited
   // set of instructions to appear between them, otherwise we reset our
   // tracking.
+  unsigned NumSMChanges = 0;
+  unsigned NumSMChangesRemoved = 0;
   for (MachineInstr &MI : make_early_inc_range(MBB)) {
     switch (MI.getOpcode()) {
     case AArch64::MSRpstatesvcrImm1:
     case AArch64::MSRpstatePseudo: {
+      if (ChangesStreamingMode(&MI))
+        NumSMChanges++;
+
       if (!Prev)
         Prev = &MI;
       else if (isMatchingStartStopPair(Prev, &MI)) {
@@ -167,6 +169,7 @@ bool SMEPeepholeOpt::optimizeStartStopPairs(MachineBasicBlock &MBB,
         ToBeRemoved.clear();
         Prev = nullptr;
         Changed = true;
+        NumSMChangesRemoved += 2;
       } else {
         Reset();
         Prev = &MI;
@@ -218,6 +221,8 @@ bool SMEPeepholeOpt::optimizeStartStopPairs(MachineBasicBlock &MBB,
     }
   }
 
+  HasRemovedAllSMChanges =
+      NumSMChanges && (NumSMChanges == NumSMChangesRemoved);
   return Changed;
 }
 
@@ -234,20 +239,20 @@ bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
   assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
 
   bool Changed = false;
-  bool FunctionHasRemainingSMChange = false;
+  bool FunctionHasAllSMChangesRemoved = false;
 
   // Even if the block lives in a function with no SME attributes attached we
   // still have to analyze all the blocks because we may call a streaming
   // function that requires smstart/smstop pairs.
   for (MachineBasicBlock &MBB : MF) {
-    bool BlockHasRemainingSMChange;
-    Changed |= optimizeStartStopPairs(MBB, BlockHasRemainingSMChange);
-    FunctionHasRemainingSMChange |= BlockHasRemainingSMChange;
+    bool BlockHasAllSMChangesRemoved;
+    Changed |= optimizeStartStopPairs(MBB, BlockHasAllSMChangesRemoved);
+    FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved;
   }
 
   AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
-  if (Changed && AFI->hasStreamingModeChanges())
-    AFI->setHasStreamingModeChanges(FunctionHasRemainingSMChange);
+  if (FunctionHasAllSMChangesRemoved)
+    AFI->setHasStreamingModeChanges(false);
 
   return Changed;
 }



More information about the llvm-commits mailing list