[llvm] 12bd049 - [AArch64] Enable fixed-length vector support for partial-reductions (#142032)

via llvm-commits llvm-commits at lists.llvm.org
Fri May 30 09:47:34 PDT 2025


Author: Sander de Smalen
Date: 2025-05-30T17:47:31+01:00
New Revision: 12bd04951054f3b7a4c604d84820ba809b67cedb

URL: https://github.com/llvm/llvm-project/commit/12bd04951054f3b7a4c604d84820ba809b67cedb
DIFF: https://github.com/llvm/llvm-project/commit/12bd04951054f3b7a4c604d84820ba809b67cedb.diff

LOG: [AArch64] Enable fixed-length vector support for partial-reductions (#142032)

This enables the use of the [us]dot, [us]add[wt] and [us]mlal[bt]
instructions in Streaming mode, and for wider vectors when the runtime
vector length is known to be 256bits or larger.

Added: 
    llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 349bcd95c09f6..ef68734fa039c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1935,6 +1935,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                          Custom);
       setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
                          Custom);
+
+      if (EnablePartialReduceNodes) {
+        static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+                                          ISD::PARTIAL_REDUCE_UMLA};
+        // Must be lowered to SVE instructions.
+        setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
+        setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
+        setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
+        setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
+        setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
+        setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
+      }
     }
   }
 
@@ -2230,6 +2242,28 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
   bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
   bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
 
+  if (EnablePartialReduceNodes) {
+    static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+                                      ISD::PARTIAL_REDUCE_UMLA};
+    unsigned NumElts = VT.getVectorNumElements();
+    if (VT.getVectorElementType() == MVT::i64) {
+      setPartialReduceMLAAction(MLAOps, VT,
+                                MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
+      setPartialReduceMLAAction(
+          MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
+      setPartialReduceMLAAction(
+          MLAOps, VT, MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
+    } else if (VT.getVectorElementType() == MVT::i32) {
+      setPartialReduceMLAAction(MLAOps, VT,
+                                MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
+      setPartialReduceMLAAction(
+          MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
+    } else if (VT.getVectorElementType() == MVT::i16) {
+      setPartialReduceMLAAction(MLAOps, VT,
+                                MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
+    }
+  }
+
   // Lower fixed length vector operations to scalable equivalents.
   setOperationAction(ISD::ABDS, VT, Default);
   setOperationAction(ISD::ABDU, VT, Default);
@@ -29251,50 +29285,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
 SDValue
 AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
                                                SelectionDAG &DAG) const {
-  bool Scalable = Op.getValueType().isScalableVector();
-
-  assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
-         "SVE or StreamingSVE must be available when using scalable vectors.");
-  assert((Scalable || Subtarget->hasDotProd()) &&
-         "Dotprod must be available when targeting NEON dot product "
-         "instructions.");
-
   SDLoc DL(Op);
 
   SDValue Acc = Op.getOperand(0);
   SDValue LHS = Op.getOperand(1);
   SDValue RHS = Op.getOperand(2);
   EVT ResultVT = Op.getValueType();
+  EVT OrigResultVT = ResultVT;
+  EVT OpVT = LHS.getValueType();
 
-  assert((Scalable && ResultVT == MVT::nxv2i64 &&
-          LHS.getValueType() == MVT::nxv16i8) ||
-         (!Scalable && ResultVT == MVT::v2i64 &&
-          LHS.getValueType() == MVT::v16i8));
+  bool ConvertToScalable =
+      ResultVT.isFixedLengthVector() &&
+      useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
 
-  EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
+  if (ConvertToScalable) {
+    ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
+    OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
+    Acc = convertToScalableVector(DAG, ResultVT, Acc);
+    LHS = convertToScalableVector(DAG, OpVT, LHS);
+    RHS = convertToScalableVector(DAG, OpVT, RHS);
+    Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
+  }
+
+  // Two-way and four-way partial reductions are supported by patterns.
+  // We only need to handle the 8-way partial reduction.
+  if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
+    return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
+                             : Op;
+
+  EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
   SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
                                 DAG.getConstant(0, DL, DotVT), LHS, RHS);
 
+  SDValue Res;
   bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
-  if (Scalable &&
-      (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
+  if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
     unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
     unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
     SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
-    return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
-  }
-
-  // Fold (nx)v4i32 into (nx)v2i64
-  auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
-  if (IsUnsigned) {
-    DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
-    DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
+    Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
   } else {
-    DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
-    DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
+    // Fold (nx)v4i32 into (nx)v2i64
+    auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
+    if (IsUnsigned) {
+      DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
+      DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
+    } else {
+      DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
+      DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
+    }
+    auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
+    Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
   }
-  auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
-  return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
+
+  return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
+                           : Res;
 }
 
 SDValue

diff  --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
new file mode 100644
index 0000000000000..79d766d1b9908
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
@@ -0,0 +1,791 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
+; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
+; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
+
+target triple = "aarch64"
+
+;
+; Two-way mla (i8 -> i16)
+;
+
+define <8 x i16> @two_way_i8_i16_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i8_i16_vl128:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    umlal v0.8h, v2.8b, v1.8b
+; COMMON-NEXT:    umlal2 v0.8h, v2.16b, v1.16b
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i8_i16_vl128:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr q0, [x0]
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    umlalb z0.h, z2.b, z1.b
+; SME-NEXT:    umlalt z0.h, z2.b, z1.b
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <8 x i16>, ptr %accptr
+  %u = load <16 x i8>, ptr %uptr
+  %s = load <16 x i8>, ptr %sptr
+  %u.wide = zext <16 x i8> %u to <16 x i16>
+  %s.wide = zext <16 x i8> %s to <16 x i16>
+  %mult = mul nuw nsw <16 x i16> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i16> @llvm.experimental.vector.partial.reduce.add(<8 x i16> %acc, <16 x i16> %mult)
+  ret <8 x i16> %partial.reduce
+}
+
+define <16 x i16> @two_way_i8_i16_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i8_i16_vl128_double_width:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldp q0, q1, [x0]
+; COMMON-NEXT:    ldp q2, q3, [x1]
+; COMMON-NEXT:    ldp q4, q5, [x2]
+; COMMON-NEXT:    umlal v0.8h, v4.8b, v2.8b
+; COMMON-NEXT:    umlal v1.8h, v5.8b, v3.8b
+; COMMON-NEXT:    umlal2 v0.8h, v4.16b, v2.16b
+; COMMON-NEXT:    umlal2 v1.8h, v5.16b, v3.16b
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i8_i16_vl128_double_width:
+; SME:       // %bb.0:
+; SME-NEXT:    ldp q0, q1, [x0]
+; SME-NEXT:    ldp q3, q2, [x1]
+; SME-NEXT:    ldp q5, q4, [x2]
+; SME-NEXT:    umlalb z0.h, z5.b, z3.b
+; SME-NEXT:    umlalb z1.h, z4.b, z2.b
+; SME-NEXT:    umlalt z0.h, z5.b, z3.b
+; SME-NEXT:    umlalt z1.h, z4.b, z2.b
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <16 x i16>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i16>
+  %s.wide = zext <32 x i8> %s to <32 x i16>
+  %mult = mul nuw nsw <32 x i16> %s.wide, %u.wide
+  %partial.reduce = tail call <16 x i16> @llvm.experimental.vector.partial.reduce.add(<16 x i16> %acc, <32 x i16> %mult)
+  ret <16 x i16> %partial.reduce
+}
+
+define <16 x i16> @two_way_i8_i16_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i8_i16_vl256:
+; NEON:       // %bb.0:
+; NEON-NEXT:    ldp q0, q1, [x0]
+; NEON-NEXT:    ldp q2, q3, [x1]
+; NEON-NEXT:    ldp q4, q5, [x2]
+; NEON-NEXT:    umlal v0.8h, v4.8b, v2.8b
+; NEON-NEXT:    umlal v1.8h, v5.8b, v3.8b
+; NEON-NEXT:    umlal2 v0.8h, v4.16b, v2.16b
+; NEON-NEXT:    umlal2 v1.8h, v5.16b, v3.16b
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: two_way_i8_i16_vl256:
+; SVE:       // %bb.0:
+; SVE-NEXT:    ldr z0, [x1]
+; SVE-NEXT:    ldr z1, [x2]
+; SVE-NEXT:    ptrue p0.h
+; SVE-NEXT:    ldr z4, [x0]
+; SVE-NEXT:    uunpklo z2.h, z0.b
+; SVE-NEXT:    uunpklo z3.h, z1.b
+; SVE-NEXT:    uunpkhi z0.h, z0.b
+; SVE-NEXT:    uunpkhi z1.h, z1.b
+; SVE-NEXT:    mad z2.h, p0/m, z3.h, z4.h
+; SVE-NEXT:    mad z0.h, p0/m, z1.h, z2.h
+; SVE-NEXT:    mov z1.d, z0.d
+; SVE-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: two_way_i8_i16_vl256:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr z0, [x0]
+; SME-NEXT:    ldr z1, [x1]
+; SME-NEXT:    ldr z2, [x2]
+; SME-NEXT:    umlalb z0.h, z2.b, z1.b
+; SME-NEXT:    umlalt z0.h, z2.b, z1.b
+; SME-NEXT:    mov z1.d, z0.d
+; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <16 x i16>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i16>
+  %s.wide = zext <32 x i8> %s to <32 x i16>
+  %mult = mul nuw nsw <32 x i16> %s.wide, %u.wide
+  %partial.reduce = tail call <16 x i16> @llvm.experimental.vector.partial.reduce.add(<16 x i16> %acc, <32 x i16> %mult)
+  ret <16 x i16> %partial.reduce
+}
+
+;
+; Two-way mla (i16 -> i32)
+;
+
+define <4 x i32> @two_way_i16_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i16_i32_vl128:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    umlal v0.4s, v2.4h, v1.4h
+; COMMON-NEXT:    umlal2 v0.4s, v2.8h, v1.8h
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i16_i32_vl128:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr q0, [x0]
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    umlalb z0.s, z2.h, z1.h
+; SME-NEXT:    umlalt z0.s, z2.h, z1.h
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <4 x i32>, ptr %accptr
+  %u = load <8 x i16>, ptr %uptr
+  %s = load <8 x i16>, ptr %sptr
+  %u.wide = zext <8 x i16> %u to <8 x i32>
+  %s.wide = zext <8 x i16> %s to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <8 x i32> @two_way_i16_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i16_i32_vl128_double_width:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldp q0, q1, [x0]
+; COMMON-NEXT:    ldp q2, q3, [x1]
+; COMMON-NEXT:    ldp q4, q5, [x2]
+; COMMON-NEXT:    umlal v0.4s, v4.4h, v2.4h
+; COMMON-NEXT:    umlal v1.4s, v5.4h, v3.4h
+; COMMON-NEXT:    umlal2 v0.4s, v4.8h, v2.8h
+; COMMON-NEXT:    umlal2 v1.4s, v5.8h, v3.8h
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i16_i32_vl128_double_width:
+; SME:       // %bb.0:
+; SME-NEXT:    ldp q0, q1, [x0]
+; SME-NEXT:    ldp q3, q2, [x1]
+; SME-NEXT:    ldp q5, q4, [x2]
+; SME-NEXT:    umlalb z0.s, z5.h, z3.h
+; SME-NEXT:    umlalb z1.s, z4.h, z2.h
+; SME-NEXT:    umlalt z0.s, z5.h, z3.h
+; SME-NEXT:    umlalt z1.s, z4.h, z2.h
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <8 x i32>, ptr %accptr
+  %u = load <16 x i16>, ptr %uptr
+  %s = load <16 x i16>, ptr %sptr
+  %u.wide = zext <16 x i16> %u to <16 x i32>
+  %s.wide = zext <16 x i16> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
+  ret <8 x i32> %partial.reduce
+}
+
+define <8 x i32> @two_way_i16_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i16_i32_vl256:
+; NEON:       // %bb.0:
+; NEON-NEXT:    ldp q0, q1, [x0]
+; NEON-NEXT:    ldp q2, q3, [x1]
+; NEON-NEXT:    ldp q4, q5, [x2]
+; NEON-NEXT:    umlal v0.4s, v4.4h, v2.4h
+; NEON-NEXT:    umlal v1.4s, v5.4h, v3.4h
+; NEON-NEXT:    umlal2 v0.4s, v4.8h, v2.8h
+; NEON-NEXT:    umlal2 v1.4s, v5.8h, v3.8h
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: two_way_i16_i32_vl256:
+; SVE:       // %bb.0:
+; SVE-NEXT:    ldr z0, [x1]
+; SVE-NEXT:    ldr z1, [x2]
+; SVE-NEXT:    ptrue p0.s
+; SVE-NEXT:    ldr z4, [x0]
+; SVE-NEXT:    uunpklo z2.s, z0.h
+; SVE-NEXT:    uunpklo z3.s, z1.h
+; SVE-NEXT:    uunpkhi z0.s, z0.h
+; SVE-NEXT:    uunpkhi z1.s, z1.h
+; SVE-NEXT:    mad z2.s, p0/m, z3.s, z4.s
+; SVE-NEXT:    mad z0.s, p0/m, z1.s, z2.s
+; SVE-NEXT:    mov z1.d, z0.d
+; SVE-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: two_way_i16_i32_vl256:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr z0, [x0]
+; SME-NEXT:    ldr z1, [x1]
+; SME-NEXT:    ldr z2, [x2]
+; SME-NEXT:    umlalb z0.s, z2.h, z1.h
+; SME-NEXT:    umlalt z0.s, z2.h, z1.h
+; SME-NEXT:    mov z1.d, z0.d
+; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <8 x i32>, ptr %accptr
+  %u = load <16 x i16>, ptr %uptr
+  %s = load <16 x i16>, ptr %sptr
+  %u.wide = zext <16 x i16> %u to <16 x i32>
+  %s.wide = zext <16 x i16> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
+  ret <8 x i32> %partial.reduce
+}
+
+;
+; Two-way mla (i32 -> i64)
+;
+
+define <2 x i64> @two_way_i32_i64_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i32_i64_vl128:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    umlal v0.2d, v2.2s, v1.2s
+; COMMON-NEXT:    umlal2 v0.2d, v2.4s, v1.4s
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i32_i64_vl128:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr q0, [x0]
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    umlalb z0.d, z2.s, z1.s
+; SME-NEXT:    umlalt z0.d, z2.s, z1.s
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <2 x i64>, ptr %accptr
+  %u = load <4 x i32>, ptr %uptr
+  %s = load <4 x i32>, ptr %sptr
+  %u.wide = zext <4 x i32> %u to <4 x i64>
+  %s.wide = zext <4 x i32> %s to <4 x i64>
+  %mult = mul nuw nsw <4 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <4 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
+
+define <4 x i64> @two_way_i32_i64_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i32_i64_vl128_double_width:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldp q0, q1, [x0]
+; COMMON-NEXT:    ldp q2, q3, [x1]
+; COMMON-NEXT:    ldp q4, q5, [x2]
+; COMMON-NEXT:    umlal v0.2d, v4.2s, v2.2s
+; COMMON-NEXT:    umlal v1.2d, v5.2s, v3.2s
+; COMMON-NEXT:    umlal2 v0.2d, v4.4s, v2.4s
+; COMMON-NEXT:    umlal2 v1.2d, v5.4s, v3.4s
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i32_i64_vl128_double_width:
+; SME:       // %bb.0:
+; SME-NEXT:    ldp q0, q1, [x0]
+; SME-NEXT:    ldp q3, q2, [x1]
+; SME-NEXT:    ldp q5, q4, [x2]
+; SME-NEXT:    umlalb z0.d, z5.s, z3.s
+; SME-NEXT:    umlalb z1.d, z4.s, z2.s
+; SME-NEXT:    umlalt z0.d, z5.s, z3.s
+; SME-NEXT:    umlalt z1.d, z4.s, z2.s
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <4 x i64>, ptr %accptr
+  %u = load <8 x i32>, ptr %uptr
+  %s = load <8 x i32>, ptr %sptr
+  %u.wide = zext <8 x i32> %u to <8 x i64>
+  %s.wide = zext <8 x i32> %s to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <8 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @two_way_i32_i64_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i32_i64_vl256:
+; NEON:       // %bb.0:
+; NEON-NEXT:    ldp q0, q1, [x0]
+; NEON-NEXT:    ldp q2, q3, [x1]
+; NEON-NEXT:    ldp q4, q5, [x2]
+; NEON-NEXT:    umlal v0.2d, v4.2s, v2.2s
+; NEON-NEXT:    umlal v1.2d, v5.2s, v3.2s
+; NEON-NEXT:    umlal2 v0.2d, v4.4s, v2.4s
+; NEON-NEXT:    umlal2 v1.2d, v5.4s, v3.4s
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: two_way_i32_i64_vl256:
+; SVE:       // %bb.0:
+; SVE-NEXT:    ldr z0, [x1]
+; SVE-NEXT:    ldr z1, [x2]
+; SVE-NEXT:    ptrue p0.d
+; SVE-NEXT:    ldr z4, [x0]
+; SVE-NEXT:    uunpklo z2.d, z0.s
+; SVE-NEXT:    uunpklo z3.d, z1.s
+; SVE-NEXT:    uunpkhi z0.d, z0.s
+; SVE-NEXT:    uunpkhi z1.d, z1.s
+; SVE-NEXT:    mad z2.d, p0/m, z3.d, z4.d
+; SVE-NEXT:    mad z0.d, p0/m, z1.d, z2.d
+; SVE-NEXT:    mov z1.d, z0.d
+; SVE-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: two_way_i32_i64_vl256:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr z0, [x0]
+; SME-NEXT:    ldr z1, [x1]
+; SME-NEXT:    ldr z2, [x2]
+; SME-NEXT:    umlalb z0.d, z2.s, z1.s
+; SME-NEXT:    umlalt z0.d, z2.s, z1.s
+; SME-NEXT:    mov z1.d, z0.d
+; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <4 x i64>, ptr %accptr
+  %u = load <8 x i32>, ptr %uptr
+  %s = load <8 x i32>, ptr %sptr
+  %u.wide = zext <8 x i32> %u to <8 x i64>
+  %s.wide = zext <8 x i32> %s to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <8 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}
+
+
+;
+; Four-way dot (i8 -> i32)
+;
+
+define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i8_i32_vl128:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    udot v0.4s, v2.16b, v1.16b
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i32_vl128:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr q0, [x0]
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    udot z0.s, z2.b, z1.b
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <4 x i32>, ptr %accptr
+  %u = load <16 x i8>, ptr %uptr
+  %s = load <16 x i8>, ptr %sptr
+  %u.wide = zext <16 x i8> %u to <16 x i32>
+  %s.wide = zext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i8_i32_vl128_double_width:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldp q0, q1, [x0]
+; COMMON-NEXT:    ldp q3, q2, [x1]
+; COMMON-NEXT:    ldp q5, q4, [x2]
+; COMMON-NEXT:    udot v0.4s, v5.16b, v3.16b
+; COMMON-NEXT:    udot v1.4s, v4.16b, v2.16b
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_double_width:
+; SME:       // %bb.0:
+; SME-NEXT:    ldp q0, q1, [x0]
+; SME-NEXT:    ldp q3, q2, [x1]
+; SME-NEXT:    ldp q5, q4, [x2]
+; SME-NEXT:    udot z0.s, z5.b, z3.b
+; SME-NEXT:    udot z1.s, z4.b, z2.b
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <8 x i32>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i32>
+  %s.wide = zext <32 x i8> %s to <32 x i32>
+  %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+  ret <8 x i32> %partial.reduce
+}
+
+define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: four_way_i8_i32_vl256:
+; NEON:       // %bb.0:
+; NEON-NEXT:    ldp q0, q1, [x0]
+; NEON-NEXT:    ldp q3, q2, [x1]
+; NEON-NEXT:    ldp q5, q4, [x2]
+; NEON-NEXT:    udot v0.4s, v5.16b, v3.16b
+; NEON-NEXT:    udot v1.4s, v4.16b, v2.16b
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: four_way_i8_i32_vl256:
+; SVE:       // %bb.0:
+; SVE-NEXT:    ldr z0, [x0]
+; SVE-NEXT:    ldr z1, [x1]
+; SVE-NEXT:    ldr z2, [x2]
+; SVE-NEXT:    udot z0.s, z2.b, z1.b
+; SVE-NEXT:    mov z1.d, z0.d
+; SVE-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i32_vl256:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr z0, [x0]
+; SME-NEXT:    ldr z1, [x1]
+; SME-NEXT:    ldr z2, [x2]
+; SME-NEXT:    udot z0.s, z2.b, z1.b
+; SME-NEXT:    mov z1.d, z0.d
+; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <8 x i32>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i32>
+  %s.wide = zext <32 x i8> %s to <32 x i32>
+  %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+  ret <8 x i32> %partial.reduce
+}
+
+;
+; Four-way dot (i16 -> i64)
+;
+
+define <2 x i64> @four_way_i16_i64_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i16_i64_vl128:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x1]
+; COMMON-NEXT:    ldr q1, [x2]
+; COMMON-NEXT:    ldr q3, [x0]
+; COMMON-NEXT:    umull v2.4s, v1.4h, v0.4h
+; COMMON-NEXT:    umull2 v0.4s, v1.8h, v0.8h
+; COMMON-NEXT:    uaddw v3.2d, v3.2d, v2.2s
+; COMMON-NEXT:    uaddw2 v1.2d, v3.2d, v2.4s
+; COMMON-NEXT:    uaddw v1.2d, v1.2d, v0.2s
+; COMMON-NEXT:    uaddw2 v0.2d, v1.2d, v0.4s
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: four_way_i16_i64_vl128:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr q0, [x0]
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    udot z0.d, z2.h, z1.h
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <2 x i64>, ptr %accptr
+  %u = load <8 x i16>, ptr %uptr
+  %s = load <8 x i16>, ptr %sptr
+  %u.wide = zext <8 x i16> %u to <8 x i64>
+  %s.wide = zext <8 x i16> %s to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <8 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
+
+define <4 x i64> @four_way_i16_i64_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i16_i64_vl128_double_width:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldp q0, q1, [x1]
+; COMMON-NEXT:    ldp q2, q3, [x2]
+; COMMON-NEXT:    ldp q7, q6, [x0]
+; COMMON-NEXT:    umull v4.4s, v3.4h, v1.4h
+; COMMON-NEXT:    umull v5.4s, v2.4h, v0.4h
+; COMMON-NEXT:    umull2 v1.4s, v3.8h, v1.8h
+; COMMON-NEXT:    umull2 v0.4s, v2.8h, v0.8h
+; COMMON-NEXT:    uaddw v7.2d, v7.2d, v5.2s
+; COMMON-NEXT:    uaddw v6.2d, v6.2d, v4.2s
+; COMMON-NEXT:    uaddw2 v2.2d, v7.2d, v5.4s
+; COMMON-NEXT:    uaddw2 v3.2d, v6.2d, v4.4s
+; COMMON-NEXT:    uaddw v2.2d, v2.2d, v0.2s
+; COMMON-NEXT:    uaddw v3.2d, v3.2d, v1.2s
+; COMMON-NEXT:    uaddw2 v0.2d, v2.2d, v0.4s
+; COMMON-NEXT:    uaddw2 v1.2d, v3.2d, v1.4s
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: four_way_i16_i64_vl128_double_width:
+; SME:       // %bb.0:
+; SME-NEXT:    ldp q0, q1, [x0]
+; SME-NEXT:    ldp q3, q2, [x1]
+; SME-NEXT:    ldp q5, q4, [x2]
+; SME-NEXT:    udot z0.d, z5.h, z3.h
+; SME-NEXT:    udot z1.d, z4.h, z2.h
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <4 x i64>, ptr %accptr
+  %u = load <16 x i16>, ptr %uptr
+  %s = load <16 x i16>, ptr %sptr
+  %u.wide = zext <16 x i16> %u to <16 x i64>
+  %s.wide = zext <16 x i16> %s to <16 x i64>
+  %mult = mul nuw nsw <16 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <16 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @four_way_i16_i64_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: four_way_i16_i64_vl256:
+; NEON:       // %bb.0:
+; NEON-NEXT:    ldp q0, q1, [x1]
+; NEON-NEXT:    ldp q2, q3, [x2]
+; NEON-NEXT:    ldp q7, q6, [x0]
+; NEON-NEXT:    umull v4.4s, v3.4h, v1.4h
+; NEON-NEXT:    umull v5.4s, v2.4h, v0.4h
+; NEON-NEXT:    umull2 v1.4s, v3.8h, v1.8h
+; NEON-NEXT:    umull2 v0.4s, v2.8h, v0.8h
+; NEON-NEXT:    uaddw v7.2d, v7.2d, v5.2s
+; NEON-NEXT:    uaddw v6.2d, v6.2d, v4.2s
+; NEON-NEXT:    uaddw2 v2.2d, v7.2d, v5.4s
+; NEON-NEXT:    uaddw2 v3.2d, v6.2d, v4.4s
+; NEON-NEXT:    uaddw v2.2d, v2.2d, v0.2s
+; NEON-NEXT:    uaddw v3.2d, v3.2d, v1.2s
+; NEON-NEXT:    uaddw2 v0.2d, v2.2d, v0.4s
+; NEON-NEXT:    uaddw2 v1.2d, v3.2d, v1.4s
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: four_way_i16_i64_vl256:
+; SVE:       // %bb.0:
+; SVE-NEXT:    ldr z0, [x0]
+; SVE-NEXT:    ldr z1, [x1]
+; SVE-NEXT:    ldr z2, [x2]
+; SVE-NEXT:    udot z0.d, z2.h, z1.h
+; SVE-NEXT:    mov z1.d, z0.d
+; SVE-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: four_way_i16_i64_vl256:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr z0, [x0]
+; SME-NEXT:    ldr z1, [x1]
+; SME-NEXT:    ldr z2, [x2]
+; SME-NEXT:    udot z0.d, z2.h, z1.h
+; SME-NEXT:    mov z1.d, z0.d
+; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <4 x i64>, ptr %accptr
+  %u = load <16 x i16>, ptr %uptr
+  %s = load <16 x i16>, ptr %sptr
+  %u.wide = zext <16 x i16> %u to <16 x i64>
+  %s.wide = zext <16 x i16> %s to <16 x i64>
+  %mult = mul nuw nsw <16 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <16 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}
+
+
+;
+; Eight-way dot, requires two steps (i8 -> i64)
+;
+
+define <2 x i64> @eight_way_i8_i64_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; NEON-LABEL: eight_way_i8_i64_vl128:
+; NEON:       // %bb.0:
+; NEON-NEXT:    movi v0.2d, #0000000000000000
+; NEON-NEXT:    ldr q1, [x1]
+; NEON-NEXT:    ldr q2, [x2]
+; NEON-NEXT:    udot v0.4s, v2.16b, v1.16b
+; NEON-NEXT:    ldr q1, [x0]
+; NEON-NEXT:    uaddw v1.2d, v1.2d, v0.2s
+; NEON-NEXT:    uaddw2 v0.2d, v1.2d, v0.4s
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: eight_way_i8_i64_vl128:
+; SVE:       // %bb.0:
+; SVE-NEXT:    movi v0.2d, #0000000000000000
+; SVE-NEXT:    ldr q1, [x1]
+; SVE-NEXT:    ldr q2, [x2]
+; SVE-NEXT:    udot z0.s, z2.b, z1.b
+; SVE-NEXT:    ldr q2, [x0]
+; SVE-NEXT:    uunpklo z1.d, z0.s
+; SVE-NEXT:    uunpkhi z0.d, z0.s
+; SVE-NEXT:    add z1.d, z2.d, z1.d
+; SVE-NEXT:    add z0.d, z1.d, z0.d
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    ret
+;
+; SME-LABEL: eight_way_i8_i64_vl128:
+; SME:       // %bb.0:
+; SME-NEXT:    mov z0.s, #0 // =0x0
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    udot z0.s, z2.b, z1.b
+; SME-NEXT:    ldr q1, [x0]
+; SME-NEXT:    uaddwb z1.d, z1.d, z0.s
+; SME-NEXT:    uaddwt z0.d, z1.d, z0.s
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <2 x i64>, ptr %accptr
+  %u = load <16 x i8>, ptr %uptr
+  %s = load <16 x i8>, ptr %sptr
+  %u.wide = zext <16 x i8> %u to <16 x i64>
+  %s.wide = zext <16 x i8> %s to <16 x i64>
+  %mult = mul nuw nsw <16 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <16 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
+
+define <4 x i64> @four_way_i8_i64_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; NEON-LABEL: four_way_i8_i64_vl128_double_width:
+; NEON:       // %bb.0:
+; NEON-NEXT:    movi v1.2d, #0000000000000000
+; NEON-NEXT:    movi v0.2d, #0000000000000000
+; NEON-NEXT:    ldp q3, q2, [x1]
+; NEON-NEXT:    ldp q5, q4, [x2]
+; NEON-NEXT:    udot v0.4s, v5.16b, v3.16b
+; NEON-NEXT:    udot v1.4s, v4.16b, v2.16b
+; NEON-NEXT:    ldp q3, q2, [x0]
+; NEON-NEXT:    uaddw v3.2d, v3.2d, v0.2s
+; NEON-NEXT:    uaddw v2.2d, v2.2d, v1.2s
+; NEON-NEXT:    uaddw2 v0.2d, v3.2d, v0.4s
+; NEON-NEXT:    uaddw2 v1.2d, v2.2d, v1.4s
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: four_way_i8_i64_vl128_double_width:
+; SVE:       // %bb.0:
+; SVE-NEXT:    movi v0.2d, #0000000000000000
+; SVE-NEXT:    movi v1.2d, #0000000000000000
+; SVE-NEXT:    ldp q3, q2, [x1]
+; SVE-NEXT:    ldp q5, q4, [x2]
+; SVE-NEXT:    udot z1.s, z5.b, z3.b
+; SVE-NEXT:    udot z0.s, z4.b, z2.b
+; SVE-NEXT:    ldp q5, q4, [x0]
+; SVE-NEXT:    uunpklo z2.d, z1.s
+; SVE-NEXT:    uunpklo z3.d, z0.s
+; SVE-NEXT:    uunpkhi z1.d, z1.s
+; SVE-NEXT:    uunpkhi z6.d, z0.s
+; SVE-NEXT:    add z0.d, z5.d, z2.d
+; SVE-NEXT:    add z2.d, z4.d, z3.d
+; SVE-NEXT:    add z0.d, z0.d, z1.d
+; SVE-NEXT:    add z1.d, z2.d, z6.d
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i64_vl128_double_width:
+; SME:       // %bb.0:
+; SME-NEXT:    mov z1.s, #0 // =0x0
+; SME-NEXT:    mov z0.s, #0 // =0x0
+; SME-NEXT:    ldp q3, q2, [x1]
+; SME-NEXT:    ldp q5, q4, [x2]
+; SME-NEXT:    udot z0.s, z5.b, z3.b
+; SME-NEXT:    udot z1.s, z4.b, z2.b
+; SME-NEXT:    ldp q3, q2, [x0]
+; SME-NEXT:    uaddwb z3.d, z3.d, z0.s
+; SME-NEXT:    uaddwb z2.d, z2.d, z1.s
+; SME-NEXT:    uaddwt z0.d, z3.d, z0.s
+; SME-NEXT:    uaddwt z1.d, z2.d, z1.s
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <4 x i64>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i64>
+  %s.wide = zext <32 x i8> %s to <32 x i64>
+  %mult = mul nuw nsw <32 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <32 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @four_way_i8_i64_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+; NEON-LABEL: four_way_i8_i64_vl256:
+; NEON:       // %bb.0:
+; NEON-NEXT:    movi v1.2d, #0000000000000000
+; NEON-NEXT:    movi v0.2d, #0000000000000000
+; NEON-NEXT:    ldp q3, q2, [x1]
+; NEON-NEXT:    ldp q5, q4, [x2]
+; NEON-NEXT:    udot v0.4s, v5.16b, v3.16b
+; NEON-NEXT:    udot v1.4s, v4.16b, v2.16b
+; NEON-NEXT:    ldp q3, q2, [x0]
+; NEON-NEXT:    uaddw v3.2d, v3.2d, v0.2s
+; NEON-NEXT:    uaddw v2.2d, v2.2d, v1.2s
+; NEON-NEXT:    uaddw2 v0.2d, v3.2d, v0.4s
+; NEON-NEXT:    uaddw2 v1.2d, v2.2d, v1.4s
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: four_way_i8_i64_vl256:
+; SVE:       // %bb.0:
+; SVE-NEXT:    movi v0.2d, #0000000000000000
+; SVE-NEXT:    ldr z1, [x1]
+; SVE-NEXT:    ldr z2, [x2]
+; SVE-NEXT:    udot z0.s, z2.b, z1.b
+; SVE-NEXT:    ldr z2, [x0]
+; SVE-NEXT:    uunpklo z1.d, z0.s
+; SVE-NEXT:    uunpkhi z0.d, z0.s
+; SVE-NEXT:    add z1.d, z2.d, z1.d
+; SVE-NEXT:    add z0.d, z1.d, z0.d
+; SVE-NEXT:    mov z1.d, z0.d
+; SVE-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i64_vl256:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr z0, [x1]
+; SME-NEXT:    ldr z1, [x2]
+; SME-NEXT:    mov z2.s, #0 // =0x0
+; SME-NEXT:    udot z2.s, z1.b, z0.b
+; SME-NEXT:    ldr z0, [x0]
+; SME-NEXT:    uaddwb z0.d, z0.d, z2.s
+; SME-NEXT:    uaddwt z0.d, z0.d, z2.s
+; SME-NEXT:    mov z1.d, z0.d
+; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <4 x i64>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i64>
+  %s.wide = zext <32 x i8> %s to <32 x i64>
+  %mult = mul nuw nsw <32 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <32 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}


        


More information about the llvm-commits mailing list