[llvm] 7c4f188 - [LV] Support multiplies by constants when forming scaled reductions. (#161092)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 2 03:53:21 PDT 2025
Author: Florian Hahn
Date: 2025-10-02T10:53:17Z
New Revision: 7c4f188f27ee7c562c2aa11b2384fd0ef918be94
URL: https://github.com/llvm/llvm-project/commit/7c4f188f27ee7c562c2aa11b2384fd0ef918be94
DIFF: https://github.com/llvm/llvm-project/commit/7c4f188f27ee7c562c2aa11b2384fd0ef918be94.diff
LOG: [LV] Support multiplies by constants when forming scaled reductions. (#161092)
We can create partial reductions for multiplies with constants, if the
constant is small enough to be extended from source to destination type
w/o changing the value.
This only handles constant on the right side of a multiply, relying on
other passes to canonicalize the input.
Alive2 Proofs: https://alive2.llvm.org/ce/z/iWRMr6
PR: https://github.com/llvm/llvm-project/pull/161092
Added:
Modified:
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
llvm/lib/Transforms/Vectorize/VPlan.cpp
llvm/lib/Transforms/Vectorize/VPlanHelpers.h
llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e5d6c8118eb55..7fa787bc9befd 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7954,6 +7954,13 @@ bool VPRecipeBuilder::getScaledReductions(
auto CollectExtInfo = [this, &Exts, &ExtOpTypes,
&ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
for (const auto &[I, OpI] : enumerate(Ops)) {
+ auto *CI = dyn_cast<ConstantInt>(OpI);
+ if (I > 0 && CI &&
+ canConstantBeExtended(CI, ExtOpTypes[0], ExtKinds[0])) {
+ ExtOpTypes[I] = ExtOpTypes[0];
+ ExtKinds[I] = ExtKinds[0];
+ continue;
+ }
Value *ExtOp;
if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
return false;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 02eb6375aac41..07b191a787806 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1753,6 +1753,16 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) {
}
#endif
+bool llvm::canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
+ TTI::PartialReductionExtendKind ExtKind) {
+ APInt TruncatedVal = CI->getValue().trunc(NarrowType->getScalarSizeInBits());
+ unsigned WideSize = CI->getType()->getScalarSizeInBits();
+ APInt ExtendedVal = ExtKind == TTI::PR_SignExtend
+ ? TruncatedVal.sext(WideSize)
+ : TruncatedVal.zext(WideSize);
+ return ExtendedVal == CI->getValue();
+}
+
TargetTransformInfo::OperandValueInfo
VPCostContext::getOperandInfo(VPValue *V) const {
if (!V->isLiveIn())
diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
index fe59774b7c838..fc1a09e9850f6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
@@ -468,6 +468,10 @@ class VPlanPrinter {
};
#endif
+/// Check if a constant \p CI can be safely treated as having been extended
+/// from a narrower type with the given extension kind.
+bool canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
+ TTI::PartialReductionExtendKind ExtKind);
} // end namespace llvm
#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 46909a53a9547..67b9244e9dc72 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -340,6 +340,14 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
: Widen->getOperand(1));
ExtAType = GetExtendKind(ExtAR);
ExtBType = GetExtendKind(ExtBR);
+
+ if (!ExtBR && Widen->getOperand(1)->isLiveIn()) {
+ auto *CI = cast<ConstantInt>(Widen->getOperand(1)->getLiveInIRValue());
+ if (canConstantBeExtended(CI, InputTypeA, ExtAType)) {
+ InputTypeB = InputTypeA;
+ ExtBType = ExtAType;
+ }
+ }
};
if (isa<VPWidenCastRecipe>(OpR)) {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
index 0086f6e61cd36..b033f6051f812 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
@@ -20,22 +20,22 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63)
-; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
-; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
+; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
; CHECK: [[SCALAR_PH]]:
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ]
-; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ]
@@ -48,7 +48,7 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]]
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]]
; CHECK: [[EXIT]]:
-; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ]
; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]]
;
entry:
@@ -86,17 +86,17 @@ define i32 @red_zext_mul_by_255(ptr %start, ptr %end) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 255)
-; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
+; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
; CHECK: [[SCALAR_PH]]:
@@ -218,22 +218,22 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63)
-; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
-; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
-; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
+; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
; CHECK: [[SCALAR_PH]]:
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ]
-; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ]
@@ -246,7 +246,7 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) {
; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]]
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP9:![0-9]+]]
; CHECK: [[EXIT]]:
-; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ]
; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]]
;
entry:
More information about the llvm-commits
mailing list