[llvm] [SCEV] Unify and optimize constant folding (NFC) (PR #101473)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 1 06:51:15 PDT 2024


https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/101473

>From d48adf87b6309f7003f9a2f2db7b34f59420df72 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Wed, 31 Jul 2024 17:10:30 +0200
Subject: [PATCH 1/2] [SCEV] Unify and optimize constant folding (NFC)

Add a common ConstantFoldAndGroupOps() helper that takes care of
common constant folding and grouping transforms that are common
to all nary ops. This moves the constant folding prior to grouping,
which is more efficient, and excludes any constant from the sort.

The constant folding has hooks for folding, identity constants
and absorber constants.
---
 llvm/lib/Analysis/ScalarEvolution.cpp | 184 ++++++++++++--------------
 1 file changed, 84 insertions(+), 100 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index fb56d5d436653..16f0d7ae2f992 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -831,6 +831,41 @@ static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
   });
 }
 
+template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
+static const SCEV *
+ConstantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT,
+                        SmallVectorImpl<const SCEV *> &Ops, FoldT Fold,
+                        IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
+  const SCEVConstant *Folded = nullptr;
+  for (unsigned Idx = 0; Idx < Ops.size();) {
+    const SCEV *Op = Ops[Idx];
+    if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
+      if (!Folded)
+        Folded = C;
+      else
+        Folded = cast<SCEVConstant>(
+            SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
+      Ops.erase(Ops.begin() + Idx);
+      continue;
+    }
+    ++Idx;
+  }
+
+  if (Ops.empty()) {
+    assert(Folded && "Must have folded value");
+    return Folded;
+  }
+
+  if (Folded && IsAbsorber(Folded->getAPInt()))
+    return Folded;
+
+  GroupByComplexity(Ops, &LI, DT);
+  if (Folded && !IsIdentity(Folded->getAPInt()))
+    Ops.insert(Ops.begin(), Folded);
+
+  return Ops.size() == 1 ? Ops[0] : nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 //                      Simple SCEV method implementations
 //===----------------------------------------------------------------------===//
@@ -2504,30 +2539,15 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
   assert(NumPtrs <= 1 && "add has at most one pointer operand");
 #endif
 
-  // Sort by complexity, this groups all similar expression types together.
-  GroupByComplexity(Ops, &LI, DT);
+  const SCEV *Folded = ConstantFoldAndGroupOps(
+      *this, LI, DT, Ops,
+      [](const APInt &C1, const APInt &C2) { return C1 + C2; },
+      [](const APInt &C) { return C.isZero(); }, // identity
+      [](const APInt &C) { return false; });     // absorber
+  if (Folded)
+    return Folded;
 
-  // If there are any constants, fold them together.
-  unsigned Idx = 0;
-  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
-    ++Idx;
-    assert(Idx < Ops.size());
-    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
-      // We found two constants, fold them together!
-      Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
-      if (Ops.size() == 2) return Ops[0];
-      Ops.erase(Ops.begin()+1);  // Erase the folded element
-      LHSC = cast<SCEVConstant>(Ops[0]);
-    }
-
-    // If we are left with a constant zero being added, strip it off.
-    if (LHSC->getValue()->isZero()) {
-      Ops.erase(Ops.begin());
-      --Idx;
-    }
-
-    if (Ops.size() == 1) return Ops[0];
-  }
+  unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
 
   // Delay expensive flag strengthening until necessary.
   auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
@@ -3097,35 +3117,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
            "SCEVMulExpr operand types don't match!");
 #endif
 
-  // Sort by complexity, this groups all similar expression types together.
-  GroupByComplexity(Ops, &LI, DT);
-
-  // If there are any constants, fold them together.
-  unsigned Idx = 0;
-  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
-    ++Idx;
-    assert(Idx < Ops.size());
-    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
-      // We found two constants, fold them together!
-      Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
-      if (Ops.size() == 2) return Ops[0];
-      Ops.erase(Ops.begin()+1);  // Erase the folded element
-      LHSC = cast<SCEVConstant>(Ops[0]);
-    }
-
-    // If we have a multiply of zero, it will always be zero.
-    if (LHSC->getValue()->isZero())
-      return LHSC;
-
-    // If we are left with a constant one being multiplied, strip it off.
-    if (LHSC->getValue()->isOne()) {
-      Ops.erase(Ops.begin());
-      --Idx;
-    }
-
-    if (Ops.size() == 1)
-      return Ops[0];
-  }
+  const SCEV *Folded = ConstantFoldAndGroupOps(
+      *this, LI, DT, Ops,
+      [](const APInt &C1, const APInt &C2) { return C1 * C2; },
+      [](const APInt &C) { return C.isOne(); },   // identity
+      [](const APInt &C) { return C.isZero(); }); // absorber
+  if (Folded)
+    return Folded;
 
   // Delay expensive flag strengthening until necessary.
   auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
@@ -3202,6 +3200,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
   }
 
   // Skip over the add expression until we get to a multiply.
+  unsigned Idx = 0;
   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
     ++Idx;
 
@@ -3829,61 +3828,46 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
   bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
   bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
 
-  // Sort by complexity, this groups all similar expression types together.
-  GroupByComplexity(Ops, &LI, DT);
+  const SCEV *Folded = ConstantFoldAndGroupOps(
+      *this, LI, DT, Ops,
+      [&](const APInt &C1, const APInt &C2) {
+        switch (Kind) {
+        case scSMaxExpr:
+          return APIntOps::smax(C1, C2);
+        case scSMinExpr:
+          return APIntOps::smin(C1, C2);
+        case scUMaxExpr:
+          return APIntOps::umax(C1, C2);
+        case scUMinExpr:
+          return APIntOps::umin(C1, C2);
+        default:
+          llvm_unreachable("Unknown SCEV min/max opcode");
+        }
+      },
+      [&](const APInt &C) {
+        // identity
+        if (IsMax)
+          return IsSigned ? C.isMinSignedValue() : C.isMinValue();
+        else
+          return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
+      },
+      [&](const APInt &C) {
+        // absorber
+        if (IsMax)
+          return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
+        else
+          return IsSigned ? C.isMinSignedValue() : C.isMinValue();
+      });
+  if (Folded)
+    return Folded;
 
   // Check if we have created the same expression before.
   if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
     return S;
   }
 
-  // If there are any constants, fold them together.
-  unsigned Idx = 0;
-  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
-    ++Idx;
-    assert(Idx < Ops.size());
-    auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
-      switch (Kind) {
-      case scSMaxExpr:
-        return APIntOps::smax(LHS, RHS);
-      case scSMinExpr:
-        return APIntOps::smin(LHS, RHS);
-      case scUMaxExpr:
-        return APIntOps::umax(LHS, RHS);
-      case scUMinExpr:
-        return APIntOps::umin(LHS, RHS);
-      default:
-        llvm_unreachable("Unknown SCEV min/max opcode");
-      }
-    };
-
-    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
-      // We found two constants, fold them together!
-      ConstantInt *Fold = ConstantInt::get(
-          getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
-      Ops[0] = getConstant(Fold);
-      Ops.erase(Ops.begin()+1);  // Erase the folded element
-      if (Ops.size() == 1) return Ops[0];
-      LHSC = cast<SCEVConstant>(Ops[0]);
-    }
-
-    bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
-    bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
-
-    if (IsMax ? IsMinV : IsMaxV) {
-      // If we are left with a constant minimum(/maximum)-int, strip it off.
-      Ops.erase(Ops.begin());
-      --Idx;
-    } else if (IsMax ? IsMaxV : IsMinV) {
-      // If we have a max(/min) with a constant maximum(/minimum)-int,
-      // it will always be the extremum.
-      return LHSC;
-    }
-
-    if (Ops.size() == 1) return Ops[0];
-  }
-
   // Find the first operation of the same kind
+  unsigned Idx = 0;
   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
     ++Idx;
 

>From 20b5603a978909e0fb1d891bb70948cb268239ce Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 1 Aug 2024 15:50:51 +0200
Subject: [PATCH 2/2] rename  & comment

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 16 ++++++++++++----
 1 file changed, 12 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 16f0d7ae2f992..98e3ad66e895c 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -831,9 +831,17 @@ static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
   });
 }
 
+/// Performs a number of common optimizations on the passed \p Ops. If the
+/// whole expression reduces down to a single operand, it will be returned.
+///
+/// The following optimizations are performed:
+///  * Fold constants using the \p Fold function.
+///  * Remove identity constants satisfying \p IsIdentity.
+///  * If a constant satisfies \p IsAbsorber, return it.
+///  * Sort operands by complexity.
 template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
 static const SCEV *
-ConstantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT,
+constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT,
                         SmallVectorImpl<const SCEV *> &Ops, FoldT Fold,
                         IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
   const SCEVConstant *Folded = nullptr;
@@ -2539,7 +2547,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
   assert(NumPtrs <= 1 && "add has at most one pointer operand");
 #endif
 
-  const SCEV *Folded = ConstantFoldAndGroupOps(
+  const SCEV *Folded = constantFoldAndGroupOps(
       *this, LI, DT, Ops,
       [](const APInt &C1, const APInt &C2) { return C1 + C2; },
       [](const APInt &C) { return C.isZero(); }, // identity
@@ -3117,7 +3125,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
            "SCEVMulExpr operand types don't match!");
 #endif
 
-  const SCEV *Folded = ConstantFoldAndGroupOps(
+  const SCEV *Folded = constantFoldAndGroupOps(
       *this, LI, DT, Ops,
       [](const APInt &C1, const APInt &C2) { return C1 * C2; },
       [](const APInt &C) { return C.isOne(); },   // identity
@@ -3828,7 +3836,7 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
   bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
   bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
 
-  const SCEV *Folded = ConstantFoldAndGroupOps(
+  const SCEV *Folded = constantFoldAndGroupOps(
       *this, LI, DT, Ops,
       [&](const APInt &C1, const APInt &C2) {
         switch (Kind) {



More information about the llvm-commits mailing list