[Mlir-commits] [mlir] d9e04b0 - [mlir][CAPI] Expose the rest of MLIRContext's constructors

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Jul 10 13:17:25 PDT 2023


Author: Krzysztof Drewniak
Date: 2023-07-10T20:17:21Z
New Revision: d9e04b0626aef6269abd6328f6a189c313eacbba

URL: https://github.com/llvm/llvm-project/commit/d9e04b0626aef6269abd6328f6a189c313eacbba
DIFF: https://github.com/llvm/llvm-project/commit/d9e04b0626aef6269abd6328f6a189c313eacbba.diff

LOG: [mlir][CAPI] Expose the rest of MLIRContext's constructors

It's recommended practice that people calling MLIR in a loop
pre-create a LLVM ThreadPool and a dialect registry and then
explicitly pass those into a MLIRContext for each compilation.
However, the C API does not expose the functions needed to follow this
recommendation from a project that isn't calling MLIR's C++ dilectly.

Add the necessary APIs to mlir-c, including a wrapper around LLVM's
ThreadPool struct (so as to avoid having to amend or re-export parts
of the LLVM API).

Reviewed By: makslevental

Differential Revision: https://reviews.llvm.org/D153593

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir-c/Support.h
    mlir/include/mlir/CAPI/Support.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/lib/CAPI/IR/Support.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 6b5d8cc4b8c033..26f7f0738b8bf1 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -84,8 +84,19 @@ typedef struct MlirNamedAttribute MlirNamedAttribute;
 //===----------------------------------------------------------------------===//
 
 /// Creates an MLIR context and transfers its ownership to the caller.
+/// This sets the default multithreading option (enabled).
 MLIR_CAPI_EXPORTED MlirContext mlirContextCreate(void);
 
+/// Creates an MLIR context with an explicit setting of the multithreading
+/// setting and transfers its ownership to the caller.
+MLIR_CAPI_EXPORTED MlirContext
+mlirContextCreateWithThreading(bool threadingEnabled);
+
+/// Creates an MLIR context, setting the multithreading setting explicitly and
+/// pre-loading the dialects from the provided DialectRegistry.
+MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithRegistry(
+    MlirDialectRegistry registry, bool threadingEnabled);
+
 /// Checks if two contexts are equal.
 MLIR_CAPI_EXPORTED bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
 
@@ -144,6 +155,13 @@ mlirContextLoadAllAvailableDialects(MlirContext context);
 MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
                                                          MlirStringRef name);
 
+/// Sets the thread pool of the context explicitly, enabling multithreading in
+/// the process. This API should be used to avoid re-creating thread pools in
+/// long-running applications that perform multiple compilations, see
+/// the C++ documentation for MLIRContext for details.
+MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
+                                                 MlirLlvmThreadPool threadPool);
+
 //===----------------------------------------------------------------------===//
 // Dialect API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h
index 8d0188e319672d..78fc94f93439ec 100644
--- a/mlir/include/mlir-c/Support.h
+++ b/mlir/include/mlir-c/Support.h
@@ -56,6 +56,8 @@ extern "C" {
   };                                                                           \
   typedef struct name name
 
+/// Re-export llvm::ThreadPool so as to avoid including the LLVM C API directly.
+DEFINE_C_API_STRUCT(MlirLlvmThreadPool, void);
 DEFINE_C_API_STRUCT(MlirTypeID, const void);
 DEFINE_C_API_STRUCT(MlirTypeIDAllocator, void);
 
@@ -138,6 +140,17 @@ inline static MlirLogicalResult mlirLogicalResultFailure(void) {
   return res;
 }
 
+//===----------------------------------------------------------------------===//
+// MlirLlvmThreadPool.
+//===----------------------------------------------------------------------===//
+
+/// Create an LLVM thread pool. This is reexported here to avoid directly
+/// pulling in the LLVM headers directly.
+MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirLlvmThreadPoolCreate(void);
+
+/// Destroy an LLVM thread pool.
+MLIR_CAPI_EXPORTED void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool pool);
+
 //===----------------------------------------------------------------------===//
 // TypeID API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h
index f3e8a67e0ac360..82aa05185858e3 100644
--- a/mlir/include/mlir/CAPI/Support.h
+++ b/mlir/include/mlir/CAPI/Support.h
@@ -21,6 +21,10 @@
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/StringRef.h"
 
+namespace llvm {
+class ThreadPool;
+} // namespace llvm
+
 /// Converts a StringRef into its MLIR C API equivalent.
 inline MlirStringRef wrap(llvm::StringRef ref) {
   return mlirStringRefCreate(ref.data(), ref.size());
@@ -41,6 +45,7 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
   return mlir::success(mlirLogicalResultIsSuccess(res));
 }
 
+DEFINE_C_API_PTR_METHODS(MlirLlvmThreadPool, llvm::ThreadPool)
 DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
 DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
 

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 16b333afc102d7..8c3ea09e9fb1a9 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -39,6 +39,23 @@ MlirContext mlirContextCreate() {
   return wrap(context);
 }
 
+static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) {
+  return threadingEnabled ? MLIRContext::Threading::ENABLED
+                          : MLIRContext::Threading::DISABLED;
+}
+
+MlirContext mlirContextCreateWithThreading(bool threadingEnabled) {
+  auto *context = new MLIRContext(toThreadingEnum(threadingEnabled));
+  return wrap(context);
+}
+
+MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry,
+                                          bool threadingEnabled) {
+  auto *context =
+      new MLIRContext(*unwrap(registry), toThreadingEnum(threadingEnabled));
+  return wrap(context);
+}
+
 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
   return unwrap(ctx1) == unwrap(ctx2);
 }
@@ -84,6 +101,11 @@ void mlirContextLoadAllAvailableDialects(MlirContext context) {
   unwrap(context)->loadAllAvailableDialects();
 }
 
+void mlirContextSetThreadPool(MlirContext context,
+                              MlirLlvmThreadPool threadPool) {
+  unwrap(context)->setThreadPool(*unwrap(threadPool));
+}
+
 //===----------------------------------------------------------------------===//
 // Dialect API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp
index ea081b2e99b59a..81c9fc77192640 100644
--- a/mlir/lib/CAPI/IR/Support.cpp
+++ b/mlir/lib/CAPI/IR/Support.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/CAPI/Support.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Support/ThreadPool.h"
 
 #include <cstring>
 
@@ -20,6 +21,17 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
          llvm::StringRef(other.data, other.length);
 }
 
+//===----------------------------------------------------------------------===//
+// LLVM ThreadPool API.
+//===----------------------------------------------------------------------===//
+MlirLlvmThreadPool mlirLlvmThreadPoolCreate() {
+  return wrap(new llvm::ThreadPool());
+}
+
+void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool threadPool) {
+  delete unwrap(threadPool);
+}
+
 //===----------------------------------------------------------------------===//
 // TypeID API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 00f2e8d749d513..d388fcc45e31e1 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2210,6 +2210,18 @@ int testDialectRegistry(void) {
   return 0;
 }
 
+void testExplicitThreadPools(void) {
+  MlirLlvmThreadPool threadPool = mlirLlvmThreadPoolCreate();
+  MlirDialectRegistry registry = mlirDialectRegistryCreate();
+  mlirRegisterAllDialects(registry);
+  MlirContext context =
+      mlirContextCreateWithRegistry(registry, /*threadingEnabled=*/false);
+  mlirContextSetThreadPool(context, threadPool);
+  mlirContextDestroy(context);
+  mlirDialectRegistryDestroy(registry);
+  mlirLlvmThreadPoolDestroy(threadPool);
+}
+
 void testDiagnostics(void) {
   MlirContext ctx = mlirContextCreate();
   MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
@@ -2310,6 +2322,7 @@ int main(void) {
 
   mlirContextDestroy(ctx);
 
+  testExplicitThreadPools();
   testDiagnostics();
   return 0;
 }


        


More information about the Mlir-commits mailing list