[llvm] 4bd5bd4 - [RISCV] Convert VSLIDE(UP|DOWN) nodes to "VL" versions (NFC)

Fraser Cormack via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 15 02:39:08 PST 2021


Author: Fraser Cormack
Date: 2021-02-15T10:32:56Z
New Revision: 4bd5bd40094c7b8b691cf394d813efc48d82acfd

URL: https://github.com/llvm/llvm-project/commit/4bd5bd40094c7b8b691cf394d813efc48d82acfd
DIFF: https://github.com/llvm/llvm-project/commit/4bd5bd40094c7b8b691cf394d813efc48d82acfd.diff

LOG: [RISCV] Convert VSLIDE(UP|DOWN) nodes to "VL" versions (NFC)

This patch prepares the RISCV VSLIDEUP and VSLIDEDOWN custom nodes to
ones carrying additional mask and vector-length operands. This is
primarily so they can be used by both systems.

This also takes the opportunity to create some helper functions to deal
with the common task of getting the default (unmasked) VL operands.

Reviewed By: craig.topper, arcbbb

Differential Revision: https://reviews.llvm.org/D96505

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index bb297efaf297..c0685e670a7e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -837,6 +837,30 @@ static SDValue convertFromScalableVector(EVT VT, SDValue V, SelectionDAG &DAG,
   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero);
 }
 
+// Gets the two common "VL" operands: an all-ones mask and the vector length.
+// VecVT is a vector type, either fixed-length or scalable, and ContainerVT is
+// the vector type that it is contained in.
+static std::pair<SDValue, SDValue>
+getDefaultVLOps(MVT VecVT, MVT ContainerVT, SDLoc DL, SelectionDAG &DAG,
+                const RISCVSubtarget &Subtarget) {
+  assert(ContainerVT.isScalableVector() && "Expecting scalable container type");
+  MVT XLenVT = Subtarget.getXLenVT();
+  SDValue VL = VecVT.isFixedLengthVector()
+                   ? DAG.getConstant(VecVT.getVectorNumElements(), DL, XLenVT)
+                   : DAG.getRegister(RISCV::X0, XLenVT);
+  MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
+  SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
+  return {Mask, VL};
+}
+
+// As above but assuming the given type is a scalable vector type.
+static std::pair<SDValue, SDValue>
+getDefaultScalableVLOps(MVT VecVT, SDLoc DL, SelectionDAG &DAG,
+                        const RISCVSubtarget &Subtarget) {
+  assert(VecVT.isScalableVector() && "Expecting a scalable vector");
+  return getDefaultVLOps(VecVT, VecVT, DL, DAG, Subtarget);
+}
+
 static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
   MVT VT = Op.getSimpleValueType();
@@ -845,8 +869,8 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
   MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
 
   SDLoc DL(Op);
-  SDValue VL =
-      DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT());
+  SDValue Mask, VL;
+  std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
 
   if (VT.getVectorElementType() == MVT::i1) {
     if (ISD::isBuildVectorAllZeros(Op.getNode())) {
@@ -879,8 +903,6 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
                 Op.getConstantOperandVal(i) == i);
 
   if (IsVID) {
-    MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
-    SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
     SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, ContainerVT, Mask, VL);
     return convertFromScalableVector(VT, VID, DAG, Subtarget);
   }
@@ -903,11 +925,9 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
       V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget);
       assert(Lane < (int)VT.getVectorNumElements() && "Unexpected lane!");
 
+      SDValue Mask, VL;
+      std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
       MVT XLenVT = Subtarget.getXLenVT();
-      SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
-      MVT MaskVT =
-          MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
-      SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
       SDValue Gather =
           DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, V1,
                       DAG.getConstant(Lane, DL, XLenVT), Mask, VL);
@@ -1768,7 +1788,7 @@ SDValue RISCVTargetLowering::lowerVectorMaskTrunc(SDValue Op,
 SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
                                                     SelectionDAG &DAG) const {
   SDLoc DL(Op);
-  EVT VecVT = Op.getValueType();
+  MVT VecVT = Op.getSimpleValueType();
   SDValue Vec = Op.getOperand(0);
   SDValue Val = Op.getOperand(1);
   SDValue Idx = Op.getOperand(2);
@@ -1780,13 +1800,16 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
   if (Subtarget.is64Bit() || VecVT.getVectorElementType() != MVT::i64) {
     if (isNullConstant(Idx))
       return Op;
-    SDValue Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN, DL, VecVT,
-                                    DAG.getUNDEF(VecVT), Vec, Idx);
+    SDValue Mask, VL;
+    std::tie(Mask, VL) = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget);
+    SDValue Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, VecVT,
+                                    DAG.getUNDEF(VecVT), Vec, Idx, Mask, VL);
     SDValue InsertElt0 =
         DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, Slidedown, Val,
                     DAG.getConstant(0, DL, Subtarget.getXLenVT()));
 
-    return DAG.getNode(RISCVISD::VSLIDEUP, DL, VecVT, Vec, InsertElt0, Idx);
+    return DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, VecVT, Vec, InsertElt0, Idx,
+                       Mask, VL);
   }
 
   // Custom-legalize INSERT_VECTOR_ELT where XLEN<SEW, as the SEW element type
@@ -1803,9 +1826,8 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
   SDValue SplattedVal = DAG.getSplatVector(VecVT, DL, Val);
   SDValue SplattedIdx = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Idx);
 
-  SDValue VL = DAG.getRegister(RISCV::X0, Subtarget.getXLenVT());
-  MVT MaskVT = MVT::getVectorVT(MVT::i1, VecVT.getVectorElementCount());
-  SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
+  SDValue Mask, VL;
+  std::tie(Mask, VL) = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget);
   SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, VecVT, Mask, VL);
   auto SetCCVT =
       getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VecVT);
@@ -1824,13 +1846,15 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
   SDValue Idx = Op.getOperand(1);
   SDValue Vec = Op.getOperand(0);
   EVT EltVT = Op.getValueType();
-  EVT VecVT = Vec.getValueType();
+  MVT VecVT = Vec.getSimpleValueType();
   MVT XLenVT = Subtarget.getXLenVT();
 
   // If the index is 0, the vector is already in the right position.
   if (!isNullConstant(Idx)) {
-    Vec = DAG.getNode(RISCVISD::VSLIDEDOWN, DL, VecVT, DAG.getUNDEF(VecVT), Vec,
-                      Idx);
+    SDValue Mask, VL;
+    std::tie(Mask, VL) = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget);
+    Vec = DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, VecVT, DAG.getUNDEF(VecVT),
+                      Vec, Idx, Mask, VL);
   }
 
   if (!EltVT.isInteger()) {
@@ -2146,10 +2170,8 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
   }
 
   SDLoc DL(Op);
-  SDValue VL =
-      DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT());
-  MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
-  SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
+  SDValue Mask, VL;
+  std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
   Ops.push_back(Mask);
   Ops.push_back(VL);
 
@@ -2398,19 +2420,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
     SDLoc DL(N);
     SDValue Vec = N->getOperand(0);
     SDValue Idx = N->getOperand(1);
-    EVT VecVT = Vec.getValueType();
+    MVT VecVT = Vec.getSimpleValueType();
     assert(!Subtarget.is64Bit() && N->getValueType(0) == MVT::i64 &&
            VecVT.getVectorElementType() == MVT::i64 &&
            "Unexpected EXTRACT_VECTOR_ELT legalization");
 
     SDValue Slidedown = Vec;
+    MVT XLenVT = Subtarget.getXLenVT();
     // Unless the index is known to be 0, we must slide the vector down to get
     // the desired element into index 0.
-    if (!isNullConstant(Idx))
-      Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN, DL, VecVT,
-                              DAG.getUNDEF(VecVT), Vec, Idx);
+    if (!isNullConstant(Idx)) {
+      SDValue Mask, VL;
+      std::tie(Mask, VL) = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget);
+      Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, VecVT,
+                              DAG.getUNDEF(VecVT), Vec, Idx, Mask, VL);
+    }
 
-    MVT XLenVT = Subtarget.getXLenVT();
     // Extract the lower XLEN bits of the correct vector element.
     SDValue EltLo = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Slidedown, Idx);
 
@@ -4713,8 +4738,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(TRUNCATE_VECTOR)
   NODE_NAME_CASE(VLEFF)
   NODE_NAME_CASE(VLEFF_MASK)
-  NODE_NAME_CASE(VSLIDEUP)
-  NODE_NAME_CASE(VSLIDEDOWN)
+  NODE_NAME_CASE(VSLIDEUP_VL)
+  NODE_NAME_CASE(VSLIDEDOWN_VL)
   NODE_NAME_CASE(VID_VL)
   NODE_NAME_CASE(VFNCVT_ROD)
   NODE_NAME_CASE(VECREDUCE_ADD)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 9356f35e5899..08b2da1802b3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -109,10 +109,11 @@ enum NodeType : unsigned {
   VLEFF,
   VLEFF_MASK,
   // Matches the semantics of vslideup/vslidedown. The first operand is the
-  // pass-thru operand, the second is the source vector, and the third is the
-  // XLenVT index (either constant or non-constant).
-  VSLIDEUP,
-  VSLIDEDOWN,
+  // pass-thru operand, the second is the source vector, the third is the
+  // XLenVT index (either constant or non-constant), the fourth is the mask
+  // and the fifth the VL.
+  VSLIDEUP_VL,
+  VSLIDEDOWN_VL,
   // Matches the semantics of the vid.v instruction, with a mask and VL
   // operand.
   VID_VL,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 942b4e2e223a..d116908f38ec 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -820,43 +820,3 @@ foreach vti = AllFloatVectors in {
                     (vti.Scalar vti.ScalarRegClass:$rs1),
                     vti.AVL, vti.SEW)>;
 }
-
-def SDTRVVSlide : SDTypeProfile<1, 3, [
-  SDTCisVec<0>, SDTCisSameAs<1, 0>, SDTCisSameAs<2, 0>, SDTCisVT<3, XLenVT>
-]>;
-
-def riscv_slideup : SDNode<"RISCVISD::VSLIDEUP", SDTRVVSlide, []>;
-def riscv_slidedown : SDNode<"RISCVISD::VSLIDEDOWN", SDTRVVSlide, []>;
-
-let Predicates = [HasStdExtV] in {
-
-foreach vti = !listconcat(AllIntegerVectors, AllFloatVectors) in {
-    def : Pat<(vti.Vector (riscv_slideup (vti.Vector vti.RegClass:$rs3),
-                                         (vti.Vector vti.RegClass:$rs1),
-                                         uimm5:$rs2)),
-              (!cast<Instruction>("PseudoVSLIDEUP_VI_"#vti.LMul.MX)
-                  vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2,
-                  vti.AVL, vti.SEW)>;
-
-    def : Pat<(vti.Vector (riscv_slideup (vti.Vector vti.RegClass:$rs3),
-                                         (vti.Vector vti.RegClass:$rs1),
-                                         GPR:$rs2)),
-              (!cast<Instruction>("PseudoVSLIDEUP_VX_"#vti.LMul.MX)
-                  vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2,
-                  vti.AVL, vti.SEW)>;
-
-    def : Pat<(vti.Vector (riscv_slidedown (vti.Vector vti.RegClass:$rs3),
-                                           (vti.Vector vti.RegClass:$rs1),
-                                           uimm5:$rs2)),
-              (!cast<Instruction>("PseudoVSLIDEDOWN_VI_"#vti.LMul.MX)
-                  vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2,
-                  vti.AVL, vti.SEW)>;
-
-    def : Pat<(vti.Vector (riscv_slidedown (vti.Vector vti.RegClass:$rs3),
-                                           (vti.Vector vti.RegClass:$rs1),
-                                           GPR:$rs2)),
-              (!cast<Instruction>("PseudoVSLIDEDOWN_VX_"#vti.LMul.MX)
-                  vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2,
-                  vti.AVL, vti.SEW)>;
-}
-} // Predicates = [HasStdExtV]

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 321a80a92389..fad5c89cf7df 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -543,6 +543,14 @@ def riscv_vid_vl : SDNode<"RISCVISD::VID_VL", SDTypeProfile<1, 2,
                           [SDTCisVec<0>, SDTCisVec<1>, SDTCVecEltisVT<1, i1>,
                            SDTCisSameNumEltsAs<0, 1>, SDTCisVT<2, XLenVT>]>, []>;
 
+def SDTRVVSlide : SDTypeProfile<1, 5, [
+  SDTCisVec<0>, SDTCisSameAs<1, 0>, SDTCisSameAs<2, 0>, SDTCisVT<3, XLenVT>,
+  SDTCVecEltisVT<4, i1>, SDTCisSameNumEltsAs<0, 4>, SDTCisVT<5, XLenVT>
+]>;
+
+def riscv_slideup_vl   : SDNode<"RISCVISD::VSLIDEUP_VL", SDTRVVSlide, []>;
+def riscv_slidedown_vl : SDNode<"RISCVISD::VSLIDEDOWN_VL", SDTRVVSlide, []>;
+
 let Predicates = [HasStdExtV] in {
 
 foreach vti = AllIntegerVectors in
@@ -550,4 +558,38 @@ foreach vti = AllIntegerVectors in
                                       (XLenVT (VLOp GPR:$vl)))),
             (!cast<Instruction>("PseudoVID_V_"#vti.LMul.MX) GPR:$vl, vti.SEW)>;
 
+foreach vti = !listconcat(AllIntegerVectors, AllFloatVectors) in {
+  def : Pat<(vti.Vector (riscv_slideup_vl (vti.Vector vti.RegClass:$rs3),
+                                          (vti.Vector vti.RegClass:$rs1),
+                                          uimm5:$rs2, (vti.Mask true_mask),
+                                          (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVSLIDEUP_VI_"#vti.LMul.MX)
+                vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2,
+                GPR:$vl, vti.SEW)>;
+
+  def : Pat<(vti.Vector (riscv_slideup_vl (vti.Vector vti.RegClass:$rs3),
+                                          (vti.Vector vti.RegClass:$rs1),
+                                          GPR:$rs2, (vti.Mask true_mask),
+                                          (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVSLIDEUP_VX_"#vti.LMul.MX)
+                vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2,
+                GPR:$vl, vti.SEW)>;
+
+  def : Pat<(vti.Vector (riscv_slidedown_vl (vti.Vector vti.RegClass:$rs3),
+                                            (vti.Vector vti.RegClass:$rs1),
+                                            uimm5:$rs2, (vti.Mask true_mask),
+                                            (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVSLIDEDOWN_VI_"#vti.LMul.MX)
+                vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2,
+                GPR:$vl, vti.SEW)>;
+
+  def : Pat<(vti.Vector (riscv_slidedown_vl (vti.Vector vti.RegClass:$rs3),
+                                            (vti.Vector vti.RegClass:$rs1),
+                                            GPR:$rs2, (vti.Mask true_mask),
+                                            (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVSLIDEDOWN_VX_"#vti.LMul.MX)
+                vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2,
+                GPR:$vl, vti.SEW)>;
+}
+
 } // Predicates = [HasStdExtV]


        


More information about the llvm-commits mailing list