[llvm] [AArch64][NEON][SVE] Lower i8 to i64 partial reduction to a dot product (PR #110220)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 27 01:25:44 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: James Chesterman (JamesChesterman)

<details>
<summary>Changes</summary>

An i8 to i64 partial reduction can instead be done with an i8 to i32 dot product followed by a sign extension.

---
Full diff: https://github.com/llvm/llvm-project/pull/110220.diff


3 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+20-4) 
- (modified) llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll (+156) 
- (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+190) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4166d9bd22bc01..af66b6b0e43b81 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1996,8 +1996,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     return true;
 
   EVT VT = EVT::getEVT(I->getType());
-  return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::v4i32 &&
-         VT != MVT::v2i32;
+  return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
+         VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
 }
 
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21916,8 +21916,10 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
-  if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
+  if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
+      !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
       !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
+      !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
       !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
       !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
     return SDValue();
@@ -21930,7 +21932,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
 
     bool Scalable = N->getValueType(0).isScalableVT();
     // There's no nxv2i64 version of usdot
-    if (Scalable && ReducedType != MVT::nxv4i32)
+    if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
       return SDValue();
 
     Opcode = AArch64ISD::USDOT;
@@ -21942,6 +21944,20 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   else
     Opcode = AArch64ISD::UDOT;
 
+  // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
+  // product followed by a zero / sign extension
+  if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
+      (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
+    EVT ReducedTypeHalved = (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+
+    auto Doti32 =
+        DAG.getNode(Opcode, DL, ReducedTypeHalved,
+                    DAG.getConstant(0, DL, ReducedTypeHalved), A, B);
+    auto Extended = DAG.getSExtOrTrunc(Doti32, DL, ReducedType);
+    return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
+                                  {NarrowOp, Extended});
+  }
+
   return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
 }
 
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 841da1f8ea57c1..c1b9a4c9dbb797 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -211,6 +211,162 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
   ret <2 x i32> %partial.reduce
 }
 
+define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-DOT-LABEL: udot_8to64:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT:    udot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: udot_8to64:
+; CHECK-NODOT:       // %bb.0: // %entry
+; CHECK-NODOT-NEXT:    umull v4.8h, v2.8b, v3.8b
+; CHECK-NODOT-NEXT:    umull2 v2.8h, v2.16b, v3.16b
+; CHECK-NODOT-NEXT:    ushll v3.4s, v4.4h, #0
+; CHECK-NODOT-NEXT:    ushll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    ushll2 v4.4s, v4.8h, #0
+; CHECK-NODOT-NEXT:    ushll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT:    uaddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NODOT-NEXT:    uaddw v0.2d, v0.2d, v3.2s
+; CHECK-NODOT-NEXT:    uaddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NODOT-NEXT:    uaddl v4.2d, v4.2s, v5.2s
+; CHECK-NODOT-NEXT:    uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT:    uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT:    add v1.2d, v3.2d, v1.2d
+; CHECK-NODOT-NEXT:    add v0.2d, v4.2d, v0.2d
+; CHECK-NODOT-NEXT:    ret
+entry:
+  %a.wide = zext <16 x i8> %a to <16 x i64>
+  %b.wide = zext <16 x i8> %b to <16 x i64>
+  %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+  <4 x i64> %acc, <16 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
+; CHECK-DOT-LABEL: sdot_8to64:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT:    sdot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: sdot_8to64:
+; CHECK-NODOT:       // %bb.0: // %entry
+; CHECK-NODOT-NEXT:    smull v4.8h, v2.8b, v3.8b
+; CHECK-NODOT-NEXT:    smull2 v2.8h, v2.16b, v3.16b
+; CHECK-NODOT-NEXT:    sshll v3.4s, v4.4h, #0
+; CHECK-NODOT-NEXT:    sshll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    sshll2 v4.4s, v4.8h, #0
+; CHECK-NODOT-NEXT:    sshll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT:    saddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NODOT-NEXT:    saddw v0.2d, v0.2d, v3.2s
+; CHECK-NODOT-NEXT:    saddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NODOT-NEXT:    saddl v4.2d, v4.2s, v5.2s
+; CHECK-NODOT-NEXT:    saddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT:    saddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT:    add v1.2d, v3.2d, v1.2d
+; CHECK-NODOT-NEXT:    add v0.2d, v4.2d, v0.2d
+; CHECK-NODOT-NEXT:    ret
+entry:
+  %a.wide = sext <16 x i8> %a to <16 x i64>
+  %b.wide = sext <16 x i8> %b to <16 x i64>
+  %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+  <4 x i64> %acc, <16 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
+; CHECK-NOI8MM-LABEL: usdot_8to64:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT:    sshll v5.8h, v3.8b, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v3.8h, v3.16b, #0
+; CHECK-NOI8MM-NEXT:    ushll v6.4s, v4.4h, #0
+; CHECK-NOI8MM-NEXT:    sshll v7.4s, v5.4h, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v4.4s, v4.8h, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v5.4s, v5.8h, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v16.4s, v2.8h, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v17.4s, v3.8h, #0
+; CHECK-NOI8MM-NEXT:    ushll v2.4s, v2.4h, #0
+; CHECK-NOI8MM-NEXT:    sshll v3.4s, v3.4h, #0
+; CHECK-NOI8MM-NEXT:    smlal2 v1.2d, v6.4s, v7.4s
+; CHECK-NOI8MM-NEXT:    smlal v0.2d, v6.2s, v7.2s
+; CHECK-NOI8MM-NEXT:    smull v18.2d, v4.2s, v5.2s
+; CHECK-NOI8MM-NEXT:    smull2 v4.2d, v4.4s, v5.4s
+; CHECK-NOI8MM-NEXT:    smlal2 v1.2d, v16.4s, v17.4s
+; CHECK-NOI8MM-NEXT:    smlal v0.2d, v16.2s, v17.2s
+; CHECK-NOI8MM-NEXT:    smlal2 v4.2d, v2.4s, v3.4s
+; CHECK-NOI8MM-NEXT:    smlal v18.2d, v2.2s, v3.2s
+; CHECK-NOI8MM-NEXT:    add v1.2d, v4.2d, v1.2d
+; CHECK-NOI8MM-NEXT:    add v0.2d, v18.2d, v0.2d
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-I8MM-LABEL: usdot_8to64:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-I8MM-NEXT:    usdot v4.4s, v2.16b, v3.16b
+; CHECK-I8MM-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-I8MM-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-I8MM-NEXT:    ret
+entry:
+  %a.wide = zext <16 x i8> %a to <16 x i64>
+  %b.wide = sext <16 x i8> %b to <16 x i64>
+  %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+  <4 x i64> %acc, <16 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-NOI8MM-LABEL: sudot_8to64:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    sshll v4.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT:    ushll v5.8h, v3.8b, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v3.8h, v3.16b, #0
+; CHECK-NOI8MM-NEXT:    sshll v6.4s, v4.4h, #0
+; CHECK-NOI8MM-NEXT:    ushll v7.4s, v5.4h, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v4.4s, v4.8h, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v5.4s, v5.8h, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v16.4s, v2.8h, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v17.4s, v3.8h, #0
+; CHECK-NOI8MM-NEXT:    sshll v2.4s, v2.4h, #0
+; CHECK-NOI8MM-NEXT:    ushll v3.4s, v3.4h, #0
+; CHECK-NOI8MM-NEXT:    smlal2 v1.2d, v6.4s, v7.4s
+; CHECK-NOI8MM-NEXT:    smlal v0.2d, v6.2s, v7.2s
+; CHECK-NOI8MM-NEXT:    smull v18.2d, v4.2s, v5.2s
+; CHECK-NOI8MM-NEXT:    smull2 v4.2d, v4.4s, v5.4s
+; CHECK-NOI8MM-NEXT:    smlal2 v1.2d, v16.4s, v17.4s
+; CHECK-NOI8MM-NEXT:    smlal v0.2d, v16.2s, v17.2s
+; CHECK-NOI8MM-NEXT:    smlal2 v4.2d, v2.4s, v3.4s
+; CHECK-NOI8MM-NEXT:    smlal v18.2d, v2.2s, v3.2s
+; CHECK-NOI8MM-NEXT:    add v1.2d, v4.2d, v1.2d
+; CHECK-NOI8MM-NEXT:    add v0.2d, v18.2d, v0.2d
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-I8MM-LABEL: sudot_8to64:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-I8MM-NEXT:    usdot v4.4s, v3.16b, v2.16b
+; CHECK-I8MM-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-I8MM-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-I8MM-NEXT:    ret
+entry:
+  %a.wide = sext <16 x i8> %a to <16 x i64>
+  %b.wide = zext <16 x i8> %b to <16 x i64>
+  %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+  <4 x i64> %acc, <16 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}
+
 define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-LABEL: not_udot:
 ; CHECK:       // %bb.0:
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 00e5ac479d02c9..66d6e0388bbf94 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -126,6 +126,196 @@ entry:
   ret <vscale x 4 x i32> %partial.reduce
 }
 
+define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: udot_8to64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+  <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b){
+; CHECK-LABEL: sdot_8to64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+  <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b){
+; CHECK-I8MM-LABEL: usdot_8to64:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT:    usdot z4.s, z2.b, z3.b
+; CHECK-I8MM-NEXT:    sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: usdot_8to64:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NOI8MM-NEXT:    addvl sp, sp, #-2
+; CHECK-NOI8MM-NEXT:    str z9, [sp] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT:    str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NOI8MM-NEXT:    .cfi_offset w29, -16
+; CHECK-NOI8MM-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NOI8MM-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NOI8MM-NEXT:    uunpklo z4.h, z2.b
+; CHECK-NOI8MM-NEXT:    sunpklo z5.h, z3.b
+; CHECK-NOI8MM-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NOI8MM-NEXT:    sunpkhi z3.h, z3.b
+; CHECK-NOI8MM-NEXT:    ptrue p0.d
+; CHECK-NOI8MM-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NOI8MM-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NOI8MM-NEXT:    sunpklo z7.s, z5.h
+; CHECK-NOI8MM-NEXT:    sunpkhi z5.s, z5.h
+; CHECK-NOI8MM-NEXT:    uunpklo z24.s, z2.h
+; CHECK-NOI8MM-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NOI8MM-NEXT:    sunpklo z25.s, z3.h
+; CHECK-NOI8MM-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NOI8MM-NEXT:    uunpkhi z26.d, z6.s
+; CHECK-NOI8MM-NEXT:    uunpklo z6.d, z6.s
+; CHECK-NOI8MM-NEXT:    uunpklo z27.d, z4.s
+; CHECK-NOI8MM-NEXT:    sunpklo z28.d, z7.s
+; CHECK-NOI8MM-NEXT:    sunpklo z29.d, z5.s
+; CHECK-NOI8MM-NEXT:    uunpkhi z4.d, z4.s
+; CHECK-NOI8MM-NEXT:    sunpkhi z7.d, z7.s
+; CHECK-NOI8MM-NEXT:    sunpkhi z5.d, z5.s
+; CHECK-NOI8MM-NEXT:    uunpkhi z30.d, z24.s
+; CHECK-NOI8MM-NEXT:    uunpkhi z31.d, z2.s
+; CHECK-NOI8MM-NEXT:    uunpklo z24.d, z24.s
+; CHECK-NOI8MM-NEXT:    uunpklo z2.d, z2.s
+; CHECK-NOI8MM-NEXT:    sunpkhi z8.d, z25.s
+; CHECK-NOI8MM-NEXT:    sunpklo z25.d, z25.s
+; CHECK-NOI8MM-NEXT:    sunpklo z9.d, z3.s
+; CHECK-NOI8MM-NEXT:    mul z27.d, z27.d, z29.d
+; CHECK-NOI8MM-NEXT:    mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NOI8MM-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NOI8MM-NEXT:    mul z4.d, z4.d, z5.d
+; CHECK-NOI8MM-NEXT:    mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NOI8MM-NEXT:    mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NOI8MM-NEXT:    movprfx z2, z27
+; CHECK-NOI8MM-NEXT:    mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NOI8MM-NEXT:    ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT:    mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NOI8MM-NEXT:    movprfx z3, z4
+; CHECK-NOI8MM-NEXT:    mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NOI8MM-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NOI8MM-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NOI8MM-NEXT:    addvl sp, sp, #2
+; CHECK-NOI8MM-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NOI8MM-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+  <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-I8MM-LABEL: sudot_8to64:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT:    usdot z4.s, z3.b, z2.b
+; CHECK-I8MM-NEXT:    sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT:    ret
+;
+; CHECK-NOI8MM-LABEL: sudot_8to64:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NOI8MM-NEXT:    addvl sp, sp, #-2
+; CHECK-NOI8MM-NEXT:    str z9, [sp] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT:    str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NOI8MM-NEXT:    .cfi_offset w29, -16
+; CHECK-NOI8MM-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NOI8MM-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NOI8MM-NEXT:    sunpklo z4.h, z2.b
+; CHECK-NOI8MM-NEXT:    uunpklo z5.h, z3.b
+; CHECK-NOI8MM-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NOI8MM-NEXT:    uunpkhi z3.h, z3.b
+; CHECK-NOI8MM-NEXT:    ptrue p0.d
+; CHECK-NOI8MM-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NOI8MM-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NOI8MM-NEXT:    uunpklo z7.s, z5.h
+; CHECK-NOI8MM-NEXT:    uunpkhi z5.s, z5.h
+; CHECK-NOI8MM-NEXT:    sunpklo z24.s, z2.h
+; CHECK-NOI8MM-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NOI8MM-NEXT:    uunpklo z25.s, z3.h
+; CHECK-NOI8MM-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NOI8MM-NEXT:    sunpkhi z26.d, z6.s
+; CHECK-NOI8MM-NEXT:    sunpklo z6.d, z6.s
+; CHECK-NOI8MM-NEXT:    sunpklo z27.d, z4.s
+; CHECK-NOI8MM-NEXT:    uunpklo z28.d, z7.s
+; CHECK-NOI8MM-NEXT:    uunpklo z29.d, z5.s
+; CHECK-NOI8MM-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NOI8MM-NEXT:    uunpkhi z7.d, z7.s
+; CHECK-NOI8MM-NEXT:    uunpkhi z5.d, z5.s
+; CHECK-NOI8MM-NEXT:    sunpkhi z30.d, z24.s
+; CHECK-NOI8MM-NEXT:    sunpkhi z31.d, z2.s
+; CHECK-NOI8MM-NEXT:    sunpklo z24.d, z24.s
+; CHECK-NOI8MM-NEXT:    sunpklo z2.d, z2.s
+; CHECK-NOI8MM-NEXT:    uunpkhi z8.d, z25.s
+; CHECK-NOI8MM-NEXT:    uunpklo z25.d, z25.s
+; CHECK-NOI8MM-NEXT:    uunpklo z9.d, z3.s
+; CHECK-NOI8MM-NEXT:    mul z27.d, z27.d, z29.d
+; CHECK-NOI8MM-NEXT:    mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NOI8MM-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NOI8MM-NEXT:    mul z4.d, z4.d, z5.d
+; CHECK-NOI8MM-NEXT:    mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NOI8MM-NEXT:    mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NOI8MM-NEXT:    movprfx z2, z27
+; CHECK-NOI8MM-NEXT:    mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NOI8MM-NEXT:    ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT:    mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NOI8MM-NEXT:    movprfx z3, z4
+; CHECK-NOI8MM-NEXT:    mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NOI8MM-NEXT:    ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NOI8MM-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NOI8MM-NEXT:    addvl sp, sp, #2
+; CHECK-NOI8MM-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NOI8MM-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+  <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
 define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
 ; CHECK-LABEL: not_udot:
 ; CHECK:       // %bb.0: // %entry

``````````

</details>


https://github.com/llvm/llvm-project/pull/110220


More information about the llvm-commits mailing list