[llvm] r355868 - Relax constraints for reduction vectorization
Sanjoy Das via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 11 14:36:42 PDT 2019
Author: sanjoy
Date: Mon Mar 11 14:36:41 2019
New Revision: 355868
URL: http://llvm.org/viewvc/llvm-project?rev=355868&view=rev
Log:
Relax constraints for reduction vectorization
Summary:
Gating vectorizing reductions on *all* fastmath flags seems unnecessary;
`reassoc` should be sufficient.
Reviewers: tvvikram, mkuper, kristof.beyls, sdesmalen, Ayal
Reviewed By: sdesmalen
Subscribers: dcaballe, huntergr, jmolloy, mcrosier, jlebar, bixia, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D57728
Added:
llvm/trunk/test/Transforms/LoopVectorize/reduction-fastmath.ll
Modified:
llvm/trunk/include/llvm/Analysis/IVDescriptors.h
llvm/trunk/include/llvm/IR/Operator.h
llvm/trunk/include/llvm/Transforms/Utils/LoopUtils.h
llvm/trunk/lib/Analysis/IVDescriptors.cpp
llvm/trunk/lib/CodeGen/ExpandReductions.cpp
llvm/trunk/lib/Transforms/Utils/LoopUtils.cpp
llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp
llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp
Modified: llvm/trunk/include/llvm/Analysis/IVDescriptors.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/IVDescriptors.h?rev=355868&r1=355867&r2=355868&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/IVDescriptors.h (original)
+++ llvm/trunk/include/llvm/Analysis/IVDescriptors.h Mon Mar 11 14:36:41 2019
@@ -89,10 +89,12 @@ public:
RecurrenceDescriptor() = default;
RecurrenceDescriptor(Value *Start, Instruction *Exit, RecurrenceKind K,
- MinMaxRecurrenceKind MK, Instruction *UAI, Type *RT,
- bool Signed, SmallPtrSetImpl<Instruction *> &CI)
- : StartValue(Start), LoopExitInstr(Exit), Kind(K), MinMaxKind(MK),
- UnsafeAlgebraInst(UAI), RecurrenceType(RT), IsSigned(Signed) {
+ FastMathFlags FMF, MinMaxRecurrenceKind MK,
+ Instruction *UAI, Type *RT, bool Signed,
+ SmallPtrSetImpl<Instruction *> &CI)
+ : StartValue(Start), LoopExitInstr(Exit), Kind(K), FMF(FMF),
+ MinMaxKind(MK), UnsafeAlgebraInst(UAI), RecurrenceType(RT),
+ IsSigned(Signed) {
CastInsts.insert(CI.begin(), CI.end());
}
@@ -198,6 +200,8 @@ public:
MinMaxRecurrenceKind getMinMaxRecurrenceKind() { return MinMaxKind; }
+ FastMathFlags getFastMathFlags() { return FMF; }
+
TrackingVH<Value> getRecurrenceStartValue() { return StartValue; }
Instruction *getLoopExitInstr() { return LoopExitInstr; }
@@ -237,6 +241,9 @@ private:
Instruction *LoopExitInstr = nullptr;
// The kind of the recurrence.
RecurrenceKind Kind = RK_NoRecurrence;
+ // The fast-math flags on the recurrent instructions. We propagate these
+ // fast-math flags into the vectorized FP instructions we generate.
+ FastMathFlags FMF;
// If this a min/max recurrence the kind of recurrence.
MinMaxRecurrenceKind MinMaxKind = MRK_Invalid;
// First occurrence of unasfe algebra in the PHI's use-chain.
Modified: llvm/trunk/include/llvm/IR/Operator.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/IR/Operator.h?rev=355868&r1=355867&r2=355868&view=diff
==============================================================================
--- llvm/trunk/include/llvm/IR/Operator.h (original)
+++ llvm/trunk/include/llvm/IR/Operator.h Mon Mar 11 14:36:41 2019
@@ -187,6 +187,12 @@ public:
FastMathFlags() = default;
+ static FastMathFlags getFast() {
+ FastMathFlags FMF;
+ FMF.setFast();
+ return FMF;
+ }
+
bool any() const { return Flags != 0; }
bool none() const { return Flags == 0; }
bool all() const { return Flags == ~0U; }
Modified: llvm/trunk/include/llvm/Transforms/Utils/LoopUtils.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Transforms/Utils/LoopUtils.h?rev=355868&r1=355867&r2=355868&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Transforms/Utils/LoopUtils.h (original)
+++ llvm/trunk/include/llvm/Transforms/Utils/LoopUtils.h Mon Mar 11 14:36:41 2019
@@ -296,6 +296,7 @@ getOrderedReduction(IRBuilder<> &Builder
Value *getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op,
RecurrenceDescriptor::MinMaxRecurrenceKind
MinMaxKind = RecurrenceDescriptor::MRK_Invalid,
+ FastMathFlags FMF = FastMathFlags(),
ArrayRef<Value *> RedOps = None);
/// Create a target reduction of the given vector. The reduction operation
@@ -308,6 +309,7 @@ Value *createSimpleTargetReduction(IRBui
unsigned Opcode, Value *Src,
TargetTransformInfo::ReductionFlags Flags =
TargetTransformInfo::ReductionFlags(),
+ FastMathFlags FMF = FastMathFlags(),
ArrayRef<Value *> RedOps = None);
/// Create a generic target reduction using a recurrence descriptor \p Desc
Modified: llvm/trunk/lib/Analysis/IVDescriptors.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/IVDescriptors.cpp?rev=355868&r1=355867&r2=355868&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/IVDescriptors.cpp (original)
+++ llvm/trunk/lib/Analysis/IVDescriptors.cpp Mon Mar 11 14:36:41 2019
@@ -251,6 +251,10 @@ bool RecurrenceDescriptor::AddReductionV
Worklist.push_back(Start);
VisitedInsts.insert(Start);
+ // Start with all flags set because we will intersect this with the reduction
+ // flags from all the reduction operations.
+ FastMathFlags FMF = FastMathFlags::getFast();
+
// A value in the reduction can be used:
// - By the reduction:
// - Reduction operation:
@@ -296,6 +300,8 @@ bool RecurrenceDescriptor::AddReductionV
ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, HasFunNoNaNAttr);
if (!ReduxDesc.isRecurrence())
return false;
+ if (isa<FPMathOperator>(ReduxDesc.getPatternInst()))
+ FMF &= ReduxDesc.getPatternInst()->getFastMathFlags();
}
bool IsASelect = isa<SelectInst>(Cur);
@@ -441,7 +447,7 @@ bool RecurrenceDescriptor::AddReductionV
// Save the description of this reduction variable.
RecurrenceDescriptor RD(
- RdxStart, ExitInstruction, Kind, ReduxDesc.getMinMaxKind(),
+ RdxStart, ExitInstruction, Kind, FMF, ReduxDesc.getMinMaxKind(),
ReduxDesc.getUnsafeAlgebraInst(), RecurrenceType, IsSigned, CastInsts);
RedDes = RD;
@@ -550,7 +556,7 @@ RecurrenceDescriptor::InstDesc
RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind,
InstDesc &Prev, bool HasFunNoNaNAttr) {
Instruction *UAI = Prev.getUnsafeAlgebraInst();
- if (!UAI && isa<FPMathOperator>(I) && !I->isFast())
+ if (!UAI && isa<FPMathOperator>(I) && !I->hasAllowReassoc())
UAI = I; // Found an unsafe (unvectorizable) algebra instruction.
switch (I->getOpcode()) {
Modified: llvm/trunk/lib/CodeGen/ExpandReductions.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/ExpandReductions.cpp?rev=355868&r1=355867&r2=355868&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/ExpandReductions.cpp (original)
+++ llvm/trunk/lib/CodeGen/ExpandReductions.cpp Mon Mar 11 14:36:41 2019
@@ -118,9 +118,11 @@ bool expandReductions(Function &F, const
}
if (!TTI->shouldExpandReduction(II))
continue;
+ FastMathFlags FMF =
+ isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
Value *Rdx =
IsOrdered ? getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), MRK)
- : getShuffleReduction(Builder, Vec, getOpcode(ID), MRK);
+ : getShuffleReduction(Builder, Vec, getOpcode(ID), MRK, FMF);
II->replaceAllUsesWith(Rdx);
II->eraseFromParent();
Changed = true;
Modified: llvm/trunk/lib/Transforms/Utils/LoopUtils.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/LoopUtils.cpp?rev=355868&r1=355867&r2=355868&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/LoopUtils.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/LoopUtils.cpp Mon Mar 11 14:36:41 2019
@@ -671,13 +671,9 @@ bool llvm::hasIterationCountInvariantInP
return true;
}
-/// Adds a 'fast' flag to floating point operations.
-static Value *addFastMathFlag(Value *V) {
- if (isa<FPMathOperator>(V)) {
- FastMathFlags Flags;
- Flags.setFast();
- cast<Instruction>(V)->setFastMathFlags(Flags);
- }
+static Value *addFastMathFlag(Value *V, FastMathFlags FMF) {
+ if (isa<FPMathOperator>(V))
+ cast<Instruction>(V)->setFastMathFlags(FMF);
return V;
}
@@ -761,7 +757,7 @@ llvm::getOrderedReduction(IRBuilder<> &B
Value *
llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op,
RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind,
- ArrayRef<Value *> RedOps) {
+ FastMathFlags FMF, ArrayRef<Value *> RedOps) {
unsigned VF = Src->getType()->getVectorNumElements();
// VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
// and vector ops, reducing the set of values being computed by half each
@@ -786,7 +782,8 @@ llvm::getShuffleReduction(IRBuilder<> &B
if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
// Floating point operations had to be 'fast' to enable the reduction.
TmpVec = addFastMathFlag(Builder.CreateBinOp((Instruction::BinaryOps)Op,
- TmpVec, Shuf, "bin.rdx"));
+ TmpVec, Shuf, "bin.rdx"),
+ FMF);
} else {
assert(MinMaxKind != RecurrenceDescriptor::MRK_Invalid &&
"Invalid min/max");
@@ -803,7 +800,7 @@ llvm::getShuffleReduction(IRBuilder<> &B
/// flags (if generating min/max reductions).
Value *llvm::createSimpleTargetReduction(
IRBuilder<> &Builder, const TargetTransformInfo *TTI, unsigned Opcode,
- Value *Src, TargetTransformInfo::ReductionFlags Flags,
+ Value *Src, TargetTransformInfo::ReductionFlags Flags, FastMathFlags FMF,
ArrayRef<Value *> RedOps) {
assert(isa<VectorType>(Src->getType()) && "Type must be a vector");
@@ -873,7 +870,7 @@ Value *llvm::createSimpleTargetReduction
}
if (TTI->useReductionIntrinsic(Opcode, Src->getType(), Flags))
return BuildFunc();
- return getShuffleReduction(Builder, Src, Opcode, MinMaxKind, RedOps);
+ return getShuffleReduction(Builder, Src, Opcode, MinMaxKind, FMF, RedOps);
}
/// Create a vector reduction using a given recurrence descriptor.
@@ -888,28 +885,37 @@ Value *llvm::createTargetReduction(IRBui
Flags.NoNaN = NoNaN;
switch (RecKind) {
case RD::RK_FloatAdd:
- return createSimpleTargetReduction(B, TTI, Instruction::FAdd, Src, Flags);
+ return createSimpleTargetReduction(B, TTI, Instruction::FAdd, Src, Flags,
+ Desc.getFastMathFlags());
case RD::RK_FloatMult:
- return createSimpleTargetReduction(B, TTI, Instruction::FMul, Src, Flags);
+ return createSimpleTargetReduction(B, TTI, Instruction::FMul, Src, Flags,
+ Desc.getFastMathFlags());
case RD::RK_IntegerAdd:
- return createSimpleTargetReduction(B, TTI, Instruction::Add, Src, Flags);
+ return createSimpleTargetReduction(B, TTI, Instruction::Add, Src, Flags,
+ Desc.getFastMathFlags());
case RD::RK_IntegerMult:
- return createSimpleTargetReduction(B, TTI, Instruction::Mul, Src, Flags);
+ return createSimpleTargetReduction(B, TTI, Instruction::Mul, Src, Flags,
+ Desc.getFastMathFlags());
case RD::RK_IntegerAnd:
- return createSimpleTargetReduction(B, TTI, Instruction::And, Src, Flags);
+ return createSimpleTargetReduction(B, TTI, Instruction::And, Src, Flags,
+ Desc.getFastMathFlags());
case RD::RK_IntegerOr:
- return createSimpleTargetReduction(B, TTI, Instruction::Or, Src, Flags);
+ return createSimpleTargetReduction(B, TTI, Instruction::Or, Src, Flags,
+ Desc.getFastMathFlags());
case RD::RK_IntegerXor:
- return createSimpleTargetReduction(B, TTI, Instruction::Xor, Src, Flags);
+ return createSimpleTargetReduction(B, TTI, Instruction::Xor, Src, Flags,
+ Desc.getFastMathFlags());
case RD::RK_IntegerMinMax: {
RD::MinMaxRecurrenceKind MMKind = Desc.getMinMaxRecurrenceKind();
Flags.IsMaxOp = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_UIntMax);
Flags.IsSigned = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_SIntMin);
- return createSimpleTargetReduction(B, TTI, Instruction::ICmp, Src, Flags);
+ return createSimpleTargetReduction(B, TTI, Instruction::ICmp, Src, Flags,
+ Desc.getFastMathFlags());
}
case RD::RK_FloatMinMax: {
Flags.IsMaxOp = Desc.getMinMaxRecurrenceKind() == RD::MRK_FloatMax;
- return createSimpleTargetReduction(B, TTI, Instruction::FCmp, Src, Flags);
+ return createSimpleTargetReduction(B, TTI, Instruction::FCmp, Src, Flags,
+ Desc.getFastMathFlags());
}
default:
llvm_unreachable("Unhandled RecKind");
Modified: llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp?rev=355868&r1=355867&r2=355868&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp (original)
+++ llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp Mon Mar 11 14:36:41 2019
@@ -319,11 +319,14 @@ static unsigned getReciprocalPredBlockPr
/// A helper function that adds a 'fast' flag to floating-point operations.
static Value *addFastMathFlag(Value *V) {
- if (isa<FPMathOperator>(V)) {
- FastMathFlags Flags;
- Flags.setFast();
- cast<Instruction>(V)->setFastMathFlags(Flags);
- }
+ if (isa<FPMathOperator>(V))
+ cast<Instruction>(V)->setFastMathFlags(FastMathFlags::getFast());
+ return V;
+}
+
+static Value *addFastMathFlag(Value *V, FastMathFlags FMF) {
+ if (isa<FPMathOperator>(V))
+ cast<Instruction>(V)->setFastMathFlags(FMF);
return V;
}
@@ -3612,7 +3615,8 @@ void InnerLoopVectorizer::fixReduction(P
// Floating point operations had to be 'fast' to enable the reduction.
ReducedPartRdx = addFastMathFlag(
Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxPart,
- ReducedPartRdx, "bin.rdx"));
+ ReducedPartRdx, "bin.rdx"),
+ RdxDesc.getFastMathFlags());
else
ReducedPartRdx = createMinMaxOp(Builder, MinMaxKind, ReducedPartRdx,
RdxPart);
Modified: llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp?rev=355868&r1=355867&r2=355868&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp (original)
+++ llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp Mon Mar 11 14:36:41 2019
@@ -5929,7 +5929,8 @@ private:
if (!IsPairwiseReduction)
return createSimpleTargetReduction(
Builder, TTI, ReductionData.getOpcode(), VectorizedValue,
- ReductionData.getFlags(), ReductionOps.back());
+ ReductionData.getFlags(), FastMathFlags::getFast(),
+ ReductionOps.back());
Value *TmpVec = VectorizedValue;
for (unsigned i = ReduxWidth / 2; i != 0; i >>= 1) {
Added: llvm/trunk/test/Transforms/LoopVectorize/reduction-fastmath.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/LoopVectorize/reduction-fastmath.ll?rev=355868&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/LoopVectorize/reduction-fastmath.ll (added)
+++ llvm/trunk/test/Transforms/LoopVectorize/reduction-fastmath.ll Mon Mar 11 14:36:41 2019
@@ -0,0 +1,112 @@
+; RUN: opt -S -loop-vectorize < %s | FileCheck %s
+
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+define float @reduction_sum_float_ieee(i32 %n, float* %array) {
+; CHECK-LABEL: define float @reduction_sum_float_ieee(
+entry:
+ %entry.cond = icmp ne i32 0, 4096
+ br i1 %entry.cond, label %loop, label %loop.exit
+
+loop:
+ %idx = phi i32 [ 0, %entry ], [ %idx.inc, %loop ]
+ %sum = phi float [ 0.000000e+00, %entry ], [ %sum.inc, %loop ]
+ %address = getelementptr float, float* %array, i32 %idx
+ %value = load float, float* %address
+ %sum.inc = fadd float %sum, %value
+ %idx.inc = add i32 %idx, 1
+ %be.cond = icmp ne i32 %idx.inc, 4096
+ br i1 %be.cond, label %loop, label %loop.exit
+
+loop.exit:
+ %sum.lcssa = phi float [ %sum.inc, %loop ], [ 0.000000e+00, %entry ]
+; CHECK-NOT: %wide.load = load <4 x float>, <4 x float>*
+; CHECK: ret float %sum.lcssa
+ ret float %sum.lcssa
+}
+
+define float @reduction_sum_float_fastmath(i32 %n, float* %array) {
+; CHECK-LABEL: define float @reduction_sum_float_fastmath(
+; CHECK: fadd fast <4 x float>
+; CHECK: fadd fast <4 x float>
+; CHECK: fadd fast <4 x float>
+; CHECK: fadd fast <4 x float>
+; CHECK: fadd fast <4 x float>
+entry:
+ %entry.cond = icmp ne i32 0, 4096
+ br i1 %entry.cond, label %loop, label %loop.exit
+
+loop:
+ %idx = phi i32 [ 0, %entry ], [ %idx.inc, %loop ]
+ %sum = phi float [ 0.000000e+00, %entry ], [ %sum.inc, %loop ]
+ %address = getelementptr float, float* %array, i32 %idx
+ %value = load float, float* %address
+ %sum.inc = fadd fast float %sum, %value
+ %idx.inc = add i32 %idx, 1
+ %be.cond = icmp ne i32 %idx.inc, 4096
+ br i1 %be.cond, label %loop, label %loop.exit
+
+loop.exit:
+ %sum.lcssa = phi float [ %sum.inc, %loop ], [ 0.000000e+00, %entry ]
+; CHECK: ret float %sum.lcssa
+ ret float %sum.lcssa
+}
+
+define float @reduction_sum_float_only_reassoc(i32 %n, float* %array) {
+; CHECK-LABEL: define float @reduction_sum_float_only_reassoc(
+; CHECK-NOT: fadd fast
+; CHECK: fadd reassoc <4 x float>
+; CHECK: fadd reassoc <4 x float>
+; CHECK: fadd reassoc <4 x float>
+; CHECK: fadd reassoc <4 x float>
+; CHECK: fadd reassoc <4 x float>
+
+entry:
+ %entry.cond = icmp ne i32 0, 4096
+ br i1 %entry.cond, label %loop, label %loop.exit
+
+loop:
+ %idx = phi i32 [ 0, %entry ], [ %idx.inc, %loop ]
+ %sum = phi float [ 0.000000e+00, %entry ], [ %sum.inc, %loop ]
+ %address = getelementptr float, float* %array, i32 %idx
+ %value = load float, float* %address
+ %sum.inc = fadd reassoc float %sum, %value
+ %idx.inc = add i32 %idx, 1
+ %be.cond = icmp ne i32 %idx.inc, 4096
+ br i1 %be.cond, label %loop, label %loop.exit
+
+loop.exit:
+ %sum.lcssa = phi float [ %sum.inc, %loop ], [ 0.000000e+00, %entry ]
+; CHECK: ret float %sum.lcssa
+ ret float %sum.lcssa
+}
+
+define float @reduction_sum_float_only_reassoc_and_contract(i32 %n, float* %array) {
+; CHECK-LABEL: define float @reduction_sum_float_only_reassoc_and_contract(
+; CHECK-NOT: fadd fast
+; CHECK: fadd reassoc contract <4 x float>
+; CHECK: fadd reassoc contract <4 x float>
+; CHECK: fadd reassoc contract <4 x float>
+; CHECK: fadd reassoc contract <4 x float>
+; CHECK: fadd reassoc contract <4 x float>
+
+entry:
+ %entry.cond = icmp ne i32 0, 4096
+ br i1 %entry.cond, label %loop, label %loop.exit
+
+loop:
+ %idx = phi i32 [ 0, %entry ], [ %idx.inc, %loop ]
+ %sum = phi float [ 0.000000e+00, %entry ], [ %sum.inc, %loop ]
+ %address = getelementptr float, float* %array, i32 %idx
+ %value = load float, float* %address
+ %sum.inc = fadd reassoc contract float %sum, %value
+ %idx.inc = add i32 %idx, 1
+ %be.cond = icmp ne i32 %idx.inc, 4096
+ br i1 %be.cond, label %loop, label %loop.exit
+
+loop.exit:
+ %sum.lcssa = phi float [ %sum.inc, %loop ], [ 0.000000e+00, %entry ]
+; CHECK: ret float %sum.lcssa
+ ret float %sum.lcssa
+}
More information about the llvm-commits
mailing list