[llvm-branch-commits] [llvm] 1c54112 - [SLP] refactor more reduction functions; NFC
Sanjay Patel via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jan 20 08:20:39 PST 2021
Author: Sanjay Patel
Date: 2021-01-20T11:14:48-05:00
New Revision: 1c54112a5762ebab2c14a90c55f27d00bfced7f8
URL: https://github.com/llvm/llvm-project/commit/1c54112a5762ebab2c14a90c55f27d00bfced7f8
DIFF: https://github.com/llvm/llvm-project/commit/1c54112a5762ebab2c14a90c55f27d00bfced7f8.diff
LOG: [SLP] refactor more reduction functions; NFC
We were able to remove almost all of the state from
OperationData, so these don't make sense as members
of that class - just pass the RecurKind in as a param.
More streamlining is possible, but I'm trying to avoid
logic/typo bugs while fixing this. Eventually, we should
not need the `OperationData` class.
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 3d657b0b898c..3192d7959f70 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6427,76 +6427,6 @@ class HorizontalReduction {
return IsLeafValue || Kind != RecurKind::None;
}
- /// Return true if this operation is a cmp+select idiom.
- bool isCmpSel() const {
- assert(Kind != RecurKind::None && "Expected reduction operation.");
- return RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind);
- }
-
- /// Get the index of the first operand.
- unsigned getFirstOperandIndex() const {
- assert(!!*this && "The opcode is not set.");
- // We allow calling this before 'Kind' is set, so handle that specially.
- if (Kind == RecurKind::None)
- return 0;
- return isCmpSel() ? 1 : 0;
- }
-
- /// Total number of operands in the reduction operation.
- unsigned getNumberOfOperands() const {
- assert(Kind != RecurKind::None && !!*this &&
- "Expected reduction operation.");
- return isCmpSel() ? 3 : 2;
- }
-
- /// Checks if the instruction is in basic block \p BB.
- /// For a min/max reduction check that both compare and select are in \p BB.
- bool hasSameParent(Instruction *I, BasicBlock *BB, bool IsRedOp) const {
- assert(Kind != RecurKind::None && !!*this &&
- "Expected reduction operation.");
- if (IsRedOp && isCmpSel()) {
- auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition());
- return I->getParent() == BB && Cmp && Cmp->getParent() == BB;
- }
- return I->getParent() == BB;
- }
-
- /// Expected number of uses for reduction operations/reduced values.
- bool hasRequiredNumberOfUses(Instruction *I, bool IsReductionOp) const {
- assert(Kind != RecurKind::None && !!*this &&
- "Expected reduction operation.");
- // SelectInst must be used twice while the condition op must have single
- // use only.
- if (isCmpSel())
- return I->hasNUses(2) &&
- (!IsReductionOp ||
- cast<SelectInst>(I)->getCondition()->hasOneUse());
-
- // Arithmetic reduction operation must be used once only.
- return I->hasOneUse();
- }
-
- /// Initializes the list of reduction operations.
- void initReductionOps(ReductionOpsListType &ReductionOps) {
- assert(Kind != RecurKind::None && !!*this &&
- "Expected reduction operation.");
- if (isCmpSel())
- ReductionOps.assign(2, ReductionOpsType());
- else
- ReductionOps.assign(1, ReductionOpsType());
- }
-
- /// Add all reduction operations for the reduction instruction \p I.
- void addReductionOps(Instruction *I, ReductionOpsListType &ReductionOps) {
- assert(Kind != RecurKind::None && "Expected reduction operation.");
- if (isCmpSel()) {
- ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition());
- ReductionOps[1].emplace_back(I);
- } else {
- ReductionOps[0].emplace_back(I);
- }
- }
-
/// Checks if instruction is associative and can be vectorized.
bool isAssociative(Instruction *I) const {
assert(Kind != RecurKind::None && "Expected reduction operation.");
@@ -6529,16 +6459,6 @@ class HorizontalReduction {
/// Get kind of reduction data.
RecurKind getKind() const { return Kind; }
- Value *getLHS(Instruction *I) const {
- if (Kind == RecurKind::None)
- return nullptr;
- return I->getOperand(getFirstOperandIndex());
- }
- Value *getRHS(Instruction *I) const {
- if (Kind == RecurKind::None)
- return nullptr;
- return I->getOperand(getFirstOperandIndex() + 1);
- }
};
WeakTrackingVH ReductionRoot;
@@ -6559,7 +6479,7 @@ class HorizontalReduction {
// Do not perform analysis of remaining operands of ParentStackElem.first
// instruction, this whole instruction is an extra argument.
OperationData OpData = getOperationData(ParentStackElem.first);
- ParentStackElem.second = OpData.getNumberOfOperands();
+ ParentStackElem.second = getNumberOfOperands(OpData.getKind());
} else {
// We ran into something like:
// ParentStackElem.first += ... + ExtraArg + ...
@@ -6730,6 +6650,81 @@ class HorizontalReduction {
return OperationData(*I);
}
+ /// Return true if this operation is a cmp+select idiom.
+ static bool isCmpSel(RecurKind Kind) {
+ return RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind);
+ }
+
+ /// Get the index of the first operand.
+ static unsigned getFirstOperandIndex(RecurKind Kind) {
+ // We allow calling this before 'Kind' is set, so handle that specially.
+ if (Kind == RecurKind::None)
+ return 0;
+ return isCmpSel(Kind) ? 1 : 0;
+ }
+
+ /// Total number of operands in the reduction operation.
+ static unsigned getNumberOfOperands(RecurKind Kind) {
+ return isCmpSel(Kind) ? 3 : 2;
+ }
+
+ /// Checks if the instruction is in basic block \p BB.
+ /// For a min/max reduction check that both compare and select are in \p BB.
+ static bool hasSameParent(RecurKind Kind, Instruction *I, BasicBlock *BB,
+ bool IsRedOp) {
+ if (IsRedOp && isCmpSel(Kind)) {
+ auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition());
+ return I->getParent() == BB && Cmp && Cmp->getParent() == BB;
+ }
+ return I->getParent() == BB;
+ }
+
+ /// Expected number of uses for reduction operations/reduced values.
+ static bool hasRequiredNumberOfUses(RecurKind Kind, Instruction *I,
+ bool IsReductionOp) {
+ // SelectInst must be used twice while the condition op must have single
+ // use only.
+ if (isCmpSel(Kind))
+ return I->hasNUses(2) &&
+ (!IsReductionOp ||
+ cast<SelectInst>(I)->getCondition()->hasOneUse());
+
+ // Arithmetic reduction operation must be used once only.
+ return I->hasOneUse();
+ }
+
+ /// Initializes the list of reduction operations.
+ static void initReductionOps(RecurKind Kind,
+ ReductionOpsListType &ReductionOps) {
+ if (isCmpSel(Kind))
+ ReductionOps.assign(2, ReductionOpsType());
+ else
+ ReductionOps.assign(1, ReductionOpsType());
+ }
+
+ /// Add all reduction operations for the reduction instruction \p I.
+ static void addReductionOps(RecurKind Kind, Instruction *I,
+ ReductionOpsListType &ReductionOps) {
+ assert(Kind != RecurKind::None && "Expected reduction operation.");
+ if (isCmpSel(Kind)) {
+ ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition());
+ ReductionOps[1].emplace_back(I);
+ } else {
+ ReductionOps[0].emplace_back(I);
+ }
+ }
+
+ static Value *getLHS(RecurKind Kind, Instruction *I) {
+ if (Kind == RecurKind::None)
+ return nullptr;
+ return I->getOperand(getFirstOperandIndex(Kind));
+ }
+ static Value *getRHS(RecurKind Kind, Instruction *I) {
+ if (Kind == RecurKind::None)
+ return nullptr;
+ return I->getOperand(getFirstOperandIndex(Kind) + 1);
+ }
+
public:
HorizontalReduction() = default;
@@ -6744,13 +6739,13 @@ class HorizontalReduction {
// r *= v1 + v2 + v3 + v4
// In such a case start looking for a tree rooted in the first '+'.
if (Phi) {
- if (RdxTreeInst.getLHS(B) == Phi) {
+ if (getLHS(RdxTreeInst.getKind(), B) == Phi) {
Phi = nullptr;
- B = dyn_cast<Instruction>(RdxTreeInst.getRHS(B));
+ B = dyn_cast<Instruction>(getRHS(RdxTreeInst.getKind(), B));
RdxTreeInst = getOperationData(B);
- } else if (RdxTreeInst.getRHS(B) == Phi) {
+ } else if (getRHS(RdxTreeInst.getKind(), B) == Phi) {
Phi = nullptr;
- B = dyn_cast<Instruction>(RdxTreeInst.getLHS(B));
+ B = dyn_cast<Instruction>(getLHS(RdxTreeInst.getKind(), B));
RdxTreeInst = getOperationData(B);
}
}
@@ -6775,8 +6770,9 @@ class HorizontalReduction {
// Post order traverse the reduction tree starting at B. We only handle true
// trees containing only binary operators.
SmallVector<std::pair<Instruction *, unsigned>, 32> Stack;
- Stack.push_back(std::make_pair(B, RdxTreeInst.getFirstOperandIndex()));
- RdxTreeInst.initReductionOps(ReductionOps);
+ Stack.push_back(
+ std::make_pair(B, getFirstOperandIndex(RdxTreeInst.getKind())));
+ initReductionOps(RdxTreeInst.getKind(), ReductionOps);
while (!Stack.empty()) {
Instruction *TreeN = Stack.back().first;
unsigned EdgeToVisit = Stack.back().second++;
@@ -6784,7 +6780,8 @@ class HorizontalReduction {
bool IsReducedValue = OpData != RdxTreeInst;
// Postorder visit.
- if (IsReducedValue || EdgeToVisit == OpData.getNumberOfOperands()) {
+ if (IsReducedValue ||
+ EdgeToVisit == getNumberOfOperands(OpData.getKind())) {
if (IsReducedValue)
ReducedVals.push_back(TreeN);
else {
@@ -6802,7 +6799,7 @@ class HorizontalReduction {
markExtraArg(Stack[Stack.size() - 2], TreeN);
ExtraArgs.erase(TreeN);
} else
- RdxTreeInst.addReductionOps(TreeN, ReductionOps);
+ addReductionOps(RdxTreeInst.getKind(), TreeN, ReductionOps);
}
// Retract.
Stack.pop_back();
@@ -6822,8 +6819,8 @@ class HorizontalReduction {
// ultimate reduction.
const bool IsRdxInst = EdgeOpData == RdxTreeInst;
if (I && I != Phi && I != B &&
- RdxTreeInst.hasSameParent(I, B->getParent(), IsRdxInst) &&
- RdxTreeInst.hasRequiredNumberOfUses(I, IsRdxInst) &&
+ hasSameParent(RdxTreeInst.getKind(), I, B->getParent(), IsRdxInst) &&
+ hasRequiredNumberOfUses(RdxTreeInst.getKind(), I, IsRdxInst) &&
(!LeafOpcode || LeafOpcode == I->getOpcode() || IsRdxInst)) {
if (IsRdxInst) {
// We need to be able to reassociate the reduction operations.
@@ -6835,7 +6832,8 @@ class HorizontalReduction {
} else if (!LeafOpcode) {
LeafOpcode = I->getOpcode();
}
- Stack.push_back(std::make_pair(I, EdgeOpData.getFirstOperandIndex()));
+ Stack.push_back(
+ std::make_pair(I, getFirstOperandIndex(EdgeOpData.getKind())));
continue;
}
// NextV is an extra argument for TreeN (its parent operation).
@@ -6976,7 +6974,7 @@ class HorizontalReduction {
// Emit a reduction. If the root is a select (min/max idiom), the insert
// point is the compare condition of that select.
Instruction *RdxRootInst = cast<Instruction>(ReductionRoot);
- if (RdxTreeInst.isCmpSel())
+ if (isCmpSel(RdxTreeInst.getKind()))
Builder.SetInsertPoint(getCmpForMinMaxReduction(RdxRootInst));
else
Builder.SetInsertPoint(RdxRootInst);
@@ -7019,7 +7017,7 @@ class HorizontalReduction {
// select, we also have to RAUW for the compare instruction feeding the
// reduction root. That's because the original compare may have extra uses
// besides the final select of the reduction.
- if (RdxTreeInst.isCmpSel()) {
+ if (isCmpSel(RdxTreeInst.getKind())) {
if (auto *VecSelect = dyn_cast<SelectInst>(VectorizedTree)) {
Instruction *ScalarCmp =
getCmpForMinMaxReduction(cast<Instruction>(ReductionRoot));
More information about the llvm-branch-commits
mailing list