[llvm] [DAGCombiner] Add combine avg from shifts (PR #113909)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 30 08:55:27 PDT 2024


https://github.com/dnsampaio updated https://github.com/llvm/llvm-project/pull/113909

>From bd5312ec825f717447f86ae7baab77e47b008ee1 Mon Sep 17 00:00:00 2001
From: Diogo Sampaio <dsampaio at kalrayinc.com>
Date: Mon, 28 Oct 2024 13:14:40 +0100
Subject: [PATCH] [DAGCombiner] Add combine avg from shifts

This teaches dagcombiner to fold:
`(asr (add nsw x, y), 1) -> (avgfloors x, y)`
`(lsr (add nuw x, y), 1) -> (avgflooru x, y)`

as well the combine them to a ceil variant:
`(avgfloors (add nsw x, y), 1) -> (avgceils x, y)`
`(avgflooru (add nuw x, y), 1) -> (avgceilu x, y)`

iff valid for the target.

Removes some of the ARM MVE patterns that are now dead code.
It adds the avg opcodes to `IsQRMVEInstruction` as to preserve the immediate splatting as before.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  76 +++++++
 llvm/lib/Target/ARM/ARMISelLowering.cpp       |   2 +
 llvm/lib/Target/ARM/ARMInstrMVE.td            |  85 +-------
 llvm/test/CodeGen/AArch64/avg.ll              | 202 ++++++++++++++++++
 llvm/test/CodeGen/AArch64/sve-hadd.ll         |  40 ++++
 5 files changed, 329 insertions(+), 76 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ceaf5d664131c3..0e8f10394e895f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -401,6 +401,8 @@ namespace {
     SDValue PromoteExtend(SDValue Op);
     bool PromoteLoad(SDValue Op);
 
+    SDValue foldShiftToAvg(SDNode *N);
+
     SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
                                 SDValue RHS, SDValue True, SDValue False,
                                 ISD::CondCode CC);
@@ -5354,6 +5356,27 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
           DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
   }
 
+  // Fold avgfloor((add nw x,y), 1) -> avgceil(x,y)
+  // Fold avgfloor((add nw x,1), y) -> avgceil(x,y)
+  if ((Opcode == ISD::AVGFLOORU && hasOperation(ISD::AVGCEILU, VT)) ||
+      (Opcode == ISD::AVGFLOORS && hasOperation(ISD::AVGCEILS, VT))) {
+    SDValue Add;
+    if (sd_match(N,
+                 m_c_BinOp(Opcode,
+                           m_AllOf(m_Value(Add), m_Add(m_Value(X), m_Value(Y))),
+                           m_One())) ||
+        sd_match(N, m_c_BinOp(Opcode,
+                              m_AllOf(m_Value(Add), m_Add(m_Value(X), m_One())),
+                              m_Value(Y)))) {
+
+      if (IsSigned && Add->getFlags().hasNoSignedWrap())
+        return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
+
+      if (!IsSigned && Add->getFlags().hasNoUnsignedWrap())
+        return DAG.getNode(ISD::AVGCEILU, DL, VT, X, Y);
+    }
+  }
+
   return SDValue();
 }
 
@@ -10635,6 +10658,9 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
   if (SDValue NarrowLoad = reduceLoadWidth(N))
     return NarrowLoad;
 
+  if (SDValue AVG = foldShiftToAvg(N))
+    return AVG;
+
   return SDValue();
 }
 
@@ -10889,6 +10915,9 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
   if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
     return MULH;
 
+  if (SDValue AVG = foldShiftToAvg(N))
+    return AVG;
+
   return SDValue();
 }
 
@@ -11402,6 +11431,53 @@ static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
   }
 }
 
+SDValue DAGCombiner::foldShiftToAvg(SDNode *N) {
+  const unsigned Opcode = N->getOpcode();
+
+  // Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
+  if (Opcode != ISD::SRA && Opcode != ISD::SRL)
+    return SDValue();
+
+  unsigned FloorISD = 0;
+  auto VT = N->getValueType(0);
+  bool IsUnsigned = false;
+
+  // Decide wether signed or unsigned.
+  switch (Opcode) {
+  case ISD::SRA:
+    if (!hasOperation(ISD::AVGFLOORS, VT))
+      return SDValue();
+    FloorISD = ISD::AVGFLOORS;
+    break;
+  case ISD::SRL:
+    IsUnsigned = true;
+    if (!hasOperation(ISD::AVGFLOORU, VT))
+      return SDValue();
+    FloorISD = ISD::AVGFLOORU;
+    break;
+  default:
+    return SDValue();
+  }
+
+  // Captured values.
+  SDValue A, B, Add;
+
+  // Match floor average as it is common to both floor/ceil avgs.
+  if (!sd_match(N, m_BinOp(Opcode,
+                           m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+                           m_One())))
+    return SDValue();
+
+  // Can't optimize adds that may wrap.
+  if (IsUnsigned && !Add->getFlags().hasNoUnsignedWrap())
+    return SDValue();
+
+  if (!IsUnsigned && !Add->getFlags().hasNoSignedWrap())
+    return SDValue();
+
+  return DAG.getNode(FloorISD, SDLoc(N), N->getValueType(0), {A, B});
+}
+
 /// Generate Min/Max node
 SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
                                          SDValue RHS, SDValue True,
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index a98b7a8420927e..e08c0c0653eba2 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -7951,6 +7951,8 @@ static bool IsQRMVEInstruction(const SDNode *N, const SDNode *Op) {
   case ISD::MUL:
   case ISD::SADDSAT:
   case ISD::UADDSAT:
+  case ISD::AVGFLOORS:
+  case ISD::AVGFLOORU:
     return true;
   case ISD::SUB:
   case ISD::SSUBSAT:
diff --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td
index 04d5d00eef10e6..8c8403ac58b080 100644
--- a/llvm/lib/Target/ARM/ARMInstrMVE.td
+++ b/llvm/lib/Target/ARM/ARMInstrMVE.td
@@ -2222,64 +2222,6 @@ defm MVE_VRHADDu8  : MVE_VRHADD<MVE_v16u8, avgceilu>;
 defm MVE_VRHADDu16 : MVE_VRHADD<MVE_v8u16, avgceilu>;
 defm MVE_VRHADDu32 : MVE_VRHADD<MVE_v4u32, avgceilu>;
 
-// Rounding Halving Add perform the arithemtic operation with an extra bit of
-// precision, before performing the shift, to void clipping errors. We're not
-// modelling that here with these patterns, but we're using no wrap forms of
-// add to ensure that the extra bit of information is not needed for the
-// arithmetic or the rounding.
-let Predicates = [HasMVEInt] in {
-  def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
-                                        (v16i8 (ARMvmovImm (i32 3585)))),
-                                (i32 1))),
-            (MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
-                                        (v8i16 (ARMvmovImm (i32 2049)))),
-                                (i32 1))),
-            (MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
-                                        (v4i32 (ARMvmovImm (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
-                                        (v16i8 (ARMvmovImm (i32 3585)))),
-                                (i32 1))),
-            (MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
-                                        (v8i16 (ARMvmovImm (i32 2049)))),
-                                (i32 1))),
-            (MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
-                                        (v4i32 (ARMvmovImm (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>;
-
-  def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
-                                        (v16i8 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
-                                        (v8i16 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
-                                        (v4i32 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
-                                        (v16i8 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
-                                        (v8i16 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>;
-  def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
-                                        (v4i32 (ARMvdup (i32 1)))),
-                                (i32 1))),
-            (MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>;
-}
-
-
 class MVE_VHADDSUB<string iname, string suffix, bit U, bit subtract,
                    bits<2> size, list<dag> pattern=[]>
   : MVE_int<iname, suffix, size, pattern> {
@@ -2303,8 +2245,7 @@ class MVE_VHSUB_<string suffix, bit U, bits<2> size,
   : MVE_VHADDSUB<"vhsub", suffix, U, 0b1, size, pattern>;
 
 multiclass MVE_VHADD_m<MVEVectorVTInfo VTI, SDNode Op,
-                      SDPatternOperator unpred_op, Intrinsic PredInt, PatFrag add_op,
-                      SDNode shift_op> {
+                      SDPatternOperator unpred_op, Intrinsic PredInt> {
   def "" : MVE_VHADD_<VTI.Suffix, VTI.Unsigned, VTI.Size>;
   defvar Inst = !cast<Instruction>(NAME);
   defm : MVE_TwoOpPattern<VTI, Op, PredInt, (? (i32 VTI.Unsigned)), !cast<Instruction>(NAME)>;
@@ -2313,26 +2254,18 @@ multiclass MVE_VHADD_m<MVEVectorVTInfo VTI, SDNode Op,
     // Unpredicated add-and-divide-by-two
     def : Pat<(VTI.Vec (unpred_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn), (i32 VTI.Unsigned))),
               (VTI.Vec (Inst (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)))>;
-
-    def : Pat<(VTI.Vec (shift_op (add_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)), (i32 1))),
-              (Inst MQPR:$Qm, MQPR:$Qn)>;
   }
 }
 
-multiclass MVE_VHADD<MVEVectorVTInfo VTI, SDNode Op, PatFrag add_op, SDNode shift_op>
-  : MVE_VHADD_m<VTI, Op, int_arm_mve_vhadd, int_arm_mve_hadd_predicated, add_op,
-                shift_op>;
+multiclass MVE_VHADD<MVEVectorVTInfo VTI, SDNode Op>
+  : MVE_VHADD_m<VTI, Op, int_arm_mve_vhadd, int_arm_mve_hadd_predicated>;
 
-// Halving add/sub perform the arithemtic operation with an extra bit of
-// precision, before performing the shift, to void clipping errors. We're not
-// modelling that here with these patterns, but we're using no wrap forms of
-// add/sub to ensure that the extra bit of information is not needed.
-defm MVE_VHADDs8  : MVE_VHADD<MVE_v16s8, avgfloors, addnsw, ARMvshrsImm>;
-defm MVE_VHADDs16 : MVE_VHADD<MVE_v8s16, avgfloors, addnsw, ARMvshrsImm>;
-defm MVE_VHADDs32 : MVE_VHADD<MVE_v4s32, avgfloors, addnsw, ARMvshrsImm>;
-defm MVE_VHADDu8  : MVE_VHADD<MVE_v16u8, avgflooru, addnuw, ARMvshruImm>;
-defm MVE_VHADDu16 : MVE_VHADD<MVE_v8u16, avgflooru, addnuw, ARMvshruImm>;
-defm MVE_VHADDu32 : MVE_VHADD<MVE_v4u32, avgflooru, addnuw, ARMvshruImm>;
+defm MVE_VHADDs8  : MVE_VHADD<MVE_v16s8, avgfloors>;
+defm MVE_VHADDs16 : MVE_VHADD<MVE_v8s16, avgfloors>;
+defm MVE_VHADDs32 : MVE_VHADD<MVE_v4s32, avgfloors>;
+defm MVE_VHADDu8  : MVE_VHADD<MVE_v16u8, avgflooru>;
+defm MVE_VHADDu16 : MVE_VHADD<MVE_v8u16, avgflooru>;
+defm MVE_VHADDu32 : MVE_VHADD<MVE_v4u32, avgflooru>;
 
 multiclass MVE_VHSUB_m<MVEVectorVTInfo VTI,
                       SDPatternOperator unpred_op, Intrinsic pred_int, PatFrag sub_op,
diff --git a/llvm/test/CodeGen/AArch64/avg.ll b/llvm/test/CodeGen/AArch64/avg.ll
index cabc0d346b806f..ea07b10c22c2e7 100644
--- a/llvm/test/CodeGen/AArch64/avg.ll
+++ b/llvm/test/CodeGen/AArch64/avg.ll
@@ -146,3 +146,205 @@ define <16 x i16> @sext_avgceils_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
   %avg = sub <16 x i16> %or, %shift
   ret <16 x i16> %avg
 }
+
+define <8 x i16> @add_avgflooru(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgflooru:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ret
+  %add = add nuw <8 x i16> %a0, %a1
+  %avg = lshr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgflooru_mismatch(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgflooru_mismatch:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ushr v0.8h, v0.8h, #1
+; CHECK-NEXT:    ret
+  %add = add <8 x i16> %a0, %a1
+  %avg = lshr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    urhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ret
+  %add0 = add nuw <8 x i16> %a0, splat(i16 1)
+  %add = add nuw <8 x i16> %a1, %add0
+  %avg = lshr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu2(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    urhadd v0.8h, v1.8h, v0.8h
+; CHECK-NEXT:    ret
+  %add0 = add nuw <8 x i16> %a1, %a0
+  %add = add nuw <8 x i16> %add0, splat(i16 1)
+  %avg = lshr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu_mismatch1(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu_mismatch1:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v2.8h, #1
+; CHECK-NEXT:    add v0.8h, v1.8h, v0.8h
+; CHECK-NEXT:    uhadd v0.8h, v0.8h, v2.8h
+; CHECK-NEXT:    ret
+  %add0 = add <8 x i16> %a1, %a0
+  %add = add nuw <8 x i16> %add0, splat(i16 1)
+  %avg = lshr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu_mismatch2(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu_mismatch2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn v1.16b, v1.16b
+; CHECK-NEXT:    sub v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ushr v0.8h, v0.8h, #1
+; CHECK-NEXT:    ret
+  %add0 = add nuw <8 x i16> %a1, %a0
+  %add = add <8 x i16> %add0, splat(i16 1)
+  %avg = lshr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu_mismatch3(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu_mismatch3:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn v1.16b, v1.16b
+; CHECK-NEXT:    sub v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ushr v0.8h, v0.8h, #1
+; CHECK-NEXT:    ret
+  %add0 = add nuw <8 x i16> %a1, %a0
+  %add = add <8 x i16> %add0, splat(i16 1)
+  %avg = lshr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgfloors(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgfloors:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ret
+  %add = add nsw <8 x i16> %a0, %a1
+  %avg = ashr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgfloors_mismatch(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgfloors_mismatch:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    sshr v0.8h, v0.8h, #1
+; CHECK-NEXT:    ret
+  %add = add <8 x i16> %a0, %a1
+  %avg = ashr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgfoor_mismatch2(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgfoor_mismatch2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    sshr v0.8h, v0.8h, #2
+; CHECK-NEXT:    ret
+  %add = add nsw <8 x i16> %a0, %a1
+  %avg = ashr <8 x i16> %add, splat(i16 2)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    srhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ret
+  %add0 = add nsw <8 x i16> %a0, splat(i16 1)
+  %add = add nsw <8 x i16> %a1, %add0
+  %avg = ashr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils2(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    srhadd v0.8h, v1.8h, v0.8h
+; CHECK-NEXT:    ret
+  %add0 = add nsw <8 x i16> %a1, %a0
+  %add = add nsw <8 x i16> %add0, splat(i16 1)
+  %avg = ashr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils_mismatch1(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils_mismatch1:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v2.8h, #1
+; CHECK-NEXT:    add v0.8h, v1.8h, v0.8h
+; CHECK-NEXT:    shadd v0.8h, v0.8h, v2.8h
+; CHECK-NEXT:    ret
+  %add0 = add <8 x i16> %a1, %a0
+  %add = add nsw <8 x i16> %add0, splat(i16 1)
+  %avg = ashr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils_mismatch2(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils_mismatch2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn v1.16b, v1.16b
+; CHECK-NEXT:    sub v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    sshr v0.8h, v0.8h, #1
+; CHECK-NEXT:    ret
+  %add0 = add nsw <8 x i16> %a1, %a0
+  %add = add <8 x i16> %add0, splat(i16 1)
+  %avg = ashr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils_mismatch3(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils_mismatch3:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn v1.16b, v1.16b
+; CHECK-NEXT:    sub v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    sshr v0.8h, v0.8h, #1
+; CHECK-NEXT:    ret
+  %add0 = add nsw <8 x i16> %a1, %a0
+  %add = add <8 x i16> %add0, splat(i16 1)
+  %avg = ashr <8 x i16> %add, splat(i16 1)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils_mismatch4(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils_mismatch4:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn v0.16b, v0.16b
+; CHECK-NEXT:    sub v0.8h, v1.8h, v0.8h
+; CHECK-NEXT:    sshr v0.8h, v0.8h, #2
+; CHECK-NEXT:    ret
+  %add0 = add nsw <8 x i16> %a0, splat(i16 1)
+  %add = add nsw <8 x i16> %a1, %add0
+  %avg = ashr <8 x i16> %add, splat(i16 2)
+  ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu_mismatch(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu_mismatch:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v2.8h, #1
+; CHECK-NEXT:    add v0.8h, v1.8h, v0.8h
+; CHECK-NEXT:    add v0.8h, v0.8h, v2.8h
+; CHECK-NEXT:    ushr v0.8h, v0.8h, #2
+; CHECK-NEXT:    ret
+  %add0 = add nuw <8 x i16> %a1, %a0
+  %add = add nuw <8 x i16> %add0, splat(i16 1)
+  %avg = lshr <8 x i16> %add, splat(i16 2)
+  ret <8 x i16> %avg
+}
diff --git a/llvm/test/CodeGen/AArch64/sve-hadd.ll b/llvm/test/CodeGen/AArch64/sve-hadd.ll
index 6017e13ce00352..ce440d3095d3f3 100644
--- a/llvm/test/CodeGen/AArch64/sve-hadd.ll
+++ b/llvm/test/CodeGen/AArch64/sve-hadd.ll
@@ -1301,3 +1301,43 @@ entry:
   %result = trunc <vscale x 16 x i16> %s to <vscale x 16 x i8>
   ret <vscale x 16 x i8> %result
 }
+
+define <vscale x 2 x i64> @haddu_v2i64_add(<vscale x 2 x i64> %s0, <vscale x 2 x i64> %s1) {
+; SVE-LABEL: haddu_v2i64_add:
+; SVE:       // %bb.0: // %entry
+; SVE-NEXT:    eor z2.d, z0.d, z1.d
+; SVE-NEXT:    and z0.d, z0.d, z1.d
+; SVE-NEXT:    lsr z1.d, z2.d, #1
+; SVE-NEXT:    add z0.d, z0.d, z1.d
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: haddu_v2i64_add:
+; SVE2:       // %bb.0: // %entry
+; SVE2-NEXT:    ptrue p0.d
+; SVE2-NEXT:    uhadd z0.d, p0/m, z0.d, z1.d
+; SVE2-NEXT:    ret
+entry:
+  %add = add nuw nsw <vscale x 2 x i64> %s0, %s1
+  %avg = lshr <vscale x 2 x i64> %add, splat (i64 1)
+  ret <vscale x 2 x i64> %avg
+}
+
+define <vscale x 2 x i64> @hadds_v2i64_add(<vscale x 2 x i64> %s0, <vscale x 2 x i64> %s1) {
+; SVE-LABEL: hadds_v2i64_add:
+; SVE:       // %bb.0: // %entry
+; SVE-NEXT:    eor z2.d, z0.d, z1.d
+; SVE-NEXT:    and z0.d, z0.d, z1.d
+; SVE-NEXT:    asr z1.d, z2.d, #1
+; SVE-NEXT:    add z0.d, z0.d, z1.d
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: hadds_v2i64_add:
+; SVE2:       // %bb.0: // %entry
+; SVE2-NEXT:    ptrue p0.d
+; SVE2-NEXT:    shadd z0.d, p0/m, z0.d, z1.d
+; SVE2-NEXT:    ret
+entry:
+  %add = add nuw nsw <vscale x 2 x i64> %s0, %s1
+  %avg = ashr <vscale x 2 x i64> %add, splat (i64 1)
+  ret <vscale x 2 x i64> %avg
+}



More information about the llvm-commits mailing list