[llvm] 18abc7e - [PatternMatch] Introduce m_c_Select (#114328)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 25 05:47:27 PST 2024


Author: David Green
Date: 2024-11-25T13:47:23Z
New Revision: 18abc7e0c5b34e9e7bbe0893a4a5281c0937f7d8

URL: https://github.com/llvm/llvm-project/commit/18abc7e0c5b34e9e7bbe0893a4a5281c0937f7d8
DIFF: https://github.com/llvm/llvm-project/commit/18abc7e0c5b34e9e7bbe0893a4a5281c0937f7d8.diff

LOG: [PatternMatch] Introduce m_c_Select (#114328)

This matches m_Select(m_Value(), L, R) or m_Select(m_Value(), R, L).

Added: 
    

Modified: 
    llvm/include/llvm/IR/PatternMatch.h
    llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 0d6df727906324..fc4c0124d00b84 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1716,7 +1716,8 @@ template <typename T0, typename T1, unsigned Opcode> struct TwoOps_match {
 };
 
 /// Matches instructions with Opcode and three operands.
-template <typename T0, typename T1, typename T2, unsigned Opcode>
+template <typename T0, typename T1, typename T2, unsigned Opcode,
+          bool CommutableOp2Op3 = false>
 struct ThreeOps_match {
   T0 Op1;
   T1 Op2;
@@ -1728,8 +1729,12 @@ struct ThreeOps_match {
   template <typename OpTy> bool match(OpTy *V) {
     if (V->getValueID() == Value::InstructionVal + Opcode) {
       auto *I = cast<Instruction>(V);
-      return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) &&
-             Op3.match(I->getOperand(2));
+      if (!Op1.match(I->getOperand(0)))
+        return false;
+      if (Op2.match(I->getOperand(1)) && Op3.match(I->getOperand(2)))
+        return true;
+      return CommutableOp2Op3 && Op2.match(I->getOperand(2)) &&
+             Op3.match(I->getOperand(1));
     }
     return false;
   }
@@ -1781,6 +1786,14 @@ m_SelectCst(const Cond &C) {
   return m_Select(C, m_ConstantInt<L>(), m_ConstantInt<R>());
 }
 
+/// Match Select(C, LHS, RHS) or Select(C, RHS, LHS)
+template <typename LHS, typename RHS>
+inline ThreeOps_match<decltype(m_Value()), LHS, RHS, Instruction::Select, true>
+m_c_Select(const LHS &L, const RHS &R) {
+  return ThreeOps_match<decltype(m_Value()), LHS, RHS, Instruction::Select,
+                        true>(m_Value(), L, R);
+}
+
 /// Matches FreezeInst.
 template <typename OpTy>
 inline OneOps_match<OpTy, Instruction::Freeze> m_Freeze(const OpTy &Op) {

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 46ce011c5f7880..6fe96935818531 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2245,9 +2245,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
         const Instruction *UI = dyn_cast<Instruction>(U);
         if (!UI)
           return false;
-        return match(UI,
-                     m_Select(m_Value(), m_Specific(Op1), m_Specific(&I))) ||
-               match(UI, m_Select(m_Value(), m_Specific(&I), m_Specific(Op1)));
+        return match(UI, m_c_Select(m_Specific(Op1), m_Specific(&I)));
       })) {
     if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation &&
                                                         I.hasNoSignedWrap(),

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 42c0acd1e45ec1..fd38738e3be80b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1736,9 +1736,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     Value *X;
     if (match(IIOperand, m_Neg(m_Value(X))))
       return replaceOperand(*II, 0, X);
-    if (match(IIOperand, m_Select(m_Value(), m_Value(X), m_Neg(m_Deferred(X)))))
-      return replaceOperand(*II, 0, X);
-    if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X))))
+    if (match(IIOperand, m_c_Select(m_Neg(m_Value(X)), m_Deferred(X))))
       return replaceOperand(*II, 0, X);
 
     Value *Y;

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 857724470f2223..fed21db393ed22 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -8460,9 +8460,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
     case Instruction::Select:
       // fcmp eq (cond ? x : -x), 0 --> fcmp eq x, 0
       if (FCmpInst::isEquality(Pred) && match(RHSC, m_AnyZeroFP()) &&
-          (match(LHSI,
-                 m_Select(m_Value(), m_Value(X), m_FNeg(m_Deferred(X)))) ||
-           match(LHSI, m_Select(m_Value(), m_FNeg(m_Value(X)), m_Deferred(X)))))
+          match(LHSI, m_c_Select(m_FNeg(m_Value(X)), m_Deferred(X))))
         return replaceOperand(I, 0, X);
       if (Instruction *NV = FoldOpIntoSelect(I, cast<SelectInst>(LHSI)))
         return NV;

diff  --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index b664bde5d320a1..c7e814bced57db 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -3842,10 +3842,7 @@ static bool foldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
   // These can often be turned into switches and other things.
   auto IsBinOpOrAnd = [](Value *V) {
     return match(
-        V, m_CombineOr(
-               m_BinOp(),
-               m_CombineOr(m_Select(m_Value(), m_ImmConstant(), m_Value()),
-                           m_Select(m_Value(), m_Value(), m_ImmConstant()))));
+        V, m_CombineOr(m_BinOp(), m_c_Select(m_ImmConstant(), m_Value())));
   };
   if (PN->getType()->isIntegerTy(1) &&
       (IsBinOpOrAnd(PN->getIncomingValue(0)) ||


        


More information about the llvm-commits mailing list