[llvm] [PatternMatch][VPlan] Add std::function match overload. NFCI (PR #146374)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 30 08:42:23 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-vectorizers
@llvm/pr-subscribers-llvm-analysis
Author: Luke Lau (lukel97)
<details>
<summary>Changes</summary>
A relatively common use case for PatternMatch is to use match inside all_of/any_of/none_of. This patch adds an overload for match that returns a lambda so callers don't need to create a lambda themselves for both the LLVM and VPlan pattern matchers.
---
Full diff: https://github.com/llvm/llvm-project/pull/146374.diff
13 Files Affected:
- (modified) llvm/include/llvm/IR/PatternMatch.h (+5)
- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+6-10)
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+2-3)
- (modified) llvm/lib/CodeGen/InterleavedAccessPass.cpp (+3-4)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (+2-6)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+1-3)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+1-1)
- (modified) llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp (+2-2)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+1-1)
- (modified) llvm/lib/Transforms/Scalar/LICM.cpp (+3-4)
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+16-19)
- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+3-4)
- (modified) llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (+10)
``````````diff
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 1f86cdfd94e17..e5013c31f1f40 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -50,6 +50,11 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
return P.match(V);
}
+template <typename Val = const Value, typename Pattern>
+std::function<bool(Val *)> match(const Pattern &P) {
+ return [&P](Val *V) { return P.match(V); };
+}
+
template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) {
return P.match(Mask);
}
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index cb1dae92faf92..fba9dd80f02c7 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -5028,14 +5028,12 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
// All-zero GEP is a no-op, unless it performs a vector splat.
- if (Ptr->getType() == GEPTy &&
- all_of(Indices, [](const auto *V) { return match(V, m_Zero()); }))
+ if (Ptr->getType() == GEPTy && all_of(Indices, match(m_Zero())))
return Ptr;
// getelementptr poison, idx -> poison
// getelementptr baseptr, poison -> poison
- if (isa<PoisonValue>(Ptr) ||
- any_of(Indices, [](const auto *V) { return isa<PoisonValue>(V); }))
+ if (isa<PoisonValue>(Ptr) || any_of(Indices, match(m_Poison())))
return PoisonValue::get(GEPTy);
// getelementptr undef, idx -> undef
@@ -5092,8 +5090,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
if (!IsScalableVec && Q.DL.getTypeAllocSize(LastType) == 1 &&
- all_of(Indices.drop_back(1),
- [](Value *Idx) { return match(Idx, m_Zero()); })) {
+ all_of(Indices.drop_back(1), match(m_Zero()))) {
unsigned IdxWidth =
Q.DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace());
if (Q.DL.getTypeSizeInBits(Indices.back()->getType()) == IdxWidth) {
@@ -5123,8 +5120,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
// Check to see if this is constant foldable.
- if (!isa<Constant>(Ptr) ||
- !all_of(Indices, [](Value *V) { return isa<Constant>(V); }))
+ if (!isa<Constant>(Ptr) || !all_of(Indices, match(m_Constant())))
return nullptr;
if (!ConstantExpr::isSupportedGetElementPtr(SrcTy))
@@ -5649,7 +5645,7 @@ static Constant *simplifyFPOp(ArrayRef<Value *> Ops, FastMathFlags FMF,
RoundingMode Rounding) {
// Poison is independent of anything else. It always propagates from an
// operand to a math result.
- if (any_of(Ops, [](Value *V) { return match(V, m_Poison()); }))
+ if (any_of(Ops, match(m_Poison())))
return PoisonValue::get(Ops[0]->getType());
for (Value *V : Ops) {
@@ -7116,7 +7112,7 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
switch (I->getOpcode()) {
default:
- if (llvm::all_of(NewOps, [](Value *V) { return isa<Constant>(V); })) {
+ if (all_of(NewOps, match(m_Constant()))) {
SmallVector<Constant *, 8> NewConstOps(NewOps.size());
transform(NewOps, NewConstOps.begin(),
[](Value *V) { return cast<Constant>(V); });
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e576f4899810a..6cc50bf7e3ee1 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -251,9 +251,8 @@ bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
}
bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
- return !I->user_empty() && all_of(I->users(), [](const User *U) {
- return match(U, m_ICmp(m_Value(), m_Zero()));
- });
+ return !I->user_empty() &&
+ all_of(I->users(), match(m_ICmp(m_Value(), m_Zero())));
}
bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
diff --git a/llvm/lib/CodeGen/InterleavedAccessPass.cpp b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
index 9c4c86cebe7e5..2d2d48b004e77 100644
--- a/llvm/lib/CodeGen/InterleavedAccessPass.cpp
+++ b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
@@ -294,10 +294,9 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
continue;
}
if (auto *BI = dyn_cast<BinaryOperator>(User)) {
- if (!BI->user_empty() && all_of(BI->users(), [](auto *U) {
- auto *SVI = dyn_cast<ShuffleVectorInst>(U);
- return SVI && isa<UndefValue>(SVI->getOperand(1));
- })) {
+ using namespace PatternMatch;
+ if (!BI->user_empty() &&
+ all_of(BI->users(), match(m_Shuffle(m_Value(), m_Undef())))) {
for (auto *SVI : BI->users())
BinOpShuffles.insert(cast<ShuffleVectorInst>(SVI));
continue;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index e721f0cd5f9e3..cb6dc4b5b0fc5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2307,12 +2307,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// and let's try to sink `(sub 0, b)` into `b` itself. But only if this isn't
// a pure negation used by a select that looks like abs/nabs.
bool IsNegation = match(Op0, m_ZeroInt());
- if (!IsNegation || none_of(I.users(), [&I, Op1](const User *U) {
- const Instruction *UI = dyn_cast<Instruction>(U);
- if (!UI)
- return false;
- return match(UI, m_c_Select(m_Specific(Op1), m_Specific(&I)));
- })) {
+ if (!IsNegation ||
+ none_of(I.users(), match(m_c_Select(m_Specific(Op1), m_Specific(&I))))) {
if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation &&
I.hasNoSignedWrap(),
Op1, *this))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index e33d111167c04..cc6ad9bf44cd2 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1418,9 +1418,7 @@ InstCombinerImpl::foldShuffledIntrinsicOperands(IntrinsicInst *II) {
// At least 1 operand must be a shuffle with 1 use because we are creating 2
// instructions.
- if (none_of(II->args(), [](Value *V) {
- return isa<ShuffleVectorInst>(V) && V->hasOneUse();
- }))
+ if (none_of(II->args(), match(m_OneUse(m_Shuffle(m_Value(), m_Value())))))
return nullptr;
// See if all arguments are shuffled with the same mask.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0894ca92086f3..27b239417de04 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1341,7 +1341,7 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
return nullptr;
if (auto *Phi = dyn_cast<PHINode>(Op0))
- if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) {
+ if (all_of(Phi->operands(), match(m_Constant()))) {
SmallVector<Constant *> Ops;
for (Value *V : Phi->incoming_values()) {
Constant *Res =
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 6477141ab095f..d992e2f57a0c7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -339,7 +339,7 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) {
// convert ptr2int ( phi[ int2ptr(ptr2int(x))] ) --> ptr2int ( phi [ x ] )
// Make sure all uses of phi are ptr2int.
- if (!all_of(PN.users(), [](User *U) { return isa<PtrToIntInst>(U); }))
+ if (!all_of(PN.users(), match(m_PtrToInt(m_Value()))))
return nullptr;
// Iterating over all operands to check presence of target pointers for
@@ -1298,7 +1298,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN,
// \ /
// phi [v1] [v2]
// Make sure all inputs are constants.
- if (!all_of(PN.operands(), [](Value *V) { return isa<ConstantInt>(V); }))
+ if (!all_of(PN.operands(), match(m_ConstantInt())))
return nullptr;
BasicBlock *BB = PN.getParent();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 73ba0f78e8053..c43a8cb53e4e9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3142,7 +3142,7 @@ static Instruction *foldNestedSelects(SelectInst &OuterSelVal,
// Profitability check - avoid increasing instruction count.
if (none_of(ArrayRef<Value *>({OuterSelVal.getCondition(), InnerSelVal}),
- [](Value *V) { return V->hasOneUse(); }))
+ match(m_OneUse(m_Value()))))
return nullptr;
// The appropriate hand of the outermost `select` must be a select itself.
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index cf84366c4200b..c2edc3a33edc1 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -435,10 +435,9 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI,
// potentially happen in other passes where instructions are being moved
// across that edge.
bool HasCoroSuspendInst = llvm::any_of(L->getBlocks(), [](BasicBlock *BB) {
- return llvm::any_of(*BB, [](Instruction &I) {
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
- return II && II->getIntrinsicID() == Intrinsic::coro_suspend;
- });
+ using namespace PatternMatch;
+ return any_of(make_pointer_range(*BB),
+ match(m_Intrinsic<Intrinsic::coro_suspend>()));
});
MemorySSAUpdater MSSAU(MSSA);
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 95479373b4393..3fe9c46ac7656 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7036,11 +7036,12 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
// Unused FOR splices are removed by VPlan transforms, so the VPlan-based
// cost model won't cost it whilst the legacy will.
if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R)) {
- if (none_of(FOR->users(), [](VPUser *U) {
- auto *VPI = dyn_cast<VPInstruction>(U);
- return VPI && VPI->getOpcode() ==
- VPInstruction::FirstOrderRecurrenceSplice;
- }))
+ using namespace VPlanPatternMatch;
+ if (none_of(
+ FOR->users(),
+ match(
+ m_VPInstruction<VPInstruction::FirstOrderRecurrenceSplice>(
+ m_VPValue(), m_VPValue()))))
return true;
}
// The VPlan-based cost model is more accurate for partial reduction and
@@ -7449,13 +7450,11 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
Hints.setAlreadyVectorized();
// Check if it's EVL-vectorized and mark the corresponding metadata.
+ using namespace VPlanPatternMatch;
bool IsEVLVectorized =
- llvm::any_of(*HeaderVPBB, [](const VPRecipeBase &Recipe) {
- // Looking for the ExplictVectorLength VPInstruction.
- if (const auto *VI = dyn_cast<VPInstruction>(&Recipe))
- return VI->getOpcode() == VPInstruction::ExplicitVectorLength;
- return false;
- });
+ any_of(make_pointer_range(*HeaderVPBB),
+ match(m_VPInstruction<VPInstruction::ExplicitVectorLength>(
+ m_VPValue())));
if (IsEVLVectorized) {
LLVMContext &Context = L->getHeader()->getContext();
MDNode *LoopID = L->getLoopID();
@@ -9737,10 +9736,9 @@ static void preparePlanForMainVectorLoop(VPlan &MainPlan, VPlan &EpiPlan) {
// If there is a suitable resume value for the canonical induction in the
// scalar (which will become vector) epilogue loop we are done. Otherwise
// create it below.
- if (any_of(*MainScalarPH, [VectorTC](VPRecipeBase &R) {
- return match(&R, m_VPInstruction<Instruction::PHI>(m_Specific(VectorTC),
- m_SpecificInt(0)));
- }))
+ if (any_of(make_pointer_range(*MainScalarPH),
+ match(m_VPInstruction<Instruction::PHI>(m_Specific(VectorTC),
+ m_SpecificInt(0)))))
return;
VPBuilder ScalarPHBuilder(MainScalarPH, MainScalarPH->begin());
ScalarPHBuilder.createScalarPhi(
@@ -9778,10 +9776,9 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L,
match(
P.getIncomingValueForBlock(EPI.MainLoopIterationCountCheck),
m_SpecificInt(0)) &&
- all_of(P.incoming_values(), [&EPI](Value *Inc) {
- return Inc == EPI.VectorTripCount ||
- match(Inc, m_SpecificInt(0));
- }))
+ all_of(P.incoming_values(),
+ match(m_CombineOr(m_Specific(EPI.VectorTripCount),
+ m_SpecificInt(0)))))
return &P;
return nullptr;
});
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 0941bf61953f1..79bf939cd591e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -20708,10 +20708,9 @@ void BoUpSLP::computeMinimumValueSizes() {
IsTruncRoot = true;
}
bool IsSignedCmp = false;
- if (UserIgnoreList && all_of(*UserIgnoreList, [](Value *V) {
- return match(V, m_SMin(m_Value(), m_Value())) ||
- match(V, m_SMax(m_Value(), m_Value()));
- }))
+ if (UserIgnoreList &&
+ all_of(*UserIgnoreList, match(m_CombineOr(m_SMin(m_Value(), m_Value()),
+ m_SMax(m_Value(), m_Value())))))
IsSignedCmp = true;
while (NodeIdx < VectorizableTree.size()) {
ArrayRef<Value *> TreeRoot = VectorizableTree[NodeIdx]->Scalars;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index efea99f22d086..4aba5fb010559 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -29,11 +29,21 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
return P.match(V);
}
+template <typename Val, typename Pattern>
+std::function<bool(Val *)> match(const Pattern &P) {
+ return [&P](Val *V) { return P.match(V); };
+}
+
template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
auto *R = dyn_cast<VPRecipeBase>(U);
return R && match(R, P);
}
+template <typename Pattern>
+std::function<bool(VPUser *)> match(const Pattern &P) {
+ return [&P](VPUser *U) { return match(U, P); };
+}
+
template <typename Class> struct class_match {
template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
};
``````````
</details>
https://github.com/llvm/llvm-project/pull/146374
More information about the llvm-commits
mailing list