[llvm] [AArch64] Tweak the cost-model of partial reductions to mitigate regressions from #181706 (PR #181707)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 26 07:44:31 PDT 2026
https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/181707
>From b9181c5e8f833986ff2228f6d2d559fbc410d6b5 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Tue, 3 Mar 2026 21:51:25 +0000
Subject: [PATCH 1/5] NFC Pre-commit of rerunning checks on
partial-reduced-chained.ll
---
.../AArch64/partial-reduce-chained.ll | 48 +++++++++----------
1 file changed, 24 insertions(+), 24 deletions(-)
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index d1fde2cdaafe1..15e0220b71d61 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -193,7 +193,7 @@ define i32 @chained_partial_reduce_add_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP11]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEON-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
+; CHECK-NEON-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK-NEON: middle.block:
; CHECK-NEON-NEXT: [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -234,7 +234,7 @@ define i32 @chained_partial_reduce_add_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP11]])
; CHECK-SVE-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
+; CHECK-SVE-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK-SVE: middle.block:
; CHECK-SVE-NEXT: [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -275,7 +275,7 @@ define i32 @chained_partial_reduce_add_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP17]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-MAXBW-NEXT: [[TMP18:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP18]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
+; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP18]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
; CHECK-SVE-MAXBW-NEXT: [[TMP19:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -345,7 +345,7 @@ define i32 @chained_partial_reduce_sub_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP12]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEON-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK-NEON-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
; CHECK-NEON: middle.block:
; CHECK-NEON-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -387,7 +387,7 @@ define i32 @chained_partial_reduce_sub_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP12]])
; CHECK-SVE-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK-SVE-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
; CHECK-SVE: middle.block:
; CHECK-SVE-NEXT: [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -429,7 +429,7 @@ define i32 @chained_partial_reduce_sub_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP18]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-MAXBW-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
; CHECK-SVE-MAXBW-NEXT: [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -500,7 +500,7 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP12]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEON-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
+; CHECK-NEON-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; CHECK-NEON: middle.block:
; CHECK-NEON-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-NEON-NEXT: [[TMP11:%.*]] = sub i32 0, [[TMP15]]
@@ -542,7 +542,7 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP12]])
; CHECK-SVE-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
+; CHECK-SVE-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; CHECK-SVE: middle.block:
; CHECK-SVE-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-NEXT: [[TMP18:%.*]] = sub i32 0, [[TMP15]]
@@ -584,7 +584,7 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP18]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-MAXBW-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
+; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
; CHECK-SVE-MAXBW-NEXT: [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-MAXBW-NEXT: [[TMP15:%.*]] = sub i32 0, [[TMP21]]
@@ -659,7 +659,7 @@ define i32 @chained_partial_reduce_add_add_add(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE4]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE3]], <16 x i32> [[TMP12]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEON-NEXT: br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP12:![0-9]+]]
+; CHECK-NEON-NEXT: br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
; CHECK-NEON: middle.block:
; CHECK-NEON-NEXT: [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE4]])
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -702,7 +702,7 @@ define i32 @chained_partial_reduce_add_add_add(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-SVE-NEXT: [[PARTIAL_REDUCE4]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE3]], <vscale x 16 x i32> [[TMP12]])
; CHECK-SVE-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-NEXT: [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-NEXT: br i1 [[TMP22]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP12:![0-9]+]]
+; CHECK-SVE-NEXT: br i1 [[TMP22]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
; CHECK-SVE: middle.block:
; CHECK-SVE-NEXT: [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE4]])
; CHECK-SVE-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -745,7 +745,7 @@ define i32 @chained_partial_reduce_add_add_add(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE4]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]], <vscale x 8 x i32> [[TMP18]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-MAXBW-NEXT: [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP12:![0-9]+]]
+; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
; CHECK-SVE-MAXBW-NEXT: [[TMP20:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE4]])
; CHECK-SVE-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -823,7 +823,7 @@ define i32 @chained_partial_reduce_sub_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE4]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE3]], <16 x i32> [[TMP15]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP16:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEON-NEXT: br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP14:![0-9]+]]
+; CHECK-NEON-NEXT: br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP12:![0-9]+]]
; CHECK-NEON: middle.block:
; CHECK-NEON-NEXT: [[TMP17:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE4]])
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -868,7 +868,7 @@ define i32 @chained_partial_reduce_sub_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-SVE-NEXT: [[PARTIAL_REDUCE4]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE3]], <vscale x 16 x i32> [[TMP14]])
; CHECK-SVE-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-NEXT: [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-NEXT: br i1 [[TMP22]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP14:![0-9]+]]
+; CHECK-SVE-NEXT: br i1 [[TMP22]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP12:![0-9]+]]
; CHECK-SVE: middle.block:
; CHECK-SVE-NEXT: [[TMP16:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE4]])
; CHECK-SVE-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -913,7 +913,7 @@ define i32 @chained_partial_reduce_sub_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE4]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]], <vscale x 8 x i32> [[TMP20]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-MAXBW-NEXT: [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP22]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP14:![0-9]+]]
+; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP22]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP12:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
; CHECK-SVE-MAXBW-NEXT: [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE4]])
; CHECK-SVE-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -987,7 +987,7 @@ define i32 @chained_partial_reduce_madd_extadd(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP9]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEON-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
+; CHECK-NEON-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP14:![0-9]+]]
; CHECK-NEON: middle.block:
; CHECK-NEON-NEXT: [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -1027,7 +1027,7 @@ define i32 @chained_partial_reduce_madd_extadd(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-SVE-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP10]])
; CHECK-SVE-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-NEXT: [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
+; CHECK-SVE-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP14:![0-9]+]]
; CHECK-SVE: middle.block:
; CHECK-SVE-NEXT: [[TMP12:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -1067,7 +1067,7 @@ define i32 @chained_partial_reduce_madd_extadd(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP15]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-MAXBW-NEXT: [[TMP18:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP18]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
+; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP18]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP14:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
; CHECK-SVE-MAXBW-NEXT: [[TMP19:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -1131,7 +1131,7 @@ define i32 @chained_partial_reduce_extadd_extadd(ptr %a, ptr %b, i32 %N) #0 {
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE2]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP6]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEON-NEXT: br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP18:![0-9]+]]
+; CHECK-NEON-NEXT: br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
; CHECK-NEON: middle.block:
; CHECK-NEON-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE2]])
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -1167,7 +1167,7 @@ define i32 @chained_partial_reduce_extadd_extadd(ptr %a, ptr %b, i32 %N) #0 {
; CHECK-SVE-NEXT: [[PARTIAL_REDUCE2]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP10]])
; CHECK-SVE-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-NEXT: [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-NEXT: br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP18:![0-9]+]]
+; CHECK-SVE-NEXT: br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
; CHECK-SVE: middle.block:
; CHECK-SVE-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE2]])
; CHECK-SVE-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -1203,7 +1203,7 @@ define i32 @chained_partial_reduce_extadd_extadd(ptr %a, ptr %b, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE2]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP12]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-MAXBW-NEXT: [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP18:![0-9]+]]
+; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
; CHECK-SVE-MAXBW-NEXT: [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE2]])
; CHECK-SVE-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -1268,7 +1268,7 @@ define i32 @chained_partial_reduce_extadd_madd(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP10]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEON-NEXT: br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP20:![0-9]+]]
+; CHECK-NEON-NEXT: br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP18:![0-9]+]]
; CHECK-NEON: middle.block:
; CHECK-NEON-NEXT: [[TMP12:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -1308,7 +1308,7 @@ define i32 @chained_partial_reduce_extadd_madd(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-SVE-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP10]])
; CHECK-SVE-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-NEXT: [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP20:![0-9]+]]
+; CHECK-SVE-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP18:![0-9]+]]
; CHECK-SVE: middle.block:
; CHECK-SVE-NEXT: [[TMP12:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
@@ -1348,7 +1348,7 @@ define i32 @chained_partial_reduce_extadd_madd(ptr %a, ptr %b, ptr %c, i32 %N) #
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP16]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-MAXBW-NEXT: [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP20:![0-9]+]]
+; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP18:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
; CHECK-SVE-MAXBW-NEXT: [[TMP18:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]])
; CHECK-SVE-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
>From 3d044c3fb6847945d9b617aa705c5bb241ecf265 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Tue, 3 Mar 2026 11:40:11 +0000
Subject: [PATCH 2/5] Various changes to the cost-model.
This has a number of changes to the partial reduction cost-model:
* Implement the fact that *MLALB/T instructions can be used for
16-bit -> 32-bit partial reductions (or *MLAL/MLAL2 for NEON).
* Fixes the cost of reductions that don't have specific lowering,
rather than returning a random number, we now return the cost of
expanding the partial reduction in ISel.
For sub-reductions we scale the cost to make them slightly cheaper,
so that they're still candidates for forming cdot operations.
* Reduce the cost of FP reductions, which are currently prohibitively
expensive.
---
.../AArch64/AArch64TargetTransformInfo.cpp | 65 +++++++++++--------
.../partial-reduce-add-sdot-i16-i32.ll | 28 +++++++-
.../AArch64/partial-reduce-fdot-product.ll | 44 ++++++-------
.../AArch64/partial-reduce-sub-sdot.ll | 4 +-
4 files changed, 89 insertions(+), 52 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 5f73e10c9c626..140c168c6e0ee 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5958,22 +5958,26 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
std::pair<InstructionCost, MVT> InputLT =
getTypeLegalizationCost(InputVectorType);
- InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
+ // Returns cost of expanding the partial reduction in ISel.
+ auto GetExpandCost = [&]() -> InstructionCost {
+ unsigned ExtOpc = AccumVectorType->getElementType()->isFloatingPointTy()
+ ? Instruction::FPExt
+ : Instruction::ZExt;
+
+ Type *ExtVectorType =
+ VectorType::get(AccumVectorType->getElementType(), VF);
+ return (BinOp ? 2 : 1) *
+ getCastInstrCost(ExtOpc, ExtVectorType, InputVectorType,
+ TTI::CastContextHint::None, CostKind) +
+ (BinOp ? getArithmeticInstrCost(*BinOp, ExtVectorType, CostKind)
+ : InstructionCost()) +
+ Ratio * getArithmeticInstrCost(Opcode, AccumVectorType, CostKind);
+ };
- // The sub/negation cannot be folded into the operands of
- // ISD::PARTIAL_REDUCE_*MLA, so make the cost more expensive.
- if (Opcode == Instruction::Sub)
- Cost += 8;
-
- // Prefer using full types by costing half-full input types as more expensive.
- if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(),
- TypeSize::getScalable(128)))
- // FIXME: This can be removed after the cost of the extends are folded into
- // the dot-product expression in VPlan, after landing:
- // https://github.com/llvm/llvm-project/pull/147302
- Cost *= 2;
+ bool IsSub = Opcode == Instruction::Sub;
+ InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
- if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot) {
+ if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot && !IsSub) {
// i16 -> i64 is natively supported for udot/sdot
if (AccumLT.second.getScalarType() == MVT::i64 &&
InputLT.second.getScalarType() == MVT::i16)
@@ -5994,29 +5998,38 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
return Cost;
}
+ // For a ratio of 2, we can use 2 [u|s|f|bf]mlalb/t instructions.
+ if (Ratio == 2 && !IsSub &&
+ llvm::is_contained({MVT::i16, MVT::i32, MVT::f16, MVT::bf16},
+ InputLT.second.getScalarType().SimpleTy))
+ return Cost * 2;
+
// i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE.
if (ST->isSVEorStreamingSVEAvailable() ||
(AccumLT.second.isFixedLengthVector() && ST->isNeonAvailable() &&
ST->hasDotProd())) {
if (AccumLT.second.getScalarType() == MVT::i32 &&
- InputLT.second.getScalarType() == MVT::i8)
+ InputLT.second.getScalarType() == MVT::i8 && !IsSub)
return Cost;
}
// f16 -> f32 is natively supported for fdot
- if (Opcode == Instruction::FAdd && (ST->hasSME2() || ST->hasSVE2p1())) {
- if (AccumLT.second.getScalarType() == MVT::f32 &&
- InputLT.second.getScalarType() == MVT::f16 &&
- AccumLT.second.getVectorMinNumElements() == 4 &&
- InputLT.second.getVectorMinNumElements() == 8)
- return Cost;
- // Floating-point types aren't promoted, so expanding the partial reduction
- // is more expensive.
- return Cost + 20;
+ if (Opcode == Instruction::FAdd && (ST->hasSME2() || ST->hasSVE2p1()) &&
+ AccumLT.second.getScalarType() == MVT::f32 &&
+ InputLT.second.getScalarType() == MVT::f16 &&
+ AccumLT.second.getVectorMinNumElements() == 4 &&
+ InputLT.second.getVectorMinNumElements() == 8)
+ return Cost;
+
+ if (IsSub) {
+ // Slightly lower the cost of a sub reduction so that it can be considered
+ // as candidate for 'cdot' operations. This is a somewhat arbitrary number,
+ // because we don't yet model these operations directly.
+ return (8 * GetExpandCost()) / 10;
}
- // Add additional cost for the extends that would need to be inserted.
- return Cost + 2;
+ // By default, assume the operation is expanded.
+ return GetExpandCost();
}
InstructionCost
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-add-sdot-i16-i32.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-add-sdot-i16-i32.ll
index 02afd113d3efa..6e320959e0b4c 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-add-sdot-i16-i32.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-add-sdot-i16-i32.ll
@@ -17,14 +17,19 @@
; RUN: -disable-output < %s 2>&1 | FileCheck %s --check-prefix=CHECK-SCALABLE
; LV: Checking a loop in 'sext_reduction_i16_to_i32'
-; CHECK-FIXED-BASE: Cost of 3 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.add (ir<%load> sext to i32)
+; CHECK-FIXED-BASE: Cost of 2 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.add (ir<%load> sext to i32)
; CHECK-FIXED: Cost of 1 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.add (ir<%load> sext to i32)
; CHECK-SCALABLE: Cost of 1 for VF vscale x 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.add (ir<%load> sext to i32)
; LV: Checking a loop in 'zext_reduction_i16_to_i32'
-; CHECK-FIXED-BASE: Cost of 3 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.add (ir<%load> zext to i32)
+; CHECK-FIXED-BASE: Cost of 2 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.add (ir<%load> zext to i32)
; CHECK-FIXED: Cost of 1 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.add (ir<%load> zext to i32)
; CHECK-SCALABLE: Cost of 1 for VF vscale x 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.add (ir<%load> zext to i32)
+
+; LV: Checking a loop in 'fpext_reduction_half_to_float'
+; CHECK-FIXED-BASE: Cost of 2 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.fadd (ir<%load> reassoc contract fpext to float)
+; CHECK-FIXED: Cost of 2 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.fadd (ir<%load> reassoc contract fpext to float)
+; CHECK-SCALABLE: Cost of 2 for VF vscale x 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.fadd (ir<%load> reassoc contract fpext to float)
target triple = "aarch64"
define i32 @sext_reduction_i16_to_i32(ptr %arr, i32 %n) vscale_range(1,16) {
@@ -64,3 +69,22 @@ loop:
exit:
ret i32 %add
}
+
+define float @fpext_reduction_half_to_float(ptr %arr, i32 %n) vscale_range(1,16) {
+entry:
+ br label %loop
+
+loop:
+ %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop ]
+ %acc = phi float [ 0.0, %entry ], [ %add, %loop ]
+ %gep = getelementptr inbounds half, ptr %arr, i32 %iv
+ %load = load half, ptr %gep
+ %zext = fpext half %load to float
+ %add = fadd reassoc contract float %acc, %zext
+ %iv.next = add i32 %iv, 1
+ %cmp = icmp ult i32 %iv.next, %n
+ br i1 %cmp, label %loop, label %exit
+
+exit:
+ ret float %add
+}
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
index c94b8996718e1..8573e1b0937f2 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
@@ -143,42 +143,42 @@ define double @fdot_f32_f64(ptr %a, ptr %b) #0 {
; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i64 [[TMP0]], 2
+; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i64 [[TMP0]], 3
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP1]]
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
; CHECK: [[VECTOR_PH]]:
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT: [[TMP6:%.*]] = shl nuw i64 [[TMP2]], 1
+; CHECK-NEXT: [[TMP6:%.*]] = shl nuw i64 [[TMP2]], 2
; CHECK-NEXT: [[TMP3:%.*]] = shl nuw i64 [[TMP6]], 1
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 2 x double> [ insertelement (<vscale x 2 x double> splat (double -0.000000e+00), double 0.000000e+00, i32 0), %[[VECTOR_PH]] ], [ [[TMP18:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <vscale x 2 x double> [ splat (double -0.000000e+00), %[[VECTOR_PH]] ], [ [[TMP19:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[TMP4:%.*]] = getelementptr float, ptr [[A]], i64 [[INDEX]]
-; CHECK-NEXT: [[TMP7:%.*]] = getelementptr float, ptr [[TMP4]], i64 [[TMP6]]
-; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 2 x float>, ptr [[TMP4]], align 1
-; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <vscale x 2 x float>, ptr [[TMP7]], align 1
-; CHECK-NEXT: [[TMP8:%.*]] = fpext <vscale x 2 x float> [[WIDE_LOAD]] to <vscale x 2 x double>
-; CHECK-NEXT: [[TMP9:%.*]] = fpext <vscale x 2 x float> [[WIDE_LOAD2]] to <vscale x 2 x double>
-; CHECK-NEXT: [[TMP10:%.*]] = getelementptr float, ptr [[B]], i64 [[INDEX]]
-; CHECK-NEXT: [[TMP13:%.*]] = getelementptr float, ptr [[TMP10]], i64 [[TMP6]]
-; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <vscale x 2 x float>, ptr [[TMP10]], align 1
-; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 2 x float>, ptr [[TMP13]], align 1
-; CHECK-NEXT: [[TMP14:%.*]] = fpext <vscale x 2 x float> [[WIDE_LOAD3]] to <vscale x 2 x double>
-; CHECK-NEXT: [[TMP15:%.*]] = fpext <vscale x 2 x float> [[WIDE_LOAD4]] to <vscale x 2 x double>
-; CHECK-NEXT: [[TMP16:%.*]] = fmul <vscale x 2 x double> [[TMP14]], [[TMP8]]
-; CHECK-NEXT: [[TMP17:%.*]] = fmul <vscale x 2 x double> [[TMP15]], [[TMP9]]
-; CHECK-NEXT: [[TMP18]] = fadd reassoc contract <vscale x 2 x double> [[TMP16]], [[VEC_PHI]]
-; CHECK-NEXT: [[TMP19]] = fadd reassoc contract <vscale x 2 x double> [[TMP17]], [[VEC_PHI1]]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 2 x double> [ insertelement (<vscale x 2 x double> splat (double -0.000000e+00), double 0.000000e+00, i32 0), %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <vscale x 2 x double> [ insertelement (<vscale x 2 x double> splat (double -0.000000e+00), double -0.000000e+00, i32 0), %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[TMP5:%.*]] = getelementptr float, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT: [[TMP15:%.*]] = getelementptr float, ptr [[TMP5]], i64 [[TMP6]]
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[TMP5]], align 1
+; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <vscale x 4 x float>, ptr [[TMP15]], align 1
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr float, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr float, ptr [[TMP7]], i64 [[TMP6]]
+; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <vscale x 4 x float>, ptr [[TMP7]], align 1
+; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 4 x float>, ptr [[TMP8]], align 1
+; CHECK-NEXT: [[TMP9:%.*]] = fpext <vscale x 4 x float> [[WIDE_LOAD3]] to <vscale x 4 x double>
+; CHECK-NEXT: [[TMP10:%.*]] = fpext <vscale x 4 x float> [[WIDE_LOAD]] to <vscale x 4 x double>
+; CHECK-NEXT: [[TMP11:%.*]] = fmul <vscale x 4 x double> [[TMP9]], [[TMP10]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call reassoc contract <vscale x 2 x double> @llvm.vector.partial.reduce.fadd.nxv2f64.nxv4f64(<vscale x 2 x double> [[VEC_PHI]], <vscale x 4 x double> [[TMP11]])
+; CHECK-NEXT: [[TMP12:%.*]] = fpext <vscale x 4 x float> [[WIDE_LOAD4]] to <vscale x 4 x double>
+; CHECK-NEXT: [[TMP13:%.*]] = fpext <vscale x 4 x float> [[WIDE_LOAD2]] to <vscale x 4 x double>
+; CHECK-NEXT: [[TMP14:%.*]] = fmul <vscale x 4 x double> [[TMP12]], [[TMP13]]
+; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call reassoc contract <vscale x 2 x double> @llvm.vector.partial.reduce.fadd.nxv2f64.nxv4f64(<vscale x 2 x double> [[VEC_PHI1]], <vscale x 4 x double> [[TMP14]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT: br i1 [[TMP20]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[BIN_RDX:%.*]] = fadd reassoc contract <vscale x 2 x double> [[TMP19]], [[TMP18]]
-; CHECK-NEXT: [[TMP21:%.*]] = call reassoc contract double @llvm.vector.reduce.fadd.nxv2f64(double -0.000000e+00, <vscale x 2 x double> [[BIN_RDX]])
+; CHECK-NEXT: [[BIN_RDX:%.*]] = fadd reassoc contract <vscale x 2 x double> [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]]
+; CHECK-NEXT: [[TMP16:%.*]] = call reassoc contract double @llvm.vector.reduce.fadd.nxv2f64(double -0.000000e+00, <vscale x 2 x double> [[BIN_RDX]])
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], [[FOR_EXIT:label %.*]], label %[[SCALAR_PH]]
; CHECK: [[SCALAR_PH]]:
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
index f06b2137c2b8d..ff3881f2c7fda 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
@@ -15,9 +15,9 @@
; COMMON: LV: Checking a loop in 'add_sub_chained_reduction'
; SVE: Cost of 1 for VF vscale x 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
-; SVE: Cost of 9 for VF vscale x 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+; SVE: Cost of 16 for VF vscale x 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
; NEON: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
-; NEON: Cost of 9 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+; NEON: Cost of 16 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
target triple = "aarch64"
>From f75db42d0b5c051ff20e22720bcac379e51f19aa Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Tue, 17 Mar 2026 09:54:49 +0000
Subject: [PATCH 3/5] Improvements to cost-model
The chosen costs are more precise as it tries to better use the target-features to determine if something can be expanded.
The costs in sdot-i16-i32 are now more accurate and the loops that didn't vectorise before result in equivalent or better codegen.
---
.../AArch64/AArch64TargetTransformInfo.cpp | 98 +++++++++++--------
.../partial-reduce-add-sdot-i16-i32.ll | 6 +-
.../AArch64/partial-reduce-fdot-product.ll | 16 +--
.../AArch64/partial-reduce-no-dotprod.ll | 18 ++--
.../AArch64/partial-reduce-sub-sdot.ll | 4 +-
5 files changed, 81 insertions(+), 61 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 140c168c6e0ee..fb29dba8718a9 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5895,10 +5895,6 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
if (CostKind != TTI::TCK_RecipThroughput)
return Invalid;
- if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() &&
- (!ST->isNeonAvailable() || !ST->hasDotProd()))
- return Invalid;
-
if ((Opcode != Instruction::Add && Opcode != Instruction::Sub &&
Opcode != Instruction::FAdd) ||
OpAExtend == TTI::PR_None)
@@ -5958,26 +5954,26 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
std::pair<InstructionCost, MVT> InputLT =
getTypeLegalizationCost(InputVectorType);
- // Returns cost of expanding the partial reduction in ISel.
- auto GetExpandCost = [&]() -> InstructionCost {
- unsigned ExtOpc = AccumVectorType->getElementType()->isFloatingPointTy()
- ? Instruction::FPExt
- : Instruction::ZExt;
+ bool IsSub = Opcode == Instruction::Sub;
+ InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
- Type *ExtVectorType =
- VectorType::get(AccumVectorType->getElementType(), VF);
- return (BinOp ? 2 : 1) *
- getCastInstrCost(ExtOpc, ExtVectorType, InputVectorType,
- TTI::CastContextHint::None, CostKind) +
- (BinOp ? getArithmeticInstrCost(*BinOp, ExtVectorType, CostKind)
- : InstructionCost()) +
- Ratio * getArithmeticInstrCost(Opcode, AccumVectorType, CostKind);
+ // Returns true if the subtarget supports the operation for a given type.
+ auto IsSupported = [&AccumLT](bool SVEPred, bool NEONPred) -> bool {
+ return SVEPred || (AccumLT.second.isFixedLengthVector() &&
+ AccumLT.second.getSizeInBits() <= 128 && NEONPred);
};
- bool IsSub = Opcode == Instruction::Sub;
- InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
+ // i8 -> i32 is natively supported with udot/sdot/usdot, both for NEON and
+ // SVE.
+ if (IsSupported(ST->isSVEorStreamingSVEAvailable(), ST->hasDotProd()) &&
+ !IsSub) {
+ if (AccumLT.second.getScalarType() == MVT::i32 &&
+ InputLT.second.getScalarType() == MVT::i8)
+ return Cost;
+ }
- if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot && !IsSub) {
+ if (IsSupported(ST->isSVEorStreamingSVEAvailable(), false) && !IsUSDot &&
+ !IsSub) {
// i16 -> i64 is natively supported for udot/sdot
if (AccumLT.second.getScalarType() == MVT::i64 &&
InputLT.second.getScalarType() == MVT::i16)
@@ -5996,30 +5992,54 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
// the extends in the IR are still counted. This can be fixed
// after https://github.com/llvm/llvm-project/pull/147302 has landed.
return Cost;
+ // f16 -> f32 is natively supported for fdot
+ if (Opcode == Instruction::FAdd && (ST->hasSME2() || ST->hasSVE2p1()) &&
+ AccumLT.second.getScalarType() == MVT::f32 &&
+ InputLT.second.getScalarType() == MVT::f16 &&
+ AccumLT.second.getVectorMinNumElements() == 4 &&
+ InputLT.second.getVectorMinNumElements() == 8)
+ return Cost;
}
- // For a ratio of 2, we can use 2 [u|s|f|bf]mlalb/t instructions.
- if (Ratio == 2 && !IsSub &&
- llvm::is_contained({MVT::i16, MVT::i32, MVT::f16, MVT::bf16},
- InputLT.second.getScalarType().SimpleTy))
- return Cost * 2;
+ // For a ratio of 2, we can use *mlal top/bottom instructions.
+ if (Ratio == 2 && !IsSub) {
+ MVT InVT = InputLT.second.getScalarType();
- // i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE.
- if (ST->isSVEorStreamingSVEAvailable() ||
- (AccumLT.second.isFixedLengthVector() && ST->isNeonAvailable() &&
- ST->hasDotProd())) {
- if (AccumLT.second.getScalarType() == MVT::i32 &&
- InputLT.second.getScalarType() == MVT::i8 && !IsSub)
- return Cost;
+ // SVE2 [us]mlalb/t and NEON [us]mlal(2)
+ if (IsSupported(ST->isSVEorStreamingSVEAvailable() && ST->hasSVE2(),
+ ST->hasNEON()) &&
+ llvm::is_contained({MVT::i8, MVT::i16, MVT::i32}, InVT.SimpleTy))
+ return Cost * 2;
+
+ // SVE2 fmlalb/t and NEON fmlal(2)
+ if (IsSupported(ST->isSVEorStreamingSVEAvailable() && ST->hasSVE2(),
+ ST->hasFP16FML()) &&
+ InVT == MVT::f16)
+ return Cost * 2;
+
+ // SVE and NEON bfmlalb/t
+ if (IsSupported(ST->isSVEorStreamingSVEAvailable() && ST->hasBF16(),
+ ST->hasBF16()) &&
+ InVT == MVT::bf16)
+ return Cost * 2;
}
- // f16 -> f32 is natively supported for fdot
- if (Opcode == Instruction::FAdd && (ST->hasSME2() || ST->hasSVE2p1()) &&
- AccumLT.second.getScalarType() == MVT::f32 &&
- InputLT.second.getScalarType() == MVT::f16 &&
- AccumLT.second.getVectorMinNumElements() == 4 &&
- InputLT.second.getVectorMinNumElements() == 8)
- return Cost;
+ // Returns cost of expanding the partial reduction in ISel.
+ auto GetExpandCost = [&]() -> InstructionCost {
+ unsigned ExtOpc = AccumVectorType->getElementType()->isFloatingPointTy()
+ ? Instruction::FPExt
+ : Instruction::ZExt;
+
+ Type *ExtVectorType =
+ VectorType::get(AccumVectorType->getElementType(), VF);
+ return (BinOp ? 2 : 1) *
+ getCastInstrCost(ExtOpc, ExtVectorType, InputVectorType,
+ TTI::CastContextHint::None, CostKind) +
+ (BinOp ? getArithmeticInstrCost(*BinOp, ExtVectorType, CostKind)
+ : InstructionCost()) +
+ Log2_32(Ratio) *
+ getArithmeticInstrCost(Opcode, AccumVectorType, CostKind);
+ };
if (IsSub) {
// Slightly lower the cost of a sub reduction so that it can be considered
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-add-sdot-i16-i32.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-add-sdot-i16-i32.ll
index 6e320959e0b4c..6aca12ebf7dac 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-add-sdot-i16-i32.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-add-sdot-i16-i32.ll
@@ -27,9 +27,9 @@
; CHECK-SCALABLE: Cost of 1 for VF vscale x 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.add (ir<%load> zext to i32)
; LV: Checking a loop in 'fpext_reduction_half_to_float'
-; CHECK-FIXED-BASE: Cost of 2 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.fadd (ir<%load> reassoc contract fpext to float)
-; CHECK-FIXED: Cost of 2 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.fadd (ir<%load> reassoc contract fpext to float)
-; CHECK-SCALABLE: Cost of 2 for VF vscale x 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.fadd (ir<%load> reassoc contract fpext to float)
+; CHECK-FIXED-BASE: Cost of 3 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.fadd (ir<%load> reassoc contract fpext to float)
+; CHECK-FIXED: Cost of 1 for VF 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.fadd (ir<%load> reassoc contract fpext to float)
+; CHECK-SCALABLE: Cost of 1 for VF vscale x 8: EXPRESSION vp<%8> = ir<%acc> + partial.reduce.fadd (ir<%load> reassoc contract fpext to float)
target triple = "aarch64"
define i32 @sext_reduction_i16_to_i32(ptr %arr, i32 %n) vscale_range(1,16) {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
index 8573e1b0937f2..aa6623bed78a7 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
@@ -331,20 +331,20 @@ define float @not_fdot_f16_f32_nosve(ptr %a, ptr %b) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <8 x float> [ <float 0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00>, %[[VECTOR_PH]] ], [ [[TMP11:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x float> [ <float 0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00>, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr half, ptr [[A]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <8 x half>, ptr [[TMP0]], align 1
-; CHECK-NEXT: [[TMP3:%.*]] = fpext <8 x half> [[WIDE_LOAD2]] to <8 x float>
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr half, ptr [[B]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <8 x half>, ptr [[TMP4]], align 1
; CHECK-NEXT: [[TMP7:%.*]] = fpext <8 x half> [[WIDE_LOAD4]] to <8 x float>
+; CHECK-NEXT: [[TMP3:%.*]] = fpext <8 x half> [[WIDE_LOAD2]] to <8 x float>
; CHECK-NEXT: [[TMP9:%.*]] = fmul <8 x float> [[TMP7]], [[TMP3]]
-; CHECK-NEXT: [[TMP11]] = fadd reassoc contract <8 x float> [[TMP9]], [[VEC_PHI1]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call reassoc contract <4 x float> @llvm.vector.partial.reduce.fadd.v4f32.v8f32(<4 x float> [[VEC_PHI]], <8 x float> [[TMP9]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP11:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[TMP13:%.*]] = call reassoc contract float @llvm.vector.reduce.fadd.v8f32(float -0.000000e+00, <8 x float> [[TMP11]])
+; CHECK-NEXT: [[TMP13:%.*]] = call reassoc contract float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[PARTIAL_REDUCE]])
; CHECK-NEXT: br label %[[FOR_EXIT:.*]]
; CHECK: [[FOR_EXIT]]:
; CHECK-NEXT: ret float [[TMP13]]
@@ -380,20 +380,20 @@ define double @not_fdot_f32_f64_nosve(ptr %a, ptr %b) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <4 x double> [ <double 0.000000e+00, double -0.000000e+00, double -0.000000e+00, double -0.000000e+00>, %[[VECTOR_PH]] ], [ [[TMP11:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <2 x double> [ <double 0.000000e+00, double -0.000000e+00>, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr float, ptr [[A]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <4 x float>, ptr [[TMP0]], align 1
-; CHECK-NEXT: [[TMP3:%.*]] = fpext <4 x float> [[WIDE_LOAD2]] to <4 x double>
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr float, ptr [[B]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <4 x float>, ptr [[TMP4]], align 1
; CHECK-NEXT: [[TMP7:%.*]] = fpext <4 x float> [[WIDE_LOAD4]] to <4 x double>
+; CHECK-NEXT: [[TMP3:%.*]] = fpext <4 x float> [[WIDE_LOAD2]] to <4 x double>
; CHECK-NEXT: [[TMP9:%.*]] = fmul <4 x double> [[TMP7]], [[TMP3]]
-; CHECK-NEXT: [[TMP11]] = fadd reassoc contract <4 x double> [[TMP9]], [[VEC_PHI1]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call reassoc contract <2 x double> @llvm.vector.partial.reduce.fadd.v2f64.v4f64(<2 x double> [[VEC_PHI]], <4 x double> [[TMP9]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP12:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[TMP13:%.*]] = call reassoc contract double @llvm.vector.reduce.fadd.v4f64(double -0.000000e+00, <4 x double> [[TMP11]])
+; CHECK-NEXT: [[TMP13:%.*]] = call reassoc contract double @llvm.vector.reduce.fadd.v2f64(double -0.000000e+00, <2 x double> [[PARTIAL_REDUCE]])
; CHECK-NEXT: br label %[[FOR_EXIT:.*]]
; CHECK: [[FOR_EXIT]]:
; CHECK-NEXT: ret double [[TMP13]]
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-no-dotprod.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-no-dotprod.ll
index a439f5189794a..87076baff3296 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-no-dotprod.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-no-dotprod.ll
@@ -13,30 +13,30 @@ define i32 @not_dotp(ptr %a, ptr %b) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP13:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP14:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i64 16
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP1]], align 1
; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NEXT: [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i64 16
; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1
; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
; CHECK-NEXT: [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NEXT: [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
+; CHECK-NEXT: [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NEXT: [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
; CHECK-NEXT: [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NEXT: [[TMP13]] = add <16 x i32> [[TMP11]], [[VEC_PHI]]
-; CHECK-NEXT: [[TMP14]] = add <16 x i32> [[TMP12]], [[VEC_PHI1]]
+; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
; CHECK-NEXT: [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], 992
; CHECK-NEXT: br i1 [[TMP15]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[BIN_RDX:%.*]] = add <16 x i32> [[TMP14]], [[TMP13]]
-; CHECK-NEXT: [[TMP16:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[BIN_RDX]])
+; CHECK-NEXT: [[BIN_RDX:%.*]] = add <4 x i32> [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]]
+; CHECK-NEXT: [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[BIN_RDX]])
; CHECK-NEXT: br label %[[SCALAR_PH:.*]]
; CHECK: [[SCALAR_PH]]:
;
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
index ff3881f2c7fda..112e4d713d42b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
@@ -15,9 +15,9 @@
; COMMON: LV: Checking a loop in 'add_sub_chained_reduction'
; SVE: Cost of 1 for VF vscale x 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
-; SVE: Cost of 16 for VF vscale x 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+; SVE: Cost of 14 for VF vscale x 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
; NEON: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
-; NEON: Cost of 16 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+; NEON: Cost of 14 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
target triple = "aarch64"
>From 001dd5b1dbeab749409ae12a5fb750c28dea3fda Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Wed, 18 Mar 2026 10:10:26 +0000
Subject: [PATCH 4/5] Distinguish between extends
---
.../llvm/Analysis/TargetTransformInfo.h | 3 +++
llvm/lib/Analysis/TargetTransformInfo.cpp | 16 +++++++++++++
.../AArch64/AArch64TargetTransformInfo.cpp | 24 ++++++++++---------
3 files changed, 32 insertions(+), 11 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f01eb3aaad3d3..38172f4648001 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -283,6 +283,9 @@ class TargetTransformInfo {
/// Get the kind of extension that a cast opcode represents.
LLVM_ABI static PartialReductionExtendKind
getPartialReductionExtendKind(Instruction::CastOps CastOpc);
+ /// Get the cast opcode for an extension kind.
+ LLVM_ABI static Instruction::CastOps
+ getOpcodeForPartialReductionExtendKind(PartialReductionExtendKind Kind);
/// Construct a TTI object using a type implementing the \c Concept
/// API below.
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a196492641d53..69ecde011e1a9 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1050,6 +1050,22 @@ TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
return PR_None;
}
+Instruction::CastOps
+TargetTransformInfo::getOpcodeForPartialReductionExtendKind(
+ TargetTransformInfo::PartialReductionExtendKind Kind) {
+ switch (Kind) {
+ case TargetTransformInfo::PR_ZeroExtend:
+ return Instruction::CastOps::ZExt;
+ case TargetTransformInfo::PR_SignExtend:
+ return Instruction::CastOps::SExt;
+ case TargetTransformInfo::PR_FPExtend:
+ return Instruction::CastOps::FPExt;
+ default:
+ break;
+ }
+ llvm_unreachable("Unhandled partial reduction extend kind");
+}
+
TargetTransformInfo::PartialReductionExtendKind
TargetTransformInfo::getPartialReductionExtendKind(
Instruction::CastOps CastOpc) {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index fb29dba8718a9..acef2b4b95b55 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -6026,19 +6026,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
// Returns cost of expanding the partial reduction in ISel.
auto GetExpandCost = [&]() -> InstructionCost {
- unsigned ExtOpc = AccumVectorType->getElementType()->isFloatingPointTy()
- ? Instruction::FPExt
- : Instruction::ZExt;
-
Type *ExtVectorType =
VectorType::get(AccumVectorType->getElementType(), VF);
- return (BinOp ? 2 : 1) *
- getCastInstrCost(ExtOpc, ExtVectorType, InputVectorType,
- TTI::CastContextHint::None, CostKind) +
- (BinOp ? getArithmeticInstrCost(*BinOp, ExtVectorType, CostKind)
- : InstructionCost()) +
- Log2_32(Ratio) *
- getArithmeticInstrCost(Opcode, AccumVectorType, CostKind);
+ auto ExtendCostA = getCastInstrCost(
+ TTI::getOpcodeForPartialReductionExtendKind(OpAExtend), ExtVectorType,
+ InputVectorType, TTI::CastContextHint::None, CostKind);
+ auto RedOpCost = Log2_32(Ratio) *
+ getArithmeticInstrCost(Opcode, AccumVectorType, CostKind);
+ if (!BinOp)
+ return ExtendCostA + RedOpCost;
+
+ auto ExtendCostB = getCastInstrCost(
+ TTI::getOpcodeForPartialReductionExtendKind(OpBExtend), ExtVectorType,
+ InputVectorType, TTI::CastContextHint::None, CostKind);
+ return ExtendCostA + ExtendCostB + RedOpCost +
+ getArithmeticInstrCost(*BinOp, ExtVectorType, CostKind);
};
if (IsSub) {
>From b60114a5e20904191865f9abb5a1f0c5640fada0 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Thu, 26 Mar 2026 14:40:26 +0000
Subject: [PATCH 5/5] Address comments
---
.../AArch64/AArch64TargetTransformInfo.cpp | 46 +++++++++----------
.../AArch64/partial-reduce-sub-sdot.ll | 4 +-
2 files changed, 24 insertions(+), 26 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index acef2b4b95b55..9de96fff2bb69 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5923,6 +5923,7 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
bool IsUSDot = OpBExtend != TTI::PR_None && OpAExtend != OpBExtend;
if (IsUSDot && !ST->hasMatMulInt8())
+ // FIXME: Remove this early bailout in favour of expand cost.
return Invalid;
unsigned Ratio =
@@ -5954,26 +5955,28 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
std::pair<InstructionCost, MVT> InputLT =
getTypeLegalizationCost(InputVectorType);
- bool IsSub = Opcode == Instruction::Sub;
- InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
-
// Returns true if the subtarget supports the operation for a given type.
- auto IsSupported = [&AccumLT](bool SVEPred, bool NEONPred) -> bool {
- return SVEPred || (AccumLT.second.isFixedLengthVector() &&
- AccumLT.second.getSizeInBits() <= 128 && NEONPred);
+ auto IsSupported = [&](bool SVEPred, bool NEONPred) -> bool {
+ return (ST->isSVEorStreamingSVEAvailable() && SVEPred) ||
+ (AccumLT.second.isFixedLengthVector() &&
+ AccumLT.second.getSizeInBits() <= 128 && ST->isNeonAvailable() &&
+ NEONPred);
};
- // i8 -> i32 is natively supported with udot/sdot/usdot, both for NEON and
- // SVE.
- if (IsSupported(ST->isSVEorStreamingSVEAvailable(), ST->hasDotProd()) &&
- !IsSub) {
- if (AccumLT.second.getScalarType() == MVT::i32 &&
- InputLT.second.getScalarType() == MVT::i8)
+ bool IsSub = Opcode == Instruction::Sub;
+ InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
+
+ if (AccumLT.second.getScalarType() == MVT::i32 &&
+ InputLT.second.getScalarType() == MVT::i8 && !IsSub) {
+ // i8 -> i32 is natively supported with udot/sdot for both NEON and SVE.
+ if (!IsUSDot && IsSupported(true, ST->hasDotProd()))
+ return Cost;
+ // i8 -> i32 usdot requires +i8mm
+ if (IsUSDot && IsSupported(ST->hasMatMulInt8(), ST->hasMatMulInt8()))
return Cost;
}
- if (IsSupported(ST->isSVEorStreamingSVEAvailable(), false) && !IsUSDot &&
- !IsSub) {
+ if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot && !IsSub) {
// i16 -> i64 is natively supported for udot/sdot
if (AccumLT.second.getScalarType() == MVT::i64 &&
InputLT.second.getScalarType() == MVT::i16)
@@ -6006,21 +6009,16 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
MVT InVT = InputLT.second.getScalarType();
// SVE2 [us]mlalb/t and NEON [us]mlal(2)
- if (IsSupported(ST->isSVEorStreamingSVEAvailable() && ST->hasSVE2(),
- ST->hasNEON()) &&
+ if (IsSupported(ST->hasSVE2(), true) &&
llvm::is_contained({MVT::i8, MVT::i16, MVT::i32}, InVT.SimpleTy))
return Cost * 2;
// SVE2 fmlalb/t and NEON fmlal(2)
- if (IsSupported(ST->isSVEorStreamingSVEAvailable() && ST->hasSVE2(),
- ST->hasFP16FML()) &&
- InVT == MVT::f16)
+ if (IsSupported(ST->hasSVE2(), ST->hasFP16FML()) && InVT == MVT::f16)
return Cost * 2;
// SVE and NEON bfmlalb/t
- if (IsSupported(ST->isSVEorStreamingSVEAvailable() && ST->hasBF16(),
- ST->hasBF16()) &&
- InVT == MVT::bf16)
+ if (IsSupported(ST->hasBF16(), ST->hasBF16()) && InVT == MVT::bf16)
return Cost * 2;
}
@@ -6031,8 +6029,8 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
auto ExtendCostA = getCastInstrCost(
TTI::getOpcodeForPartialReductionExtendKind(OpAExtend), ExtVectorType,
InputVectorType, TTI::CastContextHint::None, CostKind);
- auto RedOpCost = Log2_32(Ratio) *
- getArithmeticInstrCost(Opcode, AccumVectorType, CostKind);
+ auto RedOpCost =
+ (Ratio - 1) * getArithmeticInstrCost(Opcode, AccumVectorType, CostKind);
if (!BinOp)
return ExtendCostA + RedOpCost;
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
index 112e4d713d42b..72b3301ece094 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
@@ -15,9 +15,9 @@
; COMMON: LV: Checking a loop in 'add_sub_chained_reduction'
; SVE: Cost of 1 for VF vscale x 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
-; SVE: Cost of 14 for VF vscale x 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+; SVE: Cost of 15 for VF vscale x 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
; NEON: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
-; NEON: Cost of 14 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+; NEON: Cost of 15 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
target triple = "aarch64"
More information about the llvm-commits
mailing list