[llvm] [LV] Use SCEVPatternMatch to improve code (NFC) (PR #154568)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 20 09:16:45 PDT 2025


https://github.com/artagnon created https://github.com/llvm/llvm-project/pull/154568

The change has necessitated splitting up m_scev_SpecificInt into signed and unsigned variants.

>From c7672005b42315277d845f8352620f0b0b549c06 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Wed, 20 Aug 2025 17:04:33 +0100
Subject: [PATCH] [LV] Use SCEVPatternMatch to improve code (NFC)

The change has necessitated splitting up m_scev_SpecificInt into signed
and unsigned variants.
---
 .../Analysis/ScalarEvolutionPatternMatch.h    | 21 +++++++++----
 llvm/lib/Analysis/HashRecognize.cpp           |  4 +--
 .../Transforms/Scalar/LoopStrengthReduce.cpp  |  4 +--
 .../Transforms/Vectorize/LoopVectorize.cpp    | 30 ++++++-------------
 4 files changed, 29 insertions(+), 30 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 011d5994dc670..c233ae385db3a 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -107,14 +107,25 @@ struct specificscev_ty {
 /// Match if we have a specific specified SCEV.
 inline specificscev_ty m_scev_Specific(const SCEV *S) { return S; }
 
-struct is_specific_cst {
-  uint64_t CV;
-  is_specific_cst(uint64_t C) : CV(C) {}
-  bool isValue(const APInt &C) const { return C == CV; }
+template <typename ITy> struct is_specific_cst {
+  ITy CV;
+  is_specific_cst(ITy C) : CV(C) {}
+  bool isValue(const APInt &C) const {
+    if constexpr (std::is_unsigned_v<ITy>)
+      return C.tryZExtValue() == CV;
+    return C.trySExtValue() == CV;
+  }
 };
 
 /// Match an SCEV constant with a plain unsigned integer.
-inline cst_pred_ty<is_specific_cst> m_scev_SpecificInt(uint64_t V) { return V; }
+inline cst_pred_ty<is_specific_cst<uint64_t>> m_scev_SpecificUInt(uint64_t V) {
+  return V;
+}
+
+/// Match an SCEV constant with a plain signed integer.
+inline cst_pred_ty<is_specific_cst<int64_t>> m_scev_SpecificSInt(int64_t V) {
+  return V;
+}
 
 struct bind_cst_ty {
   const APInt *&CR;
diff --git a/llvm/lib/Analysis/HashRecognize.cpp b/llvm/lib/Analysis/HashRecognize.cpp
index 92c9e37dbb484..0c5f02bceddfa 100644
--- a/llvm/lib/Analysis/HashRecognize.cpp
+++ b/llvm/lib/Analysis/HashRecognize.cpp
@@ -562,9 +562,9 @@ static std::optional<bool> isBigEndianBitShift(Value *V, ScalarEvolution &SE) {
     return {};
 
   const SCEV *E = SE.getSCEV(V);
-  if (match(E, m_scev_UDiv(m_SCEV(), m_scev_SpecificInt(2))))
+  if (match(E, m_scev_UDiv(m_SCEV(), m_scev_SpecificUInt(2))))
     return false;
-  if (match(E, m_scev_Mul(m_scev_SpecificInt(2), m_SCEV())))
+  if (match(E, m_scev_Mul(m_scev_SpecificUInt(2), m_SCEV())))
     return true;
   return {};
 }
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index e3ef9d8680b53..b68a6519c0c83 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -2532,8 +2532,8 @@ ICmpInst *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) {
   // Check the relevant induction variable for conformance to
   // the pattern.
   const SCEV *IV = SE.getSCEV(Cond->getOperand(0));
-  if (!match(IV,
-             m_scev_AffineAddRec(m_scev_SpecificInt(1), m_scev_SpecificInt(1))))
+  if (!match(IV, m_scev_AffineAddRec(m_scev_SpecificUInt(1),
+                                     m_scev_SpecificUInt(1))))
     return Cond;
 
   assert(cast<SCEVAddRecExpr>(IV)->getLoop() == L &&
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 70f884016d08c..e23b6ff7464fe 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5903,21 +5903,11 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
             // TODO: do we need to figure out the cost of an extract to get the
             // first lane? Or do we hope that it will be folded away?
             ScalarEvolution *SE = PSE.getSE();
-            const auto *SAR =
-                dyn_cast<SCEVAddRecExpr>(SE->getSCEV(ScalarParam));
-
-            if (!SAR || SAR->getLoop() != TheLoop) {
+            if (!match(SE->getSCEV(ScalarParam),
+                       m_scev_AffineAddRec(
+                           m_SCEV(), m_scev_SpecificSInt(Param.LinearStepOrPos),
+                           m_SpecificLoop(TheLoop))))
               ParamsOk = false;
-              break;
-            }
-
-            const SCEVConstant *Step =
-                dyn_cast<SCEVConstant>(SAR->getStepRecurrence(*SE));
-
-            if (!Step ||
-                Step->getAPInt().getSExtValue() != Param.LinearStepOrPos)
-              ParamsOk = false;
-
             break;
           }
           case VFParamKind::GlobalPredicate:
@@ -8873,13 +8863,12 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
   };
   for (auto [_, Stride] : Legal->getLAI()->getSymbolicStrides()) {
     auto *StrideV = cast<SCEVUnknown>(Stride)->getValue();
-    auto *ScevStride = dyn_cast<SCEVConstant>(PSE.getSCEV(StrideV));
-    // Only handle constant strides for now.
-    if (!ScevStride)
+    const APInt *ScevStride;
+    if (!match(PSE.getSCEV(StrideV), m_scev_APInt(ScevStride)))
       continue;
 
-    auto *CI = Plan->getOrAddLiveIn(
-        ConstantInt::get(Stride->getType(), ScevStride->getAPInt()));
+    auto *CI =
+        Plan->getOrAddLiveIn(ConstantInt::get(Stride->getType(), *ScevStride));
     if (VPValue *StrideVPV = Plan->getLiveIn(StrideV))
       StrideVPV->replaceUsesWithIf(CI, CanUseVersionedStride);
 
@@ -8892,8 +8881,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
       if (!StrideVPV)
         continue;
       unsigned BW = U->getType()->getScalarSizeInBits();
-      APInt C = isa<SExtInst>(U) ? ScevStride->getAPInt().sext(BW)
-                                 : ScevStride->getAPInt().zext(BW);
+      APInt C = isa<SExtInst>(U) ? ScevStride->sext(BW) : ScevStride->zext(BW);
       VPValue *CI = Plan->getOrAddLiveIn(ConstantInt::get(U->getType(), C));
       StrideVPV->replaceUsesWithIf(CI, CanUseVersionedStride);
     }



More information about the llvm-commits mailing list