[llvm] [AArch64] Add AArch64 lowering for usdot (PR #143403)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 9 08:41:33 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Nicholas Guy (NickGuy-Arm)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/143403.diff
4 Files Affected:
- (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+2)
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+6)
- (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+5)
- (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+53-115)
``````````diff
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 5123bbe090898..9ac228110eb9c 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -521,6 +521,8 @@ def partial_reduce_umla : SDNode<"ISD::PARTIAL_REDUCE_UMLA",
SDTPartialReduceMLA>;
def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
SDTPartialReduceMLA>;
+def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
+ SDTPartialReduceMLA>;
def fadd : SDNode<"ISD::FADD" , SDTFPBinOp, [SDNPCommutative]>;
def fsub : SDNode<"ISD::FSUB" , SDTFPBinOp>;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 121720e7defd4..23ad29836da2a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1895,6 +1895,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom);
+ if (Subtarget->hasMatMulInt8()) {
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::nxv4i32, MVT::nxv16i8, Legal);
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::nxv2i64, MVT::nxv16i8, Custom);
+ }
+
// Wide add types
if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal);
@@ -7516,6 +7521,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerVECTOR_HISTOGRAM(Op, DAG);
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SUMLA:
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 12da015ae0ddb..09b9bc7de5637 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4116,6 +4116,11 @@ let Predicates = [HasSVEAES2, HasNonStreamingSVE2p1_or_SSVE_AES] in {
def PMULL_2ZZZ_Q : sve_crypto_pmull_multi<"pmull">;
}
+let Predicates = [HasSVE_or_SME, HasMatMulInt8] in {
+ def : Pat<(nxv4i32 (partial_reduce_sumla nxv4i32:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
+ (USDOT_ZZZ $Acc, $RHS, $LHS)>;
+ } // End HasSVE_or_SME, HasMatMulInt8
+
//===----------------------------------------------------------------------===//
// SME or SVE2.1 instructions
//===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index d3ccfaaf20a22..221a15e5c8fe6 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -3,7 +3,7 @@
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
-; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
+; RUN: llc -mtriple=aarch64 -mattr=+sve,+sme,+i8mm -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: udot:
@@ -106,23 +106,7 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
;
; CHECK-NEWLOWERING-LABEL: usdot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT: usdot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -161,23 +145,7 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
;
; CHECK-NEWLOWERING-LABEL: sudot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT: usdot z0.s, z2.b, z1.b
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -329,46 +297,31 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z2.d, z3.d
; CHECK-NOI8MM-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: usdot_8to64:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT: usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT: usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT: usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT: saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -430,46 +383,31 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z2.d, z3.d
; CHECK-NOI8MM-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: sudot_8to64:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT: usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT: usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT: usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-SME-NEXT: saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
``````````
</details>
https://github.com/llvm/llvm-project/pull/143403
More information about the llvm-commits
mailing list