[llvm] 12bd049 - [AArch64] Enable fixed-length vector support for partial-reductions (#142032)
via llvm-commits
llvm-commits at lists.llvm.org
Fri May 30 09:47:34 PDT 2025
Author: Sander de Smalen
Date: 2025-05-30T17:47:31+01:00
New Revision: 12bd04951054f3b7a4c604d84820ba809b67cedb
URL: https://github.com/llvm/llvm-project/commit/12bd04951054f3b7a4c604d84820ba809b67cedb
DIFF: https://github.com/llvm/llvm-project/commit/12bd04951054f3b7a4c604d84820ba809b67cedb.diff
LOG: [AArch64] Enable fixed-length vector support for partial-reductions (#142032)
This enables the use of the [us]dot, [us]add[wt] and [us]mlal[bt]
instructions in Streaming mode, and for wider vectors when the runtime
vector length is known to be 256bits or larger.
Added:
llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 349bcd95c09f6..ef68734fa039c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1935,6 +1935,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
Custom);
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
Custom);
+
+ if (EnablePartialReduceNodes) {
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+ ISD::PARTIAL_REDUCE_UMLA};
+ // Must be lowered to SVE instructions.
+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
+ }
}
}
@@ -2230,6 +2242,28 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
+ if (EnablePartialReduceNodes) {
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+ ISD::PARTIAL_REDUCE_UMLA};
+ unsigned NumElts = VT.getVectorNumElements();
+ if (VT.getVectorElementType() == MVT::i64) {
+ setPartialReduceMLAAction(MLAOps, VT,
+ MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
+ setPartialReduceMLAAction(
+ MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
+ setPartialReduceMLAAction(
+ MLAOps, VT, MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
+ } else if (VT.getVectorElementType() == MVT::i32) {
+ setPartialReduceMLAAction(MLAOps, VT,
+ MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
+ setPartialReduceMLAAction(
+ MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
+ } else if (VT.getVectorElementType() == MVT::i16) {
+ setPartialReduceMLAAction(MLAOps, VT,
+ MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
+ }
+ }
+
// Lower fixed length vector operations to scalable equivalents.
setOperationAction(ISD::ABDS, VT, Default);
setOperationAction(ISD::ABDU, VT, Default);
@@ -29251,50 +29285,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
- bool Scalable = Op.getValueType().isScalableVector();
-
- assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
- "SVE or StreamingSVE must be available when using scalable vectors.");
- assert((Scalable || Subtarget->hasDotProd()) &&
- "Dotprod must be available when targeting NEON dot product "
- "instructions.");
-
SDLoc DL(Op);
SDValue Acc = Op.getOperand(0);
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
+ EVT OrigResultVT = ResultVT;
+ EVT OpVT = LHS.getValueType();
- assert((Scalable && ResultVT == MVT::nxv2i64 &&
- LHS.getValueType() == MVT::nxv16i8) ||
- (!Scalable && ResultVT == MVT::v2i64 &&
- LHS.getValueType() == MVT::v16i8));
+ bool ConvertToScalable =
+ ResultVT.isFixedLengthVector() &&
+ useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
- EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
+ if (ConvertToScalable) {
+ ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
+ OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
+ Acc = convertToScalableVector(DAG, ResultVT, Acc);
+ LHS = convertToScalableVector(DAG, OpVT, LHS);
+ RHS = convertToScalableVector(DAG, OpVT, RHS);
+ Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
+ }
+
+ // Two-way and four-way partial reductions are supported by patterns.
+ // We only need to handle the 8-way partial reduction.
+ if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
+ : Op;
+
+ EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
DAG.getConstant(0, DL, DotVT), LHS, RHS);
+ SDValue Res;
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
- if (Scalable &&
- (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
+ if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
- return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
- }
-
- // Fold (nx)v4i32 into (nx)v2i64
- auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
- if (IsUnsigned) {
- DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
- DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
+ Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
} else {
- DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
- DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
+ // Fold (nx)v4i32 into (nx)v2i64
+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
+ if (IsUnsigned) {
+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
+ } else {
+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
+ }
+ auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
+ Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
}
- auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
- return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
+
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
+ : Res;
}
SDValue
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
new file mode 100644
index 0000000000000..79d766d1b9908
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
@@ -0,0 +1,791 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
+; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
+; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
+
+target triple = "aarch64"
+
+;
+; Two-way mla (i8 -> i16)
+;
+
+define <8 x i16> @two_way_i8_i16_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i8_i16_vl128:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: umlal v0.8h, v2.8b, v1.8b
+; COMMON-NEXT: umlal2 v0.8h, v2.16b, v1.16b
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i8_i16_vl128:
+; SME: // %bb.0:
+; SME-NEXT: ldr q0, [x0]
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: umlalb z0.h, z2.b, z1.b
+; SME-NEXT: umlalt z0.h, z2.b, z1.b
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <8 x i16>, ptr %accptr
+ %u = load <16 x i8>, ptr %uptr
+ %s = load <16 x i8>, ptr %sptr
+ %u.wide = zext <16 x i8> %u to <16 x i16>
+ %s.wide = zext <16 x i8> %s to <16 x i16>
+ %mult = mul nuw nsw <16 x i16> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i16> @llvm.experimental.vector.partial.reduce.add(<8 x i16> %acc, <16 x i16> %mult)
+ ret <8 x i16> %partial.reduce
+}
+
+define <16 x i16> @two_way_i8_i16_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i8_i16_vl128_double_width:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldp q0, q1, [x0]
+; COMMON-NEXT: ldp q2, q3, [x1]
+; COMMON-NEXT: ldp q4, q5, [x2]
+; COMMON-NEXT: umlal v0.8h, v4.8b, v2.8b
+; COMMON-NEXT: umlal v1.8h, v5.8b, v3.8b
+; COMMON-NEXT: umlal2 v0.8h, v4.16b, v2.16b
+; COMMON-NEXT: umlal2 v1.8h, v5.16b, v3.16b
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i8_i16_vl128_double_width:
+; SME: // %bb.0:
+; SME-NEXT: ldp q0, q1, [x0]
+; SME-NEXT: ldp q3, q2, [x1]
+; SME-NEXT: ldp q5, q4, [x2]
+; SME-NEXT: umlalb z0.h, z5.b, z3.b
+; SME-NEXT: umlalb z1.h, z4.b, z2.b
+; SME-NEXT: umlalt z0.h, z5.b, z3.b
+; SME-NEXT: umlalt z1.h, z4.b, z2.b
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <16 x i16>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i16>
+ %s.wide = zext <32 x i8> %s to <32 x i16>
+ %mult = mul nuw nsw <32 x i16> %s.wide, %u.wide
+ %partial.reduce = tail call <16 x i16> @llvm.experimental.vector.partial.reduce.add(<16 x i16> %acc, <32 x i16> %mult)
+ ret <16 x i16> %partial.reduce
+}
+
+define <16 x i16> @two_way_i8_i16_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i8_i16_vl256:
+; NEON: // %bb.0:
+; NEON-NEXT: ldp q0, q1, [x0]
+; NEON-NEXT: ldp q2, q3, [x1]
+; NEON-NEXT: ldp q4, q5, [x2]
+; NEON-NEXT: umlal v0.8h, v4.8b, v2.8b
+; NEON-NEXT: umlal v1.8h, v5.8b, v3.8b
+; NEON-NEXT: umlal2 v0.8h, v4.16b, v2.16b
+; NEON-NEXT: umlal2 v1.8h, v5.16b, v3.16b
+; NEON-NEXT: ret
+;
+; SVE-LABEL: two_way_i8_i16_vl256:
+; SVE: // %bb.0:
+; SVE-NEXT: ldr z0, [x1]
+; SVE-NEXT: ldr z1, [x2]
+; SVE-NEXT: ptrue p0.h
+; SVE-NEXT: ldr z4, [x0]
+; SVE-NEXT: uunpklo z2.h, z0.b
+; SVE-NEXT: uunpklo z3.h, z1.b
+; SVE-NEXT: uunpkhi z0.h, z0.b
+; SVE-NEXT: uunpkhi z1.h, z1.b
+; SVE-NEXT: mad z2.h, p0/m, z3.h, z4.h
+; SVE-NEXT: mad z0.h, p0/m, z1.h, z2.h
+; SVE-NEXT: mov z1.d, z0.d
+; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: two_way_i8_i16_vl256:
+; SME: // %bb.0:
+; SME-NEXT: ldr z0, [x0]
+; SME-NEXT: ldr z1, [x1]
+; SME-NEXT: ldr z2, [x2]
+; SME-NEXT: umlalb z0.h, z2.b, z1.b
+; SME-NEXT: umlalt z0.h, z2.b, z1.b
+; SME-NEXT: mov z1.d, z0.d
+; SME-NEXT: ext z1.b, z1.b, z0.b, #16
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <16 x i16>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i16>
+ %s.wide = zext <32 x i8> %s to <32 x i16>
+ %mult = mul nuw nsw <32 x i16> %s.wide, %u.wide
+ %partial.reduce = tail call <16 x i16> @llvm.experimental.vector.partial.reduce.add(<16 x i16> %acc, <32 x i16> %mult)
+ ret <16 x i16> %partial.reduce
+}
+
+;
+; Two-way mla (i16 -> i32)
+;
+
+define <4 x i32> @two_way_i16_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i16_i32_vl128:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: umlal v0.4s, v2.4h, v1.4h
+; COMMON-NEXT: umlal2 v0.4s, v2.8h, v1.8h
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i16_i32_vl128:
+; SME: // %bb.0:
+; SME-NEXT: ldr q0, [x0]
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: umlalb z0.s, z2.h, z1.h
+; SME-NEXT: umlalt z0.s, z2.h, z1.h
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <4 x i32>, ptr %accptr
+ %u = load <8 x i16>, ptr %uptr
+ %s = load <8 x i16>, ptr %sptr
+ %u.wide = zext <8 x i16> %u to <8 x i32>
+ %s.wide = zext <8 x i16> %s to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <8 x i32> @two_way_i16_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i16_i32_vl128_double_width:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldp q0, q1, [x0]
+; COMMON-NEXT: ldp q2, q3, [x1]
+; COMMON-NEXT: ldp q4, q5, [x2]
+; COMMON-NEXT: umlal v0.4s, v4.4h, v2.4h
+; COMMON-NEXT: umlal v1.4s, v5.4h, v3.4h
+; COMMON-NEXT: umlal2 v0.4s, v4.8h, v2.8h
+; COMMON-NEXT: umlal2 v1.4s, v5.8h, v3.8h
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i16_i32_vl128_double_width:
+; SME: // %bb.0:
+; SME-NEXT: ldp q0, q1, [x0]
+; SME-NEXT: ldp q3, q2, [x1]
+; SME-NEXT: ldp q5, q4, [x2]
+; SME-NEXT: umlalb z0.s, z5.h, z3.h
+; SME-NEXT: umlalb z1.s, z4.h, z2.h
+; SME-NEXT: umlalt z0.s, z5.h, z3.h
+; SME-NEXT: umlalt z1.s, z4.h, z2.h
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <8 x i32>, ptr %accptr
+ %u = load <16 x i16>, ptr %uptr
+ %s = load <16 x i16>, ptr %sptr
+ %u.wide = zext <16 x i16> %u to <16 x i32>
+ %s.wide = zext <16 x i16> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
+ ret <8 x i32> %partial.reduce
+}
+
+define <8 x i32> @two_way_i16_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i16_i32_vl256:
+; NEON: // %bb.0:
+; NEON-NEXT: ldp q0, q1, [x0]
+; NEON-NEXT: ldp q2, q3, [x1]
+; NEON-NEXT: ldp q4, q5, [x2]
+; NEON-NEXT: umlal v0.4s, v4.4h, v2.4h
+; NEON-NEXT: umlal v1.4s, v5.4h, v3.4h
+; NEON-NEXT: umlal2 v0.4s, v4.8h, v2.8h
+; NEON-NEXT: umlal2 v1.4s, v5.8h, v3.8h
+; NEON-NEXT: ret
+;
+; SVE-LABEL: two_way_i16_i32_vl256:
+; SVE: // %bb.0:
+; SVE-NEXT: ldr z0, [x1]
+; SVE-NEXT: ldr z1, [x2]
+; SVE-NEXT: ptrue p0.s
+; SVE-NEXT: ldr z4, [x0]
+; SVE-NEXT: uunpklo z2.s, z0.h
+; SVE-NEXT: uunpklo z3.s, z1.h
+; SVE-NEXT: uunpkhi z0.s, z0.h
+; SVE-NEXT: uunpkhi z1.s, z1.h
+; SVE-NEXT: mad z2.s, p0/m, z3.s, z4.s
+; SVE-NEXT: mad z0.s, p0/m, z1.s, z2.s
+; SVE-NEXT: mov z1.d, z0.d
+; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: two_way_i16_i32_vl256:
+; SME: // %bb.0:
+; SME-NEXT: ldr z0, [x0]
+; SME-NEXT: ldr z1, [x1]
+; SME-NEXT: ldr z2, [x2]
+; SME-NEXT: umlalb z0.s, z2.h, z1.h
+; SME-NEXT: umlalt z0.s, z2.h, z1.h
+; SME-NEXT: mov z1.d, z0.d
+; SME-NEXT: ext z1.b, z1.b, z0.b, #16
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <8 x i32>, ptr %accptr
+ %u = load <16 x i16>, ptr %uptr
+ %s = load <16 x i16>, ptr %sptr
+ %u.wide = zext <16 x i16> %u to <16 x i32>
+ %s.wide = zext <16 x i16> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
+ ret <8 x i32> %partial.reduce
+}
+
+;
+; Two-way mla (i32 -> i64)
+;
+
+define <2 x i64> @two_way_i32_i64_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i32_i64_vl128:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: umlal v0.2d, v2.2s, v1.2s
+; COMMON-NEXT: umlal2 v0.2d, v2.4s, v1.4s
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i32_i64_vl128:
+; SME: // %bb.0:
+; SME-NEXT: ldr q0, [x0]
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: umlalb z0.d, z2.s, z1.s
+; SME-NEXT: umlalt z0.d, z2.s, z1.s
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <2 x i64>, ptr %accptr
+ %u = load <4 x i32>, ptr %uptr
+ %s = load <4 x i32>, ptr %sptr
+ %u.wide = zext <4 x i32> %u to <4 x i64>
+ %s.wide = zext <4 x i32> %s to <4 x i64>
+ %mult = mul nuw nsw <4 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <4 x i64> %mult)
+ ret <2 x i64> %partial.reduce
+}
+
+define <4 x i64> @two_way_i32_i64_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i32_i64_vl128_double_width:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldp q0, q1, [x0]
+; COMMON-NEXT: ldp q2, q3, [x1]
+; COMMON-NEXT: ldp q4, q5, [x2]
+; COMMON-NEXT: umlal v0.2d, v4.2s, v2.2s
+; COMMON-NEXT: umlal v1.2d, v5.2s, v3.2s
+; COMMON-NEXT: umlal2 v0.2d, v4.4s, v2.4s
+; COMMON-NEXT: umlal2 v1.2d, v5.4s, v3.4s
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i32_i64_vl128_double_width:
+; SME: // %bb.0:
+; SME-NEXT: ldp q0, q1, [x0]
+; SME-NEXT: ldp q3, q2, [x1]
+; SME-NEXT: ldp q5, q4, [x2]
+; SME-NEXT: umlalb z0.d, z5.s, z3.s
+; SME-NEXT: umlalb z1.d, z4.s, z2.s
+; SME-NEXT: umlalt z0.d, z5.s, z3.s
+; SME-NEXT: umlalt z1.d, z4.s, z2.s
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <4 x i64>, ptr %accptr
+ %u = load <8 x i32>, ptr %uptr
+ %s = load <8 x i32>, ptr %sptr
+ %u.wide = zext <8 x i32> %u to <8 x i64>
+ %s.wide = zext <8 x i32> %s to <8 x i64>
+ %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <8 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @two_way_i32_i64_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i32_i64_vl256:
+; NEON: // %bb.0:
+; NEON-NEXT: ldp q0, q1, [x0]
+; NEON-NEXT: ldp q2, q3, [x1]
+; NEON-NEXT: ldp q4, q5, [x2]
+; NEON-NEXT: umlal v0.2d, v4.2s, v2.2s
+; NEON-NEXT: umlal v1.2d, v5.2s, v3.2s
+; NEON-NEXT: umlal2 v0.2d, v4.4s, v2.4s
+; NEON-NEXT: umlal2 v1.2d, v5.4s, v3.4s
+; NEON-NEXT: ret
+;
+; SVE-LABEL: two_way_i32_i64_vl256:
+; SVE: // %bb.0:
+; SVE-NEXT: ldr z0, [x1]
+; SVE-NEXT: ldr z1, [x2]
+; SVE-NEXT: ptrue p0.d
+; SVE-NEXT: ldr z4, [x0]
+; SVE-NEXT: uunpklo z2.d, z0.s
+; SVE-NEXT: uunpklo z3.d, z1.s
+; SVE-NEXT: uunpkhi z0.d, z0.s
+; SVE-NEXT: uunpkhi z1.d, z1.s
+; SVE-NEXT: mad z2.d, p0/m, z3.d, z4.d
+; SVE-NEXT: mad z0.d, p0/m, z1.d, z2.d
+; SVE-NEXT: mov z1.d, z0.d
+; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: two_way_i32_i64_vl256:
+; SME: // %bb.0:
+; SME-NEXT: ldr z0, [x0]
+; SME-NEXT: ldr z1, [x1]
+; SME-NEXT: ldr z2, [x2]
+; SME-NEXT: umlalb z0.d, z2.s, z1.s
+; SME-NEXT: umlalt z0.d, z2.s, z1.s
+; SME-NEXT: mov z1.d, z0.d
+; SME-NEXT: ext z1.b, z1.b, z0.b, #16
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <4 x i64>, ptr %accptr
+ %u = load <8 x i32>, ptr %uptr
+ %s = load <8 x i32>, ptr %sptr
+ %u.wide = zext <8 x i32> %u to <8 x i64>
+ %s.wide = zext <8 x i32> %s to <8 x i64>
+ %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <8 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+
+;
+; Four-way dot (i8 -> i32)
+;
+
+define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i8_i32_vl128:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: udot v0.4s, v2.16b, v1.16b
+; COMMON-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i32_vl128:
+; SME: // %bb.0:
+; SME-NEXT: ldr q0, [x0]
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: udot z0.s, z2.b, z1.b
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <4 x i32>, ptr %accptr
+ %u = load <16 x i8>, ptr %uptr
+ %s = load <16 x i8>, ptr %sptr
+ %u.wide = zext <16 x i8> %u to <16 x i32>
+ %s.wide = zext <16 x i8> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i8_i32_vl128_double_width:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldp q0, q1, [x0]
+; COMMON-NEXT: ldp q3, q2, [x1]
+; COMMON-NEXT: ldp q5, q4, [x2]
+; COMMON-NEXT: udot v0.4s, v5.16b, v3.16b
+; COMMON-NEXT: udot v1.4s, v4.16b, v2.16b
+; COMMON-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_double_width:
+; SME: // %bb.0:
+; SME-NEXT: ldp q0, q1, [x0]
+; SME-NEXT: ldp q3, q2, [x1]
+; SME-NEXT: ldp q5, q4, [x2]
+; SME-NEXT: udot z0.s, z5.b, z3.b
+; SME-NEXT: udot z1.s, z4.b, z2.b
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <8 x i32>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i32>
+ %s.wide = zext <32 x i8> %s to <32 x i32>
+ %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+ ret <8 x i32> %partial.reduce
+}
+
+define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: four_way_i8_i32_vl256:
+; NEON: // %bb.0:
+; NEON-NEXT: ldp q0, q1, [x0]
+; NEON-NEXT: ldp q3, q2, [x1]
+; NEON-NEXT: ldp q5, q4, [x2]
+; NEON-NEXT: udot v0.4s, v5.16b, v3.16b
+; NEON-NEXT: udot v1.4s, v4.16b, v2.16b
+; NEON-NEXT: ret
+;
+; SVE-LABEL: four_way_i8_i32_vl256:
+; SVE: // %bb.0:
+; SVE-NEXT: ldr z0, [x0]
+; SVE-NEXT: ldr z1, [x1]
+; SVE-NEXT: ldr z2, [x2]
+; SVE-NEXT: udot z0.s, z2.b, z1.b
+; SVE-NEXT: mov z1.d, z0.d
+; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i32_vl256:
+; SME: // %bb.0:
+; SME-NEXT: ldr z0, [x0]
+; SME-NEXT: ldr z1, [x1]
+; SME-NEXT: ldr z2, [x2]
+; SME-NEXT: udot z0.s, z2.b, z1.b
+; SME-NEXT: mov z1.d, z0.d
+; SME-NEXT: ext z1.b, z1.b, z0.b, #16
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <8 x i32>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i32>
+ %s.wide = zext <32 x i8> %s to <32 x i32>
+ %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+ ret <8 x i32> %partial.reduce
+}
+
+;
+; Four-way dot (i16 -> i64)
+;
+
+define <2 x i64> @four_way_i16_i64_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i16_i64_vl128:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x1]
+; COMMON-NEXT: ldr q1, [x2]
+; COMMON-NEXT: ldr q3, [x0]
+; COMMON-NEXT: umull v2.4s, v1.4h, v0.4h
+; COMMON-NEXT: umull2 v0.4s, v1.8h, v0.8h
+; COMMON-NEXT: uaddw v3.2d, v3.2d, v2.2s
+; COMMON-NEXT: uaddw2 v1.2d, v3.2d, v2.4s
+; COMMON-NEXT: uaddw v1.2d, v1.2d, v0.2s
+; COMMON-NEXT: uaddw2 v0.2d, v1.2d, v0.4s
+; COMMON-NEXT: ret
+;
+; SME-LABEL: four_way_i16_i64_vl128:
+; SME: // %bb.0:
+; SME-NEXT: ldr q0, [x0]
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: udot z0.d, z2.h, z1.h
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <2 x i64>, ptr %accptr
+ %u = load <8 x i16>, ptr %uptr
+ %s = load <8 x i16>, ptr %sptr
+ %u.wide = zext <8 x i16> %u to <8 x i64>
+ %s.wide = zext <8 x i16> %s to <8 x i64>
+ %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <8 x i64> %mult)
+ ret <2 x i64> %partial.reduce
+}
+
+define <4 x i64> @four_way_i16_i64_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i16_i64_vl128_double_width:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldp q0, q1, [x1]
+; COMMON-NEXT: ldp q2, q3, [x2]
+; COMMON-NEXT: ldp q7, q6, [x0]
+; COMMON-NEXT: umull v4.4s, v3.4h, v1.4h
+; COMMON-NEXT: umull v5.4s, v2.4h, v0.4h
+; COMMON-NEXT: umull2 v1.4s, v3.8h, v1.8h
+; COMMON-NEXT: umull2 v0.4s, v2.8h, v0.8h
+; COMMON-NEXT: uaddw v7.2d, v7.2d, v5.2s
+; COMMON-NEXT: uaddw v6.2d, v6.2d, v4.2s
+; COMMON-NEXT: uaddw2 v2.2d, v7.2d, v5.4s
+; COMMON-NEXT: uaddw2 v3.2d, v6.2d, v4.4s
+; COMMON-NEXT: uaddw v2.2d, v2.2d, v0.2s
+; COMMON-NEXT: uaddw v3.2d, v3.2d, v1.2s
+; COMMON-NEXT: uaddw2 v0.2d, v2.2d, v0.4s
+; COMMON-NEXT: uaddw2 v1.2d, v3.2d, v1.4s
+; COMMON-NEXT: ret
+;
+; SME-LABEL: four_way_i16_i64_vl128_double_width:
+; SME: // %bb.0:
+; SME-NEXT: ldp q0, q1, [x0]
+; SME-NEXT: ldp q3, q2, [x1]
+; SME-NEXT: ldp q5, q4, [x2]
+; SME-NEXT: udot z0.d, z5.h, z3.h
+; SME-NEXT: udot z1.d, z4.h, z2.h
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <4 x i64>, ptr %accptr
+ %u = load <16 x i16>, ptr %uptr
+ %s = load <16 x i16>, ptr %sptr
+ %u.wide = zext <16 x i16> %u to <16 x i64>
+ %s.wide = zext <16 x i16> %s to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <16 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @four_way_i16_i64_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: four_way_i16_i64_vl256:
+; NEON: // %bb.0:
+; NEON-NEXT: ldp q0, q1, [x1]
+; NEON-NEXT: ldp q2, q3, [x2]
+; NEON-NEXT: ldp q7, q6, [x0]
+; NEON-NEXT: umull v4.4s, v3.4h, v1.4h
+; NEON-NEXT: umull v5.4s, v2.4h, v0.4h
+; NEON-NEXT: umull2 v1.4s, v3.8h, v1.8h
+; NEON-NEXT: umull2 v0.4s, v2.8h, v0.8h
+; NEON-NEXT: uaddw v7.2d, v7.2d, v5.2s
+; NEON-NEXT: uaddw v6.2d, v6.2d, v4.2s
+; NEON-NEXT: uaddw2 v2.2d, v7.2d, v5.4s
+; NEON-NEXT: uaddw2 v3.2d, v6.2d, v4.4s
+; NEON-NEXT: uaddw v2.2d, v2.2d, v0.2s
+; NEON-NEXT: uaddw v3.2d, v3.2d, v1.2s
+; NEON-NEXT: uaddw2 v0.2d, v2.2d, v0.4s
+; NEON-NEXT: uaddw2 v1.2d, v3.2d, v1.4s
+; NEON-NEXT: ret
+;
+; SVE-LABEL: four_way_i16_i64_vl256:
+; SVE: // %bb.0:
+; SVE-NEXT: ldr z0, [x0]
+; SVE-NEXT: ldr z1, [x1]
+; SVE-NEXT: ldr z2, [x2]
+; SVE-NEXT: udot z0.d, z2.h, z1.h
+; SVE-NEXT: mov z1.d, z0.d
+; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: four_way_i16_i64_vl256:
+; SME: // %bb.0:
+; SME-NEXT: ldr z0, [x0]
+; SME-NEXT: ldr z1, [x1]
+; SME-NEXT: ldr z2, [x2]
+; SME-NEXT: udot z0.d, z2.h, z1.h
+; SME-NEXT: mov z1.d, z0.d
+; SME-NEXT: ext z1.b, z1.b, z0.b, #16
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <4 x i64>, ptr %accptr
+ %u = load <16 x i16>, ptr %uptr
+ %s = load <16 x i16>, ptr %sptr
+ %u.wide = zext <16 x i16> %u to <16 x i64>
+ %s.wide = zext <16 x i16> %s to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <16 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+
+;
+; Eight-way dot, requires two steps (i8 -> i64)
+;
+
+define <2 x i64> @eight_way_i8_i64_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; NEON-LABEL: eight_way_i8_i64_vl128:
+; NEON: // %bb.0:
+; NEON-NEXT: movi v0.2d, #0000000000000000
+; NEON-NEXT: ldr q1, [x1]
+; NEON-NEXT: ldr q2, [x2]
+; NEON-NEXT: udot v0.4s, v2.16b, v1.16b
+; NEON-NEXT: ldr q1, [x0]
+; NEON-NEXT: uaddw v1.2d, v1.2d, v0.2s
+; NEON-NEXT: uaddw2 v0.2d, v1.2d, v0.4s
+; NEON-NEXT: ret
+;
+; SVE-LABEL: eight_way_i8_i64_vl128:
+; SVE: // %bb.0:
+; SVE-NEXT: movi v0.2d, #0000000000000000
+; SVE-NEXT: ldr q1, [x1]
+; SVE-NEXT: ldr q2, [x2]
+; SVE-NEXT: udot z0.s, z2.b, z1.b
+; SVE-NEXT: ldr q2, [x0]
+; SVE-NEXT: uunpklo z1.d, z0.s
+; SVE-NEXT: uunpkhi z0.d, z0.s
+; SVE-NEXT: add z1.d, z2.d, z1.d
+; SVE-NEXT: add z0.d, z1.d, z0.d
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: ret
+;
+; SME-LABEL: eight_way_i8_i64_vl128:
+; SME: // %bb.0:
+; SME-NEXT: mov z0.s, #0 // =0x0
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: udot z0.s, z2.b, z1.b
+; SME-NEXT: ldr q1, [x0]
+; SME-NEXT: uaddwb z1.d, z1.d, z0.s
+; SME-NEXT: uaddwt z0.d, z1.d, z0.s
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <2 x i64>, ptr %accptr
+ %u = load <16 x i8>, ptr %uptr
+ %s = load <16 x i8>, ptr %sptr
+ %u.wide = zext <16 x i8> %u to <16 x i64>
+ %s.wide = zext <16 x i8> %s to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <16 x i64> %mult)
+ ret <2 x i64> %partial.reduce
+}
+
+define <4 x i64> @four_way_i8_i64_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; NEON-LABEL: four_way_i8_i64_vl128_double_width:
+; NEON: // %bb.0:
+; NEON-NEXT: movi v1.2d, #0000000000000000
+; NEON-NEXT: movi v0.2d, #0000000000000000
+; NEON-NEXT: ldp q3, q2, [x1]
+; NEON-NEXT: ldp q5, q4, [x2]
+; NEON-NEXT: udot v0.4s, v5.16b, v3.16b
+; NEON-NEXT: udot v1.4s, v4.16b, v2.16b
+; NEON-NEXT: ldp q3, q2, [x0]
+; NEON-NEXT: uaddw v3.2d, v3.2d, v0.2s
+; NEON-NEXT: uaddw v2.2d, v2.2d, v1.2s
+; NEON-NEXT: uaddw2 v0.2d, v3.2d, v0.4s
+; NEON-NEXT: uaddw2 v1.2d, v2.2d, v1.4s
+; NEON-NEXT: ret
+;
+; SVE-LABEL: four_way_i8_i64_vl128_double_width:
+; SVE: // %bb.0:
+; SVE-NEXT: movi v0.2d, #0000000000000000
+; SVE-NEXT: movi v1.2d, #0000000000000000
+; SVE-NEXT: ldp q3, q2, [x1]
+; SVE-NEXT: ldp q5, q4, [x2]
+; SVE-NEXT: udot z1.s, z5.b, z3.b
+; SVE-NEXT: udot z0.s, z4.b, z2.b
+; SVE-NEXT: ldp q5, q4, [x0]
+; SVE-NEXT: uunpklo z2.d, z1.s
+; SVE-NEXT: uunpklo z3.d, z0.s
+; SVE-NEXT: uunpkhi z1.d, z1.s
+; SVE-NEXT: uunpkhi z6.d, z0.s
+; SVE-NEXT: add z0.d, z5.d, z2.d
+; SVE-NEXT: add z2.d, z4.d, z3.d
+; SVE-NEXT: add z0.d, z0.d, z1.d
+; SVE-NEXT: add z1.d, z2.d, z6.d
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i64_vl128_double_width:
+; SME: // %bb.0:
+; SME-NEXT: mov z1.s, #0 // =0x0
+; SME-NEXT: mov z0.s, #0 // =0x0
+; SME-NEXT: ldp q3, q2, [x1]
+; SME-NEXT: ldp q5, q4, [x2]
+; SME-NEXT: udot z0.s, z5.b, z3.b
+; SME-NEXT: udot z1.s, z4.b, z2.b
+; SME-NEXT: ldp q3, q2, [x0]
+; SME-NEXT: uaddwb z3.d, z3.d, z0.s
+; SME-NEXT: uaddwb z2.d, z2.d, z1.s
+; SME-NEXT: uaddwt z0.d, z3.d, z0.s
+; SME-NEXT: uaddwt z1.d, z2.d, z1.s
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <4 x i64>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i64>
+ %s.wide = zext <32 x i8> %s to <32 x i64>
+ %mult = mul nuw nsw <32 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <32 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @four_way_i8_i64_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+; NEON-LABEL: four_way_i8_i64_vl256:
+; NEON: // %bb.0:
+; NEON-NEXT: movi v1.2d, #0000000000000000
+; NEON-NEXT: movi v0.2d, #0000000000000000
+; NEON-NEXT: ldp q3, q2, [x1]
+; NEON-NEXT: ldp q5, q4, [x2]
+; NEON-NEXT: udot v0.4s, v5.16b, v3.16b
+; NEON-NEXT: udot v1.4s, v4.16b, v2.16b
+; NEON-NEXT: ldp q3, q2, [x0]
+; NEON-NEXT: uaddw v3.2d, v3.2d, v0.2s
+; NEON-NEXT: uaddw v2.2d, v2.2d, v1.2s
+; NEON-NEXT: uaddw2 v0.2d, v3.2d, v0.4s
+; NEON-NEXT: uaddw2 v1.2d, v2.2d, v1.4s
+; NEON-NEXT: ret
+;
+; SVE-LABEL: four_way_i8_i64_vl256:
+; SVE: // %bb.0:
+; SVE-NEXT: movi v0.2d, #0000000000000000
+; SVE-NEXT: ldr z1, [x1]
+; SVE-NEXT: ldr z2, [x2]
+; SVE-NEXT: udot z0.s, z2.b, z1.b
+; SVE-NEXT: ldr z2, [x0]
+; SVE-NEXT: uunpklo z1.d, z0.s
+; SVE-NEXT: uunpkhi z0.d, z0.s
+; SVE-NEXT: add z1.d, z2.d, z1.d
+; SVE-NEXT: add z0.d, z1.d, z0.d
+; SVE-NEXT: mov z1.d, z0.d
+; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i64_vl256:
+; SME: // %bb.0:
+; SME-NEXT: ldr z0, [x1]
+; SME-NEXT: ldr z1, [x2]
+; SME-NEXT: mov z2.s, #0 // =0x0
+; SME-NEXT: udot z2.s, z1.b, z0.b
+; SME-NEXT: ldr z0, [x0]
+; SME-NEXT: uaddwb z0.d, z0.d, z2.s
+; SME-NEXT: uaddwt z0.d, z0.d, z2.s
+; SME-NEXT: mov z1.d, z0.d
+; SME-NEXT: ext z1.b, z1.b, z0.b, #16
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <4 x i64>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i64>
+ %s.wide = zext <32 x i8> %s to <32 x i64>
+ %mult = mul nuw nsw <32 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <32 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
More information about the llvm-commits
mailing list