[llvm] dbec6e4 - [AArch64] Add AArch64 SVE lowering for usdot (#143403)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 10 05:11:09 PDT 2025


Author: Nicholas Guy
Date: 2025-06-10T13:11:05+01:00
New Revision: dbec6e424a6b356959ac800c2b7a8d4ab74e9862

URL: https://github.com/llvm/llvm-project/commit/dbec6e424a6b356959ac800c2b7a8d4ab74e9862
DIFF: https://github.com/llvm/llvm-project/commit/dbec6e424a6b356959ac800c2b7a8d4ab74e9862.diff

LOG: [AArch64] Add AArch64 SVE lowering for usdot (#143403)

Added: 
    

Modified: 
    llvm/include/llvm/Target/TargetSelectionDAG.td
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 5123bbe090898..9ac228110eb9c 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -521,6 +521,8 @@ def partial_reduce_umla : SDNode<"ISD::PARTIAL_REDUCE_UMLA",
                                  SDTPartialReduceMLA>;
 def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
                                  SDTPartialReduceMLA>;
+def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
+                                 SDTPartialReduceMLA>;
 
 def fadd       : SDNode<"ISD::FADD"       , SDTFPBinOp, [SDNPCommutative]>;
 def fsub       : SDNode<"ISD::FSUB"       , SDTFPBinOp>;

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d30cfa257721f..caac00c5b2faa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1895,6 +1895,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom);
 
+    if (Subtarget->hasMatMulInt8()) {
+      setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::nxv4i32,
+                                MVT::nxv16i8, Legal);
+      setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::nxv2i64,
+                                MVT::nxv16i8, Custom);
+    }
+
     // Wide add types
     if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
       setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal);
@@ -7516,6 +7523,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerVECTOR_HISTOGRAM(Op, DAG);
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     return LowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index bda1bcf2a730b..2360e30de63b0 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4116,6 +4116,11 @@ let Predicates = [HasSVEAES2, HasNonStreamingSVE2_or_SSVE_AES] in {
   def PMULL_2ZZZ_Q : sve_crypto_pmull_multi<"pmull">;
 }
 
+let Predicates = [HasSVE_or_SME, HasMatMulInt8] in {
+    def : Pat<(nxv4i32 (partial_reduce_sumla nxv4i32:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
+              (USDOT_ZZZ $Acc, $RHS, $LHS)>;
+  } // End HasSVE_or_SME, HasMatMulInt8
+
 //===----------------------------------------------------------------------===//
 // SME or SVE2.1 instructions
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index d3ccfaaf20a22..221a15e5c8fe6 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -3,7 +3,7 @@
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
 ; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
-; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
+; RUN: llc -mtriple=aarch64 -mattr=+sve,+sme,+i8mm -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
 
 define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: udot:
@@ -106,23 +106,7 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: usdot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    usdot z0.s, z1.b, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -161,23 +145,7 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: sudot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    usdot z0.s, z2.b, z1.b
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -329,46 +297,31 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ; CHECK-NOI8MM-NEXT:    mla z0.d, p0/m, z2.d, z3.d
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: usdot_8to64:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -430,46 +383,31 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ; CHECK-NOI8MM-NEXT:    mla z0.d, p0/m, z2.d, z3.d
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: sudot_8to64:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>


        


More information about the llvm-commits mailing list