[llvm] f2a05c6 - [RISCV] Add RISCVISD nodes for VWFMADD_VL.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Sun May 14 22:41:30 PDT 2023


Author: Craig Topper
Date: 2023-05-14T22:35:47-07:00
New Revision: f2a05c64e3880278c8b3afa5a78a93eb26d244e5

URL: https://github.com/llvm/llvm-project/commit/f2a05c64e3880278c8b3afa5a78a93eb26d244e5
DIFF: https://github.com/llvm/llvm-project/commit/f2a05c64e3880278c8b3afa5a78a93eb26d244e5.diff

LOG: [RISCV] Add RISCVISD nodes for VWFMADD_VL.

Use it to replace isel patterns with a DAG combine of FP_EXTEND_VL+VFMADD_VL.

This makes it similar to how other widening operations are handled.

I plan to use this to make it easier to form tail undisturbed vfwmacc.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 064a283d1cc4..3edff32f09ea 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11196,7 +11196,7 @@ static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {
   return Opcode;
 }
 
-static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) {
+static SDValue combineVFMADD_VLWithVFNEG_VL(SDNode *N, SelectionDAG &DAG) {
   // Fold FNEG_VL into FMA opcodes.
   // The first operand of strict-fp is chain.
   unsigned Offset = N->isTargetStrictFPOpcode();
@@ -11233,6 +11233,59 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) {
                      VL);
 }
 
+static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) {
+  if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
+    return V;
+
+  // FIXME: Ignore strict opcodes for now.
+  if (N->isTargetStrictFPOpcode())
+    return SDValue();
+
+  // Try to form widening FMA.
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue Mask = N->getOperand(3);
+  SDValue VL = N->getOperand(4);
+
+  if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
+      Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
+    return SDValue();
+
+  // TODO: Refactor to handle more complex cases similar to
+  // combineBinOp_VLToVWBinOp_VL.
+  if (!Op0.hasOneUse() || !Op1.hasOneUse())
+    return SDValue();
+
+  // Check the mask and VL are the same.
+  if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
+      Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
+    return SDValue();
+
+  unsigned NewOpc;
+  switch (N->getOpcode()) {
+  default:
+    llvm_unreachable("Unexpected opcode");
+  case RISCVISD::VFMADD_VL:
+    NewOpc = RISCVISD::VFWMADD_VL;
+    break;
+  case RISCVISD::VFNMSUB_VL:
+    NewOpc = RISCVISD::VFWNMSUB_VL;
+    break;
+  case RISCVISD::VFNMADD_VL:
+    NewOpc = RISCVISD::VFWNMADD_VL;
+    break;
+  case RISCVISD::VFMSUB_VL:
+    NewOpc = RISCVISD::VFWMSUB_VL;
+    break;
+  }
+
+  Op0 = Op0.getOperand(0);
+  Op1 = Op1.getOperand(0);
+
+  return DAG.getNode(NewOpc, SDLoc(N), N->getValueType(0), Op0, Op1,
+                     N->getOperand(2), Mask, VL);
+}
+
 static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
   assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
@@ -15074,6 +15127,10 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(VFNMADD_VL)
   NODE_NAME_CASE(VFMSUB_VL)
   NODE_NAME_CASE(VFNMSUB_VL)
+  NODE_NAME_CASE(VFWMADD_VL)
+  NODE_NAME_CASE(VFWNMADD_VL)
+  NODE_NAME_CASE(VFWMSUB_VL)
+  NODE_NAME_CASE(VFWNMSUB_VL)
   NODE_NAME_CASE(FCOPYSIGN_VL)
   NODE_NAME_CASE(SMIN_VL)
   NODE_NAME_CASE(SMAX_VL)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 54e0b18ee8a6..3936c51884cb 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -256,6 +256,13 @@ enum NodeType : unsigned {
   VFMSUB_VL,
   VFNMSUB_VL,
 
+  // Vector widening FMA ops with a mask as a fourth operand and VL as a fifth
+  // operand.
+  VFWMADD_VL,
+  VFWNMADD_VL,
+  VFWMSUB_VL,
+  VFWNMSUB_VL,
+
   // Widening instructions with a merge value a third operand, a mask as a
   // fourth operand, and VL as a fifth operand.
   VWMUL_VL,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index fc59d1f049ed..ea5084620bb5 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -136,11 +136,25 @@ def SDT_RISCVVecFMA_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
                                               SDTCVecEltisVT<4, i1>,
                                               SDTCisSameNumEltsAs<0, 4>,
                                               SDTCisVT<5, XLenVT>]>;
-def riscv_vfmadd_vl : SDNode<"RISCVISD::VFMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
+def riscv_vfmadd_vl  : SDNode<"RISCVISD::VFMADD_VL",  SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
 def riscv_vfnmadd_vl : SDNode<"RISCVISD::VFNMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
-def riscv_vfmsub_vl : SDNode<"RISCVISD::VFMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
+def riscv_vfmsub_vl  : SDNode<"RISCVISD::VFMSUB_VL",  SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
 def riscv_vfnmsub_vl : SDNode<"RISCVISD::VFNMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
 
+def SDT_RISCVWVecFMA_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>,
+                                               SDTCisVec<1>, SDTCisFP<1>,
+                                               SDTCisOpSmallerThanOp<1, 0>,
+                                               SDTCisSameNumEltsAs<0, 1>,
+                                               SDTCisSameAs<1, 2>,
+                                               SDTCisSameAs<0, 3>,
+                                               SDTCVecEltisVT<4, i1>,
+                                               SDTCisSameNumEltsAs<0, 4>,
+                                               SDTCisVT<5, XLenVT>]>;
+def riscv_vfwmadd_vl  : SDNode<"RISCVISD::VFWMADD_VL",  SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
+def riscv_vfwnmadd_vl : SDNode<"RISCVISD::VFWNMADD_VL", SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
+def riscv_vfwmsub_vl  : SDNode<"RISCVISD::VFWMSUB_VL",  SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
+def riscv_vfwnmsub_vl : SDNode<"RISCVISD::VFWNMSUB_VL", SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
+
 def riscv_strict_vfmadd_vl : SDNode<"RISCVISD::STRICT_VFMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
 def riscv_strict_vfnmadd_vl : SDNode<"RISCVISD::STRICT_VFNMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
 def riscv_strict_vfmsub_vl : SDNode<"RISCVISD::STRICT_VFMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
@@ -1514,25 +1528,15 @@ multiclass VPatWidenFPMulAccVL_VV_VF<SDNode vop, string instruction_name> {
   foreach vtiToWti = AllWidenableFloatVectors in {
     defvar vti = vtiToWti.Vti;
     defvar wti = vtiToWti.Wti;
-    def : Pat<(vop
-                   (wti.Vector (riscv_fpextend_vl_oneuse
-                                    (vti.Vector vti.RegClass:$rs1),
-                                    (vti.Mask true_mask), VLOpFrag)),
-                   (wti.Vector (riscv_fpextend_vl_oneuse
-                                    (vti.Vector vti.RegClass:$rs2),
-                                    (vti.Mask true_mask), VLOpFrag)),
+    def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
+                   (vti.Vector vti.RegClass:$rs2),
                    (wti.Vector wti.RegClass:$rd), (vti.Mask true_mask),
                    VLOpFrag),
               (!cast<Instruction>(instruction_name#"_VV_"#vti.LMul.MX)
                  wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                  GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-    def : Pat<(vop
-                   (wti.Vector (riscv_fpextend_vl_oneuse
-                                    (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs1)),
-                                    (vti.Mask true_mask), VLOpFrag)),
-                   (wti.Vector (riscv_fpextend_vl_oneuse
-                                    (vti.Vector vti.RegClass:$rs2),
-                                    (vti.Mask true_mask), VLOpFrag)),
+    def : Pat<(vop (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs1)),
+                   (vti.Vector vti.RegClass:$rs2),
                    (wti.Vector wti.RegClass:$rd), (vti.Mask true_mask),
                    VLOpFrag),
               (!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#vti.LMul.MX)
@@ -1827,10 +1831,10 @@ defm : VPatFPMulAccVL_VV_VF<riscv_vfnmadd_vl_oneuse, "PseudoVFNMACC">;
 defm : VPatFPMulAccVL_VV_VF<riscv_vfnmsub_vl_oneuse, "PseudoVFNMSAC">;
 
 // 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
-defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfmadd_vl, "PseudoVFWMACC">;
-defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfnmadd_vl, "PseudoVFWNMACC">;
-defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfmsub_vl, "PseudoVFWMSAC">;
-defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfnmsub_vl, "PseudoVFWNMSAC">;
+defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwmadd_vl, "PseudoVFWMACC">;
+defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwnmadd_vl, "PseudoVFWNMACC">;
+defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwmsub_vl, "PseudoVFWMSAC">;
+defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwnmsub_vl, "PseudoVFWNMSAC">;
 
 // 13.11. Vector Floating-Point MIN/MAX Instructions
 defm : VPatBinaryFPVL_VV_VF<riscv_fminnum_vl, "PseudoVFMIN">;


        


More information about the llvm-commits mailing list