[llvm] 16d3a82 - [RISCV] Add merge operand to RISCVISD::VRGATHER*_VL nodes.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 20 19:16:46 PDT 2022


Author: Craig Topper
Date: 2022-06-20T18:58:24-07:00
New Revision: 16d3a82de53dab4bb5ed468aff92df276f8a6e39

URL: https://github.com/llvm/llvm-project/commit/16d3a82de53dab4bb5ed468aff92df276f8a6e39
DIFF: https://github.com/llvm/llvm-project/commit/16d3a82de53dab4bb5ed468aff92df276f8a6e39.diff

LOG: [RISCV] Add merge operand to RISCVISD::VRGATHER*_VL nodes.

Use it in place of VSELECT_VL+VRGATHER*_VL.

This simplifies the isel patterns.

Overall, I think trying to match select+op to create masked instructions
in isel doesn't scale. We either need to do it in DAG combine, pre-isel
peepole, or post-isel peephole. I don't yet know which is the right
answer, but for this case it seemed best to be able to request the
masked form directly from lowering.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D128023

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b80dc298890c5..60e490fa2ce93 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1957,8 +1957,8 @@ static SDValue matchSplatAsGather(SDValue SplatVal, MVT VT, const SDLoc &DL,
   SDValue Mask, VL;
   std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
 
-  SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, Vec,
-                               Idx, Mask, VL);
+  SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT,
+                               DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL);
 
   if (!VT.isFixedLengthVector())
     return Gather;
@@ -2581,9 +2581,9 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
 
       V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget);
       assert(Lane < (int)NumElts && "Unexpected lane!");
-      SDValue Gather =
-          DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, V1,
-                      DAG.getConstant(Lane, DL, XLenVT), TrueMask, VL);
+      SDValue Gather = DAG.getNode(
+          RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
+          V1, DAG.getConstant(Lane, DL, XLenVT), TrueMask, VL);
       return convertFromScalableVector(VT, Gather, DAG, Subtarget);
     }
   }
@@ -2793,16 +2793,17 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
     // that's beneficial.
     if (LHSIndexCounts.size() == 1) {
       int SplatIndex = LHSIndexCounts.begin()->getFirst();
-      Gather =
-          DAG.getNode(GatherVXOpc, DL, ContainerVT, V1,
-                      DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, VL);
+      Gather = DAG.getNode(
+          GatherVXOpc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), V1,
+          DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, VL);
     } else {
       SDValue LHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesLHS);
       LHSIndices =
           convertToScalableVector(IndexContainerVT, LHSIndices, DAG, Subtarget);
 
-      Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V1, LHSIndices,
-                           TrueMask, VL);
+      Gather =
+          DAG.getNode(GatherVVOpc, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
+                      V1, LHSIndices, TrueMask, VL);
     }
   }
 
@@ -2810,27 +2811,26 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
   // additional vrgather.
   if (!V2.isUndef()) {
     V2 = convertToScalableVector(ContainerVT, V2, DAG, Subtarget);
+
+    MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1);
+    SelectMask =
+        convertToScalableVector(MaskContainerVT, SelectMask, DAG, Subtarget);
+
     // If only one index is used, we can use a "splat" vrgather.
     // TODO: We can splat the most-common index and fix-up any stragglers, if
     // that's beneficial.
     if (RHSIndexCounts.size() == 1) {
       int SplatIndex = RHSIndexCounts.begin()->getFirst();
-      V2 = DAG.getNode(GatherVXOpc, DL, ContainerVT, V2,
-                       DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, VL);
+      Gather =
+          DAG.getNode(GatherVXOpc, DL, ContainerVT, Gather, V2,
+                      DAG.getConstant(SplatIndex, DL, XLenVT), SelectMask, VL);
     } else {
       SDValue RHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesRHS);
       RHSIndices =
           convertToScalableVector(IndexContainerVT, RHSIndices, DAG, Subtarget);
-      V2 = DAG.getNode(GatherVVOpc, DL, ContainerVT, V2, RHSIndices, TrueMask,
-                       VL);
+      Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, Gather, V2, RHSIndices,
+                           SelectMask, VL);
     }
-
-    MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1);
-    SelectMask =
-        convertToScalableVector(MaskContainerVT, SelectMask, DAG, Subtarget);
-
-    Gather = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, SelectMask, V2,
-                         Gather, VL);
   }
 
   return convertFromScalableVector(VT, Gather, DAG, Subtarget);
@@ -5691,7 +5691,8 @@ SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op,
   SDValue Indices =
       DAG.getNode(RISCVISD::SUB_VL, DL, IntVT, SplatVL, VID, Mask, VL);
 
-  return DAG.getNode(GatherOpc, DL, VecVT, Op.getOperand(0), Indices, Mask, VL);
+  return DAG.getNode(GatherOpc, DL, VecVT, DAG.getUNDEF(VecVT),
+                     Op.getOperand(0), Indices, Mask, VL);
 }
 
 SDValue RISCVTargetLowering::lowerVECTOR_SPLICE(SDValue Op,

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 91fdd5e887482..672086da46d82 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -286,8 +286,8 @@ enum NodeType : unsigned {
   VMCLR_VL,
   VMSET_VL,
 
-  // Matches the semantics of vrgather.vx and vrgather.vv with an extra operand
-  // for VL.
+  // Matches the semantics of vrgather.vx and vrgather.vv with extra operands
+  // for passthru and VL. First operand is the passthru operand.
   VRGATHER_VX_VL,
   VRGATHER_VV_VL,
   VRGATHEREI16_VV_VL,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 78c81557aea97..5b1ce03984a2d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -143,30 +143,33 @@ def riscv_setcc_vl : SDNode<"RISCVISD::SETCC_VL",
                                                  SDTCisVT<5, XLenVT>]>>;
 
 def riscv_vrgather_vx_vl : SDNode<"RISCVISD::VRGATHER_VX_VL",
-                                  SDTypeProfile<1, 4, [SDTCisVec<0>,
+                                  SDTypeProfile<1, 5, [SDTCisVec<0>,
                                                        SDTCisSameAs<0, 1>,
-                                                       SDTCisVT<2, XLenVT>,
-                                                       SDTCVecEltisVT<3, i1>,
-                                                       SDTCisSameNumEltsAs<0, 3>,
-                                                       SDTCisVT<4, XLenVT>]>>;
+                                                       SDTCisSameAs<0, 2>,
+                                                       SDTCisVT<3, XLenVT>,
+                                                       SDTCVecEltisVT<4, i1>,
+                                                       SDTCisSameNumEltsAs<0, 4>,
+                                                       SDTCisVT<5, XLenVT>]>>;
 def riscv_vrgather_vv_vl : SDNode<"RISCVISD::VRGATHER_VV_VL",
-                                  SDTypeProfile<1, 4, [SDTCisVec<0>,
+                                  SDTypeProfile<1, 5, [SDTCisVec<0>,
                                                        SDTCisSameAs<0, 1>,
-                                                       SDTCisInt<2>,
-                                                       SDTCisSameNumEltsAs<0, 2>,
-                                                       SDTCisSameSizeAs<0, 2>,
-                                                       SDTCVecEltisVT<3, i1>,
+                                                       SDTCisSameAs<0, 2>,
+                                                       SDTCisInt<3>,
                                                        SDTCisSameNumEltsAs<0, 3>,
-                                                       SDTCisVT<4, XLenVT>]>>;
+                                                       SDTCisSameSizeAs<0, 3>,
+                                                       SDTCVecEltisVT<4, i1>,
+                                                       SDTCisSameNumEltsAs<0, 4>,
+                                                       SDTCisVT<5, XLenVT>]>>;
 def riscv_vrgatherei16_vv_vl : SDNode<"RISCVISD::VRGATHEREI16_VV_VL",
-                                      SDTypeProfile<1, 4, [SDTCisVec<0>,
+                                      SDTypeProfile<1, 5, [SDTCisVec<0>,
                                                            SDTCisSameAs<0, 1>,
-                                                           SDTCisInt<2>,
-                                                           SDTCVecEltisVT<2, i16>,
-                                                           SDTCisSameNumEltsAs<0, 2>,
-                                                           SDTCVecEltisVT<3, i1>,
+                                                           SDTCisSameAs<0, 2>,
+                                                           SDTCisInt<3>,
+                                                           SDTCVecEltisVT<3, i16>,
                                                            SDTCisSameNumEltsAs<0, 3>,
-                                                           SDTCisVT<4, XLenVT>]>>;
+                                                           SDTCVecEltisVT<4, i1>,
+                                                           SDTCisSameNumEltsAs<0, 4>,
+                                                           SDTCisVT<5, XLenVT>]>>;
 
 def SDT_RISCVSelect_VL  : SDTypeProfile<1, 4, [
   SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>,
@@ -1835,43 +1838,40 @@ foreach vti = AllIntegerVectors in {
             (!cast<Instruction>("PseudoVMV_S_X_"#vti.LMul.MX)
                 vti.RegClass:$merge,
                 (vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>;
-  def : Pat<(vti.Vector (riscv_vrgather_vv_vl vti.RegClass:$rs2,
+  def : Pat<(vti.Vector (riscv_vrgather_vv_vl (vti.Vector srcvalue),
+                                              vti.RegClass:$rs2,
                                               (vti.Vector vti.RegClass:$rs1),
                                               (vti.Mask true_mask),
                                               VLOpFrag)),
             (!cast<Instruction>("PseudoVRGATHER_VV_"# vti.LMul.MX)
                  vti.RegClass:$rs2, vti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>;
-  def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1,
+  def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue),
+                                              vti.RegClass:$rs2, GPR:$rs1,
                                               (vti.Mask true_mask),
                                               VLOpFrag)),
             (!cast<Instruction>("PseudoVRGATHER_VX_"# vti.LMul.MX)
                  vti.RegClass:$rs2, GPR:$rs1, GPR:$vl, vti.Log2SEW)>;
-  def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, uimm5:$imm,
+  def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue),
+                                              vti.RegClass:$rs2, uimm5:$imm,
                                               (vti.Mask true_mask),
                                               VLOpFrag)),
             (!cast<Instruction>("PseudoVRGATHER_VI_"# vti.LMul.MX)
                  vti.RegClass:$rs2, uimm5:$imm, GPR:$vl, vti.Log2SEW)>;
 
-  def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
-                                          (riscv_vrgather_vv_vl
-                                            vti.RegClass:$rs2,
-                                            vti.RegClass:$rs1,
-                                            (vti.Mask true_mask),
-                                            VLOpFrag),
-                                          vti.RegClass:$merge,
-                                          VLOpFrag)),
+  def : Pat<(vti.Vector (riscv_vrgather_vv_vl vti.RegClass:$merge,
+                                              vti.RegClass:$rs2,
+                                              vti.RegClass:$rs1,
+                                              (vti.Mask V0),
+                                              VLOpFrag)),
             (!cast<Instruction>("PseudoVRGATHER_VV_"# vti.LMul.MX#"_MASK")
                  vti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
                  (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
 
-  def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
-                                          (riscv_vrgather_vx_vl
-                                            vti.RegClass:$rs2,
-                                            uimm5:$imm,
-                                            (vti.Mask true_mask),
-                                            VLOpFrag),
-                                          vti.RegClass:$merge,
-                                          VLOpFrag)),
+  def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$merge,
+                                              vti.RegClass:$rs2,
+                                              uimm5:$imm,
+                                              (vti.Mask V0),
+                                              VLOpFrag)),
             (!cast<Instruction>("PseudoVRGATHER_VI_"# vti.LMul.MX#"_MASK")
                  vti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$imm,
                  (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
@@ -1884,21 +1884,20 @@ foreach vti = AllIntegerVectors in {
     defvar emul_str = octuple_to_str<octuple_emul>.ret;
     defvar ivti = !cast<VTypeInfo>("VI16" # emul_str);
     defvar inst = "PseudoVRGATHEREI16_VV_" # vti.LMul.MX # "_" # emul_str;
-    def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl vti.RegClass:$rs2,
+    def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl (vti.Vector srcvalue),
+                                                    vti.RegClass:$rs2,
                                                     (ivti.Vector ivti.RegClass:$rs1),
                                                     (vti.Mask true_mask),
                                                     VLOpFrag)),
               (!cast<Instruction>(inst)
                    vti.RegClass:$rs2, ivti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>;
 
-    def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
-                                            (riscv_vrgatherei16_vv_vl
-                                              vti.RegClass:$rs2,
-                                              (ivti.Vector ivti.RegClass:$rs1),
-                                              (vti.Mask true_mask),
-                                              VLOpFrag),
-                                            vti.RegClass:$merge,
-                                            VLOpFrag)),
+    def : Pat<(vti.Vector
+               (riscv_vrgatherei16_vv_vl vti.RegClass:$merge,
+                                         vti.RegClass:$rs2,
+                                         (ivti.Vector ivti.RegClass:$rs1),
+                                         (vti.Mask V0),
+                                         VLOpFrag)),
               (!cast<Instruction>(inst#"_MASK")
                    vti.RegClass:$merge, vti.RegClass:$rs2, ivti.RegClass:$rs1,
                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
@@ -1923,43 +1922,42 @@ foreach vti = AllFloatVectors in {
                 vti.RegClass:$merge,
                 (vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>;
   defvar ivti = GetIntVTypeInfo<vti>.Vti;
-  def : Pat<(vti.Vector (riscv_vrgather_vv_vl vti.RegClass:$rs2,
+  def : Pat<(vti.Vector (riscv_vrgather_vv_vl (vti.Vector srcvalue),
+                                              vti.RegClass:$rs2,
                                               (ivti.Vector vti.RegClass:$rs1),
                                               (vti.Mask true_mask),
                                               VLOpFrag)),
             (!cast<Instruction>("PseudoVRGATHER_VV_"# vti.LMul.MX)
                  vti.RegClass:$rs2, vti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>;
-  def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1,
+  def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue),
+                                              vti.RegClass:$rs2, GPR:$rs1,
                                               (vti.Mask true_mask),
                                               VLOpFrag)),
             (!cast<Instruction>("PseudoVRGATHER_VX_"# vti.LMul.MX)
                  vti.RegClass:$rs2, GPR:$rs1, GPR:$vl, vti.Log2SEW)>;
-  def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, uimm5:$imm,
+  def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue),
+                                              vti.RegClass:$rs2, uimm5:$imm,
                                               (vti.Mask true_mask),
                                               VLOpFrag)),
             (!cast<Instruction>("PseudoVRGATHER_VI_"# vti.LMul.MX)
                  vti.RegClass:$rs2, uimm5:$imm, GPR:$vl, vti.Log2SEW)>;
 
-  def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
-                                          (riscv_vrgather_vv_vl
-                                            vti.RegClass:$rs2,
-                                            (ivti.Vector vti.RegClass:$rs1),
-                                            (vti.Mask true_mask),
-                                            VLOpFrag),
-                                          vti.RegClass:$merge,
-                                          VLOpFrag)),
+  def : Pat<(vti.Vector
+             (riscv_vrgather_vv_vl vti.RegClass:$merge,
+                                   vti.RegClass:$rs2,
+                                   (ivti.Vector vti.RegClass:$rs1),
+                                   (vti.Mask V0),
+                                   VLOpFrag)),
             (!cast<Instruction>("PseudoVRGATHER_VV_"# vti.LMul.MX#"_MASK")
                  vti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
                  (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
 
-  def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
-                                          (riscv_vrgather_vx_vl
-                                            vti.RegClass:$rs2,
-                                            uimm5:$imm,
-                                            (vti.Mask true_mask),
-                                            VLOpFrag),
-                                          vti.RegClass:$merge,
-                                          VLOpFrag)),
+  def : Pat<(vti.Vector
+             (riscv_vrgather_vx_vl vti.RegClass:$merge,
+                                   vti.RegClass:$rs2,
+                                   uimm5:$imm,
+                                   (vti.Mask V0),
+                                   VLOpFrag)),
             (!cast<Instruction>("PseudoVRGATHER_VI_"# vti.LMul.MX#"_MASK")
                  vti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$imm,
                  (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
@@ -1971,21 +1969,20 @@ foreach vti = AllFloatVectors in {
     defvar emul_str = octuple_to_str<octuple_emul>.ret;
     defvar ivti = !cast<VTypeInfo>("VI16" # emul_str);
     defvar inst = "PseudoVRGATHEREI16_VV_" # vti.LMul.MX # "_" # emul_str;
-    def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl vti.RegClass:$rs2,
+    def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl (vti.Vector srcvalue),
+                                                    vti.RegClass:$rs2,
                                                     (ivti.Vector ivti.RegClass:$rs1),
                                                     (vti.Mask true_mask),
                                                     VLOpFrag)),
               (!cast<Instruction>(inst)
                    vti.RegClass:$rs2, ivti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>;
 
-    def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
-                                            (riscv_vrgatherei16_vv_vl
-                                              vti.RegClass:$rs2,
-                                              (ivti.Vector ivti.RegClass:$rs1),
-                                              (vti.Mask true_mask),
-                                              VLOpFrag),
-                                            vti.RegClass:$merge,
-                                            VLOpFrag)),
+    def : Pat<(vti.Vector
+               (riscv_vrgatherei16_vv_vl vti.RegClass:$merge,
+                                         vti.RegClass:$rs2,
+                                         (ivti.Vector ivti.RegClass:$rs1),
+                                         (vti.Mask V0),
+                                         VLOpFrag)),
               (!cast<Instruction>(inst#"_MASK")
                    vti.RegClass:$merge, vti.RegClass:$rs2, ivti.RegClass:$rs1,
                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;


        


More information about the llvm-commits mailing list