[llvm] [DAGCombiner] Add combine avg from shifts (PR #113909)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 30 08:55:27 PDT 2024
https://github.com/dnsampaio updated https://github.com/llvm/llvm-project/pull/113909
>From bd5312ec825f717447f86ae7baab77e47b008ee1 Mon Sep 17 00:00:00 2001
From: Diogo Sampaio <dsampaio at kalrayinc.com>
Date: Mon, 28 Oct 2024 13:14:40 +0100
Subject: [PATCH] [DAGCombiner] Add combine avg from shifts
This teaches dagcombiner to fold:
`(asr (add nsw x, y), 1) -> (avgfloors x, y)`
`(lsr (add nuw x, y), 1) -> (avgflooru x, y)`
as well the combine them to a ceil variant:
`(avgfloors (add nsw x, y), 1) -> (avgceils x, y)`
`(avgflooru (add nuw x, y), 1) -> (avgceilu x, y)`
iff valid for the target.
Removes some of the ARM MVE patterns that are now dead code.
It adds the avg opcodes to `IsQRMVEInstruction` as to preserve the immediate splatting as before.
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 76 +++++++
llvm/lib/Target/ARM/ARMISelLowering.cpp | 2 +
llvm/lib/Target/ARM/ARMInstrMVE.td | 85 +-------
llvm/test/CodeGen/AArch64/avg.ll | 202 ++++++++++++++++++
llvm/test/CodeGen/AArch64/sve-hadd.ll | 40 ++++
5 files changed, 329 insertions(+), 76 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ceaf5d664131c3..0e8f10394e895f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -401,6 +401,8 @@ namespace {
SDValue PromoteExtend(SDValue Op);
bool PromoteLoad(SDValue Op);
+ SDValue foldShiftToAvg(SDNode *N);
+
SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
SDValue RHS, SDValue True, SDValue False,
ISD::CondCode CC);
@@ -5354,6 +5356,27 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
}
+ // Fold avgfloor((add nw x,y), 1) -> avgceil(x,y)
+ // Fold avgfloor((add nw x,1), y) -> avgceil(x,y)
+ if ((Opcode == ISD::AVGFLOORU && hasOperation(ISD::AVGCEILU, VT)) ||
+ (Opcode == ISD::AVGFLOORS && hasOperation(ISD::AVGCEILS, VT))) {
+ SDValue Add;
+ if (sd_match(N,
+ m_c_BinOp(Opcode,
+ m_AllOf(m_Value(Add), m_Add(m_Value(X), m_Value(Y))),
+ m_One())) ||
+ sd_match(N, m_c_BinOp(Opcode,
+ m_AllOf(m_Value(Add), m_Add(m_Value(X), m_One())),
+ m_Value(Y)))) {
+
+ if (IsSigned && Add->getFlags().hasNoSignedWrap())
+ return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
+
+ if (!IsSigned && Add->getFlags().hasNoUnsignedWrap())
+ return DAG.getNode(ISD::AVGCEILU, DL, VT, X, Y);
+ }
+ }
+
return SDValue();
}
@@ -10635,6 +10658,9 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
if (SDValue NarrowLoad = reduceLoadWidth(N))
return NarrowLoad;
+ if (SDValue AVG = foldShiftToAvg(N))
+ return AVG;
+
return SDValue();
}
@@ -10889,6 +10915,9 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
return MULH;
+ if (SDValue AVG = foldShiftToAvg(N))
+ return AVG;
+
return SDValue();
}
@@ -11402,6 +11431,53 @@ static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
}
}
+SDValue DAGCombiner::foldShiftToAvg(SDNode *N) {
+ const unsigned Opcode = N->getOpcode();
+
+ // Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
+ if (Opcode != ISD::SRA && Opcode != ISD::SRL)
+ return SDValue();
+
+ unsigned FloorISD = 0;
+ auto VT = N->getValueType(0);
+ bool IsUnsigned = false;
+
+ // Decide wether signed or unsigned.
+ switch (Opcode) {
+ case ISD::SRA:
+ if (!hasOperation(ISD::AVGFLOORS, VT))
+ return SDValue();
+ FloorISD = ISD::AVGFLOORS;
+ break;
+ case ISD::SRL:
+ IsUnsigned = true;
+ if (!hasOperation(ISD::AVGFLOORU, VT))
+ return SDValue();
+ FloorISD = ISD::AVGFLOORU;
+ break;
+ default:
+ return SDValue();
+ }
+
+ // Captured values.
+ SDValue A, B, Add;
+
+ // Match floor average as it is common to both floor/ceil avgs.
+ if (!sd_match(N, m_BinOp(Opcode,
+ m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+ m_One())))
+ return SDValue();
+
+ // Can't optimize adds that may wrap.
+ if (IsUnsigned && !Add->getFlags().hasNoUnsignedWrap())
+ return SDValue();
+
+ if (!IsUnsigned && !Add->getFlags().hasNoSignedWrap())
+ return SDValue();
+
+ return DAG.getNode(FloorISD, SDLoc(N), N->getValueType(0), {A, B});
+}
+
/// Generate Min/Max node
SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
SDValue RHS, SDValue True,
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index a98b7a8420927e..e08c0c0653eba2 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -7951,6 +7951,8 @@ static bool IsQRMVEInstruction(const SDNode *N, const SDNode *Op) {
case ISD::MUL:
case ISD::SADDSAT:
case ISD::UADDSAT:
+ case ISD::AVGFLOORS:
+ case ISD::AVGFLOORU:
return true;
case ISD::SUB:
case ISD::SSUBSAT:
diff --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td
index 04d5d00eef10e6..8c8403ac58b080 100644
--- a/llvm/lib/Target/ARM/ARMInstrMVE.td
+++ b/llvm/lib/Target/ARM/ARMInstrMVE.td
@@ -2222,64 +2222,6 @@ defm MVE_VRHADDu8 : MVE_VRHADD<MVE_v16u8, avgceilu>;
defm MVE_VRHADDu16 : MVE_VRHADD<MVE_v8u16, avgceilu>;
defm MVE_VRHADDu32 : MVE_VRHADD<MVE_v4u32, avgceilu>;
-// Rounding Halving Add perform the arithemtic operation with an extra bit of
-// precision, before performing the shift, to void clipping errors. We're not
-// modelling that here with these patterns, but we're using no wrap forms of
-// add to ensure that the extra bit of information is not needed for the
-// arithmetic or the rounding.
-let Predicates = [HasMVEInt] in {
- def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
- (v16i8 (ARMvmovImm (i32 3585)))),
- (i32 1))),
- (MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>;
- def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
- (v8i16 (ARMvmovImm (i32 2049)))),
- (i32 1))),
- (MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>;
- def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
- (v4i32 (ARMvmovImm (i32 1)))),
- (i32 1))),
- (MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>;
- def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
- (v16i8 (ARMvmovImm (i32 3585)))),
- (i32 1))),
- (MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>;
- def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
- (v8i16 (ARMvmovImm (i32 2049)))),
- (i32 1))),
- (MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>;
- def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
- (v4i32 (ARMvmovImm (i32 1)))),
- (i32 1))),
- (MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>;
-
- def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
- (v16i8 (ARMvdup (i32 1)))),
- (i32 1))),
- (MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>;
- def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
- (v8i16 (ARMvdup (i32 1)))),
- (i32 1))),
- (MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>;
- def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
- (v4i32 (ARMvdup (i32 1)))),
- (i32 1))),
- (MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>;
- def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
- (v16i8 (ARMvdup (i32 1)))),
- (i32 1))),
- (MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>;
- def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
- (v8i16 (ARMvdup (i32 1)))),
- (i32 1))),
- (MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>;
- def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
- (v4i32 (ARMvdup (i32 1)))),
- (i32 1))),
- (MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>;
-}
-
-
class MVE_VHADDSUB<string iname, string suffix, bit U, bit subtract,
bits<2> size, list<dag> pattern=[]>
: MVE_int<iname, suffix, size, pattern> {
@@ -2303,8 +2245,7 @@ class MVE_VHSUB_<string suffix, bit U, bits<2> size,
: MVE_VHADDSUB<"vhsub", suffix, U, 0b1, size, pattern>;
multiclass MVE_VHADD_m<MVEVectorVTInfo VTI, SDNode Op,
- SDPatternOperator unpred_op, Intrinsic PredInt, PatFrag add_op,
- SDNode shift_op> {
+ SDPatternOperator unpred_op, Intrinsic PredInt> {
def "" : MVE_VHADD_<VTI.Suffix, VTI.Unsigned, VTI.Size>;
defvar Inst = !cast<Instruction>(NAME);
defm : MVE_TwoOpPattern<VTI, Op, PredInt, (? (i32 VTI.Unsigned)), !cast<Instruction>(NAME)>;
@@ -2313,26 +2254,18 @@ multiclass MVE_VHADD_m<MVEVectorVTInfo VTI, SDNode Op,
// Unpredicated add-and-divide-by-two
def : Pat<(VTI.Vec (unpred_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn), (i32 VTI.Unsigned))),
(VTI.Vec (Inst (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)))>;
-
- def : Pat<(VTI.Vec (shift_op (add_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)), (i32 1))),
- (Inst MQPR:$Qm, MQPR:$Qn)>;
}
}
-multiclass MVE_VHADD<MVEVectorVTInfo VTI, SDNode Op, PatFrag add_op, SDNode shift_op>
- : MVE_VHADD_m<VTI, Op, int_arm_mve_vhadd, int_arm_mve_hadd_predicated, add_op,
- shift_op>;
+multiclass MVE_VHADD<MVEVectorVTInfo VTI, SDNode Op>
+ : MVE_VHADD_m<VTI, Op, int_arm_mve_vhadd, int_arm_mve_hadd_predicated>;
-// Halving add/sub perform the arithemtic operation with an extra bit of
-// precision, before performing the shift, to void clipping errors. We're not
-// modelling that here with these patterns, but we're using no wrap forms of
-// add/sub to ensure that the extra bit of information is not needed.
-defm MVE_VHADDs8 : MVE_VHADD<MVE_v16s8, avgfloors, addnsw, ARMvshrsImm>;
-defm MVE_VHADDs16 : MVE_VHADD<MVE_v8s16, avgfloors, addnsw, ARMvshrsImm>;
-defm MVE_VHADDs32 : MVE_VHADD<MVE_v4s32, avgfloors, addnsw, ARMvshrsImm>;
-defm MVE_VHADDu8 : MVE_VHADD<MVE_v16u8, avgflooru, addnuw, ARMvshruImm>;
-defm MVE_VHADDu16 : MVE_VHADD<MVE_v8u16, avgflooru, addnuw, ARMvshruImm>;
-defm MVE_VHADDu32 : MVE_VHADD<MVE_v4u32, avgflooru, addnuw, ARMvshruImm>;
+defm MVE_VHADDs8 : MVE_VHADD<MVE_v16s8, avgfloors>;
+defm MVE_VHADDs16 : MVE_VHADD<MVE_v8s16, avgfloors>;
+defm MVE_VHADDs32 : MVE_VHADD<MVE_v4s32, avgfloors>;
+defm MVE_VHADDu8 : MVE_VHADD<MVE_v16u8, avgflooru>;
+defm MVE_VHADDu16 : MVE_VHADD<MVE_v8u16, avgflooru>;
+defm MVE_VHADDu32 : MVE_VHADD<MVE_v4u32, avgflooru>;
multiclass MVE_VHSUB_m<MVEVectorVTInfo VTI,
SDPatternOperator unpred_op, Intrinsic pred_int, PatFrag sub_op,
diff --git a/llvm/test/CodeGen/AArch64/avg.ll b/llvm/test/CodeGen/AArch64/avg.ll
index cabc0d346b806f..ea07b10c22c2e7 100644
--- a/llvm/test/CodeGen/AArch64/avg.ll
+++ b/llvm/test/CodeGen/AArch64/avg.ll
@@ -146,3 +146,205 @@ define <16 x i16> @sext_avgceils_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
%avg = sub <16 x i16> %or, %shift
ret <16 x i16> %avg
}
+
+define <8 x i16> @add_avgflooru(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgflooru:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: ret
+ %add = add nuw <8 x i16> %a0, %a1
+ %avg = lshr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgflooru_mismatch(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgflooru_mismatch:
+; CHECK: // %bb.0:
+; CHECK-NEXT: add v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: ushr v0.8h, v0.8h, #1
+; CHECK-NEXT: ret
+ %add = add <8 x i16> %a0, %a1
+ %avg = lshr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu:
+; CHECK: // %bb.0:
+; CHECK-NEXT: urhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: ret
+ %add0 = add nuw <8 x i16> %a0, splat(i16 1)
+ %add = add nuw <8 x i16> %a1, %add0
+ %avg = lshr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu2(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu2:
+; CHECK: // %bb.0:
+; CHECK-NEXT: urhadd v0.8h, v1.8h, v0.8h
+; CHECK-NEXT: ret
+ %add0 = add nuw <8 x i16> %a1, %a0
+ %add = add nuw <8 x i16> %add0, splat(i16 1)
+ %avg = lshr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu_mismatch1(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu_mismatch1:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v2.8h, #1
+; CHECK-NEXT: add v0.8h, v1.8h, v0.8h
+; CHECK-NEXT: uhadd v0.8h, v0.8h, v2.8h
+; CHECK-NEXT: ret
+ %add0 = add <8 x i16> %a1, %a0
+ %add = add nuw <8 x i16> %add0, splat(i16 1)
+ %avg = lshr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu_mismatch2(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu_mismatch2:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mvn v1.16b, v1.16b
+; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: ushr v0.8h, v0.8h, #1
+; CHECK-NEXT: ret
+ %add0 = add nuw <8 x i16> %a1, %a0
+ %add = add <8 x i16> %add0, splat(i16 1)
+ %avg = lshr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu_mismatch3(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu_mismatch3:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mvn v1.16b, v1.16b
+; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: ushr v0.8h, v0.8h, #1
+; CHECK-NEXT: ret
+ %add0 = add nuw <8 x i16> %a1, %a0
+ %add = add <8 x i16> %add0, splat(i16 1)
+ %avg = lshr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgfloors(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgfloors:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: ret
+ %add = add nsw <8 x i16> %a0, %a1
+ %avg = ashr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgfloors_mismatch(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgfloors_mismatch:
+; CHECK: // %bb.0:
+; CHECK-NEXT: add v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: sshr v0.8h, v0.8h, #1
+; CHECK-NEXT: ret
+ %add = add <8 x i16> %a0, %a1
+ %avg = ashr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgfoor_mismatch2(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgfoor_mismatch2:
+; CHECK: // %bb.0:
+; CHECK-NEXT: add v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: sshr v0.8h, v0.8h, #2
+; CHECK-NEXT: ret
+ %add = add nsw <8 x i16> %a0, %a1
+ %avg = ashr <8 x i16> %add, splat(i16 2)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils:
+; CHECK: // %bb.0:
+; CHECK-NEXT: srhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: ret
+ %add0 = add nsw <8 x i16> %a0, splat(i16 1)
+ %add = add nsw <8 x i16> %a1, %add0
+ %avg = ashr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils2(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils2:
+; CHECK: // %bb.0:
+; CHECK-NEXT: srhadd v0.8h, v1.8h, v0.8h
+; CHECK-NEXT: ret
+ %add0 = add nsw <8 x i16> %a1, %a0
+ %add = add nsw <8 x i16> %add0, splat(i16 1)
+ %avg = ashr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils_mismatch1(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils_mismatch1:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v2.8h, #1
+; CHECK-NEXT: add v0.8h, v1.8h, v0.8h
+; CHECK-NEXT: shadd v0.8h, v0.8h, v2.8h
+; CHECK-NEXT: ret
+ %add0 = add <8 x i16> %a1, %a0
+ %add = add nsw <8 x i16> %add0, splat(i16 1)
+ %avg = ashr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils_mismatch2(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils_mismatch2:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mvn v1.16b, v1.16b
+; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: sshr v0.8h, v0.8h, #1
+; CHECK-NEXT: ret
+ %add0 = add nsw <8 x i16> %a1, %a0
+ %add = add <8 x i16> %add0, splat(i16 1)
+ %avg = ashr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils_mismatch3(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils_mismatch3:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mvn v1.16b, v1.16b
+; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: sshr v0.8h, v0.8h, #1
+; CHECK-NEXT: ret
+ %add0 = add nsw <8 x i16> %a1, %a0
+ %add = add <8 x i16> %add0, splat(i16 1)
+ %avg = ashr <8 x i16> %add, splat(i16 1)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceils_mismatch4(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceils_mismatch4:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mvn v0.16b, v0.16b
+; CHECK-NEXT: sub v0.8h, v1.8h, v0.8h
+; CHECK-NEXT: sshr v0.8h, v0.8h, #2
+; CHECK-NEXT: ret
+ %add0 = add nsw <8 x i16> %a0, splat(i16 1)
+ %add = add nsw <8 x i16> %a1, %add0
+ %avg = ashr <8 x i16> %add, splat(i16 2)
+ ret <8 x i16> %avg
+}
+
+define <8 x i16> @add_avgceilu_mismatch(<8 x i16> %a0, <8 x i16> %a1) {
+; CHECK-LABEL: add_avgceilu_mismatch:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v2.8h, #1
+; CHECK-NEXT: add v0.8h, v1.8h, v0.8h
+; CHECK-NEXT: add v0.8h, v0.8h, v2.8h
+; CHECK-NEXT: ushr v0.8h, v0.8h, #2
+; CHECK-NEXT: ret
+ %add0 = add nuw <8 x i16> %a1, %a0
+ %add = add nuw <8 x i16> %add0, splat(i16 1)
+ %avg = lshr <8 x i16> %add, splat(i16 2)
+ ret <8 x i16> %avg
+}
diff --git a/llvm/test/CodeGen/AArch64/sve-hadd.ll b/llvm/test/CodeGen/AArch64/sve-hadd.ll
index 6017e13ce00352..ce440d3095d3f3 100644
--- a/llvm/test/CodeGen/AArch64/sve-hadd.ll
+++ b/llvm/test/CodeGen/AArch64/sve-hadd.ll
@@ -1301,3 +1301,43 @@ entry:
%result = trunc <vscale x 16 x i16> %s to <vscale x 16 x i8>
ret <vscale x 16 x i8> %result
}
+
+define <vscale x 2 x i64> @haddu_v2i64_add(<vscale x 2 x i64> %s0, <vscale x 2 x i64> %s1) {
+; SVE-LABEL: haddu_v2i64_add:
+; SVE: // %bb.0: // %entry
+; SVE-NEXT: eor z2.d, z0.d, z1.d
+; SVE-NEXT: and z0.d, z0.d, z1.d
+; SVE-NEXT: lsr z1.d, z2.d, #1
+; SVE-NEXT: add z0.d, z0.d, z1.d
+; SVE-NEXT: ret
+;
+; SVE2-LABEL: haddu_v2i64_add:
+; SVE2: // %bb.0: // %entry
+; SVE2-NEXT: ptrue p0.d
+; SVE2-NEXT: uhadd z0.d, p0/m, z0.d, z1.d
+; SVE2-NEXT: ret
+entry:
+ %add = add nuw nsw <vscale x 2 x i64> %s0, %s1
+ %avg = lshr <vscale x 2 x i64> %add, splat (i64 1)
+ ret <vscale x 2 x i64> %avg
+}
+
+define <vscale x 2 x i64> @hadds_v2i64_add(<vscale x 2 x i64> %s0, <vscale x 2 x i64> %s1) {
+; SVE-LABEL: hadds_v2i64_add:
+; SVE: // %bb.0: // %entry
+; SVE-NEXT: eor z2.d, z0.d, z1.d
+; SVE-NEXT: and z0.d, z0.d, z1.d
+; SVE-NEXT: asr z1.d, z2.d, #1
+; SVE-NEXT: add z0.d, z0.d, z1.d
+; SVE-NEXT: ret
+;
+; SVE2-LABEL: hadds_v2i64_add:
+; SVE2: // %bb.0: // %entry
+; SVE2-NEXT: ptrue p0.d
+; SVE2-NEXT: shadd z0.d, p0/m, z0.d, z1.d
+; SVE2-NEXT: ret
+entry:
+ %add = add nuw nsw <vscale x 2 x i64> %s0, %s1
+ %avg = ashr <vscale x 2 x i64> %add, splat (i64 1)
+ ret <vscale x 2 x i64> %avg
+}
More information about the llvm-commits
mailing list