[llvm] Draft (PR #111774)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 9 17:24:32 PDT 2024
https://github.com/c8ef created https://github.com/llvm/llvm-project/pull/111774
None
>From 769c4f1dd643f2714a6d2c1c9c392547d02e1eff Mon Sep 17 00:00:00 2001
From: c8ef <c8ef at outlook.com>
Date: Thu, 10 Oct 2024 00:23:28 +0000
Subject: [PATCH] match select icmp
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 94 ++++++++++++++++++++--
1 file changed, 86 insertions(+), 8 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 04135ee7e1c022..b629dd50aced00 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -542,6 +542,80 @@ struct BinaryOpc_match {
}
};
+template <typename LHS_P, typename RHS_P, typename Pred_t,
+ bool Commutable = false, bool ExcludeChain = false>
+struct MaxMin_match {
+ using PredType = Pred_t;
+ LHS_P LHS;
+ RHS_P RHS;
+
+ MaxMin_match(const LHS_P &L, const RHS_P &R) : LHS(L), RHS(R) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ if (sd_context_match(N, Ctx, m_Opc(ISD::SELECT))) {
+ EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
+ assert(EO_SELECT.Size == 3);
+ SDValue Cond = N->getOperand(EO_SELECT.FirstIndex);
+ SDValue TrueValue = N->getOperand(EO_SELECT.FirstIndex + 1);
+ SDValue FalseValue = N->getOperand(EO_SELECT.FirstIndex + 2);
+
+ if (sd_context_match(Cond, Ctx, m_Opc(ISD::SETCC))) {
+ EffectiveOperands<ExcludeChain> EO_SETCC(Cond, Ctx);
+ assert(EO_SETCC.Size == 3);
+ SDValue L = Cond->getOperand(EO_SETCC.FirstIndex);
+ SDValue R = Cond->getOperand(EO_SETCC.FirstIndex + 1);
+ CondCodeSDNode *CondNode =
+ cast<CondCodeSDNode>(Cond->getOperand(EO_SETCC.FirstIndex + 2));
+
+ if ((TrueValue != L || FalseValue != R) &&
+ (TrueValue != R || FalseValue != L)) {
+ return false;
+ }
+
+ ISD::CondCode Cond =
+ TrueValue == L ? CondNode->get()
+ : getSetCCInverse(CondNode->get(), L.getValueType());
+ if (!Pred_t::match(Cond)) {
+ return false;
+ }
+ return (LHS.match(Ctx, L) && RHS.match(Ctx, R)) ||
+ (Commutable && LHS.match(Ctx, R) && RHS.match(Ctx, L));
+ }
+ }
+
+ return false;
+ }
+};
+
+// Helper class for identifying signed max predicates.
+struct smax_pred_ty {
+ static bool match(ISD::CondCode Cond) {
+ return Cond == ISD::CondCode::SETGT || Cond == ISD::CondCode::SETGE;
+ }
+};
+
+// Helper class for identifying unsigned max predicates.
+struct umax_pred_ty {
+ static bool match(ISD::CondCode Cond) {
+ return Cond == ISD::CondCode::SETUGT || Cond == ISD::CondCode::SETUGE;
+ }
+};
+
+// Helper class for identifying signed min predicates.
+struct smin_pred_ty {
+ static bool match(ISD::CondCode Cond) {
+ return Cond == ISD::CondCode::SETLT || Cond == ISD::CondCode::SETLE;
+ }
+};
+
+// Helper class for identifying unsigned min predicates.
+struct umin_pred_ty {
+ static bool match(ISD::CondCode Cond) {
+ return Cond == ISD::CondCode::SETULT || Cond == ISD::CondCode::SETULE;
+ }
+};
+
template <typename LHS, typename RHS>
inline BinaryOpc_match<LHS, RHS> m_BinOp(unsigned Opc, const LHS &L,
const RHS &R) {
@@ -609,23 +683,27 @@ inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
}
template <typename LHS, typename RHS>
-inline BinaryOpc_match<LHS, RHS, true> m_SMin(const LHS &L, const RHS &R) {
- return BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R);
+inline auto m_SMin(const LHS &L, const RHS &R) {
+ return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R),
+ MaxMin_match<LHS, RHS, smin_pred_ty, true>(L, R));
}
template <typename LHS, typename RHS>
-inline BinaryOpc_match<LHS, RHS, true> m_SMax(const LHS &L, const RHS &R) {
- return BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R);
+inline auto m_SMax(const LHS &L, const RHS &R) {
+ return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R),
+ MaxMin_match<LHS, RHS, smax_pred_ty, true>(L, R));
}
template <typename LHS, typename RHS>
-inline BinaryOpc_match<LHS, RHS, true> m_UMin(const LHS &L, const RHS &R) {
- return BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R);
+inline auto m_UMin(const LHS &L, const RHS &R) {
+ return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R),
+ MaxMin_match<LHS, RHS, umin_pred_ty, true>(L, R));
}
template <typename LHS, typename RHS>
-inline BinaryOpc_match<LHS, RHS, true> m_UMax(const LHS &L, const RHS &R) {
- return BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R);
+inline auto m_UMax(const LHS &L, const RHS &R) {
+ return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R),
+ MaxMin_match<LHS, RHS, umax_pred_ty, true>(L, R));
}
template <typename LHS, typename RHS>
More information about the llvm-commits
mailing list