[llvm] 65697b1 - Remove value cache in SCEV comparator. (#100721)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 30 05:27:41 PDT 2024


Author: Johannes Reifferscheid
Date: 2024-07-30T14:27:37+02:00
New Revision: 65697b1c7cc9824d51f22f109c6a32428a7dd557

URL: https://github.com/llvm/llvm-project/commit/65697b1c7cc9824d51f22f109c6a32428a7dd557
DIFF: https://github.com/llvm/llvm-project/commit/65697b1c7cc9824d51f22f109c6a32428a7dd557.diff

LOG: Remove value cache in SCEV comparator. (#100721)

The cache triggers almost never, and seems unlikely to help with
performance. However, when it does, it is likely to cause the comparator
to become inconsistent due to a bad interaction of the depth limit and
cache hits. This leads to crashes in debug builds. See the new unit test
for a reproducer.

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index edd943d4eeea8..f60a6c585e5ba 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -597,11 +597,9 @@ void SCEVUnknown::allUsesReplacedWith(Value *New) {
 ///
 /// Since we do not continue running this routine on expression trees once we
 /// have seen unequal values, there is no need to track them in the cache.
-static int
-CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
-                       const LoopInfo *const LI, Value *LV, Value *RV,
-                       unsigned Depth) {
-  if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
+static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
+                                  Value *RV, unsigned Depth) {
+  if (Depth > MaxValueCompareDepth)
     return 0;
 
   // Order pointer values after integer values. This helps SCEVExpander form
@@ -660,15 +658,13 @@ CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
       return (int)LNumOps - (int)RNumOps;
 
     for (unsigned Idx : seq(LNumOps)) {
-      int Result =
-          CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
-                                 RInst->getOperand(Idx), Depth + 1);
+      int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
+                                          RInst->getOperand(Idx), Depth + 1);
       if (Result != 0)
         return Result;
     }
   }
 
-  EqCacheValue.unionSets(LV, RV);
   return 0;
 }
 
@@ -679,7 +675,6 @@ CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
 // not know if they are equivalent for sure.
 static std::optional<int>
 CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
-                      EquivalenceClasses<const Value *> &EqCacheValue,
                       const LoopInfo *const LI, const SCEV *LHS,
                       const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
   // Fast-path: SCEVs are uniqued so we can do a quick equality check.
@@ -705,8 +700,8 @@ CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
     const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
     const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
 
-    int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
-                                   RU->getValue(), Depth + 1);
+    int X =
+        CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
     if (X == 0)
       EqCacheSCEV.unionSets(LHS, RHS);
     return X;
@@ -773,8 +768,8 @@ CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
       return (int)LNumOps - (int)RNumOps;
 
     for (unsigned i = 0; i != LNumOps; ++i) {
-      auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LOps[i],
-                                     ROps[i], DT, Depth + 1);
+      auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT,
+                                     Depth + 1);
       if (X != 0)
         return X;
     }
@@ -802,12 +797,10 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
   if (Ops.size() < 2) return;  // Noop
 
   EquivalenceClasses<const SCEV *> EqCacheSCEV;
-  EquivalenceClasses<const Value *> EqCacheValue;
 
   // Whether LHS has provably less complexity than RHS.
   auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
-    auto Complexity =
-        CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
+    auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT);
     return Complexity && *Complexity < 0;
   };
   if (Ops.size() == 2) {

diff  --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index a6a5ffda3cb70..76e6095636305 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1625,4 +1625,40 @@ TEST_F(ScalarEvolutionsTest, ForgetValueWithOverflowInst) {
   });
 }
 
+TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering) {
+  // Regression test for a case where caching of equivalent values caused the
+  // comparator to get inconsistent.
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString(R"(
+    define i32 @foo(i32 %arg0) {
+      %1 = add i32 %arg0, 1
+      %2 = add i32 %arg0, 1
+      %3 = xor i32 %2, %1
+      %4 = add i32 %3, %2
+      %5 = add i32 %arg0, 1
+      %6 = xor i32 %5, %arg0
+      %7 = add i32 %arg0, %6
+      %8 = add i32 %5, %7
+      %9 = xor i32 %8, %7
+      %10 = add i32 %9, %8
+      %11 = xor i32 %10, %9
+      %12 = add i32 %11, %10
+      %13 = xor i32 %12, %11
+      %14 = add i32 %12, %13
+      %15 = add i32 %14, %4
+      ret i32 %15
+    })",
+                                                  Err, C);
+
+  ASSERT_TRUE(M && "Could not parse module?");
+  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    // When _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG, this will
+    // crash if the comparator has the specific caching bug.
+    SE.getSCEV(F.getEntryBlock().getTerminator()->getOperand(0));
+  });
+}
+
 }  // end namespace llvm


        


More information about the llvm-commits mailing list