[polly] r246142 - Use ISL to Determine Loop Trip Count

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 26 23:53:53 PDT 2015


Author: jdoerfert
Date: Thu Aug 27 01:53:52 2015
New Revision: 246142

URL: http://llvm.org/viewvc/llvm-project?rev=246142&view=rev
Log:
Use ISL to Determine Loop Trip Count

  Use ISL to compute the loop trip count when scalar evolution is unable to do
  so.

Contributed-by: Matthew Simpson <mssimpso at codeaurora.org>

Differential Revision: http://reviews.llvm.org/D9444


Added:
    polly/trunk/test/ScopInfo/isl_trip_count_01.ll
    polly/trunk/test/ScopInfo/isl_trip_count_02.ll
Modified:
    polly/trunk/include/polly/ScopDetection.h
    polly/trunk/include/polly/ScopInfo.h
    polly/trunk/include/polly/Support/SCEVAffinator.h
    polly/trunk/lib/Analysis/ScopDetection.cpp
    polly/trunk/lib/Analysis/ScopInfo.cpp
    polly/trunk/lib/Support/SCEVAffinator.cpp

Modified: polly/trunk/include/polly/ScopDetection.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/ScopDetection.h?rev=246142&r1=246141&r2=246142&view=diff
==============================================================================
--- polly/trunk/include/polly/ScopDetection.h (original)
+++ polly/trunk/include/polly/ScopDetection.h Thu Aug 27 01:53:52 2015
@@ -299,6 +299,14 @@ private:
   /// @note An OpenMP subfunction will be marked as invalid.
   bool isValidFunction(llvm::Function &F);
 
+  /// @brief Can ISL compute the trip count of a loop.
+  ///
+  /// @param L The loop to check.
+  /// @param Context The context of scop detection.
+  ///
+  /// @return True if ISL can compute the trip count of the loop.
+  bool canUseISLTripCount(Loop *L, DetectionContext &Context) const;
+
   /// @brief Print the locations of all detected scops.
   void printLocations(llvm::Function &F);
 

Modified: polly/trunk/include/polly/ScopInfo.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/ScopInfo.h?rev=246142&r1=246141&r2=246142&view=diff
==============================================================================
--- polly/trunk/include/polly/ScopInfo.h (original)
+++ polly/trunk/include/polly/ScopInfo.h Thu Aug 27 01:53:52 2015
@@ -557,6 +557,7 @@ private:
   __isl_give isl_set *buildConditionSet(const Comparison &Cmp);
   void addConditionsToDomain(TempScop &tempScop, const Region &CurRegion);
   void addLoopBoundsToDomain(TempScop &tempScop);
+  void addLoopTripCountToDomain(const Loop *L);
   void buildDomain(TempScop &tempScop, const Region &CurRegion);
 
   /// @brief Create the accesses for instructions in @p Block.
@@ -1221,6 +1222,9 @@ public:
   ///
   /// @return true if a change was made
   bool restrictDomains(__isl_take isl_union_set *Domain);
+
+  /// @brief Get the depth of a loop relative to the outermost loop in the Scop.
+  unsigned getRelativeLoopDepth(const Loop *L) const;
 };
 
 /// @brief Print Scop scop to raw_ostream O.

Modified: polly/trunk/include/polly/Support/SCEVAffinator.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/Support/SCEVAffinator.h?rev=246142&r1=246141&r2=246142&view=diff
==============================================================================
--- polly/trunk/include/polly/Support/SCEVAffinator.h (original)
+++ polly/trunk/include/polly/Support/SCEVAffinator.h Thu Aug 27 01:53:52 2015
@@ -71,8 +71,6 @@ private:
   llvm::ScalarEvolution &SE;
   const ScopStmt *Stmt;
 
-  int getLoopDepth(const llvm::Loop *L);
-
   __isl_give isl_pw_aff *visit(const llvm::SCEV *E);
   __isl_give isl_pw_aff *visitConstant(const llvm::SCEVConstant *E);
   __isl_give isl_pw_aff *visitTruncateExpr(const llvm::SCEVTruncateExpr *E);

Modified: polly/trunk/lib/Analysis/ScopDetection.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Analysis/ScopDetection.cpp?rev=246142&r1=246141&r2=246142&view=diff
==============================================================================
--- polly/trunk/lib/Analysis/ScopDetection.cpp (original)
+++ polly/trunk/lib/Analysis/ScopDetection.cpp Thu Aug 27 01:53:52 2015
@@ -166,6 +166,11 @@ static cl::opt<bool>
                 cl::Hidden, cl::init(false), cl::ZeroOrMore,
                 cl::cat(PollyCategory));
 
+static cl::opt<bool> AllowNonSCEVBackedgeTakenCount(
+    "polly-allow-non-scev-backedge-taken-count",
+    cl::desc("Allow loops even if SCEV cannot provide a trip count"),
+    cl::Hidden, cl::init(true), cl::ZeroOrMore, cl::cat(PollyCategory));
+
 bool polly::PollyTrackFailures = false;
 bool polly::PollyDelinearize = false;
 StringRef polly::PollySkipFnAttr = "polly.skip.fn";
@@ -728,10 +733,57 @@ bool ScopDetection::isValidInstruction(I
   return invalid<ReportUnknownInst>(Context, /*Assert=*/true, &Inst);
 }
 
+bool ScopDetection::canUseISLTripCount(Loop *L,
+                                       DetectionContext &Context) const {
+
+  Region &CurRegion = Context.CurRegion;
+
+  // Ensure the loop has a single back edge.
+  if (L->getNumBackEdges() != 1)
+    return false;
+
+  // Ensure the loop has a single exiting block.
+  BasicBlock *ExitingBB = L->getExitingBlock();
+  if (!ExitingBB)
+    return false;
+
+  // Ensure the exiting block is terminated by a conditional branch.
+  BranchInst *Term = dyn_cast<BranchInst>(ExitingBB->getTerminator());
+  if (!Term || !Term->isConditional())
+    return false;
+
+  Value *Cond = Term->getCondition();
+
+  // If the terminating condition is an integer comparison, ensure that it is a
+  // comparison between a recurrence and an invariant value.
+  if (ICmpInst *I = dyn_cast<ICmpInst>(Cond)) {
+    const Value *Op0 = I->getOperand(0);
+    const Value *Op1 = I->getOperand(1);
+    const SCEV *LHS = SE->getSCEVAtScope(const_cast<Value *>(Op0), L);
+    const SCEV *RHS = SE->getSCEVAtScope(const_cast<Value *>(Op1), L);
+    if ((isa<SCEVAddRecExpr>(LHS) && !isInvariant(*Op1, CurRegion)) ||
+        (isa<SCEVAddRecExpr>(RHS) && !isInvariant(*Op0, CurRegion)))
+      return false;
+  }
+
+  // If the terminating condition is not an integer comparison, ensure that it
+  // is a constant.
+  else if (!isa<ConstantInt>(Cond))
+    return false;
+
+  // We can use ISL to compute the trip count of L.
+  return true;
+}
+
 bool ScopDetection::isValidLoop(Loop *L, DetectionContext &Context) const {
   // Is the loop count affine?
+  bool IsLoopCountAffine = false;
   const SCEV *LoopCount = SE->getBackedgeTakenCount(L);
-  if (isAffineExpr(&Context.CurRegion, LoopCount, *SE)) {
+  if (!isa<SCEVCouldNotCompute>(LoopCount))
+    IsLoopCountAffine = isAffineExpr(&Context.CurRegion, LoopCount, *SE);
+  else
+    IsLoopCountAffine = canUseISLTripCount(L, Context);
+  if (IsLoopCountAffine) {
     Context.hasAffineLoops = true;
     return true;
   }

Modified: polly/trunk/lib/Analysis/ScopInfo.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Analysis/ScopInfo.cpp?rev=246142&r1=246141&r2=246142&view=diff
==============================================================================
--- polly/trunk/lib/Analysis/ScopInfo.cpp (original)
+++ polly/trunk/lib/Analysis/ScopInfo.cpp Thu Aug 27 01:53:52 2015
@@ -739,6 +739,65 @@ void ScopStmt::realignParams() {
   Domain = isl_set_align_params(Domain, Parent.getParamSpace());
 }
 
+void ScopStmt::addLoopTripCountToDomain(const Loop *L) {
+
+  unsigned loopDimension = getParent()->getRelativeLoopDepth(L);
+  ScalarEvolution *SE = getParent()->getSE();
+  isl_space *DomSpace = isl_set_get_space(Domain);
+
+  isl_space *MapSpace = isl_space_map_from_set(isl_space_copy(DomSpace));
+  isl_multi_aff *LoopMAff = isl_multi_aff_identity(MapSpace);
+  isl_aff *LoopAff = isl_multi_aff_get_aff(LoopMAff, loopDimension);
+  LoopAff = isl_aff_add_constant_si(LoopAff, 1);
+  LoopMAff = isl_multi_aff_set_aff(LoopMAff, loopDimension, LoopAff);
+  isl_map *TranslationMap = isl_map_from_multi_aff(LoopMAff);
+
+  BasicBlock *ExitingBB = L->getExitingBlock();
+  assert(ExitingBB && "Loop has more than one exiting block");
+
+  BranchInst *Term = dyn_cast<BranchInst>(ExitingBB->getTerminator());
+  assert(Term && Term->isConditional() && "Terminator is not conditional");
+
+  const SCEV *LHS = nullptr;
+  const SCEV *RHS = nullptr;
+  Value *Cond = Term->getCondition();
+  CmpInst::Predicate Pred = CmpInst::Predicate::BAD_ICMP_PREDICATE;
+
+  ICmpInst *CondICmpInst = dyn_cast<ICmpInst>(Cond);
+  ConstantInt *CondConstant = dyn_cast<ConstantInt>(Cond);
+  if (CondICmpInst) {
+    LHS = SE->getSCEVAtScope(CondICmpInst->getOperand(0), L);
+    RHS = SE->getSCEVAtScope(CondICmpInst->getOperand(1), L);
+    Pred = CondICmpInst->getPredicate();
+  } else if (CondConstant) {
+    LHS = SE->getConstant(CondConstant);
+    RHS = SE->getConstant(ConstantInt::getTrue(SE->getContext()));
+    Pred = CmpInst::Predicate::ICMP_EQ;
+  } else {
+    llvm_unreachable("Condition is neither a ConstantInt nor a ICmpInst");
+  }
+
+  if (!L->contains(Term->getSuccessor(0)))
+    Pred = ICmpInst::getInversePredicate(Pred);
+  Comparison Comp(LHS, RHS, Pred);
+
+  isl_set *CondSet = buildConditionSet(Comp);
+  isl_map *ForwardMap = isl_map_lex_le(isl_space_copy(DomSpace));
+  for (unsigned i = 0; i < isl_set_n_dim(Domain); i++)
+    if (i != loopDimension)
+      ForwardMap = isl_map_equate(ForwardMap, isl_dim_in, i, isl_dim_out, i);
+
+  ForwardMap = isl_map_apply_range(ForwardMap, isl_map_copy(TranslationMap));
+  isl_set *CondDom = isl_set_subtract(isl_set_copy(Domain), CondSet);
+  isl_set *ForwardCond = isl_set_apply(CondDom, isl_map_copy(ForwardMap));
+  isl_set *ForwardDomain = isl_set_apply(isl_set_copy(Domain), ForwardMap);
+  ForwardCond = isl_set_gist(ForwardCond, ForwardDomain);
+  Domain = isl_set_subtract(Domain, ForwardCond);
+
+  isl_map_free(TranslationMap);
+  isl_space_free(DomSpace);
+}
+
 __isl_give isl_set *ScopStmt::buildConditionSet(const Comparison &Comp) {
   isl_pw_aff *L = getPwAff(Comp.getLHS());
   isl_pw_aff *R = getPwAff(Comp.getRHS());
@@ -789,9 +848,15 @@ void ScopStmt::addLoopBoundsToDomain(Tem
     // IV <= LatchExecutions.
     const Loop *L = getLoopForDimension(i);
     const SCEV *LatchExecutions = SE->getBackedgeTakenCount(L);
-    isl_pw_aff *UpperBound = getPwAff(LatchExecutions);
-    isl_set *UpperBoundSet = isl_pw_aff_le_set(IV, UpperBound);
-    Domain = isl_set_intersect(Domain, UpperBoundSet);
+    if (!isa<SCEVCouldNotCompute>(LatchExecutions)) {
+      isl_pw_aff *UpperBound = getPwAff(LatchExecutions);
+      isl_set *UpperBoundSet = isl_pw_aff_le_set(IV, UpperBound);
+      Domain = isl_set_intersect(Domain, UpperBoundSet);
+    } else {
+      // If SCEV cannot provide a loop trip count we compute it with ISL.
+      addLoopTripCountToDomain(L);
+      isl_pw_aff_free(IV);
+    }
   }
 
   isl_local_space_free(LocalSpace);
@@ -2059,6 +2124,12 @@ ScopStmt *Scop::getStmtForBasicBlock(Bas
   return StmtMapIt->second;
 }
 
+unsigned Scop::getRelativeLoopDepth(const Loop *L) const {
+  Loop *OuterLoop = R.outermostLoopInRegion(const_cast<Loop *>(L));
+  assert(OuterLoop && "Scop does not contain this loop");
+  return L->getLoopDepth() - OuterLoop->getLoopDepth();
+}
+
 //===----------------------------------------------------------------------===//
 ScopInfo::ScopInfo() : RegionPass(ID), scop(0) {
   ctx = isl_ctx_alloc();

Modified: polly/trunk/lib/Support/SCEVAffinator.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Support/SCEVAffinator.cpp?rev=246142&r1=246141&r2=246142&view=diff
==============================================================================
--- polly/trunk/lib/Support/SCEVAffinator.cpp (original)
+++ polly/trunk/lib/Support/SCEVAffinator.cpp Thu Aug 27 01:53:52 2015
@@ -160,7 +160,7 @@ SCEVAffinator::visitAddRecExpr(const SCE
     isl_space *Space = isl_space_set_alloc(Ctx, 0, NumIterators);
     isl_local_space *LocalSpace = isl_local_space_from_space(Space);
 
-    int loopDimension = getLoopDepth(Expr->getLoop());
+    unsigned loopDimension = S->getRelativeLoopDepth(Expr->getLoop());
 
     isl_aff *LAff = isl_aff_set_coefficient_si(
         isl_aff_zero_on_domain(LocalSpace), isl_dim_in, loopDimension, 1);
@@ -248,9 +248,3 @@ __isl_give isl_pw_aff *SCEVAffinator::vi
   llvm_unreachable(
       "Unknowns SCEV was neither parameter nor a valid instruction.");
 }
-
-int SCEVAffinator::getLoopDepth(const Loop *L) {
-  Loop *outerLoop = S->getRegion().outermostLoopInRegion(const_cast<Loop *>(L));
-  assert(outerLoop && "Scop does not contain this loop");
-  return L->getLoopDepth() - outerLoop->getLoopDepth();
-}

Added: polly/trunk/test/ScopInfo/isl_trip_count_01.ll
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScopInfo/isl_trip_count_01.ll?rev=246142&view=auto
==============================================================================
--- polly/trunk/test/ScopInfo/isl_trip_count_01.ll (added)
+++ polly/trunk/test/ScopInfo/isl_trip_count_01.ll Thu Aug 27 01:53:52 2015
@@ -0,0 +1,38 @@
+; RUN: opt %loadPolly -polly-detect-unprofitable -polly-allow-non-scev-backedge-taken-count -polly-scops -analyze < %s | FileCheck %s
+;
+; CHECK: [M, N] -> { Stmt_while_body[i0] : i0 >= 0 and 4i0 <= -M + N }
+;
+;   void f(int *A, int N, int M) {
+;     int i = 0;
+;     while (M <= N) {
+;       A[i++] = 1;
+;       M += 4;
+;     }
+;   }
+;
+target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64"
+
+define void @f(i32* nocapture %A, i32 %N, i32 %M) {
+entry:
+  %cmp3 = icmp sgt i32 %M, %N
+  br i1 %cmp3, label %while.end, label %while.body.preheader
+
+while.body.preheader:
+  br label %while.body
+
+while.body:
+  %i.05 = phi i32 [ %inc, %while.body ], [ 0, %while.body.preheader ]
+  %M.addr.04 = phi i32 [ %add, %while.body ], [ %M, %while.body.preheader ]
+  %inc = add nuw nsw i32 %i.05, 1
+  %arrayidx = getelementptr inbounds i32, i32* %A, i32 %i.05
+  store i32 1, i32* %arrayidx, align 4
+  %add = add nsw i32 %M.addr.04, 4
+  %cmp = icmp sgt i32 %add, %N
+  br i1 %cmp, label %while.end.loopexit, label %while.body
+
+while.end.loopexit:
+  br label %while.end
+
+while.end:
+  ret void
+}

Added: polly/trunk/test/ScopInfo/isl_trip_count_02.ll
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScopInfo/isl_trip_count_02.ll?rev=246142&view=auto
==============================================================================
--- polly/trunk/test/ScopInfo/isl_trip_count_02.ll (added)
+++ polly/trunk/test/ScopInfo/isl_trip_count_02.ll Thu Aug 27 01:53:52 2015
@@ -0,0 +1,33 @@
+; RUN: opt %loadPolly -polly-detect-unprofitable -polly-allow-non-scev-backedge-taken-count -polly-scops -analyze < %s | FileCheck %s
+;
+; CHECK: [M, N] -> { Stmt_for_body[i0] : i0 >= 0 and N <= -1 + M };
+;
+;   void f(int *A, int N, int M) {
+;     for (int i = M; i > N; i++)
+;       A[i] = i;
+;   }
+;
+target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64"
+
+define void @f(i32* %A, i32 %N, i32 %M) {
+entry:
+  br label %entry.split
+
+entry.split:
+  %cmp.1 = icmp sgt i32 %M, %N
+  br i1 %cmp.1, label %for.body, label %for.end
+
+for.body:
+  %indvars.iv = phi i32 [ %indvars.iv.next, %for.body ], [ %M, %entry.split ]
+  %arrayidx = getelementptr inbounds i32, i32* %A, i32 %indvars.iv
+  store i32 %indvars.iv, i32* %arrayidx, align 4
+  %cmp = icmp slt i32 %M, %N
+  %indvars.iv.next = add i32 %indvars.iv, 1
+  br i1 %cmp, label %for.cond.for.end_crit_edge, label %for.body
+
+for.cond.for.end_crit_edge:
+  br label %for.end
+
+for.end:
+  ret void
+}




More information about the llvm-commits mailing list