[llvm] 8a4ad88 - [SCEV] Do not cache comparison result upon reached max depth as "equivalence". PR48725
Max Kazantsev via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 28 21:36:42 PST 2021
Author: Max Kazantsev
Date: 2021-01-29T12:08:34+07:00
New Revision: 8a4ad8849f4898dd19e31b9dcede7ace3575d00d
URL: https://github.com/llvm/llvm-project/commit/8a4ad8849f4898dd19e31b9dcede7ace3575d00d
DIFF: https://github.com/llvm/llvm-project/commit/8a4ad8849f4898dd19e31b9dcede7ace3575d00d.diff
LOG: [SCEV] Do not cache comparison result upon reached max depth as "equivalence". PR48725
We use `EquivalenceClasses` to cache the notion that two SCEVs are equivalent,
so save time in situation when `A` is equivalent to `B` and `B` is equivalent to `C`,
making check "if `A` is equivalent to `C`?" cheaper.
We also return `0` in the comparator when we reach max analysis depth to save
compile time. After doing this, we also cache them as being equivalent.
Now, imagine the following situation:
- `A` is proved equivalent to `B`;
- `C` is proved equivalent to `D`;
- Comparison of `A` against `D` is proved non-zero;
- Comparison of `B` against `C` reaches max depth (and gets cached as equivalence).
Now, before the invocation of compare(`B`, `C`), `A` and `D` belonged
to different equivalence classes, and their comparison returned non-zero.
After the the invocation of compare(`B`, `C`), equivalence classes get merged
and `A`, `B`, `C` and `D` all fall into the same equivalence class. So the comparator
will change its behavior for couple `A` and `D`, with weird consequences following it.
This comparator is finally used in `std::stable_sort`, and this behavior change
makes it crash (looks like it's causing a memory corruption).
Solution: this patch changes `CompareSCEVComplexity` to return `None`
when the max depth is reached. So in this case, we do not cache these SCEVs
(and their parents in the tree) as being equivalent.
Differential Revision: https://reviews.llvm.org/D94654
Reviewed By: lebedev.ri
Added:
llvm/test/Transforms/LoopStrengthReduce/pr48725.ll
Modified:
llvm/lib/Analysis/ScalarEvolution.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 0b575ae421fc..134d4d0fc8c5 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -689,11 +689,13 @@ CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
// Return negative, zero, or positive, if LHS is less than, equal to, or greater
// than RHS, respectively. A three-way result allows recursive comparisons to be
// more efficient.
-static int CompareSCEVComplexity(
- EquivalenceClasses<const SCEV *> &EqCacheSCEV,
- EquivalenceClasses<const Value *> &EqCacheValue,
- const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS,
- DominatorTree &DT, unsigned Depth = 0) {
+// If the max analysis depth was reached, return None, assuming we do not know
+// if they are equivalent for sure.
+static Optional<int>
+CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
+ EquivalenceClasses<const Value *> &EqCacheValue,
+ const LoopInfo *const LI, const SCEV *LHS,
+ const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
// Fast-path: SCEVs are uniqued so we can do a quick equality check.
if (LHS == RHS)
return 0;
@@ -703,8 +705,12 @@ static int CompareSCEVComplexity(
if (LType != RType)
return (int)LType - (int)RType;
- if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.isEquivalent(LHS, RHS))
+ if (EqCacheSCEV.isEquivalent(LHS, RHS))
return 0;
+
+ if (Depth > MaxSCEVCompareDepth)
+ return None;
+
// Aside from the getSCEVType() ordering, the particular ordering
// isn't very important except that it's beneficial to be consistent,
// so that (a + b) and (b + a) don't end up as
diff erent expressions.
@@ -759,9 +765,9 @@ static int CompareSCEVComplexity(
// Lexicographically compare.
for (unsigned i = 0; i != LNumOps; ++i) {
- int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
- LA->getOperand(i), RA->getOperand(i), DT,
- Depth + 1);
+ auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
+ LA->getOperand(i), RA->getOperand(i), DT,
+ Depth + 1);
if (X != 0)
return X;
}
@@ -784,9 +790,9 @@ static int CompareSCEVComplexity(
return (int)LNumOps - (int)RNumOps;
for (unsigned i = 0; i != LNumOps; ++i) {
- int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
- LC->getOperand(i), RC->getOperand(i), DT,
- Depth + 1);
+ auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
+ LC->getOperand(i), RC->getOperand(i), DT,
+ Depth + 1);
if (X != 0)
return X;
}
@@ -799,8 +805,8 @@ static int CompareSCEVComplexity(
const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
// Lexicographically compare udiv expressions.
- int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
- RC->getLHS(), DT, Depth + 1);
+ auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
+ RC->getLHS(), DT, Depth + 1);
if (X != 0)
return X;
X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
@@ -818,9 +824,9 @@ static int CompareSCEVComplexity(
const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
// Compare cast expressions by operand.
- int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
- LC->getOperand(), RC->getOperand(), DT,
- Depth + 1);
+ auto X =
+ CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
+ RC->getOperand(), DT, Depth + 1);
if (X == 0)
EqCacheSCEV.unionSets(LHS, RHS);
return X;
@@ -847,19 +853,25 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
EquivalenceClasses<const SCEV *> EqCacheSCEV;
EquivalenceClasses<const Value *> EqCacheValue;
+
+ // Whether LHS has provably less complexity than RHS.
+ auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
+ auto Complexity =
+ CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
+ return Complexity && *Complexity < 0;
+ };
if (Ops.size() == 2) {
// This is the common case, which also happens to be trivially simple.
// Special case it.
const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
- if (CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, RHS, LHS, DT) < 0)
+ if (IsLessComplex(RHS, LHS))
std::swap(LHS, RHS);
return;
}
// Do the rough sort by complexity.
llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
- return CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT) <
- 0;
+ return IsLessComplex(LHS, RHS);
});
// Now that we are sorted by complexity, group elements of the same
diff --git a/llvm/test/Transforms/LoopStrengthReduce/pr48725.ll b/llvm/test/Transforms/LoopStrengthReduce/pr48725.ll
new file mode 100644
index 000000000000..dfba28a38018
--- /dev/null
+++ b/llvm/test/Transforms/LoopStrengthReduce/pr48725.ll
@@ -0,0 +1,101 @@
+; RUN: opt -S -loop-reduce < %s | FileCheck %s
+
+source_filename = "./simple.ll"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:1-p2:32:8:8:32-ni:2"
+target triple = "x86_64-unknown-linux-gnu"
+
+; CHECK-LABEL: test
+define void @test() {
+bb:
+ br label %bb1
+
+bb1: ; preds = %bb1, %bb
+ %tmp = phi i32 [ undef, %bb ], [ %tmp87, %bb1 ]
+ %tmp2 = phi i32 [ undef, %bb ], [ %tmp86, %bb1 ]
+ %tmp3 = mul i32 %tmp, undef
+ %tmp4 = xor i32 %tmp3, -1
+ %tmp5 = add i32 %tmp, %tmp4
+ %tmp6 = add i32 %tmp2, -1
+ %tmp7 = add i32 %tmp5, %tmp6
+ %tmp8 = mul i32 %tmp7, %tmp3
+ %tmp9 = xor i32 %tmp8, -1
+ %tmp10 = add i32 %tmp7, %tmp9
+ %tmp11 = add i32 %tmp10, undef
+ %tmp12 = mul i32 %tmp11, %tmp8
+ %tmp13 = xor i32 %tmp12, -1
+ %tmp14 = add i32 %tmp11, %tmp13
+ %tmp15 = add i32 %tmp14, undef
+ %tmp16 = mul i32 %tmp15, %tmp12
+ %tmp17 = add i32 %tmp15, undef
+ %tmp18 = add i32 %tmp17, undef
+ %tmp19 = mul i32 %tmp18, %tmp16
+ %tmp20 = xor i32 %tmp19, -1
+ %tmp21 = add i32 %tmp18, %tmp20
+ %tmp22 = add i32 %tmp21, undef
+ %tmp23 = mul i32 %tmp22, %tmp19
+ %tmp24 = xor i32 %tmp23, -1
+ %tmp25 = add i32 %tmp22, %tmp24
+ %tmp26 = add i32 %tmp25, undef
+ %tmp27 = mul i32 %tmp26, %tmp23
+ %tmp28 = xor i32 %tmp27, -1
+ %tmp29 = add i32 %tmp26, %tmp28
+ %tmp30 = add i32 %tmp29, undef
+ %tmp31 = mul i32 %tmp30, %tmp27
+ %tmp32 = xor i32 %tmp31, -1
+ %tmp33 = add i32 %tmp30, %tmp32
+ %tmp34 = add i32 %tmp33, undef
+ %tmp35 = mul i32 %tmp34, %tmp31
+ %tmp36 = xor i32 %tmp35, -1
+ %tmp37 = add i32 %tmp34, %tmp36
+ %tmp38 = add i32 %tmp2, -9
+ %tmp39 = add i32 %tmp37, %tmp38
+ %tmp40 = mul i32 %tmp39, %tmp35
+ %tmp41 = xor i32 %tmp40, -1
+ %tmp42 = add i32 %tmp39, %tmp41
+ %tmp43 = add i32 %tmp42, undef
+ %tmp44 = mul i32 %tmp43, %tmp40
+ %tmp45 = xor i32 %tmp44, -1
+ %tmp46 = add i32 %tmp43, %tmp45
+ %tmp47 = add i32 %tmp46, undef
+ %tmp48 = mul i32 %tmp47, %tmp44
+ %tmp49 = xor i32 %tmp48, -1
+ %tmp50 = add i32 %tmp47, %tmp49
+ %tmp51 = add i32 %tmp50, undef
+ %tmp52 = mul i32 %tmp51, %tmp48
+ %tmp53 = xor i32 %tmp52, -1
+ %tmp54 = add i32 %tmp51, %tmp53
+ %tmp55 = add i32 %tmp54, undef
+ %tmp56 = mul i32 %tmp55, %tmp52
+ %tmp57 = xor i32 %tmp56, -1
+ %tmp58 = add i32 %tmp55, %tmp57
+ %tmp59 = add i32 %tmp2, -14
+ %tmp60 = add i32 %tmp58, %tmp59
+ %tmp61 = mul i32 %tmp60, %tmp56
+ %tmp62 = xor i32 %tmp61, -1
+ %tmp63 = add i32 %tmp60, %tmp62
+ %tmp64 = add i32 %tmp63, undef
+ %tmp65 = mul i32 %tmp64, %tmp61
+ %tmp66 = xor i32 %tmp65, -1
+ %tmp67 = add i32 %tmp64, %tmp66
+ %tmp68 = add i32 %tmp67, undef
+ %tmp69 = mul i32 %tmp68, %tmp65
+ %tmp70 = xor i32 %tmp69, -1
+ %tmp71 = add i32 %tmp68, %tmp70
+ %tmp72 = add i32 %tmp71, undef
+ %tmp73 = mul i32 %tmp72, %tmp69
+ %tmp74 = xor i32 %tmp73, -1
+ %tmp75 = add i32 %tmp72, %tmp74
+ %tmp76 = add i32 %tmp75, undef
+ %tmp77 = mul i32 %tmp76, %tmp73
+ %tmp78 = xor i32 %tmp77, -1
+ %tmp79 = add i32 %tmp76, %tmp78
+ %tmp80 = add i32 %tmp79, undef
+ %tmp81 = mul i32 %tmp80, %tmp77
+ %tmp82 = xor i32 %tmp81, -1
+ %tmp83 = add i32 %tmp80, %tmp82
+ %tmp84 = add i32 %tmp83, undef
+ %tmp85 = add i32 %tmp84, undef
+ %tmp86 = add i32 %tmp2, -21
+ %tmp87 = add i32 %tmp85, %tmp86
+ br label %bb1
+}
More information about the llvm-commits
mailing list