[llvm-commits] [llvm] r155616 - in /llvm/trunk: lib/Transforms/Scalar/Reassociate.cpp test/Transforms/Reassociate/mulfactor.ll

Chandler Carruth chandlerc at gmail.com
Wed Apr 25 22:30:30 PDT 2012


Author: chandlerc
Date: Thu Apr 26 00:30:30 2012
New Revision: 155616

URL: http://llvm.org/viewvc/llvm-project?rev=155616&view=rev
Log:
Teach the reassociate pass to fold chains of multiplies with repeated
elements to minimize the number of multiplies required to compute the
final result. This uses a heuristic to attempt to form near-optimal
binary exponentiation-style multiply chains. While there are some cases
it misses, it seems to at least a decent job on a very diverse range of
inputs.

Initial benchmarks show no interesting regressions, and an 8%
improvement on SPASS. Let me know if any other interesting results (in
either direction) crop up!

Credit to Richard Smith for the core algorithm, and helping code the
patch itself.

Modified:
    llvm/trunk/lib/Transforms/Scalar/Reassociate.cpp
    llvm/trunk/test/Transforms/Reassociate/mulfactor.ll

Modified: llvm/trunk/lib/Transforms/Scalar/Reassociate.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/Reassociate.cpp?rev=155616&r1=155615&r2=155616&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/Reassociate.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/Reassociate.cpp Thu Apr 26 00:30:30 2012
@@ -31,10 +31,12 @@
 #include "llvm/Pass.h"
 #include "llvm/Assembly/Writer.h"
 #include "llvm/Support/CFG.h"
+#include "llvm/Support/IRBuilder.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ValueHandle.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/DenseMap.h"
 #include <algorithm>
@@ -72,6 +74,45 @@
 #endif
   
 namespace {
+  /// \brief Utility class representing a base and exponent pair which form one
+  /// factor of some product.
+  struct Factor {
+    Value *Base;
+    unsigned Power;
+
+    Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {}
+
+    /// \brief Sort factors by their Base.
+    struct BaseSorter {
+      bool operator()(const Factor &LHS, const Factor &RHS) {
+        return LHS.Base < RHS.Base;
+      }
+    };
+
+    /// \brief Compare factors for equal bases.
+    struct BaseEqual {
+      bool operator()(const Factor &LHS, const Factor &RHS) {
+        return LHS.Base == RHS.Base;
+      }
+    };
+
+    /// \brief Sort factors in descending order by their power.
+    struct PowerDescendingSorter {
+      bool operator()(const Factor &LHS, const Factor &RHS) {
+        return LHS.Power > RHS.Power;
+      }
+    };
+
+    /// \brief Compare factors for equal powers.
+    struct PowerEqual {
+      bool operator()(const Factor &LHS, const Factor &RHS) {
+        return LHS.Power == RHS.Power;
+      }
+    };
+  };
+}
+
+namespace {
   class Reassociate : public FunctionPass {
     DenseMap<BasicBlock*, unsigned> RankMap;
     DenseMap<AssertingVH<Value>, unsigned> ValueRankMap;
@@ -98,6 +139,11 @@
     Value *OptimizeExpression(BinaryOperator *I,
                               SmallVectorImpl<ValueEntry> &Ops);
     Value *OptimizeAdd(Instruction *I, SmallVectorImpl<ValueEntry> &Ops);
+    bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops,
+                                SmallVectorImpl<Factor> &Factors);
+    Value *buildMinimalMultiplyDAG(IRBuilder<> &Builder,
+                                   SmallVectorImpl<Factor> &Factors);
+    Value *OptimizeMul(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
     void LinearizeExprTree(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
     void LinearizeExpr(BinaryOperator *I);
     Value *RemoveFactorFromExpression(Value *V, Value *Factor);
@@ -888,6 +934,199 @@
   return 0;
 }
 
+namespace {
+  /// \brief Predicate tests whether a ValueEntry's op is in a map.
+  struct IsValueInMap {
+    const DenseMap<Value *, unsigned> ⤅
+
+    IsValueInMap(const DenseMap<Value *, unsigned> &Map) : Map(Map) {}
+
+    bool operator()(const ValueEntry &Entry) {
+      return Map.find(Entry.Op) != Map.end();
+    }
+  };
+}
+
+/// \brief Build up a vector of value/power pairs factoring a product.
+///
+/// Given a series of multiplication operands, build a vector of factors and
+/// the powers each is raised to when forming the final product. Sort them in
+/// the order of descending power.
+///
+///      (x*x)          -> [(x, 2)]
+///     ((x*x)*x)       -> [(x, 3)]
+///   ((((x*y)*x)*y)*x) -> [(x, 3), (y, 2)]
+///
+/// \returns Whether any factors have a power greater than one.
+bool Reassociate::collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops,
+                                         SmallVectorImpl<Factor> &Factors) {
+  unsigned FactorPowerSum = 0;
+  DenseMap<Value *, unsigned> FactorCounts;
+  for (unsigned LastIdx = 0, Idx = 0, Size = Ops.size(); Idx < Size; ++Idx) {
+    // Note that 'use_empty' uses means the only use is in the linearized tree
+    // represented by Ops -- we remove the values from the actual operations to
+    // reduce their use count.
+    if (!Ops[Idx].Op->use_empty()) {
+      if (LastIdx == Idx)
+        ++LastIdx;
+      continue;
+    }
+    if (LastIdx == Idx || Ops[LastIdx].Op != Ops[Idx].Op) {
+      LastIdx = Idx;
+      continue;
+    }
+    // Track for simplification all factors which occur 2 or more times.
+    DenseMap<Value *, unsigned>::iterator CountIt;
+    bool Inserted;
+    llvm::tie(CountIt, Inserted)
+      = FactorCounts.insert(std::make_pair(Ops[Idx].Op, 2));
+    if (Inserted) {
+      FactorPowerSum += 2;
+      Factors.push_back(Factor(Ops[Idx].Op, 2));
+    } else {
+      ++CountIt->second;
+      ++FactorPowerSum;
+    }
+  }
+  // We can only simplify factors if the sum of the powers of our simplifiable
+  // factors is 4 or higher. When that is the case, we will *always* have
+  // a simplification. This is an important invariant to prevent cyclicly
+  // trying to simplify already minimal formations.
+  if (FactorPowerSum < 4)
+    return false;
+
+  // Remove all the operands which are in the map.
+  Ops.erase(std::remove_if(Ops.begin(), Ops.end(), IsValueInMap(FactorCounts)),
+            Ops.end());
+
+  // Record the adjusted power for the simplification factors. We add back into
+  // the Ops list any values with an odd power, and make the power even. This
+  // allows the outer-most multiplication tree to remain in tact during
+  // simplification.
+  unsigned OldOpsSize = Ops.size();
+  for (unsigned Idx = 0, Size = Factors.size(); Idx != Size; ++Idx) {
+    Factors[Idx].Power = FactorCounts[Factors[Idx].Base];
+    if (Factors[Idx].Power & 1) {
+      Ops.push_back(ValueEntry(getRank(Factors[Idx].Base), Factors[Idx].Base));
+      --Factors[Idx].Power;
+      --FactorPowerSum;
+    }
+  }
+  // None of the adjustments above should have reduced the sum of factor powers
+  // below our mininum of '4'.
+  assert(FactorPowerSum >= 4);
+
+  // Patch up the sort of the ops vector by sorting the factors we added back
+  // onto the back, and merging the two sequences.
+  if (OldOpsSize != Ops.size()) {
+    SmallVectorImpl<ValueEntry>::iterator MiddleIt = Ops.begin() + OldOpsSize;
+    std::sort(MiddleIt, Ops.end());
+    std::inplace_merge(Ops.begin(), MiddleIt, Ops.end());
+  }
+
+  std::sort(Factors.begin(), Factors.end(), Factor::PowerDescendingSorter());
+  return true;
+}
+
+/// \brief Build a tree of multiplies, computing the product of Ops.
+static Value *buildMultiplyTree(IRBuilder<> &Builder,
+                                SmallVectorImpl<Value*> &Ops) {
+  if (Ops.size() == 1)
+    return Ops.back();
+
+  Value *LHS = Ops.pop_back_val();
+  do {
+    LHS = Builder.CreateMul(LHS, Ops.pop_back_val());
+  } while (!Ops.empty());
+
+  return LHS;
+}
+
+/// \brief Build a minimal multiplication DAG for (a^x)*(b^y)*(c^z)*...
+///
+/// Given a vector of values raised to various powers, where no two values are
+/// equal and the powers are sorted in decreasing order, compute the minimal
+/// DAG of multiplies to compute the final product, and return that product
+/// value.
+Value *Reassociate::buildMinimalMultiplyDAG(IRBuilder<> &Builder,
+                                            SmallVectorImpl<Factor> &Factors) {
+  assert(Factors[0].Power);
+  SmallVector<Value *, 4> OuterProduct;
+  for (unsigned LastIdx = 0, Idx = 1, Size = Factors.size();
+       Idx < Size && Factors[Idx].Power > 0; ++Idx) {
+    if (Factors[Idx].Power != Factors[LastIdx].Power) {
+      LastIdx = Idx;
+      continue;
+    }
+
+    // We want to multiply across all the factors with the same power so that
+    // we can raise them to that power as a single entity. Build a mini tree
+    // for that.
+    SmallVector<Value *, 4> InnerProduct;
+    InnerProduct.push_back(Factors[LastIdx].Base);
+    do {
+      InnerProduct.push_back(Factors[Idx].Base);
+      ++Idx;
+    } while (Idx < Size && Factors[Idx].Power == Factors[LastIdx].Power);
+
+    // Reset the base value of the first factor to the new expression tree.
+    // We'll remove all the factors with the same power in a second pass.
+    Factors[LastIdx].Base
+      = ReassociateExpression(
+          cast<BinaryOperator>(buildMultiplyTree(Builder, InnerProduct)));
+
+    LastIdx = Idx;
+  }
+  // Unique factors with equal powers -- we've folded them into the first one's
+  // base.
+  Factors.erase(std::unique(Factors.begin(), Factors.end(),
+                            Factor::PowerEqual()),
+                Factors.end());
+
+  // Iteratively collect the base of each factor with an add power into the
+  // outer product, and halve each power in preparation for squaring the
+  // expression.
+  for (unsigned Idx = 0, Size = Factors.size(); Idx != Size; ++Idx) {
+    if (Factors[Idx].Power & 1)
+      OuterProduct.push_back(Factors[Idx].Base);
+    Factors[Idx].Power >>= 1;
+  }
+  if (Factors[0].Power) {
+    Value *SquareRoot = buildMinimalMultiplyDAG(Builder, Factors);
+    OuterProduct.push_back(SquareRoot);
+    OuterProduct.push_back(SquareRoot);
+  }
+  if (OuterProduct.size() == 1)
+    return OuterProduct.front();
+
+  return ReassociateExpression(
+    cast<BinaryOperator>(buildMultiplyTree(Builder, OuterProduct)));
+}
+
+Value *Reassociate::OptimizeMul(BinaryOperator *I,
+                                SmallVectorImpl<ValueEntry> &Ops) {
+  // We can only optimize the multiplies when there is a chain of more than
+  // three, such that a balanced tree might require fewer total multiplies.
+  if (Ops.size() < 4)
+    return 0;
+
+  // Try to turn linear trees of multiplies without other uses of the
+  // intermediate stages into minimal multiply DAGs with perfect sub-expression
+  // re-use.
+  SmallVector<Factor, 4> Factors;
+  if (!collectMultiplyFactors(Ops, Factors))
+    return 0; // All distinct factors, so nothing left for us to do.
+
+  IRBuilder<> Builder(I);
+  Value *V = buildMinimalMultiplyDAG(Builder, Factors);
+  if (Ops.empty())
+    return V;
+
+  ValueEntry NewEntry = ValueEntry(getRank(V), V);
+  Ops.insert(std::lower_bound(Ops.begin(), Ops.end(), NewEntry), NewEntry);
+  return 0;
+}
+
 Value *Reassociate::OptimizeExpression(BinaryOperator *I,
                                        SmallVectorImpl<ValueEntry> &Ops) {
   // Now that we have the linearized expression tree, try to optimize it.
@@ -937,30 +1176,28 @@
 
   // Handle destructive annihilation due to identities between elements in the
   // argument list here.
+  unsigned NumOps = Ops.size();
   switch (Opcode) {
   default: break;
   case Instruction::And:
   case Instruction::Or:
-  case Instruction::Xor: {
-    unsigned NumOps = Ops.size();
+  case Instruction::Xor:
     if (Value *Result = OptimizeAndOrXor(Opcode, Ops))
       return Result;
-    IterateOptimization |= Ops.size() != NumOps;
     break;
-  }
 
-  case Instruction::Add: {
-    unsigned NumOps = Ops.size();
+  case Instruction::Add:
     if (Value *Result = OptimizeAdd(I, Ops))
       return Result;
-    IterateOptimization |= Ops.size() != NumOps;
-  }
+    break;
 
+  case Instruction::Mul:
+    if (Value *Result = OptimizeMul(I, Ops))
+      return Result;
     break;
-  //case Instruction::Mul:
   }
 
-  if (IterateOptimization)
+  if (IterateOptimization || Ops.size() != NumOps)
     return OptimizeExpression(I, Ops);
   return 0;
 }

Modified: llvm/trunk/test/Transforms/Reassociate/mulfactor.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/Reassociate/mulfactor.ll?rev=155616&r1=155615&r2=155616&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/Reassociate/mulfactor.ll (original)
+++ llvm/trunk/test/Transforms/Reassociate/mulfactor.ll Thu Apr 26 00:30:30 2012
@@ -33,3 +33,102 @@
 	ret i32 %d
 }
 
+define i32 @test3(i32 %x) {
+; (x^8)
+; CHECK: @test3
+; CHECK: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: ret
+
+entry:
+  %a = mul i32 %x, %x
+  %b = mul i32 %a, %x
+  %c = mul i32 %b, %x
+  %d = mul i32 %c, %x
+  %e = mul i32 %d, %x
+  %f = mul i32 %e, %x
+  %g = mul i32 %f, %x
+  ret i32 %g
+}
+
+define i32 @test4(i32 %x) {
+; (x^7)
+; CHECK: @test4
+; CHECK: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: ret
+
+entry:
+  %a = mul i32 %x, %x
+  %b = mul i32 %a, %x
+  %c = mul i32 %b, %x
+  %d = mul i32 %c, %x
+  %e = mul i32 %d, %x
+  %f = mul i32 %e, %x
+  ret i32 %f
+}
+
+define i32 @test5(i32 %x, i32 %y) {
+; (x^4) * (y^2)
+; CHECK: @test5
+; CHECK: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: ret
+
+entry:
+  %a = mul i32 %x, %y
+  %b = mul i32 %a, %y
+  %c = mul i32 %b, %x
+  %d = mul i32 %c, %x
+  %e = mul i32 %d, %x
+  ret i32 %e
+}
+
+define i32 @test6(i32 %x, i32 %y, i32 %z) {
+; (x^5) * (y^3) * z
+; CHECK: @test6
+; CHECK: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: ret
+
+entry:
+  %a = mul i32 %x, %y
+  %b = mul i32 %a, %x
+  %c = mul i32 %b, %y
+  %d = mul i32 %c, %x
+  %e = mul i32 %d, %y
+  %f = mul i32 %e, %x
+  %g = mul i32 %f, %z
+  %h = mul i32 %g, %x
+  ret i32 %h
+}
+
+define i32 @test7(i32 %x, i32 %y, i32 %z) {
+; (x^4) * (y^3) * (z^2)
+; CHECK: @test7
+; CHECK: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: mul
+; CHECK-NEXT: ret
+
+entry:
+  %a = mul i32 %y, %x
+  %b = mul i32 %a, %z
+  %c = mul i32 %b, %z
+  %d = mul i32 %c, %x
+  %e = mul i32 %d, %y
+  %f = mul i32 %e, %y
+  %g = mul i32 %f, %x
+  %h = mul i32 %g, %x
+  ret i32 %h
+}





More information about the llvm-commits mailing list