[clang] [clang-repl] Factor out CreateJITBuilder() and allow specialization in derived classes (PR #84461)

Stefan Gränitz via cfe-commits cfe-commits at lists.llvm.org
Fri Mar 8 03:05:55 PST 2024


https://github.com/weliveindetail created https://github.com/llvm/llvm-project/pull/84461

The LLJITBuilder interface provides a very convenient way to configure the ORCv2 JIT engine. IncrementalExecutor already used it internally to construct the JIT, but didn't provide external access. This patch lifts control of the creation process to the Interpreter and allows injection of a custom instance through the extended interface. The Interpreter's default behavior remains unchanged and the IncrementalExecutor remains an implementation detail.

>From f87c7ccf59ad436087ada4be479102f7317d50ad Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Stefan=20Gr=C3=A4nitz?= <stefan.graenitz at gmail.com>
Date: Fri, 8 Mar 2024 00:08:56 +0100
Subject: [PATCH 1/3] [clang-repl] Expose RuntimeInterfaceBuilder for
 customizations

---
 clang/include/clang/Interpreter/Interpreter.h |  35 ++-
 clang/lib/Interpreter/Interpreter.cpp         | 247 ++++++++++--------
 clang/unittests/Interpreter/CMakeLists.txt    |   1 +
 .../Interpreter/InterpreterExtensionsTest.cpp |  79 ++++++
 4 files changed, 253 insertions(+), 109 deletions(-)
 create mode 100644 clang/unittests/Interpreter/InterpreterExtensionsTest.cpp

diff --git a/clang/include/clang/Interpreter/Interpreter.h b/clang/include/clang/Interpreter/Interpreter.h
index c8f932e95c4798..d972d960dcb7cd 100644
--- a/clang/include/clang/Interpreter/Interpreter.h
+++ b/clang/include/clang/Interpreter/Interpreter.h
@@ -18,6 +18,7 @@
 #include "clang/AST/GlobalDecl.h"
 #include "clang/Interpreter/PartialTranslationUnit.h"
 #include "clang/Interpreter/Value.h"
+#include "clang/Sema/Ownership.h"
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ExecutionEngine/JITSymbol.h"
@@ -75,17 +76,26 @@ class IncrementalCompilerBuilder {
   llvm::StringRef CudaSDKPath;
 };
 
+/// Generate glue code between the Interpreter's built-in runtime and user code.
+class RuntimeInterfaceBuilder {
+public:
+  virtual ~RuntimeInterfaceBuilder() = default;
+
+  using TransformExprFunction = ExprResult(RuntimeInterfaceBuilder *Builder,
+                                           Expr *, ArrayRef<Expr *>);
+  virtual TransformExprFunction *getPrintValueTransformer() = 0;
+};
+
 /// Provides top-level interfaces for incremental compilation and execution.
 class Interpreter {
   std::unique_ptr<llvm::orc::ThreadSafeContext> TSCtx;
   std::unique_ptr<IncrementalParser> IncrParser;
   std::unique_ptr<IncrementalExecutor> IncrExecutor;
+  std::unique_ptr<RuntimeInterfaceBuilder> RuntimeIB;
 
   // An optional parser for CUDA offloading
   std::unique_ptr<IncrementalParser> DeviceParser;
 
-  Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err);
-
   llvm::Error CreateExecutor();
   unsigned InitPTUSize = 0;
 
@@ -94,8 +104,25 @@ class Interpreter {
   // printing happens, it's in an invalid state.
   Value LastValue;
 
+  // Add a call to an Expr to report its result. We query the function from
+  // RuntimeInterfaceBuilder once and store it as a function pointer to avoid
+  // frequent virtual function calls.
+  RuntimeInterfaceBuilder::TransformExprFunction *AddPrintValueCall = nullptr;
+
+protected:
+  // Derived classes can make use an extended interface of the Interpreter.
+  // That's useful for testing and out-of-tree clients.
+  Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err);
+
+  // Lazily construct the RuntimeInterfaceBuilder. The provided instance will be
+  // used for the entire lifetime of the interpreter. The default implementation
+  // targets the in-process __clang_Interpreter runtime. Override this to use a
+  // custom runtime.
+  virtual std::unique_ptr<RuntimeInterfaceBuilder> FindRuntimeInterface();
+
 public:
-  ~Interpreter();
+  virtual ~Interpreter();
+
   static llvm::Expected<std::unique_ptr<Interpreter>>
   create(std::unique_ptr<CompilerInstance> CI);
   static llvm::Expected<std::unique_ptr<Interpreter>>
@@ -143,8 +170,6 @@ class Interpreter {
 private:
   size_t getEffectivePTUSize() const;
 
-  bool FindRuntimeInterface();
-
   llvm::DenseMap<CXXRecordDecl *, llvm::orc::ExecutorAddr> Dtors;
 
   llvm::SmallVector<Expr *, 4> ValuePrintingInfo;
diff --git a/clang/lib/Interpreter/Interpreter.cpp b/clang/lib/Interpreter/Interpreter.cpp
index 37696b28976428..3485da8196683a 100644
--- a/clang/lib/Interpreter/Interpreter.cpp
+++ b/clang/lib/Interpreter/Interpreter.cpp
@@ -507,9 +507,13 @@ static constexpr llvm::StringRef MagicRuntimeInterface[] = {
     "__clang_Interpreter_SetValueWithAlloc",
     "__clang_Interpreter_SetValueCopyArr", "__ci_newtag"};
 
-bool Interpreter::FindRuntimeInterface() {
+static std::unique_ptr<RuntimeInterfaceBuilder>
+createInProcessRuntimeInterfaceBuilder(Interpreter &Interp, ASTContext &Ctx,
+                                       Sema &S);
+
+std::unique_ptr<RuntimeInterfaceBuilder> Interpreter::FindRuntimeInterface() {
   if (llvm::all_of(ValuePrintingInfo, [](Expr *E) { return E != nullptr; }))
-    return true;
+    return nullptr;
 
   Sema &S = getCompilerInstance()->getSema();
   ASTContext &Ctx = S.getASTContext();
@@ -528,120 +532,34 @@ bool Interpreter::FindRuntimeInterface() {
 
   if (!LookupInterface(ValuePrintingInfo[NoAlloc],
                        MagicRuntimeInterface[NoAlloc]))
-    return false;
+    return nullptr;
   if (!LookupInterface(ValuePrintingInfo[WithAlloc],
                        MagicRuntimeInterface[WithAlloc]))
-    return false;
+    return nullptr;
   if (!LookupInterface(ValuePrintingInfo[CopyArray],
                        MagicRuntimeInterface[CopyArray]))
-    return false;
+    return nullptr;
   if (!LookupInterface(ValuePrintingInfo[NewTag],
                        MagicRuntimeInterface[NewTag]))
-    return false;
-  return true;
+    return nullptr;
+
+  return createInProcessRuntimeInterfaceBuilder(*this, Ctx, S);
 }
 
 namespace {
 
-class RuntimeInterfaceBuilder
-    : public TypeVisitor<RuntimeInterfaceBuilder, Interpreter::InterfaceKind> {
-  clang::Interpreter &Interp;
+class InterfaceKindVisitor
+    : public TypeVisitor<InterfaceKindVisitor, Interpreter::InterfaceKind> {
+  friend class InProcessRuntimeInterfaceBuilder;
+
   ASTContext &Ctx;
   Sema &S;
   Expr *E;
   llvm::SmallVector<Expr *, 3> Args;
 
 public:
-  RuntimeInterfaceBuilder(clang::Interpreter &In, ASTContext &C, Sema &SemaRef,
-                          Expr *VE, ArrayRef<Expr *> FixedArgs)
-      : Interp(In), Ctx(C), S(SemaRef), E(VE) {
-    // The Interpreter* parameter and the out parameter `OutVal`.
-    for (Expr *E : FixedArgs)
-      Args.push_back(E);
-
-    // Get rid of ExprWithCleanups.
-    if (auto *EWC = llvm::dyn_cast_if_present<ExprWithCleanups>(E))
-      E = EWC->getSubExpr();
-  }
-
-  ExprResult getCall() {
-    QualType Ty = E->getType();
-    QualType DesugaredTy = Ty.getDesugaredType(Ctx);
-
-    // For lvalue struct, we treat it as a reference.
-    if (DesugaredTy->isRecordType() && E->isLValue()) {
-      DesugaredTy = Ctx.getLValueReferenceType(DesugaredTy);
-      Ty = Ctx.getLValueReferenceType(Ty);
-    }
-
-    Expr *TypeArg =
-        CStyleCastPtrExpr(S, Ctx.VoidPtrTy, (uintptr_t)Ty.getAsOpaquePtr());
-    // The QualType parameter `OpaqueType`, represented as `void*`.
-    Args.push_back(TypeArg);
-
-    // We push the last parameter based on the type of the Expr. Note we need
-    // special care for rvalue struct.
-    Interpreter::InterfaceKind Kind = Visit(&*DesugaredTy);
-    switch (Kind) {
-    case Interpreter::InterfaceKind::WithAlloc:
-    case Interpreter::InterfaceKind::CopyArray: {
-      // __clang_Interpreter_SetValueWithAlloc.
-      ExprResult AllocCall = S.ActOnCallExpr(
-          /*Scope=*/nullptr,
-          Interp.getValuePrintingInfo()[Interpreter::InterfaceKind::WithAlloc],
-          E->getBeginLoc(), Args, E->getEndLoc());
-      assert(!AllocCall.isInvalid() && "Can't create runtime interface call!");
-
-      TypeSourceInfo *TSI = Ctx.getTrivialTypeSourceInfo(Ty, SourceLocation());
-
-      // Force CodeGen to emit destructor.
-      if (auto *RD = Ty->getAsCXXRecordDecl()) {
-        auto *Dtor = S.LookupDestructor(RD);
-        Dtor->addAttr(UsedAttr::CreateImplicit(Ctx));
-        Interp.getCompilerInstance()->getASTConsumer().HandleTopLevelDecl(
-            DeclGroupRef(Dtor));
-      }
-
-      // __clang_Interpreter_SetValueCopyArr.
-      if (Kind == Interpreter::InterfaceKind::CopyArray) {
-        const auto *ConstantArrTy =
-            cast<ConstantArrayType>(DesugaredTy.getTypePtr());
-        size_t ArrSize = Ctx.getConstantArrayElementCount(ConstantArrTy);
-        Expr *ArrSizeExpr = IntegerLiteralExpr(Ctx, ArrSize);
-        Expr *Args[] = {E, AllocCall.get(), ArrSizeExpr};
-        return S.ActOnCallExpr(
-            /*Scope *=*/nullptr,
-            Interp
-                .getValuePrintingInfo()[Interpreter::InterfaceKind::CopyArray],
-            SourceLocation(), Args, SourceLocation());
-      }
-      Expr *Args[] = {
-          AllocCall.get(),
-          Interp.getValuePrintingInfo()[Interpreter::InterfaceKind::NewTag]};
-      ExprResult CXXNewCall = S.BuildCXXNew(
-          E->getSourceRange(),
-          /*UseGlobal=*/true, /*PlacementLParen=*/SourceLocation(), Args,
-          /*PlacementRParen=*/SourceLocation(),
-          /*TypeIdParens=*/SourceRange(), TSI->getType(), TSI, std::nullopt,
-          E->getSourceRange(), E);
-
-      assert(!CXXNewCall.isInvalid() &&
-             "Can't create runtime placement new call!");
-
-      return S.ActOnFinishFullExpr(CXXNewCall.get(),
-                                   /*DiscardedValue=*/false);
-    }
-      // __clang_Interpreter_SetValueNoAlloc.
-    case Interpreter::InterfaceKind::NoAlloc: {
-      return S.ActOnCallExpr(
-          /*Scope=*/nullptr,
-          Interp.getValuePrintingInfo()[Interpreter::InterfaceKind::NoAlloc],
-          E->getBeginLoc(), Args, E->getEndLoc());
-    }
-    default:
-      llvm_unreachable("Unhandled Interpreter::InterfaceKind");
-    }
-  }
+  InterfaceKindVisitor(ASTContext &Ctx, Sema &S, Expr *E)
+      : Ctx(Ctx), S(S), E(E) {}
 
   Interpreter::InterfaceKind VisitRecordType(const RecordType *Ty) {
     return Interpreter::InterfaceKind::WithAlloc;
@@ -713,8 +631,124 @@ class RuntimeInterfaceBuilder
     Args.push_back(CastedExpr.get());
   }
 };
+
+class InProcessRuntimeInterfaceBuilder : public RuntimeInterfaceBuilder {
+  Interpreter &Interp;
+  ASTContext &Ctx;
+  Sema &S;
+
+public:
+  InProcessRuntimeInterfaceBuilder(Interpreter &Interp, ASTContext &C, Sema &S)
+      : Interp(Interp), Ctx(C), S(S) {}
+
+  TransformExprFunction *getPrintValueTransformer() override {
+    return &transformForValuePrinting;
+  }
+
+private:
+  static ExprResult transformForValuePrinting(RuntimeInterfaceBuilder *Builder,
+                                              Expr *E,
+                                              ArrayRef<Expr *> FixedArgs) {
+    auto *B = static_cast<InProcessRuntimeInterfaceBuilder *>(Builder);
+
+    // Get rid of ExprWithCleanups.
+    if (auto *EWC = llvm::dyn_cast_if_present<ExprWithCleanups>(E))
+      E = EWC->getSubExpr();
+
+    InterfaceKindVisitor Visitor(B->Ctx, B->S, E);
+
+    // The Interpreter* parameter and the out parameter `OutVal`.
+    for (Expr *E : FixedArgs)
+      Visitor.Args.push_back(E);
+
+    QualType Ty = E->getType();
+    QualType DesugaredTy = Ty.getDesugaredType(B->Ctx);
+
+    // For lvalue struct, we treat it as a reference.
+    if (DesugaredTy->isRecordType() && E->isLValue()) {
+      DesugaredTy = B->Ctx.getLValueReferenceType(DesugaredTy);
+      Ty = B->Ctx.getLValueReferenceType(Ty);
+    }
+
+    Expr *TypeArg = CStyleCastPtrExpr(B->S, B->Ctx.VoidPtrTy,
+                                      (uintptr_t)Ty.getAsOpaquePtr());
+    // The QualType parameter `OpaqueType`, represented as `void*`.
+    Visitor.Args.push_back(TypeArg);
+
+    // We push the last parameter based on the type of the Expr. Note we need
+    // special care for rvalue struct.
+    Interpreter::InterfaceKind Kind = Visitor.Visit(&*DesugaredTy);
+    switch (Kind) {
+    case Interpreter::InterfaceKind::WithAlloc:
+    case Interpreter::InterfaceKind::CopyArray: {
+      // __clang_Interpreter_SetValueWithAlloc.
+      ExprResult AllocCall = B->S.ActOnCallExpr(
+          /*Scope=*/nullptr,
+          B->Interp
+              .getValuePrintingInfo()[Interpreter::InterfaceKind::WithAlloc],
+          E->getBeginLoc(), Visitor.Args, E->getEndLoc());
+      assert(!AllocCall.isInvalid() && "Can't create runtime interface call!");
+
+      TypeSourceInfo *TSI =
+          B->Ctx.getTrivialTypeSourceInfo(Ty, SourceLocation());
+
+      // Force CodeGen to emit destructor.
+      if (auto *RD = Ty->getAsCXXRecordDecl()) {
+        auto *Dtor = B->S.LookupDestructor(RD);
+        Dtor->addAttr(UsedAttr::CreateImplicit(B->Ctx));
+        B->Interp.getCompilerInstance()->getASTConsumer().HandleTopLevelDecl(
+            DeclGroupRef(Dtor));
+      }
+
+      // __clang_Interpreter_SetValueCopyArr.
+      if (Kind == Interpreter::InterfaceKind::CopyArray) {
+        const auto *ConstantArrTy =
+            cast<ConstantArrayType>(DesugaredTy.getTypePtr());
+        size_t ArrSize = B->Ctx.getConstantArrayElementCount(ConstantArrTy);
+        Expr *ArrSizeExpr = IntegerLiteralExpr(B->Ctx, ArrSize);
+        Expr *Args[] = {E, AllocCall.get(), ArrSizeExpr};
+        return B->S.ActOnCallExpr(
+            /*Scope *=*/nullptr,
+            B->Interp
+                .getValuePrintingInfo()[Interpreter::InterfaceKind::CopyArray],
+            SourceLocation(), Args, SourceLocation());
+      }
+      Expr *Args[] = {
+          AllocCall.get(),
+          B->Interp.getValuePrintingInfo()[Interpreter::InterfaceKind::NewTag]};
+      ExprResult CXXNewCall = B->S.BuildCXXNew(
+          E->getSourceRange(),
+          /*UseGlobal=*/true, /*PlacementLParen=*/SourceLocation(), Args,
+          /*PlacementRParen=*/SourceLocation(),
+          /*TypeIdParens=*/SourceRange(), TSI->getType(), TSI, std::nullopt,
+          E->getSourceRange(), E);
+
+      assert(!CXXNewCall.isInvalid() &&
+             "Can't create runtime placement new call!");
+
+      return B->S.ActOnFinishFullExpr(CXXNewCall.get(),
+                                      /*DiscardedValue=*/false);
+    }
+      // __clang_Interpreter_SetValueNoAlloc.
+    case Interpreter::InterfaceKind::NoAlloc: {
+      return B->S.ActOnCallExpr(
+          /*Scope=*/nullptr,
+          B->Interp.getValuePrintingInfo()[Interpreter::InterfaceKind::NoAlloc],
+          E->getBeginLoc(), Visitor.Args, E->getEndLoc());
+    }
+    default:
+      llvm_unreachable("Unhandled Interpreter::InterfaceKind");
+    }
+  }
+};
 } // namespace
 
+static std::unique_ptr<RuntimeInterfaceBuilder>
+createInProcessRuntimeInterfaceBuilder(Interpreter &Interp, ASTContext &Ctx,
+                                       Sema &S) {
+  return std::make_unique<InProcessRuntimeInterfaceBuilder>(Interp, Ctx, S);
+}
+
 // This synthesizes a call expression to a speciall
 // function that is responsible for generating the Value.
 // In general, we transform:
@@ -733,8 +767,13 @@ Expr *Interpreter::SynthesizeExpr(Expr *E) {
   Sema &S = getCompilerInstance()->getSema();
   ASTContext &Ctx = S.getASTContext();
 
-  if (!FindRuntimeInterface())
-    llvm_unreachable("We can't find the runtime iterface for pretty print!");
+  if (!RuntimeIB) {
+    RuntimeIB = FindRuntimeInterface();
+    AddPrintValueCall = RuntimeIB->getPrintValueTransformer();
+  }
+
+  assert(AddPrintValueCall &&
+         "We don't have a runtime interface for pretty print!");
 
   // Create parameter `ThisInterp`.
   auto *ThisInterp = CStyleCastPtrExpr(S, Ctx.VoidPtrTy, (uintptr_t)this);
@@ -743,9 +782,9 @@ Expr *Interpreter::SynthesizeExpr(Expr *E) {
   auto *OutValue = CStyleCastPtrExpr(S, Ctx.VoidPtrTy, (uintptr_t)&LastValue);
 
   // Build `__clang_Interpreter_SetValue*` call.
-  RuntimeInterfaceBuilder Builder(*this, Ctx, S, E, {ThisInterp, OutValue});
+  ExprResult Result =
+      AddPrintValueCall(RuntimeIB.get(), E, {ThisInterp, OutValue});
 
-  ExprResult Result = Builder.getCall();
   // It could fail, like printing an array type in C. (not supported)
   if (Result.isInvalid())
     return E;
diff --git a/clang/unittests/Interpreter/CMakeLists.txt b/clang/unittests/Interpreter/CMakeLists.txt
index 0ddedb283e07d1..046d96ad0ec644 100644
--- a/clang/unittests/Interpreter/CMakeLists.txt
+++ b/clang/unittests/Interpreter/CMakeLists.txt
@@ -10,6 +10,7 @@ add_clang_unittest(ClangReplInterpreterTests
   IncrementalCompilerBuilderTest.cpp
   IncrementalProcessingTest.cpp
   InterpreterTest.cpp
+  InterpreterExtensionsTest.cpp
   CodeCompletionTest.cpp
   )
 target_link_libraries(ClangReplInterpreterTests PUBLIC
diff --git a/clang/unittests/Interpreter/InterpreterExtensionsTest.cpp b/clang/unittests/Interpreter/InterpreterExtensionsTest.cpp
new file mode 100644
index 00000000000000..4e9f2dba210a37
--- /dev/null
+++ b/clang/unittests/Interpreter/InterpreterExtensionsTest.cpp
@@ -0,0 +1,79 @@
+//===- unittests/Interpreter/InterpreterExtensionsTest.cpp ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Unit tests for Clang's Interpreter library.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Interpreter/Interpreter.h"
+
+#include "clang/AST/Expr.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Sema/Lookup.h"
+#include "clang/Sema/Sema.h"
+
+#include "llvm/Support/Error.h"
+#include "llvm/Testing/Support/Error.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <system_error>
+
+using namespace clang;
+namespace {
+
+class RecordRuntimeIBMetrics : public Interpreter {
+  struct NoopRuntimeInterfaceBuilder : public RuntimeInterfaceBuilder {
+    NoopRuntimeInterfaceBuilder(Sema &S) : S(S) {}
+
+    TransformExprFunction *getPrintValueTransformer() override {
+      TransformerQueries += 1;
+      return &noop;
+    }
+
+    static ExprResult noop(RuntimeInterfaceBuilder *Builder, Expr *E,
+                           ArrayRef<Expr *> FixedArgs) {
+      auto *B = static_cast<NoopRuntimeInterfaceBuilder *>(Builder);
+      B->TransformedExprs += 1;
+      return B->S.ActOnFinishFullExpr(E, /*DiscardedValue=*/false);
+    }
+
+    Sema &S;
+    size_t TransformedExprs = 0;
+    size_t TransformerQueries = 0;
+  };
+
+public:
+  // Inherit with using wouldn't make it public
+  RecordRuntimeIBMetrics(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err)
+      : Interpreter(std::move(CI), Err) {}
+
+  std::unique_ptr<RuntimeInterfaceBuilder> FindRuntimeInterface() override {
+    assert(RuntimeIBPtr == nullptr && "We create the builder only once");
+    Sema &S = getCompilerInstance()->getSema();
+    auto RuntimeIB = std::make_unique<NoopRuntimeInterfaceBuilder>(S);
+    RuntimeIBPtr = RuntimeIB.get();
+    return RuntimeIB;
+  }
+
+  NoopRuntimeInterfaceBuilder *RuntimeIBPtr = nullptr;
+};
+
+TEST(InterpreterExtensionsTest, FindRuntimeInterface) {
+  clang::IncrementalCompilerBuilder CB;
+  llvm::Error ErrOut = llvm::Error::success();
+  RecordRuntimeIBMetrics Interp(cantFail(CB.CreateCpp()), ErrOut);
+  cantFail(std::move(ErrOut));
+  cantFail(Interp.Parse("int a = 1; a"));
+  cantFail(Interp.Parse("int b = 2; b"));
+  cantFail(Interp.Parse("int c = 3; c"));
+  EXPECT_EQ(3U, Interp.RuntimeIBPtr->TransformedExprs);
+  EXPECT_EQ(1U, Interp.RuntimeIBPtr->TransformerQueries);
+}
+
+} // end anonymous namespace

>From 1972f25790ca50ffc616903be49a744c9f9aa938 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Stefan=20Gr=C3=A4nitz?= <stefan.graenitz at gmail.com>
Date: Thu, 7 Mar 2024 23:04:22 +0100
Subject: [PATCH 2/3] [clang-repl] Expose CreateExecutor() and ResetExecutor()

---
 clang/include/clang/Interpreter/Interpreter.h |  9 ++++++-
 clang/lib/Interpreter/Interpreter.cpp         | 17 ++++++++++++
 clang/unittests/Interpreter/CMakeLists.txt    |  1 +
 .../Interpreter/InterpreterExtensionsTest.cpp | 26 +++++++++++++++++++
 4 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/Interpreter/Interpreter.h b/clang/include/clang/Interpreter/Interpreter.h
index d972d960dcb7cd..344f38da5ab21b 100644
--- a/clang/include/clang/Interpreter/Interpreter.h
+++ b/clang/include/clang/Interpreter/Interpreter.h
@@ -96,7 +96,6 @@ class Interpreter {
   // An optional parser for CUDA offloading
   std::unique_ptr<IncrementalParser> DeviceParser;
 
-  llvm::Error CreateExecutor();
   unsigned InitPTUSize = 0;
 
   // This member holds the last result of the value printing. It's a class
@@ -114,6 +113,14 @@ class Interpreter {
   // That's useful for testing and out-of-tree clients.
   Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err);
 
+  // Create the internal IncrementalExecutor, or re-create it after calling
+  // ResetExecutor().
+  llvm::Error CreateExecutor();
+
+  // Delete the internal IncrementalExecutor. This causes a hard shutdown of the
+  // JIT engine. In particular, it doesn't run cleanup or destructors.
+  void ResetExecutor();
+
   // Lazily construct the RuntimeInterfaceBuilder. The provided instance will be
   // used for the entire lifetime of the interpreter. The default implementation
   // targets the in-process __clang_Interpreter runtime. Override this to use a
diff --git a/clang/lib/Interpreter/Interpreter.cpp b/clang/lib/Interpreter/Interpreter.cpp
index 3485da8196683a..0b4ef946de6ac4 100644
--- a/clang/lib/Interpreter/Interpreter.cpp
+++ b/clang/lib/Interpreter/Interpreter.cpp
@@ -36,9 +36,11 @@
 #include "clang/Lex/PreprocessorOptions.h"
 #include "clang/Sema/Lookup.h"
 #include "llvm/ExecutionEngine/JITSymbol.h"
+#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/TargetParser/Host.h"
@@ -368,9 +370,22 @@ Interpreter::Parse(llvm::StringRef Code) {
   return IncrParser->Parse(Code);
 }
 
+static llvm::Expected<llvm::orc::JITTargetMachineBuilder>
+createJITTargetMachineBuilder(const std::string &TT) {
+  if (TT == llvm::sys::getProcessTriple())
+    return llvm::orc::JITTargetMachineBuilder::detectHost();
+  // FIXME: This can fail as well if the target is not registered! We just don't
+  // catch it yet.
+  return llvm::orc::JITTargetMachineBuilder(llvm::Triple(TT));
+}
+
 llvm::Error Interpreter::CreateExecutor() {
   const clang::TargetInfo &TI =
       getCompilerInstance()->getASTContext().getTargetInfo();
+  if (IncrExecutor)
+    return llvm::make_error<llvm::StringError>("Operation failed. "
+                                               "Execution engine exists",
+                                               std::error_code());
   llvm::Error Err = llvm::Error::success();
   auto Executor = std::make_unique<IncrementalExecutor>(*TSCtx, Err, TI);
   if (!Err)
@@ -379,6 +394,8 @@ llvm::Error Interpreter::CreateExecutor() {
   return Err;
 }
 
+void Interpreter::ResetExecutor() { IncrExecutor.reset(); }
+
 llvm::Error Interpreter::Execute(PartialTranslationUnit &T) {
   assert(T.TheModule);
   if (!IncrExecutor) {
diff --git a/clang/unittests/Interpreter/CMakeLists.txt b/clang/unittests/Interpreter/CMakeLists.txt
index 046d96ad0ec644..498070b43d922e 100644
--- a/clang/unittests/Interpreter/CMakeLists.txt
+++ b/clang/unittests/Interpreter/CMakeLists.txt
@@ -4,6 +4,7 @@ set(LLVM_LINK_COMPONENTS
   OrcJIT
   Support
   TargetParser
+  TestingSupport
   )
 
 add_clang_unittest(ClangReplInterpreterTests
diff --git a/clang/unittests/Interpreter/InterpreterExtensionsTest.cpp b/clang/unittests/Interpreter/InterpreterExtensionsTest.cpp
index 4e9f2dba210a37..8b88db1fcfede0 100644
--- a/clang/unittests/Interpreter/InterpreterExtensionsTest.cpp
+++ b/clang/unittests/Interpreter/InterpreterExtensionsTest.cpp
@@ -27,6 +27,32 @@
 using namespace clang;
 namespace {
 
+class TestCreateResetExecutor : public Interpreter {
+public:
+  TestCreateResetExecutor(std::unique_ptr<CompilerInstance> CI,
+                          llvm::Error &Err)
+      : Interpreter(std::move(CI), Err) {}
+
+  llvm::Error testCreateExecutor() {
+    return Interpreter::CreateExecutor();
+  }
+
+  void resetExecutor() { Interpreter::ResetExecutor(); }
+};
+
+TEST(InterpreterExtensionsTest, ExecutorCreateReset) {
+  clang::IncrementalCompilerBuilder CB;
+  llvm::Error ErrOut = llvm::Error::success();
+  TestCreateResetExecutor Interp(cantFail(CB.CreateCpp()), ErrOut);
+  cantFail(std::move(ErrOut));
+  cantFail(Interp.testCreateExecutor());
+  Interp.resetExecutor();
+  cantFail(Interp.testCreateExecutor());
+  EXPECT_THAT_ERROR(Interp.testCreateExecutor(),
+                    llvm::FailedWithMessage("Operation failed. "
+                                            "Execution engine exists"));
+}
+
 class RecordRuntimeIBMetrics : public Interpreter {
   struct NoopRuntimeInterfaceBuilder : public RuntimeInterfaceBuilder {
     NoopRuntimeInterfaceBuilder(Sema &S) : S(S) {}

>From cc46034e5288ba54dfc8a0ea7d54792cbb99227d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Stefan=20Gr=C3=A4nitz?= <stefan.graenitz at gmail.com>
Date: Thu, 7 Mar 2024 23:04:22 +0100
Subject: [PATCH 3/3] [clang-repl] Add CreateJITBuilder() for specialization in
 derived classes

The LLJITBuilder interface provides a very convenient wat to configure the JIT.
---
 clang/include/clang/Interpreter/Interpreter.h |   9 ++
 clang/lib/Interpreter/IncrementalExecutor.cpp |  33 ++---
 clang/lib/Interpreter/IncrementalExecutor.h   |   9 +-
 clang/lib/Interpreter/Interpreter.cpp         |  21 +++-
 .../Interpreter/InterpreterExtensionsTest.cpp | 115 ++++++++++++++++++
 5 files changed, 165 insertions(+), 22 deletions(-)

diff --git a/clang/include/clang/Interpreter/Interpreter.h b/clang/include/clang/Interpreter/Interpreter.h
index 344f38da5ab21b..8b47a899873715 100644
--- a/clang/include/clang/Interpreter/Interpreter.h
+++ b/clang/include/clang/Interpreter/Interpreter.h
@@ -29,7 +29,9 @@
 
 namespace llvm {
 namespace orc {
+class JITTargetMachineBuilder;
 class LLJIT;
+class LLJITBuilder;
 class ThreadSafeContext;
 } // namespace orc
 } // namespace llvm
@@ -127,6 +129,13 @@ class Interpreter {
   // custom runtime.
   virtual std::unique_ptr<RuntimeInterfaceBuilder> FindRuntimeInterface();
 
+  // Lazily construct thev ORCv2 JITBuilder. This called when the internal
+  // IncrementalExecutor is created. The default implementation populates an
+  // in-process JIT with debugging support. Override this to configure the JIT
+  // engine used for execution.
+  virtual llvm::Expected<std::unique_ptr<llvm::orc::LLJITBuilder>>
+  CreateJITBuilder(CompilerInstance &CI);
+
 public:
   virtual ~Interpreter();
 
diff --git a/clang/lib/Interpreter/IncrementalExecutor.cpp b/clang/lib/Interpreter/IncrementalExecutor.cpp
index 40bcef94797d43..6f036107c14a9c 100644
--- a/clang/lib/Interpreter/IncrementalExecutor.cpp
+++ b/clang/lib/Interpreter/IncrementalExecutor.cpp
@@ -20,6 +20,7 @@
 #include "llvm/ExecutionEngine/Orc/Debugging/DebuggerSupport.h"
 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
+#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
 #include "llvm/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.h"
@@ -36,26 +37,28 @@ LLVM_ATTRIBUTE_USED void linkComponents() {
 
 namespace clang {
 
+llvm::Expected<std::unique_ptr<llvm::orc::LLJITBuilder>>
+IncrementalExecutor::createDefaultJITBuilder(
+    llvm::orc::JITTargetMachineBuilder JTMB) {
+  auto JITBuilder = std::make_unique<llvm::orc::LLJITBuilder>();
+  JITBuilder->setJITTargetMachineBuilder(std::move(JTMB));
+  JITBuilder->setPrePlatformSetup([](llvm::orc::LLJIT &J) {
+    // Try to enable debugging of JIT'd code (only works with JITLink for
+    // ELF and MachO).
+    consumeError(llvm::orc::enableDebuggerSupport(J));
+    return llvm::Error::success();
+  });
+  return std::move(JITBuilder);
+}
+
 IncrementalExecutor::IncrementalExecutor(llvm::orc::ThreadSafeContext &TSC,
-                                         llvm::Error &Err,
-                                         const clang::TargetInfo &TI)
+                                         llvm::orc::LLJITBuilder &JITBuilder,
+                                         llvm::Error &Err)
     : TSCtx(TSC) {
   using namespace llvm::orc;
   llvm::ErrorAsOutParameter EAO(&Err);
 
-  auto JTMB = JITTargetMachineBuilder(TI.getTriple());
-  JTMB.addFeatures(TI.getTargetOpts().Features);
-  LLJITBuilder Builder;
-  Builder.setJITTargetMachineBuilder(JTMB);
-  Builder.setPrePlatformSetup(
-      [](LLJIT &J) {
-        // Try to enable debugging of JIT'd code (only works with JITLink for
-        // ELF and MachO).
-        consumeError(enableDebuggerSupport(J));
-        return llvm::Error::success();
-      });
-
-  if (auto JitOrErr = Builder.create())
+  if (auto JitOrErr = JITBuilder.create())
     Jit = std::move(*JitOrErr);
   else {
     Err = JitOrErr.takeError();
diff --git a/clang/lib/Interpreter/IncrementalExecutor.h b/clang/lib/Interpreter/IncrementalExecutor.h
index dd0a210a061415..b4347209e14fe3 100644
--- a/clang/lib/Interpreter/IncrementalExecutor.h
+++ b/clang/lib/Interpreter/IncrementalExecutor.h
@@ -23,7 +23,9 @@
 namespace llvm {
 class Error;
 namespace orc {
+class JITTargetMachineBuilder;
 class LLJIT;
+class LLJITBuilder;
 class ThreadSafeContext;
 } // namespace orc
 } // namespace llvm
@@ -44,8 +46,8 @@ class IncrementalExecutor {
 public:
   enum SymbolNameKind { IRName, LinkerName };
 
-  IncrementalExecutor(llvm::orc::ThreadSafeContext &TSC, llvm::Error &Err,
-                      const clang::TargetInfo &TI);
+  IncrementalExecutor(llvm::orc::ThreadSafeContext &TSC,
+                      llvm::orc::LLJITBuilder &JITBuilder, llvm::Error &Err);
   ~IncrementalExecutor();
 
   llvm::Error addModule(PartialTranslationUnit &PTU);
@@ -56,6 +58,9 @@ class IncrementalExecutor {
   getSymbolAddress(llvm::StringRef Name, SymbolNameKind NameKind) const;
 
   llvm::orc::LLJIT &GetExecutionEngine() { return *Jit; }
+
+  static llvm::Expected<std::unique_ptr<llvm::orc::LLJITBuilder>>
+  createDefaultJITBuilder(llvm::orc::JITTargetMachineBuilder JTMB);
 };
 
 } // end namespace clang
diff --git a/clang/lib/Interpreter/Interpreter.cpp b/clang/lib/Interpreter/Interpreter.cpp
index 0b4ef946de6ac4..ad4ebfd4deb708 100644
--- a/clang/lib/Interpreter/Interpreter.cpp
+++ b/clang/lib/Interpreter/Interpreter.cpp
@@ -373,21 +373,32 @@ Interpreter::Parse(llvm::StringRef Code) {
 static llvm::Expected<llvm::orc::JITTargetMachineBuilder>
 createJITTargetMachineBuilder(const std::string &TT) {
   if (TT == llvm::sys::getProcessTriple())
+    // This fails immediately if the target backend is not registered
     return llvm::orc::JITTargetMachineBuilder::detectHost();
-  // FIXME: This can fail as well if the target is not registered! We just don't
-  // catch it yet.
+
+  // If the target backend is not registered, LLJITBuilder::create() will fail
   return llvm::orc::JITTargetMachineBuilder(llvm::Triple(TT));
 }
 
+llvm::Expected<std::unique_ptr<llvm::orc::LLJITBuilder>>
+Interpreter::CreateJITBuilder(CompilerInstance &CI) {
+  auto JTMB = createJITTargetMachineBuilder(CI.getTargetOpts().Triple);
+  if (!JTMB)
+    return JTMB.takeError();
+  return IncrementalExecutor::createDefaultJITBuilder(std::move(*JTMB));
+}
+
 llvm::Error Interpreter::CreateExecutor() {
-  const clang::TargetInfo &TI =
-      getCompilerInstance()->getASTContext().getTargetInfo();
   if (IncrExecutor)
     return llvm::make_error<llvm::StringError>("Operation failed. "
                                                "Execution engine exists",
                                                std::error_code());
+  llvm::Expected<std::unique_ptr<llvm::orc::LLJITBuilder>> JB =
+      CreateJITBuilder(*getCompilerInstance());
+  if (!JB)
+    return JB.takeError();
   llvm::Error Err = llvm::Error::success();
-  auto Executor = std::make_unique<IncrementalExecutor>(*TSCtx, Err, TI);
+  auto Executor = std::make_unique<IncrementalExecutor>(*TSCtx, **JB, Err);
   if (!Err)
     IncrExecutor = std::move(Executor);
 
diff --git a/clang/unittests/Interpreter/InterpreterExtensionsTest.cpp b/clang/unittests/Interpreter/InterpreterExtensionsTest.cpp
index 8b88db1fcfede0..a7778ddfcfd7db 100644
--- a/clang/unittests/Interpreter/InterpreterExtensionsTest.cpp
+++ b/clang/unittests/Interpreter/InterpreterExtensionsTest.cpp
@@ -17,7 +17,11 @@
 #include "clang/Sema/Lookup.h"
 #include "clang/Sema/Sema.h"
 
+#include "llvm/ExecutionEngine/Orc/LLJIT.h"
+#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h"
 #include "llvm/Support/Error.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/Threading.h"
 #include "llvm/Testing/Support/Error.h"
 
 #include "gmock/gmock.h"
@@ -33,11 +37,27 @@ class TestCreateResetExecutor : public Interpreter {
                           llvm::Error &Err)
       : Interpreter(std::move(CI), Err) {}
 
+  llvm::Error testCreateJITBuilderError() {
+    JB = nullptr;
+    return Interpreter::CreateExecutor();
+  }
+
   llvm::Error testCreateExecutor() {
+    JB = std::make_unique<llvm::orc::LLJITBuilder>();
     return Interpreter::CreateExecutor();
   }
 
   void resetExecutor() { Interpreter::ResetExecutor(); }
+
+private:
+  llvm::Expected<std::unique_ptr<llvm::orc::LLJITBuilder>>
+  CreateJITBuilder(CompilerInstance &CI) override {
+    if (JB)
+      return std::move(JB);
+    return llvm::make_error<llvm::StringError>("TestError", std::error_code());
+  }
+
+  std::unique_ptr<llvm::orc::LLJITBuilder> JB;
 };
 
 TEST(InterpreterExtensionsTest, ExecutorCreateReset) {
@@ -45,6 +65,8 @@ TEST(InterpreterExtensionsTest, ExecutorCreateReset) {
   llvm::Error ErrOut = llvm::Error::success();
   TestCreateResetExecutor Interp(cantFail(CB.CreateCpp()), ErrOut);
   cantFail(std::move(ErrOut));
+  EXPECT_THAT_ERROR(Interp.testCreateJITBuilderError(),
+                    llvm::FailedWithMessage("TestError"));
   cantFail(Interp.testCreateExecutor());
   Interp.resetExecutor();
   cantFail(Interp.testCreateExecutor());
@@ -102,4 +124,97 @@ TEST(InterpreterExtensionsTest, FindRuntimeInterface) {
   EXPECT_EQ(1U, Interp.RuntimeIBPtr->TransformerQueries);
 }
 
+class CustomJBInterpreter : public Interpreter {
+  using CustomJITBuilderCreatorFunction =
+      std::function<llvm::Expected<std::unique_ptr<llvm::orc::LLJITBuilder>>()>;
+  CustomJITBuilderCreatorFunction JBCreator = nullptr;
+
+public:
+  CustomJBInterpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &ErrOut)
+      : Interpreter(std::move(CI), ErrOut) {}
+
+  ~CustomJBInterpreter() override {
+    // Skip cleanUp() because it would trigger LLJIT default dtors
+    Interpreter::ResetExecutor();
+  }
+
+  void setCustomJITBuilderCreator(CustomJITBuilderCreatorFunction Fn) {
+    JBCreator = std::move(Fn);
+  }
+
+  llvm::Expected<std::unique_ptr<llvm::orc::LLJITBuilder>>
+  CreateJITBuilder(CompilerInstance &CI) override {
+    if (JBCreator)
+      return JBCreator();
+    return Interpreter::CreateJITBuilder(CI);
+  }
+
+  llvm::Error CreateExecutor() { return Interpreter::CreateExecutor(); }
+};
+
+static void initArmTarget() {
+  static llvm::once_flag F;
+  llvm::call_once(F, [] {
+    LLVMInitializeARMTarget();
+    LLVMInitializeARMTargetInfo();
+    LLVMInitializeARMTargetMC();
+    LLVMInitializeARMAsmPrinter();
+  });
+}
+
+llvm::llvm_shutdown_obj Shutdown;
+
+TEST(InterpreterExtensionsTest, DefaultCrossJIT) {
+  IncrementalCompilerBuilder CB;
+  CB.SetTargetTriple("armv6-none-eabi");
+  auto CI = cantFail(CB.CreateCpp());
+  llvm::Error ErrOut = llvm::Error::success();
+  CustomJBInterpreter Interp(std::move(CI), ErrOut);
+  cantFail(std::move(ErrOut));
+
+  initArmTarget();
+  cantFail(Interp.CreateExecutor());
+}
+
+TEST(InterpreterExtensionsTest, CustomCrossJIT) {
+  std::string TargetTriple = "armv6-none-eabi";
+
+  IncrementalCompilerBuilder CB;
+  CB.SetTargetTriple(TargetTriple);
+  auto CI = cantFail(CB.CreateCpp());
+  llvm::Error ErrOut = llvm::Error::success();
+  CustomJBInterpreter Interp(std::move(CI), ErrOut);
+  cantFail(std::move(ErrOut));
+
+  using namespace llvm::orc;
+  LLJIT *JIT = nullptr;
+  std::vector<std::unique_ptr<llvm::MemoryBuffer>> Objs;
+  Interp.setCustomJITBuilderCreator([&]() {
+    initArmTarget();
+    auto JTMB = JITTargetMachineBuilder(llvm::Triple(TargetTriple));
+    JTMB.setCPU("cortex-m0plus");
+    auto JB = std::make_unique<LLJITBuilder>();
+    JB->setJITTargetMachineBuilder(JTMB);
+    JB->setPlatformSetUp(setUpInactivePlatform);
+    JB->setNotifyCreatedCallback([&](LLJIT &J) {
+      ObjectLayer &ObjLayer = J.getObjLinkingLayer();
+      auto *JITLinkObjLayer = llvm::dyn_cast<ObjectLinkingLayer>(&ObjLayer);
+      JITLinkObjLayer->setReturnObjectBuffer(
+          [&Objs](std::unique_ptr<llvm::MemoryBuffer> MB) {
+            Objs.push_back(std::move(MB));
+          });
+      JIT = &J;
+      return llvm::Error::success();
+    });
+    return JB;
+  });
+
+  EXPECT_EQ(0U, Objs.size());
+  cantFail(Interp.CreateExecutor());
+  cantFail(Interp.ParseAndExecute("int a = 1;"));
+  ExecutorAddr Addr = cantFail(JIT->lookup("a"));
+  EXPECT_NE(0U, Addr.getValue());
+  EXPECT_EQ(1U, Objs.size());
+}
+
 } // end anonymous namespace



More information about the cfe-commits mailing list