[llvm] 7e878aa - [PatternMatch] Add support for capture-and-match (NFC) (#149825)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 23 01:05:13 PDT 2025


Author: Nikita Popov
Date: 2025-07-23T10:05:09+02:00
New Revision: 7e878aaf23dd559fa491a0bf6168f15f939c5965

URL: https://github.com/llvm/llvm-project/commit/7e878aaf23dd559fa491a0bf6168f15f939c5965
DIFF: https://github.com/llvm/llvm-project/commit/7e878aaf23dd559fa491a0bf6168f15f939c5965.diff

LOG: [PatternMatch] Add support for capture-and-match (NFC) (#149825)

When using PatternMatch, there is a common problem where we want to both
match something against a pattern, but also capture the
value/instruction for various reasons (e.g. to access flags).

Currently the two ways to do that is to either capture using
m_Value/m_Instruction and do a separate match on the result, or to use
the somewhat awkward `m_CombineAnd(m_XYZ, m_Value(V))` pattern.

This PR introduces to add a variant of `m_Value`/`m_Instruction` which
does both a capture and a match. `m_Value(V, m_XYZ)` is basically
equivalent to `m_CombineAnd(m_XYZ, m_Value(V))`.

I've ported two InstCombine files to this pattern as a sample.

Added: 
    

Modified: 
    llvm/include/llvm/IR/PatternMatch.h
    llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 50e50a91389e2..27c5d5ca08cd6 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -822,12 +822,52 @@ template <typename Class> struct bind_ty {
   }
 };
 
+/// Check whether the value has the given Class and matches the nested
+/// pattern. Capture it into the provided variable if successful.
+template <typename Class, typename MatchTy> struct bind_and_match_ty {
+  Class *&VR;
+  MatchTy Match;
+
+  bind_and_match_ty(Class *&V, const MatchTy &Match) : VR(V), Match(Match) {}
+
+  template <typename ITy> bool match(ITy *V) const {
+    auto *CV = dyn_cast<Class>(V);
+    if (CV && Match.match(V)) {
+      VR = CV;
+      return true;
+    }
+    return false;
+  }
+};
+
 /// Match a value, capturing it if we match.
 inline bind_ty<Value> m_Value(Value *&V) { return V; }
 inline bind_ty<const Value> m_Value(const Value *&V) { return V; }
 
+/// Match against the nested pattern, and capture the value if we match.
+template <typename MatchTy>
+inline bind_and_match_ty<Value, MatchTy> m_Value(Value *&V,
+                                                 const MatchTy &Match) {
+  return {V, Match};
+}
+
+/// Match against the nested pattern, and capture the value if we match.
+template <typename MatchTy>
+inline bind_and_match_ty<const Value, MatchTy> m_Value(const Value *&V,
+                                                       const MatchTy &Match) {
+  return {V, Match};
+}
+
 /// Match an instruction, capturing it if we match.
 inline bind_ty<Instruction> m_Instruction(Instruction *&I) { return I; }
+
+/// Match against the nested pattern, and capture the instruction if we match.
+template <typename MatchTy>
+inline bind_and_match_ty<Instruction, MatchTy>
+m_Instruction(Instruction *&I, const MatchTy &Match) {
+  return {I, Match};
+}
+
 /// Match a unary operator, capturing it if we match.
 inline bind_ty<UnaryOperator> m_UnOp(UnaryOperator *&I) { return I; }
 /// Match a binary operator, capturing it if we match.

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 7f605be976549..d934638c15e75 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1355,9 +1355,9 @@ Instruction *InstCombinerImpl::
   // right-shift of X and a "select".
   Value *X, *Select;
   Instruction *LowBitsToSkip, *Extract;
-  if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_CombineAnd(
-                               m_LShr(m_Value(X), m_Instruction(LowBitsToSkip)),
-                               m_Instruction(Extract))),
+  if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_Instruction(
+                               Extract, m_LShr(m_Value(X),
+                                               m_Instruction(LowBitsToSkip)))),
                            m_Value(Select))))
     return nullptr;
 
@@ -1763,13 +1763,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
     Constant *C;
     // (add X, (sext/zext (icmp eq X, C)))
     //    -> (select (icmp eq X, C), (add C, (sext/zext 1)), X)
-    auto CondMatcher = m_CombineAnd(
-        m_Value(Cond),
-        m_SpecificICmp(ICmpInst::ICMP_EQ, m_Deferred(A), m_ImmConstant(C)));
+    auto CondMatcher =
+        m_Value(Cond, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Deferred(A),
+                                     m_ImmConstant(C)));
 
     if (match(&I,
-              m_c_Add(m_Value(A),
-                      m_CombineAnd(m_Value(Ext), m_ZExtOrSExt(CondMatcher)))) &&
+              m_c_Add(m_Value(A), m_Value(Ext, m_ZExtOrSExt(CondMatcher)))) &&
         Ext->hasOneUse()) {
       Value *Add = isa<ZExtInst>(Ext) ? InstCombiner::AddOne(C)
                                       : InstCombiner::SubOne(C);

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 3beda6bc5ba38..b231c04319106 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2025,10 +2025,9 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I,
     if (CountUses && !Op->hasOneUse())
       return false;
 
-    if (match(Op, m_c_BinOp(FlippedOpcode,
-                            m_CombineAnd(m_Value(X),
-                                         m_Not(m_c_BinOp(Opcode, m_A, m_B))),
-                            m_C)))
+    if (match(Op,
+              m_c_BinOp(FlippedOpcode,
+                        m_Value(X, m_Not(m_c_BinOp(Opcode, m_A, m_B))), m_C)))
       return !CountUses || X->hasOneUse();
 
     return false;
@@ -2079,10 +2078,10 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I,
     // result is more undefined than a source:
     // (~(A & B) | C) & ~(C & (A ^ B)) --> (A ^ B ^ C) | ~(A | C) is invalid.
     if (Opcode == Instruction::Or && Op0->hasOneUse() &&
-        match(Op1, m_OneUse(m_Not(m_CombineAnd(
-                       m_Value(Y),
-                       m_c_BinOp(Opcode, m_Specific(C),
-                                 m_c_Xor(m_Specific(A), m_Specific(B)))))))) {
+        match(Op1,
+              m_OneUse(m_Not(m_Value(
+                  Y, m_c_BinOp(Opcode, m_Specific(C),
+                               m_c_Xor(m_Specific(A), m_Specific(B)))))))) {
       // X = ~(A | B)
       // Y = (C | (A ^ B)
       Value *Or = cast<BinaryOperator>(X)->getOperand(0);
@@ -2098,12 +2097,11 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I,
   if (match(Op0,
             m_OneUse(m_c_BinOp(FlippedOpcode,
                                m_BinOp(FlippedOpcode, m_Value(B), m_Value(C)),
-                               m_CombineAnd(m_Value(X), m_Not(m_Value(A)))))) ||
-      match(Op0, m_OneUse(m_c_BinOp(
-                     FlippedOpcode,
-                     m_c_BinOp(FlippedOpcode, m_Value(C),
-                               m_CombineAnd(m_Value(X), m_Not(m_Value(A)))),
-                     m_Value(B))))) {
+                               m_Value(X, m_Not(m_Value(A)))))) ||
+      match(Op0, m_OneUse(m_c_BinOp(FlippedOpcode,
+                                    m_c_BinOp(FlippedOpcode, m_Value(C),
+                                              m_Value(X, m_Not(m_Value(A)))),
+                                    m_Value(B))))) {
     // X = ~A
     // (~A & B & C) | ~(A | B | C) --> ~(A | (B ^ C))
     // (~A | B | C) & ~(A & B & C) --> (~A | (B ^ C))
@@ -2434,8 +2432,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
   // (-(X & 1)) & Y --> (X & 1) == 0 ? 0 : Y
   Value *Neg;
   if (match(&I,
-            m_c_And(m_CombineAnd(m_Value(Neg),
-                                 m_OneUse(m_Neg(m_And(m_Value(), m_One())))),
+            m_c_And(m_Value(Neg, m_OneUse(m_Neg(m_And(m_Value(), m_One())))),
                     m_Value(Y)))) {
     Value *Cmp = Builder.CreateIsNull(Neg);
     return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Y);
@@ -3728,9 +3725,8 @@ static Value *foldOrUnsignedUMulOverflowICmp(BinaryOperator &I,
   const APInt *C1, *C2;
   if (match(&I,
             m_c_Or(m_ExtractValue<1>(
-                       m_CombineAnd(m_Intrinsic<Intrinsic::umul_with_overflow>(
-                                        m_Value(X), m_APInt(C1)),
-                                    m_Value(WOV))),
+                       m_Value(WOV, m_Intrinsic<Intrinsic::umul_with_overflow>(
+                                        m_Value(X), m_APInt(C1)))),
                    m_OneUse(m_SpecificCmp(ICmpInst::ICMP_UGT,
                                           m_ExtractValue<0>(m_Deferred(WOV)),
                                           m_APInt(C2))))) &&
@@ -3988,12 +3984,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
     // ~(B & ?) | (A ^ B) --> ~((B & ?) & A)
     Instruction *And;
     if ((Op0->hasOneUse() || Op1->hasOneUse()) &&
-        match(Op0, m_Not(m_CombineAnd(m_Instruction(And),
-                                      m_c_And(m_Specific(A), m_Value())))))
+        match(Op0,
+              m_Not(m_Instruction(And, m_c_And(m_Specific(A), m_Value())))))
       return BinaryOperator::CreateNot(Builder.CreateAnd(And, B));
     if ((Op0->hasOneUse() || Op1->hasOneUse()) &&
-        match(Op0, m_Not(m_CombineAnd(m_Instruction(And),
-                                      m_c_And(m_Specific(B), m_Value())))))
+        match(Op0,
+              m_Not(m_Instruction(And, m_c_And(m_Specific(B), m_Value())))))
       return BinaryOperator::CreateNot(Builder.CreateAnd(And, A));
 
     // (~A | C) | (A ^ B) --> ~(A & B) | C
@@ -4125,16 +4121,13 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   // treating any non-zero result as overflow. In that case, we overflow if both
   // umul.with.overflow operands are != 0, as in that case the result can only
   // be 0, iff the multiplication overflows.
-  if (match(&I,
-            m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_Value(UMulWithOv)),
-                                m_Value(Ov)),
-                   m_CombineAnd(
-                       m_SpecificICmp(ICmpInst::ICMP_NE,
-                                      m_CombineAnd(m_ExtractValue<0>(
-                                                       m_Deferred(UMulWithOv)),
-                                                   m_Value(Mul)),
-                                      m_ZeroInt()),
-                       m_Value(MulIsNotZero)))) &&
+  if (match(&I, m_c_Or(m_Value(Ov, m_ExtractValue<1>(m_Value(UMulWithOv))),
+                       m_Value(MulIsNotZero,
+                               m_SpecificICmp(
+                                   ICmpInst::ICMP_NE,
+                                   m_Value(Mul, m_ExtractValue<0>(
+                                                    m_Deferred(UMulWithOv))),
+                                   m_ZeroInt())))) &&
       (Ov->hasOneUse() || (MulIsNotZero->hasOneUse() && Mul->hasOneUse()))) {
     Value *A, *B;
     if (match(UMulWithOv, m_Intrinsic<Intrinsic::umul_with_overflow>(
@@ -4151,9 +4144,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   const WithOverflowInst *WO;
   const Value *WOV;
   const APInt *C1, *C2;
-  if (match(&I, m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_CombineAnd(
-                                        m_WithOverflowInst(WO), m_Value(WOV))),
-                                    m_Value(Ov)),
+  if (match(&I, m_c_Or(m_Value(Ov, m_ExtractValue<1>(
+                                       m_Value(WOV, m_WithOverflowInst(WO)))),
                        m_OneUse(m_ICmp(Pred, m_ExtractValue<0>(m_Deferred(WOV)),
                                        m_APInt(C2))))) &&
       (WO->getBinaryOp() == Instruction::Add ||
@@ -4501,8 +4493,7 @@ static Instruction *visitMaskedMerge(BinaryOperator &I,
   Value *M;
   if (!match(&I, m_c_Xor(m_Value(B),
                          m_OneUse(m_c_And(
-                             m_CombineAnd(m_c_Xor(m_Deferred(B), m_Value(X)),
-                                          m_Value(D)),
+                             m_Value(D, m_c_Xor(m_Deferred(B), m_Value(X))),
                              m_Value(M))))))
     return nullptr;
 
@@ -5206,8 +5197,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
   //   (X ^ C) ^ Y --> (X ^ Y) ^ C
   // Just like we do in other places, we completely avoid the fold
   // for constantexprs, at least to avoid endless combine loop.
-  if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_CombineAnd(m_Value(X),
-                                                    m_Unless(m_ConstantExpr())),
+  if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(X, m_Unless(m_ConstantExpr())),
                                        m_ImmConstant(C1))),
                         m_Value(Y))))
     return BinaryOperator::CreateXor(Builder.CreateXor(X, Y), C1);


        


More information about the llvm-commits mailing list