[llvm] [AArch64] Add tablegen patterns for fmla index with extract 0. (PR #114976)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 5 04:42:00 PST 2024


https://github.com/davemgreen created https://github.com/llvm/llvm-project/pull/114976

We have tablegen patterns to produce an indexed `fmla s0, s1, v2.s[2]` from
  `fma extract(Rn, lane), Rm, Ra -> fmla`
But for the case of lane==0, we want to prefer the simple `fmadd s0, s1, s2`. So we have patterns for
  `fma extract(Rn, 0), Rm, Ra -> fmadd`

The problem arises when we have two extracts, as tablegen starts to prefer the second pattern, as it looks more specialized. This patch adds addition patterns to catch this case:
  `fma extract(Rn, index), extract(Rm, 0), Ra -> fmla`
To make sure the simpler fmadd keeps being used when both lanes are extracted from lane 0 we need to add patterns for that case too:
  `fma extract(Rn, 0), extract(Rm, 0), Ra -> fmadd`

>From 2d2c95a9c4a34416c93c09d5df75a87d89699c29 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Tue, 5 Nov 2024 12:05:25 +0000
Subject: [PATCH] [AArch64] Add tablegen patterns for fmla index with extract
 0.

We have tablegen patterns to produce an indexed `fmla s0, s1, v2.s[2]` from
  fma extract(Rn, lane), Rm, Ra -> fmla
But for the case of lane==0, we want to prefer the simple `fmadd s0, s1, s2. So
we have patterns for
  fma extract(Rn, 0), Rm, Ra -> fmadd

The problem arises when we have two extracts, as tablegen starts to prefer the
second pattern, as it looks more specialized. This patch adds addition patterns
to catch this case:
  fma extract(Rn, index), extract(Rm, 0), Ra -> fmla
To make sure the simpler fmadd keeps being used when both lanes are extracted
from lane 0 we need to add patterns for that case too:
  fma extract(Rn, 0), extract(Rm, 0), Ra -> fmadd
---
 .../lib/Target/AArch64/AArch64InstrFormats.td | 36 +++++++++++++++++++
 .../AArch64/complex-deinterleaving-f16-mul.ll |  8 ++---
 .../CodeGen/AArch64/fp16_intrinsic_lane.ll    |  6 ++--
 .../AArch64/neon-scalar-by-elem-fma.ll        | 24 +++++--------
 4 files changed, 50 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index e44caef686be29..b5f6388ea00285 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -5821,6 +5821,13 @@ multiclass ThreeOperandFPData<bit isNegated, bit isSub,string asm,
                        (f16 FPR16:$Ra))),
             (!cast<Instruction>(NAME # Hrrr)
               (f16 (EXTRACT_SUBREG V128:$Rn, hsub)), FPR16:$Rm, FPR16:$Ra)>;
+
+  def : Pat<(f16 (node (f16 (extractelt (v8f16 V128:$Rn), (i64 0))),
+                       (f16 (extractelt (v8f16 V128:$Rm), (i64 0))),
+                       (f16 FPR16:$Ra))),
+            (!cast<Instruction>(NAME # Hrrr)
+              (f16 (EXTRACT_SUBREG V128:$Rn, hsub)),
+              (f16 (EXTRACT_SUBREG V128:$Rm, hsub)), FPR16:$Ra)>;
   }
 
   def : Pat<(f32 (node (f32 FPR32:$Rn),
@@ -5835,6 +5842,13 @@ multiclass ThreeOperandFPData<bit isNegated, bit isSub,string asm,
             (!cast<Instruction>(NAME # Srrr)
               (EXTRACT_SUBREG V128:$Rn, ssub), FPR32:$Rm, FPR32:$Ra)>;
 
+  def : Pat<(f32 (node (f32 (extractelt (v4f32 V128:$Rn), (i64 0))),
+                       (f32 (extractelt (v4f32 V128:$Rm), (i64 0))),
+                       (f32 FPR32:$Ra))),
+            (!cast<Instruction>(NAME # Srrr)
+              (EXTRACT_SUBREG V128:$Rn, ssub),
+              (EXTRACT_SUBREG V128:$Rm, ssub), FPR32:$Ra)>;
+
   def : Pat<(f64 (node (f64 FPR64:$Rn),
                        (f64 (extractelt (v2f64 V128:$Rm), (i64 0))),
                        (f64 FPR64:$Ra))),
@@ -5846,6 +5860,13 @@ multiclass ThreeOperandFPData<bit isNegated, bit isSub,string asm,
                        (f64 FPR64:$Ra))),
             (!cast<Instruction>(NAME # Drrr)
               (EXTRACT_SUBREG V128:$Rn, dsub), FPR64:$Rm, FPR64:$Ra)>;
+
+  def : Pat<(f64 (node (f64 (extractelt (v2f64 V128:$Rn), (i64 0))),
+                       (f64 (extractelt (v2f64 V128:$Rm), (i64 0))),
+                       (f64 FPR64:$Ra))),
+            (!cast<Instruction>(NAME # Drrr)
+              (EXTRACT_SUBREG V128:$Rn, dsub),
+              (EXTRACT_SUBREG V128:$Rm, dsub), FPR64:$Ra)>;
 }
 
 //---
@@ -9282,6 +9303,11 @@ multiclass SIMDFPIndexedTiedPatterns<string INST, SDPatternOperator OpNode> {
                          (vector_extract (v8f16 V128_lo:$Rm), VectorIndexH:$idx))),
             (!cast<Instruction>(INST # "v1i16_indexed") FPR16:$Rd, FPR16:$Rn,
                 V128_lo:$Rm, VectorIndexH:$idx)>;
+  def : Pat<(f16 (OpNode (f16 FPR16:$Rd),
+                         (vector_extract (v8f16 V128:$Rn), (i64 0)),
+                         (vector_extract (v8f16 V128_lo:$Rm), VectorIndexH:$idx))),
+            (!cast<Instruction>(INST # "v1i16_indexed") FPR16:$Rd,
+                (f16 (EXTRACT_SUBREG V128:$Rn, hsub)), V128_lo:$Rm, VectorIndexH:$idx)>;
   } // Predicates = [HasNEON, HasFullFP16]
 
   // 2 variants for the .2s version: DUPLANE from 128-bit and DUP scalar.
@@ -9323,12 +9349,22 @@ multiclass SIMDFPIndexedTiedPatterns<string INST, SDPatternOperator OpNode> {
                          (vector_extract (v4f32 V128:$Rm), VectorIndexS:$idx))),
             (!cast<Instruction>(INST # "v1i32_indexed") FPR32:$Rd, FPR32:$Rn,
                 V128:$Rm, VectorIndexS:$idx)>;
+  def : Pat<(f32 (OpNode (f32 FPR32:$Rd),
+                         (vector_extract (v4f32 V128:$Rn), (i64 0)),
+                         (vector_extract (v4f32 V128:$Rm), VectorIndexS:$idx))),
+            (!cast<Instruction>(INST # "v1i32_indexed") FPR32:$Rd,
+                (f32 (EXTRACT_SUBREG V128:$Rn, ssub)), V128:$Rm, VectorIndexS:$idx)>;
 
   // 1 variant for 64-bit scalar version: extract from .1d or from .2d
   def : Pat<(f64 (OpNode (f64 FPR64:$Rd), (f64 FPR64:$Rn),
                          (vector_extract (v2f64 V128:$Rm), VectorIndexD:$idx))),
             (!cast<Instruction>(INST # "v1i64_indexed") FPR64:$Rd, FPR64:$Rn,
                 V128:$Rm, VectorIndexD:$idx)>;
+  def : Pat<(f64 (OpNode (f64 FPR64:$Rd),
+                         (vector_extract (v2f64 V128:$Rn), (i64 0)),
+                         (vector_extract (v2f64 V128:$Rm), VectorIndexD:$idx))),
+            (!cast<Instruction>(INST # "v1i64_indexed") FPR64:$Rd,
+                (f64 (EXTRACT_SUBREG V128:$Rn, dsub)), V128:$Rm, VectorIndexD:$idx)>;
 }
 
 let mayRaiseFPException = 1, Uses = [FPCR] in
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll
index fbe913e5472cc2..afcdb76067f433 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll
@@ -11,10 +11,10 @@ define <2 x half> @complex_mul_v2f16(<2 x half> %a, <2 x half> %b) {
 ; CHECK-NEXT:    mov h2, v0.h[1]
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
 ; CHECK-NEXT:    fmul h3, h0, v1.h[1]
-; CHECK-NEXT:    fmul h4, h2, v1.h[1]
-; CHECK-NEXT:    fmadd h2, h1, h2, h3
-; CHECK-NEXT:    fnmsub h0, h1, h0, h4
-; CHECK-NEXT:    mov v0.h[1], v2.h[0]
+; CHECK-NEXT:    fmul h2, h2, v1.h[1]
+; CHECK-NEXT:    fmla h3, h1, v0.h[1]
+; CHECK-NEXT:    fnmsub h0, h1, h0, h2
+; CHECK-NEXT:    mov v0.h[1], v3.h[0]
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $q0
 ; CHECK-NEXT:    ret
 entry:
diff --git a/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll b/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll
index 725c44c9788988..368683e2b93af4 100644
--- a/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll
+++ b/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll
@@ -120,8 +120,7 @@ define half @t_vfmah_lane_f16_3_0(half %a, <4 x half> %c) {
 ; CHECK-LABEL: t_vfmah_lane_f16_3_0:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov h2, v1.h[3]
-; CHECK-NEXT:    fmadd h0, h1, h2, h0
+; CHECK-NEXT:    fmla h0, h1, v1.h[3]
 ; CHECK-NEXT:    ret
 entry:
   %b = extractelement <4 x half> %c, i32 0
@@ -310,8 +309,7 @@ define half @t_vfmsh_lane_f16_0_3(half %a, <4 x half> %c, i32 %lane) {
 ; CHECK-LABEL: t_vfmsh_lane_f16_0_3:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov h2, v1.h[3]
-; CHECK-NEXT:    fmsub h0, h2, h1, h0
+; CHECK-NEXT:    fmls h0, h1, v1.h[3]
 ; CHECK-NEXT:    ret
 entry:
   %b = extractelement <4 x half> %c, i32 0
diff --git a/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll b/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll
index b2ea6ff200be1d..544d7680f01b80 100644
--- a/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll
+++ b/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll
@@ -84,8 +84,7 @@ define float @test_fmla_ss2S_1(float %a, float %b, <2 x float> %v) {
 define float @test_fmla_ss4S_3_ext0(float %a, <4 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss4S_3_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov s2, v1.s[3]
-; CHECK-NEXT:    fmadd s0, s1, s2, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[3]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <4 x float> %v, i32 0
   %tmp1 = extractelement <4 x float> %v, i32 3
@@ -96,8 +95,7 @@ define float @test_fmla_ss4S_3_ext0(float %a, <4 x float> %v) {
 define float @test_fmla_ss4S_3_ext0_swp(float %a, <4 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss4S_3_ext0_swp:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov s2, v1.s[3]
-; CHECK-NEXT:    fmadd s0, s2, s1, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[3]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <4 x float> %v, i32 0
   %tmp1 = extractelement <4 x float> %v, i32 3
@@ -120,8 +118,7 @@ define float @test_fmla_ss2S_3_ext0(float %a, <2 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss2S_3_ext0:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov s2, v1.s[1]
-; CHECK-NEXT:    fmadd s0, s1, s2, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x float> %v, i32 0
   %tmp1 = extractelement <2 x float> %v, i32 1
@@ -133,8 +130,7 @@ define float @test_fmla_ss2S_3_ext0_swp(float %a, <2 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss2S_3_ext0_swp:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov s2, v1.s[1]
-; CHECK-NEXT:    fmadd s0, s2, s1, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x float> %v, i32 0
   %tmp1 = extractelement <2 x float> %v, i32 1
@@ -218,8 +214,7 @@ define double @test_fmla_dd2D_1_swap(double %a, double %b, <2 x double> %v) {
 define double @test_fmla_ss2D_1_ext0(double %a, <2 x double> %v) {
 ; CHECK-LABEL: test_fmla_ss2D_1_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov d2, v1.d[1]
-; CHECK-NEXT:    fmadd d0, d1, d2, d0
+; CHECK-NEXT:    fmla d0, d1, v1.d[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x double> %v, i32 0
   %tmp1 = extractelement <2 x double> %v, i32 1
@@ -230,8 +225,7 @@ define double @test_fmla_ss2D_1_ext0(double %a, <2 x double> %v) {
 define double @test_fmla_ss2D_1_ext0_swp(double %a, <2 x double> %v) {
 ; CHECK-LABEL: test_fmla_ss2D_1_ext0_swp:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov d2, v1.d[1]
-; CHECK-NEXT:    fmadd d0, d2, d1, d0
+; CHECK-NEXT:    fmla d0, d1, v1.d[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x double> %v, i32 0
   %tmp1 = extractelement <2 x double> %v, i32 1
@@ -340,8 +334,7 @@ define float @test_fmls_ss2S_1(float %a, float %b, <2 x float> %v) {
 define float @test_fmls_ss4S_3_ext0(float %a, <4 x float> %v) {
 ; CHECK-LABEL: test_fmls_ss4S_3_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov s2, v1.s[3]
-; CHECK-NEXT:    fmsub s0, s1, s2, s0
+; CHECK-NEXT:    fmls s0, s1, v1.s[3]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <4 x float> %v, i32 0
   %tmp1 = extractelement <4 x float> %v, i32 3
@@ -437,8 +430,7 @@ define double @test_fmls_dd2D_1_swap(double %a, double %b, <2 x double> %v) {
 define double @test_fmls_dd2D_1_ext0(double %a, <2 x double> %v) {
 ; CHECK-LABEL: test_fmls_dd2D_1_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov d2, v1.d[1]
-; CHECK-NEXT:    fmsub d0, d1, d2, d0
+; CHECK-NEXT:    fmls d0, d1, v1.d[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x double> %v, i32 0
   %tmp1 = extractelement <2 x double> %v, i32 1



More information about the llvm-commits mailing list