[llvm] r285535 - [SCEV] Try to order n-ary expressions in CompareValueComplexity

Sanjoy Das via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 30 20:32:43 PDT 2016


Author: sanjoy
Date: Sun Oct 30 22:32:43 2016
New Revision: 285535

URL: http://llvm.org/viewvc/llvm-project?rev=285535&view=rev
Log:
[SCEV] Try to order n-ary expressions in CompareValueComplexity

Modified:
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp
    llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=285535&r1=285534&r2=285535&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Sun Oct 30 22:32:43 2016
@@ -62,6 +62,7 @@
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AssumptionCache.h"
@@ -453,9 +454,29 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy,
 //                               SCEV Utilities
 //===----------------------------------------------------------------------===//
 
-static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
-                                  Value *RV, unsigned DepthLeft = 2) {
-  if (DepthLeft == 0)
+/// Compare the two values \p LV and \p RV in terms of their "complexity" where
+/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
+/// operands in SCEV expressions.  \p EqCache is a set of pairs of values that
+/// have been previously deemed to be "equally complex" by this routine.  It is
+/// intended to avoid exponential time complexity in cases like:
+///
+///   %a = f(%x, %y)
+///   %b = f(%a, %a)
+///   %c = f(%b, %b)
+///
+///   %d = f(%x, %y)
+///   %e = f(%d, %d)
+///   %f = f(%e, %e)
+///
+///   CompareValueComplexity(%f, %c)
+///
+/// Since we do not continue running this routine on expression trees once we
+/// have seen unequal values, there is no need to track them in the cache.
+static int
+CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache,
+                       const LoopInfo *const LI, Value *LV, Value *RV,
+                       unsigned DepthLeft = 2) {
+  if (DepthLeft == 0 || EqCache.count({LV, RV}))
     return 0;
 
   // Order pointer values after integer values. This helps SCEVExpander form
@@ -510,14 +531,17 @@ static int CompareValueComplexity(const
     // Compare the number of operands.
     unsigned LNumOps = LInst->getNumOperands(),
              RNumOps = RInst->getNumOperands();
-    if (LNumOps != RNumOps || LNumOps != 1)
+    if (LNumOps != RNumOps)
       return (int)LNumOps - (int)RNumOps;
 
-    // We only bother "recursing" if we have one operand to look at (so we don't
-    // really recurse as much as we iterate).  We can consider expanding this
-    // logic in the future.
-    return CompareValueComplexity(LI, LInst->getOperand(0),
-                                  RInst->getOperand(0), DepthLeft - 1);
+    for (unsigned Idx : seq(0u, LNumOps)) {
+      int Result =
+          CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx),
+                                 RInst->getOperand(Idx), DepthLeft - 1);
+      if (Result != 0)
+	return Result;
+      EqCache.insert({LV, RV});
+    }
   }
 
   return 0;
@@ -545,7 +569,8 @@ static int CompareSCEVComplexity(const L
     const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
     const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
 
-    return CompareValueComplexity(LI, LU->getValue(), RU->getValue());
+    SmallSet<std::pair<Value *, Value *>, 8> EqCache;
+    return CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue());
   }
 
   case scConstant: {

Modified: llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp?rev=285535&r1=285534&r2=285535&view=diff
==============================================================================
--- llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp (original)
+++ llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp Sun Oct 30 22:32:43 2016
@@ -349,6 +349,8 @@ TEST_F(ScalarEvolutionsTest, Commutative
       "@var_0 = external global i32, align 4"
       "@var_1 = external global i32, align 4"
       " "
+      "declare i32 @unknown(i32, i32, i32)"
+      " "
       "define void @f_1(i8* nocapture %arr, i32 %n, i32* %A, i32* %B) "
       "    local_unnamed_addr { "
       "entry: "
@@ -391,6 +393,13 @@ TEST_F(ScalarEvolutionsTest, Commutative
       "  %y = load i32, i32* @var_1"
       "  ret void"
       "} "
+      " "
+      "define void @f_4(i32 %a, i32 %b, i32 %c) { "
+      "  %x = call i32 @unknown(i32 %a, i32 %b, i32 %c)"
+      "  %y = call i32 @unknown(i32 %b, i32 %c, i32 %a)"
+      "  %z = call i32 @unknown(i32 %c, i32 %a, i32 %b)"
+      "  ret void"
+      "} "
       ,
       Err, C);
 
@@ -419,22 +428,18 @@ TEST_F(ScalarEvolutionsTest, Commutative
     EXPECT_TRUE(isa<SCEVAddRecExpr>(SecondExprForIV0));
   });
 
-  RunWithFunctionAndSE("f_2", [&](Function &F, ScalarEvolution &SE) {
-    auto *LoadArg0 = SE.getSCEV(getInstructionByName(F, "x"));
-    auto *LoadArg1 = SE.getSCEV(getInstructionByName(F, "y"));
-    auto *LoadArg2 = SE.getSCEV(getInstructionByName(F, "z"));
-
-    auto *MulA = SE.getMulExpr(LoadArg0, LoadArg1);
-    auto *MulB = SE.getMulExpr(LoadArg1, LoadArg0);
-
-    EXPECT_EQ(MulA, MulB);
-
-    SmallVector<const SCEV *, 3> Ops0 = {LoadArg0, LoadArg1, LoadArg2};
-    SmallVector<const SCEV *, 3> Ops1 = {LoadArg0, LoadArg2, LoadArg1};
-    SmallVector<const SCEV *, 3> Ops2 = {LoadArg1, LoadArg0, LoadArg2};
-    SmallVector<const SCEV *, 3> Ops3 = {LoadArg1, LoadArg2, LoadArg0};
-    SmallVector<const SCEV *, 3> Ops4 = {LoadArg2, LoadArg1, LoadArg0};
-    SmallVector<const SCEV *, 3> Ops5 = {LoadArg2, LoadArg0, LoadArg1};
+  auto CheckCommutativeMulExprs = [&](ScalarEvolution &SE, const SCEV *A,
+                                      const SCEV *B, const SCEV *C) {
+    EXPECT_EQ(SE.getMulExpr(A, B), SE.getMulExpr(B, A));
+    EXPECT_EQ(SE.getMulExpr(B, C), SE.getMulExpr(C, B));
+    EXPECT_EQ(SE.getMulExpr(A, C), SE.getMulExpr(C, A));
+
+    SmallVector<const SCEV *, 3> Ops0 = {A, B, C};
+    SmallVector<const SCEV *, 3> Ops1 = {A, C, B};
+    SmallVector<const SCEV *, 3> Ops2 = {B, A, C};
+    SmallVector<const SCEV *, 3> Ops3 = {B, C, A};
+    SmallVector<const SCEV *, 3> Ops4 = {C, B, A};
+    SmallVector<const SCEV *, 3> Ops5 = {C, A, B};
 
     auto *Mul0 = SE.getMulExpr(Ops0);
     auto *Mul1 = SE.getMulExpr(Ops1);
@@ -443,11 +448,17 @@ TEST_F(ScalarEvolutionsTest, Commutative
     auto *Mul4 = SE.getMulExpr(Ops4);
     auto *Mul5 = SE.getMulExpr(Ops5);
 
-    EXPECT_EQ(Mul0, Mul1);
-    EXPECT_EQ(Mul1, Mul2);
-    EXPECT_EQ(Mul2, Mul3);
-    EXPECT_EQ(Mul3, Mul4);
-    EXPECT_EQ(Mul4, Mul5);
+    EXPECT_EQ(Mul0, Mul1) << "Expected " << *Mul0 << " == " << *Mul1;
+    EXPECT_EQ(Mul1, Mul2) << "Expected " << *Mul1 << " == " << *Mul2;
+    EXPECT_EQ(Mul2, Mul3) << "Expected " << *Mul2 << " == " << *Mul3;
+    EXPECT_EQ(Mul3, Mul4) << "Expected " << *Mul3 << " == " << *Mul4;
+    EXPECT_EQ(Mul4, Mul5) << "Expected " << *Mul4 << " == " << *Mul5;
+  };
+
+  RunWithFunctionAndSE("f_2", [&](Function &F, ScalarEvolution &SE) {
+    CheckCommutativeMulExprs(SE, SE.getSCEV(getInstructionByName(F, "x")),
+                             SE.getSCEV(getInstructionByName(F, "y")),
+                             SE.getSCEV(getInstructionByName(F, "z")));
   });
 
   RunWithFunctionAndSE("f_3", [&](Function &F, ScalarEvolution &SE) {
@@ -459,6 +470,12 @@ TEST_F(ScalarEvolutionsTest, Commutative
 
     EXPECT_EQ(MulA, MulB) << "MulA = " << *MulA << ", MulB = " << *MulB;
   });
+
+  RunWithFunctionAndSE("f_4", [&](Function &F, ScalarEvolution &SE) {
+    CheckCommutativeMulExprs(SE, SE.getSCEV(getInstructionByName(F, "x")),
+                             SE.getSCEV(getInstructionByName(F, "y")),
+                             SE.getSCEV(getInstructionByName(F, "z")));
+  });
 }
 
 }  // end anonymous namespace




More information about the llvm-commits mailing list