[llvm] 1d8db2f - [RISCV][NFC] Refactor lowerToScalableOp.

Jianjian GUAN via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 3 23:09:32 PDT 2023


Author: Jianjian GUAN
Date: 2023-07-04T14:09:24+08:00
New Revision: 1d8db2fab30369594c72ac33357eca8baeb71199

URL: https://github.com/llvm/llvm-project/commit/1d8db2fab30369594c72ac33357eca8baeb71199
DIFF: https://github.com/llvm/llvm-project/commit/1d8db2fab30369594c72ac33357eca8baeb71199.diff

LOG: [RISCV][NFC] Refactor lowerToScalableOp.

Refactor lowerToScalableOp to combine switch case code.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D153948

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index be824d7a4f2502..506b1c52529ee0 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -4564,6 +4564,105 @@ SDValue RISCVTargetLowering::LowerIS_FPCLASS(SDValue Op,
                       ISD::CondCode::SETNE);
 }
 
+/// Get a RISCV target specified VL op for a given SDNode.
+static unsigned getRISCVVLOp(SDValue Op) {
+#define OP_CASE(NODE)                                                          \
+  case ISD::NODE:                                                              \
+    return RISCVISD::NODE##_VL;
+  switch (Op.getOpcode()) {
+  default:
+    llvm_unreachable("don't have RISC-V specified VL op for this SDNode");
+    // clang-format off
+  OP_CASE(ADD)
+  OP_CASE(SUB)
+  OP_CASE(MUL)
+  OP_CASE(MULHS)
+  OP_CASE(MULHU)
+  OP_CASE(SDIV)
+  OP_CASE(SREM)
+  OP_CASE(UDIV)
+  OP_CASE(UREM)
+  OP_CASE(SHL)
+  OP_CASE(SRA)
+  OP_CASE(SRL)
+  OP_CASE(SADDSAT)
+  OP_CASE(UADDSAT)
+  OP_CASE(SSUBSAT)
+  OP_CASE(USUBSAT)
+  OP_CASE(FADD)
+  OP_CASE(FSUB)
+  OP_CASE(FMUL)
+  OP_CASE(FDIV)
+  OP_CASE(FNEG)
+  OP_CASE(FABS)
+  OP_CASE(FSQRT)
+  OP_CASE(SMIN)
+  OP_CASE(SMAX)
+  OP_CASE(UMIN)
+  OP_CASE(UMAX)
+  OP_CASE(FMINNUM)
+  OP_CASE(FMAXNUM)
+  OP_CASE(STRICT_FADD)
+  OP_CASE(STRICT_FSUB)
+  OP_CASE(STRICT_FMUL)
+  OP_CASE(STRICT_FDIV)
+  OP_CASE(STRICT_FSQRT)
+    // clang-format on
+#undef OP_CASE
+  case ISD::FMA:
+    return RISCVISD::VFMADD_VL;
+  case ISD::STRICT_FMA:
+    return RISCVISD::STRICT_VFMADD_VL;
+  case ISD::AND:
+    if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
+      return RISCVISD::VMAND_VL;
+    return RISCVISD::AND_VL;
+  case ISD::OR:
+    if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
+      return RISCVISD::VMOR_VL;
+    return RISCVISD::OR_VL;
+  case ISD::XOR:
+    if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
+      return RISCVISD::VMXOR_VL;
+    return RISCVISD::XOR_VL;
+  }
+}
+
+/// Return true if a RISC-V target specified op has a merge operand.
+static bool hasMergeOp(unsigned Opcode) {
+  assert(Opcode > RISCVISD::FIRST_NUMBER &&
+         Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL &&
+         "not a RISC-V target specific op");
+  assert(RISCVISD::STRICT_VFROUND_NOEXCEPT_VL - RISCVISD::FIRST_NUMBER == 421 &&
+         "adding target specific op should update this function");
+  if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::FMAXNUM_VL)
+    return true;
+  if (Opcode == RISCVISD::FCOPYSIGN_VL)
+    return true;
+  if (Opcode >= RISCVISD::VWMUL_VL && Opcode <= RISCVISD::VFWSUB_W_VL)
+    return true;
+  if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL)
+    return true;
+  return false;
+}
+
+/// Return true if a RISC-V target specified op has a mask operand.
+static bool hasMaskOp(unsigned Opcode) {
+  assert(Opcode > RISCVISD::FIRST_NUMBER &&
+         Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL &&
+         "not a RISC-V target specific op");
+  assert(RISCVISD::STRICT_VFROUND_NOEXCEPT_VL - RISCVISD::FIRST_NUMBER == 421 &&
+         "adding target specific op should update this function");
+  if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
+    return true;
+  if (Opcode >= RISCVISD::VRGATHER_VX_VL && Opcode <= RISCVISD::VFIRST_VL)
+    return true;
+  if (Opcode >= RISCVISD::STRICT_FADD_VL &&
+      Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL)
+    return true;
+  return false;
+}
+
 SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
                                             SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -5154,83 +5253,46 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     return lowerFixedLengthVectorSetccToRVV(Op, DAG);
   }
   case ISD::ADD:
-    return lowerToScalableOp(Op, DAG, RISCVISD::ADD_VL, /*HasMergeOp*/ true);
   case ISD::SUB:
-    return lowerToScalableOp(Op, DAG, RISCVISD::SUB_VL, /*HasMergeOp*/ true);
   case ISD::MUL:
-    return lowerToScalableOp(Op, DAG, RISCVISD::MUL_VL, /*HasMergeOp*/ true);
   case ISD::MULHS:
-    return lowerToScalableOp(Op, DAG, RISCVISD::MULHS_VL, /*HasMergeOp*/ true);
   case ISD::MULHU:
-    return lowerToScalableOp(Op, DAG, RISCVISD::MULHU_VL, /*HasMergeOp*/ true);
   case ISD::AND:
-    return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMAND_VL,
-                                              RISCVISD::AND_VL);
   case ISD::OR:
-    return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMOR_VL,
-                                              RISCVISD::OR_VL);
   case ISD::XOR:
-    return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMXOR_VL,
-                                              RISCVISD::XOR_VL);
   case ISD::SDIV:
-    return lowerToScalableOp(Op, DAG, RISCVISD::SDIV_VL, /*HasMergeOp*/ true);
   case ISD::SREM:
-    return lowerToScalableOp(Op, DAG, RISCVISD::SREM_VL, /*HasMergeOp*/ true);
   case ISD::UDIV:
-    return lowerToScalableOp(Op, DAG, RISCVISD::UDIV_VL, /*HasMergeOp*/ true);
   case ISD::UREM:
-    return lowerToScalableOp(Op, DAG, RISCVISD::UREM_VL, /*HasMergeOp*/ true);
+    return lowerToScalableOp(Op, DAG);
   case ISD::SHL:
   case ISD::SRA:
   case ISD::SRL:
     if (Op.getSimpleValueType().isFixedLengthVector())
-      return lowerFixedLengthVectorShiftToRVV(Op, DAG);
+      return lowerToScalableOp(Op, DAG);
     // This can be called for an i32 shift amount that needs to be promoted.
     assert(Op.getOperand(1).getValueType() == MVT::i32 && Subtarget.is64Bit() &&
            "Unexpected custom legalisation");
     return SDValue();
   case ISD::SADDSAT:
-    return lowerToScalableOp(Op, DAG, RISCVISD::SADDSAT_VL,
-                             /*HasMergeOp*/ true);
   case ISD::UADDSAT:
-    return lowerToScalableOp(Op, DAG, RISCVISD::UADDSAT_VL,
-                             /*HasMergeOp*/ true);
   case ISD::SSUBSAT:
-    return lowerToScalableOp(Op, DAG, RISCVISD::SSUBSAT_VL,
-                             /*HasMergeOp*/ true);
   case ISD::USUBSAT:
-    return lowerToScalableOp(Op, DAG, RISCVISD::USUBSAT_VL,
-                             /*HasMergeOp*/ true);
   case ISD::FADD:
-    return lowerToScalableOp(Op, DAG, RISCVISD::FADD_VL, /*HasMergeOp*/ true);
   case ISD::FSUB:
-    return lowerToScalableOp(Op, DAG, RISCVISD::FSUB_VL, /*HasMergeOp*/ true);
   case ISD::FMUL:
-    return lowerToScalableOp(Op, DAG, RISCVISD::FMUL_VL, /*HasMergeOp*/ true);
   case ISD::FDIV:
-    return lowerToScalableOp(Op, DAG, RISCVISD::FDIV_VL, /*HasMergeOp*/ true);
   case ISD::FNEG:
-    return lowerToScalableOp(Op, DAG, RISCVISD::FNEG_VL);
   case ISD::FABS:
-    return lowerToScalableOp(Op, DAG, RISCVISD::FABS_VL);
   case ISD::FSQRT:
-    return lowerToScalableOp(Op, DAG, RISCVISD::FSQRT_VL);
   case ISD::FMA:
-    return lowerToScalableOp(Op, DAG, RISCVISD::VFMADD_VL);
   case ISD::SMIN:
-    return lowerToScalableOp(Op, DAG, RISCVISD::SMIN_VL, /*HasMergeOp*/ true);
   case ISD::SMAX:
-    return lowerToScalableOp(Op, DAG, RISCVISD::SMAX_VL, /*HasMergeOp*/ true);
   case ISD::UMIN:
-    return lowerToScalableOp(Op, DAG, RISCVISD::UMIN_VL, /*HasMergeOp*/ true);
   case ISD::UMAX:
-    return lowerToScalableOp(Op, DAG, RISCVISD::UMAX_VL, /*HasMergeOp*/ true);
   case ISD::FMINNUM:
-    return lowerToScalableOp(Op, DAG, RISCVISD::FMINNUM_VL,
-                             /*HasMergeOp*/ true);
   case ISD::FMAXNUM:
-    return lowerToScalableOp(Op, DAG, RISCVISD::FMAXNUM_VL,
-                             /*HasMergeOp*/ true);
+    return lowerToScalableOp(Op, DAG);
   case ISD::ABS:
   case ISD::VP_ABS:
     return lowerABS(Op, DAG);
@@ -5243,21 +5305,12 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   case ISD::FCOPYSIGN:
     return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
   case ISD::STRICT_FADD:
-    return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FADD_VL,
-                             /*HasMergeOp*/ true);
   case ISD::STRICT_FSUB:
-    return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FSUB_VL,
-                             /*HasMergeOp*/ true);
   case ISD::STRICT_FMUL:
-    return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FMUL_VL,
-                             /*HasMergeOp*/ true);
   case ISD::STRICT_FDIV:
-    return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FDIV_VL,
-                             /*HasMergeOp*/ true);
   case ISD::STRICT_FSQRT:
-    return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FSQRT_VL);
   case ISD::STRICT_FMA:
-    return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_VFMADD_VL);
+    return lowerToScalableOp(Op, DAG);
   case ISD::STRICT_FSETCC:
   case ISD::STRICT_FSETCCS:
     return lowerVectorStrictFSetcc(Op, DAG);
@@ -8338,31 +8391,6 @@ SDValue RISCVTargetLowering::lowerVectorStrictFSetcc(SDValue Op,
   return Res;
 }
 
-SDValue RISCVTargetLowering::lowerFixedLengthVectorLogicOpToRVV(
-    SDValue Op, SelectionDAG &DAG, unsigned MaskOpc, unsigned VecOpc) const {
-  MVT VT = Op.getSimpleValueType();
-
-  if (VT.getVectorElementType() == MVT::i1)
-    return lowerToScalableOp(Op, DAG, MaskOpc, /*HasMergeOp*/ false,
-                             /*HasMask*/ false);
-
-  return lowerToScalableOp(Op, DAG, VecOpc, /*HasMergeOp*/ true);
-}
-
-SDValue
-RISCVTargetLowering::lowerFixedLengthVectorShiftToRVV(SDValue Op,
-                                                      SelectionDAG &DAG) const {
-  unsigned Opc;
-  switch (Op.getOpcode()) {
-  default: llvm_unreachable("Unexpected opcode!");
-  case ISD::SHL: Opc = RISCVISD::SHL_VL; break;
-  case ISD::SRA: Opc = RISCVISD::SRA_VL; break;
-  case ISD::SRL: Opc = RISCVISD::SRL_VL; break;
-  }
-
-  return lowerToScalableOp(Op, DAG, Opc, /*HasMergeOp*/ true);
-}
-
 // Lower vector ABS to smax(X, sub(0, X)).
 SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const {
   SDLoc DL(Op);
@@ -8446,9 +8474,12 @@ SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV(
   return convertFromScalableVector(VT, Select, DAG, Subtarget);
 }
 
-SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
-                                               unsigned NewOpc, bool HasMergeOp,
-                                               bool HasMask) const {
+SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
+                                               SelectionDAG &DAG) const {
+  unsigned NewOpc = getRISCVVLOp(Op);
+  bool HasMergeOp = hasMergeOp(NewOpc);
+  bool HasMask = hasMaskOp(NewOpc);
+
   MVT VT = Op.getSimpleValueType();
   MVT ContainerVT = getContainerForFixedLengthVector(VT);
 

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index f6092b5889c29c..21915e6b15f71b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -838,14 +838,9 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const;
-  SDValue lowerFixedLengthVectorLogicOpToRVV(SDValue Op, SelectionDAG &DAG,
-                                             unsigned MaskOpc,
-                                             unsigned VecOpc) const;
-  SDValue lowerFixedLengthVectorShiftToRVV(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorSelectToRVV(SDValue Op,
                                             SelectionDAG &DAG) const;
-  SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG, unsigned NewOpc,
-                            bool HasMergeOp = false, bool HasMask = true) const;
+  SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG, unsigned RISCVISDOpc,
                     bool HasMergeOp = false) const;


        


More information about the llvm-commits mailing list