[llvm] r320515 - Reassociate: add global reassociation algorithm

Fiona Glaser via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 12 11:18:02 PST 2017


Author: escha
Date: Tue Dec 12 11:18:02 2017
New Revision: 320515

URL: http://llvm.org/viewvc/llvm-project?rev=320515&view=rev
Log:
Reassociate: add global reassociation algorithm

This algorithm (explained more in the source code) takes into account
global redundancies by building a "pair map" to find common subexprs.

The primary motivation of this is to handle situations like

foo = (a * b) * c
bar = (a * d) * c

where we currently don't identify that "a * c" is redundant.

Accordingly, it prioritizes the emission of a * c so that CSE
can remove the redundant calculation later.

Does not change the actual reassociation algorithm -- only the
order in which the reassociated operand chain is reconstructed.

Gives ~1.5% floating point math instruction count reduction on
a large offline suite of graphics shaders.

Modified:
    llvm/trunk/include/llvm/Transforms/Scalar/Reassociate.h
    llvm/trunk/lib/Transforms/Scalar/Reassociate.cpp
    llvm/trunk/test/Transforms/Reassociate/basictest.ll
    llvm/trunk/test/Transforms/Reassociate/fast-ReassociateVector.ll
    llvm/trunk/test/Transforms/Reassociate/fast-fp-commute.ll

Modified: llvm/trunk/include/llvm/Transforms/Scalar/Reassociate.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Transforms/Scalar/Reassociate.h?rev=320515&r1=320514&r2=320515&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Transforms/Scalar/Reassociate.h (original)
+++ llvm/trunk/include/llvm/Transforms/Scalar/Reassociate.h Tue Dec 12 11:18:02 2017
@@ -72,6 +72,13 @@ class ReassociatePass : public PassInfoM
   DenseMap<BasicBlock *, unsigned> RankMap;
   DenseMap<AssertingVH<Value>, unsigned> ValueRankMap;
   SetVector<AssertingVH<Instruction>> RedoInsts;
+
+  // Arbitrary, but prevents quadratic behavior.
+  static const unsigned GlobalReassociateLimit = 10;
+  static const unsigned NumBinaryOps =
+      Instruction::BinaryOpsEnd - Instruction::BinaryOpsBegin;
+  DenseMap<std::pair<Value *, Value *>, unsigned> PairMap[NumBinaryOps];
+
   bool MadeChange;
 
 public:
@@ -105,6 +112,7 @@ private:
                                  SetVector<AssertingVH<Instruction>> &Insts);
   void OptimizeInst(Instruction *I);
   Instruction *canonicalizeNegConstExpr(Instruction *I);
+  void BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT);
 };
 
 } // end namespace llvm

Modified: llvm/trunk/lib/Transforms/Scalar/Reassociate.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/Reassociate.cpp?rev=320515&r1=320514&r2=320515&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/Reassociate.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/Reassociate.cpp Tue Dec 12 11:18:02 2017
@@ -27,6 +27,7 @@
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/GlobalsModRef.h"
@@ -2184,11 +2185,104 @@ void ReassociatePass::ReassociateExpress
     return;
   }
 
+  if (Ops.size() > 2 && Ops.size() <= GlobalReassociateLimit) {
+    // Find the pair with the highest count in the pairmap and move it to the
+    // back of the list so that it can later be CSE'd.
+    // example:
+    //   a*b*c*d*e
+    // if c*e is the most "popular" pair, we can express this as
+    //   (((c*e)*d)*b)*a
+    unsigned Max = 1;
+    unsigned BestRank = 0;
+    std::pair<unsigned, unsigned> BestPair;
+    unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin;
+    for (unsigned i = 0; i < Ops.size() - 1; ++i)
+      for (unsigned j = i + 1; j < Ops.size(); ++j) {
+        unsigned Score = 0;
+        Value *Op0 = Ops[i].Op;
+        Value *Op1 = Ops[j].Op;
+        if (std::less<Value *>()(Op1, Op0))
+          std::swap(Op0, Op1);
+        auto it = PairMap[Idx].find({Op0, Op1});
+        if (it != PairMap[Idx].end())
+          Score += it->second;
+
+        unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank);
+        if (Score > Max || (Score == Max && MaxRank < BestRank)) {
+          BestPair = {i, j};
+          Max = Score;
+          BestRank = MaxRank;
+        }
+      }
+    if (Max > 1) {
+      auto Op0 = Ops[BestPair.first];
+      auto Op1 = Ops[BestPair.second];
+      Ops.erase(&Ops[BestPair.second]);
+      Ops.erase(&Ops[BestPair.first]);
+      Ops.push_back(Op0);
+      Ops.push_back(Op1);
+    }
+  }
   // Now that we ordered and optimized the expressions, splat them back into
   // the expression tree, removing any unneeded nodes.
   RewriteExprTree(I, Ops);
 }
 
+void
+ReassociatePass::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) {
+  // Make a "pairmap" of how often each operand pair occurs.
+  for (BasicBlock *BI : RPOT) {
+    for (Instruction &I : *BI) {
+      if (!I.isAssociative())
+        continue;
+
+      // Ignore nodes that aren't at the root of trees.
+      if (I.hasOneUse() && I.user_back()->getOpcode() == I.getOpcode())
+        continue;
+
+      // Collect all operands in a single reassociable expression.
+      // Since Reassociate has already been run once, we can assume things
+      // are already canonical according to Reassociation's regime.
+      SmallVector<Value *, 8> Worklist = { I.getOperand(0), I.getOperand(1) };
+      SmallVector<Value *, 8> Ops;
+      while (!Worklist.empty() && Ops.size() <= GlobalReassociateLimit) {
+        Value *Op = Worklist.pop_back_val();
+        Instruction *OpI = dyn_cast<Instruction>(Op);
+        if (!OpI || OpI->getOpcode() != I.getOpcode() || !OpI->hasOneUse()) {
+          Ops.push_back(Op);
+          continue;
+        }
+        // Be paranoid about self-referencing expressions in unreachable code.
+        if (OpI->getOperand(0) != OpI)
+          Worklist.push_back(OpI->getOperand(0));
+        if (OpI->getOperand(1) != OpI)
+          Worklist.push_back(OpI->getOperand(1));
+      }
+      // Skip extremely long expressions.
+      if (Ops.size() > GlobalReassociateLimit)
+        continue;
+
+      // Add all pairwise combinations of operands to the pair map.
+      unsigned BinaryIdx = I.getOpcode() - Instruction::BinaryOpsBegin;
+      SmallSet<std::pair<Value *, Value*>, 32> Visited;
+      for (unsigned i = 0; i < Ops.size() - 1; ++i) {
+        for (unsigned j = i + 1; j < Ops.size(); ++j) {
+          // Canonicalize operand orderings.
+          Value *Op0 = Ops[i];
+          Value *Op1 = Ops[j];
+          if (std::less<Value *>()(Op1, Op0))
+            std::swap(Op0, Op1);
+          if (!Visited.insert({Op0, Op1}).second)
+            continue;
+          auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, 1});
+          if (!res.second)
+            ++res.first->second;
+        }
+      }
+    }
+  }
+}
+
 PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) {
   // Get the functions basic blocks in Reverse Post Order. This order is used by
   // BuildRankMap to pre calculate ranks correctly. It also excludes dead basic
@@ -2199,8 +2293,20 @@ PreservedAnalyses ReassociatePass::run(F
   // Calculate the rank map for F.
   BuildRankMap(F, RPOT);
 
+  // Build the pair map before running reassociate.
+  // Technically this would be more accurate if we did it after one round
+  // of reassociation, but in practice it doesn't seem to help much on
+  // real-world code, so don't waste the compile time running reassociate
+  // twice.
+  // If a user wants, they could expicitly run reassociate twice in their
+  // pass pipeline for further potential gains.
+  // It might also be possible to update the pair map during runtime, but the
+  // overhead of that may be large if there's many reassociable chains.
+  BuildPairMap(RPOT);
+
   MadeChange = false;
-  // Traverse the same blocks that was analysed by BuildRankMap.
+
+  // Traverse the same blocks that were analysed by BuildRankMap.
   for (BasicBlock *BI : RPOT) {
     assert(RankMap.count(&*BI) && "BB should be ranked.");
     // Optimize every instruction in the basic block.
@@ -2239,9 +2345,11 @@ PreservedAnalyses ReassociatePass::run(F
     }
   }
 
-  // We are done with the rank map.
+  // We are done with the rank map and pair map.
   RankMap.clear();
   ValueRankMap.clear();
+  for (auto &Entry : PairMap)
+    Entry.clear();
 
   if (MadeChange) {
     PreservedAnalyses PA;

Modified: llvm/trunk/test/Transforms/Reassociate/basictest.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/Reassociate/basictest.ll?rev=320515&r1=320514&r2=320515&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/Reassociate/basictest.ll (original)
+++ llvm/trunk/test/Transforms/Reassociate/basictest.ll Tue Dec 12 11:18:02 2017
@@ -242,3 +242,18 @@ if.then:
 if.end:                                           ; preds = %entry
   ret i64 0
 }
+
+; CHECK-LABEL: @test17
+; CHECK: %[[A:.*]] = mul i32 %X4, %X3
+; CHECK-NEXT:  %[[C:.*]] = mul i32 %[[A]], %X1
+; CHECK-NEXT: %[[D:.*]] = mul i32 %[[A]], %X2
+; CHECK-NEXT: %[[E:.*]] = xor i32 %[[C]], %[[D]]
+; CHECK-NEXT: ret i32 %[[E]]
+define i32 @test17(i32 %X1, i32 %X2, i32 %X3, i32 %X4) {
+  %A = mul i32 %X3, %X1
+  %B = mul i32 %X3, %X2
+  %C = mul i32 %A, %X4
+  %D = mul i32 %B, %X4
+  %E = xor i32 %C, %D
+  ret i32 %E
+}

Modified: llvm/trunk/test/Transforms/Reassociate/fast-ReassociateVector.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/Reassociate/fast-ReassociateVector.ll?rev=320515&r1=320514&r2=320515&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/Reassociate/fast-ReassociateVector.ll (original)
+++ llvm/trunk/test/Transforms/Reassociate/fast-ReassociateVector.ll Tue Dec 12 11:18:02 2017
@@ -286,8 +286,8 @@ define <2 x float> @test10_reassoc(<2 x
 
 define <2 x double> @test11(<2 x double> %x, <2 x double> %y) {
 ; CHECK-LABEL: @test11(
-; CHECK-NEXT:    [[FACTOR:%.*]] = fmul fast <2 x double> [[X:%.*]], <double 2.000000e+00, double 2.000000e+00>
-; CHECK-NEXT:    [[REASS_MUL:%.*]] = fmul fast <2 x double> [[FACTOR]], [[Y:%.*]]
+; CHECK-NEXT:    [[FACTOR:%.*]] = fmul fast <2 x double> [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[REASS_MUL:%.*]] = fmul fast <2 x double> [[FACTOR]], <double 2.000000e+00, double 2.000000e+00>
 ; CHECK-NEXT:    ret <2 x double> [[REASS_MUL]]
 ;
   %1 = fmul fast <2 x double> %x, %y

Modified: llvm/trunk/test/Transforms/Reassociate/fast-fp-commute.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/Reassociate/fast-fp-commute.ll?rev=320515&r1=320514&r2=320515&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/Reassociate/fast-fp-commute.ll (original)
+++ llvm/trunk/test/Transforms/Reassociate/fast-fp-commute.ll Tue Dec 12 11:18:02 2017
@@ -34,8 +34,8 @@ define float @test2(float %x, float %y)
 
 define float @test3(float %x, float %y) {
 ; CHECK-LABEL: @test3(
-; CHECK-NEXT:    [[FACTOR:%.*]] = fmul fast float %x, 2.000000e+00
-; CHECK-NEXT:    [[REASS_MUL:%.*]] = fmul fast float [[FACTOR]], %y
+; CHECK-NEXT:    [[FACTOR:%.*]] = fmul fast float %y, %x
+; CHECK-NEXT:    [[REASS_MUL:%.*]] = fmul fast float [[FACTOR]], 2.000000e+00
 ; CHECK-NEXT:    ret float [[REASS_MUL]]
 ;
   %1 = fmul fast float %x, %y




More information about the llvm-commits mailing list