[llvm] 8ea9576 - [SCEV] Add initial matchers for SCEV expressions. (NFC) (#119390)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 17 04:13:00 PST 2024


Author: Florian Hahn
Date: 2024-12-17T12:12:56Z
New Revision: 8ea9576d94ec6b15a2a3ba181af15d136283bde4

URL: https://github.com/llvm/llvm-project/commit/8ea9576d94ec6b15a2a3ba181af15d136283bde4
DIFF: https://github.com/llvm/llvm-project/commit/8ea9576d94ec6b15a2a3ba181af15d136283bde4.diff

LOG: [SCEV] Add initial matchers for SCEV expressions. (NFC) (#119390)

This patch adds initial matchers for unary and binary SCEV expressions 
and specializes it for SExt, ZExt and binary add expressions.

Also adds matchers for SCEVConstant and SCEVUnknown.

This patch only converts a few instances to use the new matchers to make
sure everything builds as expected for now.

The goal of the matchers is to hopefully make it slightly easier to
write code matching SCEV patterns.

Depends on https://github.com/llvm/llvm-project/pull/119389

PR: https://github.com/llvm/llvm-project/pull/119390

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 21d2ef3c867d7d..900f6d0fd05ab6 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -52,6 +52,101 @@ inline cst_pred_ty<is_all_ones> m_scev_AllOnes() {
   return cst_pred_ty<is_all_ones>();
 }
 
+template <typename Class> struct class_match {
+  template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
+};
+
+template <typename Class> struct bind_ty {
+  Class *&VR;
+
+  bind_ty(Class *&V) : VR(V) {}
+
+  template <typename ITy> bool match(ITy *V) const {
+    if (auto *CV = dyn_cast<Class>(V)) {
+      VR = CV;
+      return true;
+    }
+    return false;
+  }
+};
+
+/// Match a SCEV, capturing it if we match.
+inline bind_ty<const SCEV> m_SCEV(const SCEV *&V) { return V; }
+inline bind_ty<const SCEVConstant> m_SCEVConstant(const SCEVConstant *&V) {
+  return V;
+}
+inline bind_ty<const SCEVUnknown> m_SCEVUnknown(const SCEVUnknown *&V) {
+  return V;
+}
+
+/// Match a specified const SCEV *.
+struct specificscev_ty {
+  const SCEV *Expr;
+
+  specificscev_ty(const SCEV *Expr) : Expr(Expr) {}
+
+  template <typename ITy> bool match(ITy *S) { return S == Expr; }
+};
+
+/// Match if we have a specific specified SCEV.
+inline specificscev_ty m_Specific(const SCEV *S) { return S; }
+
+/// Match a unary SCEV.
+template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
+  Op0_t Op0;
+
+  SCEVUnaryExpr_match(Op0_t Op0) : Op0(Op0) {}
+
+  bool match(const SCEV *S) {
+    auto *E = dyn_cast<SCEVTy>(S);
+    return E && E->getNumOperands() == 1 && Op0.match(E->getOperand(0));
+  }
+};
+
+template <typename SCEVTy, typename Op0_t>
+inline SCEVUnaryExpr_match<SCEVTy, Op0_t> m_scev_Unary(const Op0_t &Op0) {
+  return SCEVUnaryExpr_match<SCEVTy, Op0_t>(Op0);
+}
+
+template <typename Op0_t>
+inline SCEVUnaryExpr_match<SCEVSignExtendExpr, Op0_t>
+m_scev_SExt(const Op0_t &Op0) {
+  return m_scev_Unary<SCEVSignExtendExpr>(Op0);
+}
+
+template <typename Op0_t>
+inline SCEVUnaryExpr_match<SCEVZeroExtendExpr, Op0_t>
+m_scev_ZExt(const Op0_t &Op0) {
+  return m_scev_Unary<SCEVZeroExtendExpr>(Op0);
+}
+
+/// Match a binary SCEV.
+template <typename SCEVTy, typename Op0_t, typename Op1_t>
+struct SCEVBinaryExpr_match {
+  Op0_t Op0;
+  Op1_t Op1;
+
+  SCEVBinaryExpr_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
+
+  bool match(const SCEV *S) {
+    auto *E = dyn_cast<SCEVTy>(S);
+    return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) &&
+           Op1.match(E->getOperand(1));
+  }
+};
+
+template <typename SCEVTy, typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>
+m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) {
+  return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVAddExpr, Op0_t, Op1_t>
+m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_scev_Binary<SCEVAddExpr>(Op0, Op1);
+}
+
 } // namespace SCEVPatternMatch
 } // namespace llvm
 

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e2c2500052e7d6..c820e8bf7266ad 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -12725,33 +12725,28 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
 static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
                                         const SCEV *LHS, const SCEV *RHS) {
   // zext x u<= sext x, sext x s<= zext x
+  const SCEV *Op;
   switch (Pred) {
   case ICmpInst::ICMP_SGE:
     std::swap(LHS, RHS);
     [[fallthrough]];
   case ICmpInst::ICMP_SLE: {
-    // If operand >=s 0 then ZExt == SExt.  If operand <s 0 then SExt <s ZExt.
-    const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
-    const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
-    if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
-      return true;
-    break;
+    // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
+    return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
+           match(RHS, m_scev_ZExt(m_Specific(Op)));
   }
   case ICmpInst::ICMP_UGE:
     std::swap(LHS, RHS);
     [[fallthrough]];
   case ICmpInst::ICMP_ULE: {
-    // If operand >=s 0 then ZExt == SExt.  If operand <s 0 then ZExt <u SExt.
-    const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
-    const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
-    if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
-      return true;
-    break;
+    // If operand >=u 0 then ZExt == SExt.  If operand <u 0 then ZExt <u SExt.
+    return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
+           match(RHS, m_scev_SExt(m_Specific(Op)));
   }
   default:
-    break;
+    return false;
   };
-  return false;
+  llvm_unreachable("unhandled case");
 }
 
 bool
@@ -15417,14 +15412,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
     // (X >=u C1).
     auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
                                  &ExprsToRewrite]() {
-      auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
-      if (!AddExpr || AddExpr->getNumOperands() != 2)
-        return false;
-
-      auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
-      auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
+      const SCEVConstant *C1;
+      const SCEVUnknown *LHSUnknown;
       auto *C2 = dyn_cast<SCEVConstant>(RHS);
-      if (!C1 || !C2 || !LHSUnknown)
+      if (!match(LHS,
+                 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
+          !C2)
         return false;
 
       auto ExactRegion =


        


More information about the llvm-commits mailing list