[llvm] [AArch64] Implement spill/fill of predicate pair register classes (PR #76068)

Momchil Velikov via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 20 10:43:27 PST 2023


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/76068

>From 57faca24a873b38d3c1a46f7ae0dc7fadf2bf448 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 20 Dec 2023 15:04:46 +0000
Subject: [PATCH 1/2] [AArch64] Implement spill/fill of predicate pair register
 classes

---
 .../AArch64/AArch64ExpandPseudoInsts.cpp      | 18 +++-
 llvm/lib/Target/AArch64/AArch64InstrInfo.cpp  | 17 ++++
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  2 +
 llvm/test/CodeGen/AArch64/spillfill-sve.mir   | 92 +++++++++++++++++++
 .../AArch64/sve-pred-pair-spill-fill.ll       | 67 ++++++++++++++
 5 files changed, 193 insertions(+), 3 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index 757471d6a905e1..bb7f4d907ffd7f 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -747,6 +747,15 @@ bool AArch64ExpandPseudo::expandSetTagLoop(
 bool AArch64ExpandPseudo::expandSVESpillFill(MachineBasicBlock &MBB,
                                              MachineBasicBlock::iterator MBBI,
                                              unsigned Opc, unsigned N) {
+  assert((Opc == AArch64::LDR_ZXI || Opc == AArch64::STR_ZXI ||
+          Opc == AArch64::LDR_PXI || Opc == AArch64::STR_PXI) &&
+         "Unexpected opcode");
+  unsigned RState = (Opc == AArch64::LDR_ZXI || Opc == AArch64::LDR_PXI)
+                        ? RegState::Define
+                        : 0;
+  unsigned sub0 = (Opc == AArch64::LDR_ZXI || Opc == AArch64::STR_ZXI)
+                      ? AArch64::zsub0
+                      : AArch64::psub0;
   const TargetRegisterInfo *TRI =
       MBB.getParent()->getSubtarget().getRegisterInfo();
   MachineInstr &MI = *MBBI;
@@ -756,9 +765,8 @@ bool AArch64ExpandPseudo::expandSVESpillFill(MachineBasicBlock &MBB,
     assert(ImmOffset >= -256 && ImmOffset < 256 &&
            "Immediate spill offset out of range");
     BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opc))
-        .addReg(
-            TRI->getSubReg(MI.getOperand(0).getReg(), AArch64::zsub0 + Offset),
-            Opc == AArch64::LDR_ZXI ? RegState::Define : 0)
+        .addReg(TRI->getSubReg(MI.getOperand(0).getReg(), sub0 + Offset),
+                RState)
         .addReg(MI.getOperand(1).getReg(), getKillRegState(Kill))
         .addImm(ImmOffset);
   }
@@ -1492,12 +1500,16 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
      return expandSVESpillFill(MBB, MBBI, AArch64::STR_ZXI, 3);
    case AArch64::STR_ZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::STR_ZXI, 2);
+   case AArch64::STR_PPXI:
+     return expandSVESpillFill(MBB, MBBI, AArch64::STR_PXI, 2);
    case AArch64::LDR_ZZZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 4);
    case AArch64::LDR_ZZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 3);
    case AArch64::LDR_ZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 2);
+   case AArch64::LDR_PPXI:
+     return expandSVESpillFill(MBB, MBBI, AArch64::LDR_PXI, 2);
    case AArch64::BLR_RVMARKER:
      return expandCALL_RVMARKER(MBB, MBBI);
    case AArch64::BLR_BTI:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 7d71c316bcb0a2..44a22a6f7ec0e3 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -2197,6 +2197,7 @@ unsigned AArch64InstrInfo::isLoadFromStackSlot(const MachineInstr &MI,
   case AArch64::LDRDui:
   case AArch64::LDRQui:
   case AArch64::LDR_PXI:
+  case AArch64::LDR_PPXI:
     if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() &&
         MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) {
       FrameIndex = MI.getOperand(1).getIndex();
@@ -2221,6 +2222,7 @@ unsigned AArch64InstrInfo::isStoreToStackSlot(const MachineInstr &MI,
   case AArch64::STRDui:
   case AArch64::STRQui:
   case AArch64::STR_PXI:
+  case AArch64::STR_PPXI:
     if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() &&
         MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) {
       FrameIndex = MI.getOperand(1).getIndex();
@@ -3771,6 +3773,13 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, TypeSize &Scale,
     MinOffset = -256;
     MaxOffset = 255;
     break;
+  case AArch64::LDR_PPXI:
+  case AArch64::STR_PPXI:
+    Scale = TypeSize::getScalable(2);
+    Width = TypeSize::getScalable(2 * 2);
+    MinOffset = -256;
+    MaxOffset = 255;
+    break;
   case AArch64::LDR_ZXI:
   case AArch64::STR_ZXI:
     Scale = TypeSize::getScalable(16);
@@ -4804,6 +4813,10 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
         assert(SrcReg != AArch64::WSP);
     } else if (AArch64::FPR32RegClass.hasSubClassEq(RC))
       Opc = AArch64::STRSui;
+    else if (AArch64::PPR2RegClass.hasSubClassEq(RC)) {
+      Opc = AArch64::STR_PPXI;
+      StackID = TargetStackID::ScalableVector;
+    }
     break;
   case 8:
     if (AArch64::GPR64allRegClass.hasSubClassEq(RC)) {
@@ -4980,6 +4993,10 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
         assert(DestReg != AArch64::WSP);
     } else if (AArch64::FPR32RegClass.hasSubClassEq(RC))
       Opc = AArch64::LDRSui;
+    else if (AArch64::PPR2RegClass.hasSubClassEq(RC)) {
+      Opc = AArch64::LDR_PPXI;
+      StackID = TargetStackID::ScalableVector;
+    }
     break;
   case 8:
     if (AArch64::GPR64allRegClass.hasSubClassEq(RC)) {
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index f68059889d0c51..d496bf50e62d10 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2378,11 +2378,13 @@ let Predicates = [HasSVEorSME] in {
     def LDR_ZZXI   : Pseudo<(outs   ZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def LDR_ZZZXI  : Pseudo<(outs  ZZZ_b:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def LDR_ZZZZXI : Pseudo<(outs ZZZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
+    def LDR_PPXI   : Pseudo<(outs PPR2:$pp), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
   }
   let mayStore = 1, hasSideEffects = 0 in {
     def STR_ZZXI   : Pseudo<(outs), (ins   ZZ_b_strided_and_contiguous:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def STR_ZZZXI  : Pseudo<(outs), (ins  ZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def STR_ZZZZXI : Pseudo<(outs), (ins ZZZZ_b_strided_and_contiguous:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
+    def STR_PPXI   : Pseudo<(outs), (ins PPR2:$pp, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
   }
 
   let AddedComplexity = 1 in {
diff --git a/llvm/test/CodeGen/AArch64/spillfill-sve.mir b/llvm/test/CodeGen/AArch64/spillfill-sve.mir
index 01756b84600192..ef7d55a1c2395f 100644
--- a/llvm/test/CodeGen/AArch64/spillfill-sve.mir
+++ b/llvm/test/CodeGen/AArch64/spillfill-sve.mir
@@ -7,6 +7,8 @@
   target triple = "aarch64--linux-gnu"
 
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr() #0 { entry: unreachable }
+  define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr2() #0 { entry: unreachable }
+  define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr2mul2() #0 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_pnr() #1 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_virtreg_pnr() #1 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_zpr() #0 { entry: unreachable }
@@ -64,6 +66,96 @@ body:             |
     RET_ReallyLR
 ...
 ---
+name: spills_fills_stack_id_ppr2
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: ppr2 }
+stack:
+liveins:
+  - { reg: '$p0_p1', virtual-reg: '%0' }
+body:             |
+  bb.0.entry:
+    liveins: $p0_p1
+
+    ; CHECK-LABEL: name: spills_fills_stack_id_ppr2
+    ; CHECK: stack:
+    ; CHECK:      - { id: 0, name: '', type: spill-slot, offset: 0, size: 4, alignment: 2
+    ; CHECK-NEXT:     stack-id: scalable-vector, callee-saved-register: ''
+
+    ; EXPAND-LABEL: name: spills_fills_stack_id_ppr2
+    ; EXPAND: STR_PXI $p0, $sp, 6
+    ; EXPAND: STR_PXI $p1, $sp, 7
+    ; EXPAND: $p0 = LDR_PXI $sp, 6
+    ; EXPAND: $p1 = LDR_PXI $sp, 7
+
+    %0:ppr2 = COPY $p0_p1
+
+    $p0 = IMPLICIT_DEF
+    $p1 = IMPLICIT_DEF
+    $p2 = IMPLICIT_DEF
+    $p3 = IMPLICIT_DEF
+    $p4 = IMPLICIT_DEF
+    $p5 = IMPLICIT_DEF
+    $p6 = IMPLICIT_DEF
+    $p7 = IMPLICIT_DEF
+    $p8 = IMPLICIT_DEF
+    $p9 = IMPLICIT_DEF
+    $p10 = IMPLICIT_DEF
+    $p11 = IMPLICIT_DEF
+    $p12 = IMPLICIT_DEF
+    $p13 = IMPLICIT_DEF
+    $p14 = IMPLICIT_DEF
+    $p15 = IMPLICIT_DEF
+
+    $p0_p1 = COPY %0
+    RET_ReallyLR
+...
+---
+name: spills_fills_stack_id_ppr2mul2
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: ppr2mul2 }
+stack:
+liveins:
+  - { reg: '$p0_p1', virtual-reg: '%0' }
+body:             |
+  bb.0.entry:
+    liveins: $p0_p1
+
+    ; CHECK-LABEL: name: spills_fills_stack_id_ppr2
+    ; CHECK: stack:
+    ; CHECK:      - { id: 0, name: '', type: spill-slot, offset: 0, size: 4, alignment: 2
+    ; CHECK-NEXT:     stack-id: scalable-vector, callee-saved-register: ''
+
+    ; EXPAND-LABEL: name: spills_fills_stack_id_ppr2mul2
+    ; EXPAND: STR_PXI $p0, $sp, 6
+    ; EXPAND: STR_PXI $p1, $sp, 7
+    ; EXPAND: $p0 = LDR_PXI $sp, 6
+    ; EXPAND: $p1 = LDR_PXI $sp, 7
+
+    %0:ppr2mul2 = COPY $p0_p1
+
+    $p0 = IMPLICIT_DEF
+    $p1 = IMPLICIT_DEF
+    $p2 = IMPLICIT_DEF
+    $p3 = IMPLICIT_DEF
+    $p4 = IMPLICIT_DEF
+    $p5 = IMPLICIT_DEF
+    $p6 = IMPLICIT_DEF
+    $p7 = IMPLICIT_DEF
+    $p8 = IMPLICIT_DEF
+    $p9 = IMPLICIT_DEF
+    $p10 = IMPLICIT_DEF
+    $p11 = IMPLICIT_DEF
+    $p12 = IMPLICIT_DEF
+    $p13 = IMPLICIT_DEF
+    $p14 = IMPLICIT_DEF
+    $p15 = IMPLICIT_DEF
+
+    $p0_p1 = COPY %0
+    RET_ReallyLR
+...
+---
 name: spills_fills_stack_id_pnr
 tracksRegLiveness: true
 registers:
diff --git a/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll b/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll
new file mode 100644
index 00000000000000..eb7950ef10c9c4
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll
@@ -0,0 +1,67 @@
+; RUN: llc < %s | FileCheck %s
+
+; Derived from 
+; #include <arm_sve.h>
+
+; void g();
+
+; svboolx2_t f0(int64_t i, int64_t n) {
+;     svboolx2_t r = svwhilelt_b16_x2(i, n);
+;     g();
+;     return r;
+; }
+
+; svboolx2_t f1(svcount_t n) {
+;     svboolx2_t r = svpext_lane_c8_x2(n, 1);
+;     g();
+;     return r;
+; }
+; 
+; Check that predicate register pairs are spilled/filled without an ICE in the backend.
+
+target triple = "aarch64-unknown-linux"
+
+define <vscale x 32 x i1> @f0(i64 %i, i64 %n) #0 {
+entry:
+  %0 = tail call { <vscale x 8 x i1>, <vscale x 8 x i1> } @llvm.aarch64.sve.whilelt.x2.nxv8i1(i64 %i, i64 %n)
+  %1 = extractvalue { <vscale x 8 x i1>, <vscale x 8 x i1> } %0, 0
+  %2 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %1)
+  %3 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> poison, <vscale x 16 x i1> %2, i64 0)
+  %4 = extractvalue { <vscale x 8 x i1>, <vscale x 8 x i1> } %0, 1
+  %5 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %4)
+  %6 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> %3, <vscale x 16 x i1> %5, i64 16)
+  tail call void @g() #4
+  ret <vscale x 32 x i1> %6
+}
+; CHECK-LABEL: f0:
+; CHECK: whilelt { p0.h, p1.h }
+; CHECK: str p0, [sp, #6, mul vl]
+; CHECK: str p1, [sp, #7, mul vl]
+; CHECK: ldr p0, [sp, #6, mul vl]
+; CHECK: ldr p1, [sp, #7, mul vl]
+
+define <vscale x 32 x i1> @f1(target("aarch64.svcount") %n) #0 {
+entry:
+  %0 = tail call { <vscale x 16 x i1>, <vscale x 16 x i1> } @llvm.aarch64.sve.pext.x2.nxv16i1(target("aarch64.svcount") %n, i32 1)
+  %1 = extractvalue { <vscale x 16 x i1>, <vscale x 16 x i1> } %0, 0
+  %2 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> poison, <vscale x 16 x i1> %1, i64 0)
+  %3 = extractvalue { <vscale x 16 x i1>, <vscale x 16 x i1> } %0, 1
+  %4 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> %2, <vscale x 16 x i1> %3, i64 16)
+  tail call void @g() #4
+  ret <vscale x 32 x i1> %4
+}
+
+; CHECK-LABEL: f1:
+; CHECK: pext { p0.b, p1.b }
+; CHECK: str p0, [sp, #6, mul vl]
+; CHECK: str p1, [sp, #7, mul vl]
+; CHECK: ldr p0, [sp, #6, mul vl]
+; CHECK: ldr p1, [sp, #7, mul vl]
+
+declare void @g(...)
+declare { <vscale x 8 x i1>, <vscale x 8 x i1> } @llvm.aarch64.sve.whilelt.x2.nxv8i1(i64, i64)
+declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>)
+declare <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1>, <vscale x 16 x i1>, i64 immarg)
+declare { <vscale x 16 x i1>, <vscale x 16 x i1> } @llvm.aarch64.sve.pext.x2.nxv16i1(target("aarch64.svcount"), i32 immarg) #1
+
+attributes #0 = { nounwind "target-features"="+sve,+sve2,+sve2p1" }

>From b179a0eb1181326ad6d4fa11dd1a3ec86999d06b Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 20 Dec 2023 18:38:41 +0000
Subject: [PATCH 2/2] [fixup] Remove LDR_PPXI/STR_PPXI from
 isLoadFromStackSlot/isStorToStackSlot as not immediately necessary

---
 llvm/lib/Target/AArch64/AArch64InstrInfo.cpp          | 2 --
 llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll | 4 ++--
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 44a22a6f7ec0e3..0e705a28ce845f 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -2197,7 +2197,6 @@ unsigned AArch64InstrInfo::isLoadFromStackSlot(const MachineInstr &MI,
   case AArch64::LDRDui:
   case AArch64::LDRQui:
   case AArch64::LDR_PXI:
-  case AArch64::LDR_PPXI:
     if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() &&
         MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) {
       FrameIndex = MI.getOperand(1).getIndex();
@@ -2222,7 +2221,6 @@ unsigned AArch64InstrInfo::isStoreToStackSlot(const MachineInstr &MI,
   case AArch64::STRDui:
   case AArch64::STRQui:
   case AArch64::STR_PXI:
-  case AArch64::STR_PPXI:
     if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() &&
         MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) {
       FrameIndex = MI.getOperand(1).getIndex();
diff --git a/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll b/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll
index eb7950ef10c9c4..4dcc81feb72f1b 100644
--- a/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll
+++ b/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll
@@ -30,7 +30,7 @@ entry:
   %4 = extractvalue { <vscale x 8 x i1>, <vscale x 8 x i1> } %0, 1
   %5 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %4)
   %6 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> %3, <vscale x 16 x i1> %5, i64 16)
-  tail call void @g() #4
+  tail call void @g()
   ret <vscale x 32 x i1> %6
 }
 ; CHECK-LABEL: f0:
@@ -47,7 +47,7 @@ entry:
   %2 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> poison, <vscale x 16 x i1> %1, i64 0)
   %3 = extractvalue { <vscale x 16 x i1>, <vscale x 16 x i1> } %0, 1
   %4 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> %2, <vscale x 16 x i1> %3, i64 16)
-  tail call void @g() #4
+  tail call void @g()
   ret <vscale x 32 x i1> %4
 }
 



More information about the llvm-commits mailing list