[llvm-branch-commits] [llvm] d777533 - [SLP] simplify reduction matching
Sanjay Patel via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 21 12:10:59 PST 2021
Author: Sanjay Patel
Date: 2021-01-21T14:58:57-05:00
New Revision: d77753381fe024434ae8ffaaacfe4b9ed9d4d760
URL: https://github.com/llvm/llvm-project/commit/d77753381fe024434ae8ffaaacfe4b9ed9d4d760
DIFF: https://github.com/llvm/llvm-project/commit/d77753381fe024434ae8ffaaacfe4b9ed9d4d760.diff
LOG: [SLP] simplify reduction matching
This is NFC-intended and removes the "OperationData"
class which had become nothing more than a recurrence
(reduction) type.
I adjusted the matching logic to distinguish
instructions from non-instructions - that's all that
the "IsLeafValue" member was keeping track of.
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2597f88ab88d..73260016f443 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6401,44 +6401,9 @@ class HorizontalReduction {
SmallVector<Value *, 32> ReducedVals;
// Use map vector to make stable output.
MapVector<Instruction *, Value *> ExtraArgs;
-
- /// This wraps functionality around a RecurKind (reduction kind).
- /// TODO: Remove this class if callers can use the 'Kind' value directly?
- class OperationData {
- /// Kind of the reduction operation.
- RecurKind Kind = RecurKind::None;
- bool IsLeafValue = false;
-
- public:
- explicit OperationData() = default;
-
- /// Constructor for reduced values. They are identified by the bool only.
- explicit OperationData(Instruction &I) { IsLeafValue = true; }
-
- /// Constructor for reduction operations with opcode and type.
- OperationData(RecurKind RdxKind) : Kind(RdxKind) {
- assert(Kind != RecurKind::None && "Expected reduction operation.");
- }
-
- explicit operator bool() const {
- return IsLeafValue || Kind != RecurKind::None;
- }
-
- /// Checks if two operation data are both a reduction op or both a reduced
- /// value.
- bool operator==(const OperationData &OD) const {
- return Kind == OD.Kind && IsLeafValue == OD.IsLeafValue;
- }
- bool operator!=(const OperationData &OD) const { return !(*this == OD); }
-
- /// Get kind of reduction data.
- RecurKind getKind() const { return Kind; }
- };
-
WeakTrackingVH ReductionRoot;
-
- /// The operation data of the reduction operation.
- OperationData RdxTreeInst;
+ /// The type of reduction operation.
+ RecurKind RdxKind;
/// Checks if instruction is associative and can be vectorized.
static bool isVectorizable(RecurKind Kind, Instruction *I) {
@@ -6471,8 +6436,8 @@ class HorizontalReduction {
// in this case.
// Do not perform analysis of remaining operands of ParentStackElem.first
// instruction, this whole instruction is an extra argument.
- OperationData OpData = getOperationData(ParentStackElem.first);
- ParentStackElem.second = getNumberOfOperands(OpData.getKind());
+ RecurKind RdxKind = getRdxKind(ParentStackElem.first);
+ ParentStackElem.second = getNumberOfOperands(RdxKind);
} else {
// We ran into something like:
// ParentStackElem.first += ... + ExtraArg + ...
@@ -6550,39 +6515,37 @@ class HorizontalReduction {
return Op;
}
- static OperationData getOperationData(Instruction *I) {
- if (!I)
- return OperationData();
-
+ static RecurKind getRdxKind(Instruction *I) {
+ assert(I && "Expected instruction for reduction matching");
TargetTransformInfo::ReductionFlags RdxFlags;
if (match(I, m_Add(m_Value(), m_Value())))
- return OperationData(RecurKind::Add);
+ return RecurKind::Add;
if (match(I, m_Mul(m_Value(), m_Value())))
- return OperationData(RecurKind::Mul);
+ return RecurKind::Mul;
if (match(I, m_And(m_Value(), m_Value())))
- return OperationData(RecurKind::And);
+ return RecurKind::And;
if (match(I, m_Or(m_Value(), m_Value())))
- return OperationData(RecurKind::Or);
+ return RecurKind::Or;
if (match(I, m_Xor(m_Value(), m_Value())))
- return OperationData(RecurKind::Xor);
+ return RecurKind::Xor;
if (match(I, m_FAdd(m_Value(), m_Value())))
- return OperationData(RecurKind::FAdd);
+ return RecurKind::FAdd;
if (match(I, m_FMul(m_Value(), m_Value())))
- return OperationData(RecurKind::FMul);
+ return RecurKind::FMul;
if (match(I, m_Intrinsic<Intrinsic::maxnum>(m_Value(), m_Value())))
- return OperationData(RecurKind::FMax);
+ return RecurKind::FMax;
if (match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_Value())))
- return OperationData(RecurKind::FMin);
+ return RecurKind::FMin;
if (match(I, m_SMax(m_Value(), m_Value())))
- return OperationData(RecurKind::SMax);
+ return RecurKind::SMax;
if (match(I, m_SMin(m_Value(), m_Value())))
- return OperationData(RecurKind::SMin);
+ return RecurKind::SMin;
if (match(I, m_UMax(m_Value(), m_Value())))
- return OperationData(RecurKind::UMax);
+ return RecurKind::UMax;
if (match(I, m_UMin(m_Value(), m_Value())))
- return OperationData(RecurKind::UMin);
+ return RecurKind::UMin;
if (auto *Select = dyn_cast<SelectInst>(I)) {
// Try harder: look for min/max pattern based on instructions producing
@@ -6608,39 +6571,39 @@ class HorizontalReduction {
if (match(Cond, m_Cmp(Pred, m_Specific(LHS), m_Instruction(L2)))) {
if (!isa<ExtractElementInst>(RHS) ||
!L2->isIdenticalTo(cast<Instruction>(RHS)))
- return OperationData(*I);
+ return RecurKind::None;
} else if (match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Specific(RHS)))) {
if (!isa<ExtractElementInst>(LHS) ||
!L1->isIdenticalTo(cast<Instruction>(LHS)))
- return OperationData(*I);
+ return RecurKind::None;
} else {
if (!isa<ExtractElementInst>(LHS) || !isa<ExtractElementInst>(RHS))
- return OperationData(*I);
+ return RecurKind::None;
if (!match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2))) ||
!L1->isIdenticalTo(cast<Instruction>(LHS)) ||
!L2->isIdenticalTo(cast<Instruction>(RHS)))
- return OperationData(*I);
+ return RecurKind::None;
}
TargetTransformInfo::ReductionFlags RdxFlags;
switch (Pred) {
default:
- return OperationData(*I);
+ return RecurKind::None;
case CmpInst::ICMP_SGT:
case CmpInst::ICMP_SGE:
- return OperationData(RecurKind::SMax);
+ return RecurKind::SMax;
case CmpInst::ICMP_SLT:
case CmpInst::ICMP_SLE:
- return OperationData(RecurKind::SMin);
+ return RecurKind::SMin;
case CmpInst::ICMP_UGT:
case CmpInst::ICMP_UGE:
- return OperationData(RecurKind::UMax);
+ return RecurKind::UMax;
case CmpInst::ICMP_ULT:
case CmpInst::ICMP_ULE:
- return OperationData(RecurKind::UMin);
+ return RecurKind::UMin;
}
}
- return OperationData(*I);
+ return RecurKind::None;
}
/// Return true if this operation is a cmp+select idiom.
@@ -6724,24 +6687,28 @@ class HorizontalReduction {
assert((!Phi || is_contained(Phi->operands(), B)) &&
"Phi needs to use the binary operator");
- RdxTreeInst = getOperationData(B);
+ RdxKind = getRdxKind(B);
// We could have a initial reductions that is not an add.
// r *= v1 + v2 + v3 + v4
// In such a case start looking for a tree rooted in the first '+'.
if (Phi) {
- if (getLHS(RdxTreeInst.getKind(), B) == Phi) {
+ if (getLHS(RdxKind, B) == Phi) {
Phi = nullptr;
- B = dyn_cast<Instruction>(getRHS(RdxTreeInst.getKind(), B));
- RdxTreeInst = getOperationData(B);
- } else if (getRHS(RdxTreeInst.getKind(), B) == Phi) {
+ B = dyn_cast<Instruction>(getRHS(RdxKind, B));
+ if (!B)
+ return false;
+ RdxKind = getRdxKind(B);
+ } else if (getRHS(RdxKind, B) == Phi) {
Phi = nullptr;
- B = dyn_cast<Instruction>(getLHS(RdxTreeInst.getKind(), B));
- RdxTreeInst = getOperationData(B);
+ B = dyn_cast<Instruction>(getLHS(RdxKind, B));
+ if (!B)
+ return false;
+ RdxKind = getRdxKind(B);
}
}
- if (!isVectorizable(RdxTreeInst.getKind(), B))
+ if (!isVectorizable(RdxKind, B))
return false;
// Analyze "regular" integer/FP types for reductions - no target-specific
@@ -6761,18 +6728,16 @@ class HorizontalReduction {
// Post order traverse the reduction tree starting at B. We only handle true
// trees containing only binary operators.
SmallVector<std::pair<Instruction *, unsigned>, 32> Stack;
- Stack.push_back(
- std::make_pair(B, getFirstOperandIndex(RdxTreeInst.getKind())));
- initReductionOps(RdxTreeInst.getKind());
+ Stack.push_back(std::make_pair(B, getFirstOperandIndex(RdxKind)));
+ initReductionOps(RdxKind);
while (!Stack.empty()) {
Instruction *TreeN = Stack.back().first;
unsigned EdgeToVisit = Stack.back().second++;
- const OperationData OpData = getOperationData(TreeN);
- bool IsReducedValue = OpData != RdxTreeInst;
+ const RecurKind TreeRdxKind = getRdxKind(TreeN);
+ bool IsReducedValue = TreeRdxKind != RdxKind;
// Postorder visit.
- if (IsReducedValue ||
- EdgeToVisit == getNumberOfOperands(OpData.getKind())) {
+ if (IsReducedValue || EdgeToVisit == getNumberOfOperands(TreeRdxKind)) {
if (IsReducedValue)
ReducedVals.push_back(TreeN);
else {
@@ -6790,7 +6755,7 @@ class HorizontalReduction {
markExtraArg(Stack[Stack.size() - 2], TreeN);
ExtraArgs.erase(TreeN);
} else
- addReductionOps(RdxTreeInst.getKind(), TreeN);
+ addReductionOps(RdxKind, TreeN);
}
// Retract.
Stack.pop_back();
@@ -6798,9 +6763,15 @@ class HorizontalReduction {
}
// Visit left or right.
- Value *NextV = TreeN->getOperand(EdgeToVisit);
- auto *I = dyn_cast<Instruction>(NextV);
- const OperationData EdgeOpData = getOperationData(I);
+ Value *EdgeVal = TreeN->getOperand(EdgeToVisit);
+ auto *I = dyn_cast<Instruction>(EdgeVal);
+ if (!I) {
+ // Edge value is not a reduction instruction or a leaf instruction.
+ // (It may be a constant, function argument, or something else.)
+ markExtraArg(Stack.back(), EdgeVal);
+ continue;
+ }
+ RecurKind EdgeRdxKind = getRdxKind(I);
// Continue analysis if the next operand is a reduction operation or
// (possibly) a leaf value. If the leaf value opcode is not set,
// the first met operation != reduction operation is considered as the
@@ -6808,14 +6779,14 @@ class HorizontalReduction {
// Only handle trees in the current basic block.
// Each tree node needs to have minimal number of users except for the
// ultimate reduction.
- const bool IsRdxInst = EdgeOpData == RdxTreeInst;
- if (I && I != Phi && I != B &&
- hasSameParent(RdxTreeInst.getKind(), I, B->getParent(), IsRdxInst) &&
- hasRequiredNumberOfUses(RdxTreeInst.getKind(), I, IsRdxInst) &&
+ const bool IsRdxInst = EdgeRdxKind == RdxKind;
+ if (I != Phi && I != B &&
+ hasSameParent(RdxKind, I, B->getParent(), IsRdxInst) &&
+ hasRequiredNumberOfUses(RdxKind, I, IsRdxInst) &&
(!LeafOpcode || LeafOpcode == I->getOpcode() || IsRdxInst)) {
if (IsRdxInst) {
// We need to be able to reassociate the reduction operations.
- if (!isVectorizable(EdgeOpData.getKind(), I)) {
+ if (!isVectorizable(EdgeRdxKind, I)) {
// I is an extra argument for TreeN (its parent operation).
markExtraArg(Stack.back(), I);
continue;
@@ -6823,12 +6794,11 @@ class HorizontalReduction {
} else if (!LeafOpcode) {
LeafOpcode = I->getOpcode();
}
- Stack.push_back(
- std::make_pair(I, getFirstOperandIndex(EdgeOpData.getKind())));
+ Stack.push_back(std::make_pair(I, getFirstOperandIndex(EdgeRdxKind)));
continue;
}
- // NextV is an extra argument for TreeN (its parent operation).
- markExtraArg(Stack.back(), NextV);
+ // I is an extra argument for TreeN (its parent operation).
+ markExtraArg(Stack.back(), I);
}
return true;
}
@@ -6922,7 +6892,7 @@ class HorizontalReduction {
}
if (V.isTreeTinyAndNotFullyVectorizable())
break;
- if (V.isLoadCombineReductionCandidate(RdxTreeInst.getKind()))
+ if (V.isLoadCombineReductionCandidate(RdxKind))
break;
V.computeMinimumValueSizes();
@@ -6965,7 +6935,7 @@ class HorizontalReduction {
// Emit a reduction. If the root is a select (min/max idiom), the insert
// point is the compare condition of that select.
Instruction *RdxRootInst = cast<Instruction>(ReductionRoot);
- if (isCmpSel(RdxTreeInst.getKind()))
+ if (isCmpSel(RdxKind))
Builder.SetInsertPoint(getCmpForMinMaxReduction(RdxRootInst));
else
Builder.SetInsertPoint(RdxRootInst);
@@ -6979,9 +6949,8 @@ class HorizontalReduction {
} else {
// Update the final value in the reduction.
Builder.SetCurrentDebugLocation(Loc);
- VectorizedTree =
- createOp(Builder, RdxTreeInst.getKind(), VectorizedTree,
- ReducedSubTree, "op.rdx", ReductionOps);
+ VectorizedTree = createOp(Builder, RdxKind, VectorizedTree,
+ ReducedSubTree, "op.rdx", ReductionOps);
}
i += ReduxWidth;
ReduxWidth = PowerOf2Floor(NumReducedVals - i);
@@ -6992,15 +6961,15 @@ class HorizontalReduction {
for (; i < NumReducedVals; ++i) {
auto *I = cast<Instruction>(ReducedVals[i]);
Builder.SetCurrentDebugLocation(I->getDebugLoc());
- VectorizedTree = createOp(Builder, RdxTreeInst.getKind(),
- VectorizedTree, I, "", ReductionOps);
+ VectorizedTree =
+ createOp(Builder, RdxKind, VectorizedTree, I, "", ReductionOps);
}
for (auto &Pair : ExternallyUsedValues) {
// Add each externally used value to the final reduction.
for (auto *I : Pair.second) {
Builder.SetCurrentDebugLocation(I->getDebugLoc());
- VectorizedTree = createOp(Builder, RdxTreeInst.getKind(),
- VectorizedTree, Pair.first, "op.extra", I);
+ VectorizedTree = createOp(Builder, RdxKind, VectorizedTree,
+ Pair.first, "op.extra", I);
}
}
@@ -7008,7 +6977,7 @@ class HorizontalReduction {
// select, we also have to RAUW for the compare instruction feeding the
// reduction root. That's because the original compare may have extra uses
// besides the final select of the reduction.
- if (isCmpSel(RdxTreeInst.getKind())) {
+ if (isCmpSel(RdxKind)) {
if (auto *VecSelect = dyn_cast<SelectInst>(VectorizedTree)) {
Instruction *ScalarCmp =
getCmpForMinMaxReduction(cast<Instruction>(ReductionRoot));
@@ -7032,10 +7001,8 @@ class HorizontalReduction {
unsigned ReduxWidth) {
Type *ScalarTy = FirstReducedVal->getType();
FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth);
-
- RecurKind Kind = RdxTreeInst.getKind();
int VectorCost, ScalarCost;
- switch (Kind) {
+ switch (RdxKind) {
case RecurKind::Add:
case RecurKind::Mul:
case RecurKind::Or:
@@ -7043,7 +7010,7 @@ class HorizontalReduction {
case RecurKind::Xor:
case RecurKind::FAdd:
case RecurKind::FMul: {
- unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
+ unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(RdxKind);
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
/*IsPairwiseForm=*/false);
ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy);
@@ -7066,7 +7033,8 @@ class HorizontalReduction {
case RecurKind::UMax:
case RecurKind::UMin: {
auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VectorTy));
- bool IsUnsigned = Kind == RecurKind::UMax || Kind == RecurKind::UMin;
+ bool IsUnsigned =
+ RdxKind == RecurKind::UMax || RdxKind == RecurKind::UMin;
VectorCost =
TTI->getMinMaxReductionCost(VectorTy, VecCondTy,
/*IsPairwiseForm=*/false, IsUnsigned);
@@ -7098,8 +7066,7 @@ class HorizontalReduction {
// FIXME: The builder should use an FMF guard. It should not be hard-coded
// to 'fast'.
assert(Builder.getFastMathFlags().isFast() && "Expected 'fast' FMF");
- return createSimpleTargetReduction(Builder, TTI, VectorizedValue,
- RdxTreeInst.getKind(),
+ return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind,
ReductionOps.back());
}
};
More information about the llvm-branch-commits
mailing list