[llvm] 61d3ad9 - [SCEVPatternMatch] Introduce m_scev_AffineAddRec (#140377)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 19 04:02:10 PDT 2025


Author: Ramkumar Ramachandra
Date: 2025-05-19T12:02:07+01:00
New Revision: 61d3ad963c9a8764a20c01c26374000d9ba5975d

URL: https://github.com/llvm/llvm-project/commit/61d3ad963c9a8764a20c01c26374000d9ba5975d
DIFF: https://github.com/llvm/llvm-project/commit/61d3ad963c9a8764a20c01c26374000d9ba5975d.diff

LOG: [SCEVPatternMatch] Introduce m_scev_AffineAddRec (#140377)

Introduce m_scev_AffineAddRec to match affine AddRecs, a class_match for
SCEVConstant, and demonstrate their utility in LSR and SCEV. While at
it, rename m_Specific to m_scev_Specific for clarity.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 5e53a4f502153..cfb1b4c6ea6b4 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -61,6 +61,9 @@ template <typename Class> struct class_match {
 };
 
 inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
+inline class_match<const SCEVConstant> m_SCEVConstant() {
+  return class_match<const SCEVConstant>();
+}
 
 template <typename Class> struct bind_ty {
   Class *&VR;
@@ -95,7 +98,7 @@ struct specificscev_ty {
 };
 
 /// Match if we have a specific specified SCEV.
-inline specificscev_ty m_Specific(const SCEV *S) { return S; }
+inline specificscev_ty m_scev_Specific(const SCEV *S) { return S; }
 
 struct is_specific_cst {
   uint64_t CV;
@@ -192,6 +195,12 @@ inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
 m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
   return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
 }
+
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVAddRecExpr, Op0_t, Op1_t>
+m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_scev_Binary<SCEVAddRecExpr>(Op0, Op1);
+}
 } // namespace SCEVPatternMatch
 } // namespace llvm
 

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 69714a112310e..342c8e39e2d3c 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -12480,26 +12480,21 @@ static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
   if (!ICmpInst::isRelational(Pred))
     return false;
 
-  const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
-  if (!LAR)
-    return false;
-  const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
-  if (!RAR)
+  const SCEV *LStart, *RStart, *Step;
+  if (!match(LHS, m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step))) ||
+      !match(RHS, m_scev_AffineAddRec(m_SCEV(RStart), m_scev_Specific(Step))))
     return false;
+  const SCEVAddRecExpr *LAR = cast<SCEVAddRecExpr>(LHS);
+  const SCEVAddRecExpr *RAR = cast<SCEVAddRecExpr>(RHS);
   if (LAR->getLoop() != RAR->getLoop())
     return false;
-  if (!LAR->isAffine() || !RAR->isAffine())
-    return false;
-
-  if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
-    return false;
 
   SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ?
                          SCEV::FlagNSW : SCEV::FlagNUW;
   if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
     return false;
 
-  return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
+  return SE.isKnownPredicate(Pred, LStart, RStart);
 }
 
 /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
@@ -12716,7 +12711,7 @@ static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS,
   case ICmpInst::ICMP_SLE: {
     // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
     return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
-           match(RHS, m_scev_ZExt(m_Specific(Op)));
+           match(RHS, m_scev_ZExt(m_scev_Specific(Op)));
   }
   case ICmpInst::ICMP_UGE:
     std::swap(LHS, RHS);
@@ -12724,7 +12719,7 @@ static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS,
   case ICmpInst::ICMP_ULE: {
     // If operand >=u 0 then ZExt == SExt.  If operand <u 0 then ZExt <u SExt.
     return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
-           match(RHS, m_scev_SExt(m_Specific(Op)));
+           match(RHS, m_scev_SExt(m_scev_Specific(Op)));
   }
   default:
     return false;

diff  --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 33d6c77f61cfd..bdab14ed34c54 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -77,11 +77,11 @@
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/ScalarEvolutionNormalization.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/BinaryFormat/Dwarf.h"
-#include "llvm/Config/llvm-config.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
@@ -128,6 +128,7 @@
 #include <utility>
 
 using namespace llvm;
+using namespace SCEVPatternMatch;
 
 #define DEBUG_TYPE "loop-reduce"
 
@@ -556,16 +557,17 @@ static void DoInitialMatch(const SCEV *S, Loop *L,
   }
 
   // Look at addrec operands.
-  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S))
-    if (!AR->getStart()->isZero() && AR->isAffine()) {
-      DoInitialMatch(AR->getStart(), L, Good, Bad, SE);
-      DoInitialMatch(SE.getAddRecExpr(SE.getConstant(AR->getType(), 0),
-                                      AR->getStepRecurrence(SE),
-                                      // FIXME: AR->getNoWrapFlags()
-                                      AR->getLoop(), SCEV::FlagAnyWrap),
-                     L, Good, Bad, SE);
-      return;
-    }
+  const SCEV *Start, *Step;
+  if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step))) &&
+      !Start->isZero()) {
+    DoInitialMatch(Start, L, Good, Bad, SE);
+    DoInitialMatch(SE.getAddRecExpr(SE.getConstant(S->getType(), 0), Step,
+                                    // FIXME: AR->getNoWrapFlags()
+                                    cast<SCEVAddRecExpr>(S)->getLoop(),
+                                    SCEV::FlagAnyWrap),
+                   L, Good, Bad, SE);
+    return;
+  }
 
   // Handle a multiplication by -1 (negation) if it didn't fold.
   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S))
@@ -1411,22 +1413,16 @@ void Cost::RateRegister(const Formula &F, const SCEV *Reg,
     unsigned LoopCost = 1;
     if (TTI->isIndexedLoadLegal(TTI->MIM_PostInc, AR->getType()) ||
         TTI->isIndexedStoreLegal(TTI->MIM_PostInc, AR->getType())) {
-
-      // If the step size matches the base offset, we could use pre-indexed
-      // addressing.
-      if (AMK == TTI::AMK_PreIndexed && F.BaseOffset.isFixed()) {
-        if (auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)))
-          if (Step->getAPInt() == F.BaseOffset.getFixedValue())
-            LoopCost = 0;
-      } else if (AMK == TTI::AMK_PostIndexed) {
-        const SCEV *LoopStep = AR->getStepRecurrence(*SE);
-        if (isa<SCEVConstant>(LoopStep)) {
-          const SCEV *LoopStart = AR->getStart();
-          if (!isa<SCEVConstant>(LoopStart) &&
-              SE->isLoopInvariant(LoopStart, L))
-            LoopCost = 0;
-        }
-      }
+      const SCEV *Start;
+      const SCEVConstant *Step;
+      if (match(AR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEVConstant(Step))))
+        // If the step size matches the base offset, we could use pre-indexed
+        // addressing.
+        if ((AMK == TTI::AMK_PreIndexed && F.BaseOffset.isFixed() &&
+             Step->getAPInt() == F.BaseOffset.getFixedValue()) ||
+            (AMK == TTI::AMK_PostIndexed && !isa<SCEVConstant>(Start) &&
+             SE->isLoopInvariant(Start, L)))
+          LoopCost = 0;
     }
     C.AddRecCost += LoopCost;
 
@@ -2519,13 +2515,11 @@ ICmpInst *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) {
   // Check the relevant induction variable for conformance to
   // the pattern.
   const SCEV *IV = SE.getSCEV(Cond->getOperand(0));
-  const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(IV);
-  if (!AR || !AR->isAffine() ||
-      AR->getStart() != One ||
-      AR->getStepRecurrence(SE) != One)
+  if (!match(IV,
+             m_scev_AffineAddRec(m_scev_SpecificInt(1), m_scev_SpecificInt(1))))
     return Cond;
 
-  assert(AR->getLoop() == L &&
+  assert(cast<SCEVAddRecExpr>(IV)->getLoop() == L &&
          "Loop condition operand is an addrec in a 
diff erent loop!");
 
   // Check the right operand of the select, and remember it, as it will
@@ -3320,7 +3314,7 @@ void LSRInstance::CollectChains() {
 void LSRInstance::FinalizeChain(IVChain &Chain) {
   assert(!Chain.Incs.empty() && "empty IV chains are not allowed");
   LLVM_DEBUG(dbgs() << "Final Chain: " << *Chain.Incs[0].UserInst << "\n");
-  
+
   for (const IVInc &Inc : Chain) {
     LLVM_DEBUG(dbgs() << "        Inc: " << *Inc.UserInst << "\n");
     auto UseI = find(Inc.UserInst->operands(), Inc.IVOperand);
@@ -3823,26 +3817,27 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
         Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder);
     }
     return nullptr;
-  } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
+  }
+  const SCEV *Start, *Step;
+  if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step)))) {
     // Split a non-zero base out of an addrec.
-    if (AR->getStart()->isZero() || !AR->isAffine())
+    if (Start->isZero())
       return S;
 
-    const SCEV *Remainder = CollectSubexprs(AR->getStart(),
-                                            C, Ops, L, SE, Depth+1);
+    const SCEV *Remainder = CollectSubexprs(Start, C, Ops, L, SE, Depth + 1);
     // Split the non-zero AddRec unless it is part of a nested recurrence that
     // does not pertain to this loop.
-    if (Remainder && (AR->getLoop() == L || !isa<SCEVAddRecExpr>(Remainder))) {
+    if (Remainder && (cast<SCEVAddRecExpr>(S)->getLoop() == L ||
+                      !isa<SCEVAddRecExpr>(Remainder))) {
       Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder);
       Remainder = nullptr;
     }
-    if (Remainder != AR->getStart()) {
+    if (Remainder != Start) {
       if (!Remainder)
-        Remainder = SE.getConstant(AR->getType(), 0);
-      return SE.getAddRecExpr(Remainder,
-                              AR->getStepRecurrence(SE),
-                              AR->getLoop(),
-                              //FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
+        Remainder = SE.getConstant(S->getType(), 0);
+      return SE.getAddRecExpr(Remainder, Step,
+                              cast<SCEVAddRecExpr>(S)->getLoop(),
+                              // FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
                               SCEV::FlagAnyWrap);
     }
   } else if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
@@ -3870,17 +3865,13 @@ static bool mayUsePostIncMode(const TargetTransformInfo &TTI,
   if (LU.Kind != LSRUse::Address ||
       !LU.AccessTy.getType()->isIntOrIntVectorTy())
     return false;
-  const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S);
-  if (!AR)
-    return false;
-  const SCEV *LoopStep = AR->getStepRecurrence(SE);
-  if (!isa<SCEVConstant>(LoopStep))
+  const SCEV *Start;
+  if (!match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEVConstant())))
     return false;
   // Check if a post-indexed load/store can be used.
-  if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, AR->getType()) ||
-      TTI.isIndexedStoreLegal(TTI.MIM_PostInc, AR->getType())) {
-    const SCEV *LoopStart = AR->getStart();
-    if (!isa<SCEVConstant>(LoopStart) && SE.isLoopInvariant(LoopStart, L))
+  if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, S->getType()) ||
+      TTI.isIndexedStoreLegal(TTI.MIM_PostInc, S->getType())) {
+    if (!isa<SCEVConstant>(Start) && SE.isLoopInvariant(Start, L))
       return true;
   }
   return false;
@@ -4139,18 +4130,15 @@ void LSRInstance::GenerateConstantOffsetsImpl(
   // base pointer for each iteration of the loop, resulting in no extra add/sub
   // instructions for pointer updating.
   if (AMK == TTI::AMK_PreIndexed && LU.Kind == LSRUse::Address) {
-    if (auto *GAR = dyn_cast<SCEVAddRecExpr>(G)) {
-      if (auto *StepRec =
-          dyn_cast<SCEVConstant>(GAR->getStepRecurrence(SE))) {
-        const APInt &StepInt = StepRec->getAPInt();
-        int64_t Step = StepInt.isNegative() ?
-          StepInt.getSExtValue() : StepInt.getZExtValue();
-
-        for (Immediate Offset : Worklist) {
-          if (Offset.isFixed()) {
-            Offset = Immediate::getFixed(Offset.getFixedValue() - Step);
-            GenerateOffset(G, Offset);
-          }
+    const APInt *StepInt;
+    if (match(G, m_scev_AffineAddRec(m_SCEV(), m_scev_APInt(StepInt)))) {
+      int64_t Step = StepInt->isNegative() ? StepInt->getSExtValue()
+                                           : StepInt->getZExtValue();
+
+      for (Immediate Offset : Worklist) {
+        if (Offset.isFixed()) {
+          Offset = Immediate::getFixed(Offset.getFixedValue() - Step);
+          GenerateOffset(G, Offset);
         }
       }
     }
@@ -6621,7 +6609,7 @@ struct SCEVDbgValueBuilder {
       if (Op.getOp() != dwarf::DW_OP_LLVM_arg) {
         Op.appendToVector(DestExpr);
         continue;
-      } 
+      }
 
       DestExpr.push_back(dwarf::DW_OP_LLVM_arg);
       // `DW_OP_LLVM_arg n` represents the nth LocationOp in this SCEV,


        


More information about the llvm-commits mailing list