[llvm] 79af689 - [SCEV] Handle more adds in computeConstantDifference() (#101339)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 2 04:43:05 PDT 2024


Author: Nikita Popov
Date: 2024-08-02T13:43:02+02:00
New Revision: 79af6892f8eee1d0fb96c243716c8b03936751c9

URL: https://github.com/llvm/llvm-project/commit/79af6892f8eee1d0fb96c243716c8b03936751c9
DIFF: https://github.com/llvm/llvm-project/commit/79af6892f8eee1d0fb96c243716c8b03936751c9.diff

LOG: [SCEV] Handle more adds in computeConstantDifference() (#101339)

Currently it only deals with the case where we're subtracting adds with
at most one non-constant operand. This patch extends it to cancel out
common operands for the subtraction of arbitrary add expressions.

The background here is that I want to replace a getMinusSCEV() call in
LAA with computeConstantDifference():

https://github.com/llvm/llvm-project/blob/93fecc2577ece0329f3bbe2719bbc5b4b9b30010/llvm/lib/Analysis/LoopAccessAnalysis.cpp#L1602-L1603

This particular call is very expensive in some cases (e.g. lencod with
LTO) and computeConstantDifference() could achieve this much more
cheaply, because it does not need to construct new SCEV expressions.

However, the current computeConstantDifference() implementation is too
weak for this and misses many basic cases. This is a step towards making
it more powerful while still keeping it pretty fast.

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 06dd17dd3648b..b6c70451b9b06 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11934,8 +11934,9 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
   // fairly deep in the call stack (i.e. is called many times).
 
   // X - X = 0.
+  unsigned BW = getTypeSizeInBits(More->getType());
   if (More == Less)
-    return APInt(getTypeSizeInBits(More->getType()), 0);
+    return APInt(BW, 0);
 
   if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
     const auto *LAR = cast<SCEVAddRecExpr>(Less);
@@ -11958,33 +11959,31 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
     // fall through
   }
 
-  if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
-    const auto &M = cast<SCEVConstant>(More)->getAPInt();
-    const auto &L = cast<SCEVConstant>(Less)->getAPInt();
-    return M - L;
-  }
-
-  SCEV::NoWrapFlags Flags;
-  const SCEV *LLess = nullptr, *RLess = nullptr;
-  const SCEV *LMore = nullptr, *RMore = nullptr;
-  const SCEVConstant *C1 = nullptr, *C2 = nullptr;
-  // Compare (X + C1) vs X.
-  if (splitBinaryAdd(Less, LLess, RLess, Flags))
-    if ((C1 = dyn_cast<SCEVConstant>(LLess)))
-      if (RLess == More)
-        return -(C1->getAPInt());
-
-  // Compare X vs (X + C2).
-  if (splitBinaryAdd(More, LMore, RMore, Flags))
-    if ((C2 = dyn_cast<SCEVConstant>(LMore)))
-      if (RMore == Less)
-        return C2->getAPInt();
+  // Try to cancel out common factors in two add expressions.
+  SmallDenseMap<const SCEV *, int, 8> Multiplicity;
+  APInt Diff(BW, 0);
+  auto Add = [&](const SCEV *S, int Mul) {
+    if (auto *C = dyn_cast<SCEVConstant>(S))
+      Diff += C->getAPInt() * Mul;
+    else
+      Multiplicity[S] += Mul;
+  };
+  auto Decompose = [&](const SCEV *S, int Mul) {
+    if (isa<SCEVAddExpr>(S)) {
+      for (const SCEV *Op : S->operands())
+        Add(Op, Mul);
+    } else
+      Add(S, Mul);
+  };
+  Decompose(More, 1);
+  Decompose(Less, -1);
 
-  // Compare (X + C1) vs (X + C2).
-  if (C1 && C2 && RLess == RMore)
-    return C2->getAPInt() - C1->getAPInt();
+  // Check whether all the non-constants cancel out.
+  for (const auto [_, Mul] : Multiplicity)
+    if (Mul != 0)
+      return std::nullopt;
 
-  return std::nullopt;
+  return Diff;
 }
 
 bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(

diff  --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 76e6095636305..3802ae4051f42 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1117,10 +1117,12 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
   LLVMContext C;
   SMDiagnostic Err;
   std::unique_ptr<Module> M = parseAssemblyString(
-      "define void @foo(i32 %sz, i32 %pp) { "
+      "define void @foo(i32 %sz, i32 %pp, i32 %x) { "
       "entry: "
       "  %v0 = add i32 %pp, 0 "
       "  %v3 = add i32 %pp, 3 "
+      "  %vx = add i32 %pp, %x "
+      "  %vx3 = add i32 %vx, 3 "
       "  br label %loop.body "
       "loop.body: "
       "  %iv = phi i32 [ %iv.next, %loop.body ], [ 0, %entry ] "
@@ -1141,6 +1143,9 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
   runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
     auto *ScevV0 = SE.getSCEV(getInstructionByName(F, "v0")); // %pp
     auto *ScevV3 = SE.getSCEV(getInstructionByName(F, "v3")); // (3 + %pp)
+    auto *ScevVX = SE.getSCEV(getInstructionByName(F, "vx")); // (%pp + %x)
+    // (%pp + %x + 3)
+    auto *ScevVX3 = SE.getSCEV(getInstructionByName(F, "vx3"));
     auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1}
     auto *ScevXA = SE.getSCEV(getInstructionByName(F, "xa")); // {%pp,+,1}
     auto *ScevYY = SE.getSCEV(getInstructionByName(F, "yy")); // {(3 + %pp),+,1}
@@ -1162,6 +1167,7 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
     EXPECT_EQ(
diff (ScevV0, ScevV3), -3);
     EXPECT_EQ(
diff (ScevV0, ScevV0), 0);
     EXPECT_EQ(
diff (ScevV3, ScevV3), 0);
+    EXPECT_EQ(
diff (ScevVX3, ScevVX), 3);
     EXPECT_EQ(
diff (ScevIV, ScevIV), 0);
     EXPECT_EQ(
diff (ScevXA, ScevXB), 0);
     EXPECT_EQ(
diff (ScevXA, ScevYY), -3);


        


More information about the llvm-commits mailing list