[llvm] [SCEV] Add initial matchers for SCEV expressions. (NFC) (PR #119390)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 17 02:46:58 PST 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/119390

>From 462d03c2b735253621ff91567c13529c0d9afbc7 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 10 Dec 2024 15:01:52 +0000
Subject: [PATCH 1/4] [SCEV] Add initial matchers for SCEV expressions. (NFC)

This patch adds initial matchers for SCEV expressions with an arbitrary
number of operands and specializes it for binary add expressions.

Also adds matchers for SCEVConstant and SCEVUnknown.

This patch only converts a few instances to use the new matchers to make
sure everything builds as expected for now.

Depends on https://github.com/llvm/llvm-project/pull/119389
---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 81 +++++++++++++++++++
 llvm/lib/Analysis/ScalarEvolution.cpp         | 12 ++-
 2 files changed, 86 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 21d2ef3c867d7d..79295dd324c4ba 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -52,6 +52,87 @@ inline cst_pred_ty<is_all_ones> m_scev_AllOnes() {
   return cst_pred_ty<is_all_ones>();
 }
 
+template <typename Class> struct class_match {
+  template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
+};
+
+template <typename Class> struct bind_ty {
+  Class *&VR;
+
+  bind_ty(Class *&V) : VR(V) {}
+
+  template <typename ITy> bool match(ITy *V) const {
+    if (auto *CV = dyn_cast<Class>(V)) {
+      VR = CV;
+      return true;
+    }
+    return false;
+  }
+};
+
+/// Match a SCEV, capturing it if we match.
+inline bind_ty<const SCEV> m_SCEV(const SCEV *&V) { return V; }
+inline bind_ty<const SCEVConstant> m_SCEVConstant(const SCEVConstant *&V) {
+  return V;
+}
+inline bind_ty<const SCEVUnknown> m_SCEVUnknown(const SCEVUnknown *&V) {
+  return V;
+}
+
+namespace detail {
+
+template <typename TupleTy, typename Fn, std::size_t... Is>
+bool CheckTupleElements(const TupleTy &Ops, Fn P, std::index_sequence<Is...>) {
+  return (P(std::get<Is>(Ops), Is) && ...);
+}
+
+/// Helper to check if predicate \p P holds on all tuple elements in \p Ops
+template <typename TupleTy, typename Fn>
+bool all_of_tuple_elements(const TupleTy &Ops, Fn P) {
+  return CheckTupleElements(
+      Ops, P, std::make_index_sequence<std::tuple_size<TupleTy>::value>{});
+}
+
+} // namespace detail
+
+template <typename Ops_t, typename SCEVTy> struct SCEV_match {
+  Ops_t Ops;
+
+  SCEV_match() : Ops() {
+    static_assert(std::tuple_size<Ops_t>::value == 0 &&
+                  "constructor can only be used with zero operands");
+  }
+  SCEV_match(Ops_t Ops) : Ops(Ops) {}
+  template <typename A_t, typename B_t> SCEV_match(A_t A, B_t B) : Ops({A, B}) {
+    static_assert(std::tuple_size<Ops_t>::value == 2 &&
+                  "constructor can only be used for binary matcher");
+  }
+
+  bool match(const SCEV *S) const {
+    auto *Cast = dyn_cast<SCEVTy>(S);
+    if (!Cast || Cast->getNumOperands() != std::tuple_size<Ops_t>::value)
+      return false;
+    return detail::all_of_tuple_elements(Ops, [Cast](auto Op, unsigned Idx) {
+      return Op.match(Cast->getOperand(Idx));
+    });
+  }
+};
+
+template <typename Op0_t, typename Op1_t, typename SCEVTy>
+using BinarySCEV_match = SCEV_match<std::tuple<Op0_t, Op1_t>, SCEVTy>;
+
+template <typename Op0_t, typename Op1_t, typename SCEVTy>
+inline BinarySCEV_match<Op0_t, Op1_t, SCEVTy> m_scev_Binary(const Op0_t &Op0,
+                                                            const Op1_t &Op1) {
+  return BinarySCEV_match<Op0_t, Op1_t, SCEVTy>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>
+m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
+  return BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>(Op0, Op1);
+}
+
 } // namespace SCEVPatternMatch
 } // namespace llvm
 
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e2c2500052e7d6..ff30265bc68082 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15417,14 +15417,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
     // (X >=u C1).
     auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
                                  &ExprsToRewrite]() {
-      auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
-      if (!AddExpr || AddExpr->getNumOperands() != 2)
-        return false;
-
-      auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
-      auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
+      const SCEVConstant *C1;
+      const SCEVUnknown *LHSUnknown;
       auto *C2 = dyn_cast<SCEVConstant>(RHS);
-      if (!C1 || !C2 || !LHSUnknown)
+      if (!match(LHS,
+                 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
+          !C2)
         return false;
 
       auto ExactRegion =

>From b79559a85703b661a2754623efbd71bf5f7dc065 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 13 Dec 2024 14:47:27 +0000
Subject: [PATCH 2/4] !fixup add unary matcher for ZExt and SExt

---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 32 +++++++++++++++++++
 llvm/lib/Analysis/ScalarEvolution.cpp         | 23 ++++++-------
 2 files changed, 41 insertions(+), 14 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 79295dd324c4ba..6e7c985b2b6e86 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -79,6 +79,18 @@ inline bind_ty<const SCEVUnknown> m_SCEVUnknown(const SCEVUnknown *&V) {
   return V;
 }
 
+/// Match a specified const SCEV *.
+struct specificscev_ty {
+  const SCEV *Expr;
+
+  specificscev_ty(const SCEV *Expr) : Expr(Expr) {}
+
+  template <typename ITy> bool match(ITy *S) { return S == Expr; }
+};
+
+/// Match if we have a specific specified SCEV.
+inline specificscev_ty m_Specific(const SCEV *S) { return S; }
+
 namespace detail {
 
 template <typename TupleTy, typename Fn, std::size_t... Is>
@@ -133,6 +145,26 @@ m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
   return BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>(Op0, Op1);
 }
 
+template <typename Op0_t, typename SCEVTy>
+using UnarySCEV_match = SCEV_match<std::tuple<Op0_t>, SCEVTy>;
+
+template <typename Op0_t, typename Op1_t, typename SCEVTy>
+inline UnarySCEV_match<Op0_t, SCEVTy> m_scev_Unary(const Op0_t &Op0) {
+  return UnarySCEV_match<Op0_t, SCEVTy>(Op0);
+}
+
+template <typename Op0_t>
+inline UnarySCEV_match<Op0_t, SCEVSignExtendExpr>
+m_scev_SExt(const Op0_t &Op0) {
+  return UnarySCEV_match<Op0_t, SCEVSignExtendExpr>(Op0);
+}
+
+template <typename Op0_t>
+inline UnarySCEV_match<Op0_t, SCEVZeroExtendExpr>
+m_scev_ZExt(const Op0_t &Op0) {
+  return UnarySCEV_match<Op0_t, SCEVZeroExtendExpr>(Op0);
+}
+
 } // namespace SCEVPatternMatch
 } // namespace llvm
 
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index ff30265bc68082..c820e8bf7266ad 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -12725,33 +12725,28 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
 static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
                                         const SCEV *LHS, const SCEV *RHS) {
   // zext x u<= sext x, sext x s<= zext x
+  const SCEV *Op;
   switch (Pred) {
   case ICmpInst::ICMP_SGE:
     std::swap(LHS, RHS);
     [[fallthrough]];
   case ICmpInst::ICMP_SLE: {
-    // If operand >=s 0 then ZExt == SExt.  If operand <s 0 then SExt <s ZExt.
-    const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
-    const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
-    if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
-      return true;
-    break;
+    // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
+    return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
+           match(RHS, m_scev_ZExt(m_Specific(Op)));
   }
   case ICmpInst::ICMP_UGE:
     std::swap(LHS, RHS);
     [[fallthrough]];
   case ICmpInst::ICMP_ULE: {
-    // If operand >=s 0 then ZExt == SExt.  If operand <s 0 then ZExt <u SExt.
-    const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
-    const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
-    if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
-      return true;
-    break;
+    // If operand >=u 0 then ZExt == SExt.  If operand <u 0 then ZExt <u SExt.
+    return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
+           match(RHS, m_scev_SExt(m_Specific(Op)));
   }
   default:
-    break;
+    return false;
   };
-  return false;
+  llvm_unreachable("unhandled case");
 }
 
 bool

>From 4fa18ab39f5954c16e6d9cd0ca5832077bc91685 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 13 Dec 2024 14:50:43 +0000
Subject: [PATCH 3/4] !fixup SCEV_match -> SCEVExpr_match.

---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 35 ++++++++++---------
 1 file changed, 18 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 6e7c985b2b6e86..96101205163da4 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -107,15 +107,16 @@ bool all_of_tuple_elements(const TupleTy &Ops, Fn P) {
 
 } // namespace detail
 
-template <typename Ops_t, typename SCEVTy> struct SCEV_match {
+template <typename Ops_t, typename SCEVTy> struct SCEVExpr_match {
   Ops_t Ops;
 
-  SCEV_match() : Ops() {
+  SCEVExpr_match() : Ops() {
     static_assert(std::tuple_size<Ops_t>::value == 0 &&
                   "constructor can only be used with zero operands");
   }
-  SCEV_match(Ops_t Ops) : Ops(Ops) {}
-  template <typename A_t, typename B_t> SCEV_match(A_t A, B_t B) : Ops({A, B}) {
+  SCEVExpr_match(Ops_t Ops) : Ops(Ops) {}
+  template <typename A_t, typename B_t>
+  SCEVExpr_match(A_t A, B_t B) : Ops({A, B}) {
     static_assert(std::tuple_size<Ops_t>::value == 2 &&
                   "constructor can only be used for binary matcher");
   }
@@ -131,38 +132,38 @@ template <typename Ops_t, typename SCEVTy> struct SCEV_match {
 };
 
 template <typename Op0_t, typename Op1_t, typename SCEVTy>
-using BinarySCEV_match = SCEV_match<std::tuple<Op0_t, Op1_t>, SCEVTy>;
+using BinarySCEVExpr_match = SCEVExpr_match<std::tuple<Op0_t, Op1_t>, SCEVTy>;
 
 template <typename Op0_t, typename Op1_t, typename SCEVTy>
-inline BinarySCEV_match<Op0_t, Op1_t, SCEVTy> m_scev_Binary(const Op0_t &Op0,
-                                                            const Op1_t &Op1) {
-  return BinarySCEV_match<Op0_t, Op1_t, SCEVTy>(Op0, Op1);
+inline BinarySCEVExpr_match<Op0_t, Op1_t, SCEVTy>
+m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) {
+  return BinarySCEVExpr_match<Op0_t, Op1_t, SCEVTy>(Op0, Op1);
 }
 
 template <typename Op0_t, typename Op1_t>
-inline BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>
+inline BinarySCEVExpr_match<Op0_t, Op1_t, SCEVAddExpr>
 m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
-  return BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>(Op0, Op1);
+  return BinarySCEVExpr_match<Op0_t, Op1_t, SCEVAddExpr>(Op0, Op1);
 }
 
 template <typename Op0_t, typename SCEVTy>
-using UnarySCEV_match = SCEV_match<std::tuple<Op0_t>, SCEVTy>;
+using UnarySCEVExpr_match = SCEVExpr_match<std::tuple<Op0_t>, SCEVTy>;
 
 template <typename Op0_t, typename Op1_t, typename SCEVTy>
-inline UnarySCEV_match<Op0_t, SCEVTy> m_scev_Unary(const Op0_t &Op0) {
-  return UnarySCEV_match<Op0_t, SCEVTy>(Op0);
+inline UnarySCEVExpr_match<Op0_t, SCEVTy> m_scev_Unary(const Op0_t &Op0) {
+  return UnarySCEVExpr_match<Op0_t, SCEVTy>(Op0);
 }
 
 template <typename Op0_t>
-inline UnarySCEV_match<Op0_t, SCEVSignExtendExpr>
+inline UnarySCEVExpr_match<Op0_t, SCEVSignExtendExpr>
 m_scev_SExt(const Op0_t &Op0) {
-  return UnarySCEV_match<Op0_t, SCEVSignExtendExpr>(Op0);
+  return UnarySCEVExpr_match<Op0_t, SCEVSignExtendExpr>(Op0);
 }
 
 template <typename Op0_t>
-inline UnarySCEV_match<Op0_t, SCEVZeroExtendExpr>
+inline UnarySCEVExpr_match<Op0_t, SCEVZeroExtendExpr>
 m_scev_ZExt(const Op0_t &Op0) {
-  return UnarySCEV_match<Op0_t, SCEVZeroExtendExpr>(Op0);
+  return UnarySCEVExpr_match<Op0_t, SCEVZeroExtendExpr>(Op0);
 }
 
 } // namespace SCEVPatternMatch

>From a71a973aed904caf3b233afe3d50098df1735d33 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 17 Dec 2024 10:42:06 +0000
Subject: [PATCH 4/4] !fixup replace generic matcher with dedicated matcher

---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 95 ++++++++-----------
 1 file changed, 38 insertions(+), 57 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 96101205163da4..900f6d0fd05ab6 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -91,79 +91,60 @@ struct specificscev_ty {
 /// Match if we have a specific specified SCEV.
 inline specificscev_ty m_Specific(const SCEV *S) { return S; }
 
-namespace detail {
+/// Match a unary SCEV.
+template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
+  Op0_t Op0;
 
-template <typename TupleTy, typename Fn, std::size_t... Is>
-bool CheckTupleElements(const TupleTy &Ops, Fn P, std::index_sequence<Is...>) {
-  return (P(std::get<Is>(Ops), Is) && ...);
+  SCEVUnaryExpr_match(Op0_t Op0) : Op0(Op0) {}
+
+  bool match(const SCEV *S) {
+    auto *E = dyn_cast<SCEVTy>(S);
+    return E && E->getNumOperands() == 1 && Op0.match(E->getOperand(0));
+  }
+};
+
+template <typename SCEVTy, typename Op0_t>
+inline SCEVUnaryExpr_match<SCEVTy, Op0_t> m_scev_Unary(const Op0_t &Op0) {
+  return SCEVUnaryExpr_match<SCEVTy, Op0_t>(Op0);
 }
 
-/// Helper to check if predicate \p P holds on all tuple elements in \p Ops
-template <typename TupleTy, typename Fn>
-bool all_of_tuple_elements(const TupleTy &Ops, Fn P) {
-  return CheckTupleElements(
-      Ops, P, std::make_index_sequence<std::tuple_size<TupleTy>::value>{});
+template <typename Op0_t>
+inline SCEVUnaryExpr_match<SCEVSignExtendExpr, Op0_t>
+m_scev_SExt(const Op0_t &Op0) {
+  return m_scev_Unary<SCEVSignExtendExpr>(Op0);
 }
 
-} // namespace detail
+template <typename Op0_t>
+inline SCEVUnaryExpr_match<SCEVZeroExtendExpr, Op0_t>
+m_scev_ZExt(const Op0_t &Op0) {
+  return m_scev_Unary<SCEVZeroExtendExpr>(Op0);
+}
 
-template <typename Ops_t, typename SCEVTy> struct SCEVExpr_match {
-  Ops_t Ops;
+/// Match a binary SCEV.
+template <typename SCEVTy, typename Op0_t, typename Op1_t>
+struct SCEVBinaryExpr_match {
+  Op0_t Op0;
+  Op1_t Op1;
 
-  SCEVExpr_match() : Ops() {
-    static_assert(std::tuple_size<Ops_t>::value == 0 &&
-                  "constructor can only be used with zero operands");
-  }
-  SCEVExpr_match(Ops_t Ops) : Ops(Ops) {}
-  template <typename A_t, typename B_t>
-  SCEVExpr_match(A_t A, B_t B) : Ops({A, B}) {
-    static_assert(std::tuple_size<Ops_t>::value == 2 &&
-                  "constructor can only be used for binary matcher");
-  }
+  SCEVBinaryExpr_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
 
-  bool match(const SCEV *S) const {
-    auto *Cast = dyn_cast<SCEVTy>(S);
-    if (!Cast || Cast->getNumOperands() != std::tuple_size<Ops_t>::value)
-      return false;
-    return detail::all_of_tuple_elements(Ops, [Cast](auto Op, unsigned Idx) {
-      return Op.match(Cast->getOperand(Idx));
-    });
+  bool match(const SCEV *S) {
+    auto *E = dyn_cast<SCEVTy>(S);
+    return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) &&
+           Op1.match(E->getOperand(1));
   }
 };
 
-template <typename Op0_t, typename Op1_t, typename SCEVTy>
-using BinarySCEVExpr_match = SCEVExpr_match<std::tuple<Op0_t, Op1_t>, SCEVTy>;
-
-template <typename Op0_t, typename Op1_t, typename SCEVTy>
-inline BinarySCEVExpr_match<Op0_t, Op1_t, SCEVTy>
+template <typename SCEVTy, typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>
 m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) {
-  return BinarySCEVExpr_match<Op0_t, Op1_t, SCEVTy>(Op0, Op1);
+  return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>(Op0, Op1);
 }
 
 template <typename Op0_t, typename Op1_t>
-inline BinarySCEVExpr_match<Op0_t, Op1_t, SCEVAddExpr>
+inline SCEVBinaryExpr_match<SCEVAddExpr, Op0_t, Op1_t>
 m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
-  return BinarySCEVExpr_match<Op0_t, Op1_t, SCEVAddExpr>(Op0, Op1);
-}
-
-template <typename Op0_t, typename SCEVTy>
-using UnarySCEVExpr_match = SCEVExpr_match<std::tuple<Op0_t>, SCEVTy>;
-
-template <typename Op0_t, typename Op1_t, typename SCEVTy>
-inline UnarySCEVExpr_match<Op0_t, SCEVTy> m_scev_Unary(const Op0_t &Op0) {
-  return UnarySCEVExpr_match<Op0_t, SCEVTy>(Op0);
-}
-
-template <typename Op0_t>
-inline UnarySCEVExpr_match<Op0_t, SCEVSignExtendExpr>
-m_scev_SExt(const Op0_t &Op0) {
-  return UnarySCEVExpr_match<Op0_t, SCEVSignExtendExpr>(Op0);
-}
-
-template <typename Op0_t>
-inline UnarySCEVExpr_match<Op0_t, SCEVZeroExtendExpr>
-m_scev_ZExt(const Op0_t &Op0) {
-  return UnarySCEVExpr_match<Op0_t, SCEVZeroExtendExpr>(Op0);
+  return m_scev_Binary<SCEVAddExpr>(Op0, Op1);
 }
 
 } // namespace SCEVPatternMatch



More information about the llvm-commits mailing list