[llvm] 4589911 - [SCEVPatternMatch] Extend with more matchers (#138836)

via llvm-commits llvm-commits at lists.llvm.org
Fri May 9 01:20:18 PDT 2025


Author: Ramkumar Ramachandra
Date: 2025-05-09T09:20:14+01:00
New Revision: 458991197d252e28ce4720a0770ef7d183435eeb

URL: https://github.com/llvm/llvm-project/commit/458991197d252e28ce4720a0770ef7d183435eeb
DIFF: https://github.com/llvm/llvm-project/commit/458991197d252e28ce4720a0770ef7d183435eeb.diff

LOG: [SCEVPatternMatch] Extend with more matchers (#138836)

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
    llvm/lib/Analysis/LoopAccessAnalysis.cpp
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 674147ca175ef..5e53a4f502153 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -23,6 +23,8 @@ template <typename Pattern> bool match(const SCEV *S, const Pattern &P) {
 }
 
 template <typename Predicate> struct cst_pred_ty : public Predicate {
+  cst_pred_ty() = default;
+  cst_pred_ty(uint64_t V) : Predicate(V) {}
   bool match(const SCEV *S) const {
     assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
            "no vector types expected from SCEVs");
@@ -58,6 +60,8 @@ template <typename Class> struct class_match {
   template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
 };
 
+inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
+
 template <typename Class> struct bind_ty {
   Class *&VR;
 
@@ -93,6 +97,34 @@ struct specificscev_ty {
 /// Match if we have a specific specified SCEV.
 inline specificscev_ty m_Specific(const SCEV *S) { return S; }
 
+struct is_specific_cst {
+  uint64_t CV;
+  is_specific_cst(uint64_t C) : CV(C) {}
+  bool isValue(const APInt &C) const { return C == CV; }
+};
+
+/// Match an SCEV constant with a plain unsigned integer.
+inline cst_pred_ty<is_specific_cst> m_scev_SpecificInt(uint64_t V) { return V; }
+
+struct bind_cst_ty {
+  const APInt *&CR;
+
+  bind_cst_ty(const APInt *&Op0) : CR(Op0) {}
+
+  bool match(const SCEV *S) const {
+    assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
+           "no vector types expected from SCEVs");
+    auto *C = dyn_cast<SCEVConstant>(S);
+    if (!C)
+      return false;
+    CR = &C->getAPInt();
+    return true;
+  }
+};
+
+/// Match an SCEV constant and bind it to an APInt.
+inline bind_cst_ty m_scev_APInt(const APInt *&C) { return C; }
+
 /// Match a unary SCEV.
 template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
   Op0_t Op0;
@@ -149,6 +181,17 @@ m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
   return m_scev_Binary<SCEVAddExpr>(Op0, Op1);
 }
 
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t>
+m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_scev_Binary<SCEVMulExpr>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
+m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
+}
 } // namespace SCEVPatternMatch
 } // namespace llvm
 

diff  --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index f222a9905c3bb..535b71cd5897e 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -30,6 +30,7 @@
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
@@ -65,6 +66,7 @@
 #include <vector>
 
 using namespace llvm;
+using namespace llvm::SCEVPatternMatch;
 
 #define DEBUG_TYPE "loop-accesses"
 
@@ -811,8 +813,8 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
   const SCEV *Step = AR->getStepRecurrence(*PSE.getSE());
 
   // Calculate the pointer stride and check if it is constant.
-  const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
-  if (!C) {
+  const APInt *APStepVal;
+  if (!match(Step, m_scev_APInt(APStepVal))) {
     LLVM_DEBUG({
       dbgs() << "LAA: Bad stride - Not a constant strided ";
       if (Ptr)
@@ -825,13 +827,12 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
   const auto &DL = Lp->getHeader()->getDataLayout();
   TypeSize AllocSize = DL.getTypeAllocSize(AccessTy);
   int64_t Size = AllocSize.getFixedValue();
-  const APInt &APStepVal = C->getAPInt();
 
   // Huge step value - give up.
-  if (APStepVal.getBitWidth() > 64)
+  if (APStepVal->getBitWidth() > 64)
     return std::nullopt;
 
-  int64_t StepVal = APStepVal.getSExtValue();
+  int64_t StepVal = APStepVal->getSExtValue();
 
   // Strided access.
   int64_t Stride = StepVal / Size;
@@ -2061,11 +2062,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
           DL, SE, *(PSE.getSymbolicMaxBackedgeTakenCount()), *Dist, MaxStride))
     return Dependence::NoDep;
 
-  const SCEVConstant *ConstDist = dyn_cast<SCEVConstant>(Dist);
-
   // Attempt to prove strided accesses independent.
-  if (ConstDist) {
-    uint64_t Distance = ConstDist->getAPInt().abs().getZExtValue();
+  const APInt *ConstDist = nullptr;
+  if (match(Dist, m_scev_APInt(ConstDist))) {
+    uint64_t Distance = ConstDist->abs().getZExtValue();
 
     // If the distance between accesses and their strides are known constants,
     // check whether the accesses interlace each other.
@@ -2111,9 +2111,8 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
         FoundNonConstantDistanceDependence |= ShouldRetryWithRuntimeCheck;
         return Dependence::Unknown;
       }
-      if (!HasSameSize ||
-          couldPreventStoreLoadForward(
-              ConstDist->getAPInt().abs().getZExtValue(), TypeByteSize)) {
+      if (!HasSameSize || couldPreventStoreLoadForward(
+                              ConstDist->abs().getZExtValue(), TypeByteSize)) {
         LLVM_DEBUG(
             dbgs() << "LAA: Forward but may prevent st->ld forwarding\n");
         return Dependence::ForwardButPreventsForwarding;

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index ac69ad598a65a..3f9614254ae7a 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7149,16 +7149,11 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
       assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
              "Should be!");
 
-      // Peel off a constant offset:
-      if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
-        // In the future we could consider being smarter here and handle
-        // {Start+Step,+,Step} too.
-        if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
-          return;
-
-        Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
-        S = SA->getOperand(1);
-      }
+      // Peel off a constant offset. In the future we could consider being
+      // smarter here and handle {Start+Step,+,Step} too.
+      const APInt *Off;
+      if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
+        Offset = *Off;
 
       // Peel off a cast operation
       if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
@@ -7337,11 +7332,11 @@ bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
 
 bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
   return !SCEVExprContains(Op, [this](const SCEV *S) {
-    auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
+    const SCEV *Op1;
+    bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
     // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
     // is a non-zero constant, we have to assume the UDiv may be UB.
-    return UDiv && (!isKnownNonZero(UDiv->getOperand(1)) ||
-                    !isGuaranteedNotToBePoison(UDiv->getOperand(1)));
+    return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
   });
 }
 


        


More information about the llvm-commits mailing list