[llvm] [AArch64] Add AArch64 lowering for usdot (PR #143403)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 9 08:41:33 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Nicholas Guy (NickGuy-Arm)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/143403.diff


4 Files Affected:

- (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+2) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+6) 
- (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+5) 
- (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+53-115) 


``````````diff
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 5123bbe090898..9ac228110eb9c 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -521,6 +521,8 @@ def partial_reduce_umla : SDNode<"ISD::PARTIAL_REDUCE_UMLA",
                                  SDTPartialReduceMLA>;
 def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
                                  SDTPartialReduceMLA>;
+def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
+                                 SDTPartialReduceMLA>;
 
 def fadd       : SDNode<"ISD::FADD"       , SDTFPBinOp, [SDNPCommutative]>;
 def fsub       : SDNode<"ISD::FSUB"       , SDTFPBinOp>;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 121720e7defd4..23ad29836da2a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1895,6 +1895,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom);
 
+    if (Subtarget->hasMatMulInt8()) {
+      setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::nxv4i32, MVT::nxv16i8, Legal);
+      setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::nxv2i64, MVT::nxv16i8, Custom);
+    }
+
     // Wide add types
     if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
       setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal);
@@ -7516,6 +7521,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerVECTOR_HISTOGRAM(Op, DAG);
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SUMLA:
     return LowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 12da015ae0ddb..09b9bc7de5637 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4116,6 +4116,11 @@ let Predicates = [HasSVEAES2, HasNonStreamingSVE2p1_or_SSVE_AES] in {
   def PMULL_2ZZZ_Q : sve_crypto_pmull_multi<"pmull">;
 }
 
+let Predicates = [HasSVE_or_SME, HasMatMulInt8] in {
+    def : Pat<(nxv4i32 (partial_reduce_sumla nxv4i32:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
+              (USDOT_ZZZ $Acc, $RHS, $LHS)>;
+  } // End HasSVE_or_SME, HasMatMulInt8
+
 //===----------------------------------------------------------------------===//
 // SME or SVE2.1 instructions
 //===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index d3ccfaaf20a22..221a15e5c8fe6 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -3,7 +3,7 @@
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
 ; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
-; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
+; RUN: llc -mtriple=aarch64 -mattr=+sve,+sme,+i8mm -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
 
 define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: udot:
@@ -106,23 +106,7 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: usdot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    usdot z0.s, z1.b, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -161,23 +145,7 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: sudot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    usdot z0.s, z2.b, z1.b
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -329,46 +297,31 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ; CHECK-NOI8MM-NEXT:    mla z0.d, p0/m, z2.d, z3.d
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: usdot_8to64:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -430,46 +383,31 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ; CHECK-NOI8MM-NEXT:    mla z0.d, p0/m, z2.d, z3.d
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: sudot_8to64:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-SME-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>

``````````

</details>


https://github.com/llvm/llvm-project/pull/143403


More information about the llvm-commits mailing list