[clang] [clang-refactor] Add Matcher Edit refactoring rule (PR #123782)

Сергеев Игнатий via cfe-commits cfe-commits at lists.llvm.org
Tue Jan 21 09:14:32 PST 2025


https://github.com/IgnatSergeev created https://github.com/llvm/llvm-project/pull/123782

Modified capabilities of refactoring engine and clang-refactor tool:
- Added source location argument for clang-refactor tool
- Added corresponding source location requirement, that requests SourceLocation
- Added AST matcher(hasPointWithin), that checks that location is inside result node SourceRange 
- Added AST location match requirement, that requests ast_matcher::DeclarationMatcher or StatementMatcher, SourceLocation(it is source location requirement), wrappes given matcher with hasPointWithin matcher using requested location, matches AST and provides last match result
- Added AST edit requirement, that requests and provides transformer::ASTEdit
- Added Edit Match Refactoring rule, that edits given match result with given ASTEdit

>From 6faa8301795b205729cbb6dc9bd5e1ad33683b91 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=D0=98=D0=B3=D0=BD=D0=B0=D1=82=20=D0=A1=D0=B5=D1=80=D0=B3?=
 =?UTF-8?q?=D0=B5=D0=B5=D0=B2?= <ignat.sergeev at softcom.su>
Date: Mon, 27 May 2024 19:43:11 +0000
Subject: [PATCH 1/3] Add source location requirement

Added source location requirement for refactoring engine
---
 .../clang/Basic/DiagnosticRefactoringKinds.td        |  2 ++
 .../Tooling/Refactoring/RefactoringActionRule.h      |  4 ++++
 .../Refactoring/RefactoringActionRuleRequirements.h  | 12 ++++++++++++
 .../Refactoring/RefactoringActionRulesInternal.h     |  5 +++++
 .../Tooling/Refactoring/RefactoringRuleContext.h     | 12 ++++++++++++
 5 files changed, 35 insertions(+)

diff --git a/clang/include/clang/Basic/DiagnosticRefactoringKinds.td b/clang/include/clang/Basic/DiagnosticRefactoringKinds.td
index e060fffc7280a7..960f44ab556629 100644
--- a/clang/include/clang/Basic/DiagnosticRefactoringKinds.td
+++ b/clang/include/clang/Basic/DiagnosticRefactoringKinds.td
@@ -28,6 +28,8 @@ def err_refactor_extract_simple_expression : Error<"the selected expression "
 def err_refactor_extract_prohibited_expression : Error<"the selected "
   "expression cannot be extracted">;
 
+def err_refactor_no_location : Error<"refactoring action can't be initiated "
+  "without a location">;
 }
 
 } // end of Refactoring diagnostics
diff --git a/clang/include/clang/Tooling/Refactoring/RefactoringActionRule.h b/clang/include/clang/Tooling/Refactoring/RefactoringActionRule.h
index c6a6c4f6093a34..374f19d6d82338 100644
--- a/clang/include/clang/Tooling/Refactoring/RefactoringActionRule.h
+++ b/clang/include/clang/Tooling/Refactoring/RefactoringActionRule.h
@@ -56,6 +56,10 @@ class RefactoringActionRule : public RefactoringActionRuleBase {
   /// to be fulfilled before refactoring can be performed.
   virtual bool hasSelectionRequirement() = 0;
 
+  /// Returns true when the rule has a source location requirement that has
+  /// to be fulfilled before refactoring can be performed.
+  virtual bool hasLocationRequirement() = 0;
+
   /// Traverses each refactoring option used by the rule and invokes the
   /// \c visit callback in the consumer for each option.
   ///
diff --git a/clang/include/clang/Tooling/Refactoring/RefactoringActionRuleRequirements.h b/clang/include/clang/Tooling/Refactoring/RefactoringActionRuleRequirements.h
index 1a318da3acca19..bdd2d98df73eb1 100644
--- a/clang/include/clang/Tooling/Refactoring/RefactoringActionRuleRequirements.h
+++ b/clang/include/clang/Tooling/Refactoring/RefactoringActionRuleRequirements.h
@@ -10,6 +10,7 @@
 #define LLVM_CLANG_TOOLING_REFACTORING_REFACTORINGACTIONRULEREQUIREMENTS_H
 
 #include "clang/Basic/LLVM.h"
+#include "clang/Basic/SourceLocation.h"
 #include "clang/Tooling/Refactoring/ASTSelection.h"
 #include "clang/Tooling/Refactoring/RefactoringDiagnostic.h"
 #include "clang/Tooling/Refactoring/RefactoringOption.h"
@@ -77,6 +78,17 @@ class CodeRangeASTSelectionRequirement : public ASTSelectionRequirement {
   evaluate(RefactoringRuleContext &Context) const;
 };
 
+/// A base class for any requirement that expects source code position (or the
+/// refactoring tool with the -location option).
+class SourceLocationRequirement : public RefactoringActionRuleRequirement {
+public:
+  Expected<SourceLocation> evaluate(RefactoringRuleContext &Context) const {
+    if (Context.getLocation().isValid())
+      return Context.getLocation();
+    return Context.createDiagnosticError(diag::err_refactor_no_location);
+  }
+};
+
 /// A base class for any requirement that requires some refactoring options.
 class RefactoringOptionsRequirement : public RefactoringActionRuleRequirement {
 public:
diff --git a/clang/include/clang/Tooling/Refactoring/RefactoringActionRulesInternal.h b/clang/include/clang/Tooling/Refactoring/RefactoringActionRulesInternal.h
index 33194c401ea143..52afb012f4874c 100644
--- a/clang/include/clang/Tooling/Refactoring/RefactoringActionRulesInternal.h
+++ b/clang/include/clang/Tooling/Refactoring/RefactoringActionRulesInternal.h
@@ -139,6 +139,11 @@ createRefactoringActionRule(const RequirementTypes &... Requirements) {
                                  RequirementTypes...>::value;
     }
 
+    bool hasLocationRequirement() override {
+      return internal::HasBaseOf<SourceLocationRequirement,
+                                 RequirementTypes...>::value;
+    }
+
     void visitRefactoringOptions(RefactoringOptionVisitor &Visitor) override {
       internal::visitRefactoringOptions(
           Visitor, Requirements,
diff --git a/clang/include/clang/Tooling/Refactoring/RefactoringRuleContext.h b/clang/include/clang/Tooling/Refactoring/RefactoringRuleContext.h
index 7d97f811f024e0..85bba662afcd28 100644
--- a/clang/include/clang/Tooling/Refactoring/RefactoringRuleContext.h
+++ b/clang/include/clang/Tooling/Refactoring/RefactoringRuleContext.h
@@ -30,6 +30,9 @@ namespace tooling {
 ///
 ///   - SelectionRange: an optional source selection ranges that can be used
 ///     to represent a selection in an editor.
+///
+///   - Location: an optional source location that can be used
+///     to represent a cursor in an editor.
 class RefactoringRuleContext {
 public:
   RefactoringRuleContext(const SourceManager &SM) : SM(SM) {}
@@ -40,8 +43,14 @@ class RefactoringRuleContext {
   /// refactoring engine. Can be invalid.
   SourceRange getSelectionRange() const { return SelectionRange; }
 
+  /// Returns the current source location as set by the
+  /// refactoring engine. Can be invalid.
+  SourceLocation getLocation() const { return Location; }
+
   void setSelectionRange(SourceRange R) { SelectionRange = R; }
 
+  void setLocation(SourceLocation L) { Location = L; }
+
   bool hasASTContext() const { return AST; }
 
   ASTContext &getASTContext() const {
@@ -73,6 +82,9 @@ class RefactoringRuleContext {
   /// An optional source selection range that's commonly used to represent
   /// a selection in an editor.
   SourceRange SelectionRange;
+  /// An optional source location that's commonly used to represent
+  /// a cursor in an editor.
+  SourceLocation Location;
   /// An optional AST for the translation unit on which a refactoring action
   /// might operate on.
   ASTContext *AST = nullptr;

>From 9232b48e3bf322a98d194e080abb2db9fd196172 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=D0=98=D0=B3=D0=BD=D0=B0=D1=82=20=D0=A1=D0=B5=D1=80=D0=B3?=
 =?UTF-8?q?=D0=B5=D0=B5=D0=B2?= <ignat.sergeev at softcom.su>
Date: Mon, 27 May 2024 20:26:52 +0000
Subject: [PATCH 2/3] Add source location argument

Added source location argument for clang-refactor cli
---
 clang/tools/clang-refactor/ClangRefactor.cpp | 149 ++++++++++++++++++-
 1 file changed, 143 insertions(+), 6 deletions(-)

diff --git a/clang/tools/clang-refactor/ClangRefactor.cpp b/clang/tools/clang-refactor/ClangRefactor.cpp
index 968f0594085d40..937e605a3dafbb 100644
--- a/clang/tools/clang-refactor/ClangRefactor.cpp
+++ b/clang/tools/clang-refactor/ClangRefactor.cpp
@@ -164,6 +164,85 @@ SourceSelectionArgument::fromString(StringRef Value) {
   return nullptr;
 }
 
+/// Stores the parsed `-location` argument.
+class AbstractSourceLocationArgument {
+public:
+  virtual ~AbstractSourceLocationArgument() {}
+
+  /// Parse the `-location` argument.
+  ///
+  /// \returns A valid argument when the parse succedeed, null otherwise.
+  static std::unique_ptr<AbstractSourceLocationArgument>
+  fromString(StringRef Value);
+
+  /// Prints any additional state associated with the location argument to
+  /// the given output stream.
+  virtual void print(raw_ostream &OS) {}
+
+  /// Returns a replacement refactoring result consumer (if any) that should
+  /// consume the results of a refactoring operation.
+  ///
+  /// The replacement refactoring result consumer is used by \c
+  /// TestSourceLocationArgument to inject a test-specific result handling
+  /// logic into the refactoring operation. The test-specific consumer
+  /// ensures that the individual results in a particular test group are
+  /// identical.
+  virtual std::unique_ptr<ClangRefactorToolConsumerInterface>
+  createCustomConsumer() {
+    return nullptr;
+  }
+
+  /// Runs the given refactoring function for each specified location.
+  ///
+  /// \returns true if an error occurred, false otherwise.
+  virtual bool
+  forAllLocations(const SourceManager &SM,
+                  llvm::function_ref<void(SourceLocation L)> Callback) = 0;
+};
+
+/// Stores the parsed -location=filename:line:column option.
+class SourceLocationArgument final : public AbstractSourceLocationArgument {
+public:
+  SourceLocationArgument(ParsedSourceLocation Location)
+      : Location(std::move(Location)) {}
+
+  bool forAllLocations(
+      const SourceManager &SM,
+      llvm::function_ref<void(SourceLocation L)> Callback) override {
+    auto FE = SM.getFileManager().getFile(Location.FileName);
+    FileID FID = FE ? SM.translateFile(*FE) : FileID();
+    if (!FE || FID.isInvalid()) {
+      llvm::errs() << "error: -location=" << Location.FileName
+                   << ":... : given file is not in the target TU\n";
+      return true;
+    }
+
+    SourceLocation Loc = SM.getMacroArgExpandedLocation(
+        SM.translateLineCol(FID, Location.Line, Location.Column));
+    if (Loc.isInvalid()) {
+      llvm::errs() << "error: -location=" << Location.FileName << ':'
+                   << Location.Line << ':' << Location.Column
+                   << " : invalid source location\n";
+      return true;
+    }
+    Callback(Loc);
+    return false;
+  }
+
+private:
+  ParsedSourceLocation Location;
+};
+
+std::unique_ptr<AbstractSourceLocationArgument>
+AbstractSourceLocationArgument::fromString(StringRef Value) {
+  ParsedSourceLocation Location = ParsedSourceLocation::FromString(Value);
+  if (Location.FileName != "")
+    return std::make_unique<SourceLocationArgument>(std::move(Location));
+  llvm::errs() << "error: '-location' option must be specified using "
+                  "<file>:<line>:<column>\n";
+  return nullptr;
+}
+
 /// A container that stores the command-line options used by a single
 /// refactoring option.
 class RefactoringActionCommandLineOptions {
@@ -272,6 +351,17 @@ class RefactoringActionSubcommand : public cl::SubCommand {
         break;
       }
     }
+    // Check if the location option is supported.
+    for (const auto &Rule : this->ActionRules) {
+      if (Rule->hasLocationRequirement()) {
+        Location = std::make_unique<cl::opt<std::string>>(
+            "location",
+            cl::desc("Location where refactoring should "
+                     "be initiated( <file>:<line>:<column>)"),
+            cl::cat(Category), cl::sub(*this));
+        break;
+      }
+    }
     // Create the refactoring options.
     for (const auto &Rule : this->ActionRules) {
       CommandLineRefactoringOptionCreator OptionCreator(Category, *this,
@@ -296,11 +386,28 @@ class RefactoringActionSubcommand : public cl::SubCommand {
     return false;
   }
 
+  /// Parses the "-location" command-line argument.
+  ///
+  /// \returns true on error, false otherwise.
+  bool parseLocationArgument() {
+    if (Location) {
+      ParsedLocation = AbstractSourceLocationArgument::fromString(*Location);
+      if (!ParsedLocation)
+        return true;
+    }
+    return false;
+  }
+
   SourceSelectionArgument *getSelection() const {
     assert(Selection && "selection not supported!");
     return ParsedSelection.get();
   }
 
+  AbstractSourceLocationArgument *getLocation() const {
+    assert(Location && "location not supported!");
+    return ParsedLocation.get();
+  }
+
   const RefactoringActionCommandLineOptions &getOptions() const {
     return Options;
   }
@@ -309,7 +416,9 @@ class RefactoringActionSubcommand : public cl::SubCommand {
   std::unique_ptr<RefactoringAction> Action;
   RefactoringActionRules ActionRules;
   std::unique_ptr<cl::opt<std::string>> Selection;
+  std::unique_ptr<cl::opt<std::string>> Location;
   std::unique_ptr<SourceSelectionArgument> ParsedSelection;
+  std::unique_ptr<AbstractSourceLocationArgument> ParsedLocation;
   RefactoringActionCommandLineOptions Options;
 };
 
@@ -399,6 +508,7 @@ class ClangRefactorTool {
     // consumer.
     std::unique_ptr<ClangRefactorToolConsumerInterface> TestConsumer;
     bool HasSelection = MatchingRule->hasSelectionRequirement();
+    bool HasLocation = MatchingRule->hasLocationRequirement();
     if (HasSelection)
       TestConsumer = SelectedSubcommand->getSelection()->createCustomConsumer();
     ClangRefactorToolConsumerInterface *ActiveConsumer =
@@ -424,6 +534,19 @@ class ClangRefactorTool {
       ActiveConsumer->endTU();
       return;
     }
+    if (HasLocation) {
+      assert(SelectedSubcommand->getLocation() && "Missing location argument?");
+      if (opts::Verbose)
+        SelectedSubcommand->getLocation()->print(llvm::outs());
+      if (SelectedSubcommand->getLocation()->forAllLocations(
+              Context.getSources(), [&](SourceLocation L) {
+                Context.setLocation(L);
+                InvokeRule(*ActiveConsumer);
+              }))
+        HasFailed = true;
+      ActiveConsumer->endTU();
+      return;
+    }
     InvokeRule(*ActiveConsumer);
     ActiveConsumer->endTU();
   }
@@ -528,6 +651,12 @@ class ClangRefactorTool {
       R.getEnd().print(llvm::outs(), Context.getSources());
       llvm::outs() << "\n";
     }
+    if (Context.getLocation().isValid()) {
+      SourceLocation L = Context.getLocation();
+      llvm::outs() << "  -location=";
+      L.print(llvm::outs(), Context.getSources());
+      llvm::outs() << "\n";
+    }
   }
 
   llvm::Expected<RefactoringActionRule *>
@@ -539,16 +668,24 @@ class ClangRefactorTool {
       CommandLineRefactoringOptionVisitor Visitor(Subcommand.getOptions());
       Rule->visitRefactoringOptions(Visitor);
       if (Visitor.getMissingRequiredOptions().empty()) {
-        if (!Rule->hasSelectionRequirement()) {
-          MatchingRules.push_back(Rule.get());
-        } else {
+        bool HasMissingOptions = false;
+        if (Rule->hasSelectionRequirement()) {
           Subcommand.parseSelectionArgument();
-          if (Subcommand.getSelection()) {
-            MatchingRules.push_back(Rule.get());
-          } else {
+          if (!Subcommand.getSelection()) {
             MissingOptions.insert("selection");
+            HasMissingOptions = true;
           }
         }
+        if (Rule->hasLocationRequirement()) {
+          Subcommand.parseLocationArgument();
+          if (!Subcommand.getLocation()) {
+            MissingOptions.insert("location");
+            HasMissingOptions = true;
+          }
+        }
+        if (!HasMissingOptions) {
+          MatchingRules.push_back(Rule.get());
+        }
       }
       for (const RefactoringOption *Opt : Visitor.getMissingRequiredOptions())
         MissingOptions.insert(Opt->getName());

>From 000fd27bf6ebea785b3e970096c601b70beaa6f1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=D0=98=D0=B3=D0=BD=D0=B0=D1=82=20=D0=A1=D0=B5=D1=80=D0=B3?=
 =?UTF-8?q?=D0=B5=D0=B5=D0=B2?= <ignat.sergeev at softcom.su>
Date: Thu, 22 Aug 2024 11:13:56 +0000
Subject: [PATCH 3/3] Add edit match rule

Added matcher requirement, edit generator requirement, and edit match rule
---
 .../clang/Basic/DiagnosticRefactoringKinds.td |  4 +
 .../Tooling/Refactoring/Edit/EditMatchRule.h  | 49 ++++++++++
 .../RefactoringActionRuleRequirements.h       | 90 ++++++++++++++++++-
 clang/lib/Tooling/Refactoring/CMakeLists.txt  |  2 +
 .../Refactoring/Edit/EditMatchRule.cpp        | 74 +++++++++++++++
 5 files changed, 218 insertions(+), 1 deletion(-)
 create mode 100644 clang/include/clang/Tooling/Refactoring/Edit/EditMatchRule.h
 create mode 100644 clang/lib/Tooling/Refactoring/Edit/EditMatchRule.cpp

diff --git a/clang/include/clang/Basic/DiagnosticRefactoringKinds.td b/clang/include/clang/Basic/DiagnosticRefactoringKinds.td
index 960f44ab556629..9eacb76cd9dadc 100644
--- a/clang/include/clang/Basic/DiagnosticRefactoringKinds.td
+++ b/clang/include/clang/Basic/DiagnosticRefactoringKinds.td
@@ -30,6 +30,10 @@ def err_refactor_extract_prohibited_expression : Error<"the selected "
 
 def err_refactor_no_location : Error<"refactoring action can't be initiated "
   "without a location">;
+def err_refactor_no_location_match : Error<"refactoring action can't be initiated "
+  "without a matching location">;
+def err_refactor_invalid_edit_generator : Error<"refactoring action can't be initiated "
+  "without a correct edit generator">;
 }
 
 } // end of Refactoring diagnostics
diff --git a/clang/include/clang/Tooling/Refactoring/Edit/EditMatchRule.h b/clang/include/clang/Tooling/Refactoring/Edit/EditMatchRule.h
new file mode 100644
index 00000000000000..46e5aa54c0bee5
--- /dev/null
+++ b/clang/include/clang/Tooling/Refactoring/Edit/EditMatchRule.h
@@ -0,0 +1,49 @@
+//===--- EditMatchRule.h - Clang refactoring library ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLING_REFACTORING_EDIT_EDITMATCHRULE_H
+#define LLVM_CLANG_TOOLING_REFACTORING_EDIT_EDITMATCHRULE_H
+
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/Tooling/Refactoring/RefactoringActionRules.h"
+#include "clang/Tooling/Transformer/RewriteRule.h"
+
+namespace clang {
+namespace tooling {
+
+/// A "Edit Match" refactoring rule edits code around matches according to
+/// EditGenerator.
+class EditMatchRule final : public SourceChangeRefactoringRule {
+public:
+  /// Initiates the delete match refactoring operation.
+  ///
+  /// \param R    MatchResult  Match result to edit.
+  /// \param EG    EditGenerator  Edit to perform.
+  static Expected<EditMatchRule>
+  initiate(RefactoringRuleContext &Context,
+           ast_matchers::MatchFinder::MatchResult R,
+           transformer::EditGenerator EG);
+
+  static const RefactoringDescriptor &describe();
+
+private:
+  EditMatchRule(ast_matchers::MatchFinder::MatchResult R,
+                transformer::EditGenerator EG)
+      : Result(std::move(R)), EditGenerator(std::move(EG)) {}
+
+  Expected<AtomicChanges>
+  createSourceReplacements(RefactoringRuleContext &Context) override;
+
+  ast_matchers::MatchFinder::MatchResult Result;
+  transformer::EditGenerator EditGenerator;
+};
+
+} // end namespace tooling
+} // end namespace clang
+
+#endif // LLVM_CLANG_TOOLING_REFACTORING_EDIT_EDITMATCHRULE_H
diff --git a/clang/include/clang/Tooling/Refactoring/RefactoringActionRuleRequirements.h b/clang/include/clang/Tooling/Refactoring/RefactoringActionRuleRequirements.h
index bdd2d98df73eb1..4ad7fbbef31620 100644
--- a/clang/include/clang/Tooling/Refactoring/RefactoringActionRuleRequirements.h
+++ b/clang/include/clang/Tooling/Refactoring/RefactoringActionRuleRequirements.h
@@ -9,14 +9,18 @@
 #ifndef LLVM_CLANG_TOOLING_REFACTORING_REFACTORINGACTIONRULEREQUIREMENTS_H
 #define LLVM_CLANG_TOOLING_REFACTORING_REFACTORINGACTIONRULEREQUIREMENTS_H
 
+#include "clang/AST/ASTTypeTraits.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/ASTMatchers/ASTMatchersMacros.h"
 #include "clang/Basic/LLVM.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Tooling/Refactoring/ASTSelection.h"
 #include "clang/Tooling/Refactoring/RefactoringDiagnostic.h"
 #include "clang/Tooling/Refactoring/RefactoringOption.h"
 #include "clang/Tooling/Refactoring/RefactoringRuleContext.h"
+#include "clang/Tooling/Transformer/RewriteRule.h"
 #include "llvm/Support/Error.h"
-#include <type_traits>
 
 namespace clang {
 namespace tooling {
@@ -89,6 +93,90 @@ class SourceLocationRequirement : public RefactoringActionRuleRequirement {
   }
 };
 
+AST_POLYMORPHIC_MATCHER_P(hasPointWithin,
+                          AST_POLYMORPHIC_SUPPORTED_TYPES(Stmt, Decl),
+                          FullSourceLoc, L) {
+  if (!L.hasManager()) {
+    return false;
+  }
+  const SourceRange &SR = Node.getSourceRange();
+  return L.getManager().isPointWithin(L, SR.getBegin(), SR.getEnd());
+}
+
+/// An AST location match is satisfied when there is match around given
+/// location. In case of several matches inner one is taken.
+///
+/// The requirement will be evaluated only once during the initiation and
+/// search of matching refactoring action rules.
+template <typename MatcherType>
+class ASTLocMatchRequirement : public SourceLocationRequirement {
+public:
+  static_assert(
+      std::is_same<ast_matchers::StatementMatcher, MatcherType>::value ||
+          std::is_same<ast_matchers::DeclarationMatcher, MatcherType>::value,
+      "Expected a Statement or Declaration matcher");
+
+  class LocMatchCallback : public ast_matchers::MatchFinder::MatchCallback {
+  public:
+    void run(const clang::ast_matchers::MatchFinder::MatchResult &R) override {
+      Result = std::make_unique<ast_matchers::MatchFinder::MatchResult>(R);
+    }
+    std::unique_ptr<ast_matchers::MatchFinder::MatchResult> Result;
+  };
+
+  Expected<ast_matchers::MatchFinder::MatchResult>
+  evaluate(RefactoringRuleContext &Context) const {
+    Expected<SourceLocation> Location =
+        SourceLocationRequirement::evaluate(Context);
+    if (!Location)
+      return Location.takeError();
+    MatcherType M = createWrapperMatcher(
+        FullSourceLoc(*Location, Context.getASTContext().getSourceManager()),
+        Matcher);
+
+    ast_matchers::MatchFinder MF;
+    LocMatchCallback Callback;
+    MF.addMatcher(M, &Callback);
+    MF.matchAST(Context.getASTContext());
+    if (!Callback.Result)
+      return Context.createDiagnosticError(
+          diag::err_refactor_no_location_match);
+    return *Callback.Result;
+  }
+
+  ASTLocMatchRequirement(MatcherType M) : Matcher(M) {}
+
+private:
+  ast_matchers::StatementMatcher
+  createWrapperMatcher(FullSourceLoc L,
+                       ast_matchers::StatementMatcher M) const {
+    return ast_matchers::stmt(M, hasPointWithin(L)).bind(transformer::RootID);
+  }
+
+  ast_matchers::DeclarationMatcher
+  createWrapperMatcher(FullSourceLoc L,
+                       ast_matchers::DeclarationMatcher M) const {
+    return ast_matchers::decl(M, hasPointWithin(L)).bind(transformer::RootID);
+  }
+
+  MatcherType Matcher;
+};
+
+/// Requirement that evaluates to the EditGenerator value given at its creation.
+class EditGeneratorRequirement : public RefactoringActionRuleRequirement {
+public:
+  Expected<transformer::EditGenerator>
+  evaluate(RefactoringRuleContext &Context) const {
+    return EditGenerator;
+  }
+
+  EditGeneratorRequirement(transformer::EditGenerator EG)
+      : EditGenerator(std::move(EG)) {}
+
+private:
+  transformer::EditGenerator EditGenerator;
+};
+
 /// A base class for any requirement that requires some refactoring options.
 class RefactoringOptionsRequirement : public RefactoringActionRuleRequirement {
 public:
diff --git a/clang/lib/Tooling/Refactoring/CMakeLists.txt b/clang/lib/Tooling/Refactoring/CMakeLists.txt
index d3077be8810aad..97023b6d6b97bb 100644
--- a/clang/lib/Tooling/Refactoring/CMakeLists.txt
+++ b/clang/lib/Tooling/Refactoring/CMakeLists.txt
@@ -13,6 +13,7 @@ add_clang_library(clangToolingRefactoring
   Rename/USRFinder.cpp
   Rename/USRFindingAction.cpp
   Rename/USRLocFinder.cpp
+  Edit/EditMatchRule.cpp
 
   LINK_LIBS
   clangAST
@@ -23,6 +24,7 @@ add_clang_library(clangToolingRefactoring
   clangLex
   clangRewrite
   clangToolingCore
+  clangTransformer
 
   DEPENDS
   omp_gen
diff --git a/clang/lib/Tooling/Refactoring/Edit/EditMatchRule.cpp b/clang/lib/Tooling/Refactoring/Edit/EditMatchRule.cpp
new file mode 100644
index 00000000000000..55c94b39237778
--- /dev/null
+++ b/clang/lib/Tooling/Refactoring/Edit/EditMatchRule.cpp
@@ -0,0 +1,74 @@
+//===--- EditMatchRule.cpp - Clang refactoring library --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Implements the "edit-match" refactoring rule that can edit matcher results
+///
+//===----------------------------------------------------------------------===//
+
+#include "clang/Tooling/Refactoring/Edit/EditMatchRule.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/Basic/DiagnosticRefactoring.h"
+#include "clang/Rewrite/Core/Rewriter.h"
+#include "clang/Tooling/Refactoring/AtomicChange.h"
+#include "clang/Tooling/Transformer/RewriteRule.h"
+#include "clang/Tooling/Transformer/SourceCode.h"
+
+using namespace clang;
+using namespace tooling;
+using namespace transformer;
+
+Expected<EditMatchRule>
+EditMatchRule::initiate(RefactoringRuleContext &Context,
+                        ast_matchers::MatchFinder::MatchResult R,
+                        transformer::EditGenerator EG) {
+  return EditMatchRule(std::move(R), std::move(EG));
+}
+
+const RefactoringDescriptor &EditMatchRule::describe() {
+  static const RefactoringDescriptor Descriptor = {
+      "edit-match",
+      "Edit Match",
+      "Edits match result source code",
+  };
+  return Descriptor;
+}
+
+Expected<AtomicChanges>
+EditMatchRule::createSourceReplacements(RefactoringRuleContext &Context) {
+  ASTContext &AST = Context.getASTContext();
+  SourceManager &SM = AST.getSourceManager();
+  Expected<SmallVector<transformer::Edit, 1>> Edits = EditGenerator(Result);
+
+  if (!Edits) {
+    return std::move(Edits.takeError());
+  }
+  if (Edits->empty())
+    return Context.createDiagnosticError(
+        diag::err_refactor_invalid_edit_generator);
+
+  AtomicChange Change(SM, Edits->front().Range.getBegin());
+  {
+    for (const auto &Edit : *Edits) {
+      switch (Edit.Kind) {
+      case EditKind::Range:
+        if (auto Err = Change.replace(SM, std::move(Edit.Range),
+                                      std::move(Edit.Replacement))) {
+          return std::move(Err);
+        }
+        break;
+      case EditKind::AddInclude:
+        Change.addHeader(std::move(Edit.Replacement));
+        break;
+      }
+    }
+  }
+
+  return AtomicChanges{std::move(Change)};
+}



More information about the cfe-commits mailing list