[llvm] [AArch64][NEON][SVE] Lower i8 to i64 partial reduction to a dot product (PR #110220)
James Chesterman via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 27 07:34:12 PDT 2024
https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/110220
>From 58b92fd1b9119e59e36a206e55cb52aaaa674a14 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 23 Sep 2024 14:17:25 +0000
Subject: [PATCH 1/3] [AArch64][NEON][SVE] Lower i8 to i64 partial reduction to
a dot product
An i8 to i64 partial reduction can instead be done with an i8 to i32 dot product followed
by a sign extension.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 24 ++-
.../neon-partial-reduce-dot-product.ll | 156 ++++++++++++++
.../AArch64/sve-partial-reduce-dot-product.ll | 190 ++++++++++++++++++
3 files changed, 366 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4166d9bd22bc01..af66b6b0e43b81 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1996,8 +1996,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
EVT VT = EVT::getEVT(I->getType());
- return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::v4i32 &&
- VT != MVT::v2i32;
+ return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
+ VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21916,8 +21916,10 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
+ if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
+ !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
!(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
+ !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
!(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
!(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
return SDValue();
@@ -21930,7 +21932,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
- if (Scalable && ReducedType != MVT::nxv4i32)
+ if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
return SDValue();
Opcode = AArch64ISD::USDOT;
@@ -21942,6 +21944,20 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
else
Opcode = AArch64ISD::UDOT;
+ // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
+ // product followed by a zero / sign extension
+ if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
+ (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
+ EVT ReducedTypeHalved = (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+
+ auto Doti32 =
+ DAG.getNode(Opcode, DL, ReducedTypeHalved,
+ DAG.getConstant(0, DL, ReducedTypeHalved), A, B);
+ auto Extended = DAG.getSExtOrTrunc(Doti32, DL, ReducedType);
+ return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
+ {NarrowOp, Extended});
+ }
+
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
}
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 841da1f8ea57c1..c1b9a4c9dbb797 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -211,6 +211,162 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
ret <2 x i32> %partial.reduce
}
+define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-DOT-LABEL: udot_8to64:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: udot_8to64:
+; CHECK-NODOT: // %bb.0: // %entry
+; CHECK-NODOT-NEXT: umull v4.8h, v2.8b, v3.8b
+; CHECK-NODOT-NEXT: umull2 v2.8h, v2.16b, v3.16b
+; CHECK-NODOT-NEXT: ushll v3.4s, v4.4h, #0
+; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT: ushll2 v4.4s, v4.8h, #0
+; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v3.2s
+; CHECK-NODOT-NEXT: uaddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NODOT-NEXT: uaddl v4.2d, v4.2s, v5.2s
+; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
+; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
+; CHECK-NODOT-NEXT: ret
+entry:
+ %a.wide = zext <16 x i8> %a to <16 x i64>
+ %b.wide = zext <16 x i8> %b to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+ <4 x i64> %acc, <16 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
+; CHECK-DOT-LABEL: sdot_8to64:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sdot_8to64:
+; CHECK-NODOT: // %bb.0: // %entry
+; CHECK-NODOT-NEXT: smull v4.8h, v2.8b, v3.8b
+; CHECK-NODOT-NEXT: smull2 v2.8h, v2.16b, v3.16b
+; CHECK-NODOT-NEXT: sshll v3.4s, v4.4h, #0
+; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT: sshll2 v4.4s, v4.8h, #0
+; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v3.2s
+; CHECK-NODOT-NEXT: saddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NODOT-NEXT: saddl v4.2d, v4.2s, v5.2s
+; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
+; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
+; CHECK-NODOT-NEXT: ret
+entry:
+ %a.wide = sext <16 x i8> %a to <16 x i64>
+ %b.wide = sext <16 x i8> %b to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+ <4 x i64> %acc, <16 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
+; CHECK-NOI8MM-LABEL: usdot_8to64:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT: sshll v5.8h, v3.8b, #0
+; CHECK-NOI8MM-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT: sshll2 v3.8h, v3.16b, #0
+; CHECK-NOI8MM-NEXT: ushll v6.4s, v4.4h, #0
+; CHECK-NOI8MM-NEXT: sshll v7.4s, v5.4h, #0
+; CHECK-NOI8MM-NEXT: ushll2 v4.4s, v4.8h, #0
+; CHECK-NOI8MM-NEXT: sshll2 v5.4s, v5.8h, #0
+; CHECK-NOI8MM-NEXT: ushll2 v16.4s, v2.8h, #0
+; CHECK-NOI8MM-NEXT: sshll2 v17.4s, v3.8h, #0
+; CHECK-NOI8MM-NEXT: ushll v2.4s, v2.4h, #0
+; CHECK-NOI8MM-NEXT: sshll v3.4s, v3.4h, #0
+; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v6.4s, v7.4s
+; CHECK-NOI8MM-NEXT: smlal v0.2d, v6.2s, v7.2s
+; CHECK-NOI8MM-NEXT: smull v18.2d, v4.2s, v5.2s
+; CHECK-NOI8MM-NEXT: smull2 v4.2d, v4.4s, v5.4s
+; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v16.4s, v17.4s
+; CHECK-NOI8MM-NEXT: smlal v0.2d, v16.2s, v17.2s
+; CHECK-NOI8MM-NEXT: smlal2 v4.2d, v2.4s, v3.4s
+; CHECK-NOI8MM-NEXT: smlal v18.2d, v2.2s, v3.2s
+; CHECK-NOI8MM-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NOI8MM-NEXT: add v0.2d, v18.2d, v0.2d
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-I8MM-LABEL: usdot_8to64:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: movi v4.2d, #0000000000000000
+; CHECK-I8MM-NEXT: usdot v4.4s, v2.16b, v3.16b
+; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-I8MM-NEXT: ret
+entry:
+ %a.wide = zext <16 x i8> %a to <16 x i64>
+ %b.wide = sext <16 x i8> %b to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+ <4 x i64> %acc, <16 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-NOI8MM-LABEL: sudot_8to64:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT: ushll v5.8h, v3.8b, #0
+; CHECK-NOI8MM-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT: ushll2 v3.8h, v3.16b, #0
+; CHECK-NOI8MM-NEXT: sshll v6.4s, v4.4h, #0
+; CHECK-NOI8MM-NEXT: ushll v7.4s, v5.4h, #0
+; CHECK-NOI8MM-NEXT: sshll2 v4.4s, v4.8h, #0
+; CHECK-NOI8MM-NEXT: ushll2 v5.4s, v5.8h, #0
+; CHECK-NOI8MM-NEXT: sshll2 v16.4s, v2.8h, #0
+; CHECK-NOI8MM-NEXT: ushll2 v17.4s, v3.8h, #0
+; CHECK-NOI8MM-NEXT: sshll v2.4s, v2.4h, #0
+; CHECK-NOI8MM-NEXT: ushll v3.4s, v3.4h, #0
+; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v6.4s, v7.4s
+; CHECK-NOI8MM-NEXT: smlal v0.2d, v6.2s, v7.2s
+; CHECK-NOI8MM-NEXT: smull v18.2d, v4.2s, v5.2s
+; CHECK-NOI8MM-NEXT: smull2 v4.2d, v4.4s, v5.4s
+; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v16.4s, v17.4s
+; CHECK-NOI8MM-NEXT: smlal v0.2d, v16.2s, v17.2s
+; CHECK-NOI8MM-NEXT: smlal2 v4.2d, v2.4s, v3.4s
+; CHECK-NOI8MM-NEXT: smlal v18.2d, v2.2s, v3.2s
+; CHECK-NOI8MM-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NOI8MM-NEXT: add v0.2d, v18.2d, v0.2d
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-I8MM-LABEL: sudot_8to64:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: movi v4.2d, #0000000000000000
+; CHECK-I8MM-NEXT: usdot v4.4s, v3.16b, v2.16b
+; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-I8MM-NEXT: ret
+entry:
+ %a.wide = sext <16 x i8> %a to <16 x i64>
+ %b.wide = zext <16 x i8> %b to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+ <4 x i64> %acc, <16 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 00e5ac479d02c9..66d6e0388bbf94 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -126,6 +126,196 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
+define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: udot_8to64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b){
+; CHECK-LABEL: sdot_8to64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b){
+; CHECK-I8MM-LABEL: usdot_8to64:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT: usdot z4.s, z2.b, z3.b
+; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: usdot_8to64:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NOI8MM-NEXT: addvl sp, sp, #-2
+; CHECK-NOI8MM-NEXT: str z9, [sp] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NOI8MM-NEXT: .cfi_offset w29, -16
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NOI8MM-NEXT: uunpklo z4.h, z2.b
+; CHECK-NOI8MM-NEXT: sunpklo z5.h, z3.b
+; CHECK-NOI8MM-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NOI8MM-NEXT: sunpkhi z3.h, z3.b
+; CHECK-NOI8MM-NEXT: ptrue p0.d
+; CHECK-NOI8MM-NEXT: uunpklo z6.s, z4.h
+; CHECK-NOI8MM-NEXT: uunpkhi z4.s, z4.h
+; CHECK-NOI8MM-NEXT: sunpklo z7.s, z5.h
+; CHECK-NOI8MM-NEXT: sunpkhi z5.s, z5.h
+; CHECK-NOI8MM-NEXT: uunpklo z24.s, z2.h
+; CHECK-NOI8MM-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NOI8MM-NEXT: sunpklo z25.s, z3.h
+; CHECK-NOI8MM-NEXT: sunpkhi z3.s, z3.h
+; CHECK-NOI8MM-NEXT: uunpkhi z26.d, z6.s
+; CHECK-NOI8MM-NEXT: uunpklo z6.d, z6.s
+; CHECK-NOI8MM-NEXT: uunpklo z27.d, z4.s
+; CHECK-NOI8MM-NEXT: sunpklo z28.d, z7.s
+; CHECK-NOI8MM-NEXT: sunpklo z29.d, z5.s
+; CHECK-NOI8MM-NEXT: uunpkhi z4.d, z4.s
+; CHECK-NOI8MM-NEXT: sunpkhi z7.d, z7.s
+; CHECK-NOI8MM-NEXT: sunpkhi z5.d, z5.s
+; CHECK-NOI8MM-NEXT: uunpkhi z30.d, z24.s
+; CHECK-NOI8MM-NEXT: uunpkhi z31.d, z2.s
+; CHECK-NOI8MM-NEXT: uunpklo z24.d, z24.s
+; CHECK-NOI8MM-NEXT: uunpklo z2.d, z2.s
+; CHECK-NOI8MM-NEXT: sunpkhi z8.d, z25.s
+; CHECK-NOI8MM-NEXT: sunpklo z25.d, z25.s
+; CHECK-NOI8MM-NEXT: sunpklo z9.d, z3.s
+; CHECK-NOI8MM-NEXT: mul z27.d, z27.d, z29.d
+; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NOI8MM-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NOI8MM-NEXT: mul z4.d, z4.d, z5.d
+; CHECK-NOI8MM-NEXT: mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NOI8MM-NEXT: movprfx z2, z27
+; CHECK-NOI8MM-NEXT: mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NOI8MM-NEXT: ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT: mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NOI8MM-NEXT: movprfx z3, z4
+; CHECK-NOI8MM-NEXT: mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NOI8MM-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NOI8MM-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NOI8MM-NEXT: addvl sp, sp, #2
+; CHECK-NOI8MM-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NOI8MM-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-I8MM-LABEL: sudot_8to64:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT: usdot z4.s, z3.b, z2.b
+; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: sudot_8to64:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NOI8MM-NEXT: addvl sp, sp, #-2
+; CHECK-NOI8MM-NEXT: str z9, [sp] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NOI8MM-NEXT: .cfi_offset w29, -16
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NOI8MM-NEXT: sunpklo z4.h, z2.b
+; CHECK-NOI8MM-NEXT: uunpklo z5.h, z3.b
+; CHECK-NOI8MM-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NOI8MM-NEXT: uunpkhi z3.h, z3.b
+; CHECK-NOI8MM-NEXT: ptrue p0.d
+; CHECK-NOI8MM-NEXT: sunpklo z6.s, z4.h
+; CHECK-NOI8MM-NEXT: sunpkhi z4.s, z4.h
+; CHECK-NOI8MM-NEXT: uunpklo z7.s, z5.h
+; CHECK-NOI8MM-NEXT: uunpkhi z5.s, z5.h
+; CHECK-NOI8MM-NEXT: sunpklo z24.s, z2.h
+; CHECK-NOI8MM-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NOI8MM-NEXT: uunpklo z25.s, z3.h
+; CHECK-NOI8MM-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NOI8MM-NEXT: sunpkhi z26.d, z6.s
+; CHECK-NOI8MM-NEXT: sunpklo z6.d, z6.s
+; CHECK-NOI8MM-NEXT: sunpklo z27.d, z4.s
+; CHECK-NOI8MM-NEXT: uunpklo z28.d, z7.s
+; CHECK-NOI8MM-NEXT: uunpklo z29.d, z5.s
+; CHECK-NOI8MM-NEXT: sunpkhi z4.d, z4.s
+; CHECK-NOI8MM-NEXT: uunpkhi z7.d, z7.s
+; CHECK-NOI8MM-NEXT: uunpkhi z5.d, z5.s
+; CHECK-NOI8MM-NEXT: sunpkhi z30.d, z24.s
+; CHECK-NOI8MM-NEXT: sunpkhi z31.d, z2.s
+; CHECK-NOI8MM-NEXT: sunpklo z24.d, z24.s
+; CHECK-NOI8MM-NEXT: sunpklo z2.d, z2.s
+; CHECK-NOI8MM-NEXT: uunpkhi z8.d, z25.s
+; CHECK-NOI8MM-NEXT: uunpklo z25.d, z25.s
+; CHECK-NOI8MM-NEXT: uunpklo z9.d, z3.s
+; CHECK-NOI8MM-NEXT: mul z27.d, z27.d, z29.d
+; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NOI8MM-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NOI8MM-NEXT: mul z4.d, z4.d, z5.d
+; CHECK-NOI8MM-NEXT: mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NOI8MM-NEXT: movprfx z2, z27
+; CHECK-NOI8MM-NEXT: mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NOI8MM-NEXT: ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT: mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NOI8MM-NEXT: movprfx z3, z4
+; CHECK-NOI8MM-NEXT: mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NOI8MM-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NOI8MM-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NOI8MM-NEXT: addvl sp, sp, #2
+; CHECK-NOI8MM-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NOI8MM-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0: // %entry
>From 6d3314881aded9cd5d7b405e53a6a4dcb7d8498e Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 27 Sep 2024 12:27:28 +0000
Subject: [PATCH 2/3] Fix formatting.
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index af66b6b0e43b81..1775de31e6d9a4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21948,14 +21948,14 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// product followed by a zero / sign extension
if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
- EVT ReducedTypeHalved = (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+ EVT ReducedTypeHalved =
+ (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
- auto Doti32 =
- DAG.getNode(Opcode, DL, ReducedTypeHalved,
- DAG.getConstant(0, DL, ReducedTypeHalved), A, B);
+ auto Doti32 = DAG.getNode(Opcode, DL, ReducedTypeHalved,
+ DAG.getConstant(0, DL, ReducedTypeHalved), A, B);
auto Extended = DAG.getSExtOrTrunc(Doti32, DL, ReducedType);
return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
- {NarrowOp, Extended});
+ {NarrowOp, Extended});
}
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
>From 7fb289d45105bf599b2f00d983f70b2566116f2e Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 27 Sep 2024 14:33:42 +0000
Subject: [PATCH 3/3] Changed whichoverload is called for getNode()
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1775de31e6d9a4..fbf4b726e7ed0d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21954,8 +21954,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
auto Doti32 = DAG.getNode(Opcode, DL, ReducedTypeHalved,
DAG.getConstant(0, DL, ReducedTypeHalved), A, B);
auto Extended = DAG.getSExtOrTrunc(Doti32, DL, ReducedType);
- return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
- {NarrowOp, Extended});
+ return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
+ Extended);
}
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
More information about the llvm-commits
mailing list