[llvm] [AArch64][CostModel] Add constraints on which partial reductions are (PR #163728)
Sushant Gokhale via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 16 02:56:45 PDT 2025
https://github.com/sushgokh created https://github.com/llvm/llvm-project/pull/163728
natively supported on Neon and SVE
PR #158641 refined and refactored the cost model for partial reductions. While doing so, it missed out on certain constraints. Specifically, cases like i32 -> i64 partial reduce are not natively supported. This patch adds back the condition/constraint that was present before PR #158641
>From 431c530d85b779d75bc20a87ef543b74d023f309 Mon Sep 17 00:00:00 2001
From: sgokhale <sgokhale at nvidia.com>
Date: Thu, 16 Oct 2025 02:39:44 -0700
Subject: [PATCH] [AArch64][CostModel] Add constraints on which partial
reductions are natively supported on Neon and SVE
PR #158641 refined and refactored the cost model for partial reductions.
While doing so, it missed out on certain constraints. Specifically,
cases like i32 -> i64 partial reduce are not natively supported. This
patch adds back the condition/constraint that was present before PR #158641
---
.../AArch64/AArch64TargetTransformInfo.cpp | 3 +
.../LoopVectorize/AArch64/partial-reduce.ll | 82 +++++++++----------
2 files changed, 44 insertions(+), 41 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 479e34515fc8a..bfb2cacbd7dca 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5661,6 +5661,9 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
if (VF.getKnownMinValue() <= Ratio)
return Invalid;
+ // i32 -> i64 or i16 -> i32 is not natively supported on Neon and SVE.
+ if (Ratio < 4)
+ return Invalid;
VectorType *InputVectorType = VectorType::get(InputTypeA, VF);
VectorType *AccumVectorType =
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll
index 1e6bcb12029e7..0a61e8e9a505f 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll
@@ -1117,24 +1117,24 @@ define i64 @sext_reduction_i32_to_i64(ptr %arr, i64 %n) #1 {
; CHECK-INTERLEAVE1-SAME: ptr [[ARR:%.*]], i64 [[N:%.*]]) #[[ATTR2]] {
; CHECK-INTERLEAVE1-NEXT: entry:
; CHECK-INTERLEAVE1-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[N]], i64 1)
-; CHECK-INTERLEAVE1-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 2
+; CHECK-INTERLEAVE1-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 4
; CHECK-INTERLEAVE1-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK-INTERLEAVE1: vector.ph:
-; CHECK-INTERLEAVE1-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 2
+; CHECK-INTERLEAVE1-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 4
; CHECK-INTERLEAVE1-NEXT: [[N_VEC:%.*]] = sub i64 [[UMAX]], [[N_MOD_VF]]
; CHECK-INTERLEAVE1-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-INTERLEAVE1: vector.body:
; CHECK-INTERLEAVE1-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-INTERLEAVE1-NEXT: [[VEC_PHI:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVE1-NEXT: [[VEC_PHI:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
; CHECK-INTERLEAVE1-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[ARR]], i64 [[INDEX]]
-; CHECK-INTERLEAVE1-NEXT: [[WIDE_LOAD:%.*]] = load <2 x i32>, ptr [[TMP4]], align 4
-; CHECK-INTERLEAVE1-NEXT: [[TMP1:%.*]] = sext <2 x i32> [[WIDE_LOAD]] to <2 x i64>
-; CHECK-INTERLEAVE1-NEXT: [[TMP2]] = add <2 x i64> [[VEC_PHI]], [[TMP1]]
-; CHECK-INTERLEAVE1-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
+; CHECK-INTERLEAVE1-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP4]], align 4
+; CHECK-INTERLEAVE1-NEXT: [[TMP1:%.*]] = sext <4 x i32> [[WIDE_LOAD]] to <4 x i64>
+; CHECK-INTERLEAVE1-NEXT: [[TMP2]] = add <4 x i64> [[VEC_PHI]], [[TMP1]]
+; CHECK-INTERLEAVE1-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
; CHECK-INTERLEAVE1-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-INTERLEAVE1-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP21:![0-9]+]]
; CHECK-INTERLEAVE1: middle.block:
-; CHECK-INTERLEAVE1-NEXT: [[TMP5:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[TMP2]])
+; CHECK-INTERLEAVE1-NEXT: [[TMP5:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
; CHECK-INTERLEAVE1-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[UMAX]], [[N_VEC]]
; CHECK-INTERLEAVE1-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
; CHECK-INTERLEAVE1: scalar.ph:
@@ -1143,42 +1143,42 @@ define i64 @sext_reduction_i32_to_i64(ptr %arr, i64 %n) #1 {
; CHECK-INTERLEAVED-SAME: ptr [[ARR:%.*]], i64 [[N:%.*]]) #[[ATTR2]] {
; CHECK-INTERLEAVED-NEXT: entry:
; CHECK-INTERLEAVED-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[N]], i64 1)
-; CHECK-INTERLEAVED-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 8
+; CHECK-INTERLEAVED-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 16
; CHECK-INTERLEAVED-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK-INTERLEAVED: vector.ph:
-; CHECK-INTERLEAVED-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 8
+; CHECK-INTERLEAVED-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 16
; CHECK-INTERLEAVED-NEXT: [[N_VEC:%.*]] = sub i64 [[UMAX]], [[N_MOD_VF]]
; CHECK-INTERLEAVED-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-INTERLEAVED: vector.body:
; CHECK-INTERLEAVED-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-INTERLEAVED-NEXT: [[VEC_PHI:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[VECTOR_BODY]] ]
-; CHECK-INTERLEAVED-NEXT: [[VEC_PHI1:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
-; CHECK-INTERLEAVED-NEXT: [[VEC_PHI2:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP10:%.*]], [[VECTOR_BODY]] ]
-; CHECK-INTERLEAVED-NEXT: [[VEC_PHI3:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP11:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVED-NEXT: [[VEC_PHI:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP8:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVED-NEXT: [[VEC_PHI1:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVED-NEXT: [[VEC_PHI2:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP10:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVED-NEXT: [[VEC_PHI3:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP11:%.*]], [[VECTOR_BODY]] ]
; CHECK-INTERLEAVED-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[ARR]], i64 [[INDEX]]
-; CHECK-INTERLEAVED-NEXT: [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 2
; CHECK-INTERLEAVED-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 4
-; CHECK-INTERLEAVED-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 6
-; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD:%.*]] = load <2 x i32>, ptr [[TMP4]], align 4
-; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD4:%.*]] = load <2 x i32>, ptr [[TMP1]], align 4
-; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD5:%.*]] = load <2 x i32>, ptr [[TMP2]], align 4
-; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD6:%.*]] = load <2 x i32>, ptr [[TMP3]], align 4
-; CHECK-INTERLEAVED-NEXT: [[TMP14:%.*]] = sext <2 x i32> [[WIDE_LOAD]] to <2 x i64>
-; CHECK-INTERLEAVED-NEXT: [[TMP5:%.*]] = sext <2 x i32> [[WIDE_LOAD4]] to <2 x i64>
-; CHECK-INTERLEAVED-NEXT: [[TMP6:%.*]] = sext <2 x i32> [[WIDE_LOAD5]] to <2 x i64>
-; CHECK-INTERLEAVED-NEXT: [[TMP7:%.*]] = sext <2 x i32> [[WIDE_LOAD6]] to <2 x i64>
-; CHECK-INTERLEAVED-NEXT: [[TMP8]] = add <2 x i64> [[VEC_PHI]], [[TMP14]]
-; CHECK-INTERLEAVED-NEXT: [[TMP9]] = add <2 x i64> [[VEC_PHI1]], [[TMP5]]
-; CHECK-INTERLEAVED-NEXT: [[TMP10]] = add <2 x i64> [[VEC_PHI2]], [[TMP6]]
-; CHECK-INTERLEAVED-NEXT: [[TMP11]] = add <2 x i64> [[VEC_PHI3]], [[TMP7]]
-; CHECK-INTERLEAVED-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
+; CHECK-INTERLEAVED-NEXT: [[TMP14:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 8
+; CHECK-INTERLEAVED-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 12
+; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP4]], align 4
+; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD4:%.*]] = load <4 x i32>, ptr [[TMP2]], align 4
+; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD5:%.*]] = load <4 x i32>, ptr [[TMP14]], align 4
+; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD6:%.*]] = load <4 x i32>, ptr [[TMP3]], align 4
+; CHECK-INTERLEAVED-NEXT: [[TMP15:%.*]] = sext <4 x i32> [[WIDE_LOAD]] to <4 x i64>
+; CHECK-INTERLEAVED-NEXT: [[TMP5:%.*]] = sext <4 x i32> [[WIDE_LOAD4]] to <4 x i64>
+; CHECK-INTERLEAVED-NEXT: [[TMP6:%.*]] = sext <4 x i32> [[WIDE_LOAD5]] to <4 x i64>
+; CHECK-INTERLEAVED-NEXT: [[TMP7:%.*]] = sext <4 x i32> [[WIDE_LOAD6]] to <4 x i64>
+; CHECK-INTERLEAVED-NEXT: [[TMP8]] = add <4 x i64> [[VEC_PHI]], [[TMP15]]
+; CHECK-INTERLEAVED-NEXT: [[TMP9]] = add <4 x i64> [[VEC_PHI1]], [[TMP5]]
+; CHECK-INTERLEAVED-NEXT: [[TMP10]] = add <4 x i64> [[VEC_PHI2]], [[TMP6]]
+; CHECK-INTERLEAVED-NEXT: [[TMP11]] = add <4 x i64> [[VEC_PHI3]], [[TMP7]]
+; CHECK-INTERLEAVED-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-INTERLEAVED-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-INTERLEAVED-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP21:![0-9]+]]
; CHECK-INTERLEAVED: middle.block:
-; CHECK-INTERLEAVED-NEXT: [[BIN_RDX:%.*]] = add <2 x i64> [[TMP9]], [[TMP8]]
-; CHECK-INTERLEAVED-NEXT: [[BIN_RDX7:%.*]] = add <2 x i64> [[TMP10]], [[BIN_RDX]]
-; CHECK-INTERLEAVED-NEXT: [[BIN_RDX8:%.*]] = add <2 x i64> [[TMP11]], [[BIN_RDX7]]
-; CHECK-INTERLEAVED-NEXT: [[TMP13:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[BIN_RDX8]])
+; CHECK-INTERLEAVED-NEXT: [[BIN_RDX:%.*]] = add <4 x i64> [[TMP9]], [[TMP8]]
+; CHECK-INTERLEAVED-NEXT: [[BIN_RDX7:%.*]] = add <4 x i64> [[TMP10]], [[BIN_RDX]]
+; CHECK-INTERLEAVED-NEXT: [[BIN_RDX8:%.*]] = add <4 x i64> [[TMP11]], [[BIN_RDX7]]
+; CHECK-INTERLEAVED-NEXT: [[TMP13:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[BIN_RDX8]])
; CHECK-INTERLEAVED-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[UMAX]], [[N_VEC]]
; CHECK-INTERLEAVED-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
; CHECK-INTERLEAVED: scalar.ph:
@@ -1187,24 +1187,24 @@ define i64 @sext_reduction_i32_to_i64(ptr %arr, i64 %n) #1 {
; CHECK-MAXBW-SAME: ptr [[ARR:%.*]], i64 [[N:%.*]]) #[[ATTR2]] {
; CHECK-MAXBW-NEXT: entry:
; CHECK-MAXBW-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[N]], i64 1)
-; CHECK-MAXBW-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 2
+; CHECK-MAXBW-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], 4
; CHECK-MAXBW-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK-MAXBW: vector.ph:
-; CHECK-MAXBW-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 2
+; CHECK-MAXBW-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[UMAX]], 4
; CHECK-MAXBW-NEXT: [[N_VEC:%.*]] = sub i64 [[UMAX]], [[N_MOD_VF]]
; CHECK-MAXBW-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-MAXBW: vector.body:
; CHECK-MAXBW-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
+; CHECK-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <4 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
; CHECK-MAXBW-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[ARR]], i64 [[INDEX]]
-; CHECK-MAXBW-NEXT: [[WIDE_LOAD:%.*]] = load <2 x i32>, ptr [[TMP4]], align 4
-; CHECK-MAXBW-NEXT: [[TMP1:%.*]] = sext <2 x i32> [[WIDE_LOAD]] to <2 x i64>
-; CHECK-MAXBW-NEXT: [[TMP2]] = add <2 x i64> [[VEC_PHI]], [[TMP1]]
-; CHECK-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
+; CHECK-MAXBW-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP4]], align 4
+; CHECK-MAXBW-NEXT: [[TMP1:%.*]] = sext <4 x i32> [[WIDE_LOAD]] to <4 x i64>
+; CHECK-MAXBW-NEXT: [[TMP2]] = add <4 x i64> [[VEC_PHI]], [[TMP1]]
+; CHECK-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
; CHECK-MAXBW-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-MAXBW-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP21:![0-9]+]]
; CHECK-MAXBW: middle.block:
-; CHECK-MAXBW-NEXT: [[TMP5:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[TMP2]])
+; CHECK-MAXBW-NEXT: [[TMP5:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP2]])
; CHECK-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[UMAX]], [[N_VEC]]
; CHECK-MAXBW-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
; CHECK-MAXBW: scalar.ph:
More information about the llvm-commits
mailing list