[llvm] f2a05c6 - [RISCV] Add RISCVISD nodes for VWFMADD_VL.
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Sun May 14 22:41:30 PDT 2023
Author: Craig Topper
Date: 2023-05-14T22:35:47-07:00
New Revision: f2a05c64e3880278c8b3afa5a78a93eb26d244e5
URL: https://github.com/llvm/llvm-project/commit/f2a05c64e3880278c8b3afa5a78a93eb26d244e5
DIFF: https://github.com/llvm/llvm-project/commit/f2a05c64e3880278c8b3afa5a78a93eb26d244e5.diff
LOG: [RISCV] Add RISCVISD nodes for VWFMADD_VL.
Use it to replace isel patterns with a DAG combine of FP_EXTEND_VL+VFMADD_VL.
This makes it similar to how other widening operations are handled.
I plan to use this to make it easier to form tail undisturbed vfwmacc.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 064a283d1cc4..3edff32f09ea 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11196,7 +11196,7 @@ static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {
return Opcode;
}
-static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) {
+static SDValue combineVFMADD_VLWithVFNEG_VL(SDNode *N, SelectionDAG &DAG) {
// Fold FNEG_VL into FMA opcodes.
// The first operand of strict-fp is chain.
unsigned Offset = N->isTargetStrictFPOpcode();
@@ -11233,6 +11233,59 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) {
VL);
}
+static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) {
+ if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
+ return V;
+
+ // FIXME: Ignore strict opcodes for now.
+ if (N->isTargetStrictFPOpcode())
+ return SDValue();
+
+ // Try to form widening FMA.
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue Mask = N->getOperand(3);
+ SDValue VL = N->getOperand(4);
+
+ if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
+ Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
+ return SDValue();
+
+ // TODO: Refactor to handle more complex cases similar to
+ // combineBinOp_VLToVWBinOp_VL.
+ if (!Op0.hasOneUse() || !Op1.hasOneUse())
+ return SDValue();
+
+ // Check the mask and VL are the same.
+ if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
+ Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
+ return SDValue();
+
+ unsigned NewOpc;
+ switch (N->getOpcode()) {
+ default:
+ llvm_unreachable("Unexpected opcode");
+ case RISCVISD::VFMADD_VL:
+ NewOpc = RISCVISD::VFWMADD_VL;
+ break;
+ case RISCVISD::VFNMSUB_VL:
+ NewOpc = RISCVISD::VFWNMSUB_VL;
+ break;
+ case RISCVISD::VFNMADD_VL:
+ NewOpc = RISCVISD::VFWNMADD_VL;
+ break;
+ case RISCVISD::VFMSUB_VL:
+ NewOpc = RISCVISD::VFWMSUB_VL;
+ break;
+ }
+
+ Op0 = Op0.getOperand(0);
+ Op1 = Op1.getOperand(0);
+
+ return DAG.getNode(NewOpc, SDLoc(N), N->getValueType(0), Op0, Op1,
+ N->getOperand(2), Mask, VL);
+}
+
static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
@@ -15074,6 +15127,10 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VFNMADD_VL)
NODE_NAME_CASE(VFMSUB_VL)
NODE_NAME_CASE(VFNMSUB_VL)
+ NODE_NAME_CASE(VFWMADD_VL)
+ NODE_NAME_CASE(VFWNMADD_VL)
+ NODE_NAME_CASE(VFWMSUB_VL)
+ NODE_NAME_CASE(VFWNMSUB_VL)
NODE_NAME_CASE(FCOPYSIGN_VL)
NODE_NAME_CASE(SMIN_VL)
NODE_NAME_CASE(SMAX_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 54e0b18ee8a6..3936c51884cb 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -256,6 +256,13 @@ enum NodeType : unsigned {
VFMSUB_VL,
VFNMSUB_VL,
+ // Vector widening FMA ops with a mask as a fourth operand and VL as a fifth
+ // operand.
+ VFWMADD_VL,
+ VFWNMADD_VL,
+ VFWMSUB_VL,
+ VFWNMSUB_VL,
+
// Widening instructions with a merge value a third operand, a mask as a
// fourth operand, and VL as a fifth operand.
VWMUL_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index fc59d1f049ed..ea5084620bb5 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -136,11 +136,25 @@ def SDT_RISCVVecFMA_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
SDTCVecEltisVT<4, i1>,
SDTCisSameNumEltsAs<0, 4>,
SDTCisVT<5, XLenVT>]>;
-def riscv_vfmadd_vl : SDNode<"RISCVISD::VFMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
+def riscv_vfmadd_vl : SDNode<"RISCVISD::VFMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
def riscv_vfnmadd_vl : SDNode<"RISCVISD::VFNMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
-def riscv_vfmsub_vl : SDNode<"RISCVISD::VFMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
+def riscv_vfmsub_vl : SDNode<"RISCVISD::VFMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
def riscv_vfnmsub_vl : SDNode<"RISCVISD::VFNMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
+def SDT_RISCVWVecFMA_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>,
+ SDTCisVec<1>, SDTCisFP<1>,
+ SDTCisOpSmallerThanOp<1, 0>,
+ SDTCisSameNumEltsAs<0, 1>,
+ SDTCisSameAs<1, 2>,
+ SDTCisSameAs<0, 3>,
+ SDTCVecEltisVT<4, i1>,
+ SDTCisSameNumEltsAs<0, 4>,
+ SDTCisVT<5, XLenVT>]>;
+def riscv_vfwmadd_vl : SDNode<"RISCVISD::VFWMADD_VL", SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
+def riscv_vfwnmadd_vl : SDNode<"RISCVISD::VFWNMADD_VL", SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
+def riscv_vfwmsub_vl : SDNode<"RISCVISD::VFWMSUB_VL", SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
+def riscv_vfwnmsub_vl : SDNode<"RISCVISD::VFWNMSUB_VL", SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
+
def riscv_strict_vfmadd_vl : SDNode<"RISCVISD::STRICT_VFMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
def riscv_strict_vfnmadd_vl : SDNode<"RISCVISD::STRICT_VFNMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
def riscv_strict_vfmsub_vl : SDNode<"RISCVISD::STRICT_VFMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
@@ -1514,25 +1528,15 @@ multiclass VPatWidenFPMulAccVL_VV_VF<SDNode vop, string instruction_name> {
foreach vtiToWti = AllWidenableFloatVectors in {
defvar vti = vtiToWti.Vti;
defvar wti = vtiToWti.Wti;
- def : Pat<(vop
- (wti.Vector (riscv_fpextend_vl_oneuse
- (vti.Vector vti.RegClass:$rs1),
- (vti.Mask true_mask), VLOpFrag)),
- (wti.Vector (riscv_fpextend_vl_oneuse
- (vti.Vector vti.RegClass:$rs2),
- (vti.Mask true_mask), VLOpFrag)),
+ def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
+ (vti.Vector vti.RegClass:$rs2),
(wti.Vector wti.RegClass:$rd), (vti.Mask true_mask),
VLOpFrag),
(!cast<Instruction>(instruction_name#"_VV_"#vti.LMul.MX)
wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
- def : Pat<(vop
- (wti.Vector (riscv_fpextend_vl_oneuse
- (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs1)),
- (vti.Mask true_mask), VLOpFrag)),
- (wti.Vector (riscv_fpextend_vl_oneuse
- (vti.Vector vti.RegClass:$rs2),
- (vti.Mask true_mask), VLOpFrag)),
+ def : Pat<(vop (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs1)),
+ (vti.Vector vti.RegClass:$rs2),
(wti.Vector wti.RegClass:$rd), (vti.Mask true_mask),
VLOpFrag),
(!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#vti.LMul.MX)
@@ -1827,10 +1831,10 @@ defm : VPatFPMulAccVL_VV_VF<riscv_vfnmadd_vl_oneuse, "PseudoVFNMACC">;
defm : VPatFPMulAccVL_VV_VF<riscv_vfnmsub_vl_oneuse, "PseudoVFNMSAC">;
// 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
-defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfmadd_vl, "PseudoVFWMACC">;
-defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfnmadd_vl, "PseudoVFWNMACC">;
-defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfmsub_vl, "PseudoVFWMSAC">;
-defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfnmsub_vl, "PseudoVFWNMSAC">;
+defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwmadd_vl, "PseudoVFWMACC">;
+defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwnmadd_vl, "PseudoVFWNMACC">;
+defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwmsub_vl, "PseudoVFWMSAC">;
+defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwnmsub_vl, "PseudoVFWNMSAC">;
// 13.11. Vector Floating-Point MIN/MAX Instructions
defm : VPatBinaryFPVL_VV_VF<riscv_fminnum_vl, "PseudoVFMIN">;
More information about the llvm-commits
mailing list