[llvm] 98a6b6e - [SCEV] Improve code using SCEVPatternMatch (NFC) (#163946)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 17 06:49:56 PDT 2025


Author: Ramkumar Ramachandra
Date: 2025-10-17T14:49:52+01:00
New Revision: 98a6b6e78ae4cf53329ef5b5464a055fed259014

URL: https://github.com/llvm/llvm-project/commit/98a6b6e78ae4cf53329ef5b5464a055fed259014
DIFF: https://github.com/llvm/llvm-project/commit/98a6b6e78ae4cf53329ef5b5464a055fed259014.diff

LOG: [SCEV] Improve code using SCEVPatternMatch (NFC) (#163946)

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 68198ec9b8a9f..9354eef98fe91 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -256,6 +256,18 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
   return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
 }
 
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVSMaxExpr, Op0_t, Op1_t>
+m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_scev_Binary<SCEVSMaxExpr>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVMinMaxExpr, Op0_t, Op1_t>
+m_scev_MinMax(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_scev_Binary<SCEVMinMaxExpr>(Op0, Op1);
+}
+
 /// Match unsigned remainder pattern.
 /// Matches patterns generated by getURemExpr.
 template <typename Op0_t, typename Op1_t> struct SCEVURem_match {

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 442b9d1e8a30e..e06b0956d4c82 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -1840,19 +1840,19 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
     //   = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
     //   = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
     //
-    if (SM->getNumOperands() == 2)
-      if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
-        if (MulLHS->getAPInt().isPowerOf2())
-          if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
-            int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
-                               MulLHS->getAPInt().logBase2();
-            Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
-            return getMulExpr(
-                getZeroExtendExpr(MulLHS, Ty),
-                getZeroExtendExpr(
-                    getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
-                SCEV::FlagNUW, Depth + 1);
-          }
+    const APInt *C;
+    const SCEV *TruncRHS;
+    if (match(SM,
+              m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
+        C->isPowerOf2()) {
+      int NewTruncBits =
+          getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
+      Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
+      return getMulExpr(
+          getZeroExtendExpr(SM->getOperand(0), Ty),
+          getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
+          SCEV::FlagNUW, Depth + 1);
+    }
   }
 
   // zext(umin(x, y)) -> umin(zext(x), zext(y))
@@ -3144,20 +3144,19 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
     if (Ops.size() == 2) {
       // C1*(C2+V) -> C1*C2 + C1*V
-      if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
-        // If any of Add's ops are Adds or Muls with a constant, apply this
-        // transformation as well.
-        //
-        // TODO: There are some cases where this transformation is not
-        // profitable; for example, Add = (C0 + X) * Y + Z.  Maybe the scope of
-        // this transformation should be narrowed down.
-        if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
-          const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
-                                       SCEV::FlagAnyWrap, Depth + 1);
-          const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
-                                       SCEV::FlagAnyWrap, Depth + 1);
-          return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
-        }
+      // If any of Add's ops are Adds or Muls with a constant, apply this
+      // transformation as well.
+      //
+      // TODO: There are some cases where this transformation is not
+      // profitable; for example, Add = (C0 + X) * Y + Z.  Maybe the scope of
+      // this transformation should be narrowed down.
+      const SCEV *Op0, *Op1;
+      if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
+          containsConstantInAddMulChain(Ops[1])) {
+        const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
+        const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
+        return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
+      }
 
       if (Ops[0]->isAllOnesValue()) {
         // If we have a mul by -1 of an add, try distributing the -1 among the
@@ -3578,20 +3577,12 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
   }
 
   // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
-  if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
-      AE && AE->getNumOperands() == 2) {
-    if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
-      const APInt &NegC = VC->getAPInt();
-      if (NegC.isNegative() && !NegC.isMinSignedValue()) {
-        const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
-        if (MME && MME->getNumOperands() == 2 &&
-            isa<SCEVConstant>(MME->getOperand(0)) &&
-            cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
-            MME->getOperand(1) == RHS)
-          return getZero(LHS->getType());
-      }
-    }
-  }
+  const APInt *NegC, *C;
+  if (match(LHS,
+            m_scev_Add(m_scev_APInt(NegC),
+                       m_scev_SMax(m_scev_APInt(C), m_scev_Specific(RHS)))) &&
+      NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
+    return getZero(LHS->getType());
 
   // TODO: Generalize to handle any common factors.
   // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
@@ -10791,19 +10782,15 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) {
 }
 
 static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
-  const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
-  if (!Add || Add->getNumOperands() != 2)
+  const SCEV *Op0, *Op1;
+  if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
     return false;
-  if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
-      ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
-    LHS = Add->getOperand(1);
-    RHS = ME->getOperand(1);
+  if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
+    LHS = Op1;
     return true;
   }
-  if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
-      ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
-    LHS = Add->getOperand(0);
-    RHS = ME->getOperand(1);
+  if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
+    LHS = Op0;
     return true;
   }
   return false;
@@ -12166,13 +12153,10 @@ bool ScalarEvolution::isImpliedCondBalancedTypes(
 bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
                                      const SCEV *&L, const SCEV *&R,
                                      SCEV::NoWrapFlags &Flags) {
-  const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
-  if (!AE || AE->getNumOperands() != 2)
+  if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
     return false;
 
-  L = AE->getOperand(0);
-  R = AE->getOperand(1);
-  Flags = AE->getNoWrapFlags();
+  Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
   return true;
 }
 
@@ -15550,19 +15534,10 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
     auto IsMinMaxSCEVWithNonNegativeConstant =
         [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
             const SCEV *&RHS) {
-          if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
-            if (MinMax->getNumOperands() != 2)
-              return false;
-            if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
-              if (C->getAPInt().isNegative())
-                return false;
-              SCTy = MinMax->getSCEVType();
-              LHS = MinMax->getOperand(0);
-              RHS = MinMax->getOperand(1);
-              return true;
-            }
-          }
-          return false;
+          const APInt *C;
+          SCTy = Expr->getSCEVType();
+          return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) &&
+                 match(LHS, m_scev_APInt(C)) && C->isNonNegative();
         };
 
     // Return a new SCEV that modifies \p Expr to the closest number divides by


        


More information about the llvm-commits mailing list