[llvm] 4157bfb - [RISCV] Add RISCVISD nodes for vfwadd/vfwsub.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 5 14:15:39 PDT 2023


Author: Craig Topper
Date: 2023-06-05T14:12:47-07:00
New Revision: 4157bfb230da068644c4f63fb1ee35f095f26cea

URL: https://github.com/llvm/llvm-project/commit/4157bfb230da068644c4f63fb1ee35f095f26cea
DIFF: https://github.com/llvm/llvm-project/commit/4157bfb230da068644c4f63fb1ee35f095f26cea.diff

LOG: [RISCV] Add RISCVISD nodes for vfwadd/vfwsub.

Add a DAG combine to form these from FADD_VL/FSUB_VL and FP_EXTEND_VL.

This makes it similar to other widening ops and allows us to handle
using the same FP_EXTEND_VL for both operands.

Differential Revision: https://reviews.llvm.org/D151969

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
    llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d33c48397536a..e8b6560036f08 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11473,6 +11473,58 @@ static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG) {
                      Op1, Merge, Mask, VL);
 }
 
+static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG) {
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue Merge = N->getOperand(2);
+  SDValue Mask = N->getOperand(3);
+  SDValue VL = N->getOperand(4);
+
+  bool IsAdd = N->getOpcode() == RISCVISD::FADD_VL;
+
+  // Look for foldable FP_EXTENDS.
+  bool Op0IsExtend =
+      Op0.getOpcode() == RISCVISD::FP_EXTEND_VL &&
+      (Op0.hasOneUse() || (Op0 == Op1 && Op0->hasNUsesOfValue(2, 0)));
+  bool Op1IsExtend =
+      (Op0 == Op1 && Op0IsExtend) ||
+      (Op1.getOpcode() == RISCVISD::FP_EXTEND_VL && Op1.hasOneUse());
+
+  // Check the mask and VL.
+  if (Op0IsExtend && (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL))
+    Op0IsExtend = false;
+  if (Op1IsExtend && (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL))
+    Op1IsExtend = false;
+
+  // Canonicalize.
+  if (!Op1IsExtend) {
+    // Sub requires at least operand 1 to be an extend.
+    if (!IsAdd)
+      return SDValue();
+
+    // Add is commutable, if the other operand is foldable, swap them.
+    if (!Op0IsExtend)
+      return SDValue();
+
+    std::swap(Op0, Op1);
+    std::swap(Op0IsExtend, Op1IsExtend);
+  }
+
+  // Op1 is a foldable extend. Op0 might be foldable.
+  Op1 = Op1.getOperand(0);
+  if (Op0IsExtend)
+    Op0 = Op0.getOperand(0);
+
+  unsigned Opc;
+  if (IsAdd)
+    Opc = Op0IsExtend ? RISCVISD::VFWADD_VL : RISCVISD::VFWADD_W_VL;
+  else
+    Opc = Op0IsExtend ? RISCVISD::VFWSUB_VL : RISCVISD::VFWSUB_W_VL;
+
+  return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op0, Op1, Merge, Mask,
+                     VL);
+}
+
 static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
   assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
@@ -12349,6 +12401,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     return performVFMADD_VLCombine(N, DAG);
   case RISCVISD::FMUL_VL:
     return performVFMUL_VLCombine(N, DAG);
+  case RISCVISD::FADD_VL:
+  case RISCVISD::FSUB_VL:
+    return performFADDSUB_VLCombine(N, DAG);
   case ISD::LOAD:
   case ISD::STORE: {
     if (DCI.isAfterLegalizeDAG())
@@ -15460,6 +15515,10 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(VWSUB_W_VL)
   NODE_NAME_CASE(VWSUBU_W_VL)
   NODE_NAME_CASE(VFWMUL_VL)
+  NODE_NAME_CASE(VFWADD_VL)
+  NODE_NAME_CASE(VFWSUB_VL)
+  NODE_NAME_CASE(VFWADD_W_VL)
+  NODE_NAME_CASE(VFWSUB_W_VL)
   NODE_NAME_CASE(VNSRL_VL)
   NODE_NAME_CASE(SETCC_VL)
   NODE_NAME_CASE(VSELECT_VL)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index fb67ed5445068..69d5dffa15d98 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -285,6 +285,10 @@ enum NodeType : unsigned {
   VWSUBU_W_VL,
 
   VFWMUL_VL,
+  VFWADD_VL,
+  VFWSUB_VL,
+  VFWADD_W_VL,
+  VFWSUB_W_VL,
 
   // Narrowing logical shift right.
   // Operands are (source, shift, passthru, mask, vl)

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 056c5ce61bbd7..71df6e4a6fce2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -400,6 +400,8 @@ def SDT_RISCVVWFPBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>,
                                                  SDTCVecEltisVT<4, i1>,
                                                  SDTCisVT<5, XLenVT>]>;
 def riscv_vfwmul_vl : SDNode<"RISCVISD::VFWMUL_VL", SDT_RISCVVWFPBinOp_VL, [SDNPCommutative]>;
+def riscv_vfwadd_vl : SDNode<"RISCVISD::VFWADD_VL", SDT_RISCVVWFPBinOp_VL, [SDNPCommutative]>;
+def riscv_vfwsub_vl : SDNode<"RISCVISD::VFWSUB_VL", SDT_RISCVVWFPBinOp_VL, []>;
 
 def SDT_RISCVVNIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
                                                   SDTCisInt<1>,
@@ -426,6 +428,19 @@ def riscv_vwaddu_w_vl : SDNode<"RISCVISD::VWADDU_W_VL", SDT_RISCVVWIntBinOpW_VL>
 def riscv_vwsub_w_vl :  SDNode<"RISCVISD::VWSUB_W_VL",  SDT_RISCVVWIntBinOpW_VL>;
 def riscv_vwsubu_w_vl : SDNode<"RISCVISD::VWSUBU_W_VL", SDT_RISCVVWIntBinOpW_VL>;
 
+def SDT_RISCVVWFPBinOpW_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>,
+                                                  SDTCisSameAs<0, 1>,
+                                                  SDTCisFP<2>,
+                                                  SDTCisSameNumEltsAs<1, 2>,
+                                                  SDTCisOpSmallerThanOp<2, 1>,
+                                                  SDTCisSameAs<0, 3>,
+                                                  SDTCisSameNumEltsAs<1, 4>,
+                                                  SDTCVecEltisVT<4, i1>,
+                                                  SDTCisVT<5, XLenVT>]>;
+
+def riscv_vfwadd_w_vl :  SDNode<"RISCVISD::VFWADD_W_VL", SDT_RISCVVWFPBinOpW_VL>;
+def riscv_vfwsub_w_vl :  SDNode<"RISCVISD::VFWSUB_W_VL", SDT_RISCVVWFPBinOpW_VL>;
+
 def SDTRVVVecReduce : SDTypeProfile<1, 6, [
   SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisSameAs<0, 3>,
   SDTCVecEltisVT<4, i1>, SDTCisSameNumEltsAs<2, 4>, SDTCisVT<5, XLenVT>,
@@ -1375,70 +1390,24 @@ multiclass VPatBinaryFPWVL_VV_VF<SDNode vop, string instruction_name> {
   }
 }
 
-multiclass VPatWidenBinaryFPVL_VV_VF<SDNode op, PatFrags extop, string instruction_name> {
-  foreach fvtiToFWti = AllWidenableFloatVectors in {
-    defvar fvti = fvtiToFWti.Vti;
-    defvar fwti = fvtiToFWti.Wti;
-    let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates,
-                                 GetVTypePredicates<fwti>.Predicates) in {
-      def : Pat<(fwti.Vector (op (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs2),
-                                                     (fvti.Mask true_mask), VLOpFrag)),
-                                 (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs1),
-                                                     (fvti.Mask true_mask), VLOpFrag)),
-                                 srcvalue, (fwti.Mask true_mask), VLOpFrag)),
-                (!cast<Instruction>(instruction_name#"_VV_"#fvti.LMul.MX)
-                   fvti.RegClass:$rs2, fvti.RegClass:$rs1,
-                   GPR:$vl, fvti.Log2SEW)>;
-      def : Pat<(fwti.Vector (op (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs2),
-                                                     (fvti.Mask true_mask), VLOpFrag)),
-                                 (fwti.Vector (extop (fvti.Vector (SplatFPOp fvti.ScalarRegClass:$rs1)),
-                                                     (fvti.Mask true_mask), VLOpFrag)),
-                                 srcvalue, (fwti.Mask true_mask), VLOpFrag)),
-                (!cast<Instruction>(instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
-                   fvti.RegClass:$rs2, fvti.ScalarRegClass:$rs1,
-                   GPR:$vl, fvti.Log2SEW)>;
-    }
-  }
-}
-
-multiclass VPatWidenBinaryFPVL_WV_WF<SDNode op, PatFrags extop, string instruction_name> {
+multiclass VPatBinaryFPWVL_VV_VF_WV_WF<SDNode vop, SDNode vop_w, string instruction_name>
+    : VPatBinaryFPWVL_VV_VF<vop, instruction_name> {
   foreach fvtiToFWti = AllWidenableFloatVectors in {
-    defvar fvti = fvtiToFWti.Vti;
-    defvar fwti = fvtiToFWti.Wti;
-    let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates,
-                                 GetVTypePredicates<fwti>.Predicates) in {
-      def : Pat<(fwti.Vector (op (fwti.Vector fwti.RegClass:$rs2),
-                                 (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs1),
-                                                     (fvti.Mask true_mask), VLOpFrag)),
-                                 srcvalue, (fwti.Mask true_mask), VLOpFrag)),
-                (!cast<Instruction>(instruction_name#"_WV_"#fvti.LMul.MX#"_TIED")
-                   fwti.RegClass:$rs2, fvti.RegClass:$rs1,
-                   GPR:$vl, fvti.Log2SEW, TAIL_AGNOSTIC)>;
-      // Tail undisturbed
-      def : Pat<(riscv_vp_merge_vl true_mask,
-                 (fwti.Vector (op (fwti.Vector fwti.RegClass:$rs2),
-                                  (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs1),
-                                                      (fvti.Mask true_mask), VLOpFrag)),
-                                  srcvalue, (fwti.Mask true_mask), VLOpFrag)),
-                 fwti.RegClass:$rs2, VLOpFrag),
-                (!cast<Instruction>(instruction_name#"_WV_"#fvti.LMul.MX#"_TIED")
-                   fwti.RegClass:$rs2, fvti.RegClass:$rs1,
-                   GPR:$vl, fvti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>;
-      def : Pat<(fwti.Vector (op (fwti.Vector fwti.RegClass:$rs2),
-                                 (fwti.Vector (extop (fvti.Vector (SplatFPOp fvti.ScalarRegClass:$rs1)),
-                                                     (fvti.Mask true_mask), VLOpFrag)),
-                                 srcvalue, (fwti.Mask true_mask), VLOpFrag)),
-                (!cast<Instruction>(instruction_name#"_W"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
-                   fwti.RegClass:$rs2, fvti.ScalarRegClass:$rs1,
-                   GPR:$vl, fvti.Log2SEW)>;
+    defvar vti = fvtiToFWti.Vti;
+    defvar wti = fvtiToFWti.Wti;
+    let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+                                 GetVTypePredicates<wti>.Predicates) in {
+      defm : VPatTiedBinaryNoMaskVL_V<vop_w, instruction_name, "WV",
+                                      wti.Vector, vti.Vector, vti.Log2SEW,
+                                      vti.LMul, wti.RegClass, vti.RegClass>;
+      def : VPatBinaryVL_VF<vop_w, instruction_name#"_W"#vti.ScalarSuffix,
+                            wti.Vector, wti.Vector, vti.Vector, vti.Mask,
+                            vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass,
+                            vti.ScalarRegClass>;
     }
   }
 }
 
-multiclass VPatWidenBinaryFPVL_VV_VF_WV_WF<SDNode op, string instruction_name>
-    : VPatWidenBinaryFPVL_VV_VF<op, riscv_fpextend_vl_oneuse, instruction_name>,
-      VPatWidenBinaryFPVL_WV_WF<op, riscv_fpextend_vl_oneuse, instruction_name>;
-
 multiclass VPatNarrowShiftSplatExt_WX<SDNode op, PatFrags extop, string instruction_name> {
   foreach vtiToWti = AllWidenableIntVectors in {
     defvar vti = vtiToWti.Vti;
@@ -1938,8 +1907,8 @@ defm : VPatBinaryFPVL_VV_VF<any_riscv_fsub_vl, "PseudoVFSUB">;
 defm : VPatBinaryFPVL_R_VF<any_riscv_fsub_vl, "PseudoVFRSUB">;
 
 // 13.3. Vector Widening Floating-Point Add/Subtract Instructions
-defm : VPatWidenBinaryFPVL_VV_VF_WV_WF<riscv_fadd_vl, "PseudoVFWADD">;
-defm : VPatWidenBinaryFPVL_VV_VF_WV_WF<riscv_fsub_vl, "PseudoVFWSUB">;
+defm : VPatBinaryFPWVL_VV_VF_WV_WF<riscv_vfwadd_vl, riscv_vfwadd_w_vl, "PseudoVFWADD">;
+defm : VPatBinaryFPWVL_VV_VF_WV_WF<riscv_vfwsub_vl, riscv_vfwsub_w_vl, "PseudoVFWSUB">;
 
 // 13.4. Vector Single-Width Floating-Point Multiply/Divide Instructions
 defm : VPatBinaryFPVL_VV_VF<any_riscv_fmul_vl, "PseudoVFMUL">;

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
index 1c2ba683cd876..661d8cc5a468d 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwadd-vp.ll
@@ -1,6 +1,21 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+experimental-zvfh | FileCheck %s
 
+define <vscale x 2 x float> @vfwadd_same_operand(<vscale x 2 x half> %arg, i32 signext %vl) {
+; CHECK-LABEL: vfwadd_same_operand:
+; CHECK:       # %bb.0: # %bb
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vfwadd.vv v9, v8, v8
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+bb:
+  %tmp = call <vscale x 2 x float> @llvm.vp.fpext.nxv2f32.nxv2f16(<vscale x 2 x half> %arg, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i32 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer), i32 %vl)
+  %tmp2 = call <vscale x 2 x float> @llvm.vp.fadd.nxv2f32(<vscale x 2 x float> %tmp, <vscale x 2 x float> %tmp, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i32 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer), i32 %vl)
+  ret <vscale x 2 x float> %tmp2
+}
+
 define <vscale x 2 x float> @vfwadd_tu(<vscale x 2 x half> %arg, <vscale x 2 x float> %arg1, i32 signext %arg2) {
 ; CHECK-LABEL: vfwadd_tu:
 ; CHECK:       # %bb.0: # %bb


        


More information about the llvm-commits mailing list