[llvm] 36b423e - [SLP] NFC. Refactor getSameOpcode and reduce for loop iterations. (#122241)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 9 17:06:10 PST 2025
Author: Han-Kuan Chen
Date: 2025-01-10T09:06:07+08:00
New Revision: 36b423e0f85cb35eb8b211662a0fab70d476f501
URL: https://github.com/llvm/llvm-project/commit/36b423e0f85cb35eb8b211662a0fab70d476f501
DIFF: https://github.com/llvm/llvm-project/commit/36b423e0f85cb35eb8b211662a0fab70d476f501.diff
LOG: [SLP] NFC. Refactor getSameOpcode and reduce for loop iterations. (#122241)
Replace Cnt and AltIndex with MainOp and AltOp.
Reduce the number of iterations in the for loop.
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 6360ddb57007d6..8ff70fdb1180b0 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -916,24 +916,22 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
if (It == VL.end())
return InstructionsState::invalid();
- Value *V = *It;
+ Instruction *MainOp = cast<Instruction>(*It);
unsigned InstCnt = std::count_if(It, VL.end(), IsaPred<Instruction>);
- if ((VL.size() > 2 && !isa<PHINode>(V) && InstCnt < VL.size() / 2) ||
+ if ((VL.size() > 2 && !isa<PHINode>(MainOp) && InstCnt < VL.size() / 2) ||
(VL.size() == 2 && InstCnt < 2))
return InstructionsState::invalid();
- bool IsCastOp = isa<CastInst>(V);
- bool IsBinOp = isa<BinaryOperator>(V);
- bool IsCmpOp = isa<CmpInst>(V);
- CmpInst::Predicate BasePred =
- IsCmpOp ? cast<CmpInst>(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE;
- unsigned Opcode = cast<Instruction>(V)->getOpcode();
+ bool IsCastOp = isa<CastInst>(MainOp);
+ bool IsBinOp = isa<BinaryOperator>(MainOp);
+ bool IsCmpOp = isa<CmpInst>(MainOp);
+ CmpInst::Predicate BasePred = IsCmpOp ? cast<CmpInst>(MainOp)->getPredicate()
+ : CmpInst::BAD_ICMP_PREDICATE;
+ Instruction *AltOp = MainOp;
+ unsigned Opcode = MainOp->getOpcode();
unsigned AltOpcode = Opcode;
- unsigned AltIndex = std::distance(VL.begin(), It);
- bool SwappedPredsCompatible = [&]() {
- if (!IsCmpOp)
- return false;
+ bool SwappedPredsCompatible = IsCmpOp && [&]() {
SetVector<unsigned> UniquePreds, UniqueNonSwappedPreds;
UniquePreds.insert(BasePred);
UniqueNonSwappedPreds.insert(BasePred);
@@ -956,18 +954,18 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
}();
// Check for one alternate opcode from another BinaryOperator.
// TODO - generalize to support all operators (types, calls etc.).
- auto *IBase = cast<Instruction>(V);
Intrinsic::ID BaseID = 0;
SmallVector<VFInfo> BaseMappings;
- if (auto *CallBase = dyn_cast<CallInst>(IBase)) {
+ if (auto *CallBase = dyn_cast<CallInst>(MainOp)) {
BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI);
BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase);
if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
return InstructionsState::invalid();
}
bool AnyPoison = InstCnt != VL.size();
- for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
- auto *I = dyn_cast<Instruction>(VL[Cnt]);
+ // Skip MainOp.
+ for (Value *V : iterator_range(It + 1, VL.end())) {
+ auto *I = dyn_cast<Instruction>(V);
if (!I)
continue;
@@ -983,11 +981,11 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) &&
isValidForAlternation(Opcode)) {
AltOpcode = InstOpcode;
- AltIndex = Cnt;
+ AltOp = I;
continue;
}
} else if (IsCastOp && isa<CastInst>(I)) {
- Value *Op0 = IBase->getOperand(0);
+ Value *Op0 = MainOp->getOperand(0);
Type *Ty0 = Op0->getType();
Value *Op1 = I->getOperand(0);
Type *Ty1 = Op1->getType();
@@ -999,12 +997,12 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
isValidForAlternation(InstOpcode) &&
"Cast isn't safe for alternation, logic needs to be updated!");
AltOpcode = InstOpcode;
- AltIndex = Cnt;
+ AltOp = I;
continue;
}
}
- } else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt]); Inst && IsCmpOp) {
- auto *BaseInst = cast<CmpInst>(V);
+ } else if (auto *Inst = dyn_cast<CmpInst>(I); Inst && IsCmpOp) {
+ auto *BaseInst = cast<CmpInst>(MainOp);
Type *Ty0 = BaseInst->getOperand(0)->getType();
Type *Ty1 = Inst->getOperand(0)->getType();
if (Ty0 == Ty1) {
@@ -1018,21 +1016,21 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
CmpInst::Predicate SwappedCurrentPred =
CmpInst::getSwappedPredicate(CurrentPred);
- if ((E == 2 || SwappedPredsCompatible) &&
+ if ((VL.size() == 2 || SwappedPredsCompatible) &&
(BasePred == CurrentPred || BasePred == SwappedCurrentPred))
continue;
if (isCmpSameOrSwapped(BaseInst, Inst, TLI))
continue;
- auto *AltInst = cast<CmpInst>(VL[AltIndex]);
- if (AltIndex) {
+ auto *AltInst = cast<CmpInst>(AltOp);
+ if (MainOp != AltOp) {
if (isCmpSameOrSwapped(AltInst, Inst, TLI))
continue;
} else if (BasePred != CurrentPred) {
assert(
isValidForAlternation(InstOpcode) &&
"CmpInst isn't safe for alternation, logic needs to be updated!");
- AltIndex = Cnt;
+ AltOp = I;
continue;
}
CmpInst::Predicate AltPred = AltInst->getPredicate();
@@ -1046,17 +1044,17 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
"CastInst.");
if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
if (Gep->getNumOperands() != 2 ||
- Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType())
+ Gep->getOperand(0)->getType() != MainOp->getOperand(0)->getType())
return InstructionsState::invalid();
} else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
if (!isVectorLikeInstWithConstOps(EI))
return InstructionsState::invalid();
} else if (auto *LI = dyn_cast<LoadInst>(I)) {
- auto *BaseLI = cast<LoadInst>(IBase);
+ auto *BaseLI = cast<LoadInst>(MainOp);
if (!LI->isSimple() || !BaseLI->isSimple())
return InstructionsState::invalid();
} else if (auto *Call = dyn_cast<CallInst>(I)) {
- auto *CallBase = cast<CallInst>(IBase);
+ auto *CallBase = cast<CallInst>(MainOp);
if (Call->getCalledFunction() != CallBase->getCalledFunction())
return InstructionsState::invalid();
if (Call->hasOperandBundles() &&
@@ -1086,8 +1084,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
return InstructionsState::invalid();
}
- return InstructionsState(cast<Instruction>(V),
- cast<Instruction>(VL[AltIndex]));
+ return InstructionsState(MainOp, AltOp);
}
/// \returns true if all of the values in \p VL have the same type or false
More information about the llvm-commits
mailing list