[llvm] [PatternMatch] Introduce m_c_Select (PR #114328)

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 20 05:44:52 PST 2024


https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/114328

>From 1b1e5e2be209a9c6275aef3fa6b5944e72761cf4 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Wed, 30 Oct 2024 23:27:58 +0000
Subject: [PATCH 1/2] [PatternMatch] Introduce m_c_Select

This matches m_Select(C, L, R) or m_Select(C, R, L).
---
 llvm/include/llvm/IR/PatternMatch.h            | 18 +++++++++++++++---
 .../InstCombine/InstCombineAddSub.cpp          |  3 +--
 .../InstCombine/InstCombineCalls.cpp           |  5 ++---
 .../InstCombine/InstCombineCompares.cpp        |  4 +---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp      |  6 ++----
 5 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 0d6df727906324..46bcce3ffc160e 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1716,7 +1716,8 @@ template <typename T0, typename T1, unsigned Opcode> struct TwoOps_match {
 };
 
 /// Matches instructions with Opcode and three operands.
-template <typename T0, typename T1, typename T2, unsigned Opcode>
+template <typename T0, typename T1, typename T2, unsigned Opcode,
+          bool CommutableOp2Op3 = false>
 struct ThreeOps_match {
   T0 Op1;
   T1 Op2;
@@ -1728,8 +1729,12 @@ struct ThreeOps_match {
   template <typename OpTy> bool match(OpTy *V) {
     if (V->getValueID() == Value::InstructionVal + Opcode) {
       auto *I = cast<Instruction>(V);
-      return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) &&
-             Op3.match(I->getOperand(2));
+      if (!Op1.match(I->getOperand(0)))
+        return false;
+      if (Op2.match(I->getOperand(1)) && Op3.match(I->getOperand(2)))
+        return true;
+      return CommutableOp2Op3 && Op2.match(I->getOperand(2)) &&
+             Op3.match(I->getOperand(1));
     }
     return false;
   }
@@ -1781,6 +1786,13 @@ m_SelectCst(const Cond &C) {
   return m_Select(C, m_ConstantInt<L>(), m_ConstantInt<R>());
 }
 
+/// Match Select(C, LHS, RHS) or Select(C, RHS, LHS)
+template <typename Cond, typename LHS, typename RHS>
+inline ThreeOps_match<Cond, LHS, RHS, Instruction::Select, true>
+m_c_Select(const Cond &C, const LHS &L, const RHS &R) {
+  return ThreeOps_match<Cond, LHS, RHS, Instruction::Select, true>(C, L, R);
+}
+
 /// Matches FreezeInst.
 template <typename OpTy>
 inline OneOps_match<OpTy, Instruction::Freeze> m_Freeze(const OpTy &Op) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 46ce011c5f7880..ce802a99e24285 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2246,8 +2246,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
         if (!UI)
           return false;
         return match(UI,
-                     m_Select(m_Value(), m_Specific(Op1), m_Specific(&I))) ||
-               match(UI, m_Select(m_Value(), m_Specific(&I), m_Specific(Op1)));
+                     m_c_Select(m_Value(), m_Specific(Op1), m_Specific(&I)));
       })) {
     if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation &&
                                                         I.hasNoSignedWrap(),
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 42c0acd1e45ec1..00dcc08ce59d9f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1736,9 +1736,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     Value *X;
     if (match(IIOperand, m_Neg(m_Value(X))))
       return replaceOperand(*II, 0, X);
-    if (match(IIOperand, m_Select(m_Value(), m_Value(X), m_Neg(m_Deferred(X)))))
-      return replaceOperand(*II, 0, X);
-    if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X))))
+    if (match(IIOperand,
+              m_c_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X))))
       return replaceOperand(*II, 0, X);
 
     Value *Y;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index d602a907e72bcd..35596c286f8cc7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -8437,9 +8437,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
     case Instruction::Select:
       // fcmp eq (cond ? x : -x), 0 --> fcmp eq x, 0
       if (FCmpInst::isEquality(Pred) && match(RHSC, m_AnyZeroFP()) &&
-          (match(LHSI,
-                 m_Select(m_Value(), m_Value(X), m_FNeg(m_Deferred(X)))) ||
-           match(LHSI, m_Select(m_Value(), m_FNeg(m_Value(X)), m_Deferred(X)))))
+          match(LHSI, m_c_Select(m_Value(), m_FNeg(m_Value(X)), m_Deferred(X))))
         return replaceOperand(I, 0, X);
       if (Instruction *NV = FoldOpIntoSelect(I, cast<SelectInst>(LHSI)))
         return NV;
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 1991ec82d1e1e4..ac106cb32387eb 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -3814,10 +3814,8 @@ static bool foldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
   // These can often be turned into switches and other things.
   auto IsBinOpOrAnd = [](Value *V) {
     return match(
-        V, m_CombineOr(
-               m_BinOp(),
-               m_CombineOr(m_Select(m_Value(), m_ImmConstant(), m_Value()),
-                           m_Select(m_Value(), m_Value(), m_ImmConstant()))));
+        V, m_CombineOr(m_BinOp(),
+                       m_c_Select(m_Value(), m_ImmConstant(), m_Value())));
   };
   if (PN->getType()->isIntegerTy(1) &&
       (IsBinOpOrAnd(PN->getIncomingValue(0)) ||

>From fdff967ae849e30d594ba0d121629f3f64ce957b Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Wed, 20 Nov 2024 13:44:35 +0000
Subject: [PATCH 2/2] Remove the first condition parameter

---
 llvm/include/llvm/IR/PatternMatch.h                     | 9 +++++----
 llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp   | 3 +--
 llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp    | 3 +--
 llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 2 +-
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp               | 3 +--
 5 files changed, 9 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 46bcce3ffc160e..fc4c0124d00b84 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1787,10 +1787,11 @@ m_SelectCst(const Cond &C) {
 }
 
 /// Match Select(C, LHS, RHS) or Select(C, RHS, LHS)
-template <typename Cond, typename LHS, typename RHS>
-inline ThreeOps_match<Cond, LHS, RHS, Instruction::Select, true>
-m_c_Select(const Cond &C, const LHS &L, const RHS &R) {
-  return ThreeOps_match<Cond, LHS, RHS, Instruction::Select, true>(C, L, R);
+template <typename LHS, typename RHS>
+inline ThreeOps_match<decltype(m_Value()), LHS, RHS, Instruction::Select, true>
+m_c_Select(const LHS &L, const RHS &R) {
+  return ThreeOps_match<decltype(m_Value()), LHS, RHS, Instruction::Select,
+                        true>(m_Value(), L, R);
 }
 
 /// Matches FreezeInst.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index ce802a99e24285..6fe96935818531 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2245,8 +2245,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
         const Instruction *UI = dyn_cast<Instruction>(U);
         if (!UI)
           return false;
-        return match(UI,
-                     m_c_Select(m_Value(), m_Specific(Op1), m_Specific(&I)));
+        return match(UI, m_c_Select(m_Specific(Op1), m_Specific(&I)));
       })) {
     if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation &&
                                                         I.hasNoSignedWrap(),
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 00dcc08ce59d9f..fd38738e3be80b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1736,8 +1736,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     Value *X;
     if (match(IIOperand, m_Neg(m_Value(X))))
       return replaceOperand(*II, 0, X);
-    if (match(IIOperand,
-              m_c_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X))))
+    if (match(IIOperand, m_c_Select(m_Neg(m_Value(X)), m_Deferred(X))))
       return replaceOperand(*II, 0, X);
 
     Value *Y;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 35596c286f8cc7..acf01a8f1f7fc5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -8437,7 +8437,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
     case Instruction::Select:
       // fcmp eq (cond ? x : -x), 0 --> fcmp eq x, 0
       if (FCmpInst::isEquality(Pred) && match(RHSC, m_AnyZeroFP()) &&
-          match(LHSI, m_c_Select(m_Value(), m_FNeg(m_Value(X)), m_Deferred(X))))
+          match(LHSI, m_c_Select(m_FNeg(m_Value(X)), m_Deferred(X))))
         return replaceOperand(I, 0, X);
       if (Instruction *NV = FoldOpIntoSelect(I, cast<SelectInst>(LHSI)))
         return NV;
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index ac106cb32387eb..0c84e6fae496f5 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -3814,8 +3814,7 @@ static bool foldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
   // These can often be turned into switches and other things.
   auto IsBinOpOrAnd = [](Value *V) {
     return match(
-        V, m_CombineOr(m_BinOp(),
-                       m_c_Select(m_Value(), m_ImmConstant(), m_Value())));
+        V, m_CombineOr(m_BinOp(), m_c_Select(m_ImmConstant(), m_Value())));
   };
   if (PN->getType()->isIntegerTy(1) &&
       (IsBinOpOrAnd(PN->getIncomingValue(0)) ||



More information about the llvm-commits mailing list