[llvm] 92a9bcc - [AArch64] Add tablegen patterns for fmla index with extract 0. (#114976)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 8 08:18:26 PST 2024


Author: David Green
Date: 2024-11-08T16:18:22Z
New Revision: 92a9bcc84d435ce28d59e7b07e2fb83a7f6bca63

URL: https://github.com/llvm/llvm-project/commit/92a9bcc84d435ce28d59e7b07e2fb83a7f6bca63
DIFF: https://github.com/llvm/llvm-project/commit/92a9bcc84d435ce28d59e7b07e2fb83a7f6bca63.diff

LOG: [AArch64] Add tablegen patterns for fmla index with extract 0. (#114976)

We have tablegen patterns to produce an indexed `fmla s0, s1, v2.s[2]`
from
  `fma extract(Rn, lane), Rm, Ra -> fmla`
But for the case of lane==0, we want to prefer the simple `fmadd s0, s1,
s2`. So we have patterns for
  `fma extract(Rn, 0), Rm, Ra -> fmadd`

The problem arises when we have two extracts, as tablegen starts to
prefer the second pattern, as it looks more specialized. This patch adds
additional patterns to catch this case:
  `fma extract(Rn, index), extract(Rm, 0), Ra -> fmla`
To make sure the simpler fmadd keeps being selected when both lanes are
extracted from lane 0 we need to add patterns for that case too:
  `fma extract(Rn, 0), extract(Rm, 0), Ra -> fmadd`

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64InstrFormats.td
    llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll
    llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll
    llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index e44caef686be29..b5f6388ea00285 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -5821,6 +5821,13 @@ multiclass ThreeOperandFPData<bit isNegated, bit isSub,string asm,
                        (f16 FPR16:$Ra))),
             (!cast<Instruction>(NAME # Hrrr)
               (f16 (EXTRACT_SUBREG V128:$Rn, hsub)), FPR16:$Rm, FPR16:$Ra)>;
+
+  def : Pat<(f16 (node (f16 (extractelt (v8f16 V128:$Rn), (i64 0))),
+                       (f16 (extractelt (v8f16 V128:$Rm), (i64 0))),
+                       (f16 FPR16:$Ra))),
+            (!cast<Instruction>(NAME # Hrrr)
+              (f16 (EXTRACT_SUBREG V128:$Rn, hsub)),
+              (f16 (EXTRACT_SUBREG V128:$Rm, hsub)), FPR16:$Ra)>;
   }
 
   def : Pat<(f32 (node (f32 FPR32:$Rn),
@@ -5835,6 +5842,13 @@ multiclass ThreeOperandFPData<bit isNegated, bit isSub,string asm,
             (!cast<Instruction>(NAME # Srrr)
               (EXTRACT_SUBREG V128:$Rn, ssub), FPR32:$Rm, FPR32:$Ra)>;
 
+  def : Pat<(f32 (node (f32 (extractelt (v4f32 V128:$Rn), (i64 0))),
+                       (f32 (extractelt (v4f32 V128:$Rm), (i64 0))),
+                       (f32 FPR32:$Ra))),
+            (!cast<Instruction>(NAME # Srrr)
+              (EXTRACT_SUBREG V128:$Rn, ssub),
+              (EXTRACT_SUBREG V128:$Rm, ssub), FPR32:$Ra)>;
+
   def : Pat<(f64 (node (f64 FPR64:$Rn),
                        (f64 (extractelt (v2f64 V128:$Rm), (i64 0))),
                        (f64 FPR64:$Ra))),
@@ -5846,6 +5860,13 @@ multiclass ThreeOperandFPData<bit isNegated, bit isSub,string asm,
                        (f64 FPR64:$Ra))),
             (!cast<Instruction>(NAME # Drrr)
               (EXTRACT_SUBREG V128:$Rn, dsub), FPR64:$Rm, FPR64:$Ra)>;
+
+  def : Pat<(f64 (node (f64 (extractelt (v2f64 V128:$Rn), (i64 0))),
+                       (f64 (extractelt (v2f64 V128:$Rm), (i64 0))),
+                       (f64 FPR64:$Ra))),
+            (!cast<Instruction>(NAME # Drrr)
+              (EXTRACT_SUBREG V128:$Rn, dsub),
+              (EXTRACT_SUBREG V128:$Rm, dsub), FPR64:$Ra)>;
 }
 
 //---
@@ -9282,6 +9303,11 @@ multiclass SIMDFPIndexedTiedPatterns<string INST, SDPatternOperator OpNode> {
                          (vector_extract (v8f16 V128_lo:$Rm), VectorIndexH:$idx))),
             (!cast<Instruction>(INST # "v1i16_indexed") FPR16:$Rd, FPR16:$Rn,
                 V128_lo:$Rm, VectorIndexH:$idx)>;
+  def : Pat<(f16 (OpNode (f16 FPR16:$Rd),
+                         (vector_extract (v8f16 V128:$Rn), (i64 0)),
+                         (vector_extract (v8f16 V128_lo:$Rm), VectorIndexH:$idx))),
+            (!cast<Instruction>(INST # "v1i16_indexed") FPR16:$Rd,
+                (f16 (EXTRACT_SUBREG V128:$Rn, hsub)), V128_lo:$Rm, VectorIndexH:$idx)>;
   } // Predicates = [HasNEON, HasFullFP16]
 
   // 2 variants for the .2s version: DUPLANE from 128-bit and DUP scalar.
@@ -9323,12 +9349,22 @@ multiclass SIMDFPIndexedTiedPatterns<string INST, SDPatternOperator OpNode> {
                          (vector_extract (v4f32 V128:$Rm), VectorIndexS:$idx))),
             (!cast<Instruction>(INST # "v1i32_indexed") FPR32:$Rd, FPR32:$Rn,
                 V128:$Rm, VectorIndexS:$idx)>;
+  def : Pat<(f32 (OpNode (f32 FPR32:$Rd),
+                         (vector_extract (v4f32 V128:$Rn), (i64 0)),
+                         (vector_extract (v4f32 V128:$Rm), VectorIndexS:$idx))),
+            (!cast<Instruction>(INST # "v1i32_indexed") FPR32:$Rd,
+                (f32 (EXTRACT_SUBREG V128:$Rn, ssub)), V128:$Rm, VectorIndexS:$idx)>;
 
   // 1 variant for 64-bit scalar version: extract from .1d or from .2d
   def : Pat<(f64 (OpNode (f64 FPR64:$Rd), (f64 FPR64:$Rn),
                          (vector_extract (v2f64 V128:$Rm), VectorIndexD:$idx))),
             (!cast<Instruction>(INST # "v1i64_indexed") FPR64:$Rd, FPR64:$Rn,
                 V128:$Rm, VectorIndexD:$idx)>;
+  def : Pat<(f64 (OpNode (f64 FPR64:$Rd),
+                         (vector_extract (v2f64 V128:$Rn), (i64 0)),
+                         (vector_extract (v2f64 V128:$Rm), VectorIndexD:$idx))),
+            (!cast<Instruction>(INST # "v1i64_indexed") FPR64:$Rd,
+                (f64 (EXTRACT_SUBREG V128:$Rn, dsub)), V128:$Rm, VectorIndexD:$idx)>;
 }
 
 let mayRaiseFPException = 1, Uses = [FPCR] in

diff  --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll
index fbe913e5472cc2..afcdb76067f433 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll
@@ -11,10 +11,10 @@ define <2 x half> @complex_mul_v2f16(<2 x half> %a, <2 x half> %b) {
 ; CHECK-NEXT:    mov h2, v0.h[1]
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
 ; CHECK-NEXT:    fmul h3, h0, v1.h[1]
-; CHECK-NEXT:    fmul h4, h2, v1.h[1]
-; CHECK-NEXT:    fmadd h2, h1, h2, h3
-; CHECK-NEXT:    fnmsub h0, h1, h0, h4
-; CHECK-NEXT:    mov v0.h[1], v2.h[0]
+; CHECK-NEXT:    fmul h2, h2, v1.h[1]
+; CHECK-NEXT:    fmla h3, h1, v0.h[1]
+; CHECK-NEXT:    fnmsub h0, h1, h0, h2
+; CHECK-NEXT:    mov v0.h[1], v3.h[0]
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $q0
 ; CHECK-NEXT:    ret
 entry:

diff  --git a/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll b/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll
index 725c44c9788988..368683e2b93af4 100644
--- a/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll
+++ b/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll
@@ -120,8 +120,7 @@ define half @t_vfmah_lane_f16_3_0(half %a, <4 x half> %c) {
 ; CHECK-LABEL: t_vfmah_lane_f16_3_0:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov h2, v1.h[3]
-; CHECK-NEXT:    fmadd h0, h1, h2, h0
+; CHECK-NEXT:    fmla h0, h1, v1.h[3]
 ; CHECK-NEXT:    ret
 entry:
   %b = extractelement <4 x half> %c, i32 0
@@ -310,8 +309,7 @@ define half @t_vfmsh_lane_f16_0_3(half %a, <4 x half> %c, i32 %lane) {
 ; CHECK-LABEL: t_vfmsh_lane_f16_0_3:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov h2, v1.h[3]
-; CHECK-NEXT:    fmsub h0, h2, h1, h0
+; CHECK-NEXT:    fmls h0, h1, v1.h[3]
 ; CHECK-NEXT:    ret
 entry:
   %b = extractelement <4 x half> %c, i32 0

diff  --git a/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll b/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll
index b2ea6ff200be1d..544d7680f01b80 100644
--- a/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll
+++ b/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll
@@ -84,8 +84,7 @@ define float @test_fmla_ss2S_1(float %a, float %b, <2 x float> %v) {
 define float @test_fmla_ss4S_3_ext0(float %a, <4 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss4S_3_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov s2, v1.s[3]
-; CHECK-NEXT:    fmadd s0, s1, s2, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[3]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <4 x float> %v, i32 0
   %tmp1 = extractelement <4 x float> %v, i32 3
@@ -96,8 +95,7 @@ define float @test_fmla_ss4S_3_ext0(float %a, <4 x float> %v) {
 define float @test_fmla_ss4S_3_ext0_swp(float %a, <4 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss4S_3_ext0_swp:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov s2, v1.s[3]
-; CHECK-NEXT:    fmadd s0, s2, s1, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[3]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <4 x float> %v, i32 0
   %tmp1 = extractelement <4 x float> %v, i32 3
@@ -120,8 +118,7 @@ define float @test_fmla_ss2S_3_ext0(float %a, <2 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss2S_3_ext0:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov s2, v1.s[1]
-; CHECK-NEXT:    fmadd s0, s1, s2, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x float> %v, i32 0
   %tmp1 = extractelement <2 x float> %v, i32 1
@@ -133,8 +130,7 @@ define float @test_fmla_ss2S_3_ext0_swp(float %a, <2 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss2S_3_ext0_swp:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov s2, v1.s[1]
-; CHECK-NEXT:    fmadd s0, s2, s1, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x float> %v, i32 0
   %tmp1 = extractelement <2 x float> %v, i32 1
@@ -218,8 +214,7 @@ define double @test_fmla_dd2D_1_swap(double %a, double %b, <2 x double> %v) {
 define double @test_fmla_ss2D_1_ext0(double %a, <2 x double> %v) {
 ; CHECK-LABEL: test_fmla_ss2D_1_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov d2, v1.d[1]
-; CHECK-NEXT:    fmadd d0, d1, d2, d0
+; CHECK-NEXT:    fmla d0, d1, v1.d[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x double> %v, i32 0
   %tmp1 = extractelement <2 x double> %v, i32 1
@@ -230,8 +225,7 @@ define double @test_fmla_ss2D_1_ext0(double %a, <2 x double> %v) {
 define double @test_fmla_ss2D_1_ext0_swp(double %a, <2 x double> %v) {
 ; CHECK-LABEL: test_fmla_ss2D_1_ext0_swp:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov d2, v1.d[1]
-; CHECK-NEXT:    fmadd d0, d2, d1, d0
+; CHECK-NEXT:    fmla d0, d1, v1.d[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x double> %v, i32 0
   %tmp1 = extractelement <2 x double> %v, i32 1
@@ -340,8 +334,7 @@ define float @test_fmls_ss2S_1(float %a, float %b, <2 x float> %v) {
 define float @test_fmls_ss4S_3_ext0(float %a, <4 x float> %v) {
 ; CHECK-LABEL: test_fmls_ss4S_3_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov s2, v1.s[3]
-; CHECK-NEXT:    fmsub s0, s1, s2, s0
+; CHECK-NEXT:    fmls s0, s1, v1.s[3]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <4 x float> %v, i32 0
   %tmp1 = extractelement <4 x float> %v, i32 3
@@ -437,8 +430,7 @@ define double @test_fmls_dd2D_1_swap(double %a, double %b, <2 x double> %v) {
 define double @test_fmls_dd2D_1_ext0(double %a, <2 x double> %v) {
 ; CHECK-LABEL: test_fmls_dd2D_1_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov d2, v1.d[1]
-; CHECK-NEXT:    fmsub d0, d1, d2, d0
+; CHECK-NEXT:    fmls d0, d1, v1.d[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x double> %v, i32 0
   %tmp1 = extractelement <2 x double> %v, i32 1


        


More information about the llvm-commits mailing list