[llvm] [AArch64] Ensure Neoverse V1 scheduling model includes all SVE pseudos. (PR #84187)

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 6 07:29:12 PST 2024


https://github.com/davemgreen created https://github.com/llvm/llvm-project/pull/84187

With the many pseudos used in SVE codegen it can be too easy to miss instructions. This enables the existing test we have for checking the scheduling info of the pseudos matches the real instructions, and adjusts the scheduling info in the NeoverseV1 model to make sure all are handled. In the cases I could I opted to use the same info as in the NeoverseV2 model, to keep the differences smaller.

>From 3a4a3facbe82e0b968b852fbfd720ed35709944b Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Wed, 6 Mar 2024 15:23:02 +0000
Subject: [PATCH] [AArch64] Ensure Neoverse V1 scheduling model includes all
 SVE pseudos.

With the many pseudos used in SVE codegen it can be too easy to miss
instructions. This enables the existing test we have for checking the
scheduling info of the pseudos matches the real instructions, and adjusts the
scheduling info in the NeoverseV1 model to make sure all are handled. In the
cases I could I opted to use the same info as in the NeoverseV2 model, to keep
the differences smaller.
---
 .../Target/AArch64/AArch64SchedNeoverseV1.td  | 105 ++++++++++--------
 .../Target/AArch64/AArch64SchedNeoverseV2.td  |   6 +-
 .../AArch64/AArch64SVESchedPseudoTest.cpp     |   4 +
 3 files changed, 67 insertions(+), 48 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td
index e50a401f8b2aec..c7dfd64b2fb24e 100644
--- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td
+++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td
@@ -1372,18 +1372,18 @@ def : InstRW<[V1Write_3c_2M0], (instregex "^PTRUES_[BHSD]$")>;
 // Arithmetic, basic
 // Logical
 def : InstRW<[V1Write_2c_1V01],
-             (instregex "^(ABS|CNOT|NEG)_ZPmZ_[BHSD]$",
-                        "^(ADD|SUB)_Z(I|P[mZ]Z|ZZ)_[BHSD]$",
+             (instregex "^(ABS|CNOT|NEG)_ZPmZ_[BHSD]",
+                        "^(ADD|SUB)_Z(I|P[mZ]Z|ZZ)_[BHSD]",
                         "^ADR_[SU]XTW_ZZZ_D_[0123]$",
                         "^ADR_LSL_ZZZ_[SD]_[0123]$",
-                        "^[SU]ABD_ZP[mZ]Z_[BHSD]$",
-                        "^[SU](MAX|MIN)_Z(I|P[mZ]Z)_[BHSD]$",
+                        "^[SU]ABD_ZP[mZ]Z_[BHSD]",
+                        "^[SU](MAX|MIN)_Z(I|P[mZ]Z)_[BHSD]",
                         "^[SU]Q(ADD|SUB)_Z(I|ZZ)_[BHSD]$",
-                        "^SUBR_Z(I|P[mZ]Z)_[BHSD]$",
+                        "^SUBR_Z(I|P[mZ]Z)_[BHSD]",
                         "^(AND|EOR|ORR)_ZI$",
-                        "^(AND|BIC|EOR|EOR(BT|TB)?|ORR)_ZZZ$",
+                        "^(AND|BIC|EOR|EOR(BT|TB)?|ORR)_ZP?ZZ",
                         "^EOR(BT|TB)_ZZZ_[BHSD]$",
-                        "^(AND|BIC|EOR|NOT|ORR)_ZPmZ_[BHSD]$")>;
+                        "^(AND|BIC|EOR|NOT|ORR)_ZPmZ_[BHSD]")>;
 
 // Arithmetic, shift
 def : InstRW<[V1Write_2c_1V1],
@@ -1394,10 +1394,10 @@ def : InstRW<[V1Write_2c_1V1],
                         "^(ASRR|LSLR|LSRR)_ZPmZ_[BHSD]")>;
 
 // Arithmetic, shift right for divide
-def : InstRW<[V1Write_4c_1V1], (instregex "^ASRD_ZP[mZ]I_[BHSD]$")>;
+def : InstRW<[V1Write_4c_1V1], (instregex "^ASRD_(ZPmI|ZPZI)_[BHSD]")>;
 
 // Count/reverse bits
-def : InstRW<[V1Write_2c_1V01], (instregex "^(CLS|CLZ|CNT|RBIT)_ZPmZ_[BHSD]$")>;
+def : InstRW<[V1Write_2c_1V01], (instregex "^(CLS|CLZ|CNT|RBIT)_ZPmZ_[BHSD]")>;
 
 // Broadcast logical bitmask immediate to vector
 def : InstRW<[V1Write_2c_1V01], (instrs DUPM_ZI)>;
@@ -1420,10 +1420,10 @@ def : InstRW<[V1Write_3c_1V0], (instregex "^[SU]CVTF_ZPmZ_Dto[HSD]",
                                           "^[SU]CVTF_ZPmZ_StoD")>;
 
 // Convert to floating point, 32b to single or half
-def : InstRW<[V1Write_4c_2V0], (instregex "^[SU]CVTF_ZPmZ_Sto[HS]$")>;
+def : InstRW<[V1Write_4c_2V0], (instregex "^[SU]CVTF_ZPmZ_Sto[HS]")>;
 
 // Convert to floating point, 16b to half
-def : InstRW<[V1Write_6c_4V0], (instregex "^[SU]CVTF_ZPmZ_HtoH$")>;
+def : InstRW<[V1Write_6c_4V0], (instregex "^[SU]CVTF_ZPmZ_HtoH")>;
 
 // Copy, scalar
 def : InstRW<[V1Write_5c_1M0_1V01], (instregex "^CPY_ZPmR_[BHSD]$")>;
@@ -1432,10 +1432,12 @@ def : InstRW<[V1Write_5c_1M0_1V01], (instregex "^CPY_ZPmR_[BHSD]$")>;
 def : InstRW<[V1Write_2c_1V01], (instregex "^CPY_ZP([mz]I|mV)_[BHSD]$")>;
 
 // Divides, 32 bit
-def : InstRW<[V1Write_12c7_1V0], (instregex "^[SU]DIVR?_ZPmZ_S$")>;
+def : InstRW<[V1Write_12c7_1V0], (instregex "^[SU]DIVR?_ZPmZ_S",
+                                             "^[SU]DIV_ZPZZ_S")>;
 
 // Divides, 64 bit
-def : InstRW<[V1Write_20c7_1V0], (instregex "^[SU]DIVR?_ZPmZ_D$")>;
+def : InstRW<[V1Write_20c7_1V0], (instregex "^[SU]DIVR?_ZPmZ_D",
+                                             "^[SU]DIV_ZPZZ_D")>;
 
 // Dot product, 8 bit
 def : InstRW<[V1Write_3c_1V01], (instregex "^[SU]DOT_ZZZI?_S$")>;
@@ -1454,9 +1456,9 @@ def : InstRW<[V1Write_2c_1V01], (instregex "^DUP_ZI_[BHSD]$",
 def : InstRW<[V1Write_3c_1M0], (instregex "^DUP_ZR_[BHSD]$")>;
 
 // Extend, sign or zero
-def : InstRW<[V1Write_2c_1V1], (instregex "^[SU]XTB_ZPmZ_[HSD]$",
-                                          "^[SU]XTH_ZPmZ_[SD]$",
-                                          "^[SU]XTW_ZPmZ_[D]$")>;
+def : InstRW<[V1Write_2c_1V1], (instregex "^[SU]XTB_ZPmZ_[HSD]",
+                                          "^[SU]XTH_ZPmZ_[SD]",
+                                          "^[SU]XTW_ZPmZ_[D]")>;
 
 // Extract
 def : InstRW<[V1Write_2c_1V01], (instrs EXT_ZZI)>;
@@ -1489,18 +1491,22 @@ def : InstRW<[V1Write_2c_1V01], (instregex "^MOVPRFX_ZP[mz]Z_[BHSD]$",
 def : InstRW<[V1Write_3c_1V01], (instrs SMMLA_ZZZ, UMMLA_ZZZ, USMMLA_ZZZ)>;
 
 // Multiply, B, H, S element size
-def : InstRW<[V1Write_4c_1V0], (instregex "^MUL_(ZI|ZPmZ)_[BHS]$",
-                                          "^[SU]MULH_(ZPmZ|ZZZ)_[BHS]$")>;
+def : InstRW<[V1Write_4c_1V0], (instregex "^MUL_(ZI|ZPmZ|ZZZI|ZZZ)_[BHS]",
+                                          "^MUL_ZPZZ_[BHS]",
+                                          "^[SU]MULH_(ZPmZ|ZZZ)_[BHS]",
+                                          "^[SU]MULH_ZPZZ_[BHS]")>;
 
 // Multiply, D element size
 // Multiply accumulate, D element size
-def : InstRW<[V1Write_5c_2V0], (instregex "^MUL_(ZI|ZPmZ)_D$",
-                                          "^[SU]MULH_ZPmZ_D$",
-                                          "^(MLA|MLS|MAD|MSB)_ZPmZZ_D$")>;
+def : InstRW<[V1Write_5c_2V0], (instregex "^MUL_(ZI|ZPmZ|ZZZI|ZZZ)_D",
+                                          "^MUL_ZPZZ_D",
+                                          "^[SU]MULH_(ZPmZ|ZZZ)_D",
+                                          "^[SU]MULH_ZPZZ_D",
+                                          "^(MLA|MLS|MAD|MSB)_(ZPmZZ|ZPZZZ)_D")>;
 
 // Multiply accumulate, B, H, S element size
 // NOTE: This is not specified in the SOG.
-def : InstRW<[V1Write_4c_1V0], (instregex "^(ML[AS]|MAD|MSB)_ZPmZZ_[BHS]")>;
+def : InstRW<[V1Write_4c_1V0], (instregex "^(ML[AS]|MAD|MSB)_(ZPmZZ|ZPZZZ)_[BHS]")>;
 
 // Predicate counting vector
 def : InstRW<[V1Write_2c_1V0], (instregex "^([SU]Q)?(DEC|INC)[HWD]_ZPiI$")>;
@@ -1547,12 +1553,17 @@ def : InstRW<[V1Write_2c_1V01], (instregex "^SEL_ZPZZ_[BHSD]$",
 // -----------------------------------------------------------------------------
 
 // Floating point absolute value/difference
+def : InstRW<[V1Write_2c_1V01], (instregex "^FAB[SD]_ZPmZ_[HSD]",
+                                           "^FABD_ZPZZ_[HSD]",
+                                           "^FABS_ZPmZ_[HSD]")>;
+
 // Floating point arithmetic
-def : InstRW<[V1Write_2c_1V01], (instregex "^FAB[SD]_ZPmZ_[HSD]$",
-                                           "^F(ADD|SUB)_(ZPm[IZ]|ZZZ)_[HSD]$",
-                                           "^FADDP_ZPmZZ_[HSD]$",
-                                           "^FNEG_ZPmZ_[HSD]$",
-                                           "^FSUBR_ZPm[IZ]_[HSD]$")>;
+def : InstRW<[V1Write_2c_1V01], (instregex "^F(ADD|SUB)_(ZPm[IZ]|ZZZ)_[HSD]",
+                                           "^F(ADD|SUB)_ZPZ[IZ]_[HSD]",
+                                           "^FADDP_ZPmZZ_[HSD]",
+                                           "^FNEG_ZPmZ_[HSD]",
+                                           "^FSUBR_ZPm[IZ]_[HSD]",
+                                           "^FSUBR_(ZPZI|ZPZZ)_[HSD]")>;
 
 // Floating point associative add, F16
 def : InstRW<[V1Write_19c_18V0], (instrs FADDA_VPZ_H)>;
@@ -1577,40 +1588,44 @@ def : InstRW<[V1Write_5c_1V01], (instregex "^FCMLA_ZPmZZ_[HSD]$",
 
 // Floating point convert, long or narrow (F16 to F32 or F32 to F16)
 // Floating point convert to integer, F32
-def : InstRW<[V1Write_4c_2V0], (instregex "^FCVT_ZPmZ_(HtoS|StoH)$",
-                                          "^FCVTZ[SU]_ZPmZ_(HtoS|StoS)$")>;
+def : InstRW<[V1Write_4c_2V0], (instregex "^FCVT_ZPmZ_(HtoS|StoH)",
+                                          "^FCVTZ[SU]_ZPmZ_(HtoS|StoS)")>;
 
 // Floating point convert, long or narrow (F16 to F64, F32 to F64, F64 to F32 or F64 to F16)
 // Floating point convert to integer, F64
-def : InstRW<[V1Write_3c_1V0], (instregex "^FCVT_ZPmZ_(HtoD|StoD|DtoS|DtoH)$",
-                                          "^FCVTZ[SU]_ZPmZ_(HtoD|StoD|DtoS|DtoD)$")>;
+def : InstRW<[V1Write_3c_1V0], (instregex "^FCVT_ZPmZ_(HtoD|StoD|DtoS|DtoH)",
+                                          "^FCVTZ[SU]_ZPmZ_(HtoD|StoD|DtoS|DtoD)")>;
 
 // Floating point convert to integer, F16
-def : InstRW<[V1Write_6c_4V0], (instregex "^FCVTZ[SU]_ZPmZ_HtoH$")>;
+def : InstRW<[V1Write_6c_4V0], (instregex "^FCVTZ[SU]_ZPmZ_HtoH")>;
 
 // Floating point copy
 def : InstRW<[V1Write_2c_1V01], (instregex "^FCPY_ZPmI_[HSD]$",
                                            "^FDUP_ZI_[HSD]$")>;
 
 // Floating point divide, F16
-def : InstRW<[V1Write_13c10_1V0], (instregex "^FDIVR?_ZPmZ_H$")>;
+def : InstRW<[V1Write_13c10_1V0], (instregex "^FDIVR?_(ZPmZ|ZPZZ)_H")>;
 
 // Floating point divide, F32
-def : InstRW<[V1Write_10c7_1V0], (instregex "^FDIVR?_ZPmZ_S$")>;
+def : InstRW<[V1Write_10c7_1V0], (instregex "^FDIVR?_(ZPmZ|ZPZZ)_S")>;
 
 // Floating point divide, F64
-def : InstRW<[V1Write_15c7_1V0], (instregex "^FDIVR?_ZPmZ_D$")>;
+def : InstRW<[V1Write_15c7_1V0], (instregex "^FDIVR?_(ZPmZ|ZPZZ)_D")>;
 
 // Floating point min/max
-def : InstRW<[V1Write_2c_1V01], (instregex "^F(MAX|MIN)(NM)?_ZPm[IZ]_[HSD]$")>;
+def : InstRW<[V1Write_2c_1V01], (instregex "^F(MAX|MIN)(NM)?_ZPm[IZ]_[HSD]",
+                                           "^F(MAX|MIN)(NM)?_ZPZ[IZ]_[HSD]")>;
 
 // Floating point multiply
-def : InstRW<[V1Write_3c_1V01], (instregex "^F(SCALE|MULX)_ZPmZ_[HSD]$",
-                                           "^FMUL_(ZPm[IZ]|ZZZI?)_[HSD]$")>;
+def : InstRW<[V1Write_3c_1V01], (instregex "^(FSCALE|FMULX)_ZPmZ_[HSD]",
+                                           "^FMULX_ZPZZ_[HSD]",
+                                           "^FMUL_(ZPm[IZ]|ZZZI?)_[HSD]",
+                                           "^FMUL_ZPZ[IZ]_[HSD]")>;
 
 // Floating point multiply accumulate
 // Floating point reciprocal step
 def : InstRW<[V1Write_4c_1V01], (instregex "^F(N?M(AD|SB)|N?ML[AS])_ZPmZZ_[HSD]$",
+                                           "^FN?ML[AS]_ZPZZZ_[HSD]",
                                            "^FML[AS]_ZZZI_[HSD]$",
                                            "^F(RECPS|RSQRTS)_ZZZ_[HSD]$")>;
 
@@ -1624,7 +1639,7 @@ def : InstRW<[V1Write_4c_2V0], (instrs FRECPE_ZZ_S, FRSQRTE_ZZ_S)>;
 def : InstRW<[V1Write_3c_1V0], (instrs FRECPE_ZZ_D, FRSQRTE_ZZ_D)>;
 
 // Floating point reciprocal exponent
-def : InstRW<[V1Write_3c_1V0], (instregex "^FRECPX_ZPmZ_[HSD]$")>;
+def : InstRW<[V1Write_3c_1V0], (instregex "^FRECPX_ZPmZ_[HSD]")>;
 
 // Floating point reduction, F16
 def : InstRW<[V1Write_13c_6V01], (instregex "^F(ADD|((MAX|MIN)(NM)?))V_VPZ_H$")>;
@@ -1636,22 +1651,22 @@ def : InstRW<[V1Write_11c_1V_5V01], (instregex "^F(ADD|((MAX|MIN)(NM)?))V_VPZ_S$
 def : InstRW<[V1Write_9c_1V_4V01], (instregex "^F(ADD|((MAX|MIN)(NM)?))V_VPZ_D$")>;
 
 // Floating point round to integral, F16
-def : InstRW<[V1Write_6c_1V0], (instregex "^FRINT[AIMNPXZ]_ZPmZ_H$")>;
+def : InstRW<[V1Write_6c_1V0], (instregex "^FRINT[AIMNPXZ]_ZPmZ_H")>;
 
 // Floating point round to integral, F32
-def : InstRW<[V1Write_4c_1V0], (instregex "^FRINT[AIMNPXZ]_ZPmZ_S$")>;
+def : InstRW<[V1Write_4c_1V0], (instregex "^FRINT[AIMNPXZ]_ZPmZ_S")>;
 
 // Floating point round to integral, F64
-def : InstRW<[V1Write_3c_1V0], (instregex "^FRINT[AIMNPXZ]_ZPmZ_D$")>;
+def : InstRW<[V1Write_3c_1V0], (instregex "^FRINT[AIMNPXZ]_ZPmZ_D")>;
 
 // Floating point square root, F16
-def : InstRW<[V1Write_13c10_1V0], (instrs FSQRT_ZPmZ_H)>;
+def : InstRW<[V1Write_13c10_1V0], (instregex "^FSQRT_ZPmZ_H")>;
 
 // Floating point square root, F32
-def : InstRW<[V1Write_10c7_1V0], (instrs FSQRT_ZPmZ_S)>;
+def : InstRW<[V1Write_10c7_1V0], (instregex "^FSQRT_ZPmZ_S")>;
 
 // Floating point square root, F64
-def : InstRW<[V1Write_16c7_1V0], (instrs FSQRT_ZPmZ_D)>;
+def : InstRW<[V1Write_16c7_1V0], (instregex "^FSQRT_ZPmZ_D")>;
 
 // Floating point trigonometric
 def : InstRW<[V1Write_3c_1V01], (instregex "^FEXPA_ZZ_[HSD]$",
diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td
index 807ce40bc5eac1..f10b94523d2e03 100644
--- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td
+++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td
@@ -2567,13 +2567,13 @@ def : InstRW<[V2Write_4cyc_2V02], (instregex "^FRINT[AIMNPXZ]_ZPmZ_S")>;
 def : InstRW<[V2Write_3cyc_1V02], (instregex "^FRINT[AIMNPXZ]_ZPmZ_D")>;
 
 // Floating point square root, F16
-def : InstRW<[V2Write_13cyc_1V0_12rc], (instregex "^FSQRT_ZPmZ_H", "^FSQRT_ZPmZ_H")>;
+def : InstRW<[V2Write_13cyc_1V0_12rc], (instregex "^FSQRT_ZPmZ_H")>;
 
 // Floating point square root, F32
-def : InstRW<[V2Write_10cyc_1V0_9rc], (instregex "^FSQRT_ZPmZ_S", "^FSQRT_ZPmZ_S")>;
+def : InstRW<[V2Write_10cyc_1V0_9rc], (instregex "^FSQRT_ZPmZ_S")>;
 
 // Floating point square root, F64
-def : InstRW<[V2Write_16cyc_1V0_14rc], (instregex "^FSQRT_ZPmZ_D", "^FSQRT_ZPmZ_D")>;
+def : InstRW<[V2Write_16cyc_1V0_14rc], (instregex "^FSQRT_ZPmZ_D")>;
 
 // Floating point trigonometric exponentiation
 def : InstRW<[V2Write_3cyc_1V1], (instregex "^FEXPA_ZZ_[HSD]")>;
diff --git a/llvm/unittests/Target/AArch64/AArch64SVESchedPseudoTest.cpp b/llvm/unittests/Target/AArch64/AArch64SVESchedPseudoTest.cpp
index 9d8633353e1f9f..6098d4e6239251 100644
--- a/llvm/unittests/Target/AArch64/AArch64SVESchedPseudoTest.cpp
+++ b/llvm/unittests/Target/AArch64/AArch64SVESchedPseudoTest.cpp
@@ -107,6 +107,10 @@ TEST(AArch64SVESchedPseudoTesta510, IsCorrect) {
   runSVEPseudoTestForCPU("cortex-a510");
 }
 
+TEST(AArch64SVESchedPseudoTestv1, IsCorrect) {
+  runSVEPseudoTestForCPU("neoverse-v1");
+}
+
 TEST(AArch64SVESchedPseudoTestv2, IsCorrect) {
   runSVEPseudoTestForCPU("neoverse-v2");
 }



More information about the llvm-commits mailing list