[llvm] [AArch64] Improve codegen for partial.reduce.add v16i8 -> v2i32 (PR #161833)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 3 06:20:54 PDT 2025


https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/161833

>From fc9534882a0ae4d2e9af16d64198b644179afd76 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 3 Oct 2025 11:35:07 +0100
Subject: [PATCH 1/3] Precommit test

---
 .../neon-partial-reduce-dot-product.ll        | 31 +++++++++++++++++++
 1 file changed, 31 insertions(+)

diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 428750740fc56..824a3708451ba 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1451,3 +1451,34 @@ define <4 x i32> @partial_reduce_shl_zext_non_const_rhs(<16 x i8> %l, <4 x i32>
   %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
   ret <4 x i32> %red
 }
+
+define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) {
+; CHECK-COMMON-LABEL: udot_v16i8tov2i32:
+; CHECK-COMMON:       // %bb.0: // %entry
+; CHECK-COMMON-NEXT:    ushll v2.8h, v1.8b, #0
+; CHECK-COMMON-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-COMMON-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-COMMON-NEXT:    ushll v3.4s, v2.4h, #0
+; CHECK-COMMON-NEXT:    uaddw v0.4s, v0.4s, v2.4h
+; CHECK-COMMON-NEXT:    ushll2 v4.4s, v2.8h, #0
+; CHECK-COMMON-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-COMMON-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-COMMON-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-COMMON-NEXT:    ext v3.16b, v4.16b, v4.16b, #8
+; CHECK-COMMON-NEXT:    uaddw v0.4s, v0.4s, v2.4h
+; CHECK-COMMON-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-COMMON-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-COMMON-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-COMMON-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-COMMON-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-COMMON-NEXT:    ext v1.16b, v1.16b, v1.16b, #8
+; CHECK-COMMON-NEXT:    add v0.2s, v2.2s, v0.2s
+; CHECK-COMMON-NEXT:    ext v2.16b, v3.16b, v3.16b, #8
+; CHECK-COMMON-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-COMMON-NEXT:    add v0.2s, v2.2s, v0.2s
+; CHECK-COMMON-NEXT:    ret
+entry:
+    %input.wide = zext <16 x i8> %input to <16 x i32>
+    %partial.reduce = tail call <2 x i32> @llvm.vector.partial.reduce.add(<2 x i32> %acc, <16 x i32> %input.wide)
+    ret <2 x i32> %partial.reduce
+}

>From f15c8ed731cd1ee93279954625fad8376d1dcc85 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 3 Oct 2025 11:36:13 +0100
Subject: [PATCH 2/3] [AArch64] Improve codegen for partial.reduce.add v16i8 ->
 v2i32

Rather than expanding, we can handle this case natively by
widening the accumulator.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 13 ++++
 .../neon-partial-reduce-dot-product.ll        | 66 ++++++++++++-------
 2 files changed, 55 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 70d5ad7d660f1..056d367a11949 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1458,6 +1458,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
       setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
       setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v16i8, Custom);
       setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
 
       if (Subtarget->hasMatMulInt8()) {
@@ -30768,6 +30769,18 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
       ResultVT.isFixedLengthVector() &&
       useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
 
+  // We can handle this case natively by accumulating into a wider
+  // zero-padded vector.
+  if (!ConvertToScalable && ResultVT == MVT::v2i32 && OpVT == MVT::v16i8) {
+    SDValue ZeroVec = DAG.getConstant(0, DL, MVT::v4i32);
+    SDValue WideAcc = DAG.getInsertSubvector(DL, ZeroVec, Acc, 0);
+    SDValue Wide = DAG.getNode(Op.getOpcode(), DL, MVT::v4i32,
+                               WideAcc, LHS, RHS);
+    SDValue Lo = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 0);
+    SDValue Hi = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 2);
+    return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+  }
+
   if (ConvertToScalable) {
     ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
     OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 824a3708451ba..fc9e3c8a52850 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1453,30 +1453,48 @@ define <4 x i32> @partial_reduce_shl_zext_non_const_rhs(<16 x i8> %l, <4 x i32>
 }
 
 define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) {
-; CHECK-COMMON-LABEL: udot_v16i8tov2i32:
-; CHECK-COMMON:       // %bb.0: // %entry
-; CHECK-COMMON-NEXT:    ushll v2.8h, v1.8b, #0
-; CHECK-COMMON-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-COMMON-NEXT:    ushll2 v1.8h, v1.16b, #0
-; CHECK-COMMON-NEXT:    ushll v3.4s, v2.4h, #0
-; CHECK-COMMON-NEXT:    uaddw v0.4s, v0.4s, v2.4h
-; CHECK-COMMON-NEXT:    ushll2 v4.4s, v2.8h, #0
-; CHECK-COMMON-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-COMMON-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-COMMON-NEXT:    add v0.2s, v3.2s, v0.2s
-; CHECK-COMMON-NEXT:    ext v3.16b, v4.16b, v4.16b, #8
-; CHECK-COMMON-NEXT:    uaddw v0.4s, v0.4s, v2.4h
-; CHECK-COMMON-NEXT:    ushll v2.4s, v1.4h, #0
-; CHECK-COMMON-NEXT:    add v0.2s, v3.2s, v0.2s
-; CHECK-COMMON-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-COMMON-NEXT:    ushll2 v3.4s, v1.8h, #0
-; CHECK-COMMON-NEXT:    uaddw v0.4s, v0.4s, v1.4h
-; CHECK-COMMON-NEXT:    ext v1.16b, v1.16b, v1.16b, #8
-; CHECK-COMMON-NEXT:    add v0.2s, v2.2s, v0.2s
-; CHECK-COMMON-NEXT:    ext v2.16b, v3.16b, v3.16b, #8
-; CHECK-COMMON-NEXT:    uaddw v0.4s, v0.4s, v1.4h
-; CHECK-COMMON-NEXT:    add v0.2s, v2.2s, v0.2s
-; CHECK-COMMON-NEXT:    ret
+; CHECK-NODOT-LABEL: udot_v16i8tov2i32:
+; CHECK-NODOT:       // %bb.0: // %entry
+; CHECK-NODOT-NEXT:    ushll v2.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT:    ushll v3.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT:    ushll2 v4.4s, v2.8h, #0
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v3.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT:    ext v1.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    add v0.2s, v2.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v2.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v2.2s, v0.2s
+; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-DOT-LABEL: udot_v16i8tov2i32:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v2.16b, #1
+; CHECK-DOT-NEXT:    fmov d0, d0
+; CHECK-DOT-NEXT:    udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
+; CHECK-DOT-NEXT:    add v0.2s, v0.2s, v1.2s
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-DOT-I8MM-LABEL: udot_v16i8tov2i32:
+; CHECK-DOT-I8MM:       // %bb.0: // %entry
+; CHECK-DOT-I8MM-NEXT:    movi v2.16b, #1
+; CHECK-DOT-I8MM-NEXT:    fmov d0, d0
+; CHECK-DOT-I8MM-NEXT:    udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
+; CHECK-DOT-I8MM-NEXT:    add v0.2s, v0.2s, v1.2s
+; CHECK-DOT-I8MM-NEXT:    ret
 entry:
     %input.wide = zext <16 x i8> %input to <16 x i32>
     %partial.reduce = tail call <2 x i32> @llvm.vector.partial.reduce.add(<2 x i32> %acc, <16 x i32> %input.wide)

>From aebcbc9b4e97cecb43a5484538659e3a47e7e3d5 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 3 Oct 2025 14:18:58 +0100
Subject: [PATCH 3/3] Use addp instead of ext + add

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp           | 5 ++---
 .../CodeGen/AArch64/neon-partial-reduce-dot-product.ll    | 8 ++++----
 2 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 056d367a11949..a837e2b9d15e7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -30776,9 +30776,8 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
     SDValue WideAcc = DAG.getInsertSubvector(DL, ZeroVec, Acc, 0);
     SDValue Wide = DAG.getNode(Op.getOpcode(), DL, MVT::v4i32,
                                WideAcc, LHS, RHS);
-    SDValue Lo = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 0);
-    SDValue Hi = DAG.getExtractSubvector(DL, MVT::v2i32, Wide, 2);
-    return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+    SDValue Reduced = DAG.getNode(AArch64ISD::ADDP, DL, MVT::v4i32, Wide, Wide);
+    return DAG.getExtractSubvector(DL, MVT::v2i32, Reduced, 0);
   }
 
   if (ConvertToScalable) {
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index fc9e3c8a52850..dfff35d9eb1b2 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1483,8 +1483,8 @@ define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) {
 ; CHECK-DOT-NEXT:    movi v2.16b, #1
 ; CHECK-DOT-NEXT:    fmov d0, d0
 ; CHECK-DOT-NEXT:    udot v0.4s, v1.16b, v2.16b
-; CHECK-DOT-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-DOT-NEXT:    add v0.2s, v0.2s, v1.2s
+; CHECK-DOT-NEXT:    addp v0.4s, v0.4s, v0.4s
+; CHECK-DOT-NEXT:    // kill: def $d0 killed $d0 killed $q0
 ; CHECK-DOT-NEXT:    ret
 ;
 ; CHECK-DOT-I8MM-LABEL: udot_v16i8tov2i32:
@@ -1492,8 +1492,8 @@ define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) {
 ; CHECK-DOT-I8MM-NEXT:    movi v2.16b, #1
 ; CHECK-DOT-I8MM-NEXT:    fmov d0, d0
 ; CHECK-DOT-I8MM-NEXT:    udot v0.4s, v1.16b, v2.16b
-; CHECK-DOT-I8MM-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-DOT-I8MM-NEXT:    add v0.2s, v0.2s, v1.2s
+; CHECK-DOT-I8MM-NEXT:    addp v0.4s, v0.4s, v0.4s
+; CHECK-DOT-I8MM-NEXT:    // kill: def $d0 killed $d0 killed $q0
 ; CHECK-DOT-I8MM-NEXT:    ret
 entry:
     %input.wide = zext <16 x i8> %input to <16 x i32>



More information about the llvm-commits mailing list