[llvm-branch-commits] [llvm] release/22.x: [AArch64] Fix partial_reduce v16i8 -> v2i32 (#177119) (PR #177324)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 22 01:20:45 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: None (llvmbot)
<details>
<summary>Changes</summary>
Backport de997639876db38d20c7ed9fb0c683a239d56bf5
Requested by: @<!-- -->sdesmalen-arm
---
Full diff: https://github.com/llvm/llvm-project/pull/177324.diff
2 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+5-5)
- (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+43)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 093927049e9d1..c70caa3a67f06 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -31533,13 +31533,9 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
EVT OrigResultVT = ResultVT;
EVT OpVT = LHS.getValueType();
- bool ConvertToScalable =
- ResultVT.isFixedLengthVector() &&
- useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
-
// We can handle this case natively by accumulating into a wider
// zero-padded vector.
- if (!ConvertToScalable && ResultVT == MVT::v2i32 && OpVT == MVT::v16i8) {
+ if (ResultVT == MVT::v2i32 && OpVT == MVT::v16i8) {
SDValue ZeroVec = DAG.getConstant(0, DL, MVT::v4i32);
SDValue WideAcc = DAG.getInsertSubvector(DL, ZeroVec, Acc, 0);
SDValue Wide =
@@ -31548,6 +31544,10 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
return DAG.getExtractSubvector(DL, MVT::v2i32, Reduced, 0);
}
+ bool ConvertToScalable =
+ ResultVT.isFixedLengthVector() &&
+ useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
+
if (ConvertToScalable) {
ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index da0c01f13b960..8c66d237e686d 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1299,3 +1299,46 @@ entry:
%partial.reduce = tail call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}
+
+define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) "target-features"="+dotprod" {
+; CHECK-SVE2-LABEL: udot_v16i8tov2i32:
+; CHECK-SVE2: // %bb.0: // %entry
+; CHECK-SVE2-NEXT: movi v2.16b, #1
+; CHECK-SVE2-NEXT: fmov d0, d0
+; CHECK-SVE2-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-SVE2-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-SVE2-NEXT: addp v0.4s, v0.4s, v0.4s
+; CHECK-SVE2-NEXT: // kill: def $d0 killed $d0 killed $q0
+; CHECK-SVE2-NEXT: ret
+;
+; CHECK-SVE2-I8MM-LABEL: udot_v16i8tov2i32:
+; CHECK-SVE2-I8MM: // %bb.0: // %entry
+; CHECK-SVE2-I8MM-NEXT: movi v2.16b, #1
+; CHECK-SVE2-I8MM-NEXT: fmov d0, d0
+; CHECK-SVE2-I8MM-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-SVE2-I8MM-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-SVE2-I8MM-NEXT: addp v0.4s, v0.4s, v0.4s
+; CHECK-SVE2-I8MM-NEXT: // kill: def $d0 killed $d0 killed $q0
+; CHECK-SVE2-I8MM-NEXT: ret
+;
+; CHECK-SME-LABEL: udot_v16i8tov2i32:
+; CHECK-SME: // %bb.0: // %entry
+; CHECK-SME-NEXT: uunpklo z2.h, z1.b
+; CHECK-SME-NEXT: ext z1.b, z1.b, z1.b, #8
+; CHECK-SME-NEXT: uunpklo z1.h, z1.b
+; CHECK-SME-NEXT: uaddwb z0.s, z0.s, z2.h
+; CHECK-SME-NEXT: uaddwt z0.s, z0.s, z2.h
+; CHECK-SME-NEXT: ext z2.b, z2.b, z2.b, #8
+; CHECK-SME-NEXT: uaddwb z0.s, z0.s, z2.h
+; CHECK-SME-NEXT: uaddwt z0.s, z0.s, z2.h
+; CHECK-SME-NEXT: uaddwb z0.s, z0.s, z1.h
+; CHECK-SME-NEXT: uaddwt z0.s, z0.s, z1.h
+; CHECK-SME-NEXT: ext z1.b, z1.b, z1.b, #8
+; CHECK-SME-NEXT: uaddwb z0.s, z0.s, z1.h
+; CHECK-SME-NEXT: uaddwt z0.s, z0.s, z1.h
+; CHECK-SME-NEXT: ret
+entry:
+ %input.wide = zext <16 x i8> %input to <16 x i32>
+ %partial.reduce = tail call <2 x i32> @llvm.vector.partial.reduce.add(<2 x i32> %acc, <16 x i32> %input.wide)
+ ret <2 x i32> %partial.reduce
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/177324
More information about the llvm-branch-commits
mailing list