[llvm-branch-commits] [llvm] 8ca60db - [LoopUtils] reduce FMF and min/max complexity when forming reductions

Sanjay Patel via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Dec 30 12:35:26 PST 2020


Author: Sanjay Patel
Date: 2020-12-30T15:22:26-05:00
New Revision: 8ca60db40bd944dc5f67e0f200a403b4e03818ea

URL: https://github.com/llvm/llvm-project/commit/8ca60db40bd944dc5f67e0f200a403b4e03818ea
DIFF: https://github.com/llvm/llvm-project/commit/8ca60db40bd944dc5f67e0f200a403b4e03818ea.diff

LOG: [LoopUtils] reduce FMF and min/max complexity when forming reductions

I don't know if there's some way this changes what the vectorizers
may produce for reductions, but I have added test coverage with
3567908 and 5ced712 to show that both passes already have bugs in
this area. Hopefully this does not make things worse before we can
really fix it.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/LoopUtils.h
    llvm/lib/Transforms/Utils/LoopUtils.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index ef348ed56129..ba2bb0a4c6b0 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -365,24 +365,21 @@ Value *getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op,
 
 /// Create a target reduction of the given vector. The reduction operation
 /// is described by the \p Opcode parameter. min/max reductions require
-/// additional information supplied in \p Flags.
+/// additional information supplied in \p MinMaxKind.
 /// The target is queried to determine if intrinsics or shuffle sequences are
 /// required to implement the reduction.
 /// Fast-math-flags are propagated using the IRBuilder's setting.
-Value *createSimpleTargetReduction(IRBuilderBase &B,
-                                   const TargetTransformInfo *TTI,
-                                   unsigned Opcode, Value *Src,
-                                   TargetTransformInfo::ReductionFlags Flags =
-                                       TargetTransformInfo::ReductionFlags(),
-                                   ArrayRef<Value *> RedOps = None);
+Value *createSimpleTargetReduction(
+    IRBuilderBase &B, const TargetTransformInfo *TTI, unsigned Opcode,
+    Value *Src, RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind,
+    ArrayRef<Value *> RedOps = None);
 
 /// Create a generic target reduction using a recurrence descriptor \p Desc
 /// The target is queried to determine if intrinsics or shuffle sequences are
 /// required to implement the reduction.
 /// Fast-math-flags are propagated using the RecurrenceDescriptor.
 Value *createTargetReduction(IRBuilderBase &B, const TargetTransformInfo *TTI,
-                             RecurrenceDescriptor &Desc, Value *Src,
-                             bool NoNaN = false);
+                             RecurrenceDescriptor &Desc, Value *Src);
 
 /// Get the intersection (logical and) of all of the potential IR flags
 /// of each scalar operation (VL) that will be converted into a vector (I).

diff  --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index a3665a5636e5..8dc7709c6e55 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -985,14 +985,12 @@ llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op,
 /// flags (if generating min/max reductions).
 Value *llvm::createSimpleTargetReduction(
     IRBuilderBase &Builder, const TargetTransformInfo *TTI, unsigned Opcode,
-    Value *Src, TargetTransformInfo::ReductionFlags Flags,
+    Value *Src, RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind,
     ArrayRef<Value *> RedOps) {
   auto *SrcVTy = cast<VectorType>(Src->getType());
 
   std::function<Value *()> BuildFunc;
   using RD = RecurrenceDescriptor;
-  RD::MinMaxRecurrenceKind MinMaxKind = RD::MRK_Invalid;
-
   switch (Opcode) {
   case Instruction::Add:
     BuildFunc = [&]() { return Builder.CreateAddReduce(Src); };
@@ -1024,33 +1022,42 @@ Value *llvm::createSimpleTargetReduction(
     };
     break;
   case Instruction::ICmp:
-    if (Flags.IsMaxOp) {
-      MinMaxKind = Flags.IsSigned ? RD::MRK_SIntMax : RD::MRK_UIntMax;
-      BuildFunc = [&]() {
-        return Builder.CreateIntMaxReduce(Src, Flags.IsSigned);
-      };
-    } else {
-      MinMaxKind = Flags.IsSigned ? RD::MRK_SIntMin : RD::MRK_UIntMin;
-      BuildFunc = [&]() {
-        return Builder.CreateIntMinReduce(Src, Flags.IsSigned);
-      };
+    switch (MinMaxKind) {
+    case RD::MRK_SIntMax:
+      BuildFunc = [&]() { return Builder.CreateIntMaxReduce(Src, true); };
+      break;
+    case RD::MRK_SIntMin:
+      BuildFunc = [&]() { return Builder.CreateIntMinReduce(Src, true); };
+      break;
+    case RD::MRK_UIntMax:
+      BuildFunc = [&]() { return Builder.CreateIntMaxReduce(Src, false); };
+      break;
+    case RD::MRK_UIntMin:
+      BuildFunc = [&]() { return Builder.CreateIntMinReduce(Src, false); };
+      break;
+    default:
+      llvm_unreachable("Unexpected min/max reduction type");
     }
     break;
   case Instruction::FCmp:
-    if (Flags.IsMaxOp) {
-      MinMaxKind = RD::MRK_FloatMax;
+    assert((MinMaxKind == RD::MRK_FloatMax || MinMaxKind == RD::MRK_FloatMin) &&
+           "Unexpected min/max reduction type");
+    if (MinMaxKind == RD::MRK_FloatMax)
       BuildFunc = [&]() { return Builder.CreateFPMaxReduce(Src); };
-    } else {
-      MinMaxKind = RD::MRK_FloatMin;
+    else
       BuildFunc = [&]() { return Builder.CreateFPMinReduce(Src); };
-    }
     break;
   default:
     llvm_unreachable("Unhandled opcode");
-    break;
   }
+  TargetTransformInfo::ReductionFlags RdxFlags;
+  RdxFlags.IsMaxOp = MinMaxKind == RD::MRK_SIntMax ||
+                     MinMaxKind == RD::MRK_UIntMax ||
+                     MinMaxKind == RD::MRK_FloatMax;
+  RdxFlags.IsSigned =
+      MinMaxKind == RD::MRK_SIntMax || MinMaxKind == RD::MRK_SIntMin;
   if (ForceReductionIntrinsic ||
-      TTI->useReductionIntrinsic(Opcode, Src->getType(), Flags))
+      TTI->useReductionIntrinsic(Opcode, Src->getType(), RdxFlags))
     return BuildFunc();
   return getShuffleReduction(Builder, Src, Opcode, MinMaxKind, RedOps);
 }
@@ -1058,12 +1065,9 @@ Value *llvm::createSimpleTargetReduction(
 /// Create a vector reduction using a given recurrence descriptor.
 Value *llvm::createTargetReduction(IRBuilderBase &B,
                                    const TargetTransformInfo *TTI,
-                                   RecurrenceDescriptor &Desc, Value *Src,
-                                   bool NoNaN) {
+                                   RecurrenceDescriptor &Desc, Value *Src) {
   // TODO: Support in-order reductions based on the recurrence descriptor.
   using RD = RecurrenceDescriptor;
-  TargetTransformInfo::ReductionFlags Flags;
-  Flags.NoNaN = NoNaN;
 
   // All ops in the reduction inherit fast-math-flags from the recurrence
   // descriptor.
@@ -1071,11 +1075,8 @@ Value *llvm::createTargetReduction(IRBuilderBase &B,
   B.setFastMathFlags(Desc.getFastMathFlags());
 
   RD::MinMaxRecurrenceKind MMKind = Desc.getMinMaxRecurrenceKind();
-  Flags.IsMaxOp = MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_UIntMax ||
-                  MMKind == RD::MRK_FloatMax;
-  Flags.IsSigned = MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_SIntMin;
   return createSimpleTargetReduction(B, TTI, Desc.getRecurrenceBinOp(), Src,
-                                     Flags);
+                                     MMKind);
 }
 
 void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue) {

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c48b650c3c3e..3bc946603eb0 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -4325,9 +4325,8 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) {
   // Create the reduction after the loop. Note that inloop reductions create the
   // target reduction in the loop using a Reduction recipe.
   if (VF.isVector() && !IsInLoopReductionPhi) {
-    bool NoNaN = Legal->hasFunNoNaNAttr();
     ReducedPartRdx =
-        createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx, NoNaN);
+        createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx);
     // If the reduction can be performed in a smaller type, we need to extend
     // the reduction to the wider type before we branch to the original loop.
     if (Phi->getType() != RdxDesc.getRecurrenceType())
@@ -8783,7 +8782,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
       NewVecOp = Select;
     }
     Value *NewRed =
-        createTargetReduction(State.Builder, TTI, *RdxDesc, NewVecOp, NoNaN);
+        createTargetReduction(State.Builder, TTI, *RdxDesc, NewVecOp);
     Value *PrevInChain = State.get(getChainOp(), Part);
     Value *NextInChain;
     if (Kind == RecurrenceDescriptor::RK_IntegerMinMax ||

diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index eff0690eda82..2bbac1ac549e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6455,7 +6455,7 @@ class HorizontalReduction {
 
     /// Kind of the reduction operation.
     RD::RecurrenceKind Kind = RD::RK_NoRecurrence;
-    TargetTransformInfo::ReductionFlags RdxFlags;
+    RD::MinMaxRecurrenceKind MMKind = RD::MRK_Invalid;
 
     /// Checks if the reduction operation can be vectorized.
     bool isVectorizable() const {
@@ -6499,10 +6499,13 @@ class HorizontalReduction {
       case RD::RK_IntegerMinMax: {
         assert(Opcode == Instruction::ICmp && "Expected integer types.");
         ICmpInst::Predicate Pred;
-        if (RdxFlags.IsMaxOp)
-          Pred = RdxFlags.IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
-        else
-          Pred = RdxFlags.IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
+        switch (MMKind) {
+        case RD::MRK_SIntMax: Pred = ICmpInst::ICMP_SGT; break;
+        case RD::MRK_SIntMin: Pred = ICmpInst::ICMP_SLT; break;
+        case RD::MRK_UIntMax: Pred = ICmpInst::ICMP_UGT; break;
+        case RD::MRK_UIntMin: Pred = ICmpInst::ICMP_ULT; break;
+        default: llvm_unreachable("Unexpected min/max value");
+        }
         Value *Cmp = Builder.CreateICmp(Pred, LHS, RHS, Name);
         return Builder.CreateSelect(Cmp, LHS, RHS, Name);
       }
@@ -6521,9 +6524,9 @@ class HorizontalReduction {
     }
 
     /// Constructor for reduction operations with opcode and type.
-    OperationData(unsigned Opcode, RD::RecurrenceKind Kind,
-                  TargetTransformInfo::ReductionFlags Flags)
-        : Opcode(Opcode), Kind(Kind), RdxFlags(Flags) {
+    OperationData(unsigned Opcode, RD::RecurrenceKind RdxKind,
+                  RD::MinMaxRecurrenceKind MinMaxKind)
+        : Opcode(Opcode), Kind(RdxKind), MMKind(MinMaxKind) {
       assert(Kind != RD::RK_NoRecurrence && "Expected reduction operation.");
     }
 
@@ -6640,6 +6643,7 @@ class HorizontalReduction {
 
     /// Get kind of reduction data.
     RD::RecurrenceKind getKind() const { return Kind; }
+    RD::MinMaxRecurrenceKind getMinMaxKind() const { return MMKind; }
     Value *getLHS(Instruction *I) const {
       if (Kind == RD::RK_NoRecurrence)
         return nullptr;
@@ -6706,8 +6710,6 @@ class HorizontalReduction {
         llvm_unreachable("Unknown reduction operation.");
       }
     }
-
-    TargetTransformInfo::ReductionFlags getFlags() const { return RdxFlags; }
   };
 
   WeakTrackingVH ReductionRoot;
@@ -6749,28 +6751,32 @@ class HorizontalReduction {
 
     TargetTransformInfo::ReductionFlags RdxFlags;
     if (match(I, m_Add(m_Value(), m_Value())))
-      return OperationData(I->getOpcode(), RD::RK_IntegerAdd, RdxFlags);
+      return OperationData(I->getOpcode(), RD::RK_IntegerAdd, RD::MRK_Invalid);
     if (match(I, m_Mul(m_Value(), m_Value())))
-      return OperationData(I->getOpcode(), RD::RK_IntegerMult, RdxFlags);
+      return OperationData(I->getOpcode(), RD::RK_IntegerMult, RD::MRK_Invalid);
     if (match(I, m_And(m_Value(), m_Value())))
-      return OperationData(I->getOpcode(), RD::RK_IntegerAnd, RdxFlags);
+      return OperationData(I->getOpcode(), RD::RK_IntegerAnd, RD::MRK_Invalid);
     if (match(I, m_Or(m_Value(), m_Value())))
-      return OperationData(I->getOpcode(), RD::RK_IntegerOr, RdxFlags);
+      return OperationData(I->getOpcode(), RD::RK_IntegerOr, RD::MRK_Invalid);
     if (match(I, m_Xor(m_Value(), m_Value())))
-      return OperationData(I->getOpcode(), RD::RK_IntegerXor, RdxFlags);
+      return OperationData(I->getOpcode(), RD::RK_IntegerXor, RD::MRK_Invalid);
     if (match(I, m_FAdd(m_Value(), m_Value())))
-      return OperationData(I->getOpcode(), RD::RK_FloatAdd, RdxFlags);
+      return OperationData(I->getOpcode(), RD::RK_FloatAdd, RD::MRK_Invalid);
     if (match(I, m_FMul(m_Value(), m_Value())))
-      return OperationData(I->getOpcode(), RD::RK_FloatMult, RdxFlags);
-
-    if (match(I, m_MaxOrMin(m_Value(), m_Value()))) {
-      RdxFlags.IsMaxOp = match(I, m_UMax(m_Value(), m_Value())) ||
-                         match(I, m_SMax(m_Value(), m_Value()));
-      RdxFlags.IsSigned = match(I, m_SMin(m_Value(), m_Value())) ||
-                          match(I, m_SMax(m_Value(), m_Value()));
-      return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax, RdxFlags);
-    }
-
+      return OperationData(I->getOpcode(), RD::RK_FloatMult, RD::MRK_Invalid);
+
+    if (match(I, m_SMax(m_Value(), m_Value())))
+      return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
+                           RD::MRK_SIntMax);
+    if (match(I, m_SMin(m_Value(), m_Value())))
+      return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
+                           RD::MRK_SIntMin);
+    if (match(I, m_UMax(m_Value(), m_Value())))
+      return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
+                           RD::MRK_UIntMax);
+    if (match(I, m_UMin(m_Value(), m_Value())))
+      return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
+                           RD::MRK_UIntMin);
 
     if (auto *Select = dyn_cast<SelectInst>(I)) {
       // Try harder: look for min/max pattern based on instructions producing
@@ -6814,28 +6820,23 @@ class HorizontalReduction {
       switch (Pred) {
       default:
         return OperationData(*I);
-      case CmpInst::ICMP_ULT:
-      case CmpInst::ICMP_ULE:
-        RdxFlags.IsMaxOp = false;
-        RdxFlags.IsSigned = false;
-        break;
+      case CmpInst::ICMP_SGT:
+      case CmpInst::ICMP_SGE:
+        return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
+                             RD::MRK_SIntMax);
       case CmpInst::ICMP_SLT:
       case CmpInst::ICMP_SLE:
-        RdxFlags.IsMaxOp = false;
-        RdxFlags.IsSigned = true;
-        break;
+        return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
+                             RD::MRK_SIntMin);
       case CmpInst::ICMP_UGT:
       case CmpInst::ICMP_UGE:
-        RdxFlags.IsMaxOp = true;
-        RdxFlags.IsSigned = false;
-        break;
-      case CmpInst::ICMP_SGT:
-      case CmpInst::ICMP_SGE:
-        RdxFlags.IsMaxOp = true;
-        RdxFlags.IsSigned = true;
-        break;
+        return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
+                             RD::MRK_UIntMax);
+      case CmpInst::ICMP_ULT:
+      case CmpInst::ICMP_ULE:
+        return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
+                             RD::MRK_UIntMin);
       }
-      return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax, RdxFlags);
     }
     return OperationData(*I);
   }
@@ -7186,8 +7187,8 @@ class HorizontalReduction {
       break;
     case RD::RK_IntegerMinMax: {
       auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VecTy));
-      bool IsUnsigned = !RdxTreeInst.getFlags().IsSigned;
-
+      RD::MinMaxRecurrenceKind MMKind = RdxTreeInst.getMinMaxKind();
+      bool IsUnsigned = MMKind == RD::MRK_UIntMax || MMKind == RD::MRK_UIntMin;
       PairwiseRdxCost =
           TTI->getMinMaxReductionCost(VecTy, VecCondTy,
                                       /*IsPairwiseForm=*/true, IsUnsigned);
@@ -7248,7 +7249,7 @@ class HorizontalReduction {
       assert(Builder.getFastMathFlags().isFast() && "Expected 'fast' FMF");
       return createSimpleTargetReduction(
           Builder, TTI, RdxTreeInst.getOpcode(), VectorizedValue,
-          RdxTreeInst.getFlags(), ReductionOps.back());
+          RdxTreeInst.getMinMaxKind(), ReductionOps.back());
     }
 
     Value *TmpVec = VectorizedValue;


        


More information about the llvm-branch-commits mailing list