[clang] 15c49b9 - [Coroutines] [CodeGen] Don't change AST in CodeGen/Coroutines

Chuanqi Xu via cfe-commits cfe-commits at lists.llvm.org
Fri Feb 28 00:04:42 PST 2025


Author: Chuanqi Xu
Date: 2025-02-28T16:03:50+08:00
New Revision: 15c49b9db3f60bdbd320271d5e97f118c00b95dd

URL: https://github.com/llvm/llvm-project/commit/15c49b9db3f60bdbd320271d5e97f118c00b95dd
DIFF: https://github.com/llvm/llvm-project/commit/15c49b9db3f60bdbd320271d5e97f118c00b95dd.diff

LOG: [Coroutines] [CodeGen] Don't change AST in CodeGen/Coroutines

The root source of other odd bugs.

We performed a hack in CodeGen/Coroutines. But we didn't recognize that
the CodeGen is a consumer of AST. The CodeGen shouldn't change AST in
any ways. It'll break the assumption about the ASTConsumer in Clang's
framework, which may break any other clang-based tools which depends on
multiple consumers to work together.

The fix here is simple. But I am not super happy about the test. It is
too specific and verbose. We can remove this if we can get the signature
of the AST in ASTContext.

Added: 
    clang/unittests/Frontend/NoAlterCodeGenActionTest.cpp

Modified: 
    clang/lib/CodeGen/CGCoroutine.cpp
    clang/unittests/Frontend/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp
index 9abf2e8c9190d..058ec01f8ce0e 100644
--- a/clang/lib/CodeGen/CGCoroutine.cpp
+++ b/clang/lib/CodeGen/CGCoroutine.cpp
@@ -942,9 +942,16 @@ void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) {
   if (Stmt *Ret = S.getReturnStmt()) {
     // Since we already emitted the return value above, so we shouldn't
     // emit it again here.
-    if (GroManager.DirectEmit)
+    Expr *PreviousRetValue = nullptr;
+    if (GroManager.DirectEmit) {
+      PreviousRetValue = cast<ReturnStmt>(Ret)->getRetValue();
       cast<ReturnStmt>(Ret)->setRetValue(nullptr);
+    }
     EmitStmt(Ret);
+    // Set the return value back. The code generator, as the AST **Consumer**,
+    // shouldn't change the AST.
+    if (PreviousRetValue)
+      cast<ReturnStmt>(Ret)->setRetValue(PreviousRetValue);
   }
 
   // LLVM require the frontend to mark the coroutine.

diff  --git a/clang/unittests/Frontend/CMakeLists.txt b/clang/unittests/Frontend/CMakeLists.txt
index 0f05813338f2a..3c94846243870 100644
--- a/clang/unittests/Frontend/CMakeLists.txt
+++ b/clang/unittests/Frontend/CMakeLists.txt
@@ -10,6 +10,7 @@ add_clang_unittest(FrontendTests
   FixedPointString.cpp
   FrontendActionTest.cpp
   CodeGenActionTest.cpp
+  NoAlterCodeGenActionTest.cpp
   ParsedSourceLocationTest.cpp
   PCHPreambleTest.cpp
   ReparseWorkingDirTest.cpp
@@ -27,4 +28,5 @@ clang_target_link_libraries(FrontendTests
   clangCodeGen
   clangFrontendTool
   clangSerialization
+  clangTooling
   )

diff  --git a/clang/unittests/Frontend/NoAlterCodeGenActionTest.cpp b/clang/unittests/Frontend/NoAlterCodeGenActionTest.cpp
new file mode 100644
index 0000000000000..e7a3bf5a7f87a
--- /dev/null
+++ b/clang/unittests/Frontend/NoAlterCodeGenActionTest.cpp
@@ -0,0 +1,198 @@
+//===- unittests/Frontend/NoAlterCodeGenActionTest.cpp --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Unit tests for CodeGenAction may not alter the AST.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/ASTConsumer.h"
+#include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/Basic/LangStandard.h"
+#include "clang/CodeGen/BackendUtil.h"
+#include "clang/CodeGen/CodeGenAction.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/MultiplexConsumer.h"
+#include "clang/Lex/PreprocessorOptions.h"
+#include "clang/Tooling/Tooling.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/VirtualFileSystem.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace clang;
+using namespace clang::frontend;
+using namespace clang::tooling;
+
+namespace {
+
+class ASTChecker : public RecursiveASTVisitor<ASTChecker> {
+public:
+  ASTContext &Ctx;
+  ASTChecker(ASTContext &Ctx) : Ctx(Ctx) {}
+  bool VisitReturnStmt(ReturnStmt *RS) {
+    EXPECT_TRUE(RS->getRetValue());
+    return true;
+  }
+
+  bool VisitCoroutineBodyStmt(CoroutineBodyStmt *CS) {
+    return VisitReturnStmt(cast<ReturnStmt>(CS->getReturnStmt()));
+  }
+};
+
+class ASTCheckerConsumer : public ASTConsumer {
+public:
+  void HandleTranslationUnit(ASTContext &Ctx) override {
+    ASTChecker Checker(Ctx);
+    Checker.TraverseAST(Ctx);
+  }
+};
+
+class TestCodeGenAction : public EmitLLVMAction {
+public:
+  using Base = EmitLLVMAction;
+  TestCodeGenAction(llvm::LLVMContext *_VMContext = nullptr)
+      : EmitLLVMAction(_VMContext) {}
+
+  std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
+                                                 StringRef InFile) override {
+    std::vector<std::unique_ptr<ASTConsumer>> Consumers;
+    Consumers.push_back(std::make_unique<ASTCheckerConsumer>());
+    Consumers.push_back(Base::CreateASTConsumer(CI, InFile));
+    return std::make_unique<MultiplexConsumer>(std::move(Consumers));
+  }
+};
+
+const char *test_contents = R"cpp(
+
+namespace std {
+
+template <typename R, typename...> struct coroutine_traits {
+  using promise_type = typename R::promise_type;
+};
+
+template <typename Promise = void> struct coroutine_handle;
+
+template <> struct coroutine_handle<void> {
+  static coroutine_handle from_address(void *addr) noexcept;
+  void operator()() { resume(); }
+  void *address() const noexcept;
+  void resume() const { __builtin_coro_resume(ptr); }
+  void destroy() const { __builtin_coro_destroy(ptr); }
+  bool done() const;
+  coroutine_handle &operator=(decltype(nullptr));
+  coroutine_handle(decltype(nullptr)) : ptr(nullptr) {}
+  coroutine_handle() : ptr(nullptr) {}
+//  void reset() { ptr = nullptr; } // add to P0057?
+  explicit operator bool() const;
+
+protected:
+  void *ptr;
+};
+
+template <typename Promise> struct coroutine_handle : coroutine_handle<> {
+  using coroutine_handle<>::operator=;
+
+  static coroutine_handle from_address(void *addr) noexcept;
+
+  Promise &promise() const;
+  static coroutine_handle from_promise(Promise &promise);
+};
+
+template <typename _PromiseT>
+bool operator==(coroutine_handle<_PromiseT> const &_Left,
+                coroutine_handle<_PromiseT> const &_Right) noexcept {
+  return _Left.address() == _Right.address();
+}
+
+template <typename _PromiseT>
+bool operator!=(coroutine_handle<_PromiseT> const &_Left,
+                coroutine_handle<_PromiseT> const &_Right) noexcept {
+  return !(_Left == _Right);
+}
+
+struct noop_coroutine_promise {};
+
+template <>
+struct coroutine_handle<noop_coroutine_promise> {
+  operator coroutine_handle<>() const noexcept;
+
+  constexpr explicit operator bool() const noexcept { return true; }
+  constexpr bool done() const noexcept { return false; }
+
+  constexpr void operator()() const noexcept {}
+  constexpr void resume() const noexcept {}
+  constexpr void destroy() const noexcept {}
+
+  noop_coroutine_promise &promise() const noexcept {
+    return *static_cast<noop_coroutine_promise *>(
+        __builtin_coro_promise(this->__handle_, alignof(noop_coroutine_promise), false));
+  }
+
+  constexpr void *address() const noexcept { return __handle_; }
+
+private:
+  friend coroutine_handle<noop_coroutine_promise> noop_coroutine() noexcept;
+
+  coroutine_handle() noexcept {
+    this->__handle_ = __builtin_coro_noop();
+  }
+
+  void *__handle_ = nullptr;
+};
+
+using noop_coroutine_handle = coroutine_handle<noop_coroutine_promise>;
+
+inline noop_coroutine_handle noop_coroutine() noexcept { return noop_coroutine_handle(); }
+
+struct suspend_always {
+  bool await_ready() noexcept { return false; }
+  void await_suspend(coroutine_handle<>) noexcept {}
+  void await_resume() noexcept {}
+};
+struct suspend_never {
+  bool await_ready() noexcept { return true; }
+  void await_suspend(coroutine_handle<>) noexcept {}
+  void await_resume() noexcept {}
+};
+
+} // namespace std
+
+using namespace std;
+
+class invoker {
+public:
+  class invoker_promise {
+  public:
+    invoker get_return_object() { return invoker{}; }
+    auto initial_suspend() { return suspend_always{}; }
+    auto final_suspend() noexcept { return suspend_always{}; }
+    void return_void() {}
+    void unhandled_exception() {}
+  };
+  using promise_type = invoker_promise;
+  invoker() {}
+  invoker(const invoker &) = delete;
+  invoker &operator=(const invoker &) = delete;
+  invoker(invoker &&) = delete;
+  invoker &operator=(invoker &&) = delete;
+};
+
+invoker g() {
+  co_return;
+}
+
+)cpp";
+
+TEST(CodeGenTest, TestNonAlterTest) {
+  EXPECT_TRUE(runToolOnCodeWithArgs(std::make_unique<TestCodeGenAction>(),
+                                    test_contents,
+                                    {
+                                        "-std=c++20",
+                                    }));
+}
+} // namespace


        


More information about the cfe-commits mailing list