[llvm] 5d6d649 - [RISCV][NFC] Simplify lowerVPOp.

Jianjian GUAN via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 27 01:42:29 PDT 2023


Author: Jianjian GUAN
Date: 2023-07-27T16:42:20+08:00
New Revision: 5d6d6493ff74139c63f3d0e60ed29633217690dc

URL: https://github.com/llvm/llvm-project/commit/5d6d6493ff74139c63f3d0e60ed29633217690dc
DIFF: https://github.com/llvm/llvm-project/commit/5d6d6493ff74139c63f3d0e60ed29633217690dc.diff

LOG: [RISCV][NFC] Simplify lowerVPOp.

This patch is similar to https://reviews.llvm.org/D153948, using helper function to get ISD and information.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D154411

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 34eee88d398c40..abfd56b8400952 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -4664,10 +4664,13 @@ static unsigned getRISCVVLOp(SDValue Op) {
 #define OP_CASE(NODE)                                                          \
   case ISD::NODE:                                                              \
     return RISCVISD::NODE##_VL;
+#define VP_CASE(NODE)                                                          \
+  case ISD::VP_##NODE:                                                         \
+    return RISCVISD::NODE##_VL;
+  // clang-format off
   switch (Op.getOpcode()) {
   default:
     llvm_unreachable("don't have RISC-V specified VL op for this SDNode");
-    // clang-format off
   OP_CASE(ADD)
   OP_CASE(SUB)
   OP_CASE(MUL)
@@ -4702,25 +4705,81 @@ static unsigned getRISCVVLOp(SDValue Op) {
   OP_CASE(STRICT_FMUL)
   OP_CASE(STRICT_FDIV)
   OP_CASE(STRICT_FSQRT)
-    // clang-format on
-#undef OP_CASE
+  VP_CASE(ADD)        // VP_ADD
+  VP_CASE(SUB)        // VP_SUB
+  VP_CASE(MUL)        // VP_MUL
+  VP_CASE(SDIV)       // VP_SDIV
+  VP_CASE(SREM)       // VP_SREM
+  VP_CASE(UDIV)       // VP_UDIV
+  VP_CASE(UREM)       // VP_UREM
+  VP_CASE(SHL)        // VP_SHL
+  VP_CASE(FADD)       // VP_FADD
+  VP_CASE(FSUB)       // VP_FSUB
+  VP_CASE(FMUL)       // VP_FMUL
+  VP_CASE(FDIV)       // VP_FDIV
+  VP_CASE(FNEG)       // VP_FNEG
+  VP_CASE(FABS)       // VP_FABS
+  VP_CASE(SMIN)       // VP_SMIN
+  VP_CASE(SMAX)       // VP_SMAX
+  VP_CASE(UMIN)       // VP_UMIN
+  VP_CASE(UMAX)       // VP_UMAX
+  VP_CASE(FMINNUM)    // VP_FMINNUM
+  VP_CASE(FMAXNUM)    // VP_FMAXNUM
+  VP_CASE(FCOPYSIGN)  // VP_FCOPYSIGN
+  VP_CASE(SETCC)      // VP_SETCC
+  VP_CASE(SINT_TO_FP) // VP_SINT_TO_FP
+  VP_CASE(UINT_TO_FP) // VP_UINT_TO_FP
+  VP_CASE(BITREVERSE) // VP_BITREVERSE
+  VP_CASE(BSWAP)      // VP_BSWAP
+  VP_CASE(CTLZ)       // VP_CTLZ
+  VP_CASE(CTTZ)       // VP_CTTZ
+  VP_CASE(CTPOP)      // VP_CTPOP
+  case ISD::VP_CTLZ_ZERO_UNDEF:
+    return RISCVISD::CTLZ_VL;
+  case ISD::VP_CTTZ_ZERO_UNDEF:
+    return RISCVISD::CTTZ_VL;
   case ISD::FMA:
+  case ISD::VP_FMA:
     return RISCVISD::VFMADD_VL;
   case ISD::STRICT_FMA:
     return RISCVISD::STRICT_VFMADD_VL;
   case ISD::AND:
+  case ISD::VP_AND:
     if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
       return RISCVISD::VMAND_VL;
     return RISCVISD::AND_VL;
   case ISD::OR:
+  case ISD::VP_OR:
     if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
       return RISCVISD::VMOR_VL;
     return RISCVISD::OR_VL;
   case ISD::XOR:
+  case ISD::VP_XOR:
     if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
       return RISCVISD::VMXOR_VL;
     return RISCVISD::XOR_VL;
+  case ISD::VP_SELECT:
+    return RISCVISD::VSELECT_VL;
+  case ISD::VP_MERGE:
+    return RISCVISD::VP_MERGE_VL;
+  case ISD::VP_ASHR:
+    return RISCVISD::SRA_VL;
+  case ISD::VP_LSHR:
+    return RISCVISD::SRL_VL;
+  case ISD::VP_SQRT:
+    return RISCVISD::FSQRT_VL;
+  case ISD::VP_SIGN_EXTEND:
+    return RISCVISD::VSEXT_VL;
+  case ISD::VP_ZERO_EXTEND:
+    return RISCVISD::VZEXT_VL;
+  case ISD::VP_FP_TO_SINT:
+    return RISCVISD::VFCVT_RTZ_X_F_VL;
+  case ISD::VP_FP_TO_UINT:
+    return RISCVISD::VFCVT_RTZ_XU_F_VL;
   }
+  // clang-format on
+#undef OP_CASE
+#undef VP_CASE
 }
 
 /// Return true if a RISC-V target specified op has a merge operand.
@@ -4739,6 +4798,8 @@ static bool hasMergeOp(unsigned Opcode) {
     return true;
   if (Opcode >= RISCVISD::VWMUL_VL && Opcode <= RISCVISD::VFWSUB_W_VL)
     return true;
+  if (Opcode == RISCVISD::SETCC_VL)
+    return true;
   if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL)
     return true;
   return false;
@@ -5476,106 +5537,72 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   case ISD::EH_DWARF_CFA:
     return lowerEH_DWARF_CFA(Op, DAG);
   case ISD::VP_SELECT:
-    return lowerVPOp(Op, DAG, RISCVISD::VSELECT_VL);
   case ISD::VP_MERGE:
-    return lowerVPOp(Op, DAG, RISCVISD::VP_MERGE_VL);
   case ISD::VP_ADD:
-    return lowerVPOp(Op, DAG, RISCVISD::ADD_VL, /*HasMergeOp*/ true);
   case ISD::VP_SUB:
-    return lowerVPOp(Op, DAG, RISCVISD::SUB_VL, /*HasMergeOp*/ true);
   case ISD::VP_MUL:
-    return lowerVPOp(Op, DAG, RISCVISD::MUL_VL, /*HasMergeOp*/ true);
   case ISD::VP_SDIV:
-    return lowerVPOp(Op, DAG, RISCVISD::SDIV_VL, /*HasMergeOp*/ true);
   case ISD::VP_UDIV:
-    return lowerVPOp(Op, DAG, RISCVISD::UDIV_VL, /*HasMergeOp*/ true);
   case ISD::VP_SREM:
-    return lowerVPOp(Op, DAG, RISCVISD::SREM_VL, /*HasMergeOp*/ true);
   case ISD::VP_UREM:
-    return lowerVPOp(Op, DAG, RISCVISD::UREM_VL, /*HasMergeOp*/ true);
+    return lowerVPOp(Op, DAG);
   case ISD::VP_AND:
-    return lowerLogicVPOp(Op, DAG, RISCVISD::VMAND_VL, RISCVISD::AND_VL);
   case ISD::VP_OR:
-    return lowerLogicVPOp(Op, DAG, RISCVISD::VMOR_VL, RISCVISD::OR_VL);
   case ISD::VP_XOR:
-    return lowerLogicVPOp(Op, DAG, RISCVISD::VMXOR_VL, RISCVISD::XOR_VL);
+    return lowerLogicVPOp(Op, DAG);
   case ISD::VP_ASHR:
-    return lowerVPOp(Op, DAG, RISCVISD::SRA_VL, /*HasMergeOp*/ true);
   case ISD::VP_LSHR:
-    return lowerVPOp(Op, DAG, RISCVISD::SRL_VL, /*HasMergeOp*/ true);
   case ISD::VP_SHL:
-    return lowerVPOp(Op, DAG, RISCVISD::SHL_VL, /*HasMergeOp*/ true);
   case ISD::VP_FADD:
-    return lowerVPOp(Op, DAG, RISCVISD::FADD_VL, /*HasMergeOp*/ true);
   case ISD::VP_FSUB:
-    return lowerVPOp(Op, DAG, RISCVISD::FSUB_VL, /*HasMergeOp*/ true);
   case ISD::VP_FMUL:
-    return lowerVPOp(Op, DAG, RISCVISD::FMUL_VL, /*HasMergeOp*/ true);
   case ISD::VP_FDIV:
-    return lowerVPOp(Op, DAG, RISCVISD::FDIV_VL, /*HasMergeOp*/ true);
   case ISD::VP_FNEG:
-    return lowerVPOp(Op, DAG, RISCVISD::FNEG_VL);
   case ISD::VP_FABS:
-    return lowerVPOp(Op, DAG, RISCVISD::FABS_VL);
   case ISD::VP_SQRT:
-    return lowerVPOp(Op, DAG, RISCVISD::FSQRT_VL);
   case ISD::VP_FMA:
-    return lowerVPOp(Op, DAG, RISCVISD::VFMADD_VL);
   case ISD::VP_FMINNUM:
-    return lowerVPOp(Op, DAG, RISCVISD::FMINNUM_VL, /*HasMergeOp*/ true);
   case ISD::VP_FMAXNUM:
-    return lowerVPOp(Op, DAG, RISCVISD::FMAXNUM_VL, /*HasMergeOp*/ true);
   case ISD::VP_FCOPYSIGN:
-    return lowerVPOp(Op, DAG, RISCVISD::FCOPYSIGN_VL, /*HasMergeOp*/ true);
+    return lowerVPOp(Op, DAG);
   case ISD::VP_SIGN_EXTEND:
   case ISD::VP_ZERO_EXTEND:
     if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1)
       return lowerVPExtMaskOp(Op, DAG);
-    return lowerVPOp(Op, DAG,
-                     Op.getOpcode() == ISD::VP_SIGN_EXTEND
-                         ? RISCVISD::VSEXT_VL
-                         : RISCVISD::VZEXT_VL);
+    return lowerVPOp(Op, DAG);
   case ISD::VP_TRUNCATE:
     return lowerVectorTruncLike(Op, DAG);
   case ISD::VP_FP_EXTEND:
   case ISD::VP_FP_ROUND:
     return lowerVectorFPExtendOrRoundLike(Op, DAG);
   case ISD::VP_FP_TO_SINT:
-    return lowerVPFPIntConvOp(Op, DAG, RISCVISD::VFCVT_RTZ_X_F_VL);
   case ISD::VP_FP_TO_UINT:
-    return lowerVPFPIntConvOp(Op, DAG, RISCVISD::VFCVT_RTZ_XU_F_VL);
   case ISD::VP_SINT_TO_FP:
-    return lowerVPFPIntConvOp(Op, DAG, RISCVISD::SINT_TO_FP_VL);
   case ISD::VP_UINT_TO_FP:
-    return lowerVPFPIntConvOp(Op, DAG, RISCVISD::UINT_TO_FP_VL);
+    return lowerVPFPIntConvOp(Op, DAG);
   case ISD::VP_SETCC:
     if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1)
       return lowerVPSetCCMaskOp(Op, DAG);
-    return lowerVPOp(Op, DAG, RISCVISD::SETCC_VL, /*HasMergeOp*/ true);
+    [[fallthrough]];
   case ISD::VP_SMIN:
-    return lowerVPOp(Op, DAG, RISCVISD::SMIN_VL, /*HasMergeOp*/ true);
   case ISD::VP_SMAX:
-    return lowerVPOp(Op, DAG, RISCVISD::SMAX_VL, /*HasMergeOp*/ true);
   case ISD::VP_UMIN:
-    return lowerVPOp(Op, DAG, RISCVISD::UMIN_VL, /*HasMergeOp*/ true);
   case ISD::VP_UMAX:
-    return lowerVPOp(Op, DAG, RISCVISD::UMAX_VL, /*HasMergeOp*/ true);
   case ISD::VP_BITREVERSE:
-    return lowerVPOp(Op, DAG, RISCVISD::BITREVERSE_VL, /*HasMergeOp*/ true);
   case ISD::VP_BSWAP:
-    return lowerVPOp(Op, DAG, RISCVISD::BSWAP_VL, /*HasMergeOp*/ true);
+    return lowerVPOp(Op, DAG);
   case ISD::VP_CTLZ:
   case ISD::VP_CTLZ_ZERO_UNDEF:
     if (Subtarget.hasStdExtZvbb())
-      return lowerVPOp(Op, DAG, RISCVISD::CTLZ_VL, /*HasMergeOp*/ true);
+      return lowerVPOp(Op, DAG);
     return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
   case ISD::VP_CTTZ:
   case ISD::VP_CTTZ_ZERO_UNDEF:
     if (Subtarget.hasStdExtZvbb())
-      return lowerVPOp(Op, DAG, RISCVISD::CTTZ_VL, /*HasMergeOp*/ true);
+      return lowerVPOp(Op, DAG);
     return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
   case ISD::VP_CTPOP:
-    return lowerVPOp(Op, DAG, RISCVISD::CTPOP_VL, /*HasMergeOp*/ true);
+    return lowerVPOp(Op, DAG);
   case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
     return lowerVPStridedLoad(Op, DAG);
   case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
@@ -8827,9 +8854,10 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
 // * The EVL operand is promoted from i32 to i64 on RV64.
 // * Fixed-length vectors are converted to their scalable-vector container
 //   types.
-SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG,
-                                       unsigned RISCVISDOpc,
-                                       bool HasMergeOp) const {
+SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG) const {
+  unsigned RISCVISDOpc = getRISCVVLOp(Op);
+  bool HasMergeOp = hasMergeOp(RISCVISDOpc);
+
   SDLoc DL(Op);
   MVT VT = Op.getSimpleValueType();
   SmallVector<SDValue, 4> Ops;
@@ -8978,13 +9006,14 @@ SDValue RISCVTargetLowering::lowerVPSetCCMaskOp(SDValue Op,
 }
 
 // Lower Floating-Point/Integer Type-Convert VP SDNodes
-SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG,
-                                                unsigned RISCVISDOpc) const {
+SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op,
+                                                SelectionDAG &DAG) const {
   SDLoc DL(Op);
 
   SDValue Src = Op.getOperand(0);
   SDValue Mask = Op.getOperand(1);
   SDValue VL = Op.getOperand(2);
+  unsigned RISCVISDOpc = getRISCVVLOp(Op);
 
   MVT DstVT = Op.getSimpleValueType();
   MVT SrcVT = Src.getSimpleValueType();
@@ -9110,12 +9139,11 @@ SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG,
   return convertFromScalableVector(VT, Result, DAG, Subtarget);
 }
 
-SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op, SelectionDAG &DAG,
-                                            unsigned MaskOpc,
-                                            unsigned VecOpc) const {
+SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op,
+                                            SelectionDAG &DAG) const {
   MVT VT = Op.getSimpleValueType();
   if (VT.getVectorElementType() != MVT::i1)
-    return lowerVPOp(Op, DAG, VecOpc, true);
+    return lowerVPOp(Op, DAG);
 
   // It is safe to drop mask parameter as masked-off elements are undef.
   SDValue Op1 = Op->getOperand(0);
@@ -9131,7 +9159,7 @@ SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op, SelectionDAG &DAG,
   }
 
   SDLoc DL(Op);
-  SDValue Val = DAG.getNode(MaskOpc, DL, ContainerVT, Op1, Op2, VL);
+  SDValue Val = DAG.getNode(getRISCVVLOp(Op), DL, ContainerVT, Op1, Op2, VL);
   if (!IsFixed)
     return Val;
   return convertFromScalableVector(VT, Val, DAG, Subtarget);

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 1809e1c57c9210..26475bf424472d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -878,14 +878,11 @@ class RISCVTargetLowering : public TargetLowering {
                                             SelectionDAG &DAG) const;
   SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const;
-  SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG, unsigned RISCVISDOpc,
-                    bool HasMergeOp = false) const;
-  SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG, unsigned MaskOpc,
-                         unsigned VecOpc) const;
+  SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVPExtMaskOp(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVPSetCCMaskOp(SDValue Op, SelectionDAG &DAG) const;
-  SDValue lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG,
-                             unsigned RISCVISDOpc) const;
+  SDValue lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVPStridedLoad(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVPStridedStore(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG,


        


More information about the llvm-commits mailing list