[llvm] 26a9f3f - [SLP][NFC]Cleanup getSameOpcode, return InstructionsState::invalid() for non-valid inputs
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 8 14:03:25 PST 2024
Author: Alexey Bataev
Date: 2024-11-08T14:00:32-08:00
New Revision: 26a9f3f5906c62cff7f2245b98affa432b504a87
URL: https://github.com/llvm/llvm-project/commit/26a9f3f5906c62cff7f2245b98affa432b504a87
DIFF: https://github.com/llvm/llvm-project/commit/26a9f3f5906c62cff7f2245b98affa432b504a87.diff
LOG: [SLP][NFC]Cleanup getSameOpcode, return InstructionsState::invalid() for non-valid inputs
Just a cleanup and related changes
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index a6accf0318a30f..4a73b9c2c4b34a 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -832,6 +832,7 @@ struct InstructionsState {
InstructionsState() = delete;
InstructionsState(Value *OpValue, Instruction *MainOp, Instruction *AltOp)
: OpValue(OpValue), MainOp(MainOp), AltOp(AltOp) {}
+ static InstructionsState invalid() { return {nullptr, nullptr, nullptr}; }
};
} // end anonymous namespace
@@ -891,20 +892,19 @@ static bool isCmpSameOrSwapped(const CmpInst *BaseCI, const CmpInst *CI,
/// could be vectorized even if its structure is diverse.
static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
const TargetLibraryInfo &TLI) {
- constexpr unsigned BaseIndex = 0;
// Make sure these are all Instructions.
- if (llvm::any_of(VL, [](Value *V) { return !isa<Instruction>(V); }))
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+ if (!all_of(VL, IsaPred<Instruction>))
+ return InstructionsState::invalid();
- bool IsCastOp = isa<CastInst>(VL[BaseIndex]);
- bool IsBinOp = isa<BinaryOperator>(VL[BaseIndex]);
- bool IsCmpOp = isa<CmpInst>(VL[BaseIndex]);
+ Value *V = VL.front();
+ bool IsCastOp = isa<CastInst>(V);
+ bool IsBinOp = isa<BinaryOperator>(V);
+ bool IsCmpOp = isa<CmpInst>(V);
CmpInst::Predicate BasePred =
- IsCmpOp ? cast<CmpInst>(VL[BaseIndex])->getPredicate()
- : CmpInst::BAD_ICMP_PREDICATE;
- unsigned Opcode = cast<Instruction>(VL[BaseIndex])->getOpcode();
+ IsCmpOp ? cast<CmpInst>(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE;
+ unsigned Opcode = cast<Instruction>(V)->getOpcode();
unsigned AltOpcode = Opcode;
- unsigned AltIndex = BaseIndex;
+ unsigned AltIndex = 0;
bool SwappedPredsCompatible = [&]() {
if (!IsCmpOp)
@@ -931,14 +931,14 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
}();
// Check for one alternate opcode from another BinaryOperator.
// TODO - generalize to support all operators (types, calls etc.).
- auto *IBase = cast<Instruction>(VL[BaseIndex]);
+ auto *IBase = cast<Instruction>(V);
Intrinsic::ID BaseID = 0;
SmallVector<VFInfo> BaseMappings;
if (auto *CallBase = dyn_cast<CallInst>(IBase)) {
BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI);
BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase);
if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+ return InstructionsState::invalid();
}
for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
auto *I = cast<Instruction>(VL[Cnt]);
@@ -970,7 +970,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
}
}
} else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt]); Inst && IsCmpOp) {
- auto *BaseInst = cast<CmpInst>(VL[BaseIndex]);
+ auto *BaseInst = cast<CmpInst>(V);
Type *Ty0 = BaseInst->getOperand(0)->getType();
Type *Ty1 = Inst->getOperand(0)->getType();
if (Ty0 == Ty1) {
@@ -988,7 +988,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
if (isCmpSameOrSwapped(BaseInst, Inst, TLI))
continue;
auto *AltInst = cast<CmpInst>(VL[AltIndex]);
- if (AltIndex != BaseIndex) {
+ if (AltIndex) {
if (isCmpSameOrSwapped(AltInst, Inst, TLI))
continue;
} else if (BasePred != CurrentPred) {
@@ -1007,27 +1007,28 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
if (Gep->getNumOperands() != 2 ||
Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType())
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+ return InstructionsState::invalid();
} else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
if (!isVectorLikeInstWithConstOps(EI))
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+ return InstructionsState::invalid();
} else if (auto *LI = dyn_cast<LoadInst>(I)) {
auto *BaseLI = cast<LoadInst>(IBase);
if (!LI->isSimple() || !BaseLI->isSimple())
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+ return InstructionsState::invalid();
} else if (auto *Call = dyn_cast<CallInst>(I)) {
auto *CallBase = cast<CallInst>(IBase);
if (Call->getCalledFunction() != CallBase->getCalledFunction())
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
- if (Call->hasOperandBundles() && (!CallBase->hasOperandBundles() ||
- !std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(),
- Call->op_begin() + Call->getBundleOperandsEndIndex(),
- CallBase->op_begin() +
- CallBase->getBundleOperandsStartIndex())))
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+ return InstructionsState::invalid();
+ if (Call->hasOperandBundles() &&
+ (!CallBase->hasOperandBundles() ||
+ !std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(),
+ Call->op_begin() + Call->getBundleOperandsEndIndex(),
+ CallBase->op_begin() +
+ CallBase->getBundleOperandsStartIndex())))
+ return InstructionsState::invalid();
Intrinsic::ID ID = getVectorIntrinsicIDForCall(Call, &TLI);
if (ID != BaseID)
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+ return InstructionsState::invalid();
if (!ID) {
SmallVector<VFInfo> Mappings = VFDatabase(*Call).getMappings(*Call);
if (Mappings.size() != BaseMappings.size() ||
@@ -1037,15 +1038,15 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
Mappings.front().Shape.VF != BaseMappings.front().Shape.VF ||
Mappings.front().Shape.Parameters !=
BaseMappings.front().Shape.Parameters)
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+ return InstructionsState::invalid();
}
}
continue;
}
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+ return InstructionsState::invalid();
}
- return InstructionsState(VL[BaseIndex], cast<Instruction>(VL[BaseIndex]),
+ return InstructionsState(V, cast<Instruction>(V),
cast<Instruction>(VL[AltIndex]));
}
@@ -8019,7 +8020,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
// Don't handle vectors.
- if (!SLPReVec && getValueType(S.OpValue)->isVectorTy()) {
+ if (!SLPReVec && getValueType(VL.front())->isVectorTy()) {
LLVM_DEBUG(dbgs() << "SLP: Gathering due to vector type.\n");
newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
return;
@@ -8088,7 +8089,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize;
bool AreAllSameBlock = S.getOpcode() && allSameBlock(VL);
bool AreScatterAllGEPSameBlock =
- (IsScatterVectorizeUserTE && S.OpValue->getType()->isPointerTy() &&
+ (IsScatterVectorizeUserTE && VL.front()->getType()->isPointerTy() &&
VL.size() > 2 &&
all_of(VL,
[&BB](Value *V) {
@@ -8104,7 +8105,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
SortedIndices));
bool AreAllSameInsts = AreAllSameBlock || AreScatterAllGEPSameBlock;
if (!AreAllSameInsts || (!S.getOpcode() && allConstant(VL)) || isSplat(VL) ||
- (isa<InsertElementInst, ExtractValueInst, ExtractElementInst>(
+ (isa_and_present<InsertElementInst, ExtractValueInst, ExtractElementInst>(
S.OpValue) &&
!all_of(VL, isVectorLikeInstWithConstOps)) ||
NotProfitableForVectorization(VL)) {
@@ -8161,7 +8162,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
// Special processing for sorted pointers for ScatterVectorize node with
// constant indeces only.
if (!AreAllSameBlock && AreScatterAllGEPSameBlock) {
- assert(S.OpValue->getType()->isPointerTy() &&
+ assert(VL.front()->getType()->isPointerTy() &&
count_if(VL, IsaPred<GetElementPtrInst>) >= 2 &&
"Expected pointers only.");
// Reset S to make it GetElementPtr kind of node.
More information about the llvm-commits
mailing list