[llvm-commits] CVS: llvm/lib/Transforms/Scalar/Reassociate.cpp

Chris Lattner lattner at cs.uiuc.edu
Sat Mar 4 01:31:27 PST 2006



Changes in directory llvm/lib/Transforms/Scalar:

Reassociate.cpp updated: 1.56 -> 1.57
---
Log message:

Add factoring of multiplications, e.g. turning A*A+A*B into A*(A+B).
Testcase here: Transforms/Reassociate/mulfactor.ll


---
Diffs of the changes:  (+186 -49)

 Reassociate.cpp |  235 ++++++++++++++++++++++++++++++++++++++++++++------------
 1 files changed, 186 insertions(+), 49 deletions(-)


Index: llvm/lib/Transforms/Scalar/Reassociate.cpp
diff -u llvm/lib/Transforms/Scalar/Reassociate.cpp:1.56 llvm/lib/Transforms/Scalar/Reassociate.cpp:1.57
--- llvm/lib/Transforms/Scalar/Reassociate.cpp:1.56	Sun Jan 22 17:32:06 2006
+++ llvm/lib/Transforms/Scalar/Reassociate.cpp	Sat Mar  4 03:31:13 2006
@@ -41,6 +41,7 @@
   Statistic<> NumChanged("reassociate","Number of insts reassociated");
   Statistic<> NumSwapped("reassociate","Number of insts with operands swapped");
   Statistic<> NumAnnihil("reassociate","Number of expr tree annihilated");
+  Statistic<> NumFactor ("reassociate","Number of multiplies factored");
 
   struct ValueEntry {
     unsigned Rank;
@@ -50,7 +51,20 @@
   inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) {
     return LHS.Rank > RHS.Rank;   // Sort so that highest rank goes to start.
   }
+}
 
+/// PrintOps - Print out the expression identified in the Ops list.
+///
+static void PrintOps(Instruction *I, const std::vector<ValueEntry> &Ops) {
+  Module *M = I->getParent()->getParent()->getParent();
+  std::cerr << Instruction::getOpcodeName(I->getOpcode()) << " "
+  << *Ops[0].Op->getType();
+  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
+    WriteAsOperand(std::cerr << " ", Ops[i].Op, false, true, M)
+      << "," << Ops[i].Rank;
+}
+  
+namespace {  
   class Reassociate : public FunctionPass {
     std::map<BasicBlock*, unsigned> RankMap;
     std::map<Value*, unsigned> ValueRankMap;
@@ -66,10 +80,13 @@
     unsigned getRank(Value *V);
     void RewriteExprTree(BinaryOperator *I, unsigned Idx,
                          std::vector<ValueEntry> &Ops);
-    void OptimizeExpression(unsigned Opcode, std::vector<ValueEntry> &Ops);
+    Value *OptimizeExpression(BinaryOperator *I, std::vector<ValueEntry> &Ops);
     void LinearizeExprTree(BinaryOperator *I, std::vector<ValueEntry> &Ops);
     void LinearizeExpr(BinaryOperator *I);
+    Value *RemoveFactorFromExpression(Value *V, Value *Factor);
     void ReassociateBB(BasicBlock *BB);
+    
+    void RemoveDeadBinaryOp(Value *V);
   };
 
   RegisterOpt<Reassociate> X("reassociate", "Reassociate expressions");
@@ -78,6 +95,15 @@
 // Public interface to the Reassociate pass
 FunctionPass *llvm::createReassociatePass() { return new Reassociate(); }
 
+void Reassociate::RemoveDeadBinaryOp(Value *V) {
+  BinaryOperator *BOp = dyn_cast<BinaryOperator>(V);
+  if (!BOp || !BOp->use_empty()) return;
+  
+  Value *LHS = BOp->getOperand(0), *RHS = BOp->getOperand(1);
+  RemoveDeadBinaryOp(LHS);
+  RemoveDeadBinaryOp(RHS);
+}
+
 
 static bool isUnmovableInstruction(Instruction *I) {
   if (I->getOpcode() == Instruction::PHI ||
@@ -207,9 +233,6 @@
 /// form of the the expression (((a+b)+c)+d), and collects information about the
 /// rank of the non-tree operands.
 ///
-/// This returns the rank of the RHS operand, which is known to be the highest
-/// rank value in the expression tree.
-///
 void Reassociate::LinearizeExprTree(BinaryOperator *I,
                                     std::vector<ValueEntry> &Ops) {
   Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
@@ -279,12 +302,17 @@
   if (i+2 == Ops.size()) {
     if (I->getOperand(0) != Ops[i].Op ||
         I->getOperand(1) != Ops[i+1].Op) {
+      Value *OldLHS = I->getOperand(0);
       DEBUG(std::cerr << "RA: " << *I);
       I->setOperand(0, Ops[i].Op);
       I->setOperand(1, Ops[i+1].Op);
       DEBUG(std::cerr << "TO: " << *I);
       MadeChange = true;
       ++NumChanged;
+      
+      // If we reassociated a tree to fewer operands (e.g. (1+a+2) -> (a+3)
+      // delete the extra, now dead, nodes.
+      RemoveDeadBinaryOp(OldLHS);
     }
     return;
   }
@@ -297,7 +325,15 @@
     MadeChange = true;
     ++NumChanged;
   }
-  RewriteExprTree(cast<BinaryOperator>(I->getOperand(0)), i+1, Ops);
+  
+  BinaryOperator *LHS = cast<BinaryOperator>(I->getOperand(0));
+  assert(LHS->getOpcode() == I->getOpcode() &&
+         "Improper expression tree!");
+  
+  // Compactify the tree instructions together with each other to guarantee
+  // that the expression tree is dominated by all of Ops.
+  LHS->moveBefore(I);
+  RewriteExprTree(LHS, i+1, Ops);
 }
 
 
@@ -405,19 +441,57 @@
   return i;
 }
 
-void Reassociate::OptimizeExpression(unsigned Opcode,
-                                     std::vector<ValueEntry> &Ops) {
+/// EmitAddTreeOfValues - Emit a tree of add instructions, summing Ops together
+/// and returning the result.  Insert the tree before I.
+static Value *EmitAddTreeOfValues(Instruction *I, std::vector<Value*> &Ops) {
+  if (Ops.size() == 1) return Ops.back();
+  
+  Value *V1 = Ops.back();
+  Ops.pop_back();
+  Value *V2 = EmitAddTreeOfValues(I, Ops);
+  return BinaryOperator::createAdd(V2, V1, "tmp", I);
+}
+
+/// RemoveFactorFromExpression - If V is an expression tree that is a 
+/// multiplication sequence, and if this sequence contains a multiply by Factor,
+/// remove Factor from the tree and return the new tree.
+Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) {
+  BinaryOperator *BO = isReassociableOp(V, Instruction::Mul);
+  if (!BO) return 0;
+  
+  std::vector<ValueEntry> Factors;
+  LinearizeExprTree(BO, Factors);
+
+  bool FoundFactor = false;
+  for (unsigned i = 0, e = Factors.size(); i != e; ++i)
+    if (Factors[i].Op == Factor) {
+      FoundFactor = true;
+      Factors.erase(Factors.begin()+i);
+      break;
+    }
+  if (!FoundFactor) return 0;
+  
+  if (Factors.size() == 1) return Factors[0].Op;
+  
+  RewriteExprTree(BO, 0, Factors);
+  return BO;
+}
+
+
+Value *Reassociate::OptimizeExpression(BinaryOperator *I,
+                                       std::vector<ValueEntry> &Ops) {
   // Now that we have the linearized expression tree, try to optimize it.
   // Start by folding any constants that we found.
   bool IterateOptimization = false;
-  if (Ops.size() == 1) return;
+  if (Ops.size() == 1) return Ops[0].Op;
 
+  unsigned Opcode = I->getOpcode();
+  
   if (Constant *V1 = dyn_cast<Constant>(Ops[Ops.size()-2].Op))
     if (Constant *V2 = dyn_cast<Constant>(Ops.back().Op)) {
       Ops.pop_back();
       Ops.back().Op = ConstantExpr::get(Opcode, V1, V2);
-      OptimizeExpression(Opcode, Ops);
-      return;
+      return OptimizeExpression(I, Ops);
     }
 
   // Check for destructive annihilation due to a constant being used.
@@ -426,30 +500,24 @@
     default: break;
     case Instruction::And:
       if (CstVal->isNullValue()) {           // ... & 0 -> 0
-        Ops[0].Op = CstVal;
-        Ops.erase(Ops.begin()+1, Ops.end());
         ++NumAnnihil;
-        return;
+        return CstVal;
       } else if (CstVal->isAllOnesValue()) { // ... & -1 -> ...
         Ops.pop_back();
       }
       break;
     case Instruction::Mul:
       if (CstVal->isNullValue()) {           // ... * 0 -> 0
-        Ops[0].Op = CstVal;
-        Ops.erase(Ops.begin()+1, Ops.end());
         ++NumAnnihil;
-        return;
+        return CstVal;
       } else if (cast<ConstantInt>(CstVal)->getRawValue() == 1) {
         Ops.pop_back();                      // ... * 1 -> ...
       }
       break;
     case Instruction::Or:
       if (CstVal->isAllOnesValue()) {        // ... | -1 -> -1
-        Ops[0].Op = CstVal;
-        Ops.erase(Ops.begin()+1, Ops.end());
         ++NumAnnihil;
-        return;
+        return CstVal;
       }
       // FALLTHROUGH!
     case Instruction::Add:
@@ -458,7 +526,7 @@
         Ops.pop_back();
       break;
     }
-  if (Ops.size() == 1) return;
+  if (Ops.size() == 1) return Ops[0].Op;
 
   // Handle destructive annihilation do to identities between elements in the
   // argument list here.
@@ -477,15 +545,11 @@
         unsigned FoundX = FindInOperandList(Ops, i, X);
         if (FoundX != i) {
           if (Opcode == Instruction::And) {   // ...&X&~X = 0
-            Ops[0].Op = Constant::getNullValue(X->getType());
-            Ops.erase(Ops.begin()+1, Ops.end());
             ++NumAnnihil;
-            return;
+            return Constant::getNullValue(X->getType());
           } else if (Opcode == Instruction::Or) {   // ...|X|~X = -1
-            Ops[0].Op = ConstantIntegral::getAllOnesValue(X->getType());
-            Ops.erase(Ops.begin()+1, Ops.end());
             ++NumAnnihil;
-            return;
+            return ConstantIntegral::getAllOnesValue(X->getType());
           }
         }
       }
@@ -503,10 +567,8 @@
         } else {
           assert(Opcode == Instruction::Xor);
           if (e == 2) {
-            Ops[0].Op = Constant::getNullValue(Ops[0].Op->getType());
-            Ops.erase(Ops.begin()+1, Ops.end());
             ++NumAnnihil;
-            return;
+            return Constant::getNullValue(Ops[0].Op->getType());
           }
           // ... X^X -> ...
           Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
@@ -520,7 +582,7 @@
 
   case Instruction::Add:
     // Scan the operand lists looking for X and -X pairs.  If we find any, we
-    // can simplify the expression. X+-X == 0
+    // can simplify the expression. X+-X == 0.
     for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
       assert(i < Ops.size());
       // Check for X and -X in the operand list.
@@ -530,10 +592,8 @@
         if (FoundX != i) {
           // Remove X and -X from the operand list.
           if (Ops.size() == 2) {
-            Ops[0].Op = Constant::getNullValue(X->getType());
-            Ops.pop_back();
             ++NumAnnihil;
-            return;
+            return Constant::getNullValue(X->getType());
           } else {
             Ops.erase(Ops.begin()+i);
             if (i < FoundX)
@@ -549,30 +609,99 @@
         }
       }
     }
+    
+
+    // Scan the operand list, checking to see if there are any common factors
+    // between operands.  Consider something like A*A+A*B*C+D.  We would like to
+    // reassociate this to A*(A+B*C)+D, which reduces the number of multiplies.
+    // To efficiently find this, we count the number of times a factor occurs
+    // for any ADD operands that are MULs.
+    std::map<Value*, unsigned> FactorOccurrences;
+    unsigned MaxOcc = 0;
+    Value *MaxOccVal = 0;
+    if (!I->getType()->isFloatingPoint()) {
+      for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
+        if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(Ops[i].Op))
+          if (BOp->getOpcode() == Instruction::Mul && BOp->hasOneUse()) {
+            // Compute all of the factors of this added value.
+            std::vector<ValueEntry> Factors;
+            LinearizeExprTree(BOp, Factors);
+            assert(Factors.size() > 1 && "Bad linearize!");
+            
+            // Add one to FactorOccurrences for each unique factor in this op.
+            if (Factors.size() == 2) {
+              unsigned Occ = ++FactorOccurrences[Factors[0].Op];
+              if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[0].Op; }
+              if (Factors[0].Op != Factors[1].Op) {   // Don't double count A*A.
+                Occ = ++FactorOccurrences[Factors[1].Op];
+                if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[1].Op; }
+              }
+            } else {
+              std::set<Value*> Duplicates;
+              for (unsigned i = 0, e = Factors.size(); i != e; ++i)
+                if (Duplicates.insert(Factors[i].Op).second) {
+                  unsigned Occ = ++FactorOccurrences[Factors[i].Op];
+                  if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[i].Op; }
+                }
+            }
+          }
+      }
+    }
+
+    // If any factor occurred more than one time, we can pull it out.
+    if (MaxOcc > 1) {
+      DEBUG(std::cerr << "\nFACTORING [" << MaxOcc << "]: "
+                      << *MaxOccVal << "\n");
+      
+      // Create a new instruction that uses the MaxOccVal twice.  If we don't do
+      // this, we could otherwise run into situations where removing a factor
+      // from an expression will drop a use of maxocc, and this can cause 
+      // RemoveFactorFromExpression on successive values to behave differently.
+      Instruction *DummyInst = BinaryOperator::createAdd(MaxOccVal, MaxOccVal);
+      std::vector<Value*> NewMulOps;
+      for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
+        if (Value *V = RemoveFactorFromExpression(Ops[i].Op, MaxOccVal)) {
+          NewMulOps.push_back(V);
+          Ops.erase(Ops.begin()+i);
+          --i; --e;
+        }
+      }
+      
+      // No need for extra uses anymore.
+      delete DummyInst;
+
+      Value *V = EmitAddTreeOfValues(I, NewMulOps);
+      // FIXME: Must optimize V now, to handle this case:
+      // A*A*B + A*A*C -> A*(A*B+A*C)   -> A*(A*(B+C))
+      V = BinaryOperator::createMul(V, MaxOccVal, "tmp", I);
+
+      ++NumFactor;
+      
+      if (Ops.size() == 0)
+        return V;
+
+      // Add the new value to the list of things being added.
+      Ops.insert(Ops.begin(), ValueEntry(getRank(V), V));
+      
+      // Rewrite the tree so that there is now a use of V.
+      RewriteExprTree(I, 0, Ops);
+      return OptimizeExpression(I, Ops);
+    }
     break;
   //case Instruction::Mul:
   }
 
   if (IterateOptimization)
-    OptimizeExpression(Opcode, Ops);
+    return OptimizeExpression(I, Ops);
+  return 0;
 }
 
-/// PrintOps - Print out the expression identified in the Ops list.
-///
-static void PrintOps(unsigned Opcode, const std::vector<ValueEntry> &Ops,
-                     BasicBlock *BB) {
-  Module *M = BB->getParent()->getParent();
-  std::cerr << Instruction::getOpcodeName(Opcode) << " "
-            << *Ops[0].Op->getType();
-  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
-    WriteAsOperand(std::cerr << " ", Ops[i].Op, false, true, M)
-      << "," << Ops[i].Rank;
-}
 
 /// ReassociateBB - Inspect all of the instructions in this basic block,
 /// reassociating them as we go.
 void Reassociate::ReassociateBB(BasicBlock *BB) {
-  for (BasicBlock::iterator BI = BB->begin(); BI != BB->end(); ++BI) {
+  for (BasicBlock::iterator BBI = BB->begin(); BBI != BB->end(); ) {
+    Instruction *BI = BBI++;
     if (BI->getOpcode() == Instruction::Shl &&
         isa<ConstantInt>(BI->getOperand(1)))
       if (Instruction *NI = ConvertShiftToMul(BI)) {
@@ -623,7 +752,7 @@
     std::vector<ValueEntry> Ops;
     LinearizeExprTree(I, Ops);
 
-    DEBUG(std::cerr << "RAIn:\t"; PrintOps(I->getOpcode(), Ops, BB);
+    DEBUG(std::cerr << "RAIn:\t"; PrintOps(I, Ops);
           std::cerr << "\n");
 
     // Now that we have linearized the tree to a list and have gathered all of
@@ -636,7 +765,14 @@
 
     // OptimizeExpression - Now that we have the expression tree in a convenient
     // sorted form, optimize it globally if possible.
-    OptimizeExpression(I->getOpcode(), Ops);
+    if (Value *V = OptimizeExpression(I, Ops)) {
+      // This expression tree simplified to something that isn't a tree,
+      // eliminate it.
+      DEBUG(std::cerr << "Reassoc to scalar: " << *V << "\n");
+      I->replaceAllUsesWith(V);
+      RemoveDeadBinaryOp(I);
+      continue;
+    }
 
     // We want to sink immediates as deeply as possible except in the case where
     // this is a multiply tree used only by an add, and the immediate is a -1.
@@ -650,13 +786,14 @@
       Ops.pop_back();
     }
 
-    DEBUG(std::cerr << "RAOut:\t"; PrintOps(I->getOpcode(), Ops, BB);
+    DEBUG(std::cerr << "RAOut:\t"; PrintOps(I, Ops);
           std::cerr << "\n");
 
     if (Ops.size() == 1) {
       // This expression tree simplified to something that isn't a tree,
       // eliminate it.
       I->replaceAllUsesWith(Ops[0].Op);
+      RemoveDeadBinaryOp(I);
     } else {
       // Now that we ordered and optimized the expressions, splat them back into
       // the expression tree, removing any unneeded nodes.






More information about the llvm-commits mailing list