[llvm] 1d8db2f - [RISCV][NFC] Refactor lowerToScalableOp.
Jianjian GUAN via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 3 23:09:32 PDT 2023
Author: Jianjian GUAN
Date: 2023-07-04T14:09:24+08:00
New Revision: 1d8db2fab30369594c72ac33357eca8baeb71199
URL: https://github.com/llvm/llvm-project/commit/1d8db2fab30369594c72ac33357eca8baeb71199
DIFF: https://github.com/llvm/llvm-project/commit/1d8db2fab30369594c72ac33357eca8baeb71199.diff
LOG: [RISCV][NFC] Refactor lowerToScalableOp.
Refactor lowerToScalableOp to combine switch case code.
Reviewed By: frasercrmck
Differential Revision: https://reviews.llvm.org/D153948
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index be824d7a4f2502..506b1c52529ee0 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -4564,6 +4564,105 @@ SDValue RISCVTargetLowering::LowerIS_FPCLASS(SDValue Op,
ISD::CondCode::SETNE);
}
+/// Get a RISCV target specified VL op for a given SDNode.
+static unsigned getRISCVVLOp(SDValue Op) {
+#define OP_CASE(NODE) \
+ case ISD::NODE: \
+ return RISCVISD::NODE##_VL;
+ switch (Op.getOpcode()) {
+ default:
+ llvm_unreachable("don't have RISC-V specified VL op for this SDNode");
+ // clang-format off
+ OP_CASE(ADD)
+ OP_CASE(SUB)
+ OP_CASE(MUL)
+ OP_CASE(MULHS)
+ OP_CASE(MULHU)
+ OP_CASE(SDIV)
+ OP_CASE(SREM)
+ OP_CASE(UDIV)
+ OP_CASE(UREM)
+ OP_CASE(SHL)
+ OP_CASE(SRA)
+ OP_CASE(SRL)
+ OP_CASE(SADDSAT)
+ OP_CASE(UADDSAT)
+ OP_CASE(SSUBSAT)
+ OP_CASE(USUBSAT)
+ OP_CASE(FADD)
+ OP_CASE(FSUB)
+ OP_CASE(FMUL)
+ OP_CASE(FDIV)
+ OP_CASE(FNEG)
+ OP_CASE(FABS)
+ OP_CASE(FSQRT)
+ OP_CASE(SMIN)
+ OP_CASE(SMAX)
+ OP_CASE(UMIN)
+ OP_CASE(UMAX)
+ OP_CASE(FMINNUM)
+ OP_CASE(FMAXNUM)
+ OP_CASE(STRICT_FADD)
+ OP_CASE(STRICT_FSUB)
+ OP_CASE(STRICT_FMUL)
+ OP_CASE(STRICT_FDIV)
+ OP_CASE(STRICT_FSQRT)
+ // clang-format on
+#undef OP_CASE
+ case ISD::FMA:
+ return RISCVISD::VFMADD_VL;
+ case ISD::STRICT_FMA:
+ return RISCVISD::STRICT_VFMADD_VL;
+ case ISD::AND:
+ if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
+ return RISCVISD::VMAND_VL;
+ return RISCVISD::AND_VL;
+ case ISD::OR:
+ if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
+ return RISCVISD::VMOR_VL;
+ return RISCVISD::OR_VL;
+ case ISD::XOR:
+ if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
+ return RISCVISD::VMXOR_VL;
+ return RISCVISD::XOR_VL;
+ }
+}
+
+/// Return true if a RISC-V target specified op has a merge operand.
+static bool hasMergeOp(unsigned Opcode) {
+ assert(Opcode > RISCVISD::FIRST_NUMBER &&
+ Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL &&
+ "not a RISC-V target specific op");
+ assert(RISCVISD::STRICT_VFROUND_NOEXCEPT_VL - RISCVISD::FIRST_NUMBER == 421 &&
+ "adding target specific op should update this function");
+ if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::FMAXNUM_VL)
+ return true;
+ if (Opcode == RISCVISD::FCOPYSIGN_VL)
+ return true;
+ if (Opcode >= RISCVISD::VWMUL_VL && Opcode <= RISCVISD::VFWSUB_W_VL)
+ return true;
+ if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL)
+ return true;
+ return false;
+}
+
+/// Return true if a RISC-V target specified op has a mask operand.
+static bool hasMaskOp(unsigned Opcode) {
+ assert(Opcode > RISCVISD::FIRST_NUMBER &&
+ Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL &&
+ "not a RISC-V target specific op");
+ assert(RISCVISD::STRICT_VFROUND_NOEXCEPT_VL - RISCVISD::FIRST_NUMBER == 421 &&
+ "adding target specific op should update this function");
+ if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
+ return true;
+ if (Opcode >= RISCVISD::VRGATHER_VX_VL && Opcode <= RISCVISD::VFIRST_VL)
+ return true;
+ if (Opcode >= RISCVISD::STRICT_FADD_VL &&
+ Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL)
+ return true;
+ return false;
+}
+
SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -5154,83 +5253,46 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerFixedLengthVectorSetccToRVV(Op, DAG);
}
case ISD::ADD:
- return lowerToScalableOp(Op, DAG, RISCVISD::ADD_VL, /*HasMergeOp*/ true);
case ISD::SUB:
- return lowerToScalableOp(Op, DAG, RISCVISD::SUB_VL, /*HasMergeOp*/ true);
case ISD::MUL:
- return lowerToScalableOp(Op, DAG, RISCVISD::MUL_VL, /*HasMergeOp*/ true);
case ISD::MULHS:
- return lowerToScalableOp(Op, DAG, RISCVISD::MULHS_VL, /*HasMergeOp*/ true);
case ISD::MULHU:
- return lowerToScalableOp(Op, DAG, RISCVISD::MULHU_VL, /*HasMergeOp*/ true);
case ISD::AND:
- return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMAND_VL,
- RISCVISD::AND_VL);
case ISD::OR:
- return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMOR_VL,
- RISCVISD::OR_VL);
case ISD::XOR:
- return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMXOR_VL,
- RISCVISD::XOR_VL);
case ISD::SDIV:
- return lowerToScalableOp(Op, DAG, RISCVISD::SDIV_VL, /*HasMergeOp*/ true);
case ISD::SREM:
- return lowerToScalableOp(Op, DAG, RISCVISD::SREM_VL, /*HasMergeOp*/ true);
case ISD::UDIV:
- return lowerToScalableOp(Op, DAG, RISCVISD::UDIV_VL, /*HasMergeOp*/ true);
case ISD::UREM:
- return lowerToScalableOp(Op, DAG, RISCVISD::UREM_VL, /*HasMergeOp*/ true);
+ return lowerToScalableOp(Op, DAG);
case ISD::SHL:
case ISD::SRA:
case ISD::SRL:
if (Op.getSimpleValueType().isFixedLengthVector())
- return lowerFixedLengthVectorShiftToRVV(Op, DAG);
+ return lowerToScalableOp(Op, DAG);
// This can be called for an i32 shift amount that needs to be promoted.
assert(Op.getOperand(1).getValueType() == MVT::i32 && Subtarget.is64Bit() &&
"Unexpected custom legalisation");
return SDValue();
case ISD::SADDSAT:
- return lowerToScalableOp(Op, DAG, RISCVISD::SADDSAT_VL,
- /*HasMergeOp*/ true);
case ISD::UADDSAT:
- return lowerToScalableOp(Op, DAG, RISCVISD::UADDSAT_VL,
- /*HasMergeOp*/ true);
case ISD::SSUBSAT:
- return lowerToScalableOp(Op, DAG, RISCVISD::SSUBSAT_VL,
- /*HasMergeOp*/ true);
case ISD::USUBSAT:
- return lowerToScalableOp(Op, DAG, RISCVISD::USUBSAT_VL,
- /*HasMergeOp*/ true);
case ISD::FADD:
- return lowerToScalableOp(Op, DAG, RISCVISD::FADD_VL, /*HasMergeOp*/ true);
case ISD::FSUB:
- return lowerToScalableOp(Op, DAG, RISCVISD::FSUB_VL, /*HasMergeOp*/ true);
case ISD::FMUL:
- return lowerToScalableOp(Op, DAG, RISCVISD::FMUL_VL, /*HasMergeOp*/ true);
case ISD::FDIV:
- return lowerToScalableOp(Op, DAG, RISCVISD::FDIV_VL, /*HasMergeOp*/ true);
case ISD::FNEG:
- return lowerToScalableOp(Op, DAG, RISCVISD::FNEG_VL);
case ISD::FABS:
- return lowerToScalableOp(Op, DAG, RISCVISD::FABS_VL);
case ISD::FSQRT:
- return lowerToScalableOp(Op, DAG, RISCVISD::FSQRT_VL);
case ISD::FMA:
- return lowerToScalableOp(Op, DAG, RISCVISD::VFMADD_VL);
case ISD::SMIN:
- return lowerToScalableOp(Op, DAG, RISCVISD::SMIN_VL, /*HasMergeOp*/ true);
case ISD::SMAX:
- return lowerToScalableOp(Op, DAG, RISCVISD::SMAX_VL, /*HasMergeOp*/ true);
case ISD::UMIN:
- return lowerToScalableOp(Op, DAG, RISCVISD::UMIN_VL, /*HasMergeOp*/ true);
case ISD::UMAX:
- return lowerToScalableOp(Op, DAG, RISCVISD::UMAX_VL, /*HasMergeOp*/ true);
case ISD::FMINNUM:
- return lowerToScalableOp(Op, DAG, RISCVISD::FMINNUM_VL,
- /*HasMergeOp*/ true);
case ISD::FMAXNUM:
- return lowerToScalableOp(Op, DAG, RISCVISD::FMAXNUM_VL,
- /*HasMergeOp*/ true);
+ return lowerToScalableOp(Op, DAG);
case ISD::ABS:
case ISD::VP_ABS:
return lowerABS(Op, DAG);
@@ -5243,21 +5305,12 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::FCOPYSIGN:
return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
case ISD::STRICT_FADD:
- return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FADD_VL,
- /*HasMergeOp*/ true);
case ISD::STRICT_FSUB:
- return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FSUB_VL,
- /*HasMergeOp*/ true);
case ISD::STRICT_FMUL:
- return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FMUL_VL,
- /*HasMergeOp*/ true);
case ISD::STRICT_FDIV:
- return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FDIV_VL,
- /*HasMergeOp*/ true);
case ISD::STRICT_FSQRT:
- return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_FSQRT_VL);
case ISD::STRICT_FMA:
- return lowerToScalableOp(Op, DAG, RISCVISD::STRICT_VFMADD_VL);
+ return lowerToScalableOp(Op, DAG);
case ISD::STRICT_FSETCC:
case ISD::STRICT_FSETCCS:
return lowerVectorStrictFSetcc(Op, DAG);
@@ -8338,31 +8391,6 @@ SDValue RISCVTargetLowering::lowerVectorStrictFSetcc(SDValue Op,
return Res;
}
-SDValue RISCVTargetLowering::lowerFixedLengthVectorLogicOpToRVV(
- SDValue Op, SelectionDAG &DAG, unsigned MaskOpc, unsigned VecOpc) const {
- MVT VT = Op.getSimpleValueType();
-
- if (VT.getVectorElementType() == MVT::i1)
- return lowerToScalableOp(Op, DAG, MaskOpc, /*HasMergeOp*/ false,
- /*HasMask*/ false);
-
- return lowerToScalableOp(Op, DAG, VecOpc, /*HasMergeOp*/ true);
-}
-
-SDValue
-RISCVTargetLowering::lowerFixedLengthVectorShiftToRVV(SDValue Op,
- SelectionDAG &DAG) const {
- unsigned Opc;
- switch (Op.getOpcode()) {
- default: llvm_unreachable("Unexpected opcode!");
- case ISD::SHL: Opc = RISCVISD::SHL_VL; break;
- case ISD::SRA: Opc = RISCVISD::SRA_VL; break;
- case ISD::SRL: Opc = RISCVISD::SRL_VL; break;
- }
-
- return lowerToScalableOp(Op, DAG, Opc, /*HasMergeOp*/ true);
-}
-
// Lower vector ABS to smax(X, sub(0, X)).
SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const {
SDLoc DL(Op);
@@ -8446,9 +8474,12 @@ SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV(
return convertFromScalableVector(VT, Select, DAG, Subtarget);
}
-SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
- unsigned NewOpc, bool HasMergeOp,
- bool HasMask) const {
+SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
+ SelectionDAG &DAG) const {
+ unsigned NewOpc = getRISCVVLOp(Op);
+ bool HasMergeOp = hasMergeOp(NewOpc);
+ bool HasMask = hasMaskOp(NewOpc);
+
MVT VT = Op.getSimpleValueType();
MVT ContainerVT = getContainerForFixedLengthVector(VT);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index f6092b5889c29c..21915e6b15f71b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -838,14 +838,9 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const;
- SDValue lowerFixedLengthVectorLogicOpToRVV(SDValue Op, SelectionDAG &DAG,
- unsigned MaskOpc,
- unsigned VecOpc) const;
- SDValue lowerFixedLengthVectorShiftToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorSelectToRVV(SDValue Op,
SelectionDAG &DAG) const;
- SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG, unsigned NewOpc,
- bool HasMergeOp = false, bool HasMask = true) const;
+ SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG, unsigned RISCVISDOpc,
bool HasMergeOp = false) const;
More information about the llvm-commits
mailing list