[llvm] 7e878aa - [PatternMatch] Add support for capture-and-match (NFC) (#149825)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 23 01:05:13 PDT 2025
Author: Nikita Popov
Date: 2025-07-23T10:05:09+02:00
New Revision: 7e878aaf23dd559fa491a0bf6168f15f939c5965
URL: https://github.com/llvm/llvm-project/commit/7e878aaf23dd559fa491a0bf6168f15f939c5965
DIFF: https://github.com/llvm/llvm-project/commit/7e878aaf23dd559fa491a0bf6168f15f939c5965.diff
LOG: [PatternMatch] Add support for capture-and-match (NFC) (#149825)
When using PatternMatch, there is a common problem where we want to both
match something against a pattern, but also capture the
value/instruction for various reasons (e.g. to access flags).
Currently the two ways to do that is to either capture using
m_Value/m_Instruction and do a separate match on the result, or to use
the somewhat awkward `m_CombineAnd(m_XYZ, m_Value(V))` pattern.
This PR introduces to add a variant of `m_Value`/`m_Instruction` which
does both a capture and a match. `m_Value(V, m_XYZ)` is basically
equivalent to `m_CombineAnd(m_XYZ, m_Value(V))`.
I've ported two InstCombine files to this pattern as a sample.
Added:
Modified:
llvm/include/llvm/IR/PatternMatch.h
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 50e50a91389e2..27c5d5ca08cd6 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -822,12 +822,52 @@ template <typename Class> struct bind_ty {
}
};
+/// Check whether the value has the given Class and matches the nested
+/// pattern. Capture it into the provided variable if successful.
+template <typename Class, typename MatchTy> struct bind_and_match_ty {
+ Class *&VR;
+ MatchTy Match;
+
+ bind_and_match_ty(Class *&V, const MatchTy &Match) : VR(V), Match(Match) {}
+
+ template <typename ITy> bool match(ITy *V) const {
+ auto *CV = dyn_cast<Class>(V);
+ if (CV && Match.match(V)) {
+ VR = CV;
+ return true;
+ }
+ return false;
+ }
+};
+
/// Match a value, capturing it if we match.
inline bind_ty<Value> m_Value(Value *&V) { return V; }
inline bind_ty<const Value> m_Value(const Value *&V) { return V; }
+/// Match against the nested pattern, and capture the value if we match.
+template <typename MatchTy>
+inline bind_and_match_ty<Value, MatchTy> m_Value(Value *&V,
+ const MatchTy &Match) {
+ return {V, Match};
+}
+
+/// Match against the nested pattern, and capture the value if we match.
+template <typename MatchTy>
+inline bind_and_match_ty<const Value, MatchTy> m_Value(const Value *&V,
+ const MatchTy &Match) {
+ return {V, Match};
+}
+
/// Match an instruction, capturing it if we match.
inline bind_ty<Instruction> m_Instruction(Instruction *&I) { return I; }
+
+/// Match against the nested pattern, and capture the instruction if we match.
+template <typename MatchTy>
+inline bind_and_match_ty<Instruction, MatchTy>
+m_Instruction(Instruction *&I, const MatchTy &Match) {
+ return {I, Match};
+}
+
/// Match a unary operator, capturing it if we match.
inline bind_ty<UnaryOperator> m_UnOp(UnaryOperator *&I) { return I; }
/// Match a binary operator, capturing it if we match.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 7f605be976549..d934638c15e75 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1355,9 +1355,9 @@ Instruction *InstCombinerImpl::
// right-shift of X and a "select".
Value *X, *Select;
Instruction *LowBitsToSkip, *Extract;
- if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_CombineAnd(
- m_LShr(m_Value(X), m_Instruction(LowBitsToSkip)),
- m_Instruction(Extract))),
+ if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_Instruction(
+ Extract, m_LShr(m_Value(X),
+ m_Instruction(LowBitsToSkip)))),
m_Value(Select))))
return nullptr;
@@ -1763,13 +1763,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
Constant *C;
// (add X, (sext/zext (icmp eq X, C)))
// -> (select (icmp eq X, C), (add C, (sext/zext 1)), X)
- auto CondMatcher = m_CombineAnd(
- m_Value(Cond),
- m_SpecificICmp(ICmpInst::ICMP_EQ, m_Deferred(A), m_ImmConstant(C)));
+ auto CondMatcher =
+ m_Value(Cond, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Deferred(A),
+ m_ImmConstant(C)));
if (match(&I,
- m_c_Add(m_Value(A),
- m_CombineAnd(m_Value(Ext), m_ZExtOrSExt(CondMatcher)))) &&
+ m_c_Add(m_Value(A), m_Value(Ext, m_ZExtOrSExt(CondMatcher)))) &&
Ext->hasOneUse()) {
Value *Add = isa<ZExtInst>(Ext) ? InstCombiner::AddOne(C)
: InstCombiner::SubOne(C);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 3beda6bc5ba38..b231c04319106 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2025,10 +2025,9 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I,
if (CountUses && !Op->hasOneUse())
return false;
- if (match(Op, m_c_BinOp(FlippedOpcode,
- m_CombineAnd(m_Value(X),
- m_Not(m_c_BinOp(Opcode, m_A, m_B))),
- m_C)))
+ if (match(Op,
+ m_c_BinOp(FlippedOpcode,
+ m_Value(X, m_Not(m_c_BinOp(Opcode, m_A, m_B))), m_C)))
return !CountUses || X->hasOneUse();
return false;
@@ -2079,10 +2078,10 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I,
// result is more undefined than a source:
// (~(A & B) | C) & ~(C & (A ^ B)) --> (A ^ B ^ C) | ~(A | C) is invalid.
if (Opcode == Instruction::Or && Op0->hasOneUse() &&
- match(Op1, m_OneUse(m_Not(m_CombineAnd(
- m_Value(Y),
- m_c_BinOp(Opcode, m_Specific(C),
- m_c_Xor(m_Specific(A), m_Specific(B)))))))) {
+ match(Op1,
+ m_OneUse(m_Not(m_Value(
+ Y, m_c_BinOp(Opcode, m_Specific(C),
+ m_c_Xor(m_Specific(A), m_Specific(B)))))))) {
// X = ~(A | B)
// Y = (C | (A ^ B)
Value *Or = cast<BinaryOperator>(X)->getOperand(0);
@@ -2098,12 +2097,11 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I,
if (match(Op0,
m_OneUse(m_c_BinOp(FlippedOpcode,
m_BinOp(FlippedOpcode, m_Value(B), m_Value(C)),
- m_CombineAnd(m_Value(X), m_Not(m_Value(A)))))) ||
- match(Op0, m_OneUse(m_c_BinOp(
- FlippedOpcode,
- m_c_BinOp(FlippedOpcode, m_Value(C),
- m_CombineAnd(m_Value(X), m_Not(m_Value(A)))),
- m_Value(B))))) {
+ m_Value(X, m_Not(m_Value(A)))))) ||
+ match(Op0, m_OneUse(m_c_BinOp(FlippedOpcode,
+ m_c_BinOp(FlippedOpcode, m_Value(C),
+ m_Value(X, m_Not(m_Value(A)))),
+ m_Value(B))))) {
// X = ~A
// (~A & B & C) | ~(A | B | C) --> ~(A | (B ^ C))
// (~A | B | C) & ~(A & B & C) --> (~A | (B ^ C))
@@ -2434,8 +2432,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
// (-(X & 1)) & Y --> (X & 1) == 0 ? 0 : Y
Value *Neg;
if (match(&I,
- m_c_And(m_CombineAnd(m_Value(Neg),
- m_OneUse(m_Neg(m_And(m_Value(), m_One())))),
+ m_c_And(m_Value(Neg, m_OneUse(m_Neg(m_And(m_Value(), m_One())))),
m_Value(Y)))) {
Value *Cmp = Builder.CreateIsNull(Neg);
return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Y);
@@ -3728,9 +3725,8 @@ static Value *foldOrUnsignedUMulOverflowICmp(BinaryOperator &I,
const APInt *C1, *C2;
if (match(&I,
m_c_Or(m_ExtractValue<1>(
- m_CombineAnd(m_Intrinsic<Intrinsic::umul_with_overflow>(
- m_Value(X), m_APInt(C1)),
- m_Value(WOV))),
+ m_Value(WOV, m_Intrinsic<Intrinsic::umul_with_overflow>(
+ m_Value(X), m_APInt(C1)))),
m_OneUse(m_SpecificCmp(ICmpInst::ICMP_UGT,
m_ExtractValue<0>(m_Deferred(WOV)),
m_APInt(C2))))) &&
@@ -3988,12 +3984,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
// ~(B & ?) | (A ^ B) --> ~((B & ?) & A)
Instruction *And;
if ((Op0->hasOneUse() || Op1->hasOneUse()) &&
- match(Op0, m_Not(m_CombineAnd(m_Instruction(And),
- m_c_And(m_Specific(A), m_Value())))))
+ match(Op0,
+ m_Not(m_Instruction(And, m_c_And(m_Specific(A), m_Value())))))
return BinaryOperator::CreateNot(Builder.CreateAnd(And, B));
if ((Op0->hasOneUse() || Op1->hasOneUse()) &&
- match(Op0, m_Not(m_CombineAnd(m_Instruction(And),
- m_c_And(m_Specific(B), m_Value())))))
+ match(Op0,
+ m_Not(m_Instruction(And, m_c_And(m_Specific(B), m_Value())))))
return BinaryOperator::CreateNot(Builder.CreateAnd(And, A));
// (~A | C) | (A ^ B) --> ~(A & B) | C
@@ -4125,16 +4121,13 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
// treating any non-zero result as overflow. In that case, we overflow if both
// umul.with.overflow operands are != 0, as in that case the result can only
// be 0, iff the multiplication overflows.
- if (match(&I,
- m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_Value(UMulWithOv)),
- m_Value(Ov)),
- m_CombineAnd(
- m_SpecificICmp(ICmpInst::ICMP_NE,
- m_CombineAnd(m_ExtractValue<0>(
- m_Deferred(UMulWithOv)),
- m_Value(Mul)),
- m_ZeroInt()),
- m_Value(MulIsNotZero)))) &&
+ if (match(&I, m_c_Or(m_Value(Ov, m_ExtractValue<1>(m_Value(UMulWithOv))),
+ m_Value(MulIsNotZero,
+ m_SpecificICmp(
+ ICmpInst::ICMP_NE,
+ m_Value(Mul, m_ExtractValue<0>(
+ m_Deferred(UMulWithOv))),
+ m_ZeroInt())))) &&
(Ov->hasOneUse() || (MulIsNotZero->hasOneUse() && Mul->hasOneUse()))) {
Value *A, *B;
if (match(UMulWithOv, m_Intrinsic<Intrinsic::umul_with_overflow>(
@@ -4151,9 +4144,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
const WithOverflowInst *WO;
const Value *WOV;
const APInt *C1, *C2;
- if (match(&I, m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_CombineAnd(
- m_WithOverflowInst(WO), m_Value(WOV))),
- m_Value(Ov)),
+ if (match(&I, m_c_Or(m_Value(Ov, m_ExtractValue<1>(
+ m_Value(WOV, m_WithOverflowInst(WO)))),
m_OneUse(m_ICmp(Pred, m_ExtractValue<0>(m_Deferred(WOV)),
m_APInt(C2))))) &&
(WO->getBinaryOp() == Instruction::Add ||
@@ -4501,8 +4493,7 @@ static Instruction *visitMaskedMerge(BinaryOperator &I,
Value *M;
if (!match(&I, m_c_Xor(m_Value(B),
m_OneUse(m_c_And(
- m_CombineAnd(m_c_Xor(m_Deferred(B), m_Value(X)),
- m_Value(D)),
+ m_Value(D, m_c_Xor(m_Deferred(B), m_Value(X))),
m_Value(M))))))
return nullptr;
@@ -5206,8 +5197,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
// (X ^ C) ^ Y --> (X ^ Y) ^ C
// Just like we do in other places, we completely avoid the fold
// for constantexprs, at least to avoid endless combine loop.
- if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_CombineAnd(m_Value(X),
- m_Unless(m_ConstantExpr())),
+ if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(X, m_Unless(m_ConstantExpr())),
m_ImmConstant(C1))),
m_Value(Y))))
return BinaryOperator::CreateXor(Builder.CreateXor(X, Y), C1);
More information about the llvm-commits
mailing list