[llvm] [AArch64] Add ISel support for partial reductions to use SVE2.1 udot/sdot (PR #158310)
Damian Heaton via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 12 08:05:16 PDT 2025
https://github.com/dheaton-arm created https://github.com/llvm/llvm-project/pull/158310
This allows dot products with scalable 8xi16 vectors (and fixed-length vectors which are converted into a scalable vector) accumulating into a 4xi32 vector to lower into a single instruction (`udot`/`sdot`), rather than a sequence of `umlalb`s and `umlalt`s`.
>From 35e983ac8a9c87ac0b50e504cbf2386e53ddf73c Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Fri, 12 Sep 2025 13:23:58 +0000
Subject: [PATCH] [AArch64] Add ISel support for partial reductions to use
SVE2.1 udot/sdot
This allows dot products with scalable 8xi16 vectors (and fixed-length vectors
which are converted into a scalable vector) accumulating into a 4xi32 vector to
lower into a single instruction (`udot`/`sdot`), rather than a sequence of
`umlalb`s and `umlalt`s`.
---
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 7 +++
.../AArch64/sve2p1-dots-partial-reduction.ll | 62 +++++++++++++++++++
2 files changed, 69 insertions(+)
create mode 100644 llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 05a3eab638eaa..88dfe4342ccae 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4237,6 +4237,13 @@ defm UDOT_ZZZ_HtoS : sve2p1_two_way_dot_vv<"udot", 0b1, int_aarch64_sve_udot_x2
defm SDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"sdot", 0b0, int_aarch64_sve_sdot_lane_x2>;
defm UDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"udot", 0b1, int_aarch64_sve_udot_lane_x2>;
+let Predicates = [HasSVE2p1_or_SME2] in {
+ def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
+ (UDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
+ def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
+ (SDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
+} // End HasSVE_or_SME
+
defm SQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtn", 0b00, int_aarch64_sve_sqcvtn_x2>;
defm UQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"uqcvtn", 0b01, int_aarch64_sve_uqcvtn_x2>;
defm SQCVTUN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtun", 0b10, int_aarch64_sve_sqcvtun_x2>;
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
new file mode 100644
index 0000000000000..8abd3a86ff1f7
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
@@ -0,0 +1,62 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
+
+define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: udot:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: udot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i32>
+ %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i32>
+ %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: sdot:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i32>
+ %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i32>
+ %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <4 x i32> @fixed_udot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: fixed_udot_s_h:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: udot z0.s, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <8 x i16> %a to <8 x i32>
+ %b.wide = zext <8 x i16> %b to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <4 x i32> @fixed_sdot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: fixed_sdot_s_h:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: sdot z0.s, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <8 x i16> %a to <8 x i32>
+ %b.wide = sext <8 x i16> %b to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
More information about the llvm-commits
mailing list