[llvm] b8d719f - [RISCV] Add support for fixed vector FMA.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 8 11:14:05 PST 2021


Author: Craig Topper
Date: 2021-02-08T11:12:56-08:00
New Revision: b8d719fbe81c88ec9e8c9dbe406c1b7de4c1ba05

URL: https://github.com/llvm/llvm-project/commit/b8d719fbe81c88ec9e8c9dbe406c1b7de4c1ba05
DIFF: https://github.com/llvm/llvm-project/commit/b8d719fbe81c88ec9e8c9dbe406c1b7de4c1ba05.diff

LOG: [RISCV] Add support for fixed vector FMA.

Follow up to D95705. Does not include the commuting support from D95800.

Differential Revision: https://reviews.llvm.org/D96103

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index fc130e3b26de..525c08f22b9f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -558,6 +558,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
         setOperationAction(ISD::FMUL, VT, Custom);
         setOperationAction(ISD::FDIV, VT, Custom);
         setOperationAction(ISD::FNEG, VT, Custom);
+        setOperationAction(ISD::FMA, VT, Custom);
       }
     }
   }
@@ -1044,6 +1045,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     return lowerToScalableOp(Op, DAG, RISCVISD::FDIV_VL);
   case ISD::FNEG:
     return lowerToScalableOp(Op, DAG, RISCVISD::FNEG_VL);
+  case ISD::FMA:
+    return lowerToScalableOp(Op, DAG, RISCVISD::FMA_VL);
   }
 }
 
@@ -4575,6 +4578,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(FMUL_VL)
   NODE_NAME_CASE(FDIV_VL)
   NODE_NAME_CASE(FNEG_VL)
+  NODE_NAME_CASE(FMA_VL)
   NODE_NAME_CASE(VMCLR_VL)
   NODE_NAME_CASE(VMSET_VL)
   NODE_NAME_CASE(VLE_VL)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index db209bb80193..354a6b767fd1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -154,6 +154,7 @@ enum NodeType : unsigned {
   FMUL_VL,
   FDIV_VL,
   FNEG_VL,
+  FMA_VL,
 
   // Set mask vector to all zeros or ones.
   VMCLR_VL,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 765f08c67667..b23e5bff109b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -70,6 +70,15 @@ def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommut
 def riscv_fdiv_vl : SDNode<"RISCVISD::FDIV_VL", SDT_RISCVFPBinOp_VL>;
 def riscv_fneg_vl : SDNode<"RISCVISD::FNEG_VL", SDT_RISCVFPUnOp_VL>;
 
+def SDT_RISCVVecFMA_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
+                                              SDTCisSameAs<0, 2>,
+                                              SDTCisSameAs<0, 3>,
+                                              SDTCisVec<0>, SDTCisFP<0>,
+                                              SDTCVecEltisVT<4, i1>,
+                                              SDTCisSameNumEltsAs<0, 4>,
+                                              SDTCisVT<5, XLenVT>]>;
+def riscv_fma_vl : SDNode<"RISCVISD::FMA_VL", SDT_RISCVVecFMA_VL>;
+
 def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCisVec<0>,
                                                 SDTCVecEltisVT<0, i1>,
                                                 SDTCisVT<1, XLenVT>]>;
@@ -178,7 +187,19 @@ defm "" : VPatBinaryFPVL_VV_VF<riscv_fsub_vl, "PseudoVFSUB">;
 defm "" : VPatBinaryFPVL_VV_VF<riscv_fmul_vl, "PseudoVFMUL">;
 defm "" : VPatBinaryFPVL_VV_VF<riscv_fdiv_vl, "PseudoVFDIV">;
 
-// 14.10. Vector Floating-Point Sign-Injection Instructions
+// 14.6 Vector Single-Width Floating-Point Fused Multiply-Add Instructions.
+foreach vti = AllFloatVectors in {
+  // NOTE: We choose VFMADD because it has the most commuting freedom. So it
+  // works best with how TwoAddressInstructionPass tries commuting.
+  def : Pat<(vti.Vector (riscv_fma_vl vti.RegClass:$rd, vti.RegClass:$rs1,
+                                      vti.RegClass:$rs2, (vti.Mask true_mask),
+                                      (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVFMADD_VV_"# vti.LMul.MX)
+                 vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                 GPR:$vl, vti.SEW)>;
+}
+
+// 14.12. Vector Floating-Point Sign-Injection Instructions
 // Handle fneg with VFSGNJN using the same input for both operands.
 foreach vti = AllFloatVectors in {
   def : Pat<(riscv_fneg_vl (vti.Vector vti.RegClass:$rs), (vti.Mask true_mask),

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll
index 17ac800f9a0b..bfeba6939208 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll
@@ -253,6 +253,72 @@ define void @fneg_v2f64(<2 x double>* %x) {
   ret void
 }
 
+define void @fma_v8f16(<8 x half>* %x, <8 x half>* %y, <8 x half>* %z) {
+; CHECK-LABEL: fma_v8f16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a3, zero, 8
+; CHECK-NEXT:    vsetvli a4, a3, e16,m1,ta,mu
+; CHECK-NEXT:    vle16.v v25, (a0)
+; CHECK-NEXT:    vle16.v v26, (a1)
+; CHECK-NEXT:    vle16.v v27, (a2)
+; CHECK-NEXT:    vsetvli a1, a3, e16,m1,tu,mu
+; CHECK-NEXT:    vfmadd.vv v25, v26, v27
+; CHECK-NEXT:    vsetvli a1, a3, e16,m1,ta,mu
+; CHECK-NEXT:    vse16.v v25, (a0)
+; CHECK-NEXT:    ret
+  %a = load <8 x half>, <8 x half>* %x
+  %b = load <8 x half>, <8 x half>* %y
+  %c = load <8 x half>, <8 x half>* %z
+  %d = call <8 x half> @llvm.fma.v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c)
+  store <8 x half> %d, <8 x half>* %x
+  ret void
+}
+declare <8 x half> @llvm.fma.v8f16(<8 x half>, <8 x half>, <8 x half>)
+
+define void @fma_v4f32(<4 x float>* %x, <4 x float>* %y, <4 x float>* %z) {
+; CHECK-LABEL: fma_v4f32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a3, zero, 4
+; CHECK-NEXT:    vsetvli a4, a3, e32,m1,ta,mu
+; CHECK-NEXT:    vle32.v v25, (a0)
+; CHECK-NEXT:    vle32.v v26, (a1)
+; CHECK-NEXT:    vle32.v v27, (a2)
+; CHECK-NEXT:    vsetvli a1, a3, e32,m1,tu,mu
+; CHECK-NEXT:    vfmadd.vv v25, v26, v27
+; CHECK-NEXT:    vsetvli a1, a3, e32,m1,ta,mu
+; CHECK-NEXT:    vse32.v v25, (a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x float>, <4 x float>* %x
+  %b = load <4 x float>, <4 x float>* %y
+  %c = load <4 x float>, <4 x float>* %z
+  %d = call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c)
+  store <4 x float> %d, <4 x float>* %x
+  ret void
+}
+declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>)
+
+define void @fma_v2f64(<2 x double>* %x, <2 x double>* %y, <2 x double>* %z) {
+; CHECK-LABEL: fma_v2f64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a3, zero, 2
+; CHECK-NEXT:    vsetvli a4, a3, e64,m1,ta,mu
+; CHECK-NEXT:    vle64.v v25, (a0)
+; CHECK-NEXT:    vle64.v v26, (a1)
+; CHECK-NEXT:    vle64.v v27, (a2)
+; CHECK-NEXT:    vsetvli a1, a3, e64,m1,tu,mu
+; CHECK-NEXT:    vfmadd.vv v25, v26, v27
+; CHECK-NEXT:    vsetvli a1, a3, e64,m1,ta,mu
+; CHECK-NEXT:    vse64.v v25, (a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x double>, <2 x double>* %x
+  %b = load <2 x double>, <2 x double>* %y
+  %c = load <2 x double>, <2 x double>* %z
+  %d = call <2 x double> @llvm.fma.v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c)
+  store <2 x double> %d, <2 x double>* %x
+  ret void
+}
+declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>)
+
 define void @fadd_v16f16(<16 x half>* %x, <16 x half>* %y) {
 ; LMULMAX2-LABEL: fadd_v16f16:
 ; LMULMAX2:       # %bb.0:
@@ -924,3 +990,132 @@ define void @fneg_v4f64(<4 x double>* %x) {
   store <4 x double> %b, <4 x double>* %x
   ret void
 }
+
+define void @fma_v16f16(<16 x half>* %x, <16 x half>* %y, <16 x half>* %z) {
+; LMULMAX2-LABEL: fma_v16f16:
+; LMULMAX2:       # %bb.0:
+; LMULMAX2-NEXT:    addi a3, zero, 16
+; LMULMAX2-NEXT:    vsetvli a4, a3, e16,m2,ta,mu
+; LMULMAX2-NEXT:    vle16.v v26, (a0)
+; LMULMAX2-NEXT:    vle16.v v28, (a1)
+; LMULMAX2-NEXT:    vle16.v v30, (a2)
+; LMULMAX2-NEXT:    vsetvli a1, a3, e16,m2,tu,mu
+; LMULMAX2-NEXT:    vfmadd.vv v26, v28, v30
+; LMULMAX2-NEXT:    vsetvli a1, a3, e16,m2,ta,mu
+; LMULMAX2-NEXT:    vse16.v v26, (a0)
+; LMULMAX2-NEXT:    ret
+;
+; LMULMAX1-LABEL: fma_v16f16:
+; LMULMAX1:       # %bb.0:
+; LMULMAX1-NEXT:    addi a3, zero, 8
+; LMULMAX1-NEXT:    vsetvli a4, a3, e16,m1,ta,mu
+; LMULMAX1-NEXT:    vle16.v v25, (a0)
+; LMULMAX1-NEXT:    addi a4, a0, 16
+; LMULMAX1-NEXT:    vle16.v v26, (a4)
+; LMULMAX1-NEXT:    vle16.v v27, (a1)
+; LMULMAX1-NEXT:    addi a1, a1, 16
+; LMULMAX1-NEXT:    vle16.v v28, (a1)
+; LMULMAX1-NEXT:    addi a1, a2, 16
+; LMULMAX1-NEXT:    vle16.v v29, (a1)
+; LMULMAX1-NEXT:    vle16.v v30, (a2)
+; LMULMAX1-NEXT:    vsetvli a1, a3, e16,m1,tu,mu
+; LMULMAX1-NEXT:    vfmadd.vv v26, v28, v29
+; LMULMAX1-NEXT:    vfmadd.vv v25, v27, v30
+; LMULMAX1-NEXT:    vsetvli a1, a3, e16,m1,ta,mu
+; LMULMAX1-NEXT:    vse16.v v25, (a0)
+; LMULMAX1-NEXT:    vse16.v v26, (a4)
+; LMULMAX1-NEXT:    ret
+  %a = load <16 x half>, <16 x half>* %x
+  %b = load <16 x half>, <16 x half>* %y
+  %c = load <16 x half>, <16 x half>* %z
+  %d = call <16 x half> @llvm.fma.v16f16(<16 x half> %a, <16 x half> %b, <16 x half> %c)
+  store <16 x half> %d, <16 x half>* %x
+  ret void
+}
+declare <16 x half> @llvm.fma.v16f16(<16 x half>, <16 x half>, <16 x half>)
+
+define void @fma_v8f32(<8 x float>* %x, <8 x float>* %y, <8 x float>* %z) {
+; LMULMAX2-LABEL: fma_v8f32:
+; LMULMAX2:       # %bb.0:
+; LMULMAX2-NEXT:    addi a3, zero, 8
+; LMULMAX2-NEXT:    vsetvli a4, a3, e32,m2,ta,mu
+; LMULMAX2-NEXT:    vle32.v v26, (a0)
+; LMULMAX2-NEXT:    vle32.v v28, (a1)
+; LMULMAX2-NEXT:    vle32.v v30, (a2)
+; LMULMAX2-NEXT:    vsetvli a1, a3, e32,m2,tu,mu
+; LMULMAX2-NEXT:    vfmadd.vv v26, v28, v30
+; LMULMAX2-NEXT:    vsetvli a1, a3, e32,m2,ta,mu
+; LMULMAX2-NEXT:    vse32.v v26, (a0)
+; LMULMAX2-NEXT:    ret
+;
+; LMULMAX1-LABEL: fma_v8f32:
+; LMULMAX1:       # %bb.0:
+; LMULMAX1-NEXT:    addi a3, zero, 4
+; LMULMAX1-NEXT:    vsetvli a4, a3, e32,m1,ta,mu
+; LMULMAX1-NEXT:    vle32.v v25, (a0)
+; LMULMAX1-NEXT:    addi a4, a0, 16
+; LMULMAX1-NEXT:    vle32.v v26, (a4)
+; LMULMAX1-NEXT:    vle32.v v27, (a1)
+; LMULMAX1-NEXT:    addi a1, a1, 16
+; LMULMAX1-NEXT:    vle32.v v28, (a1)
+; LMULMAX1-NEXT:    addi a1, a2, 16
+; LMULMAX1-NEXT:    vle32.v v29, (a1)
+; LMULMAX1-NEXT:    vle32.v v30, (a2)
+; LMULMAX1-NEXT:    vsetvli a1, a3, e32,m1,tu,mu
+; LMULMAX1-NEXT:    vfmadd.vv v26, v28, v29
+; LMULMAX1-NEXT:    vfmadd.vv v25, v27, v30
+; LMULMAX1-NEXT:    vsetvli a1, a3, e32,m1,ta,mu
+; LMULMAX1-NEXT:    vse32.v v25, (a0)
+; LMULMAX1-NEXT:    vse32.v v26, (a4)
+; LMULMAX1-NEXT:    ret
+  %a = load <8 x float>, <8 x float>* %x
+  %b = load <8 x float>, <8 x float>* %y
+  %c = load <8 x float>, <8 x float>* %z
+  %d = call <8 x float> @llvm.fma.v8f32(<8 x float> %a, <8 x float> %b, <8 x float> %c)
+  store <8 x float> %d, <8 x float>* %x
+  ret void
+}
+declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>)
+
+define void @fma_v4f64(<4 x double>* %x, <4 x double>* %y, <4 x double>* %z) {
+; LMULMAX2-LABEL: fma_v4f64:
+; LMULMAX2:       # %bb.0:
+; LMULMAX2-NEXT:    addi a3, zero, 4
+; LMULMAX2-NEXT:    vsetvli a4, a3, e64,m2,ta,mu
+; LMULMAX2-NEXT:    vle64.v v26, (a0)
+; LMULMAX2-NEXT:    vle64.v v28, (a1)
+; LMULMAX2-NEXT:    vle64.v v30, (a2)
+; LMULMAX2-NEXT:    vsetvli a1, a3, e64,m2,tu,mu
+; LMULMAX2-NEXT:    vfmadd.vv v26, v28, v30
+; LMULMAX2-NEXT:    vsetvli a1, a3, e64,m2,ta,mu
+; LMULMAX2-NEXT:    vse64.v v26, (a0)
+; LMULMAX2-NEXT:    ret
+;
+; LMULMAX1-LABEL: fma_v4f64:
+; LMULMAX1:       # %bb.0:
+; LMULMAX1-NEXT:    addi a3, zero, 2
+; LMULMAX1-NEXT:    vsetvli a4, a3, e64,m1,ta,mu
+; LMULMAX1-NEXT:    vle64.v v25, (a0)
+; LMULMAX1-NEXT:    addi a4, a0, 16
+; LMULMAX1-NEXT:    vle64.v v26, (a4)
+; LMULMAX1-NEXT:    vle64.v v27, (a1)
+; LMULMAX1-NEXT:    addi a1, a1, 16
+; LMULMAX1-NEXT:    vle64.v v28, (a1)
+; LMULMAX1-NEXT:    addi a1, a2, 16
+; LMULMAX1-NEXT:    vle64.v v29, (a1)
+; LMULMAX1-NEXT:    vle64.v v30, (a2)
+; LMULMAX1-NEXT:    vsetvli a1, a3, e64,m1,tu,mu
+; LMULMAX1-NEXT:    vfmadd.vv v26, v28, v29
+; LMULMAX1-NEXT:    vfmadd.vv v25, v27, v30
+; LMULMAX1-NEXT:    vsetvli a1, a3, e64,m1,ta,mu
+; LMULMAX1-NEXT:    vse64.v v25, (a0)
+; LMULMAX1-NEXT:    vse64.v v26, (a4)
+; LMULMAX1-NEXT:    ret
+  %a = load <4 x double>, <4 x double>* %x
+  %b = load <4 x double>, <4 x double>* %y
+  %c = load <4 x double>, <4 x double>* %z
+  %d = call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> %b, <4 x double> %c)
+  store <4 x double> %d, <4 x double>* %x
+  ret void
+}
+declare <4 x double> @llvm.fma.v4f64(<4 x double>, <4 x double>, <4 x double>)


        


More information about the llvm-commits mailing list