[llvm] [SCEVPatternMatch] Extend m_scev_AffineAddRec with Loop (PR #141132)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Fri May 23 08:08:56 PDT 2025


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/141132

>From f256b7607df15cbc7fd574b77d59677a773886a9 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 22 May 2025 19:37:58 +0100
Subject: [PATCH 1/4] [SCEVPatternMatch] Extend m_scev_AffineAddRec with Loop

---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 13 ++++++--
 llvm/lib/Transforms/Scalar/IndVarSimplify.cpp |  3 +-
 .../Transforms/Scalar/LoopIdiomRecognize.cpp  | 32 ++++++++-----------
 3 files changed, 25 insertions(+), 23 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index cfb1b4c6ea6b4..bce10442e3b5c 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -14,6 +14,7 @@
 #define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
 
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include <type_traits>
 
 namespace llvm {
 namespace SCEVPatternMatch {
@@ -162,13 +163,18 @@ template <typename SCEVTy, typename Op0_t, typename Op1_t>
 struct SCEVBinaryExpr_match {
   Op0_t Op0;
   Op1_t Op1;
+  const Loop *L;
 
-  SCEVBinaryExpr_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
+  SCEVBinaryExpr_match(Op0_t Op0, Op1_t Op1, const Loop *L = nullptr)
+      : Op0(Op0), Op1(Op1), L(L) {}
 
   bool match(const SCEV *S) const {
     auto *E = dyn_cast<SCEVTy>(S);
+    bool LoopMatches = true;
+    if constexpr (std::is_same_v<SCEVTy, SCEVAddRecExpr>)
+      LoopMatches = !L || (E && E->getLoop() == L);
     return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) &&
-           Op1.match(E->getOperand(1));
+           Op1.match(E->getOperand(1)) && LoopMatches;
   }
 };
 
@@ -198,7 +204,8 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
 
 template <typename Op0_t, typename Op1_t>
 inline SCEVBinaryExpr_match<SCEVAddRecExpr, Op0_t, Op1_t>
-m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1) {
+m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1,
+                    const Loop *L = nullptr) {
   return m_scev_Binary<SCEVAddRecExpr>(Op0, Op1);
 }
 } // namespace SCEVPatternMatch
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index e774e5fd99cbb..68db70d7abf9b 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -808,8 +808,7 @@ static bool isLoopCounter(PHINode* Phi, Loop *L,
     return false;
 
   const SCEV *S = SE->getSCEV(Phi);
-  if (!match(S, m_scev_AffineAddRec(m_SCEV(), m_scev_One())) ||
-      cast<SCEVAddRecExpr>(S)->getLoop() != L)
+  if (!match(S, m_scev_AffineAddRec(m_SCEV(), m_scev_One(), L)))
     return false;
 
   int LatchIdx = Phi->getBasicBlockIndex(L->getLoopLatch());
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 0d5e0156b22be..2b2f50c7047c9 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -455,8 +455,8 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
   // random store we can't handle.
   const SCEV *StoreEv = SE->getSCEV(StorePtr);
   const SCEVConstant *Stride;
-  if (!match(StoreEv, m_scev_AffineAddRec(m_SCEV(), m_SCEVConstant(Stride))) ||
-      cast<SCEVAddRecExpr>(StoreEv)->getLoop() != CurLoop)
+  if (!match(StoreEv,
+             m_scev_AffineAddRec(m_SCEV(), m_SCEVConstant(Stride), CurLoop)))
     return LegalStoreKind::None;
 
   // See if the store can be turned into a memset.
@@ -513,8 +513,7 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
 
     // The store and load must share the same stride.
     if (!match(LoadEv,
-               m_scev_AffineAddRec(m_SCEV(), m_scev_Specific(Stride))) ||
-        cast<SCEVAddRecExpr>(LoadEv)->getLoop() != CurLoop)
+               m_scev_AffineAddRec(m_SCEV(), m_scev_Specific(Stride), CurLoop)))
       return LegalStoreKind::None;
 
     // Success.  This store can be converted into a memcpy.
@@ -787,11 +786,13 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
   // See if the load and store pointer expressions are AddRec like {base,+,1} on
   // the current loop, which indicates a strided load and store.  If we have
   // something else, it's a random load or store we can't handle.
-  const SCEVAddRecExpr *StoreEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Dest));
-  if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine())
-    return false;
-  const SCEVAddRecExpr *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Source));
-  if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
+  const SCEV *StoreEv = SE->getSCEV(Dest);
+  const SCEV *LoadEv = SE->getSCEV(Source);
+  const APInt *StoreStrideValue, *LoadStrideValue;
+  if (!match(StoreEv, m_scev_AffineAddRec(
+                          m_SCEV(), m_scev_APInt(StoreStrideValue), CurLoop)) ||
+      !match(LoadEv, m_scev_AffineAddRec(
+                         m_SCEV(), m_scev_APInt(LoadStrideValue), CurLoop)))
     return false;
 
   // Reject memcpys that are so large that they overflow an unsigned.
@@ -801,10 +802,6 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
 
   // Check if the stride matches the size of the memcpy. If so, then we know
   // that every byte is touched in the loop.
-  const APInt *StoreStrideValue, *LoadStrideValue;
-  if (!match(StoreEv->getOperand(1), m_scev_APInt(StoreStrideValue)) ||
-      !match(LoadEv->getOperand(1), m_scev_APInt(LoadStrideValue)))
-    return false;
 
   // Huge stride value - give up
   if (StoreStrideValue->getBitWidth() > 64 ||
@@ -830,8 +827,8 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
 
   return processLoopStoreOfLoopLoad(
       Dest, Source, SE->getConstant(Dest->getType(), SizeInBytes),
-      MCI->getDestAlign(), MCI->getSourceAlign(), MCI, MCI, StoreEv, LoadEv,
-      BECount);
+      MCI->getDestAlign(), MCI->getSourceAlign(), MCI, MCI,
+      cast<SCEVAddRecExpr>(StoreEv), cast<SCEVAddRecExpr>(LoadEv), BECount);
 }
 
 /// processLoopMemSet - See if this memset can be promoted to a large memset.
@@ -852,12 +849,11 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
   // random store we can't handle.
   const SCEV *Ev = SE->getSCEV(Pointer);
   const SCEV *PointerStrideSCEV;
-  if (!match(Ev, m_scev_AffineAddRec(m_SCEV(), m_SCEV(PointerStrideSCEV)))) {
+  if (!match(Ev, m_scev_AffineAddRec(m_SCEV(), m_SCEV(PointerStrideSCEV),
+                                     CurLoop))) {
     LLVM_DEBUG(dbgs() << "  Pointer is not affine, abort\n");
     return false;
   }
-  if (cast<SCEVAddRecExpr>(Ev)->getLoop() != CurLoop)
-    return false;
 
   const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength());
   if (!PointerStrideSCEV || !MemsetSizeSCEV)

>From c20224aa2b108373c75ed02622ce0327e8f1f353 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 22 May 2025 21:36:58 +0100
Subject: [PATCH 2/4] [SCEVPM] Address review

---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 26 +++++++++++++------
 1 file changed, 18 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index bce10442e3b5c..3d9a72da6d645 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -163,18 +163,28 @@ template <typename SCEVTy, typename Op0_t, typename Op1_t>
 struct SCEVBinaryExpr_match {
   Op0_t Op0;
   Op1_t Op1;
+
+  SCEVBinaryExpr_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
+
+  bool match(const SCEV *S) const {
+    auto *E = dyn_cast<SCEVTy>(S);
+    return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) &&
+           Op1.match(E->getOperand(1));
+  }
+};
+
+template <typename Op0_t, typename Op1_t> struct SCEVAffineAddRec_match {
+  Op0_t Op0;
+  Op1_t Op1;
   const Loop *L;
 
-  SCEVBinaryExpr_match(Op0_t Op0, Op1_t Op1, const Loop *L = nullptr)
+  SCEVAffineAddRec_match(Op0_t Op0, Op1_t Op1, const Loop *L = nullptr)
       : Op0(Op0), Op1(Op1), L(L) {}
 
   bool match(const SCEV *S) const {
-    auto *E = dyn_cast<SCEVTy>(S);
-    bool LoopMatches = true;
-    if constexpr (std::is_same_v<SCEVTy, SCEVAddRecExpr>)
-      LoopMatches = !L || (E && E->getLoop() == L);
+    auto *E = dyn_cast<SCEVAddRecExpr>(S);
     return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) &&
-           Op1.match(E->getOperand(1)) && LoopMatches;
+           Op1.match(E->getOperand(1)) && (!L || E->getLoop() == L);
   }
 };
 
@@ -203,10 +213,10 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
 }
 
 template <typename Op0_t, typename Op1_t>
-inline SCEVBinaryExpr_match<SCEVAddRecExpr, Op0_t, Op1_t>
+inline SCEVAffineAddRec_match<Op0_t, Op1_t>
 m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1,
                     const Loop *L = nullptr) {
-  return m_scev_Binary<SCEVAddRecExpr>(Op0, Op1);
+  return SCEVAffineAddRec_match<Op0_t, Op1_t>(Op0, Op1, L);
 }
 } // namespace SCEVPatternMatch
 } // namespace llvm

>From 48f465e3def99097b060a3e9134c87bdf2c50537 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 22 May 2025 21:48:09 +0100
Subject: [PATCH 3/4] [SCEVPM] Fix default arg thinko

---
 llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 3d9a72da6d645..437c9edb49600 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -178,7 +178,7 @@ template <typename Op0_t, typename Op1_t> struct SCEVAffineAddRec_match {
   Op1_t Op1;
   const Loop *L;
 
-  SCEVAffineAddRec_match(Op0_t Op0, Op1_t Op1, const Loop *L = nullptr)
+  SCEVAffineAddRec_match(Op0_t Op0, Op1_t Op1, const Loop *L)
       : Op0(Op0), Op1(Op1), L(L) {}
 
   bool match(const SCEV *S) const {

>From 36aff8e566561e9a6d19217bc3af35a715c09074 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Fri, 23 May 2025 16:08:17 +0100
Subject: [PATCH 4/4] [SCEVPM] Strip dead include

---
 llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 437c9edb49600..5fc3840bd1d4c 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -14,7 +14,6 @@
 #define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
 
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
-#include <type_traits>
 
 namespace llvm {
 namespace SCEVPatternMatch {



More information about the llvm-commits mailing list