[llvm] [LV] Support multiplies by constants when forming scaled reductions. (PR #161092)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 2 03:15:19 PDT 2025
https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/161092
>From 13aeeb7cd5debc910c1c1af126a5ec43edad0ec6 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 25 Sep 2025 14:09:34 +0100
Subject: [PATCH 1/2] [LV] Support multiplies by constants when forming scaled
reductions.
We can create partial reductions for multiplies with constants, if the
constant is small enough to be extended from source to destination type
w/o changing the value.
This only handles constant on the right side of a multiply, relying on
other passes to canonicalize the input.
Alive2 Proofs: https://alive2.llvm.org/ce/z/iWRMr6
---
.../Transforms/Vectorize/LoopVectorize.cpp | 7 ++++
llvm/lib/Transforms/Vectorize/VPlan.cpp | 10 ++++++
llvm/lib/Transforms/Vectorize/VPlanHelpers.h | 4 +++
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 9 +++++
.../AArch64/partial-reduce-constant-ops.ll | 34 +++++++++----------
5 files changed, 47 insertions(+), 17 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 12fb46da8e71a..b55685f735929 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7937,6 +7937,13 @@ bool VPRecipeBuilder::getScaledReductions(
auto CollectExtInfo = [this, &Exts, &ExtOpTypes,
&ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
for (const auto &[I, OpI] : enumerate(Ops)) {
+ auto *CI = dyn_cast<ConstantInt>(OpI);
+ if (I > 0 && CI &&
+ canConstantBeExtended(CI, ExtOpTypes[0], ExtKinds[0])) {
+ ExtOpTypes[I] = ExtOpTypes[0];
+ ExtKinds[I] = ExtKinds[0];
+ continue;
+ }
Value *ExtOp;
if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
return false;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 81f1956c96254..6273c066b00d4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1741,6 +1741,16 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) {
}
#endif
+bool llvm::canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
+ TTI::PartialReductionExtendKind ExtKind) {
+ APInt TruncatedVal = CI->getValue().trunc(NarrowType->getScalarSizeInBits());
+ unsigned WideSize = CI->getType()->getScalarSizeInBits();
+ APInt ExtendedVal = ExtKind == TTI::PR_SignExtend
+ ? TruncatedVal.sext(WideSize)
+ : TruncatedVal.zext(WideSize);
+ return ExtendedVal == CI->getValue();
+}
+
TargetTransformInfo::OperandValueInfo
VPCostContext::getOperandInfo(VPValue *V) const {
if (!V->isLiveIn())
diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
index fe59774b7c838..fc1a09e9850f6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
@@ -468,6 +468,10 @@ class VPlanPrinter {
};
#endif
+/// Check if a constant \p CI can be safely treated as having been extended
+/// from a narrower type with the given extension kind.
+bool canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
+ TTI::PartialReductionExtendKind ExtKind);
} // end namespace llvm
#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 46909a53a9547..a9e8bd17d0ae1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -340,6 +340,15 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
: Widen->getOperand(1));
ExtAType = GetExtendKind(ExtAR);
ExtBType = GetExtendKind(ExtBR);
+
+ if (!ExtBR && Widen->getOperand(1)->isLiveIn()) {
+ auto *CI =
+ dyn_cast<ConstantInt>(Widen->getOperand(1)->getLiveInIRValue());
+ if (CI && canConstantBeExtended(CI, InputTypeA, ExtAType)) {
+ InputTypeB = InputTypeA;
+ ExtBType = ExtAType;
+ }
+ }
};
if (isa<VPWidenCastRecipe>(OpR)) {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
index 0086f6e61cd36..b033f6051f812 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
@@ -20,22 +20,22 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63)
-; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
-; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
+; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
; CHECK: [[SCALAR_PH]]:
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ]
-; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ]
@@ -48,7 +48,7 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]]
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]]
; CHECK: [[EXIT]]:
-; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ]
; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]]
;
entry:
@@ -86,17 +86,17 @@ define i32 @red_zext_mul_by_255(ptr %start, ptr %end) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 255)
-; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
+; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
; CHECK: [[SCALAR_PH]]:
@@ -218,22 +218,22 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63)
-; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
-; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
+; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
; CHECK: [[SCALAR_PH]]:
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ]
-; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ]
@@ -246,7 +246,7 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]]
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP9:![0-9]+]]
; CHECK: [[EXIT]]:
-; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ]
; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]]
;
entry:
>From a30869fbad5ca4216b5234e3d725dfe1e6727dfe Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 2 Oct 2025 09:31:30 +0100
Subject: [PATCH 2/2] !fixup dyn_cast->cast, thanks!
---
llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a9e8bd17d0ae1..67b9244e9dc72 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -342,9 +342,8 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
ExtBType = GetExtendKind(ExtBR);
if (!ExtBR && Widen->getOperand(1)->isLiveIn()) {
- auto *CI =
- dyn_cast<ConstantInt>(Widen->getOperand(1)->getLiveInIRValue());
- if (CI && canConstantBeExtended(CI, InputTypeA, ExtAType)) {
+ auto *CI = cast<ConstantInt>(Widen->getOperand(1)->getLiveInIRValue());
+ if (canConstantBeExtended(CI, InputTypeA, ExtAType)) {
InputTypeB = InputTypeA;
ExtBType = ExtAType;
}
More information about the llvm-commits
mailing list