[llvm] e160b2a - [AArch64] Improve codegen for partial.reduce.add v16i8 -> v2i32 (#161833)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 9 09:08:14 PDT 2025
Author: Sander de Smalen
Date: 2025-10-09T17:08:10+01:00
New Revision: e160b2a03c44f254d80287d74026ddacd2868089
URL: https://github.com/llvm/llvm-project/commit/e160b2a03c44f254d80287d74026ddacd2868089
DIFF: https://github.com/llvm/llvm-project/commit/e160b2a03c44f254d80287d74026ddacd2868089.diff
LOG: [AArch64] Improve codegen for partial.reduce.add v16i8 -> v2i32 (#161833)
Rather than expanding, we can handle this case natively by
widening the accumulator.
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dc8e7c84f5e2c..31b3d1807933b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1458,6 +1458,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v16i8, Custom);
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
if (Subtarget->hasMatMulInt8()) {
@@ -30769,6 +30770,17 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
ResultVT.isFixedLengthVector() &&
useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
+ // We can handle this case natively by accumulating into a wider
+ // zero-padded vector.
+ if (!ConvertToScalable && ResultVT == MVT::v2i32 && OpVT == MVT::v16i8) {
+ SDValue ZeroVec = DAG.getConstant(0, DL, MVT::v4i32);
+ SDValue WideAcc = DAG.getInsertSubvector(DL, ZeroVec, Acc, 0);
+ SDValue Wide =
+ DAG.getNode(Op.getOpcode(), DL, MVT::v4i32, WideAcc, LHS, RHS);
+ SDValue Reduced = DAG.getNode(AArch64ISD::ADDP, DL, MVT::v4i32, Wide, Wide);
+ return DAG.getExtractSubvector(DL, MVT::v2i32, Reduced, 0);
+ }
+
if (ConvertToScalable) {
ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 428750740fc56..dfff35d9eb1b2 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1451,3 +1451,52 @@ define <4 x i32> @partial_reduce_shl_zext_non_const_rhs(<16 x i8> %l, <4 x i32>
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
ret <4 x i32> %red
}
+
+define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) {
+; CHECK-NODOT-LABEL: udot_v16i8tov2i32:
+; CHECK-NODOT: // %bb.0: // %entry
+; CHECK-NODOT-NEXT: ushll v2.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: ushll v3.4s, v2.4h, #0
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT: ushll2 v4.4s, v2.8h, #0
+; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: ext v3.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT: ext v1.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: add v0.2s, v2.2s, v0.2s
+; CHECK-NODOT-NEXT: ext v2.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT: add v0.2s, v2.2s, v0.2s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-DOT-LABEL: udot_v16i8tov2i32:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v2.16b, #1
+; CHECK-DOT-NEXT: fmov d0, d0
+; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT: addp v0.4s, v0.4s, v0.4s
+; CHECK-DOT-NEXT: // kill: def $d0 killed $d0 killed $q0
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-DOT-I8MM-LABEL: udot_v16i8tov2i32:
+; CHECK-DOT-I8MM: // %bb.0: // %entry
+; CHECK-DOT-I8MM-NEXT: movi v2.16b, #1
+; CHECK-DOT-I8MM-NEXT: fmov d0, d0
+; CHECK-DOT-I8MM-NEXT: udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT: addp v0.4s, v0.4s, v0.4s
+; CHECK-DOT-I8MM-NEXT: // kill: def $d0 killed $d0 killed $q0
+; CHECK-DOT-I8MM-NEXT: ret
+entry:
+ %input.wide = zext <16 x i8> %input to <16 x i32>
+ %partial.reduce = tail call <2 x i32> @llvm.vector.partial.reduce.add(<2 x i32> %acc, <16 x i32> %input.wide)
+ ret <2 x i32> %partial.reduce
+}
More information about the llvm-commits
mailing list