r302624 - Add ASTMatchRefactorer and ReplaceNodeWithTemplate to RefactoringCallbacks

Eric Liu via cfe-commits cfe-commits at lists.llvm.org
Wed May 10 00:48:45 PDT 2017


Author: ioeric
Date: Wed May 10 02:48:45 2017
New Revision: 302624

URL: http://llvm.org/viewvc/llvm-project?rev=302624&view=rev
Log:
Add ASTMatchRefactorer and ReplaceNodeWithTemplate to RefactoringCallbacks

Summary: This is the first change as part of developing a clang-query based search and replace tool.

Reviewers: klimek, bkramer, ioeric, sbenza, jbangert

Reviewed By: ioeric, jbangert

Subscribers: sbenza, ioeric, cfe-commits

Patch by Julian Bangert!

Differential Revision: https://reviews.llvm.org/D29621

Modified:
    cfe/trunk/include/clang/Tooling/RefactoringCallbacks.h
    cfe/trunk/lib/Tooling/RefactoringCallbacks.cpp
    cfe/trunk/unittests/Tooling/RefactoringCallbacksTest.cpp

Modified: cfe/trunk/include/clang/Tooling/RefactoringCallbacks.h
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/include/clang/Tooling/RefactoringCallbacks.h?rev=302624&r1=302623&r2=302624&view=diff
==============================================================================
--- cfe/trunk/include/clang/Tooling/RefactoringCallbacks.h (original)
+++ cfe/trunk/include/clang/Tooling/RefactoringCallbacks.h Wed May 10 02:48:45 2017
@@ -47,6 +47,32 @@ protected:
   Replacements Replace;
 };
 
+/// \brief Adaptor between \c ast_matchers::MatchFinder and \c
+/// tooling::RefactoringTool.
+///
+/// Runs AST matchers and stores the \c tooling::Replacements in a map.
+class ASTMatchRefactorer {
+public:
+  ASTMatchRefactorer(std::map<std::string, Replacements> &FileToReplaces);
+
+  template <typename T>
+  void addMatcher(const T &Matcher, RefactoringCallback *Callback) {
+    MatchFinder.addMatcher(Matcher, Callback);
+    Callbacks.push_back(Callback);
+  }
+
+  void addDynamicMatcher(const ast_matchers::internal::DynTypedMatcher &Matcher,
+                         RefactoringCallback *Callback);
+
+  std::unique_ptr<ASTConsumer> newASTConsumer();
+
+private:
+  friend class RefactoringASTConsumer;
+  std::vector<RefactoringCallback *> Callbacks;
+  ast_matchers::MatchFinder MatchFinder;
+  std::map<std::string, Replacements> &FileToReplaces;
+};
+
 /// \brief Replace the text of the statement bound to \c FromId with the text in
 /// \c ToText.
 class ReplaceStmtWithText : public RefactoringCallback {
@@ -59,6 +85,29 @@ private:
   std::string ToText;
 };
 
+/// \brief Replace the text of an AST node bound to \c FromId with the result of
+/// evaluating the template in \c ToTemplate.
+///
+/// Expressions of the form ${NodeName} in \c ToTemplate will be
+/// replaced by the text of the node bound to ${NodeName}. The string
+/// "$$" will be replaced by "$".
+class ReplaceNodeWithTemplate : public RefactoringCallback {
+public:
+  static llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
+  create(StringRef FromId, StringRef ToTemplate);
+  void run(const ast_matchers::MatchFinder::MatchResult &Result) override;
+
+private:
+  struct TemplateElement {
+    enum { Literal, Identifier } Type;
+    std::string Value;
+  };
+  ReplaceNodeWithTemplate(llvm::StringRef FromId,
+                          std::vector<TemplateElement> &&Template);
+  std::string FromId;
+  std::vector<TemplateElement> Template;
+};
+
 /// \brief Replace the text of the statement bound to \c FromId with the text of
 /// the statement bound to \c ToId.
 class ReplaceStmtWithStmt : public RefactoringCallback {

Modified: cfe/trunk/lib/Tooling/RefactoringCallbacks.cpp
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/lib/Tooling/RefactoringCallbacks.cpp?rev=302624&r1=302623&r2=302624&view=diff
==============================================================================
--- cfe/trunk/lib/Tooling/RefactoringCallbacks.cpp (original)
+++ cfe/trunk/lib/Tooling/RefactoringCallbacks.cpp Wed May 10 02:48:45 2017
@@ -9,8 +9,13 @@
 //
 //
 //===----------------------------------------------------------------------===//
-#include "clang/Lex/Lexer.h"
 #include "clang/Tooling/RefactoringCallbacks.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Lex/Lexer.h"
+
+using llvm::StringError;
+using llvm::make_error;
 
 namespace clang {
 namespace tooling {
@@ -20,18 +25,62 @@ tooling::Replacements &RefactoringCallba
   return Replace;
 }
 
-static Replacement replaceStmtWithText(SourceManager &Sources,
-                                       const Stmt &From,
+ASTMatchRefactorer::ASTMatchRefactorer(
+    std::map<std::string, Replacements> &FileToReplaces)
+    : FileToReplaces(FileToReplaces) {}
+
+void ASTMatchRefactorer::addDynamicMatcher(
+    const ast_matchers::internal::DynTypedMatcher &Matcher,
+    RefactoringCallback *Callback) {
+  MatchFinder.addDynamicMatcher(Matcher, Callback);
+  Callbacks.push_back(Callback);
+}
+
+class RefactoringASTConsumer : public ASTConsumer {
+public:
+  RefactoringASTConsumer(ASTMatchRefactorer &Refactoring)
+      : Refactoring(Refactoring) {}
+
+  void HandleTranslationUnit(ASTContext &Context) override {
+    // The ASTMatchRefactorer is re-used between translation units.
+    // Clear the matchers so that each Replacement is only emitted once.
+    for (const auto &Callback : Refactoring.Callbacks) {
+      Callback->getReplacements().clear();
+    }
+    Refactoring.MatchFinder.matchAST(Context);
+    for (const auto &Callback : Refactoring.Callbacks) {
+      for (const auto &Replacement : Callback->getReplacements()) {
+        llvm::Error Err =
+            Refactoring.FileToReplaces[Replacement.getFilePath()].add(
+                Replacement);
+        if (Err) {
+          llvm::errs() << "Skipping replacement " << Replacement.toString()
+                       << " due to this error:\n"
+                       << toString(std::move(Err)) << "\n";
+        }
+      }
+    }
+  }
+
+private:
+  ASTMatchRefactorer &Refactoring;
+};
+
+std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() {
+  return llvm::make_unique<RefactoringASTConsumer>(*this);
+}
+
+static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From,
                                        StringRef Text) {
-  return tooling::Replacement(Sources, CharSourceRange::getTokenRange(
-      From.getSourceRange()), Text);
+  return tooling::Replacement(
+      Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text);
 }
-static Replacement replaceStmtWithStmt(SourceManager &Sources,
-                                       const Stmt &From,
+static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From,
                                        const Stmt &To) {
-  return replaceStmtWithText(Sources, From, Lexer::getSourceText(
-      CharSourceRange::getTokenRange(To.getSourceRange()),
-      Sources, LangOptions()));
+  return replaceStmtWithText(
+      Sources, From,
+      Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()),
+                           Sources, LangOptions()));
 }
 
 ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText)
@@ -103,5 +152,90 @@ void ReplaceIfStmtWithItsBody::run(
   }
 }
 
+ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(
+    llvm::StringRef FromId, std::vector<TemplateElement> &&Template)
+    : FromId(FromId), Template(Template) {}
+
+llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
+ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) {
+  std::vector<TemplateElement> ParsedTemplate;
+  for (size_t Index = 0; Index < ToTemplate.size();) {
+    if (ToTemplate[Index] == '$') {
+      if (ToTemplate.substr(Index, 2) == "$$") {
+        Index += 2;
+        ParsedTemplate.push_back(
+            TemplateElement{TemplateElement::Literal, "$"});
+      } else if (ToTemplate.substr(Index, 2) == "${") {
+        size_t EndOfIdentifier = ToTemplate.find("}", Index);
+        if (EndOfIdentifier == std::string::npos) {
+          return make_error<StringError>(
+              "Unterminated ${...} in replacement template near " +
+                  ToTemplate.substr(Index),
+              std::make_error_code(std::errc::bad_message));
+        }
+        std::string SourceNodeName =
+            ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2);
+        ParsedTemplate.push_back(
+            TemplateElement{TemplateElement::Identifier, SourceNodeName});
+        Index = EndOfIdentifier + 1;
+      } else {
+        return make_error<StringError>(
+            "Invalid $ in replacement template near " +
+                ToTemplate.substr(Index),
+            std::make_error_code(std::errc::bad_message));
+      }
+    } else {
+      size_t NextIndex = ToTemplate.find('$', Index + 1);
+      ParsedTemplate.push_back(
+          TemplateElement{TemplateElement::Literal,
+                          ToTemplate.substr(Index, NextIndex - Index)});
+      Index = NextIndex;
+    }
+  }
+  return std::unique_ptr<ReplaceNodeWithTemplate>(
+      new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate)));
+}
+
+void ReplaceNodeWithTemplate::run(
+    const ast_matchers::MatchFinder::MatchResult &Result) {
+  const auto &NodeMap = Result.Nodes.getMap();
+
+  std::string ToText;
+  for (const auto &Element : Template) {
+    switch (Element.Type) {
+    case TemplateElement::Literal:
+      ToText += Element.Value;
+      break;
+    case TemplateElement::Identifier: {
+      auto NodeIter = NodeMap.find(Element.Value);
+      if (NodeIter == NodeMap.end()) {
+        llvm::errs() << "Node " << Element.Value
+                     << " used in replacement template not bound in Matcher \n";
+        llvm::report_fatal_error("Unbound node in replacement template.");
+      }
+      CharSourceRange Source =
+          CharSourceRange::getTokenRange(NodeIter->second.getSourceRange());
+      ToText += Lexer::getSourceText(Source, *Result.SourceManager,
+                                     Result.Context->getLangOpts());
+      break;
+    }
+    }
+  }
+  if (NodeMap.count(FromId) == 0) {
+    llvm::errs() << "Node to be replaced " << FromId
+                 << " not bound in query.\n";
+    llvm::report_fatal_error("FromId node not bound in MatchResult");
+  }
+  auto Replacement =
+      tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText,
+                           Result.Context->getLangOpts());
+  llvm::Error Err = Replace.add(Replacement);
+  if (Err) {
+    llvm::errs() << "Query and replace failed in " << Replacement.getFilePath()
+                 << "! " << llvm::toString(std::move(Err)) << "\n";
+    llvm::report_fatal_error("Replacement failed");
+  }
+}
+
 } // end namespace tooling
 } // end namespace clang

Modified: cfe/trunk/unittests/Tooling/RefactoringCallbacksTest.cpp
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/unittests/Tooling/RefactoringCallbacksTest.cpp?rev=302624&r1=302623&r2=302624&view=diff
==============================================================================
--- cfe/trunk/unittests/Tooling/RefactoringCallbacksTest.cpp (original)
+++ cfe/trunk/unittests/Tooling/RefactoringCallbacksTest.cpp Wed May 10 02:48:45 2017
@@ -7,10 +7,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "clang/Tooling/RefactoringCallbacks.h"
 #include "RewriterTestContext.h"
 #include "clang/ASTMatchers/ASTMatchFinder.h"
 #include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/Tooling/RefactoringCallbacks.h"
 #include "gtest/gtest.h"
 
 namespace clang {
@@ -19,11 +19,10 @@ namespace tooling {
 using namespace ast_matchers;
 
 template <typename T>
-void expectRewritten(const std::string &Code,
-                     const std::string &Expected,
-                     const T &AMatcher,
-                     RefactoringCallback &Callback) {
-  MatchFinder Finder;
+void expectRewritten(const std::string &Code, const std::string &Expected,
+                     const T &AMatcher, RefactoringCallback &Callback) {
+  std::map<std::string, Replacements> FileToReplace;
+  ASTMatchRefactorer Finder(FileToReplace);
   Finder.addMatcher(AMatcher, &Callback);
   std::unique_ptr<tooling::FrontendActionFactory> Factory(
       tooling::newFrontendActionFactory(&Finder));
@@ -31,7 +30,7 @@ void expectRewritten(const std::string &
       << "Parsing error in \"" << Code << "\"";
   RewriterTestContext Context;
   FileID ID = Context.createInMemoryFile("input.cc", Code);
-  EXPECT_TRUE(tooling::applyAllReplacements(Callback.getReplacements(),
+  EXPECT_TRUE(tooling::applyAllReplacements(FileToReplace["input.cc"],
                                             Context.Rewrite));
   EXPECT_EQ(Expected, Context.getRewrittenText(ID));
 }
@@ -61,18 +60,18 @@ TEST(RefactoringCallbacksTest, ReplacesI
   std::string Code = "void f() { int i = 1; }";
   std::string Expected = "void f() { int i = 2; }";
   ReplaceStmtWithText Callback("id", "2");
-  expectRewritten(Code, Expected, id("id", expr(integerLiteral())),
-                  Callback);
+  expectRewritten(Code, Expected, id("id", expr(integerLiteral())), Callback);
 }
 
 TEST(RefactoringCallbacksTest, ReplacesStmtWithStmt) {
   std::string Code = "void f() { int i = false ? 1 : i * 2; }";
   std::string Expected = "void f() { int i = i * 2; }";
   ReplaceStmtWithStmt Callback("always-false", "should-be");
-  expectRewritten(Code, Expected,
-      id("always-false", conditionalOperator(
-          hasCondition(cxxBoolLiteral(equals(false))),
-          hasFalseExpression(id("should-be", expr())))),
+  expectRewritten(
+      Code, Expected,
+      id("always-false",
+         conditionalOperator(hasCondition(cxxBoolLiteral(equals(false))),
+                             hasFalseExpression(id("should-be", expr())))),
       Callback);
 }
 
@@ -80,10 +79,10 @@ TEST(RefactoringCallbacksTest, ReplacesI
   std::string Code = "bool a; void f() { if (a) f(); else a = true; }";
   std::string Expected = "bool a; void f() { f(); }";
   ReplaceIfStmtWithItsBody Callback("id", true);
-  expectRewritten(Code, Expected,
-      id("id", ifStmt(
-          hasCondition(implicitCastExpr(hasSourceExpression(
-              declRefExpr(to(varDecl(hasName("a"))))))))),
+  expectRewritten(
+      Code, Expected,
+      id("id", ifStmt(hasCondition(implicitCastExpr(hasSourceExpression(
+                   declRefExpr(to(varDecl(hasName("a"))))))))),
       Callback);
 }
 
@@ -92,9 +91,63 @@ TEST(RefactoringCallbacksTest, RemovesEn
   std::string Expected = "void f() {  }";
   ReplaceIfStmtWithItsBody Callback("id", false);
   expectRewritten(Code, Expected,
-      id("id", ifStmt(hasCondition(cxxBoolLiteral(equals(false))))),
-      Callback);
+                  id("id", ifStmt(hasCondition(cxxBoolLiteral(equals(false))))),
+                  Callback);
 }
 
+TEST(RefactoringCallbacksTest, TemplateJustText) {
+  std::string Code = "void f() { int i = 1; }";
+  std::string Expected = "void f() { FOO }";
+  auto Callback = ReplaceNodeWithTemplate::create("id", "FOO");
+  EXPECT_FALSE(Callback.takeError());
+  expectRewritten(Code, Expected, id("id", declStmt()), **Callback);
+}
+
+TEST(RefactoringCallbacksTest, TemplateSimpleSubst) {
+  std::string Code = "void f() { int i = 1; }";
+  std::string Expected = "void f() { long x = 1; }";
+  auto Callback = ReplaceNodeWithTemplate::create("decl", "long x = ${init}");
+  EXPECT_FALSE(Callback.takeError());
+  expectRewritten(Code, Expected,
+                  id("decl", varDecl(hasInitializer(id("init", expr())))),
+                  **Callback);
+}
+
+TEST(RefactoringCallbacksTest, TemplateLiteral) {
+  std::string Code = "void f() { int i = 1; }";
+  std::string Expected = "void f() { string x = \"$-1\"; }";
+  auto Callback = ReplaceNodeWithTemplate::create("decl",
+                                                  "string x = \"$$-${init}\"");
+  EXPECT_FALSE(Callback.takeError());
+  expectRewritten(Code, Expected,
+                  id("decl", varDecl(hasInitializer(id("init", expr())))),
+                  **Callback);
+}
+
+static void ExpectStringError(const std::string &Expected,
+                              llvm::Error E) {
+  std::string Found;
+  handleAllErrors(std::move(E), [&](const llvm::StringError &SE) {
+      llvm::raw_string_ostream Stream(Found);
+      SE.log(Stream);
+    });
+  EXPECT_EQ(Expected, Found);
+}
+
+TEST(RefactoringCallbacksTest, TemplateUnterminated) {
+  auto Callback = ReplaceNodeWithTemplate::create("decl",
+                                                  "string x = \"$$-${init\"");
+  ExpectStringError("Unterminated ${...} in replacement template near ${init\"",
+                    Callback.takeError());
+}
+
+TEST(RefactoringCallbacksTest, TemplateUnknownDollar) {
+  auto Callback = ReplaceNodeWithTemplate::create("decl",
+                                                  "string x = \"$<");
+  ExpectStringError("Invalid $ in replacement template near $<",
+                    Callback.takeError());
+}
+
+
 } // end namespace ast_matchers
 } // end namespace clang




More information about the cfe-commits mailing list