[llvm] [AArch64][SME] Remove implicit-def's on smstart (PR #69012)

Jon Roelofs via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 6 13:49:36 PST 2023


https://github.com/jroelofs updated https://github.com/llvm/llvm-project/pull/69012

>From fd79f502533bffa7c9d2790008e931e93cb345ce Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Fri, 13 Oct 2023 09:47:22 -0700
Subject: [PATCH] [AArch64][SME] Remove implicit-def's on smstart

When we lower calls, the sequence of argument copy-to-reg nodes are glued to
the smstart.  In the InstrEmitter, these glued copies are turned into implicit
defs, since the actual call instruction uses those physregs, resulting in the
register allocator adding unnecessary copies of regs that are preserved anyway.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 16 ++++++++
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  3 ++
 .../lib/Target/AArch64/AArch64SMEInstrInfo.td |  4 +-
 llvm/lib/Target/AArch64/SMEInstrFormats.td    |  1 +
 .../AArch64/sme-streaming-interface.ll        | 41 +++++++++++++++----
 5 files changed, 56 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f5193a9f2adf30c..7162a6d366d6f67 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7283,6 +7283,22 @@ static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) {
   return ZExtBool;
 }
 
+void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
+                                                          SDNode *Node) const {
+  // Live-in physreg copies that are glued to SMSTART are applied as
+  // implicit-def's in the InstrEmitter. Here we remove them, allowing the
+  // register allocator to pass call args in callee saved regs, without extra
+  // copies to avoid these fake clobbers of actually-preserved GPRs.
+  if (MI.getOpcode() == AArch64::MSRpstatesvcrImm1 ||
+      MI.getOpcode() == AArch64::MSRpstatePseudo)
+    for (unsigned I = MI.getNumOperands() - 1; I > 0; --I)
+      if (MachineOperand &MO = MI.getOperand(I);
+          MO.isReg() && MO.isImplicit() && MO.isDef() &&
+          (AArch64::GPR32RegClass.contains(MO.getReg()) ||
+           AArch64::GPR64RegClass.contains(MO.getReg())))
+        MI.removeOperand(I);
+}
+
 SDValue AArch64TargetLowering::changeStreamingMode(
     SelectionDAG &DAG, SDLoc DL, bool Enable,
     SDValue Chain, SDValue InGlue, SDValue PStateSM, bool Entry) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 7332a95615a4da5..de31c1bb41baa3b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -967,6 +967,9 @@ class AArch64TargetLowering : public TargetLowering {
                                const SDLoc &DL, SelectionDAG &DAG,
                                SmallVectorImpl<SDValue> &InVals) const override;
 
+  void AdjustInstrPostInstrSelection(MachineInstr &MI,
+                                     SDNode *Node) const override;
+
   SDValue LowerCall(CallLoweringInfo & /*CLI*/,
                     SmallVectorImpl<SDValue> &InVals) const override;
 
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index bb9464a8d2e1cf2..8a76690a9add4be 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -203,7 +203,9 @@ def : Pat<(i64 (int_aarch64_sme_get_tpidr2)),
 def MSRpstatePseudo :
   Pseudo<(outs),
            (ins svcr_op:$pstatefield, timm0_1:$imm, GPR64:$rtpstate, timm0_1:$expected_pstate, variable_ops), []>,
-    Sched<[WriteSys]>;
+    Sched<[WriteSys]> {
+  let hasPostISelHook = 1;
+}
 
 def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)),
           (MSRpstatePseudo svcr_op:$pstate, 0b1, GPR64:$rtpstate, timm0_1:$expected_pstate)>;
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index 4f40fa538b0c3c7..880dda87f1c6f74 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -215,6 +215,7 @@ def MSRpstatesvcrImm1
   let Inst{11-9} = pstatefield;
   let Inst{8} = imm;
   let Inst{7-5} = 0b011; // op2
+  let hasPostISelHook = 1;
 }
 
 def : InstAlias<"smstart",    (MSRpstatesvcrImm1 0b011, 0b1)>;
diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
index 102ed896ce7b3e8..dd7d6470ad7b084 100644
--- a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
@@ -368,15 +368,11 @@ define i8 @call_to_non_streaming_pass_sve_objects(ptr nocapture noundef readnone
 ; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
 ; CHECK-NEXT:    stp x29, x30, [sp, #64] // 16-byte Folded Spill
 ; CHECK-NEXT:    addvl sp, sp, #-3
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    addvl x9, sp, #2
-; CHECK-NEXT:    addvl x10, sp, #1
-; CHECK-NEXT:    mov x11, sp
+; CHECK-NEXT:    rdsvl x3, #1
+; CHECK-NEXT:    addvl x0, sp, #2
+; CHECK-NEXT:    addvl x1, sp, #1
+; CHECK-NEXT:    mov x2, sp
 ; CHECK-NEXT:    smstop sm
-; CHECK-NEXT:    mov x0, x9
-; CHECK-NEXT:    mov x1, x10
-; CHECK-NEXT:    mov x2, x11
-; CHECK-NEXT:    mov x3, x8
 ; CHECK-NEXT:    bl foo
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:    ptrue p0.b
@@ -400,8 +396,37 @@ entry:
   ret i8 %vecext
 }
 
+define void @call_to_non_streaming_pass_args(ptr nocapture noundef readnone %ptr, i64 %long1, i64 %long2, i32 %int1, i32 %int2, float %float1, float %float2, double %double1, double %double2) #0 {
+; CHECK-LABEL: call_to_non_streaming_pass_args:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sub sp, sp, #112
+; CHECK-NEXT:    stp d15, d14, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #80] // 16-byte Folded Spill
+; CHECK-NEXT:    str x30, [sp, #96] // 8-byte Folded Spill
+; CHECK-NEXT:    stp d2, d3, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp s0, s1, [sp, #8] // 8-byte Folded Spill
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldp s0, s1, [sp, #8] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d2, d3, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    bl bar
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #80] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #96] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #112
+; CHECK-NEXT:    ret
+entry:
+  call void @bar(ptr noundef nonnull %ptr, i64 %long1, i64 %long2, i32 %int1, i32 %int2, float %float1, float %float2, double %double1, double %double2)
+  ret void
+}
+
 declare i64 @llvm.aarch64.sme.cntsb()
 
 declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef)
+declare void @bar(ptr noundef, i64 noundef, i64 noundef, i32 noundef, i32 noundef, float noundef, float noundef, double noundef, double noundef)
 
 attributes #0 = { nounwind vscale_range(1,16) "aarch64_pstate_sm_enabled" }



More information about the llvm-commits mailing list