[llvm] a986998 - [RISCV] Introduce RISCVISD::VWMACC(U/SU)_VL opcode

Nitin John Raj via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 16 16:15:23 PDT 2023


Author: Nitin John Raj
Date: 2023-06-16T16:11:35-07:00
New Revision: a986998bad9254ae0513d548edebd55a1153b609

URL: https://github.com/llvm/llvm-project/commit/a986998bad9254ae0513d548edebd55a1153b609
DIFF: https://github.com/llvm/llvm-project/commit/a986998bad9254ae0513d548edebd55a1153b609.diff

LOG: [RISCV] Introduce RISCVISD::VWMACC(U/SU)_VL opcode

Differential Revision: https://reviews.llvm.org/D153057

Added: 
    llvm/test/CodeGen/RISCV/rvv/vwmaccsu-vp.ll

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Removed: 
    llvm/test/CodeGen/RISCV/rvv/vwmaccus-vp.ll


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 02f3b584d19aa..d7b841ccb2dfd 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12136,6 +12136,63 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   return convertFromScalableVector(VT, Res, DAG, Subtarget);
 }
 
+static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
+                               const RISCVSubtarget &Subtarget) {
+  assert(N->getOpcode() == RISCVISD::ADD_VL);
+  SDValue Addend = N->getOperand(0);
+  SDValue MulOp = N->getOperand(1);
+  SDValue AddMergeOp = N->getOperand(2);
+
+  if (!AddMergeOp.isUndef())
+    return SDValue();
+
+  auto IsVWMulOpc = [](unsigned Opc) {
+    switch (Opc) {
+    case RISCVISD::VWMUL_VL:
+    case RISCVISD::VWMULU_VL:
+    case RISCVISD::VWMULSU_VL:
+      return true;
+    default:
+      return false;
+    }
+  };
+
+  if (!IsVWMulOpc(MulOp.getOpcode()))
+    std::swap(Addend, MulOp);
+
+  if (!IsVWMulOpc(MulOp.getOpcode()))
+    return SDValue();
+
+  SDValue MulMergeOp = MulOp.getOperand(2);
+
+  if (!MulMergeOp.isUndef())
+    return SDValue();
+
+  SDValue AddMask = N->getOperand(3);
+  SDValue AddVL = N->getOperand(4);
+  SDValue MulMask = MulOp.getOperand(3);
+  SDValue MulVL = MulOp.getOperand(4);
+
+  if (AddMask != MulMask || AddVL != MulVL)
+    return SDValue();
+
+  unsigned Opc = RISCVISD::VWMACC_VL + MulOp.getOpcode() - RISCVISD::VWMUL_VL;
+  static_assert(RISCVISD::VWMACC_VL + 1 == RISCVISD::VWMACCU_VL,
+                "Unexpected opcode after VWMACC_VL");
+  static_assert(RISCVISD::VWMACC_VL + 2 == RISCVISD::VWMACCSU_VL,
+                "Unexpected opcode after VWMACC_VL!");
+  static_assert(RISCVISD::VWMUL_VL + 1 == RISCVISD::VWMULU_VL,
+                "Unexpected opcode after VWMUL_VL!");
+  static_assert(RISCVISD::VWMUL_VL + 2 == RISCVISD::VWMULSU_VL,
+                "Unexpected opcode after VWMUL_VL!");
+
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  SDValue Ops[] = {MulOp.getOperand(0), MulOp.getOperand(1), Addend, AddMask,
+                   AddVL};
+  return DAG.getNode(Opc, DL, VT, Ops);
+}
+
 SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
@@ -12546,6 +12603,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     break;
   }
   case RISCVISD::ADD_VL:
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
+      return V;
+    return combineToVWMACC(N, DAG, Subtarget);
   case RISCVISD::SUB_VL:
   case RISCVISD::VWADD_W_VL:
   case RISCVISD::VWADDU_W_VL:
@@ -15683,6 +15743,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(VFWSUB_VL)
   NODE_NAME_CASE(VFWADD_W_VL)
   NODE_NAME_CASE(VFWSUB_W_VL)
+  NODE_NAME_CASE(VWMACC_VL)
+  NODE_NAME_CASE(VWMACCU_VL)
+  NODE_NAME_CASE(VWMACCSU_VL)
   NODE_NAME_CASE(VNSRL_VL)
   NODE_NAME_CASE(SETCC_VL)
   NODE_NAME_CASE(VSELECT_VL)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index fb7b029db662c..dddfe87df9064 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -294,6 +294,12 @@ enum NodeType : unsigned {
   VFWADD_W_VL,
   VFWSUB_W_VL,
 
+  // Widening ternary operations with a mask as the fourth operand and VL as the
+  // fifth operand.
+  VWMACC_VL,
+  VWMACCU_VL,
+  VWMACCSU_VL,
+
   // Narrowing logical shift right.
   // Operands are (source, shift, passthru, mask, vl)
   VNSRL_VL,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index abf1290bd9d94..e17844c2cc8fe 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -395,6 +395,19 @@ def riscv_vwaddu_vl  : SDNode<"RISCVISD::VWADDU_VL",  SDT_RISCVVWIntBinOp_VL, [S
 def riscv_vwsub_vl   : SDNode<"RISCVISD::VWSUB_VL",   SDT_RISCVVWIntBinOp_VL, []>;
 def riscv_vwsubu_vl  : SDNode<"RISCVISD::VWSUBU_VL",  SDT_RISCVVWIntBinOp_VL, []>;
 
+def SDT_RISCVVWIntTernOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
+                                                   SDTCisInt<1>,
+                                                   SDTCisSameNumEltsAs<0, 1>,
+                                                   SDTCisOpSmallerThanOp<1, 0>,
+                                                   SDTCisSameAs<1, 2>,
+                                                   SDTCisSameAs<0, 3>,
+                                                   SDTCisSameNumEltsAs<1, 4>,
+                                                   SDTCVecEltisVT<4, i1>,
+                                                   SDTCisVT<5, XLenVT>]>;
+def riscv_vwmacc_vl : SDNode<"RISCVISD::VWMACC_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>;
+def riscv_vwmaccu_vl : SDNode<"RISCVISD::VWMACCU_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>;
+def riscv_vwmaccsu_vl : SDNode<"RISCVISD::VWMACCSU_VL", SDT_RISCVVWIntTernOp_VL, []>;
+
 def SDT_RISCVVWFPBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>,
                                                  SDTCisFP<1>,
                                                  SDTCisSameNumEltsAs<0, 1>,
@@ -1407,30 +1420,27 @@ multiclass VPatMultiplyAccVL_VV_VX<PatFrag op, string instruction_name> {
   }
 }
 
-multiclass VPatWidenMultiplyAddVL_VV_VX<PatFrag op1, string instruction_name> {
+multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name> {
   foreach vtiTowti = AllWidenableIntVectors in {
     defvar vti = vtiTowti.Vti;
     defvar wti = vtiTowti.Wti;
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def : Pat<(wti.Vector
-               (riscv_add_vl wti.RegClass:$rd,
-                             (op1 vti.RegClass:$rs1,
-                                  (vti.Vector vti.RegClass:$rs2),
-                                  srcvalue, (vti.Mask true_mask), VLOpFrag),
-                            srcvalue, (vti.Mask true_mask), VLOpFrag)),
-              (!cast<Instruction>(instruction_name#"_VV_" # vti.LMul.MX)
-                   wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
-                   GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-      def : Pat<(wti.Vector
-               (riscv_add_vl wti.RegClass:$rd,
-                            (op1 (SplatPat XLenVT:$rs1),
-                                 (vti.Vector vti.RegClass:$rs2),
-                                 srcvalue, (vti.Mask true_mask), VLOpFrag),
-                             srcvalue, (vti.Mask true_mask), VLOpFrag)),
-              (!cast<Instruction>(instruction_name#"_VX_" # vti.LMul.MX)
-                   wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
-                   GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+      def : Pat<(vwmacc_op (vti.Vector vti.RegClass:$rs1),
+                           (vti.Vector vti.RegClass:$rs2),
+                           (wti.Vector wti.RegClass:$rd),
+                           (vti.Mask V0), VLOpFrag),
+                (!cast<Instruction>(instr_name#"_VV_"#vti.LMul.MX#"_MASK")
+                    wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+      def : Pat<(vwmacc_op (SplatPat XLenVT:$rs1),
+                           (vti.Vector vti.RegClass:$rs2),
+                           (wti.Vector wti.RegClass:$rd),
+                           (vti.Mask V0), VLOpFrag),
+                (!cast<Instruction>(instr_name#"_VX_"#vti.LMul.MX#"_MASK")
+                    wti.RegClass:$rd, vti.ScalarRegClass:$rs1,
+                    vti.RegClass:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW,
+                    TAIL_AGNOSTIC)>;
     }
   }
 }
@@ -1704,25 +1714,21 @@ defm : VPatMultiplyAccVL_VV_VX<riscv_add_vl_oneuse, "PseudoVMACC">;
 defm : VPatMultiplyAccVL_VV_VX<riscv_sub_vl_oneuse, "PseudoVNMSAC">;
 
 // 11.14. Vector Widening Integer Multiply-Add Instructions
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmul_vl_oneuse, "PseudoVWMACC">;
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmulu_vl_oneuse, "PseudoVWMACCU">;
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmulsu_vl_oneuse, "PseudoVWMACCSU">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmacc_vl, "PseudoVWMACC">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmaccu_vl, "PseudoVWMACCU">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmaccsu_vl, "PseudoVWMACCSU">;
 foreach vtiTowti = AllWidenableIntVectors in {
   defvar vti = vtiTowti.Vti;
   defvar wti = vtiTowti.Wti;
   let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                GetVTypePredicates<wti>.Predicates) in
-  def : Pat<(wti.Vector
-             (riscv_add_vl wti.RegClass:$rd,
-                           (riscv_vwmulsu_vl_oneuse (vti.Vector vti.RegClass:$rs1),
-                                                    (SplatPat XLenVT:$rs2),
-                                                    srcvalue,
-                                                    (vti.Mask true_mask),
-                                                    VLOpFrag),
-                           srcvalue, (vti.Mask true_mask),VLOpFrag)),
-            (!cast<Instruction>("PseudoVWMACCUS_VX_" # vti.LMul.MX)
-                 wti.RegClass:$rd, vti.ScalarRegClass:$rs2, vti.RegClass:$rs1,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+  def : Pat<(riscv_vwmaccsu_vl (vti.Vector vti.RegClass:$rs1),
+                               (SplatPat XLenVT:$rs2),
+                               (wti.Vector wti.RegClass:$rd),
+                               (vti.Mask V0), VLOpFrag),
+            (!cast<Instruction>("PseudoVWMACCUS_VX_"#vti.LMul.MX#"_MASK")
+                wti.RegClass:$rd, vti.ScalarRegClass:$rs2, vti.RegClass:$rs1,
+                (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
 }
 
 // 11.15. Vector Integer Merge Instructions

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vwmaccus-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vwmaccsu-vp.ll
similarity index 100%
rename from llvm/test/CodeGen/RISCV/rvv/vwmaccus-vp.ll
rename to llvm/test/CodeGen/RISCV/rvv/vwmaccsu-vp.ll


        


More information about the llvm-commits mailing list