[Mlir-commits] [mlir] d9e04b0 - [mlir][CAPI] Expose the rest of MLIRContext's constructors
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Jul 10 13:17:25 PDT 2023
Author: Krzysztof Drewniak
Date: 2023-07-10T20:17:21Z
New Revision: d9e04b0626aef6269abd6328f6a189c313eacbba
URL: https://github.com/llvm/llvm-project/commit/d9e04b0626aef6269abd6328f6a189c313eacbba
DIFF: https://github.com/llvm/llvm-project/commit/d9e04b0626aef6269abd6328f6a189c313eacbba.diff
LOG: [mlir][CAPI] Expose the rest of MLIRContext's constructors
It's recommended practice that people calling MLIR in a loop
pre-create a LLVM ThreadPool and a dialect registry and then
explicitly pass those into a MLIRContext for each compilation.
However, the C API does not expose the functions needed to follow this
recommendation from a project that isn't calling MLIR's C++ dilectly.
Add the necessary APIs to mlir-c, including a wrapper around LLVM's
ThreadPool struct (so as to avoid having to amend or re-export parts
of the LLVM API).
Reviewed By: makslevental
Differential Revision: https://reviews.llvm.org/D153593
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/include/mlir-c/Support.h
mlir/include/mlir/CAPI/Support.h
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/CAPI/IR/Support.cpp
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 6b5d8cc4b8c033..26f7f0738b8bf1 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -84,8 +84,19 @@ typedef struct MlirNamedAttribute MlirNamedAttribute;
//===----------------------------------------------------------------------===//
/// Creates an MLIR context and transfers its ownership to the caller.
+/// This sets the default multithreading option (enabled).
MLIR_CAPI_EXPORTED MlirContext mlirContextCreate(void);
+/// Creates an MLIR context with an explicit setting of the multithreading
+/// setting and transfers its ownership to the caller.
+MLIR_CAPI_EXPORTED MlirContext
+mlirContextCreateWithThreading(bool threadingEnabled);
+
+/// Creates an MLIR context, setting the multithreading setting explicitly and
+/// pre-loading the dialects from the provided DialectRegistry.
+MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithRegistry(
+ MlirDialectRegistry registry, bool threadingEnabled);
+
/// Checks if two contexts are equal.
MLIR_CAPI_EXPORTED bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
@@ -144,6 +155,13 @@ mlirContextLoadAllAvailableDialects(MlirContext context);
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
MlirStringRef name);
+/// Sets the thread pool of the context explicitly, enabling multithreading in
+/// the process. This API should be used to avoid re-creating thread pools in
+/// long-running applications that perform multiple compilations, see
+/// the C++ documentation for MLIRContext for details.
+MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
+ MlirLlvmThreadPool threadPool);
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h
index 8d0188e319672d..78fc94f93439ec 100644
--- a/mlir/include/mlir-c/Support.h
+++ b/mlir/include/mlir-c/Support.h
@@ -56,6 +56,8 @@ extern "C" {
}; \
typedef struct name name
+/// Re-export llvm::ThreadPool so as to avoid including the LLVM C API directly.
+DEFINE_C_API_STRUCT(MlirLlvmThreadPool, void);
DEFINE_C_API_STRUCT(MlirTypeID, const void);
DEFINE_C_API_STRUCT(MlirTypeIDAllocator, void);
@@ -138,6 +140,17 @@ inline static MlirLogicalResult mlirLogicalResultFailure(void) {
return res;
}
+//===----------------------------------------------------------------------===//
+// MlirLlvmThreadPool.
+//===----------------------------------------------------------------------===//
+
+/// Create an LLVM thread pool. This is reexported here to avoid directly
+/// pulling in the LLVM headers directly.
+MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirLlvmThreadPoolCreate(void);
+
+/// Destroy an LLVM thread pool.
+MLIR_CAPI_EXPORTED void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool pool);
+
//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h
index f3e8a67e0ac360..82aa05185858e3 100644
--- a/mlir/include/mlir/CAPI/Support.h
+++ b/mlir/include/mlir/CAPI/Support.h
@@ -21,6 +21,10 @@
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/StringRef.h"
+namespace llvm {
+class ThreadPool;
+} // namespace llvm
+
/// Converts a StringRef into its MLIR C API equivalent.
inline MlirStringRef wrap(llvm::StringRef ref) {
return mlirStringRefCreate(ref.data(), ref.size());
@@ -41,6 +45,7 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
return mlir::success(mlirLogicalResultIsSuccess(res));
}
+DEFINE_C_API_PTR_METHODS(MlirLlvmThreadPool, llvm::ThreadPool)
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 16b333afc102d7..8c3ea09e9fb1a9 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -39,6 +39,23 @@ MlirContext mlirContextCreate() {
return wrap(context);
}
+static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) {
+ return threadingEnabled ? MLIRContext::Threading::ENABLED
+ : MLIRContext::Threading::DISABLED;
+}
+
+MlirContext mlirContextCreateWithThreading(bool threadingEnabled) {
+ auto *context = new MLIRContext(toThreadingEnum(threadingEnabled));
+ return wrap(context);
+}
+
+MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry,
+ bool threadingEnabled) {
+ auto *context =
+ new MLIRContext(*unwrap(registry), toThreadingEnum(threadingEnabled));
+ return wrap(context);
+}
+
bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
return unwrap(ctx1) == unwrap(ctx2);
}
@@ -84,6 +101,11 @@ void mlirContextLoadAllAvailableDialects(MlirContext context) {
unwrap(context)->loadAllAvailableDialects();
}
+void mlirContextSetThreadPool(MlirContext context,
+ MlirLlvmThreadPool threadPool) {
+ unwrap(context)->setThreadPool(*unwrap(threadPool));
+}
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp
index ea081b2e99b59a..81c9fc77192640 100644
--- a/mlir/lib/CAPI/IR/Support.cpp
+++ b/mlir/lib/CAPI/IR/Support.cpp
@@ -8,6 +8,7 @@
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/ThreadPool.h"
#include <cstring>
@@ -20,6 +21,17 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
llvm::StringRef(other.data, other.length);
}
+//===----------------------------------------------------------------------===//
+// LLVM ThreadPool API.
+//===----------------------------------------------------------------------===//
+MlirLlvmThreadPool mlirLlvmThreadPoolCreate() {
+ return wrap(new llvm::ThreadPool());
+}
+
+void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool threadPool) {
+ delete unwrap(threadPool);
+}
+
//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 00f2e8d749d513..d388fcc45e31e1 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2210,6 +2210,18 @@ int testDialectRegistry(void) {
return 0;
}
+void testExplicitThreadPools(void) {
+ MlirLlvmThreadPool threadPool = mlirLlvmThreadPoolCreate();
+ MlirDialectRegistry registry = mlirDialectRegistryCreate();
+ mlirRegisterAllDialects(registry);
+ MlirContext context =
+ mlirContextCreateWithRegistry(registry, /*threadingEnabled=*/false);
+ mlirContextSetThreadPool(context, threadPool);
+ mlirContextDestroy(context);
+ mlirDialectRegistryDestroy(registry);
+ mlirLlvmThreadPoolDestroy(threadPool);
+}
+
void testDiagnostics(void) {
MlirContext ctx = mlirContextCreate();
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
@@ -2310,6 +2322,7 @@ int main(void) {
mlirContextDestroy(ctx);
+ testExplicitThreadPools();
testDiagnostics();
return 0;
}
More information about the Mlir-commits
mailing list