[llvm] [RISCV] Replace RISCVISD::VP_MERGE_VL with a new node that has a separate passthru operand. (PR #75682)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 15 19:13:24 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

ISD::VP_MERGE treats the false operand as the source for elements past VL. The vmerge instruction encodes 3 registers and treats the vd register as the source for the tail.

This patch adds a new ISD opcode that models the tail source explicitly. During lowering we copy the false operand to this operand.

I think we can merge RISCVISD::VSELECT_VL with this new opcode by using an UNDEF passthru, but I'll save that for another patch.

---
Full diff: https://github.com/llvm/llvm-project/pull/75682.diff


3 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+19-7) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+2-4) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+69-56) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 4a8ff73ec47295..7a9cab0aeb0db4 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5519,7 +5519,7 @@ static unsigned getRISCVVLOp(SDValue Op) {
   case ISD::VP_SELECT:
     return RISCVISD::VSELECT_VL;
   case ISD::VP_MERGE:
-    return RISCVISD::VP_MERGE_VL;
+    return RISCVISD::VMERGE_VL;
   case ISD::VP_ASHR:
     return RISCVISD::SRA_VL;
   case ISD::VP_LSHR:
@@ -5567,6 +5567,8 @@ static bool hasMergeOp(unsigned Opcode) {
     return true;
   if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL)
     return true;
+  if (Opcode == RISCVISD::VMERGE_VL)
+    return true;
   return false;
 }
 
@@ -8229,8 +8231,8 @@ static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
                          AVL);
     // TUMA or TUMU: Currently we always emit tumu policy regardless of tuma.
     // It's fine because vmerge does not care mask policy.
-    return DAG.getNode(RISCVISD::VP_MERGE_VL, DL, VT, Mask, Vec, MaskedOff,
-                       AVL);
+    return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Mask, Vec,
+                       MaskedOff, MaskedOff, AVL);
   }
   }
 
@@ -10303,9 +10305,19 @@ SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG) const {
   for (const auto &OpIdx : enumerate(Op->ops())) {
     SDValue V = OpIdx.value();
     assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
-    // Add dummy merge value before the mask.
-    if (HasMergeOp && *ISD::getVPMaskIdx(Op.getOpcode()) == OpIdx.index())
-      Ops.push_back(DAG.getUNDEF(ContainerVT));
+    // Add dummy merge value before the mask. Or if there isn't a mask, before
+    // EVL.
+    if (HasMergeOp) {
+      auto MaskIdx = ISD::getVPMaskIdx(Op.getOpcode());
+      if (MaskIdx) {
+        if (*MaskIdx == OpIdx.index())
+          Ops.push_back(DAG.getUNDEF(ContainerVT));
+      } else if (ISD::getVPExplicitVectorLengthIdx(Op.getOpcode()) == OpIdx.index()) {
+        // For VP_MERGE, copy the false operand instead of an undef value.
+        assert(Op.getOpcode() == ISD::VP_MERGE);
+        Ops.push_back(Ops.back());
+      }
+    }
     // Pass through operands which aren't fixed-length vectors.
     if (!V.getValueType().isFixedLengthVector()) {
       Ops.push_back(V);
@@ -18561,7 +18573,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(VNSRL_VL)
   NODE_NAME_CASE(SETCC_VL)
   NODE_NAME_CASE(VSELECT_VL)
-  NODE_NAME_CASE(VP_MERGE_VL)
+  NODE_NAME_CASE(VMERGE_VL)
   NODE_NAME_CASE(VMAND_VL)
   NODE_NAME_CASE(VMOR_VL)
   NODE_NAME_CASE(VMXOR_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 41a2dc5771c82d..765c6d3fb3b7c6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -332,10 +332,8 @@ enum NodeType : unsigned {
 
   // Vector select with an additional VL operand. This operation is unmasked.
   VSELECT_VL,
-  // Vector select with operand #2 (the value when the condition is false) tied
-  // to the destination and an additional VL operand. This operation is
-  // unmasked.
-  VP_MERGE_VL,
+  // General vmerge node with mask, true, false, passthru, and vl operands.
+  VMERGE_VL,
 
   // Mask binary operators.
   VMAND_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index dc6b57fad32105..33bdc3366aa3e3 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -344,7 +344,14 @@ def SDT_RISCVSelect_VL  : SDTypeProfile<1, 4, [
 ]>;
 
 def riscv_vselect_vl  : SDNode<"RISCVISD::VSELECT_VL", SDT_RISCVSelect_VL>;
-def riscv_vp_merge_vl : SDNode<"RISCVISD::VP_MERGE_VL", SDT_RISCVSelect_VL>;
+
+def SDT_RISCVVMERGE_VL  : SDTypeProfile<1, 5, [
+  SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>,
+  SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameAs<0, 4>,
+  SDTCisVT<5, XLenVT>
+]>;
+
+def riscv_vmerge_vl : SDNode<"RISCVISD::VMERGE_VL", SDT_RISCVVMERGE_VL>;
 
 def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCVecEltisVT<0, i1>,
                                                 SDTCisVT<1, XLenVT>]>;
@@ -675,14 +682,14 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
                      op2_reg_class:$rs2,
                      GPR:$vl, sew, TAIL_AGNOSTIC)>;
   // Tail undisturbed
-  def : Pat<(riscv_vp_merge_vl true_mask,
+  def : Pat<(riscv_vmerge_vl true_mask,
              (result_type (vop
                            result_reg_class:$rs1,
                            (op2_type op2_reg_class:$rs2),
                            srcvalue,
                            true_mask,
                            VLOpFrag)),
-             result_reg_class:$rs1, VLOpFrag),
+             result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag),
             (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_TIED")
                      result_reg_class:$rs1,
                      op2_reg_class:$rs2,
@@ -712,14 +719,14 @@ multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
                      FRM_DYN,
                      GPR:$vl, sew, TAIL_AGNOSTIC)>;
   // Tail undisturbed
-  def : Pat<(riscv_vp_merge_vl true_mask,
+  def : Pat<(riscv_vmerge_vl true_mask,
              (result_type (vop
                            result_reg_class:$rs1,
                            (op2_type op2_reg_class:$rs2),
                            srcvalue,
                            true_mask,
                            VLOpFrag)),
-             result_reg_class:$rs1, VLOpFrag),
+             result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag),
             (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_TIED")
                      result_reg_class:$rs1,
                      op2_reg_class:$rs2,
@@ -1697,21 +1704,21 @@ multiclass VPatMultiplyAccVL_VV_VX<PatFrag op, string instruction_name> {
   foreach vti = AllIntegerVectors in {
   defvar suffix = vti.LMul.MX;
   let Predicates = GetVTypePredicates<vti>.Predicates in {
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                 (vti.Vector (op vti.RegClass:$rd,
                                 (riscv_mul_vl_oneuse vti.RegClass:$rs1, vti.RegClass:$rs2,
                                     srcvalue, (vti.Mask true_mask), VLOpFrag),
                                 srcvalue, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
                    vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                 (vti.Vector (op vti.RegClass:$rd,
                                 (riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1), vti.RegClass:$rs2,
                                     srcvalue, (vti.Mask true_mask), VLOpFrag),
                                 srcvalue, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_VX_"# suffix #"_MASK")
                    vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
@@ -1840,17 +1847,17 @@ multiclass VPatFPMulAccVL_VV_VF<PatFrag vop, string instruction_name> {
   foreach vti = AllFloatVectors in {
   defvar suffix = vti.LMul.MX;
   let Predicates = GetVTypePredicates<vti>.Predicates in {
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                            (vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
                             vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
                    vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                            (vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
                             vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
                    vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
@@ -1876,10 +1883,10 @@ multiclass VPatFPMulAccVL_VV_VF_RM<PatFrag vop, string instruction_name> {
   foreach vti = AllFloatVectors in {
   defvar suffix = vti.LMul.MX;
   let Predicates = GetVTypePredicates<vti>.Predicates in {
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                            (vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
                             vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
                    vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0),
@@ -1887,10 +1894,10 @@ multiclass VPatFPMulAccVL_VV_VF_RM<PatFrag vop, string instruction_name> {
                    // RISCVInsertReadWriteCSR
                    FRM_DYN,
                    GPR:$vl, vti.Log2SEW, TU_MU)>;
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                            (vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
                             vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
                    vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0),
@@ -2273,29 +2280,32 @@ foreach vti = AllIntegerVectors in {
                    (vti.Vector (IMPLICIT_DEF)),
                    vti.RegClass:$rs2, simm5:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
 
-    def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
-                                             vti.RegClass:$rs1,
-                                             vti.RegClass:$rs2,
-                                             VLOpFrag)),
+    def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
+                                           vti.RegClass:$rs1,
+                                           vti.RegClass:$rs2,
+                                           vti.RegClass:$merge,
+                                           VLOpFrag)),
               (!cast<Instruction>("PseudoVMERGE_VVM_"#vti.LMul.MX)
-                   vti.RegClass:$rs2, vti.RegClass:$rs2, vti.RegClass:$rs1,
-                   (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
+                  vti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
+                  (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
 
-    def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
-                                             (SplatPat XLenVT:$rs1),
-                                             vti.RegClass:$rs2,
-                                             VLOpFrag)),
+    def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
+                                            (SplatPat XLenVT:$rs1),
+                                            vti.RegClass:$rs2,
+                                            vti.RegClass:$merge,
+                                            VLOpFrag)),
               (!cast<Instruction>("PseudoVMERGE_VXM_"#vti.LMul.MX)
-                   vti.RegClass:$rs2, vti.RegClass:$rs2, GPR:$rs1,
-                   (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
-
-    def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
-                                             (SplatPat_simm5 simm5:$rs1),
-                                             vti.RegClass:$rs2,
-                                             VLOpFrag)),
+                  vti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
+                  (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
+
+    def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
+                                           (SplatPat_simm5 simm5:$rs1),
+                                           vti.RegClass:$rs2,
+                                           vti.RegClass:$merge,
+                                           VLOpFrag)),
               (!cast<Instruction>("PseudoVMERGE_VIM_"#vti.LMul.MX)
-                   vti.RegClass:$rs2, vti.RegClass:$rs2, simm5:$rs1,
-                   (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
+                  vti.RegClass:$merge, vti.RegClass:$rs2, simm5:$rs1,
+                  (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
   }
 }
 
@@ -2493,21 +2503,23 @@ foreach fvti = AllFloatVectors in {
                    (fvti.Vector (IMPLICIT_DEF)),
                    fvti.RegClass:$rs2, 0, (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
 
-    def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
-                                              fvti.RegClass:$rs1,
-                                              fvti.RegClass:$rs2,
-                                              VLOpFrag)),
-              (!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX)
-                   fvti.RegClass:$rs2, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
-                   GPR:$vl, fvti.Log2SEW)>;
-
-    def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
-                                              (SplatFPOp (fvti.Scalar fpimm0)),
-                                              fvti.RegClass:$rs2,
-                                              VLOpFrag)),
-              (!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
-                   fvti.RegClass:$rs2, fvti.RegClass:$rs2, 0, (fvti.Mask V0),
-                   GPR:$vl, fvti.Log2SEW)>;
+  def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
+                                          fvti.RegClass:$rs1,
+                                          fvti.RegClass:$rs2,
+                                          fvti.RegClass:$merge,
+                                          VLOpFrag)),
+            (!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX)
+                 fvti.RegClass:$merge, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
+                 GPR:$vl, fvti.Log2SEW)>;
+
+  def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
+                                          (SplatFPOp (fvti.Scalar fpimm0)),
+                                          fvti.RegClass:$rs2,
+                                          fvti.RegClass:$merge,
+                                          VLOpFrag)),
+            (!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
+                 fvti.RegClass:$merge, fvti.RegClass:$rs2, 0, (fvti.Mask V0),
+                 GPR:$vl, fvti.Log2SEW)>;
   }
 
   let Predicates = GetVTypePredicates<fvti>.Predicates in {
@@ -2521,12 +2533,13 @@ foreach fvti = AllFloatVectors in {
                    (fvti.Scalar fvti.ScalarRegClass:$rs1),
                    (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
 
-    def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
-                                              (SplatFPOp fvti.ScalarRegClass:$rs1),
-                                              fvti.RegClass:$rs2,
-                                              VLOpFrag)),
+    def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
+                                            (SplatFPOp fvti.ScalarRegClass:$rs1),
+                                            fvti.RegClass:$rs2,
+                                            fvti.RegClass:$merge,
+                                            VLOpFrag)),
               (!cast<Instruction>("PseudoVFMERGE_V"#fvti.ScalarSuffix#"M_"#fvti.LMul.MX)
-                   fvti.RegClass:$rs2, fvti.RegClass:$rs2,
+                   fvti.RegClass:$merge, fvti.RegClass:$rs2,
                    (fvti.Scalar fvti.ScalarRegClass:$rs1),
                    (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/75682


More information about the llvm-commits mailing list