[llvm] [RISCV] fold trunc_vl (srl_vl (vwaddu X, Y), splat 1) -> vaaddu X, Y (PR #76550)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 4 07:24:07 PST 2024
https://github.com/sun-jacobi updated https://github.com/llvm/llvm-project/pull/76550
>From 73947743e193fb3b6148a184df3e4a59cb69475a Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Fri, 5 Jan 2024 00:23:30 +0900
Subject: [PATCH] [RISCV][ISel] Implement combineUnsignedAvgFloor.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 97 +++++++++++++++++--
llvm/lib/Target/RISCV/RISCVISelLowering.h | 4 +
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 24 +++++
3 files changed, 119 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 27bb69dc9868c8..5fb1b9bfcfb74f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -859,6 +859,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
+ setOperationAction(ISD::AVGFLOORU, VT, Custom);
+
// Splice
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
@@ -1177,6 +1179,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::UREM, ISD::SHL, ISD::SRA, ISD::SRL},
VT, Custom);
+ setOperationAction(ISD::AVGFLOORU, VT, Custom);
+
setOperationAction(
{ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX, ISD::ABS}, VT, Custom);
@@ -1375,7 +1379,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
- ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+ ISD::TRUNCATE, ISD::OR, ISD::XOR, ISD::SETCC,
+ ISD::SELECT});
if (Subtarget.is64Bit())
setTargetDAGCombine(ISD::SRA);
@@ -1385,9 +1390,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.hasStdExtZbb())
setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
- if (Subtarget.hasStdExtZbs() && Subtarget.is64Bit())
- setTargetDAGCombine(ISD::TRUNCATE);
-
if (Subtarget.hasStdExtZbkb())
setTargetDAGCombine(ISD::BITREVERSE);
if (Subtarget.hasStdExtZfhminOrZhinxmin())
@@ -5501,6 +5503,8 @@ static unsigned getRISCVVLOp(SDValue Op) {
VP_CASE(CTLZ) // VP_CTLZ
VP_CASE(CTTZ) // VP_CTTZ
VP_CASE(CTPOP) // VP_CTPOP
+ case ISD::AVGFLOORU:
+ return RISCVISD::UAVGADD_VL;
case ISD::CTLZ_ZERO_UNDEF:
case ISD::VP_CTLZ_ZERO_UNDEF:
return RISCVISD::CTLZ_VL;
@@ -5563,7 +5567,7 @@ static bool hasMergeOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 125 &&
+ 126 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -5589,7 +5593,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 125 &&
+ 126 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -6438,6 +6442,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
assert(Op.getOperand(1).getValueType() == MVT::i32 && Subtarget.is64Bit() &&
"Unexpected custom legalisation");
return SDValue();
+ case ISD::AVGFLOORU:
+ return lowerUnsignedAvgFloor(Op, DAG);
case ISD::FADD:
case ISD::FSUB:
case ISD::FMUL:
@@ -10298,6 +10304,36 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
return convertFromScalableVector(VT, ScalableRes, DAG, Subtarget);
}
+// Lower vector AVGFLOORU(X, Y)
+SDValue RISCVTargetLowering::lowerUnsignedAvgFloor(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ assert((Op.getOpcode() == ISD::AVGFLOORU) &&
+ "Opcode should be ISD::AVGFLOORU");
+
+ MVT VT = Op.getSimpleValueType();
+ SDValue X = Op.getOperand(0);
+ SDValue Y = Op.getOperand(1);
+
+ MVT ContainerVT = VT;
+ if (VT.isFixedLengthVector()) {
+ ContainerVT = getContainerForFixedLengthVector(VT);
+ X = convertToScalableVector(ContainerVT, X, DAG, Subtarget);
+ Y = convertToScalableVector(ContainerVT, Y, DAG, Subtarget);
+ }
+
+ auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+
+ SDValue RM = DAG.getTargetConstant(0b10, DL, Subtarget.getXLenVT());
+ SDValue Result = DAG.getNode(RISCVISD::UAVGADD_VL, DL, ContainerVT,
+ {X, Y, DAG.getUNDEF(ContainerVT), Mask, VL, RM});
+
+ if (VT.isFixedLengthVector())
+ Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
+
+ return Result;
+}
+
// Lower a VP_* ISD node to the corresponding RISCVISD::*_VL node:
// * Operands of each node are assumed to be in the same order.
// * The EVL operand is promoted from i32 to i64 on RV64.
@@ -12357,6 +12393,51 @@ static SDValue combineAddOfBooleanXor(SDNode *N, SelectionDAG &DAG) {
N0.getOperand(0));
}
+static SDValue combineUnsignedAvgFloor(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+
+ if (!Subtarget.hasVInstructions())
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+ if (!VT.isVector() || !VT.isInteger())
+ return SDValue();
+
+ assert(N->getOpcode() == ISD::TRUNCATE && "Opcode should be ISD::TRUNCATE");
+
+ if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
+ return SDValue();
+
+ SDValue Srl = N->getOperand(0);
+
+ // (lshr X, 1)
+ if (!Srl.hasOneUse() || Srl.getOpcode() != ISD::SRL ||
+ !isOneOrOneSplat(Srl->getOperand(1)))
+ return SDValue();
+
+ SDValue WiddenAdd = Srl.getOperand(0);
+
+ if (!WiddenAdd.hasOneUse() || WiddenAdd.getOpcode() != ISD::ADD)
+ return SDValue();
+
+ SDValue N0 = WiddenAdd.getOperand(0);
+ SDValue N1 = WiddenAdd.getOperand(1);
+
+ auto IsZext = [&](SDValue V) {
+ if (V.getOpcode() != ISD::ZERO_EXTEND)
+ return false;
+
+ return V.getOperand(0)->getValueType(0) == VT;
+ };
+
+ if (!IsZext(N0) || !IsZext(N1))
+ return SDValue();
+
+ SDLoc DL(N);
+ return DAG.getNode(ISD::AVGFLOORU, DL, VT, N0->getOperand(0),
+ N1->getOperand(0));
+}
+
static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (SDValue V = combineAddOfBooleanXor(N, DAG))
@@ -12490,6 +12571,9 @@ static SDValue combineDeMorganOfBoolean(SDNode *N, SelectionDAG &DAG) {
static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
+ if (SDValue V = combineUnsignedAvgFloor(N, DAG, Subtarget))
+ return V;
+
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
@@ -18619,6 +18703,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(SMAX_VL)
NODE_NAME_CASE(UMIN_VL)
NODE_NAME_CASE(UMAX_VL)
+ NODE_NAME_CASE(UAVGADD_VL)
NODE_NAME_CASE(BITREVERSE_VL)
NODE_NAME_CASE(BSWAP_VL)
NODE_NAME_CASE(CTLZ_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 58ed611efc83d1..911b2fcf2aec05 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -252,6 +252,9 @@ enum NodeType : unsigned {
UADDSAT_VL,
SSUBSAT_VL,
USUBSAT_VL,
+
+ // Averaging adds of unsigned integers.
+ UAVGADD_VL,
MULHS_VL,
MULHU_VL,
@@ -903,6 +906,7 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerFixedLengthVectorSelectToRVV(SDValue Op,
SelectionDAG &DAG) const;
SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerUnsignedAvgFloor(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 5b50a4a78c018b..570bca5ca49086 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -57,6 +57,15 @@ def SDT_RISCVCopySign_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
SDTCisSameNumEltsAs<0, 4>,
SDTCisVT<5, XLenVT>]>;
+def SDT_RISCVIntBinOp_RM_VL : SDTypeProfile<1, 6, [SDTCisSameAs<0, 1>,
+ SDTCisSameAs<0, 2>,
+ SDTCisVec<0>, SDTCisInt<0>,
+ SDTCisSameAs<0, 3>,
+ SDTCVecEltisVT<4, i1>,
+ SDTCisSameNumEltsAs<0, 4>,
+ SDTCisVT<5, XLenVT>,
+ SDTCisVT<6, XLenVT>]>; // Rounding Mode
+
def riscv_vmv_v_v_vl : SDNode<"RISCVISD::VMV_V_V_VL",
SDTypeProfile<1, 3, [SDTCisVec<0>,
SDTCisSameAs<0, 1>,
@@ -115,6 +124,7 @@ def riscv_saddsat_vl : SDNode<"RISCVISD::SADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
def riscv_ssubsat_vl : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
def riscv_usubsat_vl : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_uavgadd_vl : SDNode<"RISCVISD::UAVGADD_VL", SDT_RISCVIntBinOp_RM_VL, [SDNPCommutative]>;
def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
@@ -2338,6 +2348,20 @@ defm : VPatBinaryVL_VV_VX_VI<riscv_uaddsat_vl, "PseudoVSADDU">;
defm : VPatBinaryVL_VV_VX<riscv_ssubsat_vl, "PseudoVSSUB">;
defm : VPatBinaryVL_VV_VX<riscv_usubsat_vl, "PseudoVSSUBU">;
+// 12.2. Vector Single-Width Averaging Add and Subtract
+foreach vti = AllIntegerVectors in {
+ let Predicates = GetVTypePredicates<vti>.Predicates in {
+ def : Pat<(riscv_uavgadd_vl (vti.Vector vti.RegClass:$rs1),
+ (vti.Vector vti.RegClass:$rs2),
+ vti.RegClass:$merge, (vti.Mask V0), VLOpFrag,
+ (XLenVT timm:$rounding_mode)),
+ (!cast<Instruction>("PseudoVAADDU_VV_"# vti.LMul.MX#"_MASK")
+ vti.RegClass:$merge, vti.RegClass:$rs1, vti.RegClass:$rs2,
+ (vti.Mask V0), (XLenVT timm:$rounding_mode),
+ GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+ }
+}
+
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
class VPatTruncSatClipMaxMinBase<string inst,
VTypeInfo vti,
More information about the llvm-commits
mailing list