[llvm] a629322 - Reland "[AArch64][SME] Add support for Copy/Spill/Fill of strided ZPR2/ZPR4 registers."

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 31 08:03:40 PDT 2023


Author: Sander de Smalen
Date: 2023-08-31T15:03:19Z
New Revision: a6293228fdd5aba8c04c63f02f3d017443feb3f2

URL: https://github.com/llvm/llvm-project/commit/a6293228fdd5aba8c04c63f02f3d017443feb3f2
DIFF: https://github.com/llvm/llvm-project/commit/a6293228fdd5aba8c04c63f02f3d017443feb3f2.diff

LOG: Reland "[AArch64][SME] Add support for Copy/Spill/Fill of strided ZPR2/ZPR4 registers."

This patch contains a few changes:

* It changes the alignment of the strided/contiguous ZPR2/ZPR4 registers to
  128-bits. This is important, because when we spill these registers to the
  stack, the address doesn't need to be 256/512 bits aligned because we
  split the single-store/reload pseudo instruction up into multiple
  STR_ZXI/LDR_ZXI (single vector store/load) instructions, which only
  require a 128-bit alignment. Additionally, an alignment larger than the
  stack-alignment is not supported for scalable vectors.

* It adds support for these register classes in storeRegToStackSlot,
  loadRegFromStackSlot and copyPhysReg.

* It adds tests only for the strided forms. There is no need to also
  test the contiguous forms, because a register such as z2_z3 or
  z4_z5_z6_z7 are also part of the regular ZPR2 and ZPR4 register classes,
  respectively, which are already covered and tested.

Reviewed By: dtemirbulatov

Differential Revision: https://reviews.llvm.org/D159189

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
    llvm/lib/Target/AArch64/AArch64RegisterInfo.td
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/test/CodeGen/AArch64/spillfill-sve.mir
    llvm/test/CodeGen/AArch64/sve-copy-zprpair.mir

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 088e1e5bdb6de2..4b98dbc1a8dc2c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -3669,8 +3669,10 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
   }
 
   // Copy a Z register pair by copying the individual sub-registers.
-  if (AArch64::ZPR2RegClass.contains(DestReg) &&
-      AArch64::ZPR2RegClass.contains(SrcReg)) {
+  if ((AArch64::ZPR2RegClass.contains(DestReg) ||
+       AArch64::ZPR2StridedOrContiguousRegClass.contains(DestReg)) &&
+      (AArch64::ZPR2RegClass.contains(SrcReg) ||
+       AArch64::ZPR2StridedOrContiguousRegClass.contains(SrcReg))) {
     assert(Subtarget.hasSVEorSME() && "Unexpected SVE register.");
     static const unsigned Indices[] = {AArch64::zsub0, AArch64::zsub1};
     copyPhysRegTuple(MBB, I, DL, DestReg, SrcReg, KillSrc, AArch64::ORR_ZZZ,
@@ -3690,8 +3692,10 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
   }
 
   // Copy a Z register quad by copying the individual sub-registers.
-  if (AArch64::ZPR4RegClass.contains(DestReg) &&
-      AArch64::ZPR4RegClass.contains(SrcReg)) {
+  if ((AArch64::ZPR4RegClass.contains(DestReg) ||
+       AArch64::ZPR4StridedOrContiguousRegClass.contains(DestReg)) &&
+      (AArch64::ZPR4RegClass.contains(SrcReg) ||
+       AArch64::ZPR4StridedOrContiguousRegClass.contains(SrcReg))) {
     assert(Subtarget.hasSVEorSME() && "Unexpected SVE register.");
     static const unsigned Indices[] = {AArch64::zsub0, AArch64::zsub1,
                                        AArch64::zsub2, AArch64::zsub3};
@@ -4022,7 +4026,8 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
       assert(Subtarget.hasNEON() && "Unexpected register store without NEON");
       Opc = AArch64::ST1Twov2d;
       Offset = false;
-    } else if (AArch64::ZPR2RegClass.hasSubClassEq(RC)) {
+    } else if (AArch64::ZPR2RegClass.hasSubClassEq(RC) ||
+               AArch64::ZPR2StridedOrContiguousRegClass.hasSubClassEq(RC)) {
       assert(Subtarget.hasSVE() && "Unexpected register store without SVE");
       Opc = AArch64::STR_ZZXI;
       StackID = TargetStackID::ScalableVector;
@@ -4044,7 +4049,8 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
       assert(Subtarget.hasNEON() && "Unexpected register store without NEON");
       Opc = AArch64::ST1Fourv2d;
       Offset = false;
-    } else if (AArch64::ZPR4RegClass.hasSubClassEq(RC)) {
+    } else if (AArch64::ZPR4RegClass.hasSubClassEq(RC) ||
+               AArch64::ZPR4StridedOrContiguousRegClass.hasSubClassEq(RC)) {
       assert(Subtarget.hasSVE() && "Unexpected register store without SVE");
       Opc = AArch64::STR_ZZZZXI;
       StackID = TargetStackID::ScalableVector;
@@ -4178,7 +4184,8 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
       assert(Subtarget.hasNEON() && "Unexpected register load without NEON");
       Opc = AArch64::LD1Twov2d;
       Offset = false;
-    } else if (AArch64::ZPR2RegClass.hasSubClassEq(RC)) {
+    } else if (AArch64::ZPR2RegClass.hasSubClassEq(RC) ||
+               AArch64::ZPR2StridedOrContiguousRegClass.hasSubClassEq(RC)) {
       assert(Subtarget.hasSVE() && "Unexpected register load without SVE");
       Opc = AArch64::LDR_ZZXI;
       StackID = TargetStackID::ScalableVector;
@@ -4200,7 +4207,8 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
       assert(Subtarget.hasNEON() && "Unexpected register load without NEON");
       Opc = AArch64::LD1Fourv2d;
       Offset = false;
-    } else if (AArch64::ZPR4RegClass.hasSubClassEq(RC)) {
+    } else if (AArch64::ZPR4RegClass.hasSubClassEq(RC) ||
+               AArch64::ZPR4StridedOrContiguousRegClass.hasSubClassEq(RC)) {
       assert(Subtarget.hasSVE() && "Unexpected register load without SVE");
       Opc = AArch64::LDR_ZZZZXI;
       StackID = TargetStackID::ScalableVector;

diff  --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index 9aba263da4f47b..18fc8c77a0e44d 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -1331,16 +1331,16 @@ def ZStridedQuadsHi : RegisterTuples<[zsub0, zsub1, zsub2, zsub3], [
   (trunc (rotl ZPR, 24), 4), (trunc (rotl ZPR, 28), 4)
 ]>;
 
-def ZPR2Strided : RegisterClass<"AArch64", [untyped], 256,
+def ZPR2Strided : RegisterClass<"AArch64", [untyped], 128,
                                 (add ZStridedPairsLo, ZStridedPairsHi)>  {
   let Size = 256;
 }
-def ZPR4Strided : RegisterClass<"AArch64", [untyped], 512,
+def ZPR4Strided : RegisterClass<"AArch64", [untyped], 128,
                                 (add ZStridedQuadsLo, ZStridedQuadsHi)>  {
   let Size = 512;
 }
 
-def ZPR2StridedOrContiguous : RegisterClass<"AArch64", [untyped], 256,
+def ZPR2StridedOrContiguous : RegisterClass<"AArch64", [untyped], 128,
                                 (add ZStridedPairsLo, ZStridedPairsHi,
                                 (decimate ZSeqPairs, 2))> {
   let Size = 256;
@@ -1387,7 +1387,7 @@ let EncoderMethod = "EncodeZPR2StridedRegisterClass",
       : RegisterOperand<ZPR2StridedOrContiguous, "printTypedVectorList<0,'d'>">;
 }
 
-def ZPR4StridedOrContiguous : RegisterClass<"AArch64", [untyped], 512,
+def ZPR4StridedOrContiguous : RegisterClass<"AArch64", [untyped], 128,
                                 (add ZStridedQuadsLo, ZStridedQuadsHi,
                                 (decimate ZSeqQuads, 4))>  {
   let Size = 512;

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 91942b7e42974a..98bee05743c8f4 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2311,14 +2311,14 @@ let Predicates = [HasSVEorSME] in {
   // These get expanded to individual LDR_ZXI/STR_ZXI instructions in
   // AArch64ExpandPseudoInsts.
   let mayLoad = 1, hasSideEffects = 0 in {
-    def LDR_ZZXI   : Pseudo<(outs   ZZ_b:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
+    def LDR_ZZXI   : Pseudo<(outs   ZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def LDR_ZZZXI  : Pseudo<(outs  ZZZ_b:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
-    def LDR_ZZZZXI : Pseudo<(outs ZZZZ_b:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
+    def LDR_ZZZZXI : Pseudo<(outs ZZZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
   }
   let mayStore = 1, hasSideEffects = 0 in {
-    def STR_ZZXI   : Pseudo<(outs), (ins   ZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
+    def STR_ZZXI   : Pseudo<(outs), (ins   ZZ_b_strided_and_contiguous:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def STR_ZZZXI  : Pseudo<(outs), (ins  ZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
-    def STR_ZZZZXI : Pseudo<(outs), (ins ZZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
+    def STR_ZZZZXI : Pseudo<(outs), (ins ZZZZ_b_strided_and_contiguous:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
   }
 
   let AddedComplexity = 1 in {

diff  --git a/llvm/test/CodeGen/AArch64/spillfill-sve.mir b/llvm/test/CodeGen/AArch64/spillfill-sve.mir
index 951dbc72defc82..10fbb2499b48a5 100644
--- a/llvm/test/CodeGen/AArch64/spillfill-sve.mir
+++ b/llvm/test/CodeGen/AArch64/spillfill-sve.mir
@@ -9,8 +9,10 @@
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr() #0 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_zpr() #0 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_zpr2() #0 { entry: unreachable }
+  define aarch64_sve_vector_pcs void @spills_fills_stack_id_zpr2strided() #0 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_zpr3() #0 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_zpr4() #0 { entry: unreachable }
+  define aarch64_sve_vector_pcs void @spills_fills_stack_id_zpr4strided() #0 { entry: unreachable }
 
   attributes #0 = { nounwind "target-features"="+sve" }
 
@@ -131,6 +133,51 @@ body:             |
     RET_ReallyLR
 ...
 ---
+name: spills_fills_stack_id_zpr2strided
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: zpr2strided }
+stack:
+liveins:
+  - { reg: '$z0_z8', virtual-reg: '%0' }
+body:             |
+  bb.0.entry:
+    liveins: $z0_z1
+    successors: %bb.1
+
+    $z0_z8 = COPY $z0_z1
+
+    B %bb.1
+
+  bb.1:
+    liveins: $z0_z8
+
+    ; CHECK-LABEL: name: spills_fills_stack_id_zpr2strided
+    ; CHECK: stack:
+    ; CHECK:      - { id: 0, name: '', type: spill-slot, offset: 0, size: 32, alignment: 16
+    ; CHECK-NEXT:     stack-id: scalable-vector
+
+    ; EXPAND-LABEL: name: spills_fills_stack_id_zpr2strided
+    ; EXPAND: STR_ZXI $z0, $sp, 0
+    ; EXPAND: STR_ZXI $z8, $sp, 1
+    ; EXPAND: $z0 = LDR_ZXI $sp, 0
+    ; EXPAND: $z8 = LDR_ZXI $sp, 1
+
+    %0:zpr2strided = COPY $z0_z8
+
+    $z0_z1_z2_z3     = IMPLICIT_DEF
+    $z4_z5_z6_z7     = IMPLICIT_DEF
+    $z8_z9_z10_z11   = IMPLICIT_DEF
+    $z12_z13_z14_z15 = IMPLICIT_DEF
+    $z16_z17_z18_z19 = IMPLICIT_DEF
+    $z20_z21_z22_z23 = IMPLICIT_DEF
+    $z24_z25_z26_z27 = IMPLICIT_DEF
+    $z28_z29_z30_z31 = IMPLICIT_DEF
+
+    $z0_z8 = COPY %0
+    RET_ReallyLR
+...
+---
 name: spills_fills_stack_id_zpr3
 tracksRegLiveness: true
 registers:
@@ -210,3 +257,51 @@ body:             |
     $z0_z1_z2_z3 = COPY %0
     RET_ReallyLR
 ...
+---
+name: spills_fills_stack_id_zpr4strided
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: zpr4strided }
+stack:
+liveins:
+  - { reg: '$z0_z4_z8_z12', virtual-reg: '%0' }
+body:             |
+  bb.0.entry:
+    liveins: $z0_z1_z2_z3
+
+    $z0_z4_z8_z12 = COPY $z0_z1_z2_z3
+
+    B %bb.1
+
+  bb.1:
+    liveins: $z0_z4_z8_z12
+
+    ; CHECK-LABEL: name: spills_fills_stack_id_zpr4strided
+    ; CHECK: stack:
+    ; CHECK:      - { id: 0, name: '', type: spill-slot, offset: 0, size: 64, alignment: 16
+    ; CHECK-NEXT:     stack-id: scalable-vector
+
+    ; EXPAND-LABEL: name: spills_fills_stack_id_zpr4strided
+    ; EXPAND: STR_ZXI $z0, $sp, 0
+    ; EXPAND: STR_ZXI $z4, $sp, 1
+    ; EXPAND: STR_ZXI $z8, $sp, 2
+    ; EXPAND: STR_ZXI $z12, $sp, 3
+    ; EXPAND: $z0 = LDR_ZXI $sp, 0
+    ; EXPAND: $z4 = LDR_ZXI $sp, 1
+    ; EXPAND: $z8 = LDR_ZXI $sp, 2
+    ; EXPAND: $z12 = LDR_ZXI $sp, 3
+
+    %0:zpr4strided = COPY $z0_z4_z8_z12
+
+    $z0_z1_z2_z3     = IMPLICIT_DEF
+    $z4_z5_z6_z7     = IMPLICIT_DEF
+    $z8_z9_z10_z11   = IMPLICIT_DEF
+    $z12_z13_z14_z15 = IMPLICIT_DEF
+    $z16_z17_z18_z19 = IMPLICIT_DEF
+    $z20_z21_z22_z23 = IMPLICIT_DEF
+    $z24_z25_z26_z27 = IMPLICIT_DEF
+    $z28_z29_z30_z31 = IMPLICIT_DEF
+
+    $z0_z4_z8_z12 = COPY %0
+    RET_ReallyLR
+...

diff  --git a/llvm/test/CodeGen/AArch64/sve-copy-zprpair.mir b/llvm/test/CodeGen/AArch64/sve-copy-zprpair.mir
index 83a0b5dd1c14ac..a295d4eb7336b4 100644
--- a/llvm/test/CodeGen/AArch64/sve-copy-zprpair.mir
+++ b/llvm/test/CodeGen/AArch64/sve-copy-zprpair.mir
@@ -23,6 +23,29 @@ body:             |
     $z0_z1 = COPY $z1_z2
     RET_ReallyLR
 
+...
+---
+name:            copy_zpr2strided
+alignment:       4
+tracksRegLiveness: true
+liveins:
+  - { reg: '$z0_z1' }
+frameInfo:
+  maxCallFrameSize: 0
+body:             |
+  bb.0:
+    liveins: $z0_z1
+    ; CHECK-LABEL: name: copy_zpr2strided
+    ; CHECK: liveins: $z0_z1
+    ; CHECK: $z8 = ORR_ZZZ $z1, $z1
+    ; CHECK: $z0 = ORR_ZZZ $z0, $z0
+    ; CHECK: $z1 = ORR_ZZZ $z8, $z8
+    ; CHECK: $z0 = ORR_ZZZ $z0, $z0
+    ; CHECK: RET_ReallyLR
+    $z0_z8 = COPY $z0_z1
+    $z0_z1 = COPY $z0_z8
+    RET_ReallyLR
+
 ...
 ---
 name:            copy_zpr3
@@ -76,3 +99,30 @@ body:             |
     RET_ReallyLR
 
 ...
+---
+name:            copy_zpr4strided
+alignment:       4
+tracksRegLiveness: true
+liveins:
+  - { reg: '$z0_z1_z2_z3' }
+frameInfo:
+  maxCallFrameSize: 0
+body:             |
+  bb.0:
+    liveins: $z0_z1_z2_z3
+    ; CHECK-LABEL: name: copy_zpr4
+    ; CHECK: liveins: $z0_z1_z2_z3
+    ; CHECK: $z12 = ORR_ZZZ $z3, $z3
+    ; CHECK: $z8 = ORR_ZZZ $z2, $z2
+    ; CHECK: $z4 = ORR_ZZZ $z1, $z1
+    ; CHECK: $z0 = ORR_ZZZ $z0, $z0
+    ; CHECK: $z3 = ORR_ZZZ $z12, $z12
+    ; CHECK: $z2 = ORR_ZZZ $z8, $z8
+    ; CHECK: $z1 = ORR_ZZZ $z4, $z4
+    ; CHECK: $z0 = ORR_ZZZ $z0, $z0
+    ; CHECK: RET_ReallyLR
+    $z0_z4_z8_z12 = COPY $z0_z1_z2_z3
+    $z0_z1_z2_z3 = COPY $z0_z4_z8_z12
+    RET_ReallyLR
+
+...


        


More information about the llvm-commits mailing list