[llvm] Draft (PR #111774)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 9 17:24:32 PDT 2024


https://github.com/c8ef created https://github.com/llvm/llvm-project/pull/111774

None

>From 769c4f1dd643f2714a6d2c1c9c392547d02e1eff Mon Sep 17 00:00:00 2001
From: c8ef <c8ef at outlook.com>
Date: Thu, 10 Oct 2024 00:23:28 +0000
Subject: [PATCH] match select icmp

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h | 94 ++++++++++++++++++++--
 1 file changed, 86 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 04135ee7e1c022..b629dd50aced00 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -542,6 +542,80 @@ struct BinaryOpc_match {
   }
 };
 
+template <typename LHS_P, typename RHS_P, typename Pred_t,
+          bool Commutable = false, bool ExcludeChain = false>
+struct MaxMin_match {
+  using PredType = Pred_t;
+  LHS_P LHS;
+  RHS_P RHS;
+
+  MaxMin_match(const LHS_P &L, const RHS_P &R) : LHS(L), RHS(R) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    if (sd_context_match(N, Ctx, m_Opc(ISD::SELECT))) {
+      EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
+      assert(EO_SELECT.Size == 3);
+      SDValue Cond = N->getOperand(EO_SELECT.FirstIndex);
+      SDValue TrueValue = N->getOperand(EO_SELECT.FirstIndex + 1);
+      SDValue FalseValue = N->getOperand(EO_SELECT.FirstIndex + 2);
+
+      if (sd_context_match(Cond, Ctx, m_Opc(ISD::SETCC))) {
+        EffectiveOperands<ExcludeChain> EO_SETCC(Cond, Ctx);
+        assert(EO_SETCC.Size == 3);
+        SDValue L = Cond->getOperand(EO_SETCC.FirstIndex);
+        SDValue R = Cond->getOperand(EO_SETCC.FirstIndex + 1);
+        CondCodeSDNode *CondNode =
+            cast<CondCodeSDNode>(Cond->getOperand(EO_SETCC.FirstIndex + 2));
+
+        if ((TrueValue != L || FalseValue != R) &&
+            (TrueValue != R || FalseValue != L)) {
+          return false;
+        }
+
+        ISD::CondCode Cond =
+            TrueValue == L ? CondNode->get()
+                           : getSetCCInverse(CondNode->get(), L.getValueType());
+        if (!Pred_t::match(Cond)) {
+          return false;
+        }
+        return (LHS.match(Ctx, L) && RHS.match(Ctx, R)) ||
+               (Commutable && LHS.match(Ctx, R) && RHS.match(Ctx, L));
+      }
+    }
+
+    return false;
+  }
+};
+
+// Helper class for identifying signed max predicates.
+struct smax_pred_ty {
+  static bool match(ISD::CondCode Cond) {
+    return Cond == ISD::CondCode::SETGT || Cond == ISD::CondCode::SETGE;
+  }
+};
+
+// Helper class for identifying unsigned max predicates.
+struct umax_pred_ty {
+  static bool match(ISD::CondCode Cond) {
+    return Cond == ISD::CondCode::SETUGT || Cond == ISD::CondCode::SETUGE;
+  }
+};
+
+// Helper class for identifying signed min predicates.
+struct smin_pred_ty {
+  static bool match(ISD::CondCode Cond) {
+    return Cond == ISD::CondCode::SETLT || Cond == ISD::CondCode::SETLE;
+  }
+};
+
+// Helper class for identifying unsigned min predicates.
+struct umin_pred_ty {
+  static bool match(ISD::CondCode Cond) {
+    return Cond == ISD::CondCode::SETULT || Cond == ISD::CondCode::SETULE;
+  }
+};
+
 template <typename LHS, typename RHS>
 inline BinaryOpc_match<LHS, RHS> m_BinOp(unsigned Opc, const LHS &L,
                                          const RHS &R) {
@@ -609,23 +683,27 @@ inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
 }
 
 template <typename LHS, typename RHS>
-inline BinaryOpc_match<LHS, RHS, true> m_SMin(const LHS &L, const RHS &R) {
-  return BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R);
+inline auto m_SMin(const LHS &L, const RHS &R) {
+  return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R),
+                 MaxMin_match<LHS, RHS, smin_pred_ty, true>(L, R));
 }
 
 template <typename LHS, typename RHS>
-inline BinaryOpc_match<LHS, RHS, true> m_SMax(const LHS &L, const RHS &R) {
-  return BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R);
+inline auto m_SMax(const LHS &L, const RHS &R) {
+  return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R),
+                 MaxMin_match<LHS, RHS, smax_pred_ty, true>(L, R));
 }
 
 template <typename LHS, typename RHS>
-inline BinaryOpc_match<LHS, RHS, true> m_UMin(const LHS &L, const RHS &R) {
-  return BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R);
+inline auto m_UMin(const LHS &L, const RHS &R) {
+  return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R),
+                 MaxMin_match<LHS, RHS, umin_pred_ty, true>(L, R));
 }
 
 template <typename LHS, typename RHS>
-inline BinaryOpc_match<LHS, RHS, true> m_UMax(const LHS &L, const RHS &R) {
-  return BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R);
+inline auto m_UMax(const LHS &L, const RHS &R) {
+  return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R),
+                 MaxMin_match<LHS, RHS, umax_pred_ty, true>(L, R));
 }
 
 template <typename LHS, typename RHS>



More information about the llvm-commits mailing list