[llvm] 3ea45a6 - [AArch64] Add fixed-length SVE USDOT support (#143730)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 13 08:18:57 PDT 2025
Author: Nicholas Guy
Date: 2025-06-13T16:18:54+01:00
New Revision: 3ea45a65edb2f033e59a12f71a8241f220791ac8
URL: https://github.com/llvm/llvm-project/commit/3ea45a65edb2f033e59a12f71a8241f220791ac8
DIFF: https://github.com/llvm/llvm-project/commit/3ea45a65edb2f033e59a12f71a8241f220791ac8.diff
LOG: [AArch64] Add fixed-length SVE USDOT support (#143730)
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 781a1281db402..7519ac5260a64 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2272,6 +2272,17 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
setPartialReduceMLAAction(MLAOps, VT,
MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
}
+
+ if (Subtarget->hasMatMulInt8()) {
+ if (VT.getVectorElementType() == MVT::i32)
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
+ MVT::getVectorVT(MVT::i8, NumElts * 4),
+ Custom);
+ else if (VT.getVectorElementType() == MVT::i64)
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
+ MVT::getVectorVT(MVT::i8, NumElts * 8),
+ Custom);
+ }
}
// Lower fixed length vector operations to scalable equivalents.
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
index 79d766d1b9908..af813ff16a202 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
@@ -1,7 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
-; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
-; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
+; RUN: llc -mattr=+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
+; RUN: llc -mattr=+sve,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
+; RUN: llc -mattr=+sme,+i8mm -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
target triple = "aarch64"
@@ -407,6 +407,154 @@ define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
ret <4 x i32> %partial.reduce
}
+define <4 x i32> @four_way_i8_i32_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+; COMMON-LABEL: four_way_i8_i32_vl128_usdot:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: usdot v0.4s, v1.16b, v2.16b
+; COMMON-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_usdot:
+; SME: // %bb.0:
+; SME-NEXT: ldr q0, [x0]
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: usdot z0.s, z1.b, z2.b
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <4 x i32>, ptr %accptr
+ %u = load <16 x i8>, ptr %uptr
+ %s = load <16 x i8>, ptr %sptr
+ %u.wide = zext <16 x i8> %u to <16 x i32>
+ %s.wide = sext <16 x i8> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <4 x i32> @four_way_i8_i32_vl128_sudot(ptr %accptr, ptr %uptr, ptr %sptr) {
+; COMMON-LABEL: four_way_i8_i32_vl128_sudot:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: usdot v0.4s, v2.16b, v1.16b
+; COMMON-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_sudot:
+; SME: // %bb.0:
+; SME-NEXT: ldr q0, [x0]
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: usdot z0.s, z2.b, z1.b
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <4 x i32>, ptr %accptr
+ %u = load <16 x i8>, ptr %uptr
+ %s = load <16 x i8>, ptr %sptr
+ %u.wide = sext <16 x i8> %u to <16 x i32>
+ %s.wide = zext <16 x i8> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <2 x i64> @four_way_i8_i64_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+; NEON-LABEL: four_way_i8_i64_vl128_usdot:
+; NEON: // %bb.0:
+; NEON-NEXT: movi v0.2d, #0000000000000000
+; NEON-NEXT: ldr q1, [x1]
+; NEON-NEXT: ldr q2, [x2]
+; NEON-NEXT: usdot v0.4s, v1.16b, v2.16b
+; NEON-NEXT: ldr q1, [x0]
+; NEON-NEXT: saddw v1.2d, v1.2d, v0.2s
+; NEON-NEXT: saddw2 v0.2d, v1.2d, v0.4s
+; NEON-NEXT: ret
+;
+; SVE-LABEL: four_way_i8_i64_vl128_usdot:
+; SVE: // %bb.0:
+; SVE-NEXT: movi v0.2d, #0000000000000000
+; SVE-NEXT: ldr q1, [x1]
+; SVE-NEXT: ldr q2, [x2]
+; SVE-NEXT: usdot z0.s, z1.b, z2.b
+; SVE-NEXT: ldr q2, [x0]
+; SVE-NEXT: sunpklo z1.d, z0.s
+; SVE-NEXT: sunpkhi z0.d, z0.s
+; SVE-NEXT: add z1.d, z2.d, z1.d
+; SVE-NEXT: add z0.d, z1.d, z0.d
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i64_vl128_usdot:
+; SME: // %bb.0:
+; SME-NEXT: mov z0.s, #0 // =0x0
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: usdot z0.s, z1.b, z2.b
+; SME-NEXT: ldr q1, [x0]
+; SME-NEXT: saddwb z1.d, z1.d, z0.s
+; SME-NEXT: saddwt z0.d, z1.d, z0.s
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <2 x i64>, ptr %accptr
+ %u = load <16 x i8>, ptr %uptr
+ %s = load <16 x i8>, ptr %sptr
+ %u.wide = zext <16 x i8> %u to <16 x i64>
+ %s.wide = sext <16 x i8> %s to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <16 x i64> %mult)
+ ret <2 x i64> %partial.reduce
+}
+
+define <2 x i64> @four_way_i16_i64_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+; COMMON-LABEL: four_way_i16_i64_vl128_usdot:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ushll v3.4s, v1.4h, #0
+; COMMON-NEXT: sshll v4.4s, v2.4h, #0
+; COMMON-NEXT: ushll2 v1.4s, v1.8h, #0
+; COMMON-NEXT: sshll2 v2.4s, v2.8h, #0
+; COMMON-NEXT: smlal v0.2d, v4.2s, v3.2s
+; COMMON-NEXT: smlal2 v0.2d, v4.4s, v3.4s
+; COMMON-NEXT: smlal v0.2d, v2.2s, v1.2s
+; COMMON-NEXT: smlal2 v0.2d, v2.4s, v1.4s
+; COMMON-NEXT: ret
+;
+; SME-LABEL: four_way_i16_i64_vl128_usdot:
+; SME: // %bb.0:
+; SME-NEXT: ptrue p0.d, vl2
+; SME-NEXT: ldr q2, [x0]
+; SME-NEXT: mov x8, #2 // =0x2
+; SME-NEXT: ld1h { z0.d }, p0/z, [x1]
+; SME-NEXT: ld1sh { z1.d }, p0/z, [x2]
+; SME-NEXT: mad z0.d, p0/m, z1.d, z2.d
+; SME-NEXT: ld1h { z1.d }, p0/z, [x1, x8, lsl #1]
+; SME-NEXT: ld1sh { z2.d }, p0/z, [x2, x8, lsl #1]
+; SME-NEXT: mov x8, #4 // =0x4
+; SME-NEXT: mla z0.d, p0/m, z2.d, z1.d
+; SME-NEXT: ld1h { z1.d }, p0/z, [x1, x8, lsl #1]
+; SME-NEXT: ld1sh { z2.d }, p0/z, [x2, x8, lsl #1]
+; SME-NEXT: mov x8, #6 // =0x6
+; SME-NEXT: mla z0.d, p0/m, z2.d, z1.d
+; SME-NEXT: ld1h { z1.d }, p0/z, [x1, x8, lsl #1]
+; SME-NEXT: ld1sh { z2.d }, p0/z, [x2, x8, lsl #1]
+; SME-NEXT: mla z0.d, p0/m, z2.d, z1.d
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <2 x i64>, ptr %accptr
+ %u = load <8 x i16>, ptr %uptr
+ %s = load <8 x i16>, ptr %sptr
+ %u.wide = zext <8 x i16> %u to <8 x i64>
+ %s.wide = sext <8 x i16> %s to <8 x i64>
+ %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <8 x i64> %mult)
+ ret <2 x i64> %partial.reduce
+}
+
define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
;
; COMMON-LABEL: four_way_i8_i32_vl128_double_width:
@@ -438,6 +586,37 @@ define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr
ret <8 x i32> %partial.reduce
}
+define <8 x i32> @four_way_i8_i32_vl128_double_width_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i8_i32_vl128_double_width_usdot:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldp q0, q1, [x0]
+; COMMON-NEXT: ldp q3, q2, [x1]
+; COMMON-NEXT: ldp q5, q4, [x2]
+; COMMON-NEXT: usdot v0.4s, v3.16b, v5.16b
+; COMMON-NEXT: usdot v1.4s, v2.16b, v4.16b
+; COMMON-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_double_width_usdot:
+; SME: // %bb.0:
+; SME-NEXT: ldp q0, q1, [x0]
+; SME-NEXT: ldp q3, q2, [x1]
+; SME-NEXT: ldp q5, q4, [x2]
+; SME-NEXT: usdot z0.s, z3.b, z5.b
+; SME-NEXT: usdot z1.s, z2.b, z4.b
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <8 x i32>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i32>
+ %s.wide = sext <32 x i8> %s to <32 x i32>
+ %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+ ret <8 x i32> %partial.reduce
+}
+
define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
;
;
@@ -483,6 +662,51 @@ define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscal
ret <8 x i32> %partial.reduce
}
+define <8 x i32> @four_way_i8_i32_vl256_usdot(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: four_way_i8_i32_vl256_usdot:
+; NEON: // %bb.0:
+; NEON-NEXT: ldp q0, q1, [x0]
+; NEON-NEXT: ldp q3, q2, [x1]
+; NEON-NEXT: ldp q5, q4, [x2]
+; NEON-NEXT: usdot v0.4s, v3.16b, v5.16b
+; NEON-NEXT: usdot v1.4s, v2.16b, v4.16b
+; NEON-NEXT: ret
+;
+; SVE-LABEL: four_way_i8_i32_vl256_usdot:
+; SVE: // %bb.0:
+; SVE-NEXT: ldr z0, [x0]
+; SVE-NEXT: ldr z1, [x1]
+; SVE-NEXT: ldr z2, [x2]
+; SVE-NEXT: usdot z0.s, z1.b, z2.b
+; SVE-NEXT: mov z1.d, z0.d
+; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i32_vl256_usdot:
+; SME: // %bb.0:
+; SME-NEXT: ldr z0, [x0]
+; SME-NEXT: ldr z1, [x1]
+; SME-NEXT: ldr z2, [x2]
+; SME-NEXT: usdot z0.s, z1.b, z2.b
+; SME-NEXT: mov z1.d, z0.d
+; SME-NEXT: ext z1.b, z1.b, z0.b, #16
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <8 x i32>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i32>
+ %s.wide = sext <32 x i8> %s to <32 x i32>
+ %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+ ret <8 x i32> %partial.reduce
+}
+
;
; Four-way dot (i16 -> i64)
;
More information about the llvm-commits
mailing list