[llvm] [AArch64] Add Neon USDOT support (PR #143525)

Nicholas Guy via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 10 05:54:01 PDT 2025


https://github.com/NickGuy-Arm created https://github.com/llvm/llvm-project/pull/143525

None

>From 9a7aa7f7ea3b04523dc03d12e3728b90f979911d Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Tue, 10 Jun 2025 13:43:53 +0100
Subject: [PATCH] [AArch64] Add Neon USDOT support

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  10 ++
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |   5 +
 .../neon-partial-reduce-dot-product.ll        | 128 +++---------------
 3 files changed, 31 insertions(+), 112 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index caac00c5b2faa..766599d567efd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1464,6 +1464,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
       setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
       setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
+
+      if (Subtarget->hasMatMulInt8()) {
+        setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v4i32,
+                                  MVT::v16i8, Legal);
+        setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i64,
+                                  MVT::v16i8, Custom);
+
+        setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i32,
+                                  MVT::v8i8, Legal);
+      }
     }
 
   } else /* !isNeonAvailable */ {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index f5b66b75eb407..f90f12b5ac3c7 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1711,6 +1711,11 @@ multiclass SIMDSUDOTIndex {
 
 defm SUDOTlane : SIMDSUDOTIndex;
 
+def : Pat<(v2i32 (partial_reduce_sumla v2i32:$Acc, v8i8:$LHS, v8i8:$RHS)),
+              (USDOTv8i8 $Acc, $RHS, $LHS)>;
+def : Pat<(v4i32 (partial_reduce_sumla v4i32:$Acc, v16i8:$LHS, v16i8:$RHS)),
+              (USDOTv16i8 $Acc, $RHS, $LHS)>;
+
 }
 
 // ARMv8.2-A FP16 Fused Multiply-Add Long
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index d977d8fc9cf21..0c7b3c7d3c138 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -181,14 +181,7 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ;
 ; CHECK-NEWLOWERING-I8MM-LABEL: usdot:
 ; CHECK-NEWLOWERING-I8MM:       // %bb.0:
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v3.8h, v1.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v1.8h, v1.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v4.4h, v3.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.4s, v4.8h, v3.8h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    usdot v0.4s, v1.16b, v2.16b
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
@@ -247,15 +240,8 @@ define <4 x i32> @usdot_in_loop(ptr %p1, ptr %p2){
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ldr q3, [x1, x8]
 ; CHECK-NEWLOWERING-I8MM-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEWLOWERING-I8MM-NEXT:    add x8, x8, #16
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v5.8h, v3.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    usdot v1.4s, v3.16b, v2.16b
 ; CHECK-NEWLOWERING-I8MM-NEXT:    cmp x8, #16
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.4s, v4.4h, v5.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.4s, v4.8h, v5.8h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.4s, v2.4h, v3.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.4s, v2.8h, v3.8h
 ; CHECK-NEWLOWERING-I8MM-NEXT:    b.ne .LBB6_1
 ; CHECK-NEWLOWERING-I8MM-NEXT:  // %bb.2: // %end
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ret
@@ -306,19 +292,7 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ;
 ; CHECK-NEWLOWERING-I8MM-LABEL: usdot_narrow:
 ; CHECK-NEWLOWERING-I8MM:       // %bb.0:
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v2.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NEWLOWERING-I8MM-NEXT:    smull v3.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT:    ext v5.16b, v2.16b, v2.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT:    smull2 v1.4s, v2.8h, v1.8h
-; CHECK-NEWLOWERING-I8MM-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT:    ext v1.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT:    add v0.2s, v3.2s, v0.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v5.4h, v4.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    usdot v0.2s, v1.8b, v2.8b
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
@@ -347,14 +321,7 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
 ;
 ; CHECK-NEWLOWERING-I8MM-LABEL: sudot:
 ; CHECK-NEWLOWERING-I8MM:       // %bb.0:
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v3.8h, v1.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v1.8h, v1.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v4.4h, v3.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.4s, v4.8h, v3.8h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    usdot v0.4s, v2.16b, v1.16b
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ret
   %s.wide = sext <16 x i8> %u to <16 x i32>
   %u.wide = zext <16 x i8> %s to <16 x i32>
@@ -413,15 +380,8 @@ define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ldr q3, [x1, x8]
 ; CHECK-NEWLOWERING-I8MM-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEWLOWERING-I8MM-NEXT:    add x8, x8, #16
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v5.8h, v3.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    usdot v1.4s, v2.16b, v3.16b
 ; CHECK-NEWLOWERING-I8MM-NEXT:    cmp x8, #16
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.4s, v4.4h, v5.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.4s, v4.8h, v5.8h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.4s, v2.4h, v3.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.4s, v2.8h, v3.8h
 ; CHECK-NEWLOWERING-I8MM-NEXT:    b.ne .LBB9_1
 ; CHECK-NEWLOWERING-I8MM-NEXT:  // %bb.2: // %end
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ret
@@ -472,19 +432,7 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ;
 ; CHECK-NEWLOWERING-I8MM-LABEL: sudot_narrow:
 ; CHECK-NEWLOWERING-I8MM:       // %bb.0:
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v2.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NEWLOWERING-I8MM-NEXT:    smull v3.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT:    ext v5.16b, v2.16b, v2.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT:    smull2 v1.4s, v2.8h, v1.8h
-; CHECK-NEWLOWERING-I8MM-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT:    ext v1.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT:    add v0.2s, v3.2s, v0.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v5.4h, v4.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    usdot v0.2s, v2.8b, v1.8b
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
@@ -614,26 +562,10 @@ define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
 ;
 ; CHECK-NEWLOWERING-I8MM-LABEL: usdot_8to64:
 ; CHECK-NEWLOWERING-I8MM:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v5.8h, v3.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v3.8h, v3.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v6.4s, v4.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v7.4s, v2.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v16.4s, v5.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v17.4s, v3.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v4.4s, v4.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v2.4s, v2.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v5.4s, v5.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v3.4s, v3.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.2d, v6.2s, v16.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.2d, v7.2s, v17.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.2d, v6.4s, v16.4s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.2d, v7.4s, v17.4s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.2d, v4.2s, v5.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.2d, v2.2s, v3.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.2d, v4.4s, v5.4s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.2d, v2.4s, v3.4s
+; CHECK-NEWLOWERING-I8MM-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT:    usdot v4.4s, v2.16b, v3.16b
+; CHECK-NEWLOWERING-I8MM-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    saddw2 v0.2d, v0.2d, v4.4s
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ret
 entry:
   %a.wide = zext <16 x i8> %a to <16 x i64>
@@ -679,26 +611,10 @@ define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
 ;
 ; CHECK-NEWLOWERING-I8MM-LABEL: sudot_8to64:
 ; CHECK-NEWLOWERING-I8MM:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v5.8h, v3.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v3.8h, v3.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v6.4s, v4.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v7.4s, v2.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v16.4s, v5.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v17.4s, v3.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v4.4s, v4.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v2.4s, v2.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v5.4s, v5.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v3.4s, v3.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.2d, v6.2s, v16.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.2d, v7.2s, v17.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.2d, v6.4s, v16.4s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.2d, v7.4s, v17.4s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.2d, v4.2s, v5.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.2d, v2.2s, v3.2s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.2d, v4.4s, v5.4s
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.2d, v2.4s, v3.4s
+; CHECK-NEWLOWERING-I8MM-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT:    usdot v4.4s, v3.16b, v2.16b
+; CHECK-NEWLOWERING-I8MM-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    saddw2 v0.2d, v0.2d, v4.4s
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ret
 entry:
   %a.wide = sext <16 x i8> %a to <16 x i64>
@@ -1147,21 +1063,9 @@ define <4 x i32> @usdot_multiple_zext_users(ptr %p1, ptr %p2, ptr %p3) {
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ldr q3, [x1, x8]
 ; CHECK-NEWLOWERING-I8MM-NEXT:    ldr q4, [x2, x8]
 ; CHECK-NEWLOWERING-I8MM-NEXT:    add x8, x8, #16
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v5.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v6.8h, v4.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v7.8h, v3.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v4.8h, v4.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    usdot v0.4s, v4.16b, v2.16b
+; CHECK-NEWLOWERING-I8MM-NEXT:    usdot v1.4s, v4.16b, v3.16b
 ; CHECK-NEWLOWERING-I8MM-NEXT:    cmp x8, #1024
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v5.4h, v6.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.4s, v7.4h, v6.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.4s, v5.8h, v6.8h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.4s, v7.8h, v6.8h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v2.4h, v4.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.4s, v3.4h, v4.4h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.4s, v2.8h, v4.8h
-; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.4s, v3.8h, v4.8h
 ; CHECK-NEWLOWERING-I8MM-NEXT:    b.ne .LBB28_1
 ; CHECK-NEWLOWERING-I8MM-NEXT:  // %bb.2: // %end
 ; CHECK-NEWLOWERING-I8MM-NEXT:    add v0.4s, v1.4s, v0.4s



More information about the llvm-commits mailing list