[llvm] 8a4ad88 - [SCEV] Do not cache comparison result upon reached max depth as "equivalence". PR48725

Max Kazantsev via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 28 21:36:42 PST 2021


Author: Max Kazantsev
Date: 2021-01-29T12:08:34+07:00
New Revision: 8a4ad8849f4898dd19e31b9dcede7ace3575d00d

URL: https://github.com/llvm/llvm-project/commit/8a4ad8849f4898dd19e31b9dcede7ace3575d00d
DIFF: https://github.com/llvm/llvm-project/commit/8a4ad8849f4898dd19e31b9dcede7ace3575d00d.diff

LOG: [SCEV] Do not cache comparison result upon reached max depth as "equivalence". PR48725

We use `EquivalenceClasses` to cache the notion that two SCEVs are equivalent,
so save time in situation when `A` is equivalent to `B` and `B` is equivalent to `C`,
making check "if `A` is equivalent to `C`?" cheaper.

We also return `0` in the comparator when we reach max analysis depth to save
compile time. After doing this, we also cache them as being equivalent.

Now, imagine the following situation:
- `A` is proved equivalent to `B`;
- `C` is proved equivalent to `D`;
- Comparison of `A` against `D` is proved non-zero;
- Comparison of `B` against `C` reaches max depth (and gets cached as equivalence).

Now, before the invocation of compare(`B`, `C`), `A` and `D` belonged
to different equivalence classes, and their comparison returned non-zero.
After the the invocation of compare(`B`, `C`), equivalence classes get merged
and `A`, `B`, `C` and `D` all fall into the same equivalence class. So the comparator
will change its behavior for couple `A` and `D`, with weird consequences following it.
This comparator is finally used in `std::stable_sort`, and this behavior change
makes it crash (looks like it's causing a memory corruption).

Solution: this patch changes `CompareSCEVComplexity` to return `None`
when the max depth is reached. So in this case, we do not cache these SCEVs
(and their parents in the tree) as being equivalent.

Differential Revision: https://reviews.llvm.org/D94654
Reviewed By: lebedev.ri

Added: 
    llvm/test/Transforms/LoopStrengthReduce/pr48725.ll

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 0b575ae421fc..134d4d0fc8c5 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -689,11 +689,13 @@ CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
 // than RHS, respectively. A three-way result allows recursive comparisons to be
 // more efficient.
-static int CompareSCEVComplexity(
-    EquivalenceClasses<const SCEV *> &EqCacheSCEV,
-    EquivalenceClasses<const Value *> &EqCacheValue,
-    const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS,
-    DominatorTree &DT, unsigned Depth = 0) {
+// If the max analysis depth was reached, return None, assuming we do not know
+// if they are equivalent for sure.
+static Optional<int>
+CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
+                      EquivalenceClasses<const Value *> &EqCacheValue,
+                      const LoopInfo *const LI, const SCEV *LHS,
+                      const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
   // Fast-path: SCEVs are uniqued so we can do a quick equality check.
   if (LHS == RHS)
     return 0;
@@ -703,8 +705,12 @@ static int CompareSCEVComplexity(
   if (LType != RType)
     return (int)LType - (int)RType;
 
-  if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.isEquivalent(LHS, RHS))
+  if (EqCacheSCEV.isEquivalent(LHS, RHS))
     return 0;
+
+  if (Depth > MaxSCEVCompareDepth)
+    return None;
+
   // Aside from the getSCEVType() ordering, the particular ordering
   // isn't very important except that it's beneficial to be consistent,
   // so that (a + b) and (b + a) don't end up as 
diff erent expressions.
@@ -759,9 +765,9 @@ static int CompareSCEVComplexity(
 
     // Lexicographically compare.
     for (unsigned i = 0; i != LNumOps; ++i) {
-      int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
-                                    LA->getOperand(i), RA->getOperand(i), DT,
-                                    Depth + 1);
+      auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
+                                     LA->getOperand(i), RA->getOperand(i), DT,
+                                     Depth + 1);
       if (X != 0)
         return X;
     }
@@ -784,9 +790,9 @@ static int CompareSCEVComplexity(
       return (int)LNumOps - (int)RNumOps;
 
     for (unsigned i = 0; i != LNumOps; ++i) {
-      int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
-                                    LC->getOperand(i), RC->getOperand(i), DT,
-                                    Depth + 1);
+      auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
+                                     LC->getOperand(i), RC->getOperand(i), DT,
+                                     Depth + 1);
       if (X != 0)
         return X;
     }
@@ -799,8 +805,8 @@ static int CompareSCEVComplexity(
     const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
 
     // Lexicographically compare udiv expressions.
-    int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
-                                  RC->getLHS(), DT, Depth + 1);
+    auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
+                                   RC->getLHS(), DT, Depth + 1);
     if (X != 0)
       return X;
     X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
@@ -818,9 +824,9 @@ static int CompareSCEVComplexity(
     const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
 
     // Compare cast expressions by operand.
-    int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
-                                  LC->getOperand(), RC->getOperand(), DT,
-                                  Depth + 1);
+    auto X =
+        CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
+                              RC->getOperand(), DT, Depth + 1);
     if (X == 0)
       EqCacheSCEV.unionSets(LHS, RHS);
     return X;
@@ -847,19 +853,25 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
 
   EquivalenceClasses<const SCEV *> EqCacheSCEV;
   EquivalenceClasses<const Value *> EqCacheValue;
+
+  // Whether LHS has provably less complexity than RHS.
+  auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
+    auto Complexity =
+        CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
+    return Complexity && *Complexity < 0;
+  };
   if (Ops.size() == 2) {
     // This is the common case, which also happens to be trivially simple.
     // Special case it.
     const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
-    if (CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, RHS, LHS, DT) < 0)
+    if (IsLessComplex(RHS, LHS))
       std::swap(LHS, RHS);
     return;
   }
 
   // Do the rough sort by complexity.
   llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
-    return CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT) <
-           0;
+    return IsLessComplex(LHS, RHS);
   });
 
   // Now that we are sorted by complexity, group elements of the same

diff  --git a/llvm/test/Transforms/LoopStrengthReduce/pr48725.ll b/llvm/test/Transforms/LoopStrengthReduce/pr48725.ll
new file mode 100644
index 000000000000..dfba28a38018
--- /dev/null
+++ b/llvm/test/Transforms/LoopStrengthReduce/pr48725.ll
@@ -0,0 +1,101 @@
+; RUN: opt -S -loop-reduce < %s | FileCheck %s
+
+source_filename = "./simple.ll"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:1-p2:32:8:8:32-ni:2"
+target triple = "x86_64-unknown-linux-gnu"
+
+; CHECK-LABEL: test
+define void @test() {
+bb:
+  br label %bb1
+
+bb1:                                              ; preds = %bb1, %bb
+  %tmp = phi i32 [ undef, %bb ], [ %tmp87, %bb1 ]
+  %tmp2 = phi i32 [ undef, %bb ], [ %tmp86, %bb1 ]
+  %tmp3 = mul i32 %tmp, undef
+  %tmp4 = xor i32 %tmp3, -1
+  %tmp5 = add i32 %tmp, %tmp4
+  %tmp6 = add i32 %tmp2, -1
+  %tmp7 = add i32 %tmp5, %tmp6
+  %tmp8 = mul i32 %tmp7, %tmp3
+  %tmp9 = xor i32 %tmp8, -1
+  %tmp10 = add i32 %tmp7, %tmp9
+  %tmp11 = add i32 %tmp10, undef
+  %tmp12 = mul i32 %tmp11, %tmp8
+  %tmp13 = xor i32 %tmp12, -1
+  %tmp14 = add i32 %tmp11, %tmp13
+  %tmp15 = add i32 %tmp14, undef
+  %tmp16 = mul i32 %tmp15, %tmp12
+  %tmp17 = add i32 %tmp15, undef
+  %tmp18 = add i32 %tmp17, undef
+  %tmp19 = mul i32 %tmp18, %tmp16
+  %tmp20 = xor i32 %tmp19, -1
+  %tmp21 = add i32 %tmp18, %tmp20
+  %tmp22 = add i32 %tmp21, undef
+  %tmp23 = mul i32 %tmp22, %tmp19
+  %tmp24 = xor i32 %tmp23, -1
+  %tmp25 = add i32 %tmp22, %tmp24
+  %tmp26 = add i32 %tmp25, undef
+  %tmp27 = mul i32 %tmp26, %tmp23
+  %tmp28 = xor i32 %tmp27, -1
+  %tmp29 = add i32 %tmp26, %tmp28
+  %tmp30 = add i32 %tmp29, undef
+  %tmp31 = mul i32 %tmp30, %tmp27
+  %tmp32 = xor i32 %tmp31, -1
+  %tmp33 = add i32 %tmp30, %tmp32
+  %tmp34 = add i32 %tmp33, undef
+  %tmp35 = mul i32 %tmp34, %tmp31
+  %tmp36 = xor i32 %tmp35, -1
+  %tmp37 = add i32 %tmp34, %tmp36
+  %tmp38 = add i32 %tmp2, -9
+  %tmp39 = add i32 %tmp37, %tmp38
+  %tmp40 = mul i32 %tmp39, %tmp35
+  %tmp41 = xor i32 %tmp40, -1
+  %tmp42 = add i32 %tmp39, %tmp41
+  %tmp43 = add i32 %tmp42, undef
+  %tmp44 = mul i32 %tmp43, %tmp40
+  %tmp45 = xor i32 %tmp44, -1
+  %tmp46 = add i32 %tmp43, %tmp45
+  %tmp47 = add i32 %tmp46, undef
+  %tmp48 = mul i32 %tmp47, %tmp44
+  %tmp49 = xor i32 %tmp48, -1
+  %tmp50 = add i32 %tmp47, %tmp49
+  %tmp51 = add i32 %tmp50, undef
+  %tmp52 = mul i32 %tmp51, %tmp48
+  %tmp53 = xor i32 %tmp52, -1
+  %tmp54 = add i32 %tmp51, %tmp53
+  %tmp55 = add i32 %tmp54, undef
+  %tmp56 = mul i32 %tmp55, %tmp52
+  %tmp57 = xor i32 %tmp56, -1
+  %tmp58 = add i32 %tmp55, %tmp57
+  %tmp59 = add i32 %tmp2, -14
+  %tmp60 = add i32 %tmp58, %tmp59
+  %tmp61 = mul i32 %tmp60, %tmp56
+  %tmp62 = xor i32 %tmp61, -1
+  %tmp63 = add i32 %tmp60, %tmp62
+  %tmp64 = add i32 %tmp63, undef
+  %tmp65 = mul i32 %tmp64, %tmp61
+  %tmp66 = xor i32 %tmp65, -1
+  %tmp67 = add i32 %tmp64, %tmp66
+  %tmp68 = add i32 %tmp67, undef
+  %tmp69 = mul i32 %tmp68, %tmp65
+  %tmp70 = xor i32 %tmp69, -1
+  %tmp71 = add i32 %tmp68, %tmp70
+  %tmp72 = add i32 %tmp71, undef
+  %tmp73 = mul i32 %tmp72, %tmp69
+  %tmp74 = xor i32 %tmp73, -1
+  %tmp75 = add i32 %tmp72, %tmp74
+  %tmp76 = add i32 %tmp75, undef
+  %tmp77 = mul i32 %tmp76, %tmp73
+  %tmp78 = xor i32 %tmp77, -1
+  %tmp79 = add i32 %tmp76, %tmp78
+  %tmp80 = add i32 %tmp79, undef
+  %tmp81 = mul i32 %tmp80, %tmp77
+  %tmp82 = xor i32 %tmp81, -1
+  %tmp83 = add i32 %tmp80, %tmp82
+  %tmp84 = add i32 %tmp83, undef
+  %tmp85 = add i32 %tmp84, undef
+  %tmp86 = add i32 %tmp2, -21
+  %tmp87 = add i32 %tmp85, %tmp86
+  br label %bb1
+}


        


More information about the llvm-commits mailing list