[llvm] a986998 - [RISCV] Introduce RISCVISD::VWMACC(U/SU)_VL opcode
Nitin John Raj via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 16 16:15:23 PDT 2023
Author: Nitin John Raj
Date: 2023-06-16T16:11:35-07:00
New Revision: a986998bad9254ae0513d548edebd55a1153b609
URL: https://github.com/llvm/llvm-project/commit/a986998bad9254ae0513d548edebd55a1153b609
DIFF: https://github.com/llvm/llvm-project/commit/a986998bad9254ae0513d548edebd55a1153b609.diff
LOG: [RISCV] Introduce RISCVISD::VWMACC(U/SU)_VL opcode
Differential Revision: https://reviews.llvm.org/D153057
Added:
llvm/test/CodeGen/RISCV/rvv/vwmaccsu-vp.ll
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Removed:
llvm/test/CodeGen/RISCV/rvv/vwmaccus-vp.ll
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 02f3b584d19aa..d7b841ccb2dfd 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12136,6 +12136,63 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
return convertFromScalableVector(VT, Res, DAG, Subtarget);
}
+static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ assert(N->getOpcode() == RISCVISD::ADD_VL);
+ SDValue Addend = N->getOperand(0);
+ SDValue MulOp = N->getOperand(1);
+ SDValue AddMergeOp = N->getOperand(2);
+
+ if (!AddMergeOp.isUndef())
+ return SDValue();
+
+ auto IsVWMulOpc = [](unsigned Opc) {
+ switch (Opc) {
+ case RISCVISD::VWMUL_VL:
+ case RISCVISD::VWMULU_VL:
+ case RISCVISD::VWMULSU_VL:
+ return true;
+ default:
+ return false;
+ }
+ };
+
+ if (!IsVWMulOpc(MulOp.getOpcode()))
+ std::swap(Addend, MulOp);
+
+ if (!IsVWMulOpc(MulOp.getOpcode()))
+ return SDValue();
+
+ SDValue MulMergeOp = MulOp.getOperand(2);
+
+ if (!MulMergeOp.isUndef())
+ return SDValue();
+
+ SDValue AddMask = N->getOperand(3);
+ SDValue AddVL = N->getOperand(4);
+ SDValue MulMask = MulOp.getOperand(3);
+ SDValue MulVL = MulOp.getOperand(4);
+
+ if (AddMask != MulMask || AddVL != MulVL)
+ return SDValue();
+
+ unsigned Opc = RISCVISD::VWMACC_VL + MulOp.getOpcode() - RISCVISD::VWMUL_VL;
+ static_assert(RISCVISD::VWMACC_VL + 1 == RISCVISD::VWMACCU_VL,
+ "Unexpected opcode after VWMACC_VL");
+ static_assert(RISCVISD::VWMACC_VL + 2 == RISCVISD::VWMACCSU_VL,
+ "Unexpected opcode after VWMACC_VL!");
+ static_assert(RISCVISD::VWMUL_VL + 1 == RISCVISD::VWMULU_VL,
+ "Unexpected opcode after VWMUL_VL!");
+ static_assert(RISCVISD::VWMUL_VL + 2 == RISCVISD::VWMULSU_VL,
+ "Unexpected opcode after VWMUL_VL!");
+
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ SDValue Ops[] = {MulOp.getOperand(0), MulOp.getOperand(1), Addend, AddMask,
+ AddVL};
+ return DAG.getNode(Opc, DL, VT, Ops);
+}
+
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@@ -12546,6 +12603,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case RISCVISD::ADD_VL:
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
+ return V;
+ return combineToVWMACC(N, DAG, Subtarget);
case RISCVISD::SUB_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
@@ -15683,6 +15743,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VFWSUB_VL)
NODE_NAME_CASE(VFWADD_W_VL)
NODE_NAME_CASE(VFWSUB_W_VL)
+ NODE_NAME_CASE(VWMACC_VL)
+ NODE_NAME_CASE(VWMACCU_VL)
+ NODE_NAME_CASE(VWMACCSU_VL)
NODE_NAME_CASE(VNSRL_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VSELECT_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index fb7b029db662c..dddfe87df9064 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -294,6 +294,12 @@ enum NodeType : unsigned {
VFWADD_W_VL,
VFWSUB_W_VL,
+ // Widening ternary operations with a mask as the fourth operand and VL as the
+ // fifth operand.
+ VWMACC_VL,
+ VWMACCU_VL,
+ VWMACCSU_VL,
+
// Narrowing logical shift right.
// Operands are (source, shift, passthru, mask, vl)
VNSRL_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index abf1290bd9d94..e17844c2cc8fe 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -395,6 +395,19 @@ def riscv_vwaddu_vl : SDNode<"RISCVISD::VWADDU_VL", SDT_RISCVVWIntBinOp_VL, [S
def riscv_vwsub_vl : SDNode<"RISCVISD::VWSUB_VL", SDT_RISCVVWIntBinOp_VL, []>;
def riscv_vwsubu_vl : SDNode<"RISCVISD::VWSUBU_VL", SDT_RISCVVWIntBinOp_VL, []>;
+def SDT_RISCVVWIntTernOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
+ SDTCisInt<1>,
+ SDTCisSameNumEltsAs<0, 1>,
+ SDTCisOpSmallerThanOp<1, 0>,
+ SDTCisSameAs<1, 2>,
+ SDTCisSameAs<0, 3>,
+ SDTCisSameNumEltsAs<1, 4>,
+ SDTCVecEltisVT<4, i1>,
+ SDTCisVT<5, XLenVT>]>;
+def riscv_vwmacc_vl : SDNode<"RISCVISD::VWMACC_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>;
+def riscv_vwmaccu_vl : SDNode<"RISCVISD::VWMACCU_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>;
+def riscv_vwmaccsu_vl : SDNode<"RISCVISD::VWMACCSU_VL", SDT_RISCVVWIntTernOp_VL, []>;
+
def SDT_RISCVVWFPBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>,
SDTCisFP<1>,
SDTCisSameNumEltsAs<0, 1>,
@@ -1407,30 +1420,27 @@ multiclass VPatMultiplyAccVL_VV_VX<PatFrag op, string instruction_name> {
}
}
-multiclass VPatWidenMultiplyAddVL_VV_VX<PatFrag op1, string instruction_name> {
+multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name> {
foreach vtiTowti = AllWidenableIntVectors in {
defvar vti = vtiTowti.Vti;
defvar wti = vtiTowti.Wti;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
- def : Pat<(wti.Vector
- (riscv_add_vl wti.RegClass:$rd,
- (op1 vti.RegClass:$rs1,
- (vti.Vector vti.RegClass:$rs2),
- srcvalue, (vti.Mask true_mask), VLOpFrag),
- srcvalue, (vti.Mask true_mask), VLOpFrag)),
- (!cast<Instruction>(instruction_name#"_VV_" # vti.LMul.MX)
- wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
- GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
- def : Pat<(wti.Vector
- (riscv_add_vl wti.RegClass:$rd,
- (op1 (SplatPat XLenVT:$rs1),
- (vti.Vector vti.RegClass:$rs2),
- srcvalue, (vti.Mask true_mask), VLOpFrag),
- srcvalue, (vti.Mask true_mask), VLOpFrag)),
- (!cast<Instruction>(instruction_name#"_VX_" # vti.LMul.MX)
- wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
- GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+ def : Pat<(vwmacc_op (vti.Vector vti.RegClass:$rs1),
+ (vti.Vector vti.RegClass:$rs2),
+ (wti.Vector wti.RegClass:$rd),
+ (vti.Mask V0), VLOpFrag),
+ (!cast<Instruction>(instr_name#"_VV_"#vti.LMul.MX#"_MASK")
+ wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+ (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+ def : Pat<(vwmacc_op (SplatPat XLenVT:$rs1),
+ (vti.Vector vti.RegClass:$rs2),
+ (wti.Vector wti.RegClass:$rd),
+ (vti.Mask V0), VLOpFrag),
+ (!cast<Instruction>(instr_name#"_VX_"#vti.LMul.MX#"_MASK")
+ wti.RegClass:$rd, vti.ScalarRegClass:$rs1,
+ vti.RegClass:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW,
+ TAIL_AGNOSTIC)>;
}
}
}
@@ -1704,25 +1714,21 @@ defm : VPatMultiplyAccVL_VV_VX<riscv_add_vl_oneuse, "PseudoVMACC">;
defm : VPatMultiplyAccVL_VV_VX<riscv_sub_vl_oneuse, "PseudoVNMSAC">;
// 11.14. Vector Widening Integer Multiply-Add Instructions
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmul_vl_oneuse, "PseudoVWMACC">;
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmulu_vl_oneuse, "PseudoVWMACCU">;
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmulsu_vl_oneuse, "PseudoVWMACCSU">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmacc_vl, "PseudoVWMACC">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmaccu_vl, "PseudoVWMACCU">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmaccsu_vl, "PseudoVWMACCSU">;
foreach vtiTowti = AllWidenableIntVectors in {
defvar vti = vtiTowti.Vti;
defvar wti = vtiTowti.Wti;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in
- def : Pat<(wti.Vector
- (riscv_add_vl wti.RegClass:$rd,
- (riscv_vwmulsu_vl_oneuse (vti.Vector vti.RegClass:$rs1),
- (SplatPat XLenVT:$rs2),
- srcvalue,
- (vti.Mask true_mask),
- VLOpFrag),
- srcvalue, (vti.Mask true_mask),VLOpFrag)),
- (!cast<Instruction>("PseudoVWMACCUS_VX_" # vti.LMul.MX)
- wti.RegClass:$rd, vti.ScalarRegClass:$rs2, vti.RegClass:$rs1,
- GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+ def : Pat<(riscv_vwmaccsu_vl (vti.Vector vti.RegClass:$rs1),
+ (SplatPat XLenVT:$rs2),
+ (wti.Vector wti.RegClass:$rd),
+ (vti.Mask V0), VLOpFrag),
+ (!cast<Instruction>("PseudoVWMACCUS_VX_"#vti.LMul.MX#"_MASK")
+ wti.RegClass:$rd, vti.ScalarRegClass:$rs2, vti.RegClass:$rs1,
+ (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
}
// 11.15. Vector Integer Merge Instructions
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwmaccus-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vwmaccsu-vp.ll
similarity index 100%
rename from llvm/test/CodeGen/RISCV/rvv/vwmaccus-vp.ll
rename to llvm/test/CodeGen/RISCV/rvv/vwmaccsu-vp.ll
More information about the llvm-commits
mailing list