[llvm] e86a92f - [AArch64][SelectionDAG] Add support for 8to64 partial reduction cases (#138269)
via llvm-commits
llvm-commits at lists.llvm.org
Tue May 6 08:55:02 PDT 2025
Author: Nicholas Guy
Date: 2025-05-06T16:54:59+01:00
New Revision: e86a92f947b0ddf624b19b005a23e55823219524
URL: https://github.com/llvm/llvm-project/commit/e86a92f947b0ddf624b19b005a23e55823219524
DIFF: https://github.com/llvm/llvm-project/commit/e86a92f947b0ddf624b19b005a23e55823219524.diff
LOG: [AArch64][SelectionDAG] Add support for 8to64 partial reduction cases (#138269)
---------
Co-authored-by: James Chesterman <james.chesterman at arm.com>
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1c889d67c81e0..16ef5b072a028 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1868,6 +1868,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// Other pairs will default to 'Expand'.
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
}
// Handle operations that are only available in non-streaming SVE mode.
@@ -7740,6 +7742,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerFLDEXP(Op, DAG);
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return LowerVECTOR_HISTOGRAM(Op, DAG);
+ case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_UMLA:
+ return LowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
@@ -29476,6 +29481,40 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
return Scatter;
}
+/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
+/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
+/// however still make use of the dot product instruction by instead
+/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+
+ SDValue Acc = Op.getOperand(0);
+ SDValue LHS = Op.getOperand(1);
+ SDValue RHS = Op.getOperand(2);
+ EVT ResultVT = Op.getValueType();
+ assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
+
+ SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
+ DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
+
+ bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
+ if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
+ unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
+ unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
+ SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
+ return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
+ }
+
+ unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
+ unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
+ auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
+ auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
+ auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+ return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+}
+
SDValue
AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d9b535b910b80..9d8d1c22258be 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1181,6 +1181,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 67be3f58e8a24..5bc9a101b1e44 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,7 +1,9 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
-; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
+; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: udot:
@@ -196,46 +198,31 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
; CHECK-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: udot_8to64:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT: uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: uaddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -256,46 +243,31 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
; CHECK-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: sdot_8to64:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT: saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
More information about the llvm-commits
mailing list