[llvm] [AArch64] Add ISel support for partial reductions to use SVE2.1 udot/sdot (PR #158310)

Damian Heaton via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 15 03:45:47 PDT 2025


https://github.com/dheaton-arm updated https://github.com/llvm/llvm-project/pull/158310

>From ead45fa3c9b2f67b967f16ff510aff5868ccb1ca Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Fri, 12 Sep 2025 13:23:58 +0000
Subject: [PATCH 1/4] [AArch64] Add ISel support for partial reductions to use
 SVE2.1 udot/sdot

This allows dot products with scalable 8xi16 vectors (and fixed-length vectors
which are converted into a scalable vector) accumulating into a 4xi32 vector to
lower into a single instruction (`udot`/`sdot`), rather than a sequence of
`umlalb`s and `umlalt`s`.
---
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  7 +++
 .../AArch64/sve2p1-dots-partial-reduction.ll  | 62 +++++++++++++++++++
 2 files changed, 69 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll

diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 7604ffdc9f646..0528da568ae40 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4238,6 +4238,13 @@ defm UDOT_ZZZ_HtoS  : sve2p1_two_way_dot_vv<"udot", 0b1, int_aarch64_sve_udot_x2
 defm SDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"sdot", 0b0, int_aarch64_sve_sdot_lane_x2>;
 defm UDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"udot", 0b1, int_aarch64_sve_udot_lane_x2>;
 
+let Predicates = [HasSVE2p1_or_SME2] in {
+  def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
+            (UDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
+  def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
+            (SDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
+} // End HasSVE_or_SME
+
 defm SQCVTN_Z2Z_StoH  : sve2p1_multi_vec_extract_narrow<"sqcvtn", 0b00, int_aarch64_sve_sqcvtn_x2>;
 defm UQCVTN_Z2Z_StoH  : sve2p1_multi_vec_extract_narrow<"uqcvtn", 0b01, int_aarch64_sve_uqcvtn_x2>;
 defm SQCVTUN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtun", 0b10, int_aarch64_sve_sqcvtun_x2>;
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
new file mode 100644
index 0000000000000..8abd3a86ff1f7
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
@@ -0,0 +1,62 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
+
+define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: udot:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    udot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i32>
+  %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i32>
+  %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: sdot:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i32>
+  %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i32>
+  %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <4 x i32> @fixed_udot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: fixed_udot_s_h:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    udot z0.s, z1.h, z2.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <8 x i16> %a to <8 x i32>
+  %b.wide = zext <8 x i16> %b to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <4 x i32> @fixed_sdot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: fixed_sdot_s_h:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    sdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <8 x i16> %a to <8 x i32>
+  %b.wide = sext <8 x i16> %b to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}

>From b8aa1d0c1edee7278bf4edd1723e81360b85a851 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Fri, 12 Sep 2025 15:38:53 +0000
Subject: [PATCH 2/4] Add vl256 tests & fix comment typo

---
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  2 +-
 .../AArch64/sve2p1-dots-partial-reduction.ll  | 34 ++++++++++++++++---
 2 files changed, 31 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 0528da568ae40..7fe4f7acdbd49 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4243,7 +4243,7 @@ let Predicates = [HasSVE2p1_or_SME2] in {
             (UDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
   def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
             (SDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
-} // End HasSVE_or_SME
+} // End HasSVE2p1_or_SME2
 
 defm SQCVTN_Z2Z_StoH  : sve2p1_multi_vec_extract_narrow<"sqcvtn", 0b00, int_aarch64_sve_sqcvtn_x2>;
 defm UQCVTN_Z2Z_StoH  : sve2p1_multi_vec_extract_narrow<"uqcvtn", 0b01, int_aarch64_sve_uqcvtn_x2>;
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
index 8abd3a86ff1f7..7933730a068b1 100644
--- a/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
+++ b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
@@ -1,8 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
 
-define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: udot:
+define <vscale x 4 x i32> @udot_vl128(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: udot_vl128:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    udot z0.s, z1.h, z2.h
 ; CHECK-NEXT:    ret
@@ -14,8 +14,34 @@ entry:
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: sdot:
+define <vscale x 4 x i32> @sdot_vl128(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b)  {
+; CHECK-LABEL: sdot_vl128:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i32>
+  %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i32>
+  %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_vl256(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) vscale_range(2,2) {
+; CHECK-LABEL: udot_vl256:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    udot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i32>
+  %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i32>
+  %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot_vl256(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) vscale_range(2,2) {
+; CHECK-LABEL: sdot_vl256:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    sdot z0.s, z1.h, z2.h
 ; CHECK-NEXT:    ret

>From c48caac72cfc2e2234ce0065fd784a589633f7ba Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Fri, 12 Sep 2025 16:21:26 +0000
Subject: [PATCH 3/4] Double the width of the base types for vl256 test
 variants

---
 .../AArch64/sve2p1-dots-partial-reduction.ll  | 30 ++++++++++---------
 1 file changed, 16 insertions(+), 14 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
index 7933730a068b1..770224743d1a7 100644
--- a/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
+++ b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
@@ -27,30 +27,32 @@ entry:
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 4 x i32> @udot_vl256(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) vscale_range(2,2) {
+define <vscale x 8 x i32> @udot_vl256(<vscale x 8 x i32> %acc, <vscale x 16 x i16> %a, <vscale x 16 x i16> %b) vscale_range(2,2) {
 ; CHECK-LABEL: udot_vl256:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    udot z0.s, z1.h, z2.h
+; CHECK-NEXT:    udot z0.s, z2.h, z4.h
+; CHECK-NEXT:    udot z1.s, z3.h, z5.h
 ; CHECK-NEXT:    ret
 entry:
-  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i32>
-  %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i32>
-  %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
-  ret <vscale x 4 x i32> %partial.reduce
+  %a.wide = zext <vscale x 16 x i16> %a to <vscale x 16 x i32>
+  %b.wide = zext <vscale x 16 x i16> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 8 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 8 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 8 x i32> %partial.reduce
 }
 
-define <vscale x 4 x i32> @sdot_vl256(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) vscale_range(2,2) {
+define <vscale x 8 x i32> @sdot_vl256(<vscale x 8 x i32> %acc, <vscale x 16 x i16> %a, <vscale x 16 x i16> %b) vscale_range(2,2) {
 ; CHECK-LABEL: sdot_vl256:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    sdot z0.s, z2.h, z4.h
+; CHECK-NEXT:    sdot z1.s, z3.h, z5.h
 ; CHECK-NEXT:    ret
 entry:
-  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i32>
-  %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i32>
-  %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
-  ret <vscale x 4 x i32> %partial.reduce
+  %a.wide = sext <vscale x 16 x i16> %a to <vscale x 16 x i32>
+  %b.wide = sext <vscale x 16 x i16> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 8 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 8 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 8 x i32> %partial.reduce
 }
 
 define <4 x i32> @fixed_udot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {

>From 80d20cb0127ede7e1d0516d3c2152d22d1ad1a35 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Mon, 15 Sep 2025 09:12:05 +0000
Subject: [PATCH 4/4] Remove `vscale` from vl256 test types

---
 .../AArch64/sve2p1-dots-partial-reduction.ll  | 56 ++++++++++++++-----
 1 file changed, 42 insertions(+), 14 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
index 770224743d1a7..a9943b7d99c71 100644
--- a/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
+++ b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
@@ -27,32 +27,60 @@ entry:
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 8 x i32> @udot_vl256(<vscale x 8 x i32> %acc, <vscale x 16 x i16> %a, <vscale x 16 x i16> %b) vscale_range(2,2) {
+define <8 x i32> @udot_vl256(<8 x i32> %acc, <16 x i16> %a, <16 x i16> %b) vscale_range(2,2) {
 ; CHECK-LABEL: udot_vl256:
 ; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $q5 killed $q5 killed $z4_z5 def $z4_z5
+; CHECK-NEXT:    // kill: def $q3 killed $q3 killed $z2_z3 def $z2_z3
+; CHECK-NEXT:    // kill: def $q1 killed $q1 killed $z0_z1 def $z0_z1
+; CHECK-NEXT:    ptrue p0.h, vl8
+; CHECK-NEXT:    ptrue p1.s, vl4
+; CHECK-NEXT:    // kill: def $q4 killed $q4 killed $z4_z5 def $z4_z5
+; CHECK-NEXT:    // kill: def $q2 killed $q2 killed $z2_z3 def $z2_z3
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0_z1 def $z0_z1
+; CHECK-NEXT:    splice z4.h, p0, { z4.h, z5.h }
+; CHECK-NEXT:    splice z2.h, p0, { z2.h, z3.h }
+; CHECK-NEXT:    splice z0.s, p1, { z0.s, z1.s }
 ; CHECK-NEXT:    udot z0.s, z2.h, z4.h
-; CHECK-NEXT:    udot z1.s, z3.h, z5.h
+; CHECK-NEXT:    movprfx z1, z0
+; CHECK-NEXT:    ext z1.b, z1.b, z0.b, #16
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    // kill: def $q1 killed $q1 killed $z1
 ; CHECK-NEXT:    ret
 entry:
-  %a.wide = zext <vscale x 16 x i16> %a to <vscale x 16 x i32>
-  %b.wide = zext <vscale x 16 x i16> %b to <vscale x 16 x i32>
-  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 8 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 8 x i32> %acc, <vscale x 16 x i32> %mult)
-  ret <vscale x 8 x i32> %partial.reduce
+  %a.wide = zext <16 x i16> %a to <16 x i32>
+  %b.wide = zext <16 x i16> %b to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
+  ret <8 x i32> %partial.reduce
 }
 
-define <vscale x 8 x i32> @sdot_vl256(<vscale x 8 x i32> %acc, <vscale x 16 x i16> %a, <vscale x 16 x i16> %b) vscale_range(2,2) {
+define <8 x i32> @sdot_vl256(<8 x i32> %acc, <16 x i16> %a, <16 x i16> %b) vscale_range(2,2) {
 ; CHECK-LABEL: sdot_vl256:
 ; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $q5 killed $q5 killed $z4_z5 def $z4_z5
+; CHECK-NEXT:    // kill: def $q3 killed $q3 killed $z2_z3 def $z2_z3
+; CHECK-NEXT:    // kill: def $q1 killed $q1 killed $z0_z1 def $z0_z1
+; CHECK-NEXT:    ptrue p0.h, vl8
+; CHECK-NEXT:    ptrue p1.s, vl4
+; CHECK-NEXT:    // kill: def $q4 killed $q4 killed $z4_z5 def $z4_z5
+; CHECK-NEXT:    // kill: def $q2 killed $q2 killed $z2_z3 def $z2_z3
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0_z1 def $z0_z1
+; CHECK-NEXT:    splice z4.h, p0, { z4.h, z5.h }
+; CHECK-NEXT:    splice z2.h, p0, { z2.h, z3.h }
+; CHECK-NEXT:    splice z0.s, p1, { z0.s, z1.s }
 ; CHECK-NEXT:    sdot z0.s, z2.h, z4.h
-; CHECK-NEXT:    sdot z1.s, z3.h, z5.h
+; CHECK-NEXT:    movprfx z1, z0
+; CHECK-NEXT:    ext z1.b, z1.b, z0.b, #16
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    // kill: def $q1 killed $q1 killed $z1
 ; CHECK-NEXT:    ret
 entry:
-  %a.wide = sext <vscale x 16 x i16> %a to <vscale x 16 x i32>
-  %b.wide = sext <vscale x 16 x i16> %b to <vscale x 16 x i32>
-  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 8 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 8 x i32> %acc, <vscale x 16 x i32> %mult)
-  ret <vscale x 8 x i32> %partial.reduce
+  %a.wide = sext <16 x i16> %a to <16 x i32>
+  %b.wide = sext <16 x i16> %b to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
+  ret <8 x i32> %partial.reduce
 }
 
 define <4 x i32> @fixed_udot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {



More information about the llvm-commits mailing list