[clang-tools-extra] [clangd] Extract Function: add hoisting support (PR #75533)

Julian Schmidt via cfe-commits cfe-commits at lists.llvm.org
Sun Mar 24 06:10:55 PDT 2024


https://github.com/5chmidti updated https://github.com/llvm/llvm-project/pull/75533

>From c1130028fcbb3cd26dd1df537ca0fa449f44bfe1 Mon Sep 17 00:00:00 2001
From: Julian Schmidt <44101708+5chmidti at users.noreply.github.com>
Date: Sat, 21 Jan 2023 14:49:58 +0100
Subject: [PATCH 1/2] [clangd] Extract Function: add hoisting support

Adds support to hoist variables declared inside the selected region
and used afterwards back out of the extraced function for later use.
Uses the explicit variable type if only one decl needs hoisting,
otherwise uses std::pair or std::tuple with auto return type
deduction (requires c++14) and a structured binding (requires c++17)
or explicitly unpacking the variables with get<>.
---
 .../refactor/tweaks/ExtractFunction.cpp       | 159 +++++--
 .../unittests/tweaks/ExtractFunctionTests.cpp | 393 +++++++++++++++++-
 clang-tools-extra/docs/ReleaseNotes.rst       |   3 +
 3 files changed, 523 insertions(+), 32 deletions(-)

diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
index 0302839c58252e..02d1a6d0996a53 100644
--- a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
+++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
@@ -79,6 +79,13 @@ namespace {
 
 using Node = SelectionTree::Node;
 
+struct HoistSetComparator {
+  bool operator()(const Decl *const Lhs, const Decl *const Rhs) const {
+    return Lhs->getLocation() < Rhs->getLocation();
+  }
+};
+using HoistSet = llvm::SmallSet<const NamedDecl *, 1, HoistSetComparator>;
+
 // ExtractionZone is the part of code that is being extracted.
 // EnclosingFunction is the function/method inside which the zone lies.
 // We split the file into 4 parts relative to extraction zone.
@@ -171,12 +178,13 @@ struct ExtractionZone {
   // semicolon after the extraction.
   const Node *getLastRootStmt() const { return Parent->Children.back(); }
 
-  // Checks if declarations inside extraction zone are accessed afterwards.
+  // Checks if declarations inside extraction zone are accessed afterwards and
+  // adds these declarations to the returned set.
   //
   // This performs a partial AST traversal proportional to the size of the
   // enclosing function, so it is possibly expensive.
-  bool requiresHoisting(const SourceManager &SM,
-                        const HeuristicResolver *Resolver) const {
+  HoistSet getDeclsToHoist(const SourceManager &SM,
+                           const HeuristicResolver *Resolver) const {
     // First find all the declarations that happened inside extraction zone.
     llvm::SmallSet<const Decl *, 1> DeclsInExtZone;
     for (auto *RootStmt : RootStmts) {
@@ -191,29 +199,28 @@ struct ExtractionZone {
     }
     // Early exit without performing expensive traversal below.
     if (DeclsInExtZone.empty())
-      return false;
-    // Then make sure they are not used outside the zone.
+      return {};
+    // Add any decl used after the selection to the returned set
+    HoistSet DeclsToHoist{};
     for (const auto *S : EnclosingFunction->getBody()->children()) {
       if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(),
                                        ZoneRange.getEnd()))
         continue;
-      bool HasPostUse = false;
       findExplicitReferences(
           S,
           [&](const ReferenceLoc &Loc) {
-            if (HasPostUse ||
-                SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd()))
+            if (SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd()))
               return;
-            HasPostUse = llvm::any_of(Loc.Targets,
-                                      [&DeclsInExtZone](const Decl *Target) {
-                                        return DeclsInExtZone.contains(Target);
-                                      });
+            for (const NamedDecl *const PostUse : llvm::make_filter_range(
+                     Loc.Targets, [&DeclsInExtZone](const Decl *Target) {
+                       return DeclsInExtZone.contains(Target);
+                     })) {
+              DeclsToHoist.insert(PostUse);
+            }
           },
           Resolver);
-      if (HasPostUse)
-        return true;
     }
-    return false;
+    return DeclsToHoist;
   }
 };
 
@@ -367,14 +374,17 @@ struct NewFunction {
   bool Static = false;
   ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified;
   bool Const = false;
+  const HoistSet &ToHoist;
 
   // Decides whether the extracted function body and the function call need a
   // semicolon after extraction.
   tooling::ExtractionSemicolonPolicy SemicolonPolicy;
   const LangOptions *LangOpts;
-  NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy,
+  NewFunction(const HoistSet &ToHoist,
+              tooling::ExtractionSemicolonPolicy SemicolonPolicy,
               const LangOptions *LangOpts)
-      : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {}
+      : ToHoist(ToHoist), SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {
+  }
   // Render the call for this function.
   std::string renderCall() const;
   // Render the definition for this function.
@@ -390,6 +400,7 @@ struct NewFunction {
   std::string renderSpecifiers(FunctionDeclKind K) const;
   std::string renderQualifiers() const;
   std::string renderDeclarationName(FunctionDeclKind K) const;
+  std::string renderHoistedCall() const;
   // Generate the function body.
   std::string getFuncBody(const SourceManager &SM) const;
 };
@@ -462,7 +473,55 @@ std::string NewFunction::renderDeclarationName(FunctionDeclKind K) const {
   return llvm::formatv("{0}{1}", QualifierName, Name);
 }
 
+// Renders the HoistSet to a comma separated list or a single named decl.
+std::string renderHoistSet(const HoistSet &ToHoist) {
+  std::string Res{};
+  bool NeedsComma = false;
+
+  for (const NamedDecl *DeclToHoist : ToHoist) {
+    if (llvm::isa<VarDecl>(DeclToHoist) ||
+        llvm::isa<BindingDecl>(DeclToHoist)) {
+      if (NeedsComma) {
+        Res += ", ";
+      }
+      Res += DeclToHoist->getNameAsString();
+      NeedsComma = true;
+    }
+  }
+  return Res;
+}
+
+std::string NewFunction::renderHoistedCall() const {
+  auto HoistedVarDecls = std::string{};
+  auto ExplicitUnpacking = std::string{};
+  const auto HasStructuredBinding = LangOpts->CPlusPlus17;
+
+  if (ToHoist.size() > 1) {
+    if (HasStructuredBinding) {
+      HoistedVarDecls = "auto [" + renderHoistSet(ToHoist) + "] = ";
+    } else {
+      HoistedVarDecls = "auto returned = ";
+      auto DeclIter = ToHoist.begin();
+      for (size_t Index = 0U; Index < ToHoist.size(); ++Index, ++DeclIter) {
+        ExplicitUnpacking +=
+            llvm::formatv("\nauto {0} = std::get<{1}>(returned);",
+                          (*DeclIter)->getNameAsString(), Index);
+      }
+    }
+  } else {
+    HoistedVarDecls = "auto " + renderHoistSet(ToHoist) + " = ";
+  }
+
+  return llvm::formatv(
+      "{0}{1}({2}){3}{4}", HoistedVarDecls, Name, renderParametersForCall(),
+      (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : ""),
+      ExplicitUnpacking);
+}
+
 std::string NewFunction::renderCall() const {
+  if (!ToHoist.empty())
+    return renderHoistedCall();
+
   return std::string(
       llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name,
                     renderParametersForCall(),
@@ -495,8 +554,22 @@ std::string NewFunction::getFuncBody(const SourceManager &SM) const {
   // - hoist decls
   // - add return statement
   // - Add semicolon
-  return toSourceCode(SM, BodyRange).str() +
-         (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
+  auto Body = toSourceCode(SM, BodyRange).str() +
+              (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
+
+  if (ToHoist.empty())
+    return Body;
+
+  if (const bool NeedsTupleOrPair = ToHoist.size() > 1; NeedsTupleOrPair) {
+    const auto NeedsPair = ToHoist.size() == 2;
+
+    Body += "\nreturn " +
+            std::string(NeedsPair ? "std::pair{" : "std::tuple{") +
+            renderHoistSet(ToHoist) + "};";
+  } else {
+    Body += "\nreturn " + renderHoistSet(ToHoist) + ";";
+  }
+  return Body;
 }
 
 std::string NewFunction::Parameter::render(const DeclContext *Context) const {
@@ -674,10 +747,6 @@ bool createParameters(NewFunction &ExtractedFunc,
     const auto &DeclInfo = KeyVal.second;
     // If a Decl was Declared in zone and referenced in post zone, it
     // needs to be hoisted (we bail out in that case).
-    // FIXME: Support Decl Hoisting.
-    if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
-        DeclInfo.IsReferencedInPostZone)
-      return false;
     if (!DeclInfo.IsReferencedInZone)
       continue; // no need to pass as parameter, not referenced
     if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
@@ -723,6 +792,19 @@ getSemicolonPolicy(ExtractionZone &ExtZone, const SourceManager &SM,
   return SemicolonPolicy;
 }
 
+QualType getReturnTypeForHoisted(const FunctionDecl &EnclosingFunc,
+                                 const HoistSet &ToHoist) {
+  // Hoisting just one variable, use that variables type instead of auto
+  if (ToHoist.size() == 1) {
+    if (const auto *const VDecl = llvm::dyn_cast<VarDecl>(*ToHoist.begin());
+        VDecl != nullptr) {
+      return VDecl->getType();
+    }
+  }
+
+  return EnclosingFunc.getParentASTContext().getAutoDeductType();
+}
+
 // Generate return type for ExtractedFunc. Return false if unable to do so.
 bool generateReturnProperties(NewFunction &ExtractedFunc,
                               const FunctionDecl &EnclosingFunc,
@@ -744,7 +826,11 @@ bool generateReturnProperties(NewFunction &ExtractedFunc,
     return true;
   }
   // FIXME: Generate new return statement if needed.
-  ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
+  ExtractedFunc.ReturnType =
+      ExtractedFunc.ToHoist.empty()
+          ? EnclosingFunc.getParentASTContext().VoidTy
+          : getReturnTypeForHoisted(EnclosingFunc, ExtractedFunc.ToHoist);
+
   return true;
 }
 
@@ -758,6 +844,7 @@ void captureMethodInfo(NewFunction &ExtractedFunc,
 // FIXME: add support for adding other function return types besides void.
 // FIXME: assign the value returned by non void extracted function.
 llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
+                                                 const HoistSet &ToHoist,
                                                  const SourceManager &SM,
                                                  const LangOptions &LangOpts) {
   CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
@@ -765,7 +852,7 @@ llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
   if (CapturedInfo.BrokenControlFlow)
     return error("Cannot extract break/continue without corresponding "
                  "loop/switch statement.");
-  NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts),
+  NewFunction ExtractedFunc(ToHoist, getSemicolonPolicy(ExtZone, SM, LangOpts),
                             &LangOpts);
 
   ExtractedFunc.SyntacticDC =
@@ -814,6 +901,7 @@ class ExtractFunction : public Tweak {
 
 private:
   ExtractionZone ExtZone;
+  HoistSet ToHoist;
 };
 
 REGISTER_TWEAK(ExtractFunction)
@@ -879,8 +967,19 @@ bool ExtractFunction::prepare(const Selection &Inputs) {
       (hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone)))
     return false;
 
-  // FIXME: Get rid of this check once we support hoisting.
-  if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver()))
+  ToHoist =
+      MaybeExtZone->getDeclsToHoist(SM, Inputs.AST->getHeuristicResolver());
+
+  // Cannot extract a selection that contains a type declaration that is used
+  // outside of the selected range
+  if (llvm::any_of(ToHoist, [](const NamedDecl *NDecl) {
+        return llvm::isa<TypeDecl>(NDecl);
+      }))
+    return false;
+
+  const auto HasAutoReturnTypeDeduction = LangOpts.CPlusPlus14;
+  const auto RequiresPairOrTuple = ToHoist.size() > 1;
+  if (RequiresPairOrTuple && !HasAutoReturnTypeDeduction)
     return false;
 
   ExtZone = std::move(*MaybeExtZone);
@@ -890,7 +989,7 @@ bool ExtractFunction::prepare(const Selection &Inputs) {
 Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) {
   const SourceManager &SM = Inputs.AST->getSourceManager();
   const LangOptions &LangOpts = Inputs.AST->getLangOpts();
-  auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
+  auto ExtractedFunc = getExtractedFunction(ExtZone, ToHoist, SM, LangOpts);
   // FIXME: Add more types of errors.
   if (!ExtractedFunc)
     return ExtractedFunc.takeError();
@@ -913,8 +1012,8 @@ Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) {
 
       tooling::Replacements OtherEdit(
           createForwardDeclaration(*ExtractedFunc, SM));
-      if (auto PathAndEdit = Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc),
-                                                 OtherEdit))
+      if (auto PathAndEdit =
+              Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), OtherEdit))
         MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first,
                                                 PathAndEdit->second);
       else
diff --git a/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp b/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
index dec63d454d52c6..58b6f0d5c86957 100644
--- a/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
+++ b/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
@@ -30,8 +30,8 @@ TEST_F(ExtractFunctionTest, FunctionTest) {
   EXPECT_EQ(apply("auto lam = [](){ [[int x;]] }; "), "unavailable");
   // Partial statements aren't extracted.
   EXPECT_THAT(apply("int [[x = 0]];"), "unavailable");
-  // FIXME: Support hoisting.
-  EXPECT_THAT(apply(" [[int a = 5;]] a++; "), "unavailable");
+  // Extract regions that require hoisting
+  EXPECT_THAT(apply(" [[int a = 5;]] a++; "), HasSubstr("extracted"));
 
   // Ensure that end of Zone and Beginning of PostZone being adjacent doesn't
   // lead to break being included in the extraction zone.
@@ -192,6 +192,395 @@ F (extracted();)
   EXPECT_EQ(apply(CompoundFailInput), "unavailable");
 }
 
+TEST_F(ExtractFunctionTest, Hoisting) {
+  ExtraArgs.emplace_back("-std=c++17");
+  std::string HoistingInput = R"cpp(
+    int foo() {
+      int a = 3;
+      [[int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;]]
+      return x + y + z;
+    }
+  )cpp";
+  std::string HoistingOutput = R"cpp(
+    auto extracted(int &a) {
+int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;
+return std::tuple{x, y, z};
+}
+int foo() {
+      int a = 3;
+      auto [x, y, z] = extracted(a);
+      return x + y + z;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput), HoistingOutput);
+
+  std::string HoistingInput2 = R"cpp(
+    int foo() {
+      int a{};
+      [[int b = a + 1;]]
+      return b;
+    }
+  )cpp";
+  std::string HoistingOutput2 = R"cpp(
+    int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+      int a{};
+      auto b = extracted(a);
+      return b;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+
+  std::string HoistingInput3 = R"cpp(
+    int foo(int b) {
+      int a{};
+      if (b == 42) {
+        [[a = 123;
+        return a + b;]]
+      }
+      a = 456;
+      return a;
+    }
+  )cpp";
+  std::string HoistingOutput3 = R"cpp(
+    int extracted(int &b, int &a) {
+a = 123;
+        return a + b;
+}
+int foo(int b) {
+      int a{};
+      if (b == 42) {
+        return extracted(b, a);
+      }
+      a = 456;
+      return a;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput3), HoistingOutput3);
+
+  std::string HoistingInput3B = R"cpp(
+    int foo(int b) {
+      [[int a{};
+      if (b == 42) {
+        a = 123;
+        return a + b;
+      }
+      a = 456;
+      return a;]]
+    }
+  )cpp";
+  std::string HoistingOutput3B = R"cpp(
+    int extracted(int &b) {
+int a{};
+      if (b == 42) {
+        a = 123;
+        return a + b;
+      }
+      a = 456;
+      return a;
+}
+int foo(int b) {
+      return extracted(b);
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput3B), HoistingOutput3B);
+
+  std::string HoistingInput4 = R"cpp(
+    struct A {
+      bool flag;
+      int val;
+    };
+    A bar();
+    int foo(int b) {
+      int a = 0;
+      [[auto [flag, val] = bar();
+      int c = 4;
+      val = c + a;]]
+      return a + b + c + val;
+    }
+  )cpp";
+  std::string HoistingOutput4 = R"cpp(
+    struct A {
+      bool flag;
+      int val;
+    };
+    A bar();
+    auto extracted(int &a) {
+auto [flag, val] = bar();
+      int c = 4;
+      val = c + a;
+return std::pair{val, c};
+}
+int foo(int b) {
+      int a = 0;
+      auto [val, c] = extracted(a);
+      return a + b + c + val;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput4), HoistingOutput4);
+
+  // Cannot extract a selection that contains a type declaration that is used
+  // outside of the selected range
+  EXPECT_THAT(apply(R"cpp(
+      [[using MyType = int;]]
+      MyType x = 42;
+      MyType y = x;
+    )cpp"),
+              "unavailable");
+  EXPECT_THAT(apply(R"cpp(
+      [[using MyType = int;
+      MyType x = 42;]]
+      MyType y = x;
+    )cpp"),
+              "unavailable");
+  EXPECT_THAT(apply(R"cpp(
+      [[struct Bar {
+        int X;
+      };
+      auto Y = Bar{42};]]
+      auto Z = Bar{Y};
+    )cpp"),
+              "unavailable");
+
+  // Check that selections containing type declarations can be extracted if
+  // there are no uses of the type after the selection
+  std::string FullTypeAliasInput = R"cpp(
+    void foo() {
+      [[using MyType = int;
+      MyType x = 42;
+      MyType y = x;]]
+    }
+    )cpp";
+  std::string FullTypeAliasOutput = R"cpp(
+    void extracted() {
+using MyType = int;
+      MyType x = 42;
+      MyType y = x;
+}
+void foo() {
+      extracted();
+    }
+    )cpp";
+  EXPECT_EQ(apply(FullTypeAliasInput), FullTypeAliasOutput);
+
+  std::string FullStructInput = R"cpp(
+    int foo() {
+      [[struct Bar {
+        int X;
+      };
+      auto Y = Bar{42};
+      auto Z = Bar{Y};
+      return 42;]]
+    }
+    )cpp";
+  std::string FullStructOutput = R"cpp(
+    int extracted() {
+struct Bar {
+        int X;
+      };
+      auto Y = Bar{42};
+      auto Z = Bar{Y};
+      return 42;
+}
+int foo() {
+      return extracted();
+    }
+    )cpp";
+  EXPECT_EQ(apply(FullStructInput), FullStructOutput);
+
+  std::string ReturnTypeIsAliasedInput = R"cpp(
+    int foo() {
+      [[struct Bar {
+        int X;
+      };
+      auto Y = Bar{42};
+      auto Z = Bar{Y};
+      using MyInt = int;
+      MyInt A = 42;
+      return A;]]
+    }
+    )cpp";
+  std::string ReturnTypeIsAliasedOutput = R"cpp(
+    int extracted() {
+struct Bar {
+        int X;
+      };
+      auto Y = Bar{42};
+      auto Z = Bar{Y};
+      using MyInt = int;
+      MyInt A = 42;
+      return A;
+}
+int foo() {
+      return extracted();
+    }
+    )cpp";
+  EXPECT_EQ(apply(ReturnTypeIsAliasedInput), ReturnTypeIsAliasedOutput);
+
+  EXPECT_THAT(apply(R"cpp(
+      [[struct Bar {
+        int X;
+      };
+      auto Y = Bar{42};]]
+      auto Z = Bar{Y};
+    )cpp"),
+              "unavailable");
+
+  std::string ControlStmtInput1 = R"cpp(
+    float foo(float* p, int n) {
+      [[float sum = 0.0F;
+      for (int i = 0; i < n; ++i) {
+        if (p[i] > 0.0F) {
+          sum += p[i];
+        }
+      }
+      return sum]];
+    }
+    )cpp";
+
+  std::string ControlStmtOutput1 = R"cpp(
+    float extracted(float * &p, int &n) {
+float sum = 0.0F;
+      for (int i = 0; i < n; ++i) {
+        if (p[i] > 0.0F) {
+          sum += p[i];
+        }
+      }
+      return sum;
+}
+float foo(float* p, int n) {
+      return extracted(p, n);
+    }
+    )cpp";
+
+  EXPECT_EQ(apply(ControlStmtInput1), ControlStmtOutput1);
+
+  std::string ControlStmtInput2 = R"cpp(
+    float foo(float* p, int n) {
+      float sum = 0.0F;
+      [[for (int i = 0; i < n; ++i) {
+        if (p[i] > 0.0F) {
+          sum += p[i];
+        }
+      }
+      return sum]];
+    }
+    )cpp";
+
+  std::string ControlStmtOutput2 = R"cpp(
+    float extracted(float * &p, int &n, float &sum) {
+for (int i = 0; i < n; ++i) {
+        if (p[i] > 0.0F) {
+          sum += p[i];
+        }
+      }
+      return sum;
+}
+float foo(float* p, int n) {
+      float sum = 0.0F;
+      return extracted(p, n, sum);
+    }
+    )cpp";
+
+  EXPECT_EQ(apply(ControlStmtInput2), ControlStmtOutput2);
+}
+
+TEST_F(ExtractFunctionTest, HoistingCXX11) {
+  ExtraArgs.emplace_back("-std=c++11");
+  std::string HoistingInput = R"cpp(
+    int foo() {
+      int a = 3;
+      [[int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;]]
+      return x + y + z;
+    }
+  )cpp";
+  EXPECT_THAT(apply(HoistingInput), HasSubstr("unavailable"));
+
+  std::string HoistingInput2 = R"cpp(
+    int foo() {
+      int a;
+      [[int b = a + 1;]]
+      return b;
+    }
+  )cpp";
+  std::string HoistingOutput2 = R"cpp(
+    int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+      int a;
+      auto b = extracted(a);
+      return b;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+}
+
+TEST_F(ExtractFunctionTest, HoistingCXX14) {
+  ExtraArgs.emplace_back("-std=c++14");
+  std::string HoistingInput = R"cpp(
+    int foo() {
+      int a = 3;
+      [[int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;]]
+      return x + y + z;
+    }
+  )cpp";
+  std::string HoistingOutput = R"cpp(
+    auto extracted(int &a) {
+int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;
+return std::tuple{x, y, z};
+}
+int foo() {
+      int a = 3;
+      auto returned = extracted(a);
+auto x = std::get<0>(returned);
+auto y = std::get<1>(returned);
+auto z = std::get<2>(returned);
+      return x + y + z;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput), HoistingOutput);
+
+  std::string HoistingInput2 = R"cpp(
+    int foo() {
+      int a;
+      [[int b = a + 1;]]
+      return b;
+    }
+  )cpp";
+  std::string HoistingOutput2 = R"cpp(
+    int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+      int a;
+      auto b = extracted(a);
+      return b;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+}
+
 TEST_F(ExtractFunctionTest, DifferentHeaderSourceTest) {
   Header = R"cpp(
     class SomeClass {
diff --git a/clang-tools-extra/docs/ReleaseNotes.rst b/clang-tools-extra/docs/ReleaseNotes.rst
index a604e9276668ae..d4103cba64b020 100644
--- a/clang-tools-extra/docs/ReleaseNotes.rst
+++ b/clang-tools-extra/docs/ReleaseNotes.rst
@@ -81,6 +81,9 @@ Objective-C
 Miscellaneous
 ^^^^^^^^^^^^^
 
+- The extract function tweak gained support for hoisting, i.e. returning decls declared
+  inside the selection that are used outside of the selection.
+
 Improvements to clang-doc
 -------------------------
 

>From 550adfcabb4852241760bc9a37f34c4f44069d08 Mon Sep 17 00:00:00 2001
From: Julian Schmidt <44101708+5chmidti at users.noreply.github.com>
Date: Fri, 12 Jan 2024 03:34:31 +0100
Subject: [PATCH 2/2] fix return type of hoisted lambda + add tests

---
 .../refactor/tweaks/ExtractFunction.cpp       |  6 ++-
 .../unittests/tweaks/ExtractFunctionTests.cpp | 54 +++++++++++++++++++
 2 files changed, 59 insertions(+), 1 deletion(-)

diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
index 02d1a6d0996a53..b19197bda4d9cf 100644
--- a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
+++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
@@ -798,7 +798,11 @@ QualType getReturnTypeForHoisted(const FunctionDecl &EnclosingFunc,
   if (ToHoist.size() == 1) {
     if (const auto *const VDecl = llvm::dyn_cast<VarDecl>(*ToHoist.begin());
         VDecl != nullptr) {
-      return VDecl->getType();
+      const auto Type = VDecl->getType();
+      if (const auto *const RDecl = Type->getAsCXXRecordDecl();
+          RDecl != nullptr && RDecl->isLambda())
+        return EnclosingFunc.getParentASTContext().getAutoDeductType();
+      return Type;
     }
   }
 
diff --git a/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp b/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
index 58b6f0d5c86957..8afef4a3562b07 100644
--- a/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
+++ b/clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
@@ -492,6 +492,60 @@ float foo(float* p, int n) {
     )cpp";
 
   EXPECT_EQ(apply(ControlStmtInput2), ControlStmtOutput2);
+
+  std::string LambdaReturnInput1 = R"cpp(
+    void foo() {
+      for (;;) {
+        [[auto l = [](){};
+        l();]]
+        l();
+      }
+    }
+  )cpp";
+
+  std::string LambdaReturnOutput1 = R"cpp(
+    auto extracted() {
+auto l = [](){};
+        l();
+return l;
+}
+void foo() {
+      for (;;) {
+        auto l = extracted();
+        l();
+      }
+    }
+  )cpp";
+
+  EXPECT_EQ(apply(LambdaReturnInput1), LambdaReturnOutput1);
+
+  std::string LambdaReturnInput2 = R"cpp(
+    void foo() {
+      for (;;) {
+        [[auto l = [](int v){};
+        int v = 42;
+        l(v);]]
+        l(v);
+      }
+    }
+  )cpp";
+
+  std::string LambdaReturnOutput2 = R"cpp(
+    auto extracted() {
+auto l = [](int v){};
+        int v = 42;
+        l(v);
+return std::pair{l, v};
+}
+void foo() {
+      for (;;) {
+        auto [l, v] = extracted();
+        l(v);
+      }
+    }
+  )cpp";
+
+  EXPECT_EQ(apply(LambdaReturnInput2), LambdaReturnOutput2);
 }
 
 TEST_F(ExtractFunctionTest, HoistingCXX11) {



More information about the cfe-commits mailing list