[llvm] [SCEV] Add initial matchers for SCEV expressions. (NFC) (PR #119390)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 13 06:48:32 PST 2024
https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/119390
>From ecaa67e2f4beb4905a9d605b4e7f2b6eae61ed0c Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 10 Dec 2024 15:01:52 +0000
Subject: [PATCH 1/2] [SCEV] Add initial matchers for SCEV expressions. (NFC)
This patch adds initial matchers for SCEV expressions with an arbitrary
number of operands and specializes it for binary add expressions.
Also adds matchers for SCEVConstant and SCEVUnknown.
This patch only converts a few instances to use the new matchers to make
sure everything builds as expected for now.
Depends on https://github.com/llvm/llvm-project/pull/119389
---
.../Analysis/ScalarEvolutionPatternMatch.h | 81 +++++++++++++++++++
llvm/lib/Analysis/ScalarEvolution.cpp | 12 ++-
2 files changed, 86 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 21d2ef3c867d7d..79295dd324c4ba 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -52,6 +52,87 @@ inline cst_pred_ty<is_all_ones> m_scev_AllOnes() {
return cst_pred_ty<is_all_ones>();
}
+template <typename Class> struct class_match {
+ template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
+};
+
+template <typename Class> struct bind_ty {
+ Class *&VR;
+
+ bind_ty(Class *&V) : VR(V) {}
+
+ template <typename ITy> bool match(ITy *V) const {
+ if (auto *CV = dyn_cast<Class>(V)) {
+ VR = CV;
+ return true;
+ }
+ return false;
+ }
+};
+
+/// Match a SCEV, capturing it if we match.
+inline bind_ty<const SCEV> m_SCEV(const SCEV *&V) { return V; }
+inline bind_ty<const SCEVConstant> m_SCEVConstant(const SCEVConstant *&V) {
+ return V;
+}
+inline bind_ty<const SCEVUnknown> m_SCEVUnknown(const SCEVUnknown *&V) {
+ return V;
+}
+
+namespace detail {
+
+template <typename TupleTy, typename Fn, std::size_t... Is>
+bool CheckTupleElements(const TupleTy &Ops, Fn P, std::index_sequence<Is...>) {
+ return (P(std::get<Is>(Ops), Is) && ...);
+}
+
+/// Helper to check if predicate \p P holds on all tuple elements in \p Ops
+template <typename TupleTy, typename Fn>
+bool all_of_tuple_elements(const TupleTy &Ops, Fn P) {
+ return CheckTupleElements(
+ Ops, P, std::make_index_sequence<std::tuple_size<TupleTy>::value>{});
+}
+
+} // namespace detail
+
+template <typename Ops_t, typename SCEVTy> struct SCEV_match {
+ Ops_t Ops;
+
+ SCEV_match() : Ops() {
+ static_assert(std::tuple_size<Ops_t>::value == 0 &&
+ "constructor can only be used with zero operands");
+ }
+ SCEV_match(Ops_t Ops) : Ops(Ops) {}
+ template <typename A_t, typename B_t> SCEV_match(A_t A, B_t B) : Ops({A, B}) {
+ static_assert(std::tuple_size<Ops_t>::value == 2 &&
+ "constructor can only be used for binary matcher");
+ }
+
+ bool match(const SCEV *S) const {
+ auto *Cast = dyn_cast<SCEVTy>(S);
+ if (!Cast || Cast->getNumOperands() != std::tuple_size<Ops_t>::value)
+ return false;
+ return detail::all_of_tuple_elements(Ops, [Cast](auto Op, unsigned Idx) {
+ return Op.match(Cast->getOperand(Idx));
+ });
+ }
+};
+
+template <typename Op0_t, typename Op1_t, typename SCEVTy>
+using BinarySCEV_match = SCEV_match<std::tuple<Op0_t, Op1_t>, SCEVTy>;
+
+template <typename Op0_t, typename Op1_t, typename SCEVTy>
+inline BinarySCEV_match<Op0_t, Op1_t, SCEVTy> m_scev_Binary(const Op0_t &Op0,
+ const Op1_t &Op1) {
+ return BinarySCEV_match<Op0_t, Op1_t, SCEVTy>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>
+m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
+ return BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>(Op0, Op1);
+}
+
} // namespace SCEVPatternMatch
} // namespace llvm
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e18133971f5bf0..66729fd970c3c0 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15381,14 +15381,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
// (X >=u C1).
auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
&ExprsToRewrite]() {
- auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
- if (!AddExpr || AddExpr->getNumOperands() != 2)
- return false;
-
- auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
- auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
+ const SCEVConstant *C1;
+ const SCEVUnknown *LHSUnknown;
auto *C2 = dyn_cast<SCEVConstant>(RHS);
- if (!C1 || !C2 || !LHSUnknown)
+ if (!match(LHS,
+ m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
+ !C2)
return false;
auto ExactRegion =
>From a0dc4283e90f96c39d7d50b16f68b41fbb5d055f Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 13 Dec 2024 14:47:27 +0000
Subject: [PATCH 2/2] !fixup add unary matcher for ZExt and SExt
---
.../Analysis/ScalarEvolutionPatternMatch.h | 32 +++++++++++++++++++
llvm/lib/Analysis/ScalarEvolution.cpp | 23 ++++++-------
2 files changed, 41 insertions(+), 14 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 79295dd324c4ba..6e7c985b2b6e86 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -79,6 +79,18 @@ inline bind_ty<const SCEVUnknown> m_SCEVUnknown(const SCEVUnknown *&V) {
return V;
}
+/// Match a specified const SCEV *.
+struct specificscev_ty {
+ const SCEV *Expr;
+
+ specificscev_ty(const SCEV *Expr) : Expr(Expr) {}
+
+ template <typename ITy> bool match(ITy *S) { return S == Expr; }
+};
+
+/// Match if we have a specific specified SCEV.
+inline specificscev_ty m_Specific(const SCEV *S) { return S; }
+
namespace detail {
template <typename TupleTy, typename Fn, std::size_t... Is>
@@ -133,6 +145,26 @@ m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
return BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>(Op0, Op1);
}
+template <typename Op0_t, typename SCEVTy>
+using UnarySCEV_match = SCEV_match<std::tuple<Op0_t>, SCEVTy>;
+
+template <typename Op0_t, typename Op1_t, typename SCEVTy>
+inline UnarySCEV_match<Op0_t, SCEVTy> m_scev_Unary(const Op0_t &Op0) {
+ return UnarySCEV_match<Op0_t, SCEVTy>(Op0);
+}
+
+template <typename Op0_t>
+inline UnarySCEV_match<Op0_t, SCEVSignExtendExpr>
+m_scev_SExt(const Op0_t &Op0) {
+ return UnarySCEV_match<Op0_t, SCEVSignExtendExpr>(Op0);
+}
+
+template <typename Op0_t>
+inline UnarySCEV_match<Op0_t, SCEVZeroExtendExpr>
+m_scev_ZExt(const Op0_t &Op0) {
+ return UnarySCEV_match<Op0_t, SCEVZeroExtendExpr>(Op0);
+}
+
} // namespace SCEVPatternMatch
} // namespace llvm
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 66729fd970c3c0..d4bec26ff6c52d 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -12724,33 +12724,28 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
// zext x u<= sext x, sext x s<= zext x
+ const SCEV *Op;
switch (Pred) {
case ICmpInst::ICMP_SGE:
std::swap(LHS, RHS);
[[fallthrough]];
case ICmpInst::ICMP_SLE: {
- // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
- const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
- const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
- if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
- return true;
- break;
+ // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
+ return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
+ match(RHS, m_scev_ZExt(m_Specific(Op)));
}
case ICmpInst::ICMP_UGE:
std::swap(LHS, RHS);
[[fallthrough]];
case ICmpInst::ICMP_ULE: {
- // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
- const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
- const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
- if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
- return true;
- break;
+ // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
+ return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
+ match(RHS, m_scev_SExt(m_Specific(Op)));
}
default:
- break;
+ return false;
};
- return false;
+ llvm_unreachable("unhandled case");
}
bool
More information about the llvm-commits
mailing list