[llvm] r228821 - [LoopReroll] Introduce the concept of DAGRootSets.

James Molloy james.molloy at arm.com
Wed Feb 11 01:19:47 PST 2015


Author: jamesm
Date: Wed Feb 11 03:19:47 2015
New Revision: 228821

URL: http://llvm.org/viewvc/llvm-project?rev=228821&view=rev
Log:
[LoopReroll] Introduce the concept of DAGRootSets.

A DAGRootSet models an induction variable being used in a rerollable
loop. For example:

   x[i*3+0] = y1
   x[i*3+1] = y2
   x[i*3+2] = y3

   Base instruction -> i*3
                    +---+----+
                   /    |     \
               ST[y1]  +1     +2  <-- Roots
                        |      |
                      ST[y2] ST[y3]

There may be multiple DAGRootSets, for example:

   x[i*2+0] = ...   (1)
   x[i*2+1] = ...   (1)
   x[i*2+4] = ...   (2)
   x[i*2+5] = ...   (2)
   x[(i+1234)*2+5678] = ... (3)
   x[(i+1234)*2+5679] = ... (3)

This concept is similar to the "Scale" member used previously, but allows
multiple independent sets of roots based off the same induction variable.

Modified:
    llvm/trunk/lib/Transforms/Scalar/LoopRerollPass.cpp
    llvm/trunk/test/Transforms/LoopReroll/basic.ll

Modified: llvm/trunk/lib/Transforms/Scalar/LoopRerollPass.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/LoopRerollPass.cpp?rev=228821&r1=228820&r2=228821&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/LoopRerollPass.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/LoopRerollPass.cpp Wed Feb 11 03:19:47 2015
@@ -126,8 +126,8 @@ namespace {
     /// has to be less than 25 in order to fit into a SmallBitVector.
     IL_MaxRerollIterations = 16,
     /// The bitvector index used by loop induction variables and other
-    /// instructions that belong to no one particular iteration.
-    IL_LoopIncIdx,
+    /// instructions that belong to all iterations.
+    IL_All,
     IL_End
   };
 
@@ -323,6 +323,39 @@ namespace {
       DenseSet<int> Reds;
     };
 
+    // A DAGRootSet models an induction variable being used in a rerollable
+    // loop. For example,
+    //
+    //   x[i*3+0] = y1
+    //   x[i*3+1] = y2
+    //   x[i*3+2] = y3
+    //
+    //   Base instruction -> i*3               
+    //                    +---+----+
+    //                   /    |     \
+    //               ST[y1]  +1     +2  <-- Roots
+    //                        |      |
+    //                      ST[y2] ST[y3]
+    //
+    // There may be multiple DAGRoots, for example:
+    //
+    //   x[i*2+0] = ...   (1)
+    //   x[i*2+1] = ...   (1)
+    //   x[i*2+4] = ...   (2)
+    //   x[i*2+5] = ...   (2)
+    //   x[(i+1234)*2+5678] = ... (3)
+    //   x[(i+1234)*2+5679] = ... (3)
+    //
+    // The loop will be rerolled by adding a new loop induction variable,
+    // one for the Base instruction in each DAGRootSet.
+    //
+    struct DAGRootSet {
+      Instruction *BaseInst;
+      SmallInstructionVector Roots;
+      // The instructions between IV and BaseInst (but not including BaseInst).
+      SmallInstructionSet SubsumedInsts;
+    };
+
     // The set of all DAG roots, and state tracking of all roots
     // for a particular induction variable.
     struct DAGRootTracker {
@@ -345,8 +378,11 @@ namespace {
     protected:
       typedef MapVector<Instruction*, SmallBitVector> UsesTy;
 
-      bool findScaleFromMul();
-      bool collectAllRoots();
+      bool findRootsRecursive(Instruction *IVU,
+                              SmallInstructionSet SubsumedInsts);
+      bool findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts);
+      bool collectPossibleRoots(Instruction *Base,
+                                std::map<int64_t,Instruction*> &Roots);
 
       bool collectUsedInstructions(SmallInstructionSet &PossibleRedSet);
       void collectInLoopUserSet(const SmallInstructionVector &Roots,
@@ -359,6 +395,8 @@ namespace {
                                 DenseSet<Instruction *> &Users);
 
       UsesTy::iterator nextInstr(int Val, UsesTy &In, UsesTy::iterator I);
+      bool isBaseInst(Instruction *I);
+      bool isRootInst(Instruction *I);
 
       LoopReroll *Parent;
 
@@ -377,14 +415,12 @@ namespace {
       // to the indvar: a[i*2+0] = ...; a[i*2+1] = ... ;
       // If Inc is not 1, Scale = Inc.
       uint64_t Scale;
-      // If Scale != Inc, then RealIV is IV after its multiplication.
-      Instruction *RealIV;
       // The roots themselves.
-      SmallInstructionVector Roots;
+      SmallVector<DAGRootSet,16> RootSets;
       // All increment instructions for IV.
       SmallInstructionVector LoopIncs;
       // Map of all instructions in the loop (in order) to the iterations
-      // they are used in (or specially, IL_LoopIncIdx for instructions
+      // they are used in (or specially, IL_All for instructions
       // used in the loop increment mechanism).
       UsesTy Uses;
     };
@@ -586,155 +622,258 @@ static bool isSimpleLoadStore(Instructio
   return false;
 }
 
-bool LoopReroll::DAGRootTracker::findRoots() {
-
-  const SCEVAddRecExpr *RealIVSCEV = cast<SCEVAddRecExpr>(SE->getSCEV(IV));
-  Inc = cast<SCEVConstant>(RealIVSCEV->getOperand(1))->
-    getValue()->getZExtValue();
-
-  // The effective induction variable, IV, is normally also the real induction
-  // variable. When we're dealing with a loop like:
-  //   for (int i = 0; i < 500; ++i)
-  //     x[3*i] = ...;
-  //     x[3*i+1] = ...;
-  //     x[3*i+2] = ...;
-  // then the real IV is still i, but the effective IV is (3*i).
-  Scale = Inc;
-  RealIV = IV;
-  if (Inc == 1 && !findScaleFromMul())
-    return false;
+/// Return true if IVU is a "simple" arithmetic operation.
+/// This is used for narrowing the search space for DAGRoots; only arithmetic
+/// and GEPs can be part of a DAGRoot.
+static bool isSimpleArithmeticOp(User *IVU) {
+  if (Instruction *I = dyn_cast<Instruction>(IVU)) {
+    switch (I->getOpcode()) {
+    default: return false;
+    case Instruction::Add:
+    case Instruction::Sub:
+    case Instruction::Mul:
+    case Instruction::Shl:
+    case Instruction::AShr:
+    case Instruction::LShr:
+    case Instruction::GetElementPtr:
+    case Instruction::Trunc:
+    case Instruction::ZExt:
+    case Instruction::SExt:
+      return true;
+    }
+  }
+  return false;
+}
 
-  // The set of increment instructions for each increment value.
-  if (!collectAllRoots())
+static bool isLoopIncrement(User *U, Instruction *IV) {
+  BinaryOperator *BO = dyn_cast<BinaryOperator>(U);
+  if (!BO || BO->getOpcode() != Instruction::Add)
     return false;
 
-  if (Roots.size() > IL_MaxRerollIterations) {
-    DEBUG(dbgs() << "LRR: Aborting - too many iterations found. "
-          << "#Found=" << Roots.size() << ", #Max=" << IL_MaxRerollIterations
-          << "\n");
-    return false;
+  for (auto *UU : BO->users()) {
+    PHINode *PN = dyn_cast<PHINode>(UU);
+    if (PN && PN == IV)
+      return true;
   }
-
-  return true;
+  return false;
 }
 
-// Recognize loops that are setup like this:
-//
-// %iv = phi [ (preheader, ...), (body, %iv.next) ]
-// %scaled.iv = mul %iv, scale
-// f(%scaled.iv)
-// %scaled.iv.1 = add %scaled.iv, 1
-// f(%scaled.iv.1)
-// %scaled.iv.2 = add %scaled.iv, 2
-// f(%scaled.iv.2)
-// %scaled.iv.scale_m_1 = add %scaled.iv, scale-1
-// f(%scaled.iv.scale_m_1)
-// ...
-// %iv.next = add %iv, 1
-// %cmp = icmp(%iv, ...)
-// br %cmp, header, exit
-//
-// and, if found, set IV = %scaled.iv, and add %iv.next to LoopIncs.
-bool LoopReroll::DAGRootTracker::findScaleFromMul() {
+bool LoopReroll::DAGRootTracker::
+collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots) {
+  SmallInstructionVector BaseUsers;
 
-  // This is a special case: here we're looking for all uses (except for
-  // the increment) to be multiplied by a common factor. The increment must
-  // be by one. This is to capture loops like:
-  //   for (int i = 0; i < 500; ++i) {
-  //     foo(3*i); foo(3*i+1); foo(3*i+2);
-  //   }
-  if (RealIV->getNumUses() != 2)
-    return false;
-  const SCEVAddRecExpr *RealIVSCEV = cast<SCEVAddRecExpr>(SE->getSCEV(RealIV));
-  Instruction *User1 = cast<Instruction>(*RealIV->user_begin()),
-              *User2 = cast<Instruction>(*std::next(RealIV->user_begin()));
-  if (!SE->isSCEVable(User1->getType()) || !SE->isSCEVable(User2->getType()))
-    return false;
-  const SCEVAddRecExpr *User1SCEV =
-                         dyn_cast<SCEVAddRecExpr>(SE->getSCEV(User1)),
-                       *User2SCEV =
-                         dyn_cast<SCEVAddRecExpr>(SE->getSCEV(User2));
-  if (!User1SCEV || !User1SCEV->isAffine() ||
-      !User2SCEV || !User2SCEV->isAffine())
-    return false;
+  for (auto *I : Base->users()) {
+    ConstantInt *CI = nullptr;
+
+    if (isLoopIncrement(I, IV)) {
+      LoopIncs.push_back(cast<Instruction>(I));
+      continue;
+    }
 
-  // We assume below that User1 is the scale multiply and User2 is the
-  // increment. If this can't be true, then swap them.
-  if (User1SCEV == RealIVSCEV->getPostIncExpr(*SE)) {
-    std::swap(User1, User2);
-    std::swap(User1SCEV, User2SCEV);
+    // The root nodes must be either GEPs, ORs or ADDs.
+    if (auto *BO = dyn_cast<BinaryOperator>(I)) {
+      if (BO->getOpcode() == Instruction::Add ||
+          BO->getOpcode() == Instruction::Or)
+        CI = dyn_cast<ConstantInt>(BO->getOperand(1));
+    } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) {
+      Value *LastOperand = GEP->getOperand(GEP->getNumOperands()-1);
+      CI = dyn_cast<ConstantInt>(LastOperand);
+    }
+
+    if (!CI) {
+      if (Instruction *II = dyn_cast<Instruction>(I)) {
+        BaseUsers.push_back(II);
+        continue;
+      } else {
+        DEBUG(dbgs() << "LRR: Aborting due to non-instruction: " << *I << "\n");
+        return false;
+      }
+    }
+
+    int64_t V = CI->getValue().getSExtValue();
+    if (Roots.find(V) != Roots.end())
+      // No duplicates, please.
+      return false;
+
+    // FIXME: Add support for negative values.
+    if (V < 0) {
+      DEBUG(dbgs() << "LRR: Aborting due to negative value: " << V << "\n");
+      return false;
+    }
+
+    Roots[V] = cast<Instruction>(I);
   }
 
-  if (User2SCEV != RealIVSCEV->getPostIncExpr(*SE))
+  if (Roots.empty())
     return false;
-  assert(User2SCEV->getStepRecurrence(*SE)->isOne() &&
-         "Invalid non-unit step for multiplicative scaling");
-  LoopIncs.push_back(User2);
-
-  if (const SCEVConstant *MulScale =
-      dyn_cast<SCEVConstant>(User1SCEV->getStepRecurrence(*SE))) {
-    // Make sure that both the start and step have the same multiplier.
-    if (RealIVSCEV->getStart()->getType() != MulScale->getType())
-      return false;
-    if (SE->getMulExpr(RealIVSCEV->getStart(), MulScale) !=
-        User1SCEV->getStart())
+  
+  assert(Roots.find(0) == Roots.end() && "Didn't expect a zero index!");
+
+  // If we found non-loop-inc, non-root users of Base, assume they are
+  // for the zeroth root index. This is because "add %a, 0" gets optimized
+  // away.
+  if (BaseUsers.size())
+    Roots[0] = Base;
+
+  // Calculate the number of users of the base, or lowest indexed, iteration.
+  unsigned NumBaseUses = BaseUsers.size();
+  if (NumBaseUses == 0)
+    NumBaseUses = Roots.begin()->second->getNumUses();
+  
+  // Check that every node has the same number of users.
+  for (auto &KV : Roots) {
+    if (KV.first == 0)
+      continue;
+    if (KV.second->getNumUses() != NumBaseUses) {
+      DEBUG(dbgs() << "LRR: Aborting - Root and Base #users not the same: "
+            << "#Base=" << NumBaseUses << ", #Root=" <<
+            KV.second->getNumUses() << "\n");
       return false;
+    }
+  }
+
+  return true; 
+}
+
+bool LoopReroll::DAGRootTracker::
+findRootsRecursive(Instruction *I, SmallInstructionSet SubsumedInsts) {
+  // Does the user look like it could be part of a root set?
+  // All its users must be simple arithmetic ops.
+  if (I->getNumUses() > IL_MaxRerollIterations)
+    return false;
 
-    ConstantInt *MulScaleCI = MulScale->getValue();
-    if (!MulScaleCI->uge(2) || MulScaleCI->uge(MaxInc))
+  if ((I->getOpcode() == Instruction::Mul ||
+       I->getOpcode() == Instruction::PHI) &&
+      I != IV &&
+      findRootsBase(I, SubsumedInsts))
+    return true;
+
+  SubsumedInsts.insert(I);
+
+  for (User *V : I->users()) {
+    Instruction *I = dyn_cast<Instruction>(V);
+    if (std::find(LoopIncs.begin(), LoopIncs.end(), I) != LoopIncs.end())
+      continue;
+
+    if (!I || !isSimpleArithmeticOp(I) ||
+        !findRootsRecursive(I, SubsumedInsts))
       return false;
-    Scale = MulScaleCI->getZExtValue();
-    IV = User1;
-  } else
+  }
+  return true;
+}
+
+bool LoopReroll::DAGRootTracker::
+findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts) {
+
+  // The base instruction needs to be a multiply so
+  // that we can erase it.
+  if (IVU->getOpcode() != Instruction::Mul &&
+      IVU->getOpcode() != Instruction::PHI)
     return false;
 
-  DEBUG(dbgs() << "LRR: Found possible scaling " << *User1 << "\n");
+  std::map<int64_t, Instruction*> V;
+  if (!collectPossibleRoots(IVU, V))
+    return false;
 
-  assert(Scale <= MaxInc && "Scale is too large");
-  assert(Scale > 1 && "Scale must be at least 2");
+  // If we didn't get a root for index zero, then IVU must be 
+  // subsumed.
+  if (V.find(0) == V.end())
+    SubsumedInsts.insert(IVU);
+
+  // Partition the vector into monotonically increasing indexes.
+  DAGRootSet DRS;
+  DRS.BaseInst = nullptr;
+
+  for (auto &KV : V) {
+    if (!DRS.BaseInst) {
+      DRS.BaseInst = KV.second;
+      DRS.SubsumedInsts = SubsumedInsts;
+    } else if (DRS.Roots.empty()) {
+      DRS.Roots.push_back(KV.second);
+    } else if (V.find(KV.first - 1) != V.end()) {
+      DRS.Roots.push_back(KV.second);
+    } else {
+      // Linear sequence terminated.
+      RootSets.push_back(DRS);
+      DRS.BaseInst = KV.second;
+      DRS.SubsumedInsts = SubsumedInsts;
+      DRS.Roots.clear();
+    }
+  }
+  RootSets.push_back(DRS);
 
   return true;
 }
 
-// Collect all root increments with respect to the provided induction variable
-// (normally the PHI, but sometimes a multiply). A root increment is an
-// instruction, normally an add, with a positive constant less than Scale. In a
-// rerollable loop, each of these increments is the root of an instruction
-// graph isomorphic to the others. Also, we collect the final induction
-// increment (the increment equal to the Scale), and its users in LoopIncs.
-bool LoopReroll::DAGRootTracker::collectAllRoots() {
-  Roots.resize(Scale-1);
-  
-  for (User *U : IV->users()) {
-    Instruction *UI = cast<Instruction>(U);
-    if (!SE->isSCEVable(UI->getType()))
-      continue;
-    if (UI->getType() != IV->getType())
-      continue;
-    if (!L->contains(UI))
-      continue;
-    if (hasUsesOutsideLoop(UI, L))
-      continue;
+bool LoopReroll::DAGRootTracker::findRoots() {
 
-    if (const SCEVConstant *Diff = dyn_cast<SCEVConstant>(SE->getMinusSCEV(
-          SE->getSCEV(UI), SE->getSCEV(IV)))) {
-      uint64_t Idx = Diff->getValue()->getValue().getZExtValue();
-      if (Idx > 0 && Idx < Scale) {
-        if (Roots[Idx-1])
-          // No duplicates allowed.
-          return false;
-        Roots[Idx-1] = UI;
-      } else if (Idx == Scale && Inc > 1) {
-        LoopIncs.push_back(UI);
-      }
+  const SCEVAddRecExpr *RealIVSCEV = cast<SCEVAddRecExpr>(SE->getSCEV(IV));
+  Inc = cast<SCEVConstant>(RealIVSCEV->getOperand(1))->
+    getValue()->getZExtValue();
+
+  assert(RootSets.empty() && "Unclean state!");
+  if (Inc == 1) {
+    for (auto *IVU : IV->users()) {
+      if (isLoopIncrement(IVU, IV))
+        LoopIncs.push_back(cast<Instruction>(IVU));
+    }
+    if (!findRootsRecursive(IV, SmallInstructionSet()))
+      return false;
+    LoopIncs.push_back(IV);
+  } else {
+    if (!findRootsBase(IV, SmallInstructionSet()))
+      return false;
+  }
+
+  // Ensure all sets have the same size.
+  if (RootSets.empty()) {
+    DEBUG(dbgs() << "LRR: Aborting because no root sets found!\n");
+    return false;
+  }
+  for (auto &V : RootSets) {
+    if (V.Roots.empty() || V.Roots.size() != RootSets[0].Roots.size()) {
+      DEBUG(dbgs()
+            << "LRR: Aborting because not all root sets have the same size\n");
+      return false;
     }
   }
 
-  for (unsigned i = 0; i < Scale-1; ++i) {
-    if (!Roots[i])
+  // And ensure all loop iterations are consecutive. We rely on std::map
+  // providing ordered traversal.
+  for (auto &V : RootSets) {
+    const auto *ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(V.BaseInst));
+    if (!ADR)
       return false;
+
+    // Consider a DAGRootSet with N-1 roots (so N different values including
+    //   BaseInst).
+    // Define d = Roots[0] - BaseInst, which should be the same as
+    //   Roots[I] - Roots[I-1] for all I in [1..N).
+    // Define D = BaseInst at J - BaseInst at J-1, where "@J" means the value at the
+    //   loop iteration J.
+    //
+    // Now, For the loop iterations to be consecutive:
+    //   D = d * N
+
+    unsigned N = V.Roots.size() + 1;
+    const SCEV *StepSCEV = SE->getMinusSCEV(SE->getSCEV(V.Roots[0]), ADR);
+    const SCEV *ScaleSCEV = SE->getConstant(StepSCEV->getType(), N);
+    if (ADR->getStepRecurrence(*SE) != SE->getMulExpr(StepSCEV, ScaleSCEV)) {
+      DEBUG(dbgs() << "LRR: Aborting because iterations are not consecutive\n");
+      return false;
+    }
+  }
+  Scale = RootSets[0].Roots.size() + 1;
+
+  if (Scale > IL_MaxRerollIterations) {
+    DEBUG(dbgs() << "LRR: Aborting - too many iterations found. "
+          << "#Found=" << Scale << ", #Max=" << IL_MaxRerollIterations
+          << "\n");
+    return false;
   }
 
+  DEBUG(dbgs() << "LRR: Successfully found roots: Scale=" << Scale << "\n");
+
   return true;
 }
 
@@ -746,43 +885,57 @@ bool LoopReroll::DAGRootTracker::collect
   }
 
   SmallInstructionSet Exclude;
-  Exclude.insert(Roots.begin(), Roots.end());
+  for (auto &DRS : RootSets) {
+    Exclude.insert(DRS.Roots.begin(), DRS.Roots.end());
+    Exclude.insert(DRS.SubsumedInsts.begin(), DRS.SubsumedInsts.end());
+    Exclude.insert(DRS.BaseInst);
+  }
   Exclude.insert(LoopIncs.begin(), LoopIncs.end());
 
-  DenseSet<Instruction*> VBase;
-  collectInLoopUserSet(IV, Exclude, PossibleRedSet, VBase);
-  for (auto *I : VBase) {
-    Uses[I].set(0);
-  }
+  for (auto &DRS : RootSets) {
+    DenseSet<Instruction*> VBase;
+    collectInLoopUserSet(DRS.BaseInst, Exclude, PossibleRedSet, VBase);
+    for (auto *I : VBase) {
+      Uses[I].set(0);
+    }
 
-  unsigned Idx = 1;
-  for (auto *Root : Roots) {
-    DenseSet<Instruction*> V;
-    collectInLoopUserSet(Root, Exclude, PossibleRedSet, V);
-
-    // While we're here, check the use sets are the same size.
-    if (V.size() != VBase.size()) {
-      DEBUG(dbgs() << "LRR: Aborting - use sets are different sizes\n");
-      return false;
+    unsigned Idx = 1;
+    for (auto *Root : DRS.Roots) {
+      DenseSet<Instruction*> V;
+      collectInLoopUserSet(Root, Exclude, PossibleRedSet, V);
+
+      // While we're here, check the use sets are the same size.
+      if (V.size() != VBase.size()) {
+        DEBUG(dbgs() << "LRR: Aborting - use sets are different sizes\n");
+        return false;
+      }
+
+      for (auto *I : V) {
+        Uses[I].set(Idx);
+      }
+      ++Idx;
     }
 
-    for (auto *I : V) {
-      Uses[I].set(Idx);
+    // Make sure our subsumed instructions are remembered too.
+    for (auto *I : DRS.SubsumedInsts) {
+      Uses[I].set(IL_All);
     }
-    ++Idx;
   }
 
   // Make sure the loop increments are also accounted for.
+
   Exclude.clear();
-  Exclude.insert(Roots.begin(), Roots.end());
+  for (auto &DRS : RootSets) {
+    Exclude.insert(DRS.Roots.begin(), DRS.Roots.end());
+    Exclude.insert(DRS.SubsumedInsts.begin(), DRS.SubsumedInsts.end());
+    Exclude.insert(DRS.BaseInst);
+  }
 
   DenseSet<Instruction*> V;
   collectInLoopUserSet(LoopIncs, Exclude, PossibleRedSet, V);
   for (auto *I : V) {
-    Uses[I].set(IL_LoopIncIdx);
+    Uses[I].set(IL_All);
   }
-  if (IV != RealIV)
-    Uses[RealIV].set(IL_LoopIncIdx);
 
   return true;
 
@@ -796,6 +949,22 @@ LoopReroll::DAGRootTracker::nextInstr(in
   return I;
 }
 
+bool LoopReroll::DAGRootTracker::isBaseInst(Instruction *I) {
+  for (auto &DRS : RootSets) {
+    if (DRS.BaseInst == I)
+      return true;
+  }
+  return false;
+}
+
+bool LoopReroll::DAGRootTracker::isRootInst(Instruction *I) {
+  for (auto &DRS : RootSets) {
+    if (std::find(DRS.Roots.begin(), DRS.Roots.end(), I) != DRS.Roots.end())
+      return true;
+  }
+  return false;
+}
+
 bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) {
   // We now need to check for equivalence of the use graph of each root with
   // that of the primary induction variable (excluding the roots). Our goal
@@ -823,7 +992,7 @@ bool LoopReroll::DAGRootTracker::validat
 
   // Make sure we mark the reduction PHIs as used in all iterations.
   for (auto *I : PossibleRedPHISet) {
-    Uses[I].set(IL_LoopIncIdx);
+    Uses[I].set(IL_All);
   }
 
   // Make sure all instructions in the loop are in one and only one
@@ -863,11 +1032,11 @@ bool LoopReroll::DAGRootTracker::validat
 
       // Skip over the IV or root instructions; only match their users.
       bool Continue = false;
-      if (BaseInst == RealIV || BaseInst == IV) {
+      if (isBaseInst(BaseInst)) {
         BaseIt = nextInstr(0, Uses, ++BaseIt);
         Continue = true;
       }
-      if (std::find(Roots.begin(), Roots.end(), RootInst) != Roots.end()) {
+      if (isRootInst(RootInst)) {
         LastRootIt = RootIt;
         RootIt = nextInstr(Iter, Uses, ++RootIt);
         Continue = true;
@@ -961,10 +1130,16 @@ bool LoopReroll::DAGRootTracker::validat
                 continue;
 
           DenseMap<Value *, Value *>::iterator BMI = BaseMap.find(Op2);
-          if (BMI != BaseMap.end())
+          if (BMI != BaseMap.end()) {
             Op2 = BMI->second;
-          else if (Roots[Iter-1] == (Instruction*) Op2)
-            Op2 = IV;
+          } else {
+            for (auto &DRS : RootSets) {
+              if (DRS.Roots[Iter-1] == (Instruction*) Op2) {
+                Op2 = DRS.BaseInst;
+                break;
+              }
+            }
+          }
 
           if (BaseInst->getOperand(Swapped ? unsigned(!j) : j) != Op2) {
             // If we've not already decided to swap the matched operands, and
@@ -1007,7 +1182,7 @@ bool LoopReroll::DAGRootTracker::validat
   }
 
   DEBUG(dbgs() << "LRR: Matched all iteration increments for " <<
-                  *RealIV << "\n");
+                  *IV << "\n");
 
   return true;
 }
@@ -1018,7 +1193,7 @@ void LoopReroll::DAGRootTracker::replace
   for (BasicBlock::reverse_iterator J = Header->rbegin();
        J != Header->rend();) {
     unsigned I = Uses[&*J].find_first();
-    if (I > 0 && I < IL_LoopIncIdx) {
+    if (I > 0 && I < IL_All) {
       Instruction *D = &*J;
       DEBUG(dbgs() << "LRR: removing: " << *D << "\n");
       D->eraseFromParent();
@@ -1028,54 +1203,53 @@ void LoopReroll::DAGRootTracker::replace
     ++J;
   }
 
-  // Insert the new induction variable.
-  const SCEVAddRecExpr *RealIVSCEV = cast<SCEVAddRecExpr>(SE->getSCEV(RealIV));
-  const SCEV *Start = RealIVSCEV->getStart();
-  if (Inc == 1)
-    Start = SE->getMulExpr(Start,
-                           SE->getConstant(Start->getType(), Scale));
-  const SCEVAddRecExpr *H =
-    cast<SCEVAddRecExpr>(SE->getAddRecExpr(Start,
-                           SE->getConstant(RealIVSCEV->getType(), 1),
-                           L, SCEV::FlagAnyWrap));
-  { // Limit the lifetime of SCEVExpander.
-    SCEVExpander Expander(*SE, "reroll");
-    Value *NewIV = Expander.expandCodeFor(H, IV->getType(), Header->begin());
+  // We need to create a new induction variable for each different BaseInst.
+  for (auto &DRS : RootSets) {
+    // Insert the new induction variable.
+    const SCEVAddRecExpr *RealIVSCEV =
+      cast<SCEVAddRecExpr>(SE->getSCEV(DRS.BaseInst));
+    const SCEV *Start = RealIVSCEV->getStart();
+    const SCEVAddRecExpr *H = cast<SCEVAddRecExpr>
+      (SE->getAddRecExpr(Start,
+                         SE->getConstant(RealIVSCEV->getType(), 1),
+                         L, SCEV::FlagAnyWrap));
+    { // Limit the lifetime of SCEVExpander.
+      SCEVExpander Expander(*SE, "reroll");
+      Value *NewIV = Expander.expandCodeFor(H, IV->getType(), Header->begin());
+
+      for (auto &KV : Uses) {
+        if (KV.second.find_first() == 0)
+          KV.first->replaceUsesOfWith(DRS.BaseInst, NewIV);
+      }
+
+      if (BranchInst *BI = dyn_cast<BranchInst>(Header->getTerminator())) {
+        // FIXME: Why do we need this check?
+        if (Uses[BI].find_first() == IL_All) {
+          const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE);
+
+          // Iteration count SCEV minus 1
+          const SCEV *ICMinus1SCEV =
+            SE->getMinusSCEV(ICSCEV, SE->getConstant(ICSCEV->getType(), 1));
+
+          Value *ICMinus1; // Iteration count minus 1
+          if (isa<SCEVConstant>(ICMinus1SCEV)) {
+            ICMinus1 = Expander.expandCodeFor(ICMinus1SCEV, NewIV->getType(), BI);
+          } else {
+            BasicBlock *Preheader = L->getLoopPreheader();
+            if (!Preheader)
+              Preheader = InsertPreheaderForLoop(L, Parent);
 
-    for (auto &KV : Uses) {
-      if (KV.second.find_first() == 0)
-        KV.first->replaceUsesOfWith(IV, NewIV);
-    }
-
-    if (BranchInst *BI = dyn_cast<BranchInst>(Header->getTerminator())) {
-      // FIXME: Why do we need this check?
-      if (Uses[BI].find_first() == IL_LoopIncIdx) {
-        const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE);
-        if (Inc == 1)
-          ICSCEV =
-            SE->getMulExpr(ICSCEV, SE->getConstant(ICSCEV->getType(), Scale));
-        // Iteration count SCEV minus 1
-        const SCEV *ICMinus1SCEV =
-          SE->getMinusSCEV(ICSCEV, SE->getConstant(ICSCEV->getType(), 1));
-
-        Value *ICMinus1; // Iteration count minus 1
-        if (isa<SCEVConstant>(ICMinus1SCEV)) {
-          ICMinus1 = Expander.expandCodeFor(ICMinus1SCEV, NewIV->getType(), BI);
-        } else {
-          BasicBlock *Preheader = L->getLoopPreheader();
-          if (!Preheader)
-            Preheader = InsertPreheaderForLoop(L, Parent);
-
-          ICMinus1 = Expander.expandCodeFor(ICMinus1SCEV, NewIV->getType(),
-                                            Preheader->getTerminator());
-        }
+            ICMinus1 = Expander.expandCodeFor(ICMinus1SCEV, NewIV->getType(),
+                                              Preheader->getTerminator());
+          }
 
-        Value *Cond =
+          Value *Cond =
             new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, ICMinus1, "exitcond");
-        BI->setCondition(Cond);
+          BI->setCondition(Cond);
 
-        if (BI->getSuccessor(1) != Header)
-          BI->swapSuccessors();
+          if (BI->getSuccessor(1) != Header)
+            BI->swapSuccessors();
+        }
       }
     }
   }

Modified: llvm/trunk/test/Transforms/LoopReroll/basic.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/LoopReroll/basic.ll?rev=228821&r1=228820&r2=228821&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/LoopReroll/basic.ll (original)
+++ llvm/trunk/test/Transforms/LoopReroll/basic.ll Wed Feb 11 03:19:47 2015
@@ -322,6 +322,173 @@ for.end:
   ret void
 }
 
+; void multi1(int *x) {
+;   y = foo(0)
+;   for (int i = 0; i < 500; ++i) {
+;     x[3*i] = y;
+;     x[3*i+1] = y;
+;     x[3*i+2] = y;
+;     x[3*i+6] = y;
+;     x[3*i+7] = y;
+;     x[3*i+8] = y;
+;   }
+; }
+
+; Function Attrs: nounwind uwtable
+define void @multi1(i32* nocapture %x) #0 {
+entry:
+  %call = tail call i32 @foo(i32 0) #1
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %0 = mul nsw i64 %indvars.iv, 3
+  %arrayidx = getelementptr inbounds i32* %x, i64 %0
+  store i32 %call, i32* %arrayidx, align 4
+  %1 = add nsw i64 %0, 1
+  %arrayidx4 = getelementptr inbounds i32* %x, i64 %1
+  store i32 %call, i32* %arrayidx4, align 4
+  %2 = add nsw i64 %0, 2
+  %arrayidx9 = getelementptr inbounds i32* %x, i64 %2
+  store i32 %call, i32* %arrayidx9, align 4
+  %3 = add nsw i64 %0, 6
+  %arrayidx6 = getelementptr inbounds i32* %x, i64 %3
+  store i32 %call, i32* %arrayidx6, align 4
+  %4 = add nsw i64 %0, 7
+  %arrayidx7 = getelementptr inbounds i32* %x, i64 %4
+  store i32 %call, i32* %arrayidx7, align 4
+  %5 = add nsw i64 %0, 8
+  %arrayidx8 = getelementptr inbounds i32* %x, i64 %5
+  store i32 %call, i32* %arrayidx8, align 4
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, 500
+  br i1 %exitcond, label %for.end, label %for.body
+
+; CHECK-LABEL: @multi1
+
+; CHECK:for.body:
+; CHECK:  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+; CHECK:  %0 = add i64 %indvars.iv, 6
+; CHECK:  %arrayidx = getelementptr inbounds i32* %x, i64 %indvars.iv
+; CHECK:  store i32 %call, i32* %arrayidx, align 4
+; CHECK:  %arrayidx6 = getelementptr inbounds i32* %x, i64 %0
+; CHECK:  store i32 %call, i32* %arrayidx6, align 4
+; CHECK:  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+; CHECK:  %exitcond2 = icmp eq i64 %0, 1505
+; CHECK:  br i1 %exitcond2, label %for.end, label %for.body
+
+for.end:                                          ; preds = %for.body
+  ret void
+}
+
+; void multi2(int *x) {
+;   y = foo(0)
+;   for (int i = 0; i < 500; ++i) {
+;     x[3*i] = y;
+;     x[3*i+1] = y;
+;     x[3*i+2] = y;
+;     x[3*(i+1)] = y;
+;     x[3*(i+1)+1] = y;
+;     x[3*(i+1)+2] = y;
+;   }
+; }
+
+; Function Attrs: nounwind uwtable
+define void @multi2(i32* nocapture %x) #0 {
+entry:
+  %call = tail call i32 @foo(i32 0) #1
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %0 = mul nsw i64 %indvars.iv, 3
+  %add = add nsw i64 %indvars.iv, 1
+  %newmul = mul nsw i64 %add, 3
+  %arrayidx = getelementptr inbounds i32* %x, i64 %0
+  store i32 %call, i32* %arrayidx, align 4
+  %1 = add nsw i64 %0, 1
+  %arrayidx4 = getelementptr inbounds i32* %x, i64 %1
+  store i32 %call, i32* %arrayidx4, align 4
+  %2 = add nsw i64 %0, 2
+  %arrayidx9 = getelementptr inbounds i32* %x, i64 %2
+  store i32 %call, i32* %arrayidx9, align 4
+  %arrayidx6 = getelementptr inbounds i32* %x, i64 %newmul
+  store i32 %call, i32* %arrayidx6, align 4
+  %3 = add nsw i64 %newmul, 1
+  %arrayidx7 = getelementptr inbounds i32* %x, i64 %3
+  store i32 %call, i32* %arrayidx7, align 4
+  %4 = add nsw i64 %newmul, 2
+  %arrayidx8 = getelementptr inbounds i32* %x, i64 %4
+  store i32 %call, i32* %arrayidx8, align 4
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, 500
+  br i1 %exitcond, label %for.end, label %for.body
+
+; CHECK-LABEL: @multi2
+
+; CHECK:for.body:
+; CHECK:  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+; CHECK:  %0 = add i64 %indvars.iv, 3
+; CHECK:  %arrayidx = getelementptr inbounds i32* %x, i64 %indvars.iv
+; CHECK:  store i32 %call, i32* %arrayidx, align 4
+; CHECK:  %arrayidx6 = getelementptr inbounds i32* %x, i64 %0
+; CHECK:  store i32 %call, i32* %arrayidx6, align 4
+; CHECK:  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+; CHECK:  %exitcond2 = icmp eq i64 %indvars.iv, 1499
+; CHECK:  br i1 %exitcond2, label %for.end, label %for.body
+
+for.end:                                          ; preds = %for.body
+  ret void
+}
+
+; void multi3(int *x) {
+;   y = foo(0)
+;   for (int i = 0; i < 500; ++i) {
+;     // Note: No zero index
+;     x[3*i+3] = y;
+;     x[3*i+4] = y;
+;     x[3*i+5] = y;
+;   }
+; }
+
+; Function Attrs: nounwind uwtable
+define void @multi3(i32* nocapture %x) #0 {
+entry:
+  %call = tail call i32 @foo(i32 0) #1
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %0 = mul nsw i64 %indvars.iv, 3
+  %x0 = add nsw i64 %0, 3
+  %add = add nsw i64 %indvars.iv, 1
+  %arrayidx = getelementptr inbounds i32* %x, i64 %x0
+  store i32 %call, i32* %arrayidx, align 4
+  %1 = add nsw i64 %0, 4
+  %arrayidx4 = getelementptr inbounds i32* %x, i64 %1
+  store i32 %call, i32* %arrayidx4, align 4
+  %2 = add nsw i64 %0, 5
+  %arrayidx9 = getelementptr inbounds i32* %x, i64 %2
+  store i32 %call, i32* %arrayidx9, align 4
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, 500
+  br i1 %exitcond, label %for.end, label %for.body
+
+; CHECK-LABEL: @multi3
+; CHECK: for.body:
+; CHECK:   %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+; CHECK:   %0 = add i64 %indvars.iv, 3
+; CHECK:   %arrayidx = getelementptr inbounds i32* %x, i64 %0
+; CHECK:   store i32 %call, i32* %arrayidx, align 4
+; CHECK:   %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+; CHECK:   %exitcond1 = icmp eq i64 %0, 1502
+; CHECK:   br i1 %exitcond1, label %for.end, label %for.body
+
+for.end:                                          ; preds = %for.body
+  ret void
+}
+
+
 attributes #0 = { nounwind uwtable }
 attributes #1 = { nounwind }
 





More information about the llvm-commits mailing list