[llvm] bbed5f2 - [LoopVectorize] improve IR fast-math-flags propagation in reductions

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 1 13:21:45 PST 2021


Author: Sanjay Patel
Date: 2021-02-01T16:21:36-05:00
New Revision: bbed5f2f8a04ae3a49f5e8f900c117f405101424

URL: https://github.com/llvm/llvm-project/commit/bbed5f2f8a04ae3a49f5e8f900c117f405101424
DIFF: https://github.com/llvm/llvm-project/commit/bbed5f2f8a04ae3a49f5e8f900c117f405101424.diff

LOG: [LoopVectorize] improve IR fast-math-flags propagation in reductions

This is another step (see D95452) towards correcting fast-math-flags
bugs in vector reductions.

There are multiple bugs visible in the test diffs, and this is still
not working as it should. We still use function attributes (rather
than FMF) to drive part of the logic, but we are not checking for
the correct FP function attributes.

Note that FMF may not be propagated optimally on selects (example
in https://llvm.org/PR35607 ). That's why I'm proposing to union the
FMF of a fcmp+select pair and avoid regressions on existing vectorizer
tests.

Differential Revision: https://reviews.llvm.org/D95690

Added: 
    

Modified: 
    llvm/include/llvm/IR/Operator.h
    llvm/include/llvm/Transforms/Utils/LoopUtils.h
    llvm/lib/Analysis/IVDescriptors.cpp
    llvm/lib/Transforms/Utils/LoopUtils.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/test/Transforms/LoopVectorize/X86/reduction-fastmath.ll
    llvm/test/Transforms/LoopVectorize/float-minmax-instruction-flag.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index acfacbd6c74e..4c16eb17e2a5 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -239,6 +239,9 @@ class FastMathFlags {
   void operator&=(const FastMathFlags &OtherFlags) {
     Flags &= OtherFlags.Flags;
   }
+  void operator|=(const FastMathFlags &OtherFlags) {
+    Flags |= OtherFlags.Flags;
+  }
 };
 
 /// Utility class for floating point operations which can have

diff  --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index c1973fe0ae52..5f2169cb42e8 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -356,6 +356,7 @@ bool canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
                         OptimizationRemarkEmitter *ORE = nullptr);
 
 /// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind.
+/// The Builder's fast-math-flags must be set to propagate the expected values.
 Value *createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
                       Value *Right);
 

diff  --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp
index 7f311d8f9a2b..91befad26de5 100644
--- a/llvm/lib/Analysis/IVDescriptors.cpp
+++ b/llvm/lib/Analysis/IVDescriptors.cpp
@@ -302,8 +302,18 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
       if (!ReduxDesc.isRecurrence())
         return false;
       // FIXME: FMF is allowed on phi, but propagation is not handled correctly.
-      if (isa<FPMathOperator>(ReduxDesc.getPatternInst()) && !IsAPhi)
-        FMF &= ReduxDesc.getPatternInst()->getFastMathFlags();
+      if (isa<FPMathOperator>(ReduxDesc.getPatternInst()) && !IsAPhi) {
+        FastMathFlags CurFMF = ReduxDesc.getPatternInst()->getFastMathFlags();
+        if (auto *Sel = dyn_cast<SelectInst>(ReduxDesc.getPatternInst())) {
+          // Accept FMF on either fcmp or select of a min/max idiom.
+          // TODO: This is a hack to work-around the fact that FMF may not be
+          //       assigned/propagated correctly. If that problem is fixed or we
+          //       standardize on fmin/fmax via intrinsics, this can be removed.
+          assert(isa<FCmpInst>(Sel->getCondition()) && "Expected fcmp min/max");
+          CurFMF |= cast<FCmpInst>(Sel->getCondition())->getFastMathFlags();
+        }
+        FMF &= CurFMF;
+      }
       // Update this reduction kind if we matched a new instruction.
       // TODO: Can we eliminate the need for a 2nd InstDesc by keeping 'Kind'
       //       state accurate while processing the worklist?

diff  --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 5c66cb0313fd..07dc3ac44c8c 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -944,12 +944,6 @@ Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
     break;
   }
 
-  // We only match FP sequences that are 'fast', so we can unconditionally
-  // set it on any generated instructions.
-  IRBuilderBase::FastMathFlagGuard FMFG(Builder);
-  FastMathFlags FMF;
-  FMF.setFast();
-  Builder.setFastMathFlags(FMF);
   Value *Cmp = Builder.CreateCmp(Pred, Left, Right, "rdx.minmax.cmp");
   Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select");
   return Select;

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f770cb6bb2b1..ec36b8292ad3 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -403,12 +403,6 @@ static Value *addFastMathFlag(Value *V) {
   return V;
 }
 
-static Value *addFastMathFlag(Value *V, FastMathFlags FMF) {
-  if (isa<FPMathOperator>(V))
-    cast<Instruction>(V)->setFastMathFlags(FMF);
-  return V;
-}
-
 /// A helper function that returns an integer or floating-point constant with
 /// value C.
 static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) {
@@ -4301,16 +4295,19 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) {
   // terminate on this line. This is the easiest way to ensure we don't
   // accidentally cause an extra step back into the loop while debugging.
   setDebugLocFromInst(Builder, LoopMiddleBlock->getTerminator());
-  for (unsigned Part = 1; Part < UF; ++Part) {
-    Value *RdxPart = VectorLoopValueMap.getVectorValue(LoopExitInst, Part);
-    if (Op != Instruction::ICmp && Op != Instruction::FCmp)
-      // Floating point operations had to be 'fast' to enable the reduction.
-      ReducedPartRdx = addFastMathFlag(
-          Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxPart,
-                              ReducedPartRdx, "bin.rdx"),
-          RdxDesc.getFastMathFlags());
-    else
-      ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
+  {
+    // Floating-point operations should have some FMF to enable the reduction.
+    IRBuilderBase::FastMathFlagGuard FMFG(Builder);
+    Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+    for (unsigned Part = 1; Part < UF; ++Part) {
+      Value *RdxPart = VectorLoopValueMap.getVectorValue(LoopExitInst, Part);
+      if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
+        ReducedPartRdx = Builder.CreateBinOp(
+            (Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
+      } else {
+        ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
+      }
+    }
   }
 
   // Create the reduction after the loop. Note that inloop reductions create the

diff  --git a/llvm/test/Transforms/LoopVectorize/X86/reduction-fastmath.ll b/llvm/test/Transforms/LoopVectorize/X86/reduction-fastmath.ll
index c70b11ebc1f1..a52e9e35a93b 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/reduction-fastmath.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/reduction-fastmath.ll
@@ -262,7 +262,8 @@ loop.exit:
   ret float %sum.lcssa
 }
 
-; FIXME: Some fcmp are 'nnan ninf', some are 'fast', but the reduction is sequential?
+; New instructions should have the same FMF as the original code.
+; Note that the select inherits FMF from its fcmp condition.
 
 define float @PR35538(float* nocapture readonly %a, i32 %N) #0 {
 ; CHECK-LABEL: @PR35538(
@@ -299,9 +300,9 @@ define float @PR35538(float* nocapture readonly %a, i32 %N) #0 {
 ; CHECK-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP8:!llvm.loop !.*]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[RDX_MINMAX_CMP:%.*]] = fcmp fast ogt <4 x float> [[TMP10]], [[TMP11]]
-; CHECK-NEXT:    [[RDX_MINMAX_SELECT:%.*]] = select fast <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP10]], <4 x float> [[TMP11]]
-; CHECK-NEXT:    [[TMP13:%.*]] = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[RDX_MINMAX_SELECT]])
+; CHECK-NEXT:    [[RDX_MINMAX_CMP:%.*]] = fcmp nnan ninf ogt <4 x float> [[TMP10]], [[TMP11]]
+; CHECK-NEXT:    [[RDX_MINMAX_SELECT:%.*]] = select nnan ninf <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP10]], <4 x float> [[TMP11]]
+; CHECK-NEXT:    [[TMP13:%.*]] = call nnan ninf float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[RDX_MINMAX_SELECT]])
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK:       scalar.ph:
@@ -349,6 +350,8 @@ for.body:
   br i1 %exitcond, label %for.cond.cleanup, label %for.body
 }
 
+; Same as above, but this time the select already has matching FMF with its condition.
+
 define float @PR35538_more_FMF(float* nocapture readonly %a, i32 %N) #0 {
 ; CHECK-LABEL: @PR35538_more_FMF(
 ; CHECK-NEXT:  entry:
@@ -384,8 +387,8 @@ define float @PR35538_more_FMF(float* nocapture readonly %a, i32 %N) #0 {
 ; CHECK-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP10:!llvm.loop !.*]]
 ; CHECK:       middle.block:
-; CHECK-NEXT:    [[RDX_MINMAX_CMP:%.*]] = fcmp fast ogt <4 x float> [[TMP10]], [[TMP11]]
-; CHECK-NEXT:    [[RDX_MINMAX_SELECT:%.*]] = select fast <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP10]], <4 x float> [[TMP11]]
+; CHECK-NEXT:    [[RDX_MINMAX_CMP:%.*]] = fcmp nnan ninf ogt <4 x float> [[TMP10]], [[TMP11]]
+; CHECK-NEXT:    [[RDX_MINMAX_SELECT:%.*]] = select nnan ninf <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP10]], <4 x float> [[TMP11]]
 ; CHECK-NEXT:    [[TMP13:%.*]] = call nnan ninf float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[RDX_MINMAX_SELECT]])
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]

diff  --git a/llvm/test/Transforms/LoopVectorize/float-minmax-instruction-flag.ll b/llvm/test/Transforms/LoopVectorize/float-minmax-instruction-flag.ll
index d6fc86956023..3dcca2fb30a3 100644
--- a/llvm/test/Transforms/LoopVectorize/float-minmax-instruction-flag.ll
+++ b/llvm/test/Transforms/LoopVectorize/float-minmax-instruction-flag.ll
@@ -69,11 +69,11 @@ define float @minloopattr(float* nocapture readonly %arg) #0 {
 ; CHECK-NEXT:    br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP0:!llvm.loop !.*]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[RDX_SHUF:%.*]] = shufflevector <4 x float> [[TMP5]], <4 x float> poison, <4 x i32> <i32 2, i32 3, i32 undef, i32 undef>
-; CHECK-NEXT:    [[RDX_MINMAX_CMP:%.*]] = fcmp fast olt <4 x float> [[TMP5]], [[RDX_SHUF]]
-; CHECK-NEXT:    [[RDX_MINMAX_SELECT:%.*]] = select fast <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP5]], <4 x float> [[RDX_SHUF]]
+; CHECK-NEXT:    [[RDX_MINMAX_CMP:%.*]] = fcmp olt <4 x float> [[TMP5]], [[RDX_SHUF]]
+; CHECK-NEXT:    [[RDX_MINMAX_SELECT:%.*]] = select <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP5]], <4 x float> [[RDX_SHUF]]
 ; CHECK-NEXT:    [[RDX_SHUF1:%.*]] = shufflevector <4 x float> [[RDX_MINMAX_SELECT]], <4 x float> poison, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
-; CHECK-NEXT:    [[RDX_MINMAX_CMP2:%.*]] = fcmp fast olt <4 x float> [[RDX_MINMAX_SELECT]], [[RDX_SHUF1]]
-; CHECK-NEXT:    [[RDX_MINMAX_SELECT3:%.*]] = select fast <4 x i1> [[RDX_MINMAX_CMP2]], <4 x float> [[RDX_MINMAX_SELECT]], <4 x float> [[RDX_SHUF1]]
+; CHECK-NEXT:    [[RDX_MINMAX_CMP2:%.*]] = fcmp olt <4 x float> [[RDX_MINMAX_SELECT]], [[RDX_SHUF1]]
+; CHECK-NEXT:    [[RDX_MINMAX_SELECT3:%.*]] = select <4 x i1> [[RDX_MINMAX_CMP2]], <4 x float> [[RDX_MINMAX_SELECT]], <4 x float> [[RDX_SHUF1]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <4 x float> [[RDX_MINMAX_SELECT3]], i32 0
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 65536, 65536
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[OUT:%.*]], label [[SCALAR_PH]]


        


More information about the llvm-commits mailing list