[llvm] 9910740 - [LoopVectorize] Propagate fast-math flags for VPInstruction

Rosie Sumpter via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 24 00:59:58 PST 2021


Author: Rosie Sumpter
Date: 2021-11-24T08:50:04Z
New Revision: 991074012a6c9a294c5c64cf51502934a8e9bb36

URL: https://github.com/llvm/llvm-project/commit/991074012a6c9a294c5c64cf51502934a8e9bb36
DIFF: https://github.com/llvm/llvm-project/commit/991074012a6c9a294c5c64cf51502934a8e9bb36.diff

LOG: [LoopVectorize] Propagate fast-math flags for VPInstruction

In-loop vector reductions which use the llvm.fmuladd intrinsic involve
the creation of two recipes; a VPReductionRecipe for the fadd and a
VPInstruction for the fmul. If the call to llvm.fmuladd has fast-math flags
these should be propagated through to the fmul instruction, so an
interface setFastMathFlags has been added to the VPInstruction class to
enable this.

Differential Revision: https://reviews.llvm.org/D113125

Added: 
    

Modified: 
    llvm/include/llvm/IR/Operator.h
    llvm/lib/IR/AsmWriter.cpp
    llvm/lib/IR/Operator.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/VPlan.cpp
    llvm/lib/Transforms/Vectorize/VPlan.h
    llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll
    llvm/test/Transforms/LoopVectorize/vplan-printing.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index b83d83f0d0ab0..7d232bba0864c 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -250,8 +250,16 @@ class FastMathFlags {
   bool operator!=(const FastMathFlags &OtherFlags) const {
     return Flags != OtherFlags.Flags;
   }
+
+  /// Print fast-math flags to \p O.
+  void print(raw_ostream &O) const;
 };
 
+inline raw_ostream &operator<<(raw_ostream &O, FastMathFlags FMF) {
+  FMF.print(O);
+  return O;
+}
+
 /// Utility class for floating point operations which can have
 /// information about relaxed accuracy requirements attached to them.
 class FPMathOperator : public Operator {

diff  --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 7734c0a8de58a..48b3d2bb1bb56 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1309,27 +1309,8 @@ static void WriteAsOperandInternal(raw_ostream &Out, const Metadata *MD,
                                    bool FromValue = false);
 
 static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
-  if (const FPMathOperator *FPO = dyn_cast<const FPMathOperator>(U)) {
-    // 'Fast' is an abbreviation for all fast-math-flags.
-    if (FPO->isFast())
-      Out << " fast";
-    else {
-      if (FPO->hasAllowReassoc())
-        Out << " reassoc";
-      if (FPO->hasNoNaNs())
-        Out << " nnan";
-      if (FPO->hasNoInfs())
-        Out << " ninf";
-      if (FPO->hasNoSignedZeros())
-        Out << " nsz";
-      if (FPO->hasAllowReciprocal())
-        Out << " arcp";
-      if (FPO->hasAllowContract())
-        Out << " contract";
-      if (FPO->hasApproxFunc())
-        Out << " afn";
-    }
-  }
+  if (const FPMathOperator *FPO = dyn_cast<const FPMathOperator>(U))
+    Out << FPO->getFastMathFlags();
 
   if (const OverflowingBinaryOperator *OBO =
         dyn_cast<OverflowingBinaryOperator>(U)) {

diff  --git a/llvm/lib/IR/Operator.cpp b/llvm/lib/IR/Operator.cpp
index cf309ffd6212e..d15fcfbc5b9f1 100644
--- a/llvm/lib/IR/Operator.cpp
+++ b/llvm/lib/IR/Operator.cpp
@@ -226,4 +226,25 @@ bool GEPOperator::collectOffset(
   }
   return true;
 }
+
+void FastMathFlags::print(raw_ostream &O) const {
+  if (all())
+    O << " fast";
+  else {
+    if (allowReassoc())
+      O << " reassoc";
+    if (noNaNs())
+      O << " nnan";
+    if (noInfs())
+      O << " ninf";
+    if (noSignedZeros())
+      O << " nsz";
+    if (allowReciprocal())
+      O << " arcp";
+    if (allowContract())
+      O << " contract";
+    if (approxFunc())
+      O << " afn";
+  }
+}
 } // namespace llvm

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index fcb0640f4a1aa..48657a3ed8cb0 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9803,6 +9803,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
         // fadd reduction.
         VPInstruction *FMulRecipe = new VPInstruction(
             Instruction::FMul, {VecOp, Plan->getVPValue(R->getOperand(1))});
+        FMulRecipe->setFastMathFlags(R->getFastMathFlags());
         WidenRecipe->getParent()->insert(FMulRecipe,
                                          WidenRecipe->getIterator());
         VecOp = FMulRecipe;

diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 638467f94e1c1..99e86735ea1b4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -718,6 +718,8 @@ void VPInstruction::generateInstruction(VPTransformState &State,
 
 void VPInstruction::execute(VPTransformState &State) {
   assert(!State.Instance && "VPInstruction executing an Instance");
+  IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
+  State.Builder.setFastMathFlags(FMF);
   for (unsigned Part = 0; Part < State.UF; ++Part)
     generateInstruction(State, Part);
 }
@@ -760,6 +762,8 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
     O << Instruction::getOpcodeName(getOpcode());
   }
 
+  O << FMF;
+
   for (const VPValue *Operand : operands()) {
     O << " ";
     Operand->printAsOperand(O, SlotTracker);
@@ -767,6 +771,16 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
 }
 #endif
 
+void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) {
+  // Make sure the VPInstruction is a floating-point operation.
+  assert((Opcode == Instruction::FAdd || Opcode == Instruction::FMul ||
+          Opcode == Instruction::FNeg || Opcode == Instruction::FSub ||
+          Opcode == Instruction::FDiv || Opcode == Instruction::FRem ||
+          Opcode == Instruction::FCmp) &&
+         "this op can't take fast-math flags");
+  FMF = FMFNew;
+}
+
 /// Generate the code inside the body of the vectorized loop. Assumes a single
 /// LoopVectorBody basic-block was created for this. Introduce additional
 /// basic-blocks as needed, and fill them all.

diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 4161810e39cd1..9e2bd0540afb6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -794,6 +794,7 @@ class VPInstruction : public VPRecipeBase, public VPValue {
 private:
   typedef unsigned char OpcodeTy;
   OpcodeTy Opcode;
+  FastMathFlags FMF;
 
   /// Utility method serving execute(): generates a single instance of the
   /// modeled instruction.
@@ -875,6 +876,9 @@ class VPInstruction : public VPRecipeBase, public VPValue {
       return true;
     }
   }
+
+  /// Set the fast-math flags.
+  void setFastMathFlags(FastMathFlags FMFNew);
 };
 
 /// VPWidenRecipe is a recipe for producing a copy of vector type its

diff  --git a/llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll b/llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll
index 94dfa876e42a7..b80b2113e2cd5 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll
@@ -483,10 +483,10 @@ define float @fmuladd_strict_fmf(float* %a, float* %b, i64 %n) #0 {
 ; CHECK-ORDERED: [[WIDE_LOAD5:%.*]] = load <vscale x 8 x float>, <vscale x 8 x float>*
 ; CHECK-ORDERED: [[WIDE_LOAD6:%.*]] = load <vscale x 8 x float>, <vscale x 8 x float>*
 ; CHECK-ORDERED: [[WIDE_LOAD7:%.*]] = load <vscale x 8 x float>, <vscale x 8 x float>*
-; CHECK-ORDERED: [[FMUL:%.*]] = fmul <vscale x 8 x float> [[WIDE_LOAD]], [[WIDE_LOAD4]]
-; CHECK-ORDERED: [[FMUL1:%.*]] = fmul <vscale x 8 x float> [[WIDE_LOAD1]], [[WIDE_LOAD5]]
-; CHECK-ORDERED: [[FMUL2:%.*]] = fmul <vscale x 8 x float> [[WIDE_LOAD2]], [[WIDE_LOAD6]]
-; CHECK-ORDERED: [[FMUL3:%.*]] = fmul <vscale x 8 x float> [[WIDE_LOAD3]], [[WIDE_LOAD7]]
+; CHECK-ORDERED: [[FMUL:%.*]] = fmul nnan <vscale x 8 x float> [[WIDE_LOAD]], [[WIDE_LOAD4]]
+; CHECK-ORDERED: [[FMUL1:%.*]] = fmul nnan <vscale x 8 x float> [[WIDE_LOAD1]], [[WIDE_LOAD5]]
+; CHECK-ORDERED: [[FMUL2:%.*]] = fmul nnan <vscale x 8 x float> [[WIDE_LOAD2]], [[WIDE_LOAD6]]
+; CHECK-ORDERED: [[FMUL3:%.*]] = fmul nnan <vscale x 8 x float> [[WIDE_LOAD3]], [[WIDE_LOAD7]]
 ; CHECK-ORDERED: [[RDX:%.*]] = call nnan float @llvm.vector.reduce.fadd.nxv8f32(float [[VEC_PHI]], <vscale x 8 x float> [[FMUL]])
 ; CHECK-ORDERED: [[RDX1:%.*]] = call nnan float @llvm.vector.reduce.fadd.nxv8f32(float [[RDX]], <vscale x 8 x float> [[FMUL1]])
 ; CHECK-ORDERED: [[RDX2:%.*]] = call nnan float @llvm.vector.reduce.fadd.nxv8f32(float [[RDX1]], <vscale x 8 x float> [[FMUL2]])

diff  --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 80576f5cc9aa1..9efaae3fd18f7 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -242,4 +242,40 @@ for.end:
   ret void
 }
 
+define float @print_fmuladd_strict(float* %a, float* %b, i64 %n) {
+; CHECK-LABEL: Checking a loop in "print_fmuladd_strict"
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT: for.body:
+; CHECK-NEXT:   WIDEN-INDUCTION %iv = phi 0, %iv.next
+; CHECK-NEXT:   WIDEN-REDUCTION-PHI ir<%sum.07> = phi ir<0.000000e+00>, ir<%muladd>
+; CHECK-NEXT:   CLONE ir<%arrayidx> = getelementptr ir<%a>, ir<%iv>
+; CHECK-NEXT:   WIDEN ir<%0> = load ir<%arrayidx>
+; CHECK-NEXT:   CLONE ir<%arrayidx2> = getelementptr ir<%b>, ir<%iv>
+; CHECK-NEXT:   WIDEN ir<%1> = load ir<%arrayidx2>
+; CHECK-NEXT:   EMIT vp<%6> = fmul nnan ninf nsz ir<%0> ir<%1>
+; CHECK-NEXT:   REDUCE ir<%muladd> = ir<%sum.07> + reduce.fadd (vp<%6>)
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %sum.07 = phi float [ 0.000000e+00, %entry ], [ %muladd, %for.body ]
+  %arrayidx = getelementptr inbounds float, float* %a, i64 %iv
+  %0 = load float, float* %arrayidx, align 4
+  %arrayidx2 = getelementptr inbounds float, float* %b, i64 %iv
+  %1 = load float, float* %arrayidx2, align 4
+  %muladd = tail call nnan ninf nsz float @llvm.fmuladd.f32(float %0, float %1, float %sum.07)
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, %n
+  br i1 %exitcond.not, label %for.end, label %for.body
+
+for.end:
+  ret float %muladd
+}
+
 declare float @llvm.sqrt.f32(float) nounwind readnone
+declare float @llvm.fmuladd.f32(float, float, float)


        


More information about the llvm-commits mailing list