[llvm] [DAGCombiner] Add combine avg from shifts (PR #113909)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 06:39:10 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-arm

Author: None (dnsampaio)

<details>
<summary>Changes</summary>

This teaches dagcombiner to fold:
`(asr (add nsw x, y), 1) -> (avgfloors x, y)`
`(lsr (add nuw x, y), 1) -> (avgflooru x, y)`

as well the combine them to a ceil variant:
`(avgfloors (add nsw x, y), 1) -> (avgceils x, y)` 
`(avgflooru (add nuw x, y), 1) -> (avgceilu x, y)`

iff valid for the target.

Removes some of the ARM MVE patterns that are now dead code.
It adds the avg opcodes to `IsQRMVEInstruction` as to preserve the immediate splatting as before.

---
Full diff: https://github.com/llvm/llvm-project/pull/113909.diff


3 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+70) 
- (modified) llvm/lib/Target/ARM/ARMISelLowering.cpp (+4) 
- (modified) llvm/lib/Target/ARM/ARMInstrMVE.td (+9-76) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b800204d917503..125e822dddb59a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -401,6 +401,8 @@ namespace {
     SDValue PromoteExtend(SDValue Op);
     bool PromoteLoad(SDValue Op);
 
+    SDValue combineAVG(SDNode *N);
+
     SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
                                 SDValue RHS, SDValue True, SDValue False,
                                 ISD::CondCode CC);
@@ -5354,6 +5356,18 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
           DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
   }
 
+  // Fold avgfloor((add nw x,y), 1) -> avgceil(x,y)
+  if (Opcode == ISD::AVGFLOORU || Opcode == ISD::AVGFLOORS) {
+    SDValue Add;
+    if(sd_match(N, m_c_BinOp(Opcode, m_AllOf(m_Value(Add), m_Add(m_Value(X), m_Value(Y))), m_One()))) {
+      if (IsSigned) {
+        if (hasOperation(ISD::AVGCEILS, VT) && Add->getFlags().hasNoSignedWrap())
+          return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
+        } else if (hasOperation(ISD::AVGCEILU, VT) && Add->getFlags().hasNoUnsignedWrap())
+          return DAG.getNode(ISD::AVGCEILU, DL, VT, X, Y);
+    }
+  }
+
   return SDValue();
 }
 
@@ -10626,6 +10640,9 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
   if (SDValue NarrowLoad = reduceLoadWidth(N))
     return NarrowLoad;
 
+  if (SDValue AVG = combineAVG(N))
+    return AVG;
+
   return SDValue();
 }
 
@@ -10880,6 +10897,9 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
   if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
     return MULH;
 
+  if (SDValue AVG = combineAVG(N))
+    return AVG;
+
   return SDValue();
 }
 
@@ -11393,6 +11413,56 @@ static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
   }
 }
 
+SDValue DAGCombiner::combineAVG(SDNode *N) {
+  const auto Opcode = N->getOpcode();
+
+  // Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
+  if (Opcode != ISD::SRA && Opcode != ISD::SRL)
+    return SDValue();
+
+  unsigned FloorISD = 0;
+  auto VT = N->getValueType(0);
+  unsigned Shift = N->getOpcode();
+  bool IsUnsigned = false;
+  // Decide wether signed or unsigned.
+  switch (Shift) {
+  case ISD::SRA:
+    if (hasOperation(ISD::AVGFLOORS, VT))
+      FloorISD = ISD::AVGFLOORS;
+    break;
+  case ISD::SRL:
+    IsUnsigned = true;
+    if (hasOperation(ISD::AVGFLOORU, VT))
+      FloorISD = ISD::AVGFLOORU;
+    break;
+  default:
+    return SDValue();
+  }
+
+  // We don't have any valid avgs, bail out.
+  if (!FloorISD)
+    return SDValue();
+
+  // Captured values.
+  SDValue A, B, Add;
+
+  // Match floor average as it is common to both floor/ceil avgs.
+  if (!sd_match(N, m_BinOp(Shift,
+                           m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+                           m_One())))
+    return SDValue();
+
+  // Can't optimize adds that may wrap.
+  if (IsUnsigned && !Add->getFlags().hasNoUnsignedWrap())
+    return SDValue();
+
+  if (!IsUnsigned && !Add->getFlags().hasNoSignedWrap())
+    return SDValue();
+
+  return DAG.getNode(FloorISD, SDLoc(N), N->getValueType(0),
+                     {A, B});
+}
+
 /// Generate Min/Max node
 SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
                                          SDValue RHS, SDValue True,
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index a98b7a8420927e..123bcac000d37a 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -7951,6 +7951,10 @@ static bool IsQRMVEInstruction(const SDNode *N, const SDNode *Op) {
   case ISD::MUL:
   case ISD::SADDSAT:
   case ISD::UADDSAT:
+  case ISD::AVGFLOORS:
+  case ISD::AVGFLOORU:
+  case ISD::AVGCEILS:
+  case ISD::AVGCEILU:
     return true;
   case ISD::SUB:
   case ISD::SSUBSAT:
diff --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td
index 04d5d00eef10e6..8c8403ac58b080 100644
--- a/llvm/lib/Target/ARM/ARMInstrMVE.td
+++ b/llvm/lib/Target/ARM/ARMInstrMVE.td
@@ -2222,64 +2222,6 @@ defm MVE_VRHADDu8  : MVE_VRHADD<MVE_v16u8, avgceilu>;
 defm MVE_VRHADDu16 : MVE_VRHADD<MVE_v8u16, avgceilu>;
 defm MVE_VRHADDu32 : MVE_VRHADD<MVE_v4u32, avgceilu>;
 
-// Rounding Halving Add perform the arithemtic operation with an extra bit of
-// precision, before performing the shift, to void clipping errors. We're not
-// modelling that here with these patterns, but we're using no wrap forms of
-// add to ensure that the extra bit of information is not needed for the
-// arithmetic or the rounding.
-let Predicates = [HasMVEInt] in {
-  def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
-                                        (v16i8 (ARMvmovImm (i32 3585)))),
-                                (i32 1))),
-            (MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
-                                        (v8i16 (ARMvmovImm (i32 2049)))),
-                                (i32 1))),
-            (MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
-                                        (v4i32 (ARMvmovImm (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
-                                        (v16i8 (ARMvmovImm (i32 3585)))),
-                                (i32 1))),
-            (MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
-                                        (v8i16 (ARMvmovImm (i32 2049)))),
-                                (i32 1))),
-            (MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
-                                        (v4i32 (ARMvmovImm (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>;
-
-  def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
-                                        (v16i8 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
-                                        (v8i16 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
-                                        (v4i32 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
-                                        (v16i8 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
-                                        (v8i16 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
-                                        (v4i32 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>;
-}
-
-
 class MVE_VHADDSUB<string iname, string suffix, bit U, bit subtract,
                    bits<2> size, list<dag> pattern=[]>
   : MVE_int<iname, suffix, size, pattern> {
@@ -2303,8 +2245,7 @@ class MVE_VHSUB_<string suffix, bit U, bits<2> size,
   : MVE_VHADDSUB<"vhsub", suffix, U, 0b1, size, pattern>;
 
 multiclass MVE_VHADD_m<MVEVectorVTInfo VTI, SDNode Op,
-                      SDPatternOperator unpred_op, Intrinsic PredInt, PatFrag add_op,
-                      SDNode shift_op> {
+                      SDPatternOperator unpred_op, Intrinsic PredInt> {
   def "" : MVE_VHADD_<VTI.Suffix, VTI.Unsigned, VTI.Size>;
   defvar Inst = !cast<Instruction>(NAME);
   defm : MVE_TwoOpPattern<VTI, Op, PredInt, (? (i32 VTI.Unsigned)), !cast<Instruction>(NAME)>;
@@ -2313,26 +2254,18 @@ multiclass MVE_VHADD_m<MVEVectorVTInfo VTI, SDNode Op,
     // Unpredicated add-and-divide-by-two
     def : Pat<(VTI.Vec (unpred_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn), (i32 VTI.Unsigned))),
               (VTI.Vec (Inst (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)))>;
-
-    def : Pat<(VTI.Vec (shift_op (add_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)), (i32 1))),
-              (Inst MQPR:$Qm, MQPR:$Qn)>;
   }
 }
 
-multiclass MVE_VHADD<MVEVectorVTInfo VTI, SDNode Op, PatFrag add_op, SDNode shift_op>
-  : MVE_VHADD_m<VTI, Op, int_arm_mve_vhadd, int_arm_mve_hadd_predicated, add_op,
-                shift_op>;
+multiclass MVE_VHADD<MVEVectorVTInfo VTI, SDNode Op>
+  : MVE_VHADD_m<VTI, Op, int_arm_mve_vhadd, int_arm_mve_hadd_predicated>;
 
-// Halving add/sub perform the arithemtic operation with an extra bit of
-// precision, before performing the shift, to void clipping errors. We're not
-// modelling that here with these patterns, but we're using no wrap forms of
-// add/sub to ensure that the extra bit of information is not needed.
-defm MVE_VHADDs8  : MVE_VHADD<MVE_v16s8, avgfloors, addnsw, ARMvshrsImm>;
-defm MVE_VHADDs16 : MVE_VHADD<MVE_v8s16, avgfloors, addnsw, ARMvshrsImm>;
-defm MVE_VHADDs32 : MVE_VHADD<MVE_v4s32, avgfloors, addnsw, ARMvshrsImm>;
-defm MVE_VHADDu8  : MVE_VHADD<MVE_v16u8, avgflooru, addnuw, ARMvshruImm>;
-defm MVE_VHADDu16 : MVE_VHADD<MVE_v8u16, avgflooru, addnuw, ARMvshruImm>;
-defm MVE_VHADDu32 : MVE_VHADD<MVE_v4u32, avgflooru, addnuw, ARMvshruImm>;
+defm MVE_VHADDs8  : MVE_VHADD<MVE_v16s8, avgfloors>;
+defm MVE_VHADDs16 : MVE_VHADD<MVE_v8s16, avgfloors>;
+defm MVE_VHADDs32 : MVE_VHADD<MVE_v4s32, avgfloors>;
+defm MVE_VHADDu8  : MVE_VHADD<MVE_v16u8, avgflooru>;
+defm MVE_VHADDu16 : MVE_VHADD<MVE_v8u16, avgflooru>;
+defm MVE_VHADDu32 : MVE_VHADD<MVE_v4u32, avgflooru>;
 
 multiclass MVE_VHSUB_m<MVEVectorVTInfo VTI,
                       SDPatternOperator unpred_op, Intrinsic pred_int, PatFrag sub_op,

``````````

</details>


https://github.com/llvm/llvm-project/pull/113909


More information about the llvm-commits mailing list