[llvm] 4b69689 - [AArch64] Implement spill/fill of predicate pair register classes (#76068)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 22 07:54:16 PST 2023


Author: Momchil Velikov
Date: 2023-12-22T15:54:12Z
New Revision: 4b6968952e653cb4da301d404717899393e4c530

URL: https://github.com/llvm/llvm-project/commit/4b6968952e653cb4da301d404717899393e4c530
DIFF: https://github.com/llvm/llvm-project/commit/4b6968952e653cb4da301d404717899393e4c530.diff

LOG: [AArch64] Implement spill/fill of predicate pair register classes (#76068)

We are getting ICE with, e.g.
```
#include <arm_sve.h>

 void g();
 svboolx2_t f0(int64_t i, int64_t n) {
     svboolx2_t r = svwhilelt_b16_x2(i, n);
     g();
     return r;
 }
```

Added: 
    llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/test/CodeGen/AArch64/spillfill-sve.mir

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index 757471d6a905e1..bb7f4d907ffd7f 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -747,6 +747,15 @@ bool AArch64ExpandPseudo::expandSetTagLoop(
 bool AArch64ExpandPseudo::expandSVESpillFill(MachineBasicBlock &MBB,
                                              MachineBasicBlock::iterator MBBI,
                                              unsigned Opc, unsigned N) {
+  assert((Opc == AArch64::LDR_ZXI || Opc == AArch64::STR_ZXI ||
+          Opc == AArch64::LDR_PXI || Opc == AArch64::STR_PXI) &&
+         "Unexpected opcode");
+  unsigned RState = (Opc == AArch64::LDR_ZXI || Opc == AArch64::LDR_PXI)
+                        ? RegState::Define
+                        : 0;
+  unsigned sub0 = (Opc == AArch64::LDR_ZXI || Opc == AArch64::STR_ZXI)
+                      ? AArch64::zsub0
+                      : AArch64::psub0;
   const TargetRegisterInfo *TRI =
       MBB.getParent()->getSubtarget().getRegisterInfo();
   MachineInstr &MI = *MBBI;
@@ -756,9 +765,8 @@ bool AArch64ExpandPseudo::expandSVESpillFill(MachineBasicBlock &MBB,
     assert(ImmOffset >= -256 && ImmOffset < 256 &&
            "Immediate spill offset out of range");
     BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opc))
-        .addReg(
-            TRI->getSubReg(MI.getOperand(0).getReg(), AArch64::zsub0 + Offset),
-            Opc == AArch64::LDR_ZXI ? RegState::Define : 0)
+        .addReg(TRI->getSubReg(MI.getOperand(0).getReg(), sub0 + Offset),
+                RState)
         .addReg(MI.getOperand(1).getReg(), getKillRegState(Kill))
         .addImm(ImmOffset);
   }
@@ -1492,12 +1500,16 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
      return expandSVESpillFill(MBB, MBBI, AArch64::STR_ZXI, 3);
    case AArch64::STR_ZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::STR_ZXI, 2);
+   case AArch64::STR_PPXI:
+     return expandSVESpillFill(MBB, MBBI, AArch64::STR_PXI, 2);
    case AArch64::LDR_ZZZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 4);
    case AArch64::LDR_ZZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 3);
    case AArch64::LDR_ZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 2);
+   case AArch64::LDR_PPXI:
+     return expandSVESpillFill(MBB, MBBI, AArch64::LDR_PXI, 2);
    case AArch64::BLR_RVMARKER:
      return expandCALL_RVMARKER(MBB, MBBI);
    case AArch64::BLR_BTI:

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index bc9678c13971fa..1cfbf4737a6f72 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -3771,6 +3771,13 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, TypeSize &Scale,
     MinOffset = -256;
     MaxOffset = 255;
     break;
+  case AArch64::LDR_PPXI:
+  case AArch64::STR_PPXI:
+    Scale = TypeSize::getScalable(2);
+    Width = TypeSize::getScalable(2 * 2);
+    MinOffset = -256;
+    MaxOffset = 254;
+    break;
   case AArch64::LDR_ZXI:
   case AArch64::STR_ZXI:
     Scale = TypeSize::getScalable(16);
@@ -4814,6 +4821,10 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
         assert(SrcReg != AArch64::WSP);
     } else if (AArch64::FPR32RegClass.hasSubClassEq(RC))
       Opc = AArch64::STRSui;
+    else if (AArch64::PPR2RegClass.hasSubClassEq(RC)) {
+      Opc = AArch64::STR_PPXI;
+      StackID = TargetStackID::ScalableVector;
+    }
     break;
   case 8:
     if (AArch64::GPR64allRegClass.hasSubClassEq(RC)) {
@@ -4990,6 +5001,10 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
         assert(DestReg != AArch64::WSP);
     } else if (AArch64::FPR32RegClass.hasSubClassEq(RC))
       Opc = AArch64::LDRSui;
+    else if (AArch64::PPR2RegClass.hasSubClassEq(RC)) {
+      Opc = AArch64::LDR_PPXI;
+      StackID = TargetStackID::ScalableVector;
+    }
     break;
   case 8:
     if (AArch64::GPR64allRegClass.hasSubClassEq(RC)) {

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 3dae6f7795ee93..344a153890631e 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2398,11 +2398,13 @@ let Predicates = [HasSVEorSME] in {
     def LDR_ZZXI   : Pseudo<(outs   ZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def LDR_ZZZXI  : Pseudo<(outs  ZZZ_b:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def LDR_ZZZZXI : Pseudo<(outs ZZZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
+    def LDR_PPXI   : Pseudo<(outs PPR2:$pp), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
   }
   let mayStore = 1, hasSideEffects = 0 in {
     def STR_ZZXI   : Pseudo<(outs), (ins   ZZ_b_strided_and_contiguous:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def STR_ZZZXI  : Pseudo<(outs), (ins  ZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def STR_ZZZZXI : Pseudo<(outs), (ins ZZZZ_b_strided_and_contiguous:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
+    def STR_PPXI   : Pseudo<(outs), (ins PPR2:$pp, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
   }
 
   let AddedComplexity = 1 in {

diff  --git a/llvm/test/CodeGen/AArch64/spillfill-sve.mir b/llvm/test/CodeGen/AArch64/spillfill-sve.mir
index 01756b84600192..ef7d55a1c2395f 100644
--- a/llvm/test/CodeGen/AArch64/spillfill-sve.mir
+++ b/llvm/test/CodeGen/AArch64/spillfill-sve.mir
@@ -7,6 +7,8 @@
   target triple = "aarch64--linux-gnu"
 
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr() #0 { entry: unreachable }
+  define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr2() #0 { entry: unreachable }
+  define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr2mul2() #0 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_pnr() #1 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_virtreg_pnr() #1 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_zpr() #0 { entry: unreachable }
@@ -64,6 +66,96 @@ body:             |
     RET_ReallyLR
 ...
 ---
+name: spills_fills_stack_id_ppr2
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: ppr2 }
+stack:
+liveins:
+  - { reg: '$p0_p1', virtual-reg: '%0' }
+body:             |
+  bb.0.entry:
+    liveins: $p0_p1
+
+    ; CHECK-LABEL: name: spills_fills_stack_id_ppr2
+    ; CHECK: stack:
+    ; CHECK:      - { id: 0, name: '', type: spill-slot, offset: 0, size: 4, alignment: 2
+    ; CHECK-NEXT:     stack-id: scalable-vector, callee-saved-register: ''
+
+    ; EXPAND-LABEL: name: spills_fills_stack_id_ppr2
+    ; EXPAND: STR_PXI $p0, $sp, 6
+    ; EXPAND: STR_PXI $p1, $sp, 7
+    ; EXPAND: $p0 = LDR_PXI $sp, 6
+    ; EXPAND: $p1 = LDR_PXI $sp, 7
+
+    %0:ppr2 = COPY $p0_p1
+
+    $p0 = IMPLICIT_DEF
+    $p1 = IMPLICIT_DEF
+    $p2 = IMPLICIT_DEF
+    $p3 = IMPLICIT_DEF
+    $p4 = IMPLICIT_DEF
+    $p5 = IMPLICIT_DEF
+    $p6 = IMPLICIT_DEF
+    $p7 = IMPLICIT_DEF
+    $p8 = IMPLICIT_DEF
+    $p9 = IMPLICIT_DEF
+    $p10 = IMPLICIT_DEF
+    $p11 = IMPLICIT_DEF
+    $p12 = IMPLICIT_DEF
+    $p13 = IMPLICIT_DEF
+    $p14 = IMPLICIT_DEF
+    $p15 = IMPLICIT_DEF
+
+    $p0_p1 = COPY %0
+    RET_ReallyLR
+...
+---
+name: spills_fills_stack_id_ppr2mul2
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: ppr2mul2 }
+stack:
+liveins:
+  - { reg: '$p0_p1', virtual-reg: '%0' }
+body:             |
+  bb.0.entry:
+    liveins: $p0_p1
+
+    ; CHECK-LABEL: name: spills_fills_stack_id_ppr2
+    ; CHECK: stack:
+    ; CHECK:      - { id: 0, name: '', type: spill-slot, offset: 0, size: 4, alignment: 2
+    ; CHECK-NEXT:     stack-id: scalable-vector, callee-saved-register: ''
+
+    ; EXPAND-LABEL: name: spills_fills_stack_id_ppr2mul2
+    ; EXPAND: STR_PXI $p0, $sp, 6
+    ; EXPAND: STR_PXI $p1, $sp, 7
+    ; EXPAND: $p0 = LDR_PXI $sp, 6
+    ; EXPAND: $p1 = LDR_PXI $sp, 7
+
+    %0:ppr2mul2 = COPY $p0_p1
+
+    $p0 = IMPLICIT_DEF
+    $p1 = IMPLICIT_DEF
+    $p2 = IMPLICIT_DEF
+    $p3 = IMPLICIT_DEF
+    $p4 = IMPLICIT_DEF
+    $p5 = IMPLICIT_DEF
+    $p6 = IMPLICIT_DEF
+    $p7 = IMPLICIT_DEF
+    $p8 = IMPLICIT_DEF
+    $p9 = IMPLICIT_DEF
+    $p10 = IMPLICIT_DEF
+    $p11 = IMPLICIT_DEF
+    $p12 = IMPLICIT_DEF
+    $p13 = IMPLICIT_DEF
+    $p14 = IMPLICIT_DEF
+    $p15 = IMPLICIT_DEF
+
+    $p0_p1 = COPY %0
+    RET_ReallyLR
+...
+---
 name: spills_fills_stack_id_pnr
 tracksRegLiveness: true
 registers:

diff  --git a/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll b/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll
new file mode 100644
index 00000000000000..4dcc81feb72f1b
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll
@@ -0,0 +1,67 @@
+; RUN: llc < %s | FileCheck %s
+
+; Derived from 
+; #include <arm_sve.h>
+
+; void g();
+
+; svboolx2_t f0(int64_t i, int64_t n) {
+;     svboolx2_t r = svwhilelt_b16_x2(i, n);
+;     g();
+;     return r;
+; }
+
+; svboolx2_t f1(svcount_t n) {
+;     svboolx2_t r = svpext_lane_c8_x2(n, 1);
+;     g();
+;     return r;
+; }
+; 
+; Check that predicate register pairs are spilled/filled without an ICE in the backend.
+
+target triple = "aarch64-unknown-linux"
+
+define <vscale x 32 x i1> @f0(i64 %i, i64 %n) #0 {
+entry:
+  %0 = tail call { <vscale x 8 x i1>, <vscale x 8 x i1> } @llvm.aarch64.sve.whilelt.x2.nxv8i1(i64 %i, i64 %n)
+  %1 = extractvalue { <vscale x 8 x i1>, <vscale x 8 x i1> } %0, 0
+  %2 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %1)
+  %3 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> poison, <vscale x 16 x i1> %2, i64 0)
+  %4 = extractvalue { <vscale x 8 x i1>, <vscale x 8 x i1> } %0, 1
+  %5 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %4)
+  %6 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> %3, <vscale x 16 x i1> %5, i64 16)
+  tail call void @g()
+  ret <vscale x 32 x i1> %6
+}
+; CHECK-LABEL: f0:
+; CHECK: whilelt { p0.h, p1.h }
+; CHECK: str p0, [sp, #6, mul vl]
+; CHECK: str p1, [sp, #7, mul vl]
+; CHECK: ldr p0, [sp, #6, mul vl]
+; CHECK: ldr p1, [sp, #7, mul vl]
+
+define <vscale x 32 x i1> @f1(target("aarch64.svcount") %n) #0 {
+entry:
+  %0 = tail call { <vscale x 16 x i1>, <vscale x 16 x i1> } @llvm.aarch64.sve.pext.x2.nxv16i1(target("aarch64.svcount") %n, i32 1)
+  %1 = extractvalue { <vscale x 16 x i1>, <vscale x 16 x i1> } %0, 0
+  %2 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> poison, <vscale x 16 x i1> %1, i64 0)
+  %3 = extractvalue { <vscale x 16 x i1>, <vscale x 16 x i1> } %0, 1
+  %4 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> %2, <vscale x 16 x i1> %3, i64 16)
+  tail call void @g()
+  ret <vscale x 32 x i1> %4
+}
+
+; CHECK-LABEL: f1:
+; CHECK: pext { p0.b, p1.b }
+; CHECK: str p0, [sp, #6, mul vl]
+; CHECK: str p1, [sp, #7, mul vl]
+; CHECK: ldr p0, [sp, #6, mul vl]
+; CHECK: ldr p1, [sp, #7, mul vl]
+
+declare void @g(...)
+declare { <vscale x 8 x i1>, <vscale x 8 x i1> } @llvm.aarch64.sve.whilelt.x2.nxv8i1(i64, i64)
+declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>)
+declare <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1>, <vscale x 16 x i1>, i64 immarg)
+declare { <vscale x 16 x i1>, <vscale x 16 x i1> } @llvm.aarch64.sve.pext.x2.nxv16i1(target("aarch64.svcount"), i32 immarg) #1
+
+attributes #0 = { nounwind "target-features"="+sve,+sve2,+sve2p1" }


        


More information about the llvm-commits mailing list