[llvm] [SCEVPatternMatch] Extend with more matchers (PR #138836)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Thu May 8 13:32:21 PDT 2025


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/138836

>From eecf6d1f52d027ceeec71ed7d5d986716c1b83d5 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Wed, 7 May 2025 10:16:25 +0100
Subject: [PATCH 1/4] [SCEVPatternMatch] Extend with more matchers

---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 48 +++++++++++++++++++
 llvm/lib/Analysis/LoopAccessAnalysis.cpp      | 39 +++++----------
 llvm/lib/Analysis/ScalarEvolution.cpp         | 21 ++++----
 3 files changed, 69 insertions(+), 39 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 674147ca175ef..f5fa9c9b72c77 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -58,6 +58,8 @@ template <typename Class> struct class_match {
   template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
 };
 
+inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
+
 template <typename Class> struct bind_ty {
   Class *&VR;
 
@@ -93,6 +95,41 @@ struct specificscev_ty {
 /// Match if we have a specific specified SCEV.
 inline specificscev_ty m_Specific(const SCEV *S) { return S; }
 
+template <typename Class> struct cst_match {
+  Class CV;
+
+  cst_match(Class Op0) : CV(Op0) {}
+
+  bool match(const SCEV *S) const {
+    assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
+           "no vector types expected from SCEVs");
+    auto *C = dyn_cast<SCEVConstant>(S);
+    return C && C->getAPInt() == CV;
+  }
+};
+
+/// Match an SCEV constant with a plain unsigned integer.
+inline cst_match<uint64_t> m_SCEVConstant(uint64_t V) { return V; }
+
+struct bind_cst_ty {
+  const APInt *&CR;
+
+  bind_cst_ty(const APInt *&Op0) : CR(Op0) {}
+
+  bool match(const SCEV *S) const {
+    assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
+           "no vector types expected from SCEVs");
+    auto *C = dyn_cast<SCEVConstant>(S);
+    if (!C)
+      return false;
+    CR = &C->getAPInt();
+    return true;
+  }
+};
+
+/// Match an SCEV constant and bind it to an APInt.
+inline bind_cst_ty m_SCEVConstant(const APInt *&C) { return C; }
+
 /// Match a unary SCEV.
 template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
   Op0_t Op0;
@@ -149,6 +186,17 @@ m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
   return m_scev_Binary<SCEVAddExpr>(Op0, Op1);
 }
 
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t>
+m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_scev_Binary<SCEVMulExpr>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
+m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
+}
 } // namespace SCEVPatternMatch
 } // namespace llvm
 
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index f222a9905c3bb..dc3ea7833c129 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -30,6 +30,7 @@
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
@@ -65,6 +66,7 @@
 #include <vector>
 
 using namespace llvm;
+using namespace llvm::SCEVPatternMatch;
 
 #define DEBUG_TYPE "loop-accesses"
 
@@ -811,8 +813,8 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
   const SCEV *Step = AR->getStepRecurrence(*PSE.getSE());
 
   // Calculate the pointer stride and check if it is constant.
-  const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
-  if (!C) {
+  const APInt *APStepVal;
+  if (!match(Step, m_SCEVConstant(APStepVal))) {
     LLVM_DEBUG({
       dbgs() << "LAA: Bad stride - Not a constant strided ";
       if (Ptr)
@@ -825,13 +827,12 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
   const auto &DL = Lp->getHeader()->getDataLayout();
   TypeSize AllocSize = DL.getTypeAllocSize(AccessTy);
   int64_t Size = AllocSize.getFixedValue();
-  const APInt &APStepVal = C->getAPInt();
 
   // Huge step value - give up.
-  if (APStepVal.getBitWidth() > 64)
+  if (APStepVal->getBitWidth() > 64)
     return std::nullopt;
 
-  int64_t StepVal = APStepVal.getSExtValue();
+  int64_t StepVal = APStepVal->getSExtValue();
 
   // Strided access.
   int64_t Stride = StepVal / Size;
@@ -2061,11 +2062,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
           DL, SE, *(PSE.getSymbolicMaxBackedgeTakenCount()), *Dist, MaxStride))
     return Dependence::NoDep;
 
-  const SCEVConstant *ConstDist = dyn_cast<SCEVConstant>(Dist);
-
   // Attempt to prove strided accesses independent.
-  if (ConstDist) {
-    uint64_t Distance = ConstDist->getAPInt().abs().getZExtValue();
+  const APInt *ConstDist = nullptr;
+  if (match(Dist, m_SCEVConstant(ConstDist))) {
+    uint64_t Distance = ConstDist->abs().getZExtValue();
 
     // If the distance between accesses and their strides are known constants,
     // check whether the accesses interlace each other.
@@ -2111,9 +2111,8 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
         FoundNonConstantDistanceDependence |= ShouldRetryWithRuntimeCheck;
         return Dependence::Unknown;
       }
-      if (!HasSameSize ||
-          couldPreventStoreLoadForward(
-              ConstDist->getAPInt().abs().getZExtValue(), TypeByteSize)) {
+      if (!HasSameSize || couldPreventStoreLoadForward(
+                              ConstDist->abs().getZExtValue(), TypeByteSize)) {
         LLVM_DEBUG(
             dbgs() << "LAA: Forward but may prevent st->ld forwarding\n");
         return Dependence::ForwardButPreventsForwarding;
@@ -2864,20 +2863,8 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
 
   // Strip off the size of access multiplication if we are still analyzing the
   // pointer.
-  if (OrigPtr == Ptr) {
-    if (auto *M = dyn_cast<SCEVMulExpr>(V)) {
-      auto *StepConst = dyn_cast<SCEVConstant>(M->getOperand(0));
-      if (!StepConst)
-        return nullptr;
-
-      auto StepVal = StepConst->getAPInt().trySExtValue();
-      // Bail out on a non-unit pointer access size.
-      if (!StepVal || StepVal != 1)
-        return nullptr;
-
-      V = M->getOperand(1);
-    }
-  }
+  if (OrigPtr == Ptr)
+    match(V, m_scev_Mul(m_SCEVConstant(1), m_SCEV(V)));
 
   // Note that the restriction after this loop invariant check are only
   // profitability restrictions.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index ac69ad598a65a..f0c35d233c42a 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7149,16 +7149,11 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
       assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
              "Should be!");
 
-      // Peel off a constant offset:
-      if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
-        // In the future we could consider being smarter here and handle
-        // {Start+Step,+,Step} too.
-        if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
-          return;
-
-        Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
-        S = SA->getOperand(1);
-      }
+      // Peel off a constant offset. In the future we could consider being
+      // smarter here and handle {Start+Step,+,Step} too.
+      const APInt *Off;
+      if (match(S, m_scev_Add(m_SCEVConstant(Off), m_SCEV(S))))
+        Offset = *Off;
 
       // Peel off a cast operation
       if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
@@ -7337,11 +7332,11 @@ bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
 
 bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
   return !SCEVExprContains(Op, [this](const SCEV *S) {
-    auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
+    const SCEV *Op1;
+    bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
     // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
     // is a non-zero constant, we have to assume the UDiv may be UB.
-    return UDiv && (!isKnownNonZero(UDiv->getOperand(1)) ||
-                    !isGuaranteedNotToBePoison(UDiv->getOperand(1)));
+    return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
   });
 }
 

>From 5d3656055602f042bca5bc0a44038e611e3f0478 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 8 May 2025 16:42:57 +0100
Subject: [PATCH 2/4] [SCEVPatternMatch] Rename fns

---
 llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h | 4 ++--
 llvm/lib/Analysis/LoopAccessAnalysis.cpp                 | 6 +++---
 llvm/lib/Analysis/ScalarEvolution.cpp                    | 2 +-
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index f5fa9c9b72c77..536d74f296931 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -109,7 +109,7 @@ template <typename Class> struct cst_match {
 };
 
 /// Match an SCEV constant with a plain unsigned integer.
-inline cst_match<uint64_t> m_SCEVConstant(uint64_t V) { return V; }
+inline cst_match<uint64_t> m_scev_SpecificInt(uint64_t V) { return V; }
 
 struct bind_cst_ty {
   const APInt *&CR;
@@ -128,7 +128,7 @@ struct bind_cst_ty {
 };
 
 /// Match an SCEV constant and bind it to an APInt.
-inline bind_cst_ty m_SCEVConstant(const APInt *&C) { return C; }
+inline bind_cst_ty m_scev_APInt(const APInt *&C) { return C; }
 
 /// Match a unary SCEV.
 template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index dc3ea7833c129..d98d0d936d788 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -814,7 +814,7 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
 
   // Calculate the pointer stride and check if it is constant.
   const APInt *APStepVal;
-  if (!match(Step, m_SCEVConstant(APStepVal))) {
+  if (!match(Step, m_scev_APInt(APStepVal))) {
     LLVM_DEBUG({
       dbgs() << "LAA: Bad stride - Not a constant strided ";
       if (Ptr)
@@ -2064,7 +2064,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
 
   // Attempt to prove strided accesses independent.
   const APInt *ConstDist = nullptr;
-  if (match(Dist, m_SCEVConstant(ConstDist))) {
+  if (match(Dist, m_scev_APInt(ConstDist))) {
     uint64_t Distance = ConstDist->abs().getZExtValue();
 
     // If the distance between accesses and their strides are known constants,
@@ -2864,7 +2864,7 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
   // Strip off the size of access multiplication if we are still analyzing the
   // pointer.
   if (OrigPtr == Ptr)
-    match(V, m_scev_Mul(m_SCEVConstant(1), m_SCEV(V)));
+    match(V, m_scev_Mul(m_scev_SpecificInt(1), m_SCEV(V)));
 
   // Note that the restriction after this loop invariant check are only
   // profitability restrictions.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index f0c35d233c42a..3f9614254ae7a 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7152,7 +7152,7 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
       // Peel off a constant offset. In the future we could consider being
       // smarter here and handle {Start+Step,+,Step} too.
       const APInt *Off;
-      if (match(S, m_scev_Add(m_SCEVConstant(Off), m_SCEV(S))))
+      if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
         Offset = *Off;
 
       // Peel off a cast operation

>From 780e98ebab8a430e8ceeca568848b208e3572c07 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 8 May 2025 20:55:14 +0100
Subject: [PATCH 3/4] [LAA] Revert some code

---
 llvm/lib/Analysis/LoopAccessAnalysis.cpp | 16 ++++++++++++++--
 1 file changed, 14 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index d98d0d936d788..535b71cd5897e 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -2863,8 +2863,20 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
 
   // Strip off the size of access multiplication if we are still analyzing the
   // pointer.
-  if (OrigPtr == Ptr)
-    match(V, m_scev_Mul(m_scev_SpecificInt(1), m_SCEV(V)));
+  if (OrigPtr == Ptr) {
+    if (auto *M = dyn_cast<SCEVMulExpr>(V)) {
+      auto *StepConst = dyn_cast<SCEVConstant>(M->getOperand(0));
+      if (!StepConst)
+        return nullptr;
+
+      auto StepVal = StepConst->getAPInt().trySExtValue();
+      // Bail out on a non-unit pointer access size.
+      if (!StepVal || StepVal != 1)
+        return nullptr;
+
+      V = M->getOperand(1);
+    }
+  }
 
   // Note that the restriction after this loop invariant check are only
   // profitability restrictions.

>From 1a5b7b72579c4025b456a2385ce42f3cc4c447b5 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 8 May 2025 21:30:35 +0100
Subject: [PATCH 4/4] [SCEVPatternMatch] Introduce is_specific_cst

---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 19 +++++++------------
 1 file changed, 7 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 536d74f296931..5e53a4f502153 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -23,6 +23,8 @@ template <typename Pattern> bool match(const SCEV *S, const Pattern &P) {
 }
 
 template <typename Predicate> struct cst_pred_ty : public Predicate {
+  cst_pred_ty() = default;
+  cst_pred_ty(uint64_t V) : Predicate(V) {}
   bool match(const SCEV *S) const {
     assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
            "no vector types expected from SCEVs");
@@ -95,21 +97,14 @@ struct specificscev_ty {
 /// Match if we have a specific specified SCEV.
 inline specificscev_ty m_Specific(const SCEV *S) { return S; }
 
-template <typename Class> struct cst_match {
-  Class CV;
-
-  cst_match(Class Op0) : CV(Op0) {}
-
-  bool match(const SCEV *S) const {
-    assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
-           "no vector types expected from SCEVs");
-    auto *C = dyn_cast<SCEVConstant>(S);
-    return C && C->getAPInt() == CV;
-  }
+struct is_specific_cst {
+  uint64_t CV;
+  is_specific_cst(uint64_t C) : CV(C) {}
+  bool isValue(const APInt &C) const { return C == CV; }
 };
 
 /// Match an SCEV constant with a plain unsigned integer.
-inline cst_match<uint64_t> m_scev_SpecificInt(uint64_t V) { return V; }
+inline cst_pred_ty<is_specific_cst> m_scev_SpecificInt(uint64_t V) { return V; }
 
 struct bind_cst_ty {
   const APInt *&CR;



More information about the llvm-commits mailing list