[llvm] 463f50b - [RISCV] Add RISCVISD::VFWMUL_VL. Use it to replace isel patterns with a DAG combine.
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Tue May 30 14:39:45 PDT 2023
Author: Craig Topper
Date: 2023-05-30T14:38:16-07:00
New Revision: 463f50b436a2ac3000a90d273f2ed05893e8864f
URL: https://github.com/llvm/llvm-project/commit/463f50b436a2ac3000a90d273f2ed05893e8864f
DIFF: https://github.com/llvm/llvm-project/commit/463f50b436a2ac3000a90d273f2ed05893e8864f.diff
LOG: [RISCV] Add RISCVISD::VFWMUL_VL. Use it to replace isel patterns with a DAG combine.
This is more consistent with how we handle integer widening multiply.
A follow up patch will add support for matching vfwmul when the
multiplicand is being squared.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3dc04d0f29e9..9d0267912c9f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11355,6 +11355,38 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) {
N->getOperand(2), Mask, VL);
}
+static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG) {
+ // FIXME: Ignore strict opcodes for now.
+ assert(!N->isTargetStrictFPOpcode() && "Unexpected opcode");
+
+ // Try to form widening multiply.
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue Merge = N->getOperand(2);
+ SDValue Mask = N->getOperand(3);
+ SDValue VL = N->getOperand(4);
+
+ if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
+ Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
+ return SDValue();
+
+ // TODO: Refactor to handle more complex cases similar to
+ // combineBinOp_VLToVWBinOp_VL.
+ if (!Op0.hasOneUse() || !Op1.hasOneUse())
+ return SDValue();
+
+ // Check the mask and VL are the same.
+ if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
+ Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
+ return SDValue();
+
+ Op0 = Op0.getOperand(0);
+ Op1 = Op1.getOperand(0);
+
+ return DAG.getNode(RISCVISD::VFWMUL_VL, SDLoc(N), N->getValueType(0), Op0,
+ Op1, Merge, Mask, VL);
+}
+
static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
@@ -12229,6 +12261,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::STRICT_VFMSUB_VL:
case RISCVISD::STRICT_VFNMSUB_VL:
return performVFMADD_VLCombine(N, DAG);
+ case RISCVISD::FMUL_VL:
+ return performVFMUL_VLCombine(N, DAG);
case ISD::LOAD:
case ISD::STORE: {
if (DCI.isAfterLegalizeDAG())
@@ -15339,6 +15373,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VWADDU_W_VL)
NODE_NAME_CASE(VWSUB_W_VL)
NODE_NAME_CASE(VWSUBU_W_VL)
+ NODE_NAME_CASE(VFWMUL_VL)
NODE_NAME_CASE(VNSRL_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VSELECT_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 829ff1fd4692..af6849cf73e6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -284,6 +284,8 @@ enum NodeType : unsigned {
VWSUB_W_VL,
VWSUBU_W_VL,
+ VFWMUL_VL,
+
// Narrowing logical shift right.
// Operands are (source, shift, passthru, mask, vl)
VNSRL_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 76e2a2b4f56b..b83ae5ff7cdd 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -388,6 +388,8 @@ def riscv_vwaddu_vl : SDNode<"RISCVISD::VWADDU_VL", SDT_RISCVVWBinOp_VL, [SDNPCo
def riscv_vwsub_vl : SDNode<"RISCVISD::VWSUB_VL", SDT_RISCVVWBinOp_VL, []>;
def riscv_vwsubu_vl : SDNode<"RISCVISD::VWSUBU_VL", SDT_RISCVVWBinOp_VL, []>;
+def riscv_vfwmul_vl : SDNode<"RISCVISD::VFWMUL_VL", SDT_RISCVVWBinOp_VL, [SDNPCommutative]>;
+
def SDT_RISCVVNBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>,
SDTCisSameNumEltsAs<0, 1>,
SDTCisOpSmallerThanOp<0, 1>,
@@ -726,6 +728,7 @@ multiclass VPatBinaryWVL_VV_VX<SDPatternOperator vop, string instruction_name> {
}
}
}
+
multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
string instruction_name>
: VPatBinaryWVL_VV_VX<vop, instruction_name> {
@@ -1346,6 +1349,24 @@ multiclass VPatWidenReductionVL_Ext_VL<SDNode vop, PatFrags extop, string instru
}
}
+multiclass VPatBinaryFPWVL_VV_VF<SDNode vop, string instruction_name> {
+ foreach fvtiToFWti = AllWidenableFloatVectors in {
+ defvar vti = fvtiToFWti.Vti;
+ defvar wti = fvtiToFWti.Wti;
+ let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+ GetVTypePredicates<wti>.Predicates) in {
+ defm : VPatBinaryVL_V<vop, instruction_name, "VV",
+ wti.Vector, vti.Vector, vti.Vector, vti.Mask,
+ vti.Log2SEW, vti.LMul, wti.RegClass, vti.RegClass,
+ vti.RegClass>;
+ defm : VPatBinaryVL_VF<vop, instruction_name#"_V"#vti.ScalarSuffix,
+ wti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
+ vti.LMul, wti.RegClass, vti.RegClass,
+ vti.ScalarRegClass>;
+ }
+ }
+}
+
multiclass VPatWidenBinaryFPVL_VV_VF<SDNode op, PatFrags extop, string instruction_name> {
foreach fvtiToFWti = AllWidenableFloatVectors in {
defvar fvti = fvtiToFWti.Vti;
@@ -1918,7 +1939,7 @@ defm : VPatBinaryFPVL_VV_VF_E<any_riscv_fdiv_vl, "PseudoVFDIV">;
defm : VPatBinaryFPVL_R_VF_E<any_riscv_fdiv_vl, "PseudoVFRDIV">;
// 13.5. Vector Widening Floating-Point Multiply Instructions
-defm : VPatWidenBinaryFPVL_VV_VF<riscv_fmul_vl, riscv_fpextend_vl_oneuse, "PseudoVFWMUL">;
+defm : VPatBinaryFPWVL_VV_VF<riscv_vfwmul_vl, "PseudoVFWMUL">;
// 13.6 Vector Single-Width Floating-Point Fused Multiply-Add Instructions.
defm : VPatFPMulAddVL_VV_VF<any_riscv_vfmadd_vl, "PseudoVFMADD">;
More information about the llvm-commits
mailing list