[llvm] 61cc873 - [LV] Recognize intrinsic min/max reductions

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 15 02:45:59 PDT 2021


Author: David Green
Date: 2021-09-15T10:45:50+01:00
New Revision: 61cc873a8ef1f3c77114b5322cf1f9f551c74978

URL: https://github.com/llvm/llvm-project/commit/61cc873a8ef1f3c77114b5322cf1f9f551c74978
DIFF: https://github.com/llvm/llvm-project/commit/61cc873a8ef1f3c77114b5322cf1f9f551c74978.diff

LOG: [LV] Recognize intrinsic min/max reductions

This extends the reduction logic in the vectorizer to handle intrinsic
versions of min and max, both the floating point variants already
created by instcombine under fastmath and the integer variants from
D98152.

As a bonus this allows us to match a chain of min or max operations into
a single reduction, similar to how add/mul/etc work.

Differential Revision: https://reviews.llvm.org/D109645

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/IVDescriptors.h
    llvm/lib/Analysis/IVDescriptors.cpp
    llvm/test/Transforms/LoopVectorize/minmax_reduction.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/IVDescriptors.h b/llvm/include/llvm/Analysis/IVDescriptors.h
index 98148f61baa19..59ad0a335a3a0 100644
--- a/llvm/include/llvm/Analysis/IVDescriptors.h
+++ b/llvm/include/llvm/Analysis/IVDescriptors.h
@@ -117,7 +117,7 @@ class RecurrenceDescriptor {
   /// compare instruction to the select instruction and stores this pointer in
   /// 'PatternLastInst' member of the returned struct.
   static InstDesc isRecurrenceInstr(Instruction *I, RecurKind Kind,
-                                    InstDesc &Prev, FastMathFlags FMF);
+                                    InstDesc &Prev, FastMathFlags FuncFMF);
 
   /// Returns true if instruction I has multiple uses in Insts
   static bool hasMultipleUsesOf(Instruction *I,
@@ -127,12 +127,13 @@ class RecurrenceDescriptor {
   /// Returns true if all uses of the instruction I is within the Set.
   static bool areAllUsesIn(Instruction *I, SmallPtrSetImpl<Instruction *> &Set);
 
-  /// Returns a struct describing if the instruction is a
-  /// Select(ICmp(X, Y), X, Y) instruction pattern corresponding to a min(X, Y)
-  /// or max(X, Y). \p Prev specifies the description of an already processed
-  /// select instruction, so its corresponding cmp can be matched to it.
-  static InstDesc isMinMaxSelectCmpPattern(Instruction *I,
-                                           const InstDesc &Prev);
+  /// Returns a struct describing if the instruction is a llvm.(s/u)(min/max),
+  /// llvm.minnum/maxnum or a Select(ICmp(X, Y), X, Y) pair of instructions
+  /// corresponding to a min(X, Y) or max(X, Y), matching the recurrence kind \p
+  /// Kind. \p Prev specifies the description of an already processed select
+  /// instruction, so its corresponding cmp can be matched to it.
+  static InstDesc isMinMaxPattern(Instruction *I, RecurKind Kind,
+                                  const InstDesc &Prev);
 
   /// Returns a struct describing if the instruction is a
   /// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
@@ -150,7 +151,7 @@ class RecurrenceDescriptor {
   /// non-null, the minimal bit width needed to compute the reduction will be
   /// computed.
   static bool AddReductionVar(PHINode *Phi, RecurKind Kind, Loop *TheLoop,
-                              FastMathFlags FMF,
+                              FastMathFlags FuncFMF,
                               RecurrenceDescriptor &RedDes,
                               DemandedBits *DB = nullptr,
                               AssumptionCache *AC = nullptr,

diff  --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp
index 53375efe97643..c04083c2a6101 100644
--- a/llvm/lib/Analysis/IVDescriptors.cpp
+++ b/llvm/lib/Analysis/IVDescriptors.cpp
@@ -423,7 +423,8 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
                  ((!isa<FCmpInst>(UI) && !isa<ICmpInst>(UI) &&
                    !isa<SelectInst>(UI)) ||
                   (!isConditionalRdxPattern(Kind, UI).isRecurrence() &&
-                   !isMinMaxSelectCmpPattern(UI, IgnoredVal).isRecurrence())))
+                   !isMinMaxPattern(UI, Kind, IgnoredVal)
+                        .isRecurrence())))
         return false;
 
       // Remember that we completed the cycle.
@@ -435,8 +436,10 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
   }
 
   // This means we have seen one but not the other instruction of the
-  // pattern or more than just a select and cmp.
-  if (isMinMaxRecurrenceKind(Kind) && NumCmpSelectPatternInst != 2)
+  // pattern or more than just a select and cmp. Zero implies that we saw a
+  // llvm.min/max instrinsic, which is always OK.
+  if (isMinMaxRecurrenceKind(Kind) && NumCmpSelectPatternInst != 2 &&
+      NumCmpSelectPatternInst != 0)
     return false;
 
   if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction)
@@ -506,10 +509,12 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
 }
 
 RecurrenceDescriptor::InstDesc
-RecurrenceDescriptor::isMinMaxSelectCmpPattern(Instruction *I,
-                                               const InstDesc &Prev) {
-  assert((isa<CmpInst>(I) || isa<SelectInst>(I)) &&
-         "Expected a cmp or select instruction");
+RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
+                                      const InstDesc &Prev) {
+  assert((isa<CmpInst>(I) || isa<SelectInst>(I) || isa<CallInst>(I)) &&
+         "Expected a cmp or select or call instruction");
+  if (!isMinMaxRecurrenceKind(Kind))
+    return InstDesc(false, I);
 
   // We must handle the select(cmp()) as a single instruction. Advance to the
   // select.
@@ -519,28 +524,33 @@ RecurrenceDescriptor::isMinMaxSelectCmpPattern(Instruction *I,
       return InstDesc(Select, Prev.getRecKind());
   }
 
-  // Only match select with single use cmp condition.
-  if (!match(I, m_Select(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), m_Value(),
+  // Only match select with single use cmp condition, or a min/max intrinsic.
+  if (!isa<IntrinsicInst>(I) &&
+      !match(I, m_Select(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), m_Value(),
                          m_Value())))
     return InstDesc(false, I);
 
   // Look for a min/max pattern.
   if (match(I, m_UMin(m_Value(), m_Value())))
-    return InstDesc(I, RecurKind::UMin);
+    return InstDesc(Kind == RecurKind::UMin, I);
   if (match(I, m_UMax(m_Value(), m_Value())))
-    return InstDesc(I, RecurKind::UMax);
+    return InstDesc(Kind == RecurKind::UMax, I);
   if (match(I, m_SMax(m_Value(), m_Value())))
-    return InstDesc(I, RecurKind::SMax);
+    return InstDesc(Kind == RecurKind::SMax, I);
   if (match(I, m_SMin(m_Value(), m_Value())))
-    return InstDesc(I, RecurKind::SMin);
+    return InstDesc(Kind == RecurKind::SMin, I);
   if (match(I, m_OrdFMin(m_Value(), m_Value())))
-    return InstDesc(I, RecurKind::FMin);
+    return InstDesc(Kind == RecurKind::FMin, I);
   if (match(I, m_OrdFMax(m_Value(), m_Value())))
-    return InstDesc(I, RecurKind::FMax);
+    return InstDesc(Kind == RecurKind::FMax, I);
   if (match(I, m_UnordFMin(m_Value(), m_Value())))
-    return InstDesc(I, RecurKind::FMin);
+    return InstDesc(Kind == RecurKind::FMin, I);
   if (match(I, m_UnordFMax(m_Value(), m_Value())))
-    return InstDesc(I, RecurKind::FMax);
+    return InstDesc(Kind == RecurKind::FMax, I);
+  if (match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_Value())))
+    return InstDesc(Kind == RecurKind::FMin, I);
+  if (match(I, m_Intrinsic<Intrinsic::maxnum>(m_Value(), m_Value())))
+    return InstDesc(Kind == RecurKind::FMax, I);
 
   return InstDesc(false, I);
 }
@@ -593,7 +603,8 @@ RecurrenceDescriptor::isConditionalRdxPattern(RecurKind Kind, Instruction *I) {
 
 RecurrenceDescriptor::InstDesc
 RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurKind Kind,
-                                        InstDesc &Prev, FastMathFlags FMF) {
+                                        InstDesc &Prev, FastMathFlags FuncFMF) {
+  assert(Prev.getRecKind() == RecurKind::None || Prev.getRecKind() == Kind);
   switch (I->getOpcode()) {
   default:
     return InstDesc(false, I);
@@ -624,9 +635,13 @@ RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurKind Kind,
     LLVM_FALLTHROUGH;
   case Instruction::FCmp:
   case Instruction::ICmp:
+  case Instruction::Call:
     if (isIntMinMaxRecurrenceKind(Kind) ||
-        (FMF.noNaNs() && FMF.noSignedZeros() && isFPMinMaxRecurrenceKind(Kind)))
-      return isMinMaxSelectCmpPattern(I, Prev);
+        (((FuncFMF.noNaNs() && FuncFMF.noSignedZeros()) ||
+          (isa<FPMathOperator>(I) && I->hasNoNaNs() &&
+           I->hasNoSignedZeros())) &&
+         isFPMinMaxRecurrenceKind(Kind)))
+      return isMinMaxPattern(I, Kind, Prev);
     return InstDesc(false, I);
   }
 }

diff  --git a/llvm/test/Transforms/LoopVectorize/minmax_reduction.ll b/llvm/test/Transforms/LoopVectorize/minmax_reduction.ll
index dc55f1081d70f..73a54e7dba09f 100644
--- a/llvm/test/Transforms/LoopVectorize/minmax_reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/minmax_reduction.ll
@@ -876,7 +876,8 @@ for.end:
 }
 
 ; CHECK-LABEL: @smin_intrinsic(
-; CHECK-NOT: <2 x i32> @llvm.smin.v2i32
+; CHECK: <2 x i32> @llvm.smin.v2i32
+; CHECK: i32 @llvm.vector.reduce.smin.v2i32
 define i32 @smin_intrinsic(i32* nocapture readonly %x) {
 entry:
   br label %for.body
@@ -896,7 +897,8 @@ for.cond.cleanup:                                 ; preds = %for.body
 }
 
 ; CHECK-LABEL: @smax_intrinsic(
-; CHECK-NOT: <2 x i32> @llvm.smax.v2i32
+; CHECK: <2 x i32> @llvm.smax.v2i32
+; CHECK: i32 @llvm.vector.reduce.smax.v2i32
 define i32 @smax_intrinsic(i32* nocapture readonly %x) {
 entry:
   br label %for.body
@@ -916,7 +918,8 @@ for.cond.cleanup:                                 ; preds = %for.body
 }
 
 ; CHECK-LABEL: @umin_intrinsic(
-; CHECK-NOT: <2 x i32> @llvm.umin.v2i32
+; CHECK: <2 x i32> @llvm.umin.v2i32
+; CHECK: i32 @llvm.vector.reduce.umin.v2i32
 define i32 @umin_intrinsic(i32* nocapture readonly %x) {
 entry:
   br label %for.body
@@ -936,7 +939,8 @@ for.cond.cleanup:                                 ; preds = %for.body
 }
 
 ; CHECK-LABEL: @umax_intrinsic(
-; CHECK-NOT: <2 x i32> @llvm.umax.v2i32
+; CHECK: <2 x i32> @llvm.umax.v2i32
+; CHECK: i32 @llvm.vector.reduce.umax.v2i32
 define i32 @umax_intrinsic(i32* nocapture readonly %x) {
 entry:
   br label %for.body
@@ -956,7 +960,8 @@ for.cond.cleanup:                                 ; preds = %for.body
 }
 
 ; CHECK-LABEL: @fmin_intrinsic(
-; CHECK-NOT: nnan nsz <2 x float> @llvm.minnum.v2f32
+; CHECK: nnan nsz <2 x float> @llvm.minnum.v2f32
+; CHECK: nnan nsz float @llvm.vector.reduce.fmin.v2f32
 define float @fmin_intrinsic(float* nocapture readonly %x) {
 entry:
   br label %for.body
@@ -976,7 +981,8 @@ for.body:                                         ; preds = %entry, %for.body
 }
 
 ; CHECK-LABEL: @fmax_intrinsic(
-; CHECK-NOT: fast <2 x float> @llvm.maxnum.v2f32
+; CHECK: fast <2 x float> @llvm.maxnum.v2f32
+; CHECK: fast float @llvm.vector.reduce.fmax.v2f32
 define float @fmax_intrinsic(float* nocapture readonly %x) {
 entry:
   br label %for.body
@@ -1060,8 +1066,9 @@ for.body:                                         ; preds = %entry, %for.body
 }
 
 ; CHECK-LABEL: @sminmin(
-; CHECK-NOT: <2 x i32> @llvm.smin.v2i32
-; CHECK-NOT: <2 x i32> @llvm.smin.v2i32
+; CHECK: <2 x i32> @llvm.smin.v2i32
+; CHECK: <2 x i32> @llvm.smin.v2i32
+; CHECK: i32 @llvm.vector.reduce.smin.v2i32
 define i32 @sminmin(i32* nocapture readonly %x, i32* nocapture readonly %y) {
 entry:
   br label %for.body


        


More information about the llvm-commits mailing list