[llvm] [IVDesc] Unify calls to min-max patterns (PR #142769)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 4 04:48:09 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Ramkumar Ramachandra (artagnon)
<details>
<summary>Changes</summary>
Generalize isMinMaxPattern over all MinMax RecurKinds, and cut wasteful calls to addReductionVar for different MinMax RecurKinds in isReductionPHI, keeping just one integral and one floating-point MinMax RecurKind. Generalizing over integral and floating-point MinMax RecurKinds is left as an exercise for a follow-up. The patch has the side-effect of enabling vectorization when two MinMax patterns are mixed.
---
Full diff: https://github.com/llvm/llvm-project/pull/142769.diff
4 Files Affected:
- (modified) llvm/include/llvm/Analysis/IVDescriptors.h (+2-2)
- (modified) llvm/lib/Analysis/IVDescriptors.cpp (+38-85)
- (added) llvm/test/Transforms/LoopVectorize/minmax-reduction-mixed.ll (+68)
- (modified) llvm/test/Transforms/LoopVectorize/minmax_reduction.ll (+6-2)
``````````diff
diff --git a/llvm/include/llvm/Analysis/IVDescriptors.h b/llvm/include/llvm/Analysis/IVDescriptors.h
index d12780cde71d7..772a1c110a4c2 100644
--- a/llvm/include/llvm/Analysis/IVDescriptors.h
+++ b/llvm/include/llvm/Analysis/IVDescriptors.h
@@ -147,8 +147,8 @@ class RecurrenceDescriptor {
/// corresponding to a min(X, Y) or max(X, Y), matching the recurrence kind \p
/// Kind. \p Prev specifies the description of an already processed select
/// instruction, so its corresponding cmp can be matched to it.
- LLVM_ABI static InstDesc isMinMaxPattern(Instruction *I, RecurKind Kind,
- const InstDesc &Prev);
+ LLVM_ABI static InstDesc isMinMaxPattern(Instruction *I, const InstDesc &Prev,
+ FastMathFlags FuncFMF);
/// Returns a struct describing whether the instruction is either a
/// Select(ICmp(A, B), X, Y), or
diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp
index d13f2e139ee4a..64fcfde97c97a 100644
--- a/llvm/lib/Analysis/IVDescriptors.cpp
+++ b/llvm/lib/Analysis/IVDescriptors.cpp
@@ -486,7 +486,7 @@ bool RecurrenceDescriptor::AddReductionVar(
(!isConditionalRdxPattern(Kind, UI).isRecurrence() &&
!isAnyOfPattern(TheLoop, Phi, UI, IgnoredVal)
.isRecurrence() &&
- !isMinMaxPattern(UI, Kind, IgnoredVal).isRecurrence())))
+ !isMinMaxPattern(UI, IgnoredVal, FuncFMF).isRecurrence())))
return false;
// Remember that we completed the cycle.
@@ -744,13 +744,10 @@ RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
}
RecurrenceDescriptor::InstDesc
-RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
- const InstDesc &Prev) {
+RecurrenceDescriptor::isMinMaxPattern(Instruction *I, const InstDesc &Prev,
+ FastMathFlags FuncFMF) {
assert((isa<CmpInst>(I) || isa<SelectInst>(I) || isa<CallInst>(I)) &&
"Expected a cmp or select or call instruction");
- if (!isMinMaxRecurrenceKind(Kind))
- return InstDesc(false, I);
-
// We must handle the select(cmp()) as a single instruction. Advance to the
// select.
if (match(I, m_OneUse(m_Cmp()))) {
@@ -765,29 +762,40 @@ RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
// Look for a min/max pattern.
if (match(I, m_UMin(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::UMin, I);
+ return InstDesc(I, RecurKind::UMin);
if (match(I, m_UMax(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::UMax, I);
+ return InstDesc(I, RecurKind::UMax);
if (match(I, m_SMax(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::SMax, I);
+ return InstDesc(I, RecurKind::SMax);
if (match(I, m_SMin(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::SMin, I);
- if (match(I, m_OrdOrUnordFMin(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::FMin, I);
- if (match(I, m_OrdOrUnordFMax(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::FMax, I);
- if (match(I, m_FMinNum(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::FMin, I);
- if (match(I, m_FMaxNum(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::FMax, I);
+ return InstDesc(I, RecurKind::SMin);
+
+ // minimum/minnum and maximum/maxnum intrinsics do not require nsz and nnan
+ // flags since NaN and signed zeroes are propagated in the intrinsic
+ // implementation.
if (match(I, m_FMinimumNum(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::FMinimumNum, I);
+ return InstDesc(I, RecurKind::FMinimumNum);
if (match(I, m_FMaximumNum(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::FMaximumNum, I);
+ return InstDesc(I, RecurKind::FMaximumNum);
if (match(I, m_FMinimum(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::FMinimum, I);
+ return InstDesc(I, RecurKind::FMinimum);
if (match(I, m_FMaximum(m_Value(), m_Value())))
- return InstDesc(Kind == RecurKind::FMaximum, I);
+ return InstDesc(I, RecurKind::FMaximum);
+
+ bool HasRequiredFMF =
+ (FuncFMF.noNaNs() && FuncFMF.noSignedZeros()) ||
+ (isa<FPMathOperator>(I) && I->hasNoNaNs() && I->hasNoSignedZeros());
+ if (!HasRequiredFMF)
+ return InstDesc(false, I);
+
+ if (match(I, m_OrdOrUnordFMin(m_Value(), m_Value())))
+ return InstDesc(I, RecurKind::FMin);
+ if (match(I, m_OrdOrUnordFMax(m_Value(), m_Value())))
+ return InstDesc(I, RecurKind::FMax);
+ if (match(I, m_FMinNum(m_Value(), m_Value())))
+ return InstDesc(I, RecurKind::FMin);
+ if (match(I, m_FMaxNum(m_Value(), m_Value())))
+ return InstDesc(I, RecurKind::FMax);
return InstDesc(false, I);
}
@@ -883,24 +891,9 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
case Instruction::Call:
if (isAnyOfRecurrenceKind(Kind))
return isAnyOfPattern(L, OrigPhi, I, Prev);
- auto HasRequiredFMF = [&]() {
- if (FuncFMF.noNaNs() && FuncFMF.noSignedZeros())
- return true;
- if (isa<FPMathOperator>(I) && I->hasNoNaNs() && I->hasNoSignedZeros())
- return true;
- // minimum/minnum and maximum/maxnum intrinsics do not require nsz and nnan
- // flags since NaN and signed zeroes are propagated in the intrinsic
- // implementation.
- return match(I, m_Intrinsic<Intrinsic::minimum>(m_Value(), m_Value())) ||
- match(I, m_Intrinsic<Intrinsic::maximum>(m_Value(), m_Value())) ||
- match(I,
- m_Intrinsic<Intrinsic::minimumnum>(m_Value(), m_Value())) ||
- match(I, m_Intrinsic<Intrinsic::maximumnum>(m_Value(), m_Value()));
- };
- if (isIntMinMaxRecurrenceKind(Kind) ||
- (HasRequiredFMF() && isFPMinMaxRecurrenceKind(Kind)))
- return isMinMaxPattern(I, Kind, Prev);
- else if (isFMulAddIntrinsic(I))
+ if (isMinMaxRecurrenceKind(Kind))
+ return isMinMaxPattern(I, Prev, FuncFMF);
+ if (isFMulAddIntrinsic(I))
return InstDesc(Kind == RecurKind::FMulAdd, I,
I->hasAllowReassoc() ? nullptr : I);
return InstDesc(false, I);
@@ -961,22 +954,14 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
}
if (AddReductionVar(Phi, RecurKind::SMax, TheLoop, FMF, RedDes, DB, AC, DT,
SE)) {
- LLVM_DEBUG(dbgs() << "Found a SMAX reduction PHI." << *Phi << "\n");
- return true;
- }
- if (AddReductionVar(Phi, RecurKind::SMin, TheLoop, FMF, RedDes, DB, AC, DT,
- SE)) {
- LLVM_DEBUG(dbgs() << "Found a SMIN reduction PHI." << *Phi << "\n");
- return true;
- }
- if (AddReductionVar(Phi, RecurKind::UMax, TheLoop, FMF, RedDes, DB, AC, DT,
- SE)) {
- LLVM_DEBUG(dbgs() << "Found a UMAX reduction PHI." << *Phi << "\n");
+ LLVM_DEBUG(dbgs() << "Found an integral MinMax reduction PHI." << *Phi
+ << "\n");
return true;
}
- if (AddReductionVar(Phi, RecurKind::UMin, TheLoop, FMF, RedDes, DB, AC, DT,
+ if (AddReductionVar(Phi, RecurKind::FMax, TheLoop, FMF, RedDes, DB, AC, DT,
SE)) {
- LLVM_DEBUG(dbgs() << "Found a UMIN reduction PHI." << *Phi << "\n");
+ LLVM_DEBUG(dbgs() << "Found a floating-point MinMax reduction PHI." << *Phi
+ << "\n");
return true;
}
if (AddReductionVar(Phi, RecurKind::AnyOf, TheLoop, FMF, RedDes, DB, AC, DT,
@@ -1000,43 +985,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
LLVM_DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n");
return true;
}
- if (AddReductionVar(Phi, RecurKind::FMax, TheLoop, FMF, RedDes, DB, AC, DT,
- SE)) {
- LLVM_DEBUG(dbgs() << "Found a float MAX reduction PHI." << *Phi << "\n");
- return true;
- }
- if (AddReductionVar(Phi, RecurKind::FMin, TheLoop, FMF, RedDes, DB, AC, DT,
- SE)) {
- LLVM_DEBUG(dbgs() << "Found a float MIN reduction PHI." << *Phi << "\n");
- return true;
- }
if (AddReductionVar(Phi, RecurKind::FMulAdd, TheLoop, FMF, RedDes, DB, AC, DT,
SE)) {
LLVM_DEBUG(dbgs() << "Found an FMulAdd reduction PHI." << *Phi << "\n");
return true;
}
- if (AddReductionVar(Phi, RecurKind::FMaximum, TheLoop, FMF, RedDes, DB, AC, DT,
- SE)) {
- LLVM_DEBUG(dbgs() << "Found a float MAXIMUM reduction PHI." << *Phi << "\n");
- return true;
- }
- if (AddReductionVar(Phi, RecurKind::FMinimum, TheLoop, FMF, RedDes, DB, AC, DT,
- SE)) {
- LLVM_DEBUG(dbgs() << "Found a float MINIMUM reduction PHI." << *Phi << "\n");
- return true;
- }
- if (AddReductionVar(Phi, RecurKind::FMaximumNum, TheLoop, FMF, RedDes, DB, AC,
- DT, SE)) {
- LLVM_DEBUG(dbgs() << "Found a float MAXIMUMNUM reduction PHI." << *Phi
- << "\n");
- return true;
- }
- if (AddReductionVar(Phi, RecurKind::FMinimumNum, TheLoop, FMF, RedDes, DB, AC,
- DT, SE)) {
- LLVM_DEBUG(dbgs() << "Found a float MINIMUMNUM reduction PHI." << *Phi
- << "\n");
- return true;
- }
// Not a reduction of known type.
return false;
diff --git a/llvm/test/Transforms/LoopVectorize/minmax-reduction-mixed.ll b/llvm/test/Transforms/LoopVectorize/minmax-reduction-mixed.ll
new file mode 100644
index 0000000000000..77c50fd38e99d
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/minmax-reduction-mixed.ll
@@ -0,0 +1,68 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals none --version 5
+; RUN: opt -S -passes=loop-vectorize -force-vector-width=2 %s | FileCheck %s
+
+; A test with both smin and smax, with reduction on smax.
+
+define i32 @minmax.mixed(ptr %x, ptr %y) {
+; CHECK-LABEL: define i32 @minmax.mixed(
+; CHECK-SAME: ptr [[X:%.*]], ptr [[Y:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*]]:
+; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
+; CHECK: [[VECTOR_PH]]:
+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
+; CHECK: [[VECTOR_BODY]]:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <2 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 [[INDEX]]
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[TMP0]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <2 x i32>, ptr [[TMP1]], align 4
+; CHECK-NEXT: [[TMP2:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> [[VEC_PHI]], <2 x i32> [[WIDE_LOAD]])
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[Y]], i32 [[INDEX]]
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[TMP3]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <2 x i32>, ptr [[TMP4]], align 4
+; CHECK-NEXT: [[TMP5]] = call <2 x i32> @llvm.smax.v2i32(<2 x i32> [[TMP2]], <2 x i32> [[WIDE_LOAD1]])
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 2
+; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i32 [[INDEX_NEXT]], 1024
+; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK: [[MIDDLE_BLOCK]]:
+; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.smax.v2i32(<2 x i32> [[TMP5]])
+; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]]
+; CHECK: [[SCALAR_PH]]:
+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ 1024, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
+; CHECK-NEXT: br label %[[LOOP:.*]]
+; CHECK: [[LOOP]]:
+; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
+; CHECK-NEXT: [[RDX:%.*]] = phi i32 [ [[BC_MERGE_RDX]], %[[SCALAR_PH]] ], [ [[RDX_NEXT:%.*]], %[[LOOP]] ]
+; CHECK-NEXT: [[GEP_X_IV:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 [[IV]]
+; CHECK-NEXT: [[LD_X:%.*]] = load i32, ptr [[GEP_X_IV]], align 4
+; CHECK-NEXT: [[SMIN:%.*]] = tail call i32 @llvm.smin.i32(i32 [[RDX]], i32 [[LD_X]])
+; CHECK-NEXT: [[GEP_Y_IV:%.*]] = getelementptr inbounds i32, ptr [[Y]], i32 [[IV]]
+; CHECK-NEXT: [[LD_Y:%.*]] = load i32, ptr [[GEP_Y_IV]], align 4
+; CHECK-NEXT: [[RDX_NEXT]] = tail call i32 @llvm.smax.i32(i32 [[SMIN]], i32 [[LD_Y]])
+; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i32 [[IV]], 1
+; CHECK-NEXT: [[EXIT_COND:%.*]] = icmp eq i32 [[IV_NEXT]], 1024
+; CHECK-NEXT: br i1 [[EXIT_COND]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]]
+; CHECK: [[EXIT]]:
+; CHECK-NEXT: [[RDX_NEXT_LCSSA:%.*]] = phi i32 [ [[RDX_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: ret i32 [[RDX_NEXT_LCSSA]]
+;
+entry:
+ br label %loop
+
+loop:
+ %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop ]
+ %rdx = phi i32 [ 0, %entry ], [ %rdx.next, %loop ]
+ %gep.x.iv = getelementptr inbounds i32, ptr %x, i32 %iv
+ %ld.x = load i32, ptr %gep.x.iv, align 4
+ %smin = tail call i32 @llvm.smin.i32(i32 %rdx, i32 %ld.x)
+ %gep.y.iv = getelementptr inbounds i32, ptr %y, i32 %iv
+ %ld.y = load i32, ptr %gep.y.iv, align 4
+ %rdx.next = tail call i32 @llvm.smax.i32(i32 %smin, i32 %ld.y)
+ %iv.next = add nuw nsw i32 %iv, 1
+ %exit.cond = icmp eq i32 %iv.next, 1024
+ br i1 %exit.cond, label %exit, label %loop
+
+exit:
+ ret i32 %rdx.next
+}
diff --git a/llvm/test/Transforms/LoopVectorize/minmax_reduction.ll b/llvm/test/Transforms/LoopVectorize/minmax_reduction.ll
index 85a90f2e04c5e..1b72728da4f87 100644
--- a/llvm/test/Transforms/LoopVectorize/minmax_reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/minmax_reduction.ll
@@ -1042,8 +1042,12 @@ for.body: ; preds = %entry, %for.body
}
; CHECK-LABEL: @sminmax(
-; Min and max intrinsics - don't vectorize
-; CHECK-NOT: <2 x i32>
+; CHECK: <2 x i32> @llvm.smin.v2i32
+; CHECK: <2 x i32> @llvm.smin.v2i32
+; CHECK: <2 x i32> @llvm.smax.v2i32
+; CHECK: <2 x i32> @llvm.smax.v2i32
+; CHECK: middle.block
+; CHECK: i32 @llvm.vector.reduce.smax.v2i32
define i32 @sminmax(ptr nocapture readonly %x, ptr nocapture readonly %y) {
entry:
br label %for.body
``````````
</details>
https://github.com/llvm/llvm-project/pull/142769
More information about the llvm-commits
mailing list