[llvm] 9910740 - [LoopVectorize] Propagate fast-math flags for VPInstruction
Rosie Sumpter via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 24 00:59:58 PST 2021
Author: Rosie Sumpter
Date: 2021-11-24T08:50:04Z
New Revision: 991074012a6c9a294c5c64cf51502934a8e9bb36
URL: https://github.com/llvm/llvm-project/commit/991074012a6c9a294c5c64cf51502934a8e9bb36
DIFF: https://github.com/llvm/llvm-project/commit/991074012a6c9a294c5c64cf51502934a8e9bb36.diff
LOG: [LoopVectorize] Propagate fast-math flags for VPInstruction
In-loop vector reductions which use the llvm.fmuladd intrinsic involve
the creation of two recipes; a VPReductionRecipe for the fadd and a
VPInstruction for the fmul. If the call to llvm.fmuladd has fast-math flags
these should be propagated through to the fmul instruction, so an
interface setFastMathFlags has been added to the VPInstruction class to
enable this.
Differential Revision: https://reviews.llvm.org/D113125
Added:
Modified:
llvm/include/llvm/IR/Operator.h
llvm/lib/IR/AsmWriter.cpp
llvm/lib/IR/Operator.cpp
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
llvm/lib/Transforms/Vectorize/VPlan.cpp
llvm/lib/Transforms/Vectorize/VPlan.h
llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll
llvm/test/Transforms/LoopVectorize/vplan-printing.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index b83d83f0d0ab0..7d232bba0864c 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -250,8 +250,16 @@ class FastMathFlags {
bool operator!=(const FastMathFlags &OtherFlags) const {
return Flags != OtherFlags.Flags;
}
+
+ /// Print fast-math flags to \p O.
+ void print(raw_ostream &O) const;
};
+inline raw_ostream &operator<<(raw_ostream &O, FastMathFlags FMF) {
+ FMF.print(O);
+ return O;
+}
+
/// Utility class for floating point operations which can have
/// information about relaxed accuracy requirements attached to them.
class FPMathOperator : public Operator {
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 7734c0a8de58a..48b3d2bb1bb56 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1309,27 +1309,8 @@ static void WriteAsOperandInternal(raw_ostream &Out, const Metadata *MD,
bool FromValue = false);
static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
- if (const FPMathOperator *FPO = dyn_cast<const FPMathOperator>(U)) {
- // 'Fast' is an abbreviation for all fast-math-flags.
- if (FPO->isFast())
- Out << " fast";
- else {
- if (FPO->hasAllowReassoc())
- Out << " reassoc";
- if (FPO->hasNoNaNs())
- Out << " nnan";
- if (FPO->hasNoInfs())
- Out << " ninf";
- if (FPO->hasNoSignedZeros())
- Out << " nsz";
- if (FPO->hasAllowReciprocal())
- Out << " arcp";
- if (FPO->hasAllowContract())
- Out << " contract";
- if (FPO->hasApproxFunc())
- Out << " afn";
- }
- }
+ if (const FPMathOperator *FPO = dyn_cast<const FPMathOperator>(U))
+ Out << FPO->getFastMathFlags();
if (const OverflowingBinaryOperator *OBO =
dyn_cast<OverflowingBinaryOperator>(U)) {
diff --git a/llvm/lib/IR/Operator.cpp b/llvm/lib/IR/Operator.cpp
index cf309ffd6212e..d15fcfbc5b9f1 100644
--- a/llvm/lib/IR/Operator.cpp
+++ b/llvm/lib/IR/Operator.cpp
@@ -226,4 +226,25 @@ bool GEPOperator::collectOffset(
}
return true;
}
+
+void FastMathFlags::print(raw_ostream &O) const {
+ if (all())
+ O << " fast";
+ else {
+ if (allowReassoc())
+ O << " reassoc";
+ if (noNaNs())
+ O << " nnan";
+ if (noInfs())
+ O << " ninf";
+ if (noSignedZeros())
+ O << " nsz";
+ if (allowReciprocal())
+ O << " arcp";
+ if (allowContract())
+ O << " contract";
+ if (approxFunc())
+ O << " afn";
+ }
+}
} // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index fcb0640f4a1aa..48657a3ed8cb0 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9803,6 +9803,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
// fadd reduction.
VPInstruction *FMulRecipe = new VPInstruction(
Instruction::FMul, {VecOp, Plan->getVPValue(R->getOperand(1))});
+ FMulRecipe->setFastMathFlags(R->getFastMathFlags());
WidenRecipe->getParent()->insert(FMulRecipe,
WidenRecipe->getIterator());
VecOp = FMulRecipe;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 638467f94e1c1..99e86735ea1b4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -718,6 +718,8 @@ void VPInstruction::generateInstruction(VPTransformState &State,
void VPInstruction::execute(VPTransformState &State) {
assert(!State.Instance && "VPInstruction executing an Instance");
+ IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
+ State.Builder.setFastMathFlags(FMF);
for (unsigned Part = 0; Part < State.UF; ++Part)
generateInstruction(State, Part);
}
@@ -760,6 +762,8 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
O << Instruction::getOpcodeName(getOpcode());
}
+ O << FMF;
+
for (const VPValue *Operand : operands()) {
O << " ";
Operand->printAsOperand(O, SlotTracker);
@@ -767,6 +771,16 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
}
#endif
+void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) {
+ // Make sure the VPInstruction is a floating-point operation.
+ assert((Opcode == Instruction::FAdd || Opcode == Instruction::FMul ||
+ Opcode == Instruction::FNeg || Opcode == Instruction::FSub ||
+ Opcode == Instruction::FDiv || Opcode == Instruction::FRem ||
+ Opcode == Instruction::FCmp) &&
+ "this op can't take fast-math flags");
+ FMF = FMFNew;
+}
+
/// Generate the code inside the body of the vectorized loop. Assumes a single
/// LoopVectorBody basic-block was created for this. Introduce additional
/// basic-blocks as needed, and fill them all.
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 4161810e39cd1..9e2bd0540afb6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -794,6 +794,7 @@ class VPInstruction : public VPRecipeBase, public VPValue {
private:
typedef unsigned char OpcodeTy;
OpcodeTy Opcode;
+ FastMathFlags FMF;
/// Utility method serving execute(): generates a single instance of the
/// modeled instruction.
@@ -875,6 +876,9 @@ class VPInstruction : public VPRecipeBase, public VPValue {
return true;
}
}
+
+ /// Set the fast-math flags.
+ void setFastMathFlags(FastMathFlags FMFNew);
};
/// VPWidenRecipe is a recipe for producing a copy of vector type its
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll b/llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll
index 94dfa876e42a7..b80b2113e2cd5 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll
@@ -483,10 +483,10 @@ define float @fmuladd_strict_fmf(float* %a, float* %b, i64 %n) #0 {
; CHECK-ORDERED: [[WIDE_LOAD5:%.*]] = load <vscale x 8 x float>, <vscale x 8 x float>*
; CHECK-ORDERED: [[WIDE_LOAD6:%.*]] = load <vscale x 8 x float>, <vscale x 8 x float>*
; CHECK-ORDERED: [[WIDE_LOAD7:%.*]] = load <vscale x 8 x float>, <vscale x 8 x float>*
-; CHECK-ORDERED: [[FMUL:%.*]] = fmul <vscale x 8 x float> [[WIDE_LOAD]], [[WIDE_LOAD4]]
-; CHECK-ORDERED: [[FMUL1:%.*]] = fmul <vscale x 8 x float> [[WIDE_LOAD1]], [[WIDE_LOAD5]]
-; CHECK-ORDERED: [[FMUL2:%.*]] = fmul <vscale x 8 x float> [[WIDE_LOAD2]], [[WIDE_LOAD6]]
-; CHECK-ORDERED: [[FMUL3:%.*]] = fmul <vscale x 8 x float> [[WIDE_LOAD3]], [[WIDE_LOAD7]]
+; CHECK-ORDERED: [[FMUL:%.*]] = fmul nnan <vscale x 8 x float> [[WIDE_LOAD]], [[WIDE_LOAD4]]
+; CHECK-ORDERED: [[FMUL1:%.*]] = fmul nnan <vscale x 8 x float> [[WIDE_LOAD1]], [[WIDE_LOAD5]]
+; CHECK-ORDERED: [[FMUL2:%.*]] = fmul nnan <vscale x 8 x float> [[WIDE_LOAD2]], [[WIDE_LOAD6]]
+; CHECK-ORDERED: [[FMUL3:%.*]] = fmul nnan <vscale x 8 x float> [[WIDE_LOAD3]], [[WIDE_LOAD7]]
; CHECK-ORDERED: [[RDX:%.*]] = call nnan float @llvm.vector.reduce.fadd.nxv8f32(float [[VEC_PHI]], <vscale x 8 x float> [[FMUL]])
; CHECK-ORDERED: [[RDX1:%.*]] = call nnan float @llvm.vector.reduce.fadd.nxv8f32(float [[RDX]], <vscale x 8 x float> [[FMUL1]])
; CHECK-ORDERED: [[RDX2:%.*]] = call nnan float @llvm.vector.reduce.fadd.nxv8f32(float [[RDX1]], <vscale x 8 x float> [[FMUL2]])
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 80576f5cc9aa1..9efaae3fd18f7 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -242,4 +242,40 @@ for.end:
ret void
}
+define float @print_fmuladd_strict(float* %a, float* %b, i64 %n) {
+; CHECK-LABEL: Checking a loop in "print_fmuladd_strict"
+; CHECK: VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT: for.body:
+; CHECK-NEXT: WIDEN-INDUCTION %iv = phi 0, %iv.next
+; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%sum.07> = phi ir<0.000000e+00>, ir<%muladd>
+; CHECK-NEXT: CLONE ir<%arrayidx> = getelementptr ir<%a>, ir<%iv>
+; CHECK-NEXT: WIDEN ir<%0> = load ir<%arrayidx>
+; CHECK-NEXT: CLONE ir<%arrayidx2> = getelementptr ir<%b>, ir<%iv>
+; CHECK-NEXT: WIDEN ir<%1> = load ir<%arrayidx2>
+; CHECK-NEXT: EMIT vp<%6> = fmul nnan ninf nsz ir<%0> ir<%1>
+; CHECK-NEXT: REDUCE ir<%muladd> = ir<%sum.07> + reduce.fadd (vp<%6>)
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+
+entry:
+ br label %for.body
+
+for.body:
+ %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+ %sum.07 = phi float [ 0.000000e+00, %entry ], [ %muladd, %for.body ]
+ %arrayidx = getelementptr inbounds float, float* %a, i64 %iv
+ %0 = load float, float* %arrayidx, align 4
+ %arrayidx2 = getelementptr inbounds float, float* %b, i64 %iv
+ %1 = load float, float* %arrayidx2, align 4
+ %muladd = tail call nnan ninf nsz float @llvm.fmuladd.f32(float %0, float %1, float %sum.07)
+ %iv.next = add nuw nsw i64 %iv, 1
+ %exitcond.not = icmp eq i64 %iv.next, %n
+ br i1 %exitcond.not, label %for.end, label %for.body
+
+for.end:
+ ret float %muladd
+}
+
declare float @llvm.sqrt.f32(float) nounwind readnone
+declare float @llvm.fmuladd.f32(float, float, float)
More information about the llvm-commits
mailing list