[llvm] fd66195 - [VPlan] Manage compare predicates in VPRecipeWithIRFlags.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Sat Sep 2 13:45:52 PDT 2023
Author: Florian Hahn
Date: 2023-09-02T21:45:24+01:00
New Revision: fd6619577790f049ce72fa043351028d8c132d58
URL: https://github.com/llvm/llvm-project/commit/fd6619577790f049ce72fa043351028d8c132d58
DIFF: https://github.com/llvm/llvm-project/commit/fd6619577790f049ce72fa043351028d8c132d58.diff
LOG: [VPlan] Manage compare predicates in VPRecipeWithIRFlags.
Extend VPRecipeWithIRFlags to also manage predicates for compares. This
allows removing the custom ICmpULE opcode from VPInstruction which was a
workaround for missing proper predicate handling.
This simplifies the code a bit while also allowing compares with any
predicates. It also fixes a case where the compare predixcate wasn't
printed properly for VPReplicateRecipes.
Discussed/split off from D150398.
Reviewed By: Ayal
Differential Revision: https://reviews.llvm.org/D158992
Added:
Modified:
llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
llvm/lib/Transforms/Vectorize/VPlan.h
llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index a0d496babad3cc..2df8fea85454ea 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -45,13 +45,17 @@ class VPBuilder {
VPBasicBlock *BB = nullptr;
VPBasicBlock::iterator InsertPt = VPBasicBlock::iterator();
+ /// Insert \p VPI in BB at InsertPt if BB is set.
+ VPInstruction *tryInsertInstruction(VPInstruction *VPI) {
+ if (BB)
+ BB->insert(VPI, InsertPt);
+ return VPI;
+ }
+
VPInstruction *createInstruction(unsigned Opcode,
ArrayRef<VPValue *> Operands, DebugLoc DL,
const Twine &Name = "") {
- VPInstruction *Instr = new VPInstruction(Opcode, Operands, DL, Name);
- if (BB)
- BB->insert(Instr, InsertPt);
- return Instr;
+ return tryInsertInstruction(new VPInstruction(Opcode, Operands, DL, Name));
}
VPInstruction *createInstruction(unsigned Opcode,
@@ -152,6 +156,12 @@ class VPBuilder {
Name);
}
+ /// Create a new ICmp VPInstruction with predicate \p Pred and operands \p A
+ /// and \p B.
+ /// TODO: add createFCmp when needed.
+ VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
+ DebugLoc DL = {}, const Twine &Name = "");
+
//===--------------------------------------------------------------------===//
// RAII helpers.
//===--------------------------------------------------------------------===//
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 750f8ff22a22ab..03a5b08085b19c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7372,6 +7372,14 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
}
}
+VPValue *VPBuilder::createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
+ DebugLoc DL, const Twine &Name) {
+ assert(Pred >= CmpInst::FIRST_ICMP_PREDICATE &&
+ Pred <= CmpInst::LAST_ICMP_PREDICATE && "invalid predicate");
+ return tryInsertInstruction(
+ new VPInstruction(Instruction::ICmp, Pred, A, B, DL, Name));
+}
+
// TODO: we could return a pair of values that specify the max VF and
// min VF, to be used in `buildVPlans(MinVF, MaxVF)` instead of
// `buildVPlans(VF, VF)`. We cannot do it because VPLAN at the moment
@@ -8079,7 +8087,7 @@ void VPRecipeBuilder::createHeaderMask(VPlan &Plan) {
nullptr, "active.lane.mask");
} else {
VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
- BlockMask = Builder.createNaryOp(VPInstruction::ICmpULE, {IV, BTC});
+ BlockMask = Builder.createICmp(CmpInst::ICMP_ULE, IV, BTC);
}
BlockMaskCache[Header] = BlockMask;
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 0d45882adb3586..8a497e82e1181b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -814,6 +814,7 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
/// Class to record LLVM IR flag for a recipe along with it.
class VPRecipeWithIRFlags : public VPRecipeBase {
enum class OperationType : unsigned char {
+ Cmp,
OverflowingBinOp,
PossiblyExactOp,
GEPOp,
@@ -851,11 +852,12 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
OperationType OpType;
union {
+ CmpInst::Predicate CmpPredicate;
WrapFlagsTy WrapFlags;
ExactFlagsTy ExactFlags;
GEPFlagsTy GEPFlags;
FastMathFlagsTy FMFs;
- unsigned char AllFlags;
+ unsigned AllFlags;
};
public:
@@ -869,7 +871,10 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
template <typename IterT>
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, Instruction &I)
: VPRecipeWithIRFlags(SC, Operands) {
- if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
+ if (auto *Op = dyn_cast<CmpInst>(&I)) {
+ OpType = OperationType::Cmp;
+ CmpPredicate = Op->getPredicate();
+ } else if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
OpType = OperationType::OverflowingBinOp;
WrapFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()};
} else if (auto *Op = dyn_cast<PossiblyExactOperator>(&I)) {
@@ -884,6 +889,12 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
}
}
+ template <typename IterT>
+ VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+ CmpInst::Predicate Pred)
+ : VPRecipeBase(SC, Operands), OpType(OperationType::Cmp),
+ CmpPredicate(Pred) {}
+
template <typename IterT>
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
WrapFlagsTy WrapFlags)
@@ -922,6 +933,7 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
FMFs.NoNaNs = false;
FMFs.NoInfs = false;
break;
+ case OperationType::Cmp:
case OperationType::Other:
break;
}
@@ -949,11 +961,18 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
I->setHasAllowContract(FMFs.AllowContract);
I->setHasApproxFunc(FMFs.ApproxFunc);
break;
+ case OperationType::Cmp:
case OperationType::Other:
break;
}
}
+ CmpInst::Predicate getPredicate() const {
+ assert(OpType == OperationType::Cmp &&
+ "recipe doesn't have a compare predicate");
+ return CmpPredicate;
+ }
+
bool isInBounds() const {
assert(OpType == OperationType::GEPOp &&
"recipe doesn't have inbounds flag");
@@ -996,7 +1015,6 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
Instruction::OtherOpsEnd + 1, // Combines the incoming and previous
// values of a first-order recurrence.
Not,
- ICmpULE,
SLPLoad,
SLPStore,
ActiveLaneMask,
@@ -1042,6 +1060,9 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
DebugLoc DL = {}, const Twine &Name = "")
: VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {}
+ VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, VPValue *A,
+ VPValue *B, DebugLoc DL = {}, const Twine &Name = "");
+
VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
WrapFlagsTy WrapFlags, DebugLoc DL = {}, const Twine &Name = "")
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags),
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 7964d603b168a3..9302b9974b183b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -116,8 +116,8 @@ bool VPRecipeBase::mayHaveSideEffects() const {
return false;
case VPInstructionSC:
switch (cast<VPInstruction>(this)->getOpcode()) {
+ case Instruction::ICmp:
case VPInstruction::Not:
- case VPInstruction::ICmpULE:
case VPInstruction::CalculateTripCountMinusVF:
case VPInstruction::CanonicalIVIncrement:
case VPInstruction::CanonicalIVIncrementForPart:
@@ -246,6 +246,16 @@ FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
return Res;
}
+VPInstruction::VPInstruction(unsigned Opcode, CmpInst::Predicate Pred,
+ VPValue *A, VPValue *B, DebugLoc DL,
+ const Twine &Name)
+ : VPRecipeWithIRFlags(VPDef::VPInstructionSC, ArrayRef<VPValue *>({A, B}),
+ Pred),
+ VPValue(this), Opcode(Opcode), DL(DL), Name(Name.str()) {
+ assert(Opcode == Instruction::ICmp &&
+ "only ICmp predicates supported at the moment");
+}
+
VPInstruction::VPInstruction(unsigned Opcode,
std::initializer_list<VPValue *> Operands,
FastMathFlags FMFs, DebugLoc DL, const Twine &Name)
@@ -271,10 +281,10 @@ Value *VPInstruction::generateInstruction(VPTransformState &State,
Value *A = State.get(getOperand(0), Part);
return Builder.CreateNot(A, Name);
}
- case VPInstruction::ICmpULE: {
+ case Instruction::ICmp: {
Value *A = State.get(getOperand(0), Part);
Value *B = State.get(getOperand(1), Part);
- return Builder.CreateICmpULE(A, B, Name);
+ return Builder.CreateCmp(getPredicate(), A, B, Name);
}
case Instruction::Select: {
Value *Cond = State.get(getOperand(0), Part);
@@ -444,9 +454,6 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
case VPInstruction::Not:
O << "not";
break;
- case VPInstruction::ICmpULE:
- O << "icmp ule";
- break;
case VPInstruction::SLPLoad:
O << "combined load";
break;
@@ -618,6 +625,9 @@ VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const {
switch (OpType) {
+ case OperationType::Cmp:
+ O << " " << CmpInst::getPredicateName(getPredicate());
+ break;
case OperationType::PossiblyExactOp:
if (ExactFlags.IsExact)
O << " exact";
@@ -741,8 +751,6 @@ void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent,
const Instruction *UI = getUnderlyingInstr();
O << " = " << UI->getOpcodeName();
printFlags(O);
- if (auto *Cmp = dyn_cast<CmpInst>(UI))
- O << Cmp->getPredicate() << " ";
printOperands(O, SlotTracker);
}
#endif
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll b/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll
index 4d8190319b62e2..06061465939002 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll
@@ -906,7 +906,7 @@ define void @update_multiple_users(ptr noalias %src, ptr noalias %dst, i1 %c) {
; CHECK-NEXT: pred.store.if:
; CHECK-NEXT: REPLICATE ir<%l1> = load ir<%src>
; CHECK-NEXT: REPLICATE ir<%l2> = trunc ir<%l1>
-; CHECK-NEXT: REPLICATE ir<%cmp> = icmp ir<%l1>, ir<0>
+; CHECK-NEXT: REPLICATE ir<%cmp> = icmp eq ir<%l1>, ir<0>
; CHECK-NEXT: REPLICATE ir<%sel> = select ir<%cmp>, ir<5>, ir<%l2>
; CHECK-NEXT: REPLICATE store ir<%sel>, ir<%dst>
; CHECK-NEXT: Successor(s): pred.store.continue
More information about the llvm-commits
mailing list