[llvm] r285535 - [SCEV] Try to order n-ary expressions in CompareValueComplexity
Sanjoy Das via llvm-commits
llvm-commits at lists.llvm.org
Sun Oct 30 20:32:43 PDT 2016
Author: sanjoy
Date: Sun Oct 30 22:32:43 2016
New Revision: 285535
URL: http://llvm.org/viewvc/llvm-project?rev=285535&view=rev
Log:
[SCEV] Try to order n-ary expressions in CompareValueComplexity
Modified:
llvm/trunk/lib/Analysis/ScalarEvolution.cpp
llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp
Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=285535&r1=285534&r2=285535&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Sun Oct 30 22:32:43 2016
@@ -62,6 +62,7 @@
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
@@ -453,9 +454,29 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy,
// SCEV Utilities
//===----------------------------------------------------------------------===//
-static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
- Value *RV, unsigned DepthLeft = 2) {
- if (DepthLeft == 0)
+/// Compare the two values \p LV and \p RV in terms of their "complexity" where
+/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
+/// operands in SCEV expressions. \p EqCache is a set of pairs of values that
+/// have been previously deemed to be "equally complex" by this routine. It is
+/// intended to avoid exponential time complexity in cases like:
+///
+/// %a = f(%x, %y)
+/// %b = f(%a, %a)
+/// %c = f(%b, %b)
+///
+/// %d = f(%x, %y)
+/// %e = f(%d, %d)
+/// %f = f(%e, %e)
+///
+/// CompareValueComplexity(%f, %c)
+///
+/// Since we do not continue running this routine on expression trees once we
+/// have seen unequal values, there is no need to track them in the cache.
+static int
+CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache,
+ const LoopInfo *const LI, Value *LV, Value *RV,
+ unsigned DepthLeft = 2) {
+ if (DepthLeft == 0 || EqCache.count({LV, RV}))
return 0;
// Order pointer values after integer values. This helps SCEVExpander form
@@ -510,14 +531,17 @@ static int CompareValueComplexity(const
// Compare the number of operands.
unsigned LNumOps = LInst->getNumOperands(),
RNumOps = RInst->getNumOperands();
- if (LNumOps != RNumOps || LNumOps != 1)
+ if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
- // We only bother "recursing" if we have one operand to look at (so we don't
- // really recurse as much as we iterate). We can consider expanding this
- // logic in the future.
- return CompareValueComplexity(LI, LInst->getOperand(0),
- RInst->getOperand(0), DepthLeft - 1);
+ for (unsigned Idx : seq(0u, LNumOps)) {
+ int Result =
+ CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx),
+ RInst->getOperand(Idx), DepthLeft - 1);
+ if (Result != 0)
+ return Result;
+ EqCache.insert({LV, RV});
+ }
}
return 0;
@@ -545,7 +569,8 @@ static int CompareSCEVComplexity(const L
const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
- return CompareValueComplexity(LI, LU->getValue(), RU->getValue());
+ SmallSet<std::pair<Value *, Value *>, 8> EqCache;
+ return CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue());
}
case scConstant: {
Modified: llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp?rev=285535&r1=285534&r2=285535&view=diff
==============================================================================
--- llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp (original)
+++ llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp Sun Oct 30 22:32:43 2016
@@ -349,6 +349,8 @@ TEST_F(ScalarEvolutionsTest, Commutative
"@var_0 = external global i32, align 4"
"@var_1 = external global i32, align 4"
" "
+ "declare i32 @unknown(i32, i32, i32)"
+ " "
"define void @f_1(i8* nocapture %arr, i32 %n, i32* %A, i32* %B) "
" local_unnamed_addr { "
"entry: "
@@ -391,6 +393,13 @@ TEST_F(ScalarEvolutionsTest, Commutative
" %y = load i32, i32* @var_1"
" ret void"
"} "
+ " "
+ "define void @f_4(i32 %a, i32 %b, i32 %c) { "
+ " %x = call i32 @unknown(i32 %a, i32 %b, i32 %c)"
+ " %y = call i32 @unknown(i32 %b, i32 %c, i32 %a)"
+ " %z = call i32 @unknown(i32 %c, i32 %a, i32 %b)"
+ " ret void"
+ "} "
,
Err, C);
@@ -419,22 +428,18 @@ TEST_F(ScalarEvolutionsTest, Commutative
EXPECT_TRUE(isa<SCEVAddRecExpr>(SecondExprForIV0));
});
- RunWithFunctionAndSE("f_2", [&](Function &F, ScalarEvolution &SE) {
- auto *LoadArg0 = SE.getSCEV(getInstructionByName(F, "x"));
- auto *LoadArg1 = SE.getSCEV(getInstructionByName(F, "y"));
- auto *LoadArg2 = SE.getSCEV(getInstructionByName(F, "z"));
-
- auto *MulA = SE.getMulExpr(LoadArg0, LoadArg1);
- auto *MulB = SE.getMulExpr(LoadArg1, LoadArg0);
-
- EXPECT_EQ(MulA, MulB);
-
- SmallVector<const SCEV *, 3> Ops0 = {LoadArg0, LoadArg1, LoadArg2};
- SmallVector<const SCEV *, 3> Ops1 = {LoadArg0, LoadArg2, LoadArg1};
- SmallVector<const SCEV *, 3> Ops2 = {LoadArg1, LoadArg0, LoadArg2};
- SmallVector<const SCEV *, 3> Ops3 = {LoadArg1, LoadArg2, LoadArg0};
- SmallVector<const SCEV *, 3> Ops4 = {LoadArg2, LoadArg1, LoadArg0};
- SmallVector<const SCEV *, 3> Ops5 = {LoadArg2, LoadArg0, LoadArg1};
+ auto CheckCommutativeMulExprs = [&](ScalarEvolution &SE, const SCEV *A,
+ const SCEV *B, const SCEV *C) {
+ EXPECT_EQ(SE.getMulExpr(A, B), SE.getMulExpr(B, A));
+ EXPECT_EQ(SE.getMulExpr(B, C), SE.getMulExpr(C, B));
+ EXPECT_EQ(SE.getMulExpr(A, C), SE.getMulExpr(C, A));
+
+ SmallVector<const SCEV *, 3> Ops0 = {A, B, C};
+ SmallVector<const SCEV *, 3> Ops1 = {A, C, B};
+ SmallVector<const SCEV *, 3> Ops2 = {B, A, C};
+ SmallVector<const SCEV *, 3> Ops3 = {B, C, A};
+ SmallVector<const SCEV *, 3> Ops4 = {C, B, A};
+ SmallVector<const SCEV *, 3> Ops5 = {C, A, B};
auto *Mul0 = SE.getMulExpr(Ops0);
auto *Mul1 = SE.getMulExpr(Ops1);
@@ -443,11 +448,17 @@ TEST_F(ScalarEvolutionsTest, Commutative
auto *Mul4 = SE.getMulExpr(Ops4);
auto *Mul5 = SE.getMulExpr(Ops5);
- EXPECT_EQ(Mul0, Mul1);
- EXPECT_EQ(Mul1, Mul2);
- EXPECT_EQ(Mul2, Mul3);
- EXPECT_EQ(Mul3, Mul4);
- EXPECT_EQ(Mul4, Mul5);
+ EXPECT_EQ(Mul0, Mul1) << "Expected " << *Mul0 << " == " << *Mul1;
+ EXPECT_EQ(Mul1, Mul2) << "Expected " << *Mul1 << " == " << *Mul2;
+ EXPECT_EQ(Mul2, Mul3) << "Expected " << *Mul2 << " == " << *Mul3;
+ EXPECT_EQ(Mul3, Mul4) << "Expected " << *Mul3 << " == " << *Mul4;
+ EXPECT_EQ(Mul4, Mul5) << "Expected " << *Mul4 << " == " << *Mul5;
+ };
+
+ RunWithFunctionAndSE("f_2", [&](Function &F, ScalarEvolution &SE) {
+ CheckCommutativeMulExprs(SE, SE.getSCEV(getInstructionByName(F, "x")),
+ SE.getSCEV(getInstructionByName(F, "y")),
+ SE.getSCEV(getInstructionByName(F, "z")));
});
RunWithFunctionAndSE("f_3", [&](Function &F, ScalarEvolution &SE) {
@@ -459,6 +470,12 @@ TEST_F(ScalarEvolutionsTest, Commutative
EXPECT_EQ(MulA, MulB) << "MulA = " << *MulA << ", MulB = " << *MulB;
});
+
+ RunWithFunctionAndSE("f_4", [&](Function &F, ScalarEvolution &SE) {
+ CheckCommutativeMulExprs(SE, SE.getSCEV(getInstructionByName(F, "x")),
+ SE.getSCEV(getInstructionByName(F, "y")),
+ SE.getSCEV(getInstructionByName(F, "z")));
+ });
}
} // end anonymous namespace
More information about the llvm-commits
mailing list