[llvm] [LV] Support multiplies by constants when forming scaled reductions. (PR #161092)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Sun Sep 28 12:43:25 PDT 2025
https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/161092
We can create partial reductions for multiplies with constants, if the constant is small enough to be extended from source to destination type w/o changing the value.
This only handles constant on the right side of a multiply, relying on other passes to canonicalize the input.
Alive2 Proofs: https://alive2.llvm.org/ce/z/iWRMr6
>From f29245a8aa7b8aacbcf6c627f22ea46247a2a1cc Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 25 Sep 2025 14:09:34 +0100
Subject: [PATCH] [LV] Support multiplies by constants when forming scaled
reductions.
We can create partial reductions for multiplies with constants, if the
constant is small enough to be extended from source to destination type
w/o changing the value.
This only handles constant on the right side of a multiply, relying on
other passes to canonicalize the input.
Alive2 Proofs: https://alive2.llvm.org/ce/z/iWRMr6
---
.../Transforms/Vectorize/LoopVectorize.cpp | 7 ++++
llvm/lib/Transforms/Vectorize/VPlan.cpp | 10 ++++++
llvm/lib/Transforms/Vectorize/VPlanHelpers.h | 4 +++
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 9 +++++
.../AArch64/partial-reduce-constant-ops.ll | 34 +++++++++----------
5 files changed, 47 insertions(+), 17 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 96f52076b1837..7e0249259c11a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7935,6 +7935,13 @@ bool VPRecipeBuilder::getScaledReductions(
auto CollectExtInfo = [this, &Exts, &ExtOpTypes,
&ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
for (const auto &[I, OpI] : enumerate(Ops)) {
+ auto *CI = dyn_cast<ConstantInt>(OpI);
+ if (I > 0 && CI &&
+ canConstantBeExtended(CI, ExtOpTypes[0], ExtKinds[0])) {
+ ExtOpTypes[I] = ExtOpTypes[0];
+ ExtKinds[I] = ExtKinds[0];
+ continue;
+ }
Value *ExtOp;
if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
return false;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 81f1956c96254..6273c066b00d4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1741,6 +1741,16 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) {
}
#endif
+bool llvm::canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
+ TTI::PartialReductionExtendKind ExtKind) {
+ APInt TruncatedVal = CI->getValue().trunc(NarrowType->getScalarSizeInBits());
+ unsigned WideSize = CI->getType()->getScalarSizeInBits();
+ APInt ExtendedVal = ExtKind == TTI::PR_SignExtend
+ ? TruncatedVal.sext(WideSize)
+ : TruncatedVal.zext(WideSize);
+ return ExtendedVal == CI->getValue();
+}
+
TargetTransformInfo::OperandValueInfo
VPCostContext::getOperandInfo(VPValue *V) const {
if (!V->isLiveIn())
diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
index fe59774b7c838..fc1a09e9850f6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
@@ -468,6 +468,10 @@ class VPlanPrinter {
};
#endif
+/// Check if a constant \p CI can be safely treated as having been extended
+/// from a narrower type with the given extension kind.
+bool canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
+ TTI::PartialReductionExtendKind ExtKind);
} // end namespace llvm
#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index cf5e6bf0ac418..c5f3f5bcc6042 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -340,6 +340,15 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
: Widen->getOperand(1));
ExtAType = GetExtendKind(ExtAR);
ExtBType = GetExtendKind(ExtBR);
+
+ if (!ExtBR && Widen->getOperand(1)->isLiveIn()) {
+ auto *CI =
+ dyn_cast<ConstantInt>(Widen->getOperand(1)->getLiveInIRValue());
+ if (CI && canConstantBeExtended(CI, InputTypeA, ExtAType)) {
+ InputTypeB = InputTypeA;
+ ExtBType = ExtAType;
+ }
+ }
};
if (isa<VPWidenCastRecipe>(OpR)) {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
index 0086f6e61cd36..b033f6051f812 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
@@ -20,22 +20,22 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63)
-; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
-; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
+; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
; CHECK: [[SCALAR_PH]]:
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ]
-; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ]
@@ -48,7 +48,7 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]]
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]]
; CHECK: [[EXIT]]:
-; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ]
; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]]
;
entry:
@@ -86,17 +86,17 @@ define i32 @red_zext_mul_by_255(ptr %start, ptr %end) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 255)
-; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
+; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
; CHECK: [[SCALAR_PH]]:
@@ -218,22 +218,22 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63)
-; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
-; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
+; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
; CHECK: [[SCALAR_PH]]:
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ]
-; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ]
@@ -246,7 +246,7 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]]
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP9:![0-9]+]]
; CHECK: [[EXIT]]:
-; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ]
; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]]
;
entry:
More information about the llvm-commits
mailing list