[llvm] 61a238e - [RISCV] Add isel patterns for fixed vector fmsub/fnmadd/fnmsub.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 16 12:04:57 PST 2021


Author: Craig Topper
Date: 2021-02-16T12:03:33-08:00
New Revision: 61a238e6e134ca610d5534bb2ff7121f3ce0083b

URL: https://github.com/llvm/llvm-project/commit/61a238e6e134ca610d5534bb2ff7121f3ce0083b
DIFF: https://github.com/llvm/llvm-project/commit/61a238e6e134ca610d5534bb2ff7121f3ce0083b.diff

LOG: [RISCV] Add isel patterns for fixed vector fmsub/fnmadd/fnmsub.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 7eb569d17ca2..739b3f71e4bf 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -92,7 +92,7 @@ def SDT_RISCVVecFMA_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
                                               SDTCVecEltisVT<4, i1>,
                                               SDTCisSameNumEltsAs<0, 4>,
                                               SDTCisVT<5, XLenVT>]>;
-def riscv_fma_vl : SDNode<"RISCVISD::FMA_VL", SDT_RISCVVecFMA_VL>;
+def riscv_fma_vl : SDNode<"RISCVISD::FMA_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
 
 def riscv_setcc_vl : SDNode<"RISCVISD::SETCC_VL",
                             SDTypeProfile<1, 5, [SDTCVecEltisVT<0, i1>,
@@ -472,12 +472,42 @@ foreach vti = AllFloatVectors in {
   // NOTE: We choose VFMADD because it has the most commuting freedom. So it
   // works best with how TwoAddressInstructionPass tries commuting.
   defvar suffix = vti.LMul.MX # "_COMMUTABLE";
-  def : Pat<(vti.Vector (riscv_fma_vl vti.RegClass:$rd, vti.RegClass:$rs1,
+  def : Pat<(vti.Vector (riscv_fma_vl vti.RegClass:$rs1, vti.RegClass:$rd,
                                       vti.RegClass:$rs2, (vti.Mask true_mask),
                                       (XLenVT (VLOp GPR:$vl)))),
             (!cast<Instruction>("PseudoVFMADD_VV_"# suffix)
                  vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                  GPR:$vl, vti.SEW)>;
+  def : Pat<(vti.Vector (riscv_fma_vl vti.RegClass:$rs1, vti.RegClass:$rd,
+                                      (riscv_fneg_vl vti.RegClass:$rs2,
+                                                     (vti.Mask true_mask),
+                                                     (XLenVT (VLOp GPR:$vl))),
+                                      (vti.Mask true_mask),
+                                      (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVFMSUB_VV_"# suffix)
+                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.SEW)>;
+  def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl vti.RegClass:$rs1,
+                                                     (vti.Mask true_mask),
+                                                     (XLenVT (VLOp GPR:$vl))),
+                                      vti.RegClass:$rd,
+                                      (riscv_fneg_vl vti.RegClass:$rs2,
+                                                     (vti.Mask true_mask),
+                                                     (XLenVT (VLOp GPR:$vl))),
+                                      (vti.Mask true_mask),
+                                      (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVFNMADD_VV_"# suffix)
+                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.SEW)>;
+  def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl vti.RegClass:$rs1,
+                                                     (vti.Mask true_mask),
+                                                     (XLenVT (VLOp GPR:$vl))),
+                                      vti.RegClass:$rd, vti.RegClass:$rs2,
+                                      (vti.Mask true_mask),
+                                      (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVFNMSUB_VV_"# suffix)
+                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.SEW)>;
 
   // The choice of VFMADD here is arbitrary, vfmadd.vf and vfmacc.vf are equally
   // commutable.
@@ -488,6 +518,61 @@ foreach vti = AllFloatVectors in {
             (!cast<Instruction>("PseudoVFMADD_V" # vti.ScalarSuffix # "_" # suffix)
                  vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
                  GPR:$vl, vti.SEW)>;
+  def : Pat<(vti.Vector (riscv_fma_vl (SplatFPOp vti.ScalarRegClass:$rs1),
+                                       vti.RegClass:$rd,
+                                       (riscv_fneg_vl vti.RegClass:$rs2,
+                                                      (vti.Mask true_mask),
+                                                      (XLenVT (VLOp GPR:$vl))),
+                                       (vti.Mask true_mask),
+                                       (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVFMSUB_V" # vti.ScalarSuffix # "_" # suffix)
+                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.SEW)>;
+  def : Pat<(vti.Vector (riscv_fma_vl (SplatFPOp vti.ScalarRegClass:$rs1),
+                                       (riscv_fneg_vl vti.RegClass:$rd,
+                                                      (vti.Mask true_mask),
+                                                      (XLenVT (VLOp GPR:$vl))),
+                                       (riscv_fneg_vl vti.RegClass:$rs2,
+                                                      (vti.Mask true_mask),
+                                                      (XLenVT (VLOp GPR:$vl))),
+                                       (vti.Mask true_mask),
+                                       (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVFNMADD_V" # vti.ScalarSuffix # "_" # suffix)
+                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.SEW)>;
+  def : Pat<(vti.Vector (riscv_fma_vl (SplatFPOp vti.ScalarRegClass:$rs1),
+                                       (riscv_fneg_vl vti.RegClass:$rd,
+                                                      (vti.Mask true_mask),
+                                                      (XLenVT (VLOp GPR:$vl))),
+                                       vti.RegClass:$rs2,
+                                       (vti.Mask true_mask),
+                                       (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVFNMSUB_V" # vti.ScalarSuffix # "_" # suffix)
+                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.SEW)>;
+
+  // The splat might be negated.
+  def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl (SplatFPOp vti.ScalarRegClass:$rs1),
+                                                     (vti.Mask true_mask),
+                                                     (XLenVT (VLOp GPR:$vl))),
+                                       vti.RegClass:$rd,
+                                       (riscv_fneg_vl vti.RegClass:$rs2,
+                                                      (vti.Mask true_mask),
+                                                      (XLenVT (VLOp GPR:$vl))),
+                                       (vti.Mask true_mask),
+                                       (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVFNMADD_V" # vti.ScalarSuffix # "_" # suffix)
+                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.SEW)>;
+  def : Pat<(vti.Vector (riscv_fma_vl (riscv_fneg_vl (SplatFPOp vti.ScalarRegClass:$rs1),
+                                                     (vti.Mask true_mask),
+                                                     (XLenVT (VLOp GPR:$vl))),
+                                       vti.RegClass:$rd, vti.RegClass:$rs2,
+                                       (vti.Mask true_mask),
+                                       (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVFNMSUB_V" # vti.ScalarSuffix # "_" # suffix)
+                 vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.SEW)>;
 }
 
 // 14.11. Vector Floating-Point Compare Instructions

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll
index 4dd8260a5b48..be8d54f287b7 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll
@@ -409,6 +409,67 @@ define void @fma_v2f64(<2 x double>* %x, <2 x double>* %y, <2 x double>* %z) {
 }
 declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>)
 
+define void @fmsub_v8f16(<8 x half>* %x, <8 x half>* %y, <8 x half>* %z) {
+; CHECK-LABEL: fmsub_v8f16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a3, zero, 8
+; CHECK-NEXT:    vsetvli a3, a3, e16,m1,ta,mu
+; CHECK-NEXT:    vle16.v v25, (a0)
+; CHECK-NEXT:    vle16.v v26, (a1)
+; CHECK-NEXT:    vle16.v v27, (a2)
+; CHECK-NEXT:    vfmsac.vv v27, v25, v26
+; CHECK-NEXT:    vse16.v v27, (a0)
+; CHECK-NEXT:    ret
+  %a = load <8 x half>, <8 x half>* %x
+  %b = load <8 x half>, <8 x half>* %y
+  %c = load <8 x half>, <8 x half>* %z
+  %neg = fneg <8 x half> %c
+  %d = call <8 x half> @llvm.fma.v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %neg)
+  store <8 x half> %d, <8 x half>* %x
+  ret void
+}
+
+define void @fnmsub_v4f32(<4 x float>* %x, <4 x float>* %y, <4 x float>* %z) {
+; CHECK-LABEL: fnmsub_v4f32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a3, zero, 4
+; CHECK-NEXT:    vsetvli a3, a3, e32,m1,ta,mu
+; CHECK-NEXT:    vle32.v v25, (a0)
+; CHECK-NEXT:    vle32.v v26, (a1)
+; CHECK-NEXT:    vle32.v v27, (a2)
+; CHECK-NEXT:    vfnmsac.vv v27, v25, v26
+; CHECK-NEXT:    vse32.v v27, (a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x float>, <4 x float>* %x
+  %b = load <4 x float>, <4 x float>* %y
+  %c = load <4 x float>, <4 x float>* %z
+  %neg = fneg <4 x float> %a
+  %d = call <4 x float> @llvm.fma.v4f32(<4 x float> %neg, <4 x float> %b, <4 x float> %c)
+  store <4 x float> %d, <4 x float>* %x
+  ret void
+}
+
+define void @fnmadd_v2f64(<2 x double>* %x, <2 x double>* %y, <2 x double>* %z) {
+; CHECK-LABEL: fnmadd_v2f64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a3, zero, 2
+; CHECK-NEXT:    vsetvli a3, a3, e64,m1,ta,mu
+; CHECK-NEXT:    vle64.v v25, (a0)
+; CHECK-NEXT:    vle64.v v26, (a1)
+; CHECK-NEXT:    vle64.v v27, (a2)
+; CHECK-NEXT:    vfnmacc.vv v27, v25, v26
+; CHECK-NEXT:    vse64.v v27, (a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x double>, <2 x double>* %x
+  %b = load <2 x double>, <2 x double>* %y
+  %c = load <2 x double>, <2 x double>* %z
+  %neg = fneg <2 x double> %b
+  %neg2 = fneg <2 x double> %c
+  %d = call <2 x double> @llvm.fma.v2f64(<2 x double> %a, <2 x double> %neg, <2 x double> %neg2)
+  store <2 x double> %d, <2 x double>* %x
+  ret void
+}
+
 define void @fadd_v16f16(<16 x half>* %x, <16 x half>* %y) {
 ; LMULMAX2-LABEL: fadd_v16f16:
 ; LMULMAX2:       # %bb.0:
@@ -1613,9 +1674,8 @@ define void @fma_vf_v8f16(<8 x half>* %x, <8 x half>* %y, half %z) {
 ; CHECK-NEXT:    vsetvli a2, a2, e16,m1,ta,mu
 ; CHECK-NEXT:    vle16.v v25, (a0)
 ; CHECK-NEXT:    vle16.v v26, (a1)
-; CHECK-NEXT:    vfmv.v.f v27, fa0
-; CHECK-NEXT:    vfmadd.vv v27, v25, v26
-; CHECK-NEXT:    vse16.v v27, (a0)
+; CHECK-NEXT:    vfmacc.vf v26, fa0, v25
+; CHECK-NEXT:    vse16.v v26, (a0)
 ; CHECK-NEXT:    ret
   %a = load <8 x half>, <8 x half>* %x
   %b = load <8 x half>, <8 x half>* %y
@@ -1633,9 +1693,8 @@ define void @fma_vf_v4f32(<4 x float>* %x, <4 x float>* %y, float %z) {
 ; CHECK-NEXT:    vsetvli a2, a2, e32,m1,ta,mu
 ; CHECK-NEXT:    vle32.v v25, (a0)
 ; CHECK-NEXT:    vle32.v v26, (a1)
-; CHECK-NEXT:    vfmv.v.f v27, fa0
-; CHECK-NEXT:    vfmadd.vv v27, v25, v26
-; CHECK-NEXT:    vse32.v v27, (a0)
+; CHECK-NEXT:    vfmacc.vf v26, fa0, v25
+; CHECK-NEXT:    vse32.v v26, (a0)
 ; CHECK-NEXT:    ret
   %a = load <4 x float>, <4 x float>* %x
   %b = load <4 x float>, <4 x float>* %y
@@ -1653,9 +1712,8 @@ define void @fma_vf_v2f64(<2 x double>* %x, <2 x double>* %y, double %z) {
 ; CHECK-NEXT:    vsetvli a2, a2, e64,m1,ta,mu
 ; CHECK-NEXT:    vle64.v v25, (a0)
 ; CHECK-NEXT:    vle64.v v26, (a1)
-; CHECK-NEXT:    vfmv.v.f v27, fa0
-; CHECK-NEXT:    vfmadd.vv v27, v25, v26
-; CHECK-NEXT:    vse64.v v27, (a0)
+; CHECK-NEXT:    vfmacc.vf v26, fa0, v25
+; CHECK-NEXT:    vse64.v v26, (a0)
 ; CHECK-NEXT:    ret
   %a = load <2 x double>, <2 x double>* %x
   %b = load <2 x double>, <2 x double>* %y
@@ -1722,3 +1780,105 @@ define void @fma_fv_v2f64(<2 x double>* %x, <2 x double>* %y, double %z) {
   store <2 x double> %e, <2 x double>* %x
   ret void
 }
+
+define void @fmsub_vf_v8f16(<8 x half>* %x, <8 x half>* %y, half %z) {
+; CHECK-LABEL: fmsub_vf_v8f16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, zero, 8
+; CHECK-NEXT:    vsetvli a2, a2, e16,m1,ta,mu
+; CHECK-NEXT:    vle16.v v25, (a0)
+; CHECK-NEXT:    vle16.v v26, (a1)
+; CHECK-NEXT:    vfmsac.vf v26, fa0, v25
+; CHECK-NEXT:    vse16.v v26, (a0)
+; CHECK-NEXT:    ret
+  %a = load <8 x half>, <8 x half>* %x
+  %b = load <8 x half>, <8 x half>* %y
+  %c = insertelement <8 x half> undef, half %z, i32 0
+  %d = shufflevector <8 x half> %c, <8 x half> undef, <8 x i32> zeroinitializer
+  %neg = fneg <8 x half> %b
+  %e = call <8 x half> @llvm.fma.v8f16(<8 x half> %a, <8 x half> %d, <8 x half> %neg)
+  store <8 x half> %e, <8 x half>* %x
+  ret void
+}
+
+define void @fnmsub_vf_v4f32(<4 x float>* %x, <4 x float>* %y, float %z) {
+; CHECK-LABEL: fnmsub_vf_v4f32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, zero, 4
+; CHECK-NEXT:    vsetvli a2, a2, e32,m1,ta,mu
+; CHECK-NEXT:    vle32.v v25, (a0)
+; CHECK-NEXT:    vle32.v v26, (a1)
+; CHECK-NEXT:    vfnmsac.vf v26, fa0, v25
+; CHECK-NEXT:    vse32.v v26, (a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x float>, <4 x float>* %x
+  %b = load <4 x float>, <4 x float>* %y
+  %c = insertelement <4 x float> undef, float %z, i32 0
+  %d = shufflevector <4 x float> %c, <4 x float> undef, <4 x i32> zeroinitializer
+  %neg = fneg <4 x float> %a
+  %e = call <4 x float> @llvm.fma.v4f32(<4 x float> %neg, <4 x float> %d, <4 x float> %b)
+  store <4 x float> %e, <4 x float>* %x
+  ret void
+}
+
+define void @fnmadd_vf_v2f64(<2 x double>* %x, <2 x double>* %y, double %z) {
+; CHECK-LABEL: fnmadd_vf_v2f64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, zero, 2
+; CHECK-NEXT:    vsetvli a2, a2, e64,m1,ta,mu
+; CHECK-NEXT:    vle64.v v25, (a0)
+; CHECK-NEXT:    vle64.v v26, (a1)
+; CHECK-NEXT:    vfnmacc.vf v26, fa0, v25
+; CHECK-NEXT:    vse64.v v26, (a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x double>, <2 x double>* %x
+  %b = load <2 x double>, <2 x double>* %y
+  %c = insertelement <2 x double> undef, double %z, i32 0
+  %d = shufflevector <2 x double> %c, <2 x double> undef, <2 x i32> zeroinitializer
+  %neg = fneg <2 x double> %a
+  %neg2 = fneg <2 x double> %b
+  %e = call <2 x double> @llvm.fma.v2f64(<2 x double> %neg, <2 x double> %d, <2 x double> %neg2)
+  store <2 x double> %e, <2 x double>* %x
+  ret void
+}
+
+define void @fnmsub_fv_v4f32(<4 x float>* %x, <4 x float>* %y, float %z) {
+; CHECK-LABEL: fnmsub_fv_v4f32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, zero, 4
+; CHECK-NEXT:    vsetvli a2, a2, e32,m1,ta,mu
+; CHECK-NEXT:    vle32.v v25, (a0)
+; CHECK-NEXT:    vle32.v v26, (a1)
+; CHECK-NEXT:    vfnmsac.vf v26, fa0, v25
+; CHECK-NEXT:    vse32.v v26, (a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x float>, <4 x float>* %x
+  %b = load <4 x float>, <4 x float>* %y
+  %c = insertelement <4 x float> undef, float %z, i32 0
+  %d = shufflevector <4 x float> %c, <4 x float> undef, <4 x i32> zeroinitializer
+  %neg = fneg <4 x float> %d
+  %e = call <4 x float> @llvm.fma.v4f32(<4 x float> %neg, <4 x float> %a, <4 x float> %b)
+  store <4 x float> %e, <4 x float>* %x
+  ret void
+}
+
+define void @fnmadd_fv_v2f64(<2 x double>* %x, <2 x double>* %y, double %z) {
+; CHECK-LABEL: fnmadd_fv_v2f64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, zero, 2
+; CHECK-NEXT:    vsetvli a2, a2, e64,m1,ta,mu
+; CHECK-NEXT:    vle64.v v25, (a0)
+; CHECK-NEXT:    vle64.v v26, (a1)
+; CHECK-NEXT:    vfnmacc.vf v26, fa0, v25
+; CHECK-NEXT:    vse64.v v26, (a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x double>, <2 x double>* %x
+  %b = load <2 x double>, <2 x double>* %y
+  %c = insertelement <2 x double> undef, double %z, i32 0
+  %d = shufflevector <2 x double> %c, <2 x double> undef, <2 x i32> zeroinitializer
+  %neg = fneg <2 x double> %d
+  %neg2 = fneg <2 x double> %b
+  %e = call <2 x double> @llvm.fma.v2f64(<2 x double> %neg, <2 x double> %a, <2 x double> %neg2)
+  store <2 x double> %e, <2 x double>* %x
+  ret void
+}


        


More information about the llvm-commits mailing list