[llvm] 463f50b - [RISCV] Add RISCVISD::VFWMUL_VL. Use it to replace isel patterns with a DAG combine.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue May 30 14:39:45 PDT 2023


Author: Craig Topper
Date: 2023-05-30T14:38:16-07:00
New Revision: 463f50b436a2ac3000a90d273f2ed05893e8864f

URL: https://github.com/llvm/llvm-project/commit/463f50b436a2ac3000a90d273f2ed05893e8864f
DIFF: https://github.com/llvm/llvm-project/commit/463f50b436a2ac3000a90d273f2ed05893e8864f.diff

LOG: [RISCV] Add RISCVISD::VFWMUL_VL. Use it to replace isel patterns with a DAG combine.

This is more consistent with how we handle integer widening multiply.

A follow up patch will add support for matching vfwmul when the
multiplicand is being squared.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3dc04d0f29e9..9d0267912c9f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11355,6 +11355,38 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) {
                      N->getOperand(2), Mask, VL);
 }
 
+static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG) {
+  // FIXME: Ignore strict opcodes for now.
+  assert(!N->isTargetStrictFPOpcode() && "Unexpected opcode");
+
+  // Try to form widening multiply.
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue Merge = N->getOperand(2);
+  SDValue Mask = N->getOperand(3);
+  SDValue VL = N->getOperand(4);
+
+  if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
+      Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
+    return SDValue();
+
+  // TODO: Refactor to handle more complex cases similar to
+  // combineBinOp_VLToVWBinOp_VL.
+  if (!Op0.hasOneUse() || !Op1.hasOneUse())
+    return SDValue();
+
+  // Check the mask and VL are the same.
+  if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
+      Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
+    return SDValue();
+
+  Op0 = Op0.getOperand(0);
+  Op1 = Op1.getOperand(0);
+
+  return DAG.getNode(RISCVISD::VFWMUL_VL, SDLoc(N), N->getValueType(0), Op0,
+                     Op1, Merge, Mask, VL);
+}
+
 static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
   assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
@@ -12229,6 +12261,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case RISCVISD::STRICT_VFMSUB_VL:
   case RISCVISD::STRICT_VFNMSUB_VL:
     return performVFMADD_VLCombine(N, DAG);
+  case RISCVISD::FMUL_VL:
+    return performVFMUL_VLCombine(N, DAG);
   case ISD::LOAD:
   case ISD::STORE: {
     if (DCI.isAfterLegalizeDAG())
@@ -15339,6 +15373,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(VWADDU_W_VL)
   NODE_NAME_CASE(VWSUB_W_VL)
   NODE_NAME_CASE(VWSUBU_W_VL)
+  NODE_NAME_CASE(VFWMUL_VL)
   NODE_NAME_CASE(VNSRL_VL)
   NODE_NAME_CASE(SETCC_VL)
   NODE_NAME_CASE(VSELECT_VL)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 829ff1fd4692..af6849cf73e6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -284,6 +284,8 @@ enum NodeType : unsigned {
   VWSUB_W_VL,
   VWSUBU_W_VL,
 
+  VFWMUL_VL,
+
   // Narrowing logical shift right.
   // Operands are (source, shift, passthru, mask, vl)
   VNSRL_VL,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 76e2a2b4f56b..b83ae5ff7cdd 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -388,6 +388,8 @@ def riscv_vwaddu_vl : SDNode<"RISCVISD::VWADDU_VL", SDT_RISCVVWBinOp_VL, [SDNPCo
 def riscv_vwsub_vl :  SDNode<"RISCVISD::VWSUB_VL",  SDT_RISCVVWBinOp_VL, []>;
 def riscv_vwsubu_vl : SDNode<"RISCVISD::VWSUBU_VL", SDT_RISCVVWBinOp_VL, []>;
 
+def riscv_vfwmul_vl : SDNode<"RISCVISD::VFWMUL_VL", SDT_RISCVVWBinOp_VL, [SDNPCommutative]>;
+
 def SDT_RISCVVNBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>,
                                                SDTCisSameNumEltsAs<0, 1>,
                                                SDTCisOpSmallerThanOp<0, 1>,
@@ -726,6 +728,7 @@ multiclass VPatBinaryWVL_VV_VX<SDPatternOperator vop, string instruction_name> {
     }
   }
 }
+
 multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
                                      string instruction_name>
     : VPatBinaryWVL_VV_VX<vop, instruction_name> {
@@ -1346,6 +1349,24 @@ multiclass VPatWidenReductionVL_Ext_VL<SDNode vop, PatFrags extop, string instru
   }
 }
 
+multiclass VPatBinaryFPWVL_VV_VF<SDNode vop, string instruction_name> {
+  foreach fvtiToFWti = AllWidenableFloatVectors in {
+    defvar vti = fvtiToFWti.Vti;
+    defvar wti = fvtiToFWti.Wti;
+    let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+                                 GetVTypePredicates<wti>.Predicates) in {
+      defm : VPatBinaryVL_V<vop, instruction_name, "VV",
+                            wti.Vector, vti.Vector, vti.Vector, vti.Mask,
+                            vti.Log2SEW, vti.LMul, wti.RegClass, vti.RegClass,
+                            vti.RegClass>;
+      defm : VPatBinaryVL_VF<vop, instruction_name#"_V"#vti.ScalarSuffix,
+                             wti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
+                             vti.LMul, wti.RegClass, vti.RegClass,
+                             vti.ScalarRegClass>;
+    }
+  }
+}
+
 multiclass VPatWidenBinaryFPVL_VV_VF<SDNode op, PatFrags extop, string instruction_name> {
   foreach fvtiToFWti = AllWidenableFloatVectors in {
     defvar fvti = fvtiToFWti.Vti;
@@ -1918,7 +1939,7 @@ defm : VPatBinaryFPVL_VV_VF_E<any_riscv_fdiv_vl, "PseudoVFDIV">;
 defm : VPatBinaryFPVL_R_VF_E<any_riscv_fdiv_vl, "PseudoVFRDIV">;
 
 // 13.5. Vector Widening Floating-Point Multiply Instructions
-defm : VPatWidenBinaryFPVL_VV_VF<riscv_fmul_vl, riscv_fpextend_vl_oneuse, "PseudoVFWMUL">;
+defm : VPatBinaryFPWVL_VV_VF<riscv_vfwmul_vl, "PseudoVFWMUL">;
 
 // 13.6 Vector Single-Width Floating-Point Fused Multiply-Add Instructions.
 defm : VPatFPMulAddVL_VV_VF<any_riscv_vfmadd_vl,  "PseudoVFMADD">;


        


More information about the llvm-commits mailing list