r206039 - [PGO] Change MapRegionCounters to be a RecursiveASTVisitor.

Bob Wilson bob.wilson at apple.com
Fri Apr 11 10:16:14 PDT 2014


Author: bwilson
Date: Fri Apr 11 12:16:13 2014
New Revision: 206039

URL: http://llvm.org/viewvc/llvm-project?rev=206039&view=rev
Log:
[PGO] Change MapRegionCounters to be a RecursiveASTVisitor.

This avoids the overhead of specifying all the traversal code when using
ConstStmtVisitor and makes it a lot easier to maintain this.

Modified:
    cfe/trunk/lib/CodeGen/CodeGenPGO.cpp

Modified: cfe/trunk/lib/CodeGen/CodeGenPGO.cpp
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/lib/CodeGen/CodeGenPGO.cpp?rev=206039&r1=206038&r2=206039&view=diff
==============================================================================
--- cfe/trunk/lib/CodeGen/CodeGenPGO.cpp (original)
+++ cfe/trunk/lib/CodeGen/CodeGenPGO.cpp Fri Apr 11 12:16:13 2014
@@ -321,8 +321,8 @@ llvm::Function *CodeGenPGO::emitInitiali
 }
 
 namespace {
-  /// A StmtVisitor that fills a map of statements to PGO counters.
-  struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> {
+  /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
+  struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
     /// The next counter value to assign.
     unsigned NextCounter;
     /// The map of statements to counters.
@@ -331,135 +331,55 @@ namespace {
     MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
         : NextCounter(0), CounterMap(CounterMap) {}
 
-    void VisitChildren(const Stmt *S) {
-      for (Stmt::const_child_range I = S->children(); I; ++I)
-        if (*I)
-         this->Visit(*I);
+    // Do not traverse the BlockDecl inside a BlockExpr since each BlockDecl
+    // is handled as a separate function.
+    bool TraverseBlockExpr(BlockExpr *block) { return true; }
+
+    bool VisitDecl(const Decl *D) {
+      switch (D->getKind()) {
+      default:
+        break;
+      case Decl::Function:
+      case Decl::CXXMethod:
+      case Decl::CXXConstructor:
+      case Decl::CXXDestructor:
+      case Decl::CXXConversion:
+      case Decl::ObjCMethod:
+      case Decl::Block:
+        CounterMap[D->getBody()] = NextCounter++;
+        break;
+      }
+      return true;
     }
-    void VisitStmt(const Stmt *S) { VisitChildren(S); }
 
-    /// Assign a counter to track entry to the function body.
-    void VisitFunctionDecl(const FunctionDecl *D) {
-      CounterMap[D->getBody()] = NextCounter++;
-      Visit(D->getBody());
-    }
-    void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
-      CounterMap[D->getBody()] = NextCounter++;
-      Visit(D->getBody());
-    }
-    void VisitBlockDecl(const BlockDecl *D) {
-      CounterMap[D->getBody()] = NextCounter++;
-      Visit(D->getBody());
-    }
-    /// Assign a counter to track the block following a label.
-    void VisitLabelStmt(const LabelStmt *S) {
-      CounterMap[S] = NextCounter++;
-      Visit(S->getSubStmt());
-    }
-    /// Assign a counter for the body of a while loop.
-    void VisitWhileStmt(const WhileStmt *S) {
-      CounterMap[S] = NextCounter++;
-      Visit(S->getCond());
-      Visit(S->getBody());
-    }
-    /// Assign a counter for the body of a do-while loop.
-    void VisitDoStmt(const DoStmt *S) {
-      CounterMap[S] = NextCounter++;
-      Visit(S->getBody());
-      Visit(S->getCond());
-    }
-    /// Assign a counter for the body of a for loop.
-    void VisitForStmt(const ForStmt *S) {
-      CounterMap[S] = NextCounter++;
-      if (S->getInit())
-        Visit(S->getInit());
-      const Expr *E;
-      if ((E = S->getCond()))
-        Visit(E);
-      if ((E = S->getInc()))
-        Visit(E);
-      Visit(S->getBody());
-    }
-    /// Assign a counter for the body of a for-range loop.
-    void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
-      CounterMap[S] = NextCounter++;
-      Visit(S->getRangeStmt());
-      Visit(S->getBeginEndStmt());
-      Visit(S->getCond());
-      Visit(S->getLoopVarStmt());
-      Visit(S->getBody());
-      Visit(S->getInc());
-    }
-    /// Assign a counter for the body of a for-collection loop.
-    void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
-      CounterMap[S] = NextCounter++;
-      Visit(S->getElement());
-      Visit(S->getBody());
-    }
-    /// Assign a counter for the exit block of the switch statement.
-    void VisitSwitchStmt(const SwitchStmt *S) {
-      CounterMap[S] = NextCounter++;
-      Visit(S->getCond());
-      Visit(S->getBody());
-    }
-    /// Assign a counter for a particular case in a switch. This counts jumps
-    /// from the switch header as well as fallthrough from the case before this
-    /// one.
-    void VisitCaseStmt(const CaseStmt *S) {
-      CounterMap[S] = NextCounter++;
-      Visit(S->getSubStmt());
-    }
-    /// Assign a counter for the default case of a switch statement. The count
-    /// is the number of branches from the loop header to the default, and does
-    /// not include fallthrough from previous cases. If we have multiple
-    /// conditional branch blocks from the switch instruction to the default
-    /// block, as with large GNU case ranges, this is the counter for the last
-    /// edge in that series, rather than the first.
-    void VisitDefaultStmt(const DefaultStmt *S) {
-      CounterMap[S] = NextCounter++;
-      Visit(S->getSubStmt());
-    }
-    /// Assign a counter for the "then" part of an if statement. The count for
-    /// the "else" part, if it exists, will be calculated from this counter.
-    void VisitIfStmt(const IfStmt *S) {
-      CounterMap[S] = NextCounter++;
-      Visit(S->getCond());
-      Visit(S->getThen());
-      if (S->getElse())
-        Visit(S->getElse());
-    }
-    /// Assign a counter for the continuation block of a C++ try statement.
-    void VisitCXXTryStmt(const CXXTryStmt *S) {
-      CounterMap[S] = NextCounter++;
-      Visit(S->getTryBlock());
-      for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
-        Visit(S->getHandler(I));
-    }
-    /// Assign a counter for a catch statement's handler block.
-    void VisitCXXCatchStmt(const CXXCatchStmt *S) {
-      CounterMap[S] = NextCounter++;
-      Visit(S->getHandlerBlock());
-    }
-    /// Assign a counter for the "true" part of a conditional operator. The
-    /// count in the "false" part will be calculated from this counter.
-    void VisitAbstractConditionalOperator(
-        const AbstractConditionalOperator *E) {
-      CounterMap[E] = NextCounter++;
-      Visit(E->getCond());
-      Visit(E->getTrueExpr());
-      Visit(E->getFalseExpr());
-    }
-    /// Assign a counter for the right hand side of a logical and operator.
-    void VisitBinLAnd(const BinaryOperator *E) {
-      CounterMap[E] = NextCounter++;
-      Visit(E->getLHS());
-      Visit(E->getRHS());
-    }
-    /// Assign a counter for the right hand side of a logical or operator.
-    void VisitBinLOr(const BinaryOperator *E) {
-      CounterMap[E] = NextCounter++;
-      Visit(E->getLHS());
-      Visit(E->getRHS());
+    bool VisitStmt(const Stmt *S) {
+      switch (S->getStmtClass()) {
+      default:
+        break;
+      case Stmt::LabelStmtClass:
+      case Stmt::WhileStmtClass:
+      case Stmt::DoStmtClass:
+      case Stmt::ForStmtClass:
+      case Stmt::CXXForRangeStmtClass:
+      case Stmt::ObjCForCollectionStmtClass:
+      case Stmt::SwitchStmtClass:
+      case Stmt::CaseStmtClass:
+      case Stmt::DefaultStmtClass:
+      case Stmt::IfStmtClass:
+      case Stmt::CXXTryStmtClass:
+      case Stmt::CXXCatchStmtClass:
+      case Stmt::ConditionalOperatorClass:
+      case Stmt::BinaryConditionalOperatorClass:
+        CounterMap[S] = NextCounter++;
+        break;
+      case Stmt::BinaryOperatorClass: {
+        const BinaryOperator *BO = cast<BinaryOperator>(S);
+        if (BO->getOpcode() == BO_LAnd || BO->getOpcode() == BO_LOr)
+          CounterMap[S] = NextCounter++;
+        break;
+      }
+      }
+      return true;
     }
   };
 
@@ -504,6 +424,7 @@ namespace {
     }
 
     void VisitFunctionDecl(const FunctionDecl *D) {
+      // Counter tracks entry to the function body.
       RegionCounter Cnt(PGO, D->getBody());
       Cnt.beginRegion();
       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
@@ -511,6 +432,7 @@ namespace {
     }
 
     void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
+      // Counter tracks entry to the method body.
       RegionCounter Cnt(PGO, D->getBody());
       Cnt.beginRegion();
       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
@@ -518,6 +440,7 @@ namespace {
     }
 
     void VisitBlockDecl(const BlockDecl *D) {
+      // Counter tracks entry to the block body.
       RegionCounter Cnt(PGO, D->getBody());
       Cnt.beginRegion();
       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
@@ -540,6 +463,7 @@ namespace {
 
     void VisitLabelStmt(const LabelStmt *S) {
       RecordNextStmtCount = false;
+      // Counter tracks the block following the label.
       RegionCounter Cnt(PGO, S);
       Cnt.beginRegion();
       CountMap[S] = PGO.getCurrentRegionCount();
@@ -564,6 +488,7 @@ namespace {
 
     void VisitWhileStmt(const WhileStmt *S) {
       RecordStmtCount(S);
+      // Counter tracks the body of the loop.
       RegionCounter Cnt(PGO, S);
       BreakContinueStack.push_back(BreakContinue());
       // Visit the body region first so the break/continue adjustments can be
@@ -589,6 +514,7 @@ namespace {
 
     void VisitDoStmt(const DoStmt *S) {
       RecordStmtCount(S);
+      // Counter tracks the body of the loop.
       RegionCounter Cnt(PGO, S);
       BreakContinueStack.push_back(BreakContinue());
       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
@@ -615,6 +541,7 @@ namespace {
       RecordStmtCount(S);
       if (S->getInit())
         Visit(S->getInit());
+      // Counter tracks the body of the loop.
       RegionCounter Cnt(PGO, S);
       BreakContinueStack.push_back(BreakContinue());
       // Visit the body region first. (This is basically the same as a while
@@ -653,6 +580,7 @@ namespace {
       RecordStmtCount(S);
       Visit(S->getRangeStmt());
       Visit(S->getBeginEndStmt());
+      // Counter tracks the body of the loop.
       RegionCounter Cnt(PGO, S);
       BreakContinueStack.push_back(BreakContinue());
       // Visit the body region first. (This is basically the same as a while
@@ -687,6 +615,7 @@ namespace {
     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
       RecordStmtCount(S);
       Visit(S->getElement());
+      // Counter tracks the body of the loop.
       RegionCounter Cnt(PGO, S);
       BreakContinueStack.push_back(BreakContinue());
       Cnt.beginRegion();
@@ -708,6 +637,7 @@ namespace {
       BreakContinue BC = BreakContinueStack.pop_back_val();
       if (!BreakContinueStack.empty())
         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
+      // Counter tracks the exit block of the switch.
       RegionCounter ExitCnt(PGO, S);
       ExitCnt.beginRegion();
       RecordNextStmtCount = true;
@@ -715,6 +645,9 @@ namespace {
 
     void VisitCaseStmt(const CaseStmt *S) {
       RecordNextStmtCount = false;
+      // Counter for this particular case. This counts only jumps from the
+      // switch header and does not include fallthrough from the case before
+      // this one.
       RegionCounter Cnt(PGO, S);
       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
       CountMap[S] = Cnt.getCount();
@@ -724,6 +657,8 @@ namespace {
 
     void VisitDefaultStmt(const DefaultStmt *S) {
       RecordNextStmtCount = false;
+      // Counter for this default case. This does not include fallthrough from
+      // the previous case.
       RegionCounter Cnt(PGO, S);
       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
       CountMap[S] = Cnt.getCount();
@@ -733,6 +668,8 @@ namespace {
 
     void VisitIfStmt(const IfStmt *S) {
       RecordStmtCount(S);
+      // Counter tracks the "then" part of an if statement. The count for
+      // the "else" part, if it exists, will be calculated from this counter.
       RegionCounter Cnt(PGO, S);
       Visit(S->getCond());
 
@@ -756,6 +693,7 @@ namespace {
       Visit(S->getTryBlock());
       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
         Visit(S->getHandler(I));
+      // Counter tracks the continuation block of the try statement.
       RegionCounter Cnt(PGO, S);
       Cnt.beginRegion();
       RecordNextStmtCount = true;
@@ -763,6 +701,7 @@ namespace {
 
     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
       RecordNextStmtCount = false;
+      // Counter tracks the catch statement's handler block.
       RegionCounter Cnt(PGO, S);
       Cnt.beginRegion();
       CountMap[S] = PGO.getCurrentRegionCount();
@@ -772,6 +711,8 @@ namespace {
     void VisitAbstractConditionalOperator(
         const AbstractConditionalOperator *E) {
       RecordStmtCount(E);
+      // Counter tracks the "true" part of a conditional operator. The
+      // count in the "false" part will be calculated from this counter.
       RegionCounter Cnt(PGO, E);
       Visit(E->getCond());
 
@@ -791,6 +732,7 @@ namespace {
 
     void VisitBinLAnd(const BinaryOperator *E) {
       RecordStmtCount(E);
+      // Counter tracks the right hand side of a logical and operator.
       RegionCounter Cnt(PGO, E);
       Visit(E->getLHS());
       Cnt.beginRegion();
@@ -803,6 +745,7 @@ namespace {
 
     void VisitBinLOr(const BinaryOperator *E) {
       RecordStmtCount(E);
+      // Counter tracks the right hand side of a logical or operator.
       RegionCounter Cnt(PGO, E);
       Visit(E->getLHS());
       Cnt.beginRegion();
@@ -884,11 +827,11 @@ void CodeGenPGO::mapRegionCounters(const
   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
   MapRegionCounters Walker(*RegionCounterMap);
   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
-    Walker.VisitFunctionDecl(FD);
+    Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
-    Walker.VisitObjCMethodDecl(MD);
+    Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
-    Walker.VisitBlockDecl(BD);
+    Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
   NumRegionCounters = Walker.NextCounter;
   // FIXME: The number of counters isn't sufficient for the hash
   FunctionHash = NumRegionCounters;





More information about the cfe-commits mailing list