[llvm] [RISCV] fold trunc_vl (srl_vl (vwaddu X, Y), splat 1) -> vaaddu X, Y (PR #76550)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 4 07:24:07 PST 2024


https://github.com/sun-jacobi updated https://github.com/llvm/llvm-project/pull/76550

>From 73947743e193fb3b6148a184df3e4a59cb69475a Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Fri, 5 Jan 2024 00:23:30 +0900
Subject: [PATCH] [RISCV][ISel] Implement combineUnsignedAvgFloor.

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 97 +++++++++++++++++--
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |  4 +
 .../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 24 +++++
 3 files changed, 119 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 27bb69dc9868c8..5fb1b9bfcfb74f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -859,6 +859,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
       setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
 
+      setOperationAction(ISD::AVGFLOORU, VT, Custom);
+
       // Splice
       setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
 
@@ -1177,6 +1179,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                             ISD::UREM, ISD::SHL, ISD::SRA, ISD::SRL},
                            VT, Custom);
 
+        setOperationAction(ISD::AVGFLOORU, VT, Custom);
+
         setOperationAction(
             {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX, ISD::ABS}, VT, Custom);
 
@@ -1375,7 +1379,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
   setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
                        ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
-                       ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+                       ISD::TRUNCATE, ISD::OR, ISD::XOR, ISD::SETCC,
+                       ISD::SELECT});
   if (Subtarget.is64Bit())
     setTargetDAGCombine(ISD::SRA);
 
@@ -1385,9 +1390,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   if (Subtarget.hasStdExtZbb())
     setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
 
-  if (Subtarget.hasStdExtZbs() && Subtarget.is64Bit())
-    setTargetDAGCombine(ISD::TRUNCATE);
-
   if (Subtarget.hasStdExtZbkb())
     setTargetDAGCombine(ISD::BITREVERSE);
   if (Subtarget.hasStdExtZfhminOrZhinxmin())
@@ -5501,6 +5503,8 @@ static unsigned getRISCVVLOp(SDValue Op) {
   VP_CASE(CTLZ)       // VP_CTLZ
   VP_CASE(CTTZ)       // VP_CTTZ
   VP_CASE(CTPOP)      // VP_CTPOP
+  case ISD::AVGFLOORU:
+    return RISCVISD::UAVGADD_VL;
   case ISD::CTLZ_ZERO_UNDEF:
   case ISD::VP_CTLZ_ZERO_UNDEF:
     return RISCVISD::CTLZ_VL;
@@ -5563,7 +5567,7 @@ static bool hasMergeOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    125 &&
+                    126 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -5589,7 +5593,7 @@ static bool hasMaskOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    125 &&
+                    126 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -6438,6 +6442,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     assert(Op.getOperand(1).getValueType() == MVT::i32 && Subtarget.is64Bit() &&
            "Unexpected custom legalisation");
     return SDValue();
+  case ISD::AVGFLOORU:
+    return lowerUnsignedAvgFloor(Op, DAG);
   case ISD::FADD:
   case ISD::FSUB:
   case ISD::FMUL:
@@ -10298,6 +10304,36 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
   return convertFromScalableVector(VT, ScalableRes, DAG, Subtarget);
 }
 
+// Lower vector AVGFLOORU(X, Y)
+SDValue RISCVTargetLowering::lowerUnsignedAvgFloor(SDValue Op,
+                                                   SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  assert((Op.getOpcode() == ISD::AVGFLOORU) &&
+         "Opcode should be ISD::AVGFLOORU");
+
+  MVT VT = Op.getSimpleValueType();
+  SDValue X = Op.getOperand(0);
+  SDValue Y = Op.getOperand(1);
+
+  MVT ContainerVT = VT;
+  if (VT.isFixedLengthVector()) {
+    ContainerVT = getContainerForFixedLengthVector(VT);
+    X = convertToScalableVector(ContainerVT, X, DAG, Subtarget);
+    Y = convertToScalableVector(ContainerVT, Y, DAG, Subtarget);
+  }
+
+  auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+
+  SDValue RM = DAG.getTargetConstant(0b10, DL, Subtarget.getXLenVT());
+  SDValue Result = DAG.getNode(RISCVISD::UAVGADD_VL, DL, ContainerVT,
+                               {X, Y, DAG.getUNDEF(ContainerVT), Mask, VL, RM});
+
+  if (VT.isFixedLengthVector())
+    Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
+
+  return Result;
+}
+
 // Lower a VP_* ISD node to the corresponding RISCVISD::*_VL node:
 // * Operands of each node are assumed to be in the same order.
 // * The EVL operand is promoted from i32 to i64 on RV64.
@@ -12357,6 +12393,51 @@ static SDValue combineAddOfBooleanXor(SDNode *N, SelectionDAG &DAG) {
                      N0.getOperand(0));
 }
 
+static SDValue combineUnsignedAvgFloor(SDNode *N, SelectionDAG &DAG,
+                                       const RISCVSubtarget &Subtarget) {
+
+  if (!Subtarget.hasVInstructions())
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  if (!VT.isVector() || !VT.isInteger())
+    return SDValue();
+
+  assert(N->getOpcode() == ISD::TRUNCATE && "Opcode should be ISD::TRUNCATE");
+
+  if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
+    return SDValue();
+
+  SDValue Srl = N->getOperand(0);
+
+  // (lshr X, 1)
+  if (!Srl.hasOneUse() || Srl.getOpcode() != ISD::SRL ||
+      !isOneOrOneSplat(Srl->getOperand(1)))
+    return SDValue();
+
+  SDValue WiddenAdd = Srl.getOperand(0);
+
+  if (!WiddenAdd.hasOneUse() || WiddenAdd.getOpcode() != ISD::ADD)
+    return SDValue();
+
+  SDValue N0 = WiddenAdd.getOperand(0);
+  SDValue N1 = WiddenAdd.getOperand(1);
+
+  auto IsZext = [&](SDValue V) {
+    if (V.getOpcode() != ISD::ZERO_EXTEND)
+      return false;
+
+    return V.getOperand(0)->getValueType(0) == VT;
+  };
+
+  if (!IsZext(N0) || !IsZext(N1))
+    return SDValue();
+
+  SDLoc DL(N);
+  return DAG.getNode(ISD::AVGFLOORU, DL, VT, N0->getOperand(0),
+                     N1->getOperand(0));
+}
+
 static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
   if (SDValue V = combineAddOfBooleanXor(N, DAG))
@@ -12490,6 +12571,9 @@ static SDValue combineDeMorganOfBoolean(SDNode *N, SelectionDAG &DAG) {
 
 static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
                                       const RISCVSubtarget &Subtarget) {
+  if (SDValue V = combineUnsignedAvgFloor(N, DAG, Subtarget))
+    return V;
+
   SDValue N0 = N->getOperand(0);
   EVT VT = N->getValueType(0);
 
@@ -18619,6 +18703,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(SMAX_VL)
   NODE_NAME_CASE(UMIN_VL)
   NODE_NAME_CASE(UMAX_VL)
+  NODE_NAME_CASE(UAVGADD_VL)
   NODE_NAME_CASE(BITREVERSE_VL)
   NODE_NAME_CASE(BSWAP_VL)
   NODE_NAME_CASE(CTLZ_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 58ed611efc83d1..911b2fcf2aec05 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -252,6 +252,9 @@ enum NodeType : unsigned {
   UADDSAT_VL,
   SSUBSAT_VL,
   USUBSAT_VL,
+  
+  // Averaging adds of unsigned integers.
+  UAVGADD_VL,
 
   MULHS_VL,
   MULHU_VL,
@@ -903,6 +906,7 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue lowerFixedLengthVectorSelectToRVV(SDValue Op,
                                             SelectionDAG &DAG) const;
   SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerUnsignedAvgFloor(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 5b50a4a78c018b..570bca5ca49086 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -57,6 +57,15 @@ def SDT_RISCVCopySign_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
                                                 SDTCisSameNumEltsAs<0, 4>,
                                                 SDTCisVT<5, XLenVT>]>;
 
+def SDT_RISCVIntBinOp_RM_VL : SDTypeProfile<1, 6, [SDTCisSameAs<0, 1>,
+                                                   SDTCisSameAs<0, 2>,
+                                                   SDTCisVec<0>, SDTCisInt<0>,
+                                                   SDTCisSameAs<0, 3>,
+                                                   SDTCVecEltisVT<4, i1>,
+                                                   SDTCisSameNumEltsAs<0, 4>,
+                                                   SDTCisVT<5, XLenVT>,
+                                                   SDTCisVT<6, XLenVT>]>; // Rounding Mode
+
 def riscv_vmv_v_v_vl : SDNode<"RISCVISD::VMV_V_V_VL",
                               SDTypeProfile<1, 3, [SDTCisVec<0>,
                                                    SDTCisSameAs<0, 1>,
@@ -115,6 +124,7 @@ def riscv_saddsat_vl   : SDNode<"RISCVISD::SADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
 def riscv_uaddsat_vl   : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
 def riscv_ssubsat_vl   : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
 def riscv_usubsat_vl   : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_uavgadd_vl   : SDNode<"RISCVISD::UAVGADD_VL", SDT_RISCVIntBinOp_RM_VL, [SDNPCommutative]>;
 
 def riscv_fadd_vl  : SDNode<"RISCVISD::FADD_VL",  SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
 def riscv_fsub_vl  : SDNode<"RISCVISD::FSUB_VL",  SDT_RISCVFPBinOp_VL>;
@@ -2338,6 +2348,20 @@ defm : VPatBinaryVL_VV_VX_VI<riscv_uaddsat_vl, "PseudoVSADDU">;
 defm : VPatBinaryVL_VV_VX<riscv_ssubsat_vl, "PseudoVSSUB">;
 defm : VPatBinaryVL_VV_VX<riscv_usubsat_vl, "PseudoVSSUBU">;
 
+// 12.2. Vector Single-Width Averaging Add and Subtract
+foreach vti = AllIntegerVectors in {
+  let Predicates = GetVTypePredicates<vti>.Predicates in {
+    def : Pat<(riscv_uavgadd_vl (vti.Vector vti.RegClass:$rs1),
+                            (vti.Vector vti.RegClass:$rs2),
+                            vti.RegClass:$merge, (vti.Mask V0), VLOpFrag,
+                            (XLenVT timm:$rounding_mode)),
+              (!cast<Instruction>("PseudoVAADDU_VV_"# vti.LMul.MX#"_MASK")
+                   vti.RegClass:$merge, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                   (vti.Mask V0), (XLenVT timm:$rounding_mode),
+                   GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+  }
+}
+
 // 12.5. Vector Narrowing Fixed-Point Clip Instructions
 class VPatTruncSatClipMaxMinBase<string inst,
                                  VTypeInfo vti,



More information about the llvm-commits mailing list