[llvm] [AArch64] Add fixed-length SVE USDOT support (PR #143730)

Nicholas Guy via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 12 05:51:35 PDT 2025


https://github.com/NickGuy-Arm updated https://github.com/llvm/llvm-project/pull/143730

>From 8b1ac3404aa2e61e8c9fcf3f23693415b8cfbef0 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Wed, 11 Jun 2025 16:27:04 +0100
Subject: [PATCH 1/4] [AArch64] Add fixed-length SVE USDOT support

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |   7 +
 .../sve-fixed-length-partial-reduce.ll        | 160 +++++++++++++++++-
 2 files changed, 165 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 766599d567efd..03f381f9d7a93 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2272,6 +2272,13 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
       setPartialReduceMLAAction(MLAOps, VT,
                                 MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
     }
+
+    if (Subtarget->hasMatMulInt8()) {
+      if (VT.getVectorElementType() == MVT::i32)
+        setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
+      else if (VT.getVectorElementType() == MVT::i64)
+        setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
+    }
   }
 
   // Lower fixed length vector operations to scalable equivalents.
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
index 79d766d1b9908..81ed3e73481f8 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
-; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
+; RUN: llc -mattr=+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
+; RUN: llc -mattr=+sve,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
 ; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
 
 target triple = "aarch64"
@@ -407,6 +407,46 @@ define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
   ret <4 x i32> %partial.reduce
 }
 
+define <4 x i32> @four_way_i8_i32_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+; COMMON-LABEL: four_way_i8_i32_vl128_usdot:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    usdot v0.4s, v1.16b, v2.16b
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_usdot:
+; SME:       // %bb.0:
+; SME-NEXT:    ptrue p0.s, vl4
+; SME-NEXT:    ldr q2, [x0]
+; SME-NEXT:    mov w8, #4 // =0x4
+; SME-NEXT:    ld1b { z0.s }, p0/z, [x1]
+; SME-NEXT:    ld1sb { z1.s }, p0/z, [x2]
+; SME-NEXT:    mad z0.s, p0/m, z1.s, z2.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, x8]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT:    mov w8, #8 // =0x8
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, x8]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT:    mov w8, #12 // =0xc
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, x8]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <4 x i32>, ptr %accptr
+  %u = load <16 x i8>, ptr %uptr
+  %s = load <16 x i8>, ptr %sptr
+  %u.wide = zext <16 x i8> %u to <16 x i32>
+  %s.wide = sext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
 define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
 ;
 ; COMMON-LABEL: four_way_i8_i32_vl128_double_width:
@@ -438,6 +478,67 @@ define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr
   ret <8 x i32> %partial.reduce
 }
 
+define <8 x i32> @four_way_i8_i32_vl128_double_width_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i8_i32_vl128_double_width_usdot:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldp q0, q1, [x0]
+; COMMON-NEXT:    ldp q3, q2, [x1]
+; COMMON-NEXT:    ldp q5, q4, [x2]
+; COMMON-NEXT:    usdot v0.4s, v3.16b, v5.16b
+; COMMON-NEXT:    usdot v1.4s, v2.16b, v4.16b
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_double_width_usdot:
+; SME:       // %bb.0:
+; SME-NEXT:    ptrue p0.s, vl4
+; SME-NEXT:    mov w8, #16 // =0x10
+; SME-NEXT:    mov w9, #4 // =0x4
+; SME-NEXT:    ldp q5, q4, [x0]
+; SME-NEXT:    ld1b { z0.s }, p0/z, [x1, x8]
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2]
+; SME-NEXT:    mov w8, #20 // =0x14
+; SME-NEXT:    ld1b { z6.s }, p0/z, [x1, x8]
+; SME-NEXT:    mad z0.s, p0/m, z2.s, z4.s
+; SME-NEXT:    ld1b { z2.s }, p0/z, [x1, x9]
+; SME-NEXT:    ld1sb { z4.s }, p0/z, [x2, x9]
+; SME-NEXT:    mad z1.s, p0/m, z3.s, z5.s
+; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2, x8]
+; SME-NEXT:    mov w8, #24 // =0x18
+; SME-NEXT:    mov w9, #8 // =0x8
+; SME-NEXT:    ld1b { z5.s }, p0/z, [x1, x8]
+; SME-NEXT:    mla z0.s, p0/m, z3.s, z6.s
+; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2, x8]
+; SME-NEXT:    mov w8, #28 // =0x1c
+; SME-NEXT:    mla z1.s, p0/m, z4.s, z2.s
+; SME-NEXT:    ld1b { z2.s }, p0/z, [x1, x9]
+; SME-NEXT:    ld1sb { z4.s }, p0/z, [x2, x9]
+; SME-NEXT:    mov w9, #12 // =0xc
+; SME-NEXT:    ld1b { z6.s }, p0/z, [x1, x8]
+; SME-NEXT:    mla z1.s, p0/m, z4.s, z2.s
+; SME-NEXT:    movprfx z2, z0
+; SME-NEXT:    mla z2.s, p0/m, z3.s, z5.s
+; SME-NEXT:    ld1b { z0.s }, p0/z, [x1, x9]
+; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2, x8]
+; SME-NEXT:    ld1sb { z4.s }, p0/z, [x2, x9]
+; SME-NEXT:    mad z0.s, p0/m, z4.s, z1.s
+; SME-NEXT:    movprfx z1, z2
+; SME-NEXT:    mla z1.s, p0/m, z3.s, z6.s
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <8 x i32>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i32>
+  %s.wide = sext <32 x i8> %s to <32 x i32>
+  %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+  ret <8 x i32> %partial.reduce
+}
+
 define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
 ;
 ;
@@ -483,6 +584,61 @@ define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscal
   ret <8 x i32> %partial.reduce
 }
 
+define <8 x i32> @four_way_i8_i32_vl256_usdot(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: four_way_i8_i32_vl256_usdot:
+; NEON:       // %bb.0:
+; NEON-NEXT:    ldp q0, q1, [x0]
+; NEON-NEXT:    ldp q3, q2, [x1]
+; NEON-NEXT:    ldp q5, q4, [x2]
+; NEON-NEXT:    usdot v0.4s, v3.16b, v5.16b
+; NEON-NEXT:    usdot v1.4s, v2.16b, v4.16b
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: four_way_i8_i32_vl256_usdot:
+; SVE:       // %bb.0:
+; SVE-NEXT:    ldr z0, [x0]
+; SVE-NEXT:    ldr z1, [x1]
+; SVE-NEXT:    ldr z2, [x2]
+; SVE-NEXT:    usdot z0.s, z1.b, z2.b
+; SVE-NEXT:    mov z1.d, z0.d
+; SVE-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i32_vl256_usdot:
+; SME:       // %bb.0:
+; SME-NEXT:    ptrue p0.s
+; SME-NEXT:    ldr z0, [x0]
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2]
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, #1, mul vl]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, #1, mul vl]
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, #2, mul vl]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, #2, mul vl]
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, #3, mul vl]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, #3, mul vl]
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    mov z1.d, z0.d
+; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <8 x i32>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i32>
+  %s.wide = sext <32 x i8> %s to <32 x i32>
+  %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+  ret <8 x i32> %partial.reduce
+}
+
 ;
 ; Four-way dot (i16 -> i64)
 ;

>From 5e0d1c410b4be1b46340e4d7bf490d9e7d2bcf40 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Wed, 11 Jun 2025 16:37:49 +0100
Subject: [PATCH 2/4] clang-format

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 03f381f9d7a93..64ce3f986e9eb 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2275,9 +2275,13 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
 
     if (Subtarget->hasMatMulInt8()) {
       if (VT.getVectorElementType() == MVT::i32)
-        setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
+        setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
+                                  MVT::getVectorVT(MVT::i8, NumElts * 4),
+                                  Custom);
       else if (VT.getVectorElementType() == MVT::i64)
-        setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
+        setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
+                                  MVT::getVectorVT(MVT::i8, NumElts * 8),
+                                  Custom);
     }
   }
 

>From b6ff6b7d3193f91cfbdffabc1d534b8ad80a7d1e Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Wed, 11 Jun 2025 16:42:02 +0100
Subject: [PATCH 3/4] Update sme test with +i8mm

---
 .../sve-fixed-length-partial-reduce.ll        | 79 +++----------------
 1 file changed, 13 insertions(+), 66 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
index 81ed3e73481f8..a688a460f0f90 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
@@ -1,7 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc -mattr=+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
 ; RUN: llc -mattr=+sve,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
-; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
+; RUN: llc -mattr=+sme,+i8mm -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
 
 target triple = "aarch64"
 
@@ -418,23 +418,10 @@ define <4 x i32> @four_way_i8_i32_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr)
 ;
 ; SME-LABEL: four_way_i8_i32_vl128_usdot:
 ; SME:       // %bb.0:
-; SME-NEXT:    ptrue p0.s, vl4
-; SME-NEXT:    ldr q2, [x0]
-; SME-NEXT:    mov w8, #4 // =0x4
-; SME-NEXT:    ld1b { z0.s }, p0/z, [x1]
-; SME-NEXT:    ld1sb { z1.s }, p0/z, [x2]
-; SME-NEXT:    mad z0.s, p0/m, z1.s, z2.s
-; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, x8]
-; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
-; SME-NEXT:    mov w8, #8 // =0x8
-; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
-; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, x8]
-; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
-; SME-NEXT:    mov w8, #12 // =0xc
-; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
-; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, x8]
-; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
-; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ldr q0, [x0]
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    usdot z0.s, z1.b, z2.b
 ; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
 ; SME-NEXT:    ret
   %acc = load <4 x i32>, ptr %accptr
@@ -491,41 +478,11 @@ define <8 x i32> @four_way_i8_i32_vl128_double_width_usdot(ptr %accptr, ptr %upt
 ;
 ; SME-LABEL: four_way_i8_i32_vl128_double_width_usdot:
 ; SME:       // %bb.0:
-; SME-NEXT:    ptrue p0.s, vl4
-; SME-NEXT:    mov w8, #16 // =0x10
-; SME-NEXT:    mov w9, #4 // =0x4
-; SME-NEXT:    ldp q5, q4, [x0]
-; SME-NEXT:    ld1b { z0.s }, p0/z, [x1, x8]
-; SME-NEXT:    ld1b { z1.s }, p0/z, [x1]
-; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
-; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2]
-; SME-NEXT:    mov w8, #20 // =0x14
-; SME-NEXT:    ld1b { z6.s }, p0/z, [x1, x8]
-; SME-NEXT:    mad z0.s, p0/m, z2.s, z4.s
-; SME-NEXT:    ld1b { z2.s }, p0/z, [x1, x9]
-; SME-NEXT:    ld1sb { z4.s }, p0/z, [x2, x9]
-; SME-NEXT:    mad z1.s, p0/m, z3.s, z5.s
-; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2, x8]
-; SME-NEXT:    mov w8, #24 // =0x18
-; SME-NEXT:    mov w9, #8 // =0x8
-; SME-NEXT:    ld1b { z5.s }, p0/z, [x1, x8]
-; SME-NEXT:    mla z0.s, p0/m, z3.s, z6.s
-; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2, x8]
-; SME-NEXT:    mov w8, #28 // =0x1c
-; SME-NEXT:    mla z1.s, p0/m, z4.s, z2.s
-; SME-NEXT:    ld1b { z2.s }, p0/z, [x1, x9]
-; SME-NEXT:    ld1sb { z4.s }, p0/z, [x2, x9]
-; SME-NEXT:    mov w9, #12 // =0xc
-; SME-NEXT:    ld1b { z6.s }, p0/z, [x1, x8]
-; SME-NEXT:    mla z1.s, p0/m, z4.s, z2.s
-; SME-NEXT:    movprfx z2, z0
-; SME-NEXT:    mla z2.s, p0/m, z3.s, z5.s
-; SME-NEXT:    ld1b { z0.s }, p0/z, [x1, x9]
-; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2, x8]
-; SME-NEXT:    ld1sb { z4.s }, p0/z, [x2, x9]
-; SME-NEXT:    mad z0.s, p0/m, z4.s, z1.s
-; SME-NEXT:    movprfx z1, z2
-; SME-NEXT:    mla z1.s, p0/m, z3.s, z6.s
+; SME-NEXT:    ldp q0, q1, [x0]
+; SME-NEXT:    ldp q3, q2, [x1]
+; SME-NEXT:    ldp q5, q4, [x2]
+; SME-NEXT:    usdot z0.s, z3.b, z5.b
+; SME-NEXT:    usdot z1.s, z2.b, z4.b
 ; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
 ; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
 ; SME-NEXT:    ret
@@ -610,20 +567,10 @@ define <8 x i32> @four_way_i8_i32_vl256_usdot(ptr %accptr, ptr %uptr, ptr %sptr)
 ;
 ; SME-LABEL: four_way_i8_i32_vl256_usdot:
 ; SME:       // %bb.0:
-; SME-NEXT:    ptrue p0.s
 ; SME-NEXT:    ldr z0, [x0]
-; SME-NEXT:    ld1b { z1.s }, p0/z, [x1]
-; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2]
-; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
-; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, #1, mul vl]
-; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, #1, mul vl]
-; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
-; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, #2, mul vl]
-; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, #2, mul vl]
-; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
-; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, #3, mul vl]
-; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, #3, mul vl]
-; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ldr z1, [x1]
+; SME-NEXT:    ldr z2, [x2]
+; SME-NEXT:    usdot z0.s, z1.b, z2.b
 ; SME-NEXT:    mov z1.d, z0.d
 ; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
 ; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0

>From f0d2eb16921d32266ec8f9cfb0649ca64e42689e Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Thu, 12 Jun 2025 13:49:12 +0100
Subject: [PATCH 4/4] Add additional test cases

---
 .../sve-fixed-length-partial-reduce.ll        | 121 ++++++++++++++++++
 1 file changed, 121 insertions(+)

diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
index a688a460f0f90..af813ff16a202 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
@@ -434,6 +434,127 @@ define <4 x i32> @four_way_i8_i32_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr)
   ret <4 x i32> %partial.reduce
 }
 
+define <4 x i32> @four_way_i8_i32_vl128_sudot(ptr %accptr, ptr %uptr, ptr %sptr) {
+; COMMON-LABEL: four_way_i8_i32_vl128_sudot:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    usdot v0.4s, v2.16b, v1.16b
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_sudot:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr q0, [x0]
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    usdot z0.s, z2.b, z1.b
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <4 x i32>, ptr %accptr
+  %u = load <16 x i8>, ptr %uptr
+  %s = load <16 x i8>, ptr %sptr
+  %u.wide = sext <16 x i8> %u to <16 x i32>
+  %s.wide = zext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i64> @four_way_i8_i64_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+; NEON-LABEL: four_way_i8_i64_vl128_usdot:
+; NEON:       // %bb.0:
+; NEON-NEXT:    movi v0.2d, #0000000000000000
+; NEON-NEXT:    ldr q1, [x1]
+; NEON-NEXT:    ldr q2, [x2]
+; NEON-NEXT:    usdot v0.4s, v1.16b, v2.16b
+; NEON-NEXT:    ldr q1, [x0]
+; NEON-NEXT:    saddw v1.2d, v1.2d, v0.2s
+; NEON-NEXT:    saddw2 v0.2d, v1.2d, v0.4s
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: four_way_i8_i64_vl128_usdot:
+; SVE:       // %bb.0:
+; SVE-NEXT:    movi v0.2d, #0000000000000000
+; SVE-NEXT:    ldr q1, [x1]
+; SVE-NEXT:    ldr q2, [x2]
+; SVE-NEXT:    usdot z0.s, z1.b, z2.b
+; SVE-NEXT:    ldr q2, [x0]
+; SVE-NEXT:    sunpklo z1.d, z0.s
+; SVE-NEXT:    sunpkhi z0.d, z0.s
+; SVE-NEXT:    add z1.d, z2.d, z1.d
+; SVE-NEXT:    add z0.d, z1.d, z0.d
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i64_vl128_usdot:
+; SME:       // %bb.0:
+; SME-NEXT:    mov z0.s, #0 // =0x0
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    usdot z0.s, z1.b, z2.b
+; SME-NEXT:    ldr q1, [x0]
+; SME-NEXT:    saddwb z1.d, z1.d, z0.s
+; SME-NEXT:    saddwt z0.d, z1.d, z0.s
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <2 x i64>, ptr %accptr
+  %u = load <16 x i8>, ptr %uptr
+  %s = load <16 x i8>, ptr %sptr
+  %u.wide = zext <16 x i8> %u to <16 x i64>
+  %s.wide = sext <16 x i8> %s to <16 x i64>
+  %mult = mul nuw nsw <16 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <16 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
+
+define <2 x i64> @four_way_i16_i64_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+; COMMON-LABEL: four_way_i16_i64_vl128_usdot:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ushll v3.4s, v1.4h, #0
+; COMMON-NEXT:    sshll v4.4s, v2.4h, #0
+; COMMON-NEXT:    ushll2 v1.4s, v1.8h, #0
+; COMMON-NEXT:    sshll2 v2.4s, v2.8h, #0
+; COMMON-NEXT:    smlal v0.2d, v4.2s, v3.2s
+; COMMON-NEXT:    smlal2 v0.2d, v4.4s, v3.4s
+; COMMON-NEXT:    smlal v0.2d, v2.2s, v1.2s
+; COMMON-NEXT:    smlal2 v0.2d, v2.4s, v1.4s
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: four_way_i16_i64_vl128_usdot:
+; SME:       // %bb.0:
+; SME-NEXT:    ptrue p0.d, vl2
+; SME-NEXT:    ldr q2, [x0]
+; SME-NEXT:    mov x8, #2 // =0x2
+; SME-NEXT:    ld1h { z0.d }, p0/z, [x1]
+; SME-NEXT:    ld1sh { z1.d }, p0/z, [x2]
+; SME-NEXT:    mad z0.d, p0/m, z1.d, z2.d
+; SME-NEXT:    ld1h { z1.d }, p0/z, [x1, x8, lsl #1]
+; SME-NEXT:    ld1sh { z2.d }, p0/z, [x2, x8, lsl #1]
+; SME-NEXT:    mov x8, #4 // =0x4
+; SME-NEXT:    mla z0.d, p0/m, z2.d, z1.d
+; SME-NEXT:    ld1h { z1.d }, p0/z, [x1, x8, lsl #1]
+; SME-NEXT:    ld1sh { z2.d }, p0/z, [x2, x8, lsl #1]
+; SME-NEXT:    mov x8, #6 // =0x6
+; SME-NEXT:    mla z0.d, p0/m, z2.d, z1.d
+; SME-NEXT:    ld1h { z1.d }, p0/z, [x1, x8, lsl #1]
+; SME-NEXT:    ld1sh { z2.d }, p0/z, [x2, x8, lsl #1]
+; SME-NEXT:    mla z0.d, p0/m, z2.d, z1.d
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <2 x i64>, ptr %accptr
+  %u = load <8 x i16>, ptr %uptr
+  %s = load <8 x i16>, ptr %sptr
+  %u.wide = zext <8 x i16> %u to <8 x i64>
+  %s.wide = sext <8 x i16> %s to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <8 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
+
 define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
 ;
 ; COMMON-LABEL: four_way_i8_i32_vl128_double_width:



More information about the llvm-commits mailing list