[llvm] [LSR] Clean up code using SCEVPatternMatch (NFC) (PR #145556)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 24 10:25:39 PDT 2025


https://github.com/artagnon created https://github.com/llvm/llvm-project/pull/145556

None

>From 0074f5dcedc26cdb76f7a96ea14c4654526c853a Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Tue, 24 Jun 2025 18:12:09 +0100
Subject: [PATCH] [LSR] Clean up code using SCEVPatternMatch (NFC)

---
 .../Transforms/Scalar/LoopStrengthReduce.cpp  | 111 ++++++++----------
 1 file changed, 48 insertions(+), 63 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 4ba69034d6448..9ffcfb47b0a89 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -923,10 +923,12 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
 /// If S involves the addition of a constant integer value, return that integer
 /// value, and mutate S to point to a new SCEV with that value excluded.
 static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
-  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
-    if (C->getAPInt().getSignificantBits() <= 64) {
-      S = SE.getConstant(C->getType(), 0);
-      return Immediate::getFixed(C->getValue()->getSExtValue());
+  const APInt *C;
+  const SCEV *Op1;
+  if (match(S, m_scev_APInt(C))) {
+    if (C->getSignificantBits() <= 64) {
+      S = SE.getConstant(S->getType(), 0);
+      return Immediate::getFixed(C->getSExtValue());
     }
   } else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
     SmallVector<const SCEV *, 8> NewOps(Add->operands());
@@ -942,14 +944,11 @@ static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
                            // FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
                            SCEV::FlagAnyWrap);
     return Result;
-  } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
-    if (EnableVScaleImmediates && M->getNumOperands() == 2) {
-      if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0)))
-        if (isa<SCEVVScale>(M->getOperand(1))) {
-          S = SE.getConstant(M->getType(), 0);
-          return Immediate::getScalable(C->getValue()->getSExtValue());
-        }
-    }
+  } else if (EnableVScaleImmediates &&
+             match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) &&
+             isa<SCEVVScale>(Op1)) {
+    S = SE.getConstant(S->getType(), 0);
+    return Immediate::getScalable(C->getSExtValue());
   }
   return Immediate::getZero();
 }
@@ -1133,23 +1132,22 @@ static bool isHighCostExpansion(const SCEV *S,
     return false;
   }
 
-  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
-    if (Mul->getNumOperands() == 2) {
-      // Multiplication by a constant is ok
-      if (isa<SCEVConstant>(Mul->getOperand(0)))
-        return isHighCostExpansion(Mul->getOperand(1), Processed, SE);
-
-      // If we have the value of one operand, check if an existing
-      // multiplication already generates this expression.
-      if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Mul->getOperand(1))) {
-        Value *UVal = U->getValue();
-        for (User *UR : UVal->users()) {
-          // If U is a constant, it may be used by a ConstantExpr.
-          Instruction *UI = dyn_cast<Instruction>(UR);
-          if (UI && UI->getOpcode() == Instruction::Mul &&
-              SE.isSCEVable(UI->getType())) {
-            return SE.getSCEV(UI) == Mul;
-          }
+  const SCEV *Op0, *Op1;
+  if (match(S, m_scev_Mul(m_SCEV(Op0), m_SCEV(Op1)))) {
+    // Multiplication by a constant is ok
+    if (isa<SCEVConstant>(Op0))
+      return isHighCostExpansion(Op1, Processed, SE);
+
+    // If we have the value of one operand, check if an existing
+    // multiplication already generates this expression.
+    if (const auto *U = dyn_cast<SCEVUnknown>(Op1)) {
+      Value *UVal = U->getValue();
+      for (User *UR : UVal->users()) {
+        // If U is a constant, it may be used by a ConstantExpr.
+        Instruction *UI = dyn_cast<Instruction>(UR);
+        if (UI && UI->getOpcode() == Instruction::Mul &&
+            SE.isSCEVable(UI->getType())) {
+          return SE.getSCEV(UI) == S;
         }
       }
     }
@@ -3333,14 +3331,12 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
     IncOffset = Immediate::getFixed(IncConst->getValue()->getSExtValue());
   } else {
     // Look for mul(vscale, constant), to detect a scalable offset.
-    auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr);
-    if (!IncVScale || IncVScale->getNumOperands() != 2 ||
-        !isa<SCEVVScale>(IncVScale->getOperand(1)))
-      return false;
-    auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0));
-    if (!Scale || Scale->getType()->getScalarSizeInBits() > 64)
+    const APInt *C;
+    const SCEV *Op1;
+    if (!match(IncExpr, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) ||
+        !isa<SCEVVScale>(Op1) || C->getSignificantBits() > 64)
       return false;
-    IncOffset = Immediate::getScalable(Scale->getValue()->getSExtValue());
+    IncOffset = Immediate::getScalable(C->getSExtValue());
   }
 
   if (!isAddressUse(TTI, UserInst, Operand))
@@ -3818,6 +3814,8 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
     return nullptr;
   }
   const SCEV *Start, *Step;
+  const SCEVConstant *Op0;
+  const SCEV *Op1;
   if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step)))) {
     // Split a non-zero base out of an addrec.
     if (Start->isZero())
@@ -3839,19 +3837,13 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
                               // FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
                               SCEV::FlagAnyWrap);
     }
-  } else if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
+  } else if (match(S, m_scev_Mul(m_SCEVConstant(Op0), m_SCEV(Op1)))) {
     // Break (C * (a + b + c)) into C*a + C*b + C*c.
-    if (Mul->getNumOperands() != 2)
-      return S;
-    if (const SCEVConstant *Op0 =
-        dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
-      C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0;
-      const SCEV *Remainder =
-        CollectSubexprs(Mul->getOperand(1), C, Ops, L, SE, Depth+1);
-      if (Remainder)
-        Ops.push_back(SE.getMulExpr(C, Remainder));
-      return nullptr;
-    }
+    C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0;
+    const SCEV *Remainder = CollectSubexprs(Op1, C, Ops, L, SE, Depth + 1);
+    if (Remainder)
+      Ops.push_back(SE.getMulExpr(C, Remainder));
+    return nullptr;
   }
   return S;
 }
@@ -6478,13 +6470,10 @@ struct SCEVDbgValueBuilder {
   /// Components of the expression are omitted if they are an identity function.
   /// Chain (non-affine) SCEVs are not supported.
   bool SCEVToValueExpr(const llvm::SCEVAddRecExpr &SAR, ScalarEvolution &SE) {
-    assert(SAR.isAffine() && "Expected affine SCEV");
-    // TODO: Is this check needed?
-    if (isa<SCEVAddRecExpr>(SAR.getStart()))
-      return false;
-
-    const SCEV *Start = SAR.getStart();
-    const SCEV *Stride = SAR.getStepRecurrence(SE);
+    const SCEV *Start, *Stride;
+    [[maybe_unused]] bool Match =
+        match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride)));
+    assert(Match && "Expected affine SCEV");
 
     // Skip pushing arithmetic noops.
     if (!isIdentityFunction(llvm::dwarf::DW_OP_mul, Stride)) {
@@ -6549,14 +6538,10 @@ struct SCEVDbgValueBuilder {
   /// Components of the expression are omitted if they are an identity function.
   bool SCEVToIterCountExpr(const llvm::SCEVAddRecExpr &SAR,
                            ScalarEvolution &SE) {
-    assert(SAR.isAffine() && "Expected affine SCEV");
-    if (isa<SCEVAddRecExpr>(SAR.getStart())) {
-      LLVM_DEBUG(dbgs() << "scev-salvage: IV SCEV. Unsupported nested AddRec: "
-                        << SAR << '\n');
-      return false;
-    }
-    const SCEV *Start = SAR.getStart();
-    const SCEV *Stride = SAR.getStepRecurrence(SE);
+    const SCEV *Start, *Stride;
+    [[maybe_unused]] bool Match =
+        match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride)));
+    assert(Match && "Expected affine SCEV");
 
     // Skip pushing arithmetic noops.
     if (!isIdentityFunction(llvm::dwarf::DW_OP_minus, Start)) {



More information about the llvm-commits mailing list