[llvm] 6da3361 - [SCEV] Look through multiply in computeConstantDifference() (#103051)

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 14 00:37:41 PDT 2024


Author: Nikita Popov
Date: 2024-08-14T09:37:38+02:00
New Revision: 6da3361f504495cef71caa4de4297234b6ea7fc7

URL: https://github.com/llvm/llvm-project/commit/6da3361f504495cef71caa4de4297234b6ea7fc7
DIFF: https://github.com/llvm/llvm-project/commit/6da3361f504495cef71caa4de4297234b6ea7fc7.diff

LOG: [SCEV] Look through multiply in computeConstantDifference() (#103051)

Inside computeConstantDifference(), handle the case where both sides are
of the form `C * %x`, in which case we can strip off the common
multiplication (as long as we remember to multiply by it for the
following difference calculation).

There is an obvious alternative implementation here, which would be to
directly decompose multiplies inside the "Multiplicity" accumulation.
This does work, but I've found this to be both significantly slower
(because everything has to work on APInt) and more complex in
implementation (e.g. because we now need to match back the new More/Less
with an arbitrary factor) without providing more power in practice. As
such, I went for the simpler variant here.

This is the last step to make computeConstantDifference() sufficiently
powerful to replace existing uses of
`cast<SCEVConstant>(getMinusSCEV())` with it.

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 9a568a252f8d27..af341c55205de8 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11953,9 +11953,10 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
 
   unsigned BW = getTypeSizeInBits(More->getType());
   APInt Diff(BW, 0);
+  APInt DiffMul(BW, 1);
   // Try various simplifications to reduce the 
diff erence to a constant. Limit
   // the number of allowed simplifications to keep compile-time low.
-  for (unsigned I = 0; I < 4; ++I) {
+  for (unsigned I = 0; I < 5; ++I) {
     if (More == Less)
       return Diff;
 
@@ -11980,15 +11981,36 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
       continue;
     }
 
+    // Try to match a common constant multiply.
+    auto MatchConstMul =
+        [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
+      auto *M = dyn_cast<SCEVMulExpr>(S);
+      if (!M || M->getNumOperands() != 2 ||
+          !isa<SCEVConstant>(M->getOperand(0)))
+        return std::nullopt;
+      return {
+          {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
+    };
+    if (auto MatchedMore = MatchConstMul(More)) {
+      if (auto MatchedLess = MatchConstMul(Less)) {
+        if (MatchedMore->second == MatchedLess->second) {
+          More = MatchedMore->first;
+          Less = MatchedLess->first;
+          DiffMul *= MatchedMore->second;
+          continue;
+        }
+      }
+    }
+
     // Try to cancel out common factors in two add expressions.
     SmallDenseMap<const SCEV *, int, 8> Multiplicity;
     auto Add = [&](const SCEV *S, int Mul) {
       if (auto *C = dyn_cast<SCEVConstant>(S)) {
         if (Mul == 1) {
-          Diff += C->getAPInt();
+          Diff += C->getAPInt() * DiffMul;
         } else {
           assert(Mul == -1);
-          Diff -= C->getAPInt();
+          Diff -= C->getAPInt() * DiffMul;
         }
       } else
         Multiplicity[S] += Mul;

diff  --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 42aad6ae507bf6..d4d90d80f4cea1 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1142,6 +1142,8 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
         %var = load i32, ptr %ptr
         %iv2pvar = add i32 %iv2, %var
         %iv2pvarp3 = add i32 %iv2pvar, 3
+        %iv2pvarm3 = mul i32 %iv2pvar, 3
+        %iv2pvarp3m3 = mul i32 %iv2pvarp3, 3
         %cmp2 = icmp sle i32 %iv2.next, %sz
         br i1 %cmp2, label %loop2.body, label %exit
       exit:
@@ -1178,6 +1180,12 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
     // %var + {{3,+,1},+,1}
     const SCEV *ScevIV2PVarP3 =
         SE.getSCEV(getInstructionByName(F, "iv2pvarp3"));
+    // 3 * (%var + {{0,+,1},+,1})
+    const SCEV *ScevIV2PVarM3 =
+        SE.getSCEV(getInstructionByName(F, "iv2pvarm3"));
+    // 3 * (%var + {{3,+,1},+,1})
+    const SCEV *ScevIV2PVarP3M3 =
+        SE.getSCEV(getInstructionByName(F, "iv2pvarp3m3"));
 
     auto 
diff  = [&SE](const SCEV *LHS, const SCEV *RHS) -> std::optional<int> {
       auto ConstantDiffOrNone = computeConstantDifference(SE, LHS, RHS);
@@ -1204,6 +1212,7 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
     EXPECT_EQ(
diff (ScevIVNext, ScevIVNext), 0);
     EXPECT_EQ(
diff (ScevIV2P3, ScevIV2), 3);
     EXPECT_EQ(
diff (ScevIV2PVar, ScevIV2PVarP3), -3);
+    EXPECT_EQ(
diff (ScevIV2PVarP3M3, ScevIV2PVarM3), 9);
     EXPECT_EQ(
diff (ScevV0, ScevIV), std::nullopt);
     EXPECT_EQ(
diff (ScevIVNext, ScevV3), std::nullopt);
     EXPECT_EQ(
diff (ScevYY, ScevV3), std::nullopt);


        


More information about the llvm-commits mailing list