[Mlir-commits] [mlir] 2387fad - [mlir][capi] Add external pass creation to MLIR C-API
Daniel Resnick
llvmlistbot at llvm.org
Mon Apr 4 09:29:42 PDT 2022
Author: Daniel Resnick
Date: 2022-04-04T10:27:11-06:00
New Revision: 2387fadea3a807ba59993a23529035c13f478002
URL: https://github.com/llvm/llvm-project/commit/2387fadea3a807ba59993a23529035c13f478002
DIFF: https://github.com/llvm/llvm-project/commit/2387fadea3a807ba59993a23529035c13f478002.diff
LOG: [mlir][capi] Add external pass creation to MLIR C-API
Adds the ability to create external passes using the C-API. This allows passes
to be written in C or languages that use the C-bindings.
Differential Revision: https://reviews.llvm.org/D121866
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/include/mlir-c/Pass.h
mlir/include/mlir-c/Support.h
mlir/include/mlir/CAPI/IR.h
mlir/include/mlir/CAPI/Support.h
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/CAPI/IR/Pass.cpp
mlir/lib/CAPI/IR/Support.cpp
mlir/test/CAPI/CMakeLists.txt
mlir/test/CAPI/pass.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index d999554664d96..1c0517f5918f1 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -62,7 +62,6 @@ DEFINE_C_API_STRUCT(MlirIdentifier, const void);
DEFINE_C_API_STRUCT(MlirLocation, const void);
DEFINE_C_API_STRUCT(MlirModule, const void);
DEFINE_C_API_STRUCT(MlirType, const void);
-DEFINE_C_API_STRUCT(MlirTypeID, const void);
DEFINE_C_API_STRUCT(MlirValue, const void);
#undef DEFINE_C_API_STRUCT
@@ -757,19 +756,6 @@ MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident,
/// Gets the string value of the identifier.
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident);
-//===----------------------------------------------------------------------===//
-// TypeID API.
-//===----------------------------------------------------------------------===//
-
-/// Checks whether a type id is null.
-static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; }
-
-/// Checks if two type ids are equal.
-MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
-
-/// Returns the hash value of the type id.
-MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
-
//===----------------------------------------------------------------------===//
// Symbol and SymbolTable API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index d8b2168127f90..cdb947bdeb2fd 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -15,6 +15,7 @@
#define MLIR_C_PASS_H
#include "mlir-c/IR.h"
+#include "mlir-c/Registration.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
@@ -41,11 +42,16 @@ extern "C" {
typedef struct name name
DEFINE_C_API_STRUCT(MlirPass, void);
+DEFINE_C_API_STRUCT(MlirExternalPass, void);
DEFINE_C_API_STRUCT(MlirPassManager, void);
DEFINE_C_API_STRUCT(MlirOpPassManager, void);
#undef DEFINE_C_API_STRUCT
+//===----------------------------------------------------------------------===//
+// PassManager/OpPassManager APIs.
+//===----------------------------------------------------------------------===//
+
/// Create a new top-level PassManager.
MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx);
@@ -112,6 +118,55 @@ MLIR_CAPI_EXPORTED void mlirPrintPassPipeline(MlirOpPassManager passManager,
MLIR_CAPI_EXPORTED MlirLogicalResult
mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline);
+//===----------------------------------------------------------------------===//
+// External Pass API.
+//
+// This API allows to define passes outside of MLIR, not necessarily in
+// C++, and register them with the MLIR pass management infrastructure.
+//
+//===----------------------------------------------------------------------===//
+
+/// Structure of external `MlirPass` callbacks.
+/// All callbacks are required to be set unless otherwise specified.
+struct MlirExternalPassCallbacks {
+ /// This callback is called from the pass is created.
+ /// This is analogous to a C++ pass constructor.
+ void (*construct)(void *userData);
+
+ /// This callback is called when the pass is destroyed
+ /// This is analogous to a C++ pass destructor.
+ void (*destruct)(void *userData);
+
+ /// This callback is optional.
+ /// The callback is called before the pass is run, allowing a chance to
+ /// initialize any complex state necessary for running the pass.
+ /// See Pass::initialize(MLIRContext *).
+ MlirLogicalResult (*initialize)(MlirContext ctx, void *userData);
+
+ /// This callback is called when the pass is cloned.
+ /// See Pass::clonePass().
+ void *(*clone)(void *userData);
+
+ /// This callback is called when the pass is run.
+ /// See Pass::runOnOperation().
+ void (*run)(MlirOperation op, MlirExternalPass pass, void *userData);
+};
+typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks;
+
+/// Creates an external `MlirPass` that calls the supplied `callbacks` using the
+/// supplied `userData`. If `opName` is empty, the pass is a generic operation
+/// pass. Otherwise it is an operation pass specific to the specified pass name.
+MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
+ MlirTypeID passID, MlirStringRef name, MlirStringRef argument,
+ MlirStringRef description, MlirStringRef opName,
+ intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
+ MlirExternalPassCallbacks callbacks, void *userData);
+
+/// This signals that the pass has failed. This is only valid to call during
+/// the `run` callback of `MlirExternalPassCallbacks`.
+/// See Pass::signalPassFailure().
+MLIR_CAPI_EXPORTED void mlirExternalPassSignalFailure(MlirExternalPass pass);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h
index f20e58fe62317..5d20fb78d5dc2 100644
--- a/mlir/include/mlir-c/Support.h
+++ b/mlir/include/mlir-c/Support.h
@@ -50,6 +50,17 @@
extern "C" {
#endif
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirTypeID, const void);
+DEFINE_C_API_STRUCT(MlirTypeIDAllocator, void);
+
+#undef DEFINE_C_API_STRUCT
+
//===----------------------------------------------------------------------===//
// MlirStringRef.
//===----------------------------------------------------------------------===//
@@ -127,6 +138,38 @@ inline static MlirLogicalResult mlirLogicalResultFailure() {
return res;
}
+//===----------------------------------------------------------------------===//
+// TypeID API.
+//===----------------------------------------------------------------------===//
+
+/// `ptr` must be 8 byte aligned and unique to a type valid for the duration of
+/// the returned type id's usage
+MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr);
+
+/// Checks whether a type id is null.
+static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; }
+
+/// Checks if two type ids are equal.
+MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
+
+/// Returns the hash value of the type id.
+MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
+
+//===----------------------------------------------------------------------===//
+// TypeIDAllocator API.
+//===----------------------------------------------------------------------===//
+
+/// Creates a type id allocator for dynamic type id creation
+MLIR_CAPI_EXPORTED MlirTypeIDAllocator mlirTypeIDAllocatorCreate();
+
+/// Deallocates the allocator and all allocated type ids
+MLIR_CAPI_EXPORTED void
+mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator);
+
+/// Allocates a type id that is valid for the lifetime of the allocator
+MLIR_CAPI_EXPORTED MlirTypeID
+mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 06cf7762a9c0e..899b411670747 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -34,7 +34,6 @@ DEFINE_C_API_METHODS(MlirIdentifier, mlir::StringAttr)
DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
DEFINE_C_API_METHODS(MlirType, mlir::Type)
-DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
DEFINE_C_API_METHODS(MlirValue, mlir::Value)
#endif // MLIR_CAPI_IR_H
diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h
index 6d9a59abf111f..f3e8a67e0ac36 100644
--- a/mlir/include/mlir/CAPI/Support.h
+++ b/mlir/include/mlir/CAPI/Support.h
@@ -16,7 +16,9 @@
#define MLIR_CAPI_SUPPORT_H
#include "mlir-c/Support.h"
+#include "mlir/CAPI/Wrap.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/TypeID.h"
#include "llvm/ADT/StringRef.h"
/// Converts a StringRef into its MLIR C API equivalent.
@@ -39,4 +41,7 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
return mlir::success(mlirLogicalResultIsSuccess(res));
}
+DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
+DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
+
#endif // MLIR_CAPI_SUPPORT_H
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 75ac93ffb6897..527aa4eaf9e89 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -787,18 +787,6 @@ MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
return wrap(unwrap(ident).strref());
}
-//===----------------------------------------------------------------------===//
-// TypeID API.
-//===----------------------------------------------------------------------===//
-
-bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
- return unwrap(typeID1) == unwrap(typeID2);
-}
-
-size_t mlirTypeIDHashValue(MlirTypeID typeID) {
- return hash_value(unwrap(typeID));
-}
-
//===----------------------------------------------------------------------===//
// Symbol and SymbolTable API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 4bfc9d0132b37..a2998939a1586 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -77,3 +77,94 @@ MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
// stream and redirect to a diagnostic.
return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
}
+
+//===----------------------------------------------------------------------===//
+// External Pass API.
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+class ExternalPass;
+} // namespace mlir
+DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
+
+namespace mlir {
+/// This pass class wraps external passes defined in other languages using the
+/// MLIR C-interface
+class ExternalPass : public Pass {
+public:
+ ExternalPass(TypeID passID, StringRef name, StringRef argument,
+ StringRef description, Optional<StringRef> opName,
+ ArrayRef<MlirDialectHandle> dependentDialects,
+ MlirExternalPassCallbacks callbacks, void *userData)
+ : Pass(passID, opName), id(passID), name(name), argument(argument),
+ description(description), dependentDialects(dependentDialects),
+ callbacks(callbacks), userData(userData) {
+ callbacks.construct(userData);
+ }
+
+ ~ExternalPass() override { callbacks.destruct(userData); }
+
+ StringRef getName() const override { return name; }
+ StringRef getArgument() const override { return argument; }
+ StringRef getDescription() const override { return description; }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ MlirDialectRegistry cRegistry = wrap(®istry);
+ for (MlirDialectHandle dialect : dependentDialects)
+ mlirDialectHandleInsertDialect(dialect, cRegistry);
+ }
+
+ void signalPassFailure() { Pass::signalPassFailure(); }
+
+protected:
+ LogicalResult initialize(MLIRContext *ctx) override {
+ if (callbacks.initialize)
+ return unwrap(callbacks.initialize(wrap(ctx), userData));
+ return success();
+ }
+
+ bool canScheduleOn(RegisteredOperationName opName) const override {
+ if (Optional<StringRef> specifiedOpName = getOpName())
+ return opName.getStringRef() == specifiedOpName;
+ return true;
+ }
+
+ void runOnOperation() override {
+ callbacks.run(wrap(getOperation()), wrap(this), userData);
+ }
+
+ std::unique_ptr<Pass> clonePass() const override {
+ void *clonedUserData = callbacks.clone(userData);
+ return std::make_unique<ExternalPass>(id, name, argument, description,
+ getOpName(), dependentDialects,
+ callbacks, clonedUserData);
+ }
+
+private:
+ TypeID id;
+ std::string name;
+ std::string argument;
+ std::string description;
+ std::vector<MlirDialectHandle> dependentDialects;
+ MlirExternalPassCallbacks callbacks;
+ void *userData;
+};
+} // namespace mlir
+
+MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
+ MlirStringRef argument,
+ MlirStringRef description, MlirStringRef opName,
+ intptr_t nDependentDialects,
+ MlirDialectHandle *dependentDialects,
+ MlirExternalPassCallbacks callbacks,
+ void *userData) {
+ return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
+ unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
+ opName.length > 0 ? Optional<StringRef>(unwrap(opName)) : None,
+ {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
+ userData)));
+}
+
+void mlirExternalPassSignalFailure(MlirExternalPass pass) {
+ unwrap(pass)->signalPassFailure();
+}
diff --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp
index b6e1f9180c771..cbfbb54769aa9 100644
--- a/mlir/lib/CAPI/IR/Support.cpp
+++ b/mlir/lib/CAPI/IR/Support.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir-c/Support.h"
+#include "mlir/CAPI/Support.h"
#include "llvm/ADT/StringRef.h"
#include <cstring>
@@ -19,3 +19,40 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
return llvm::StringRef(string.data, string.length) ==
llvm::StringRef(other.data, other.length);
}
+
+//===----------------------------------------------------------------------===//
+// TypeID API.
+//===----------------------------------------------------------------------===//
+
+MlirTypeID mlirTypeIDCreate(const void *ptr) {
+ assert(reinterpret_cast<uintptr_t>(ptr) % 8 == 0 &&
+ "ptr must be 8 byte aligned");
+ // This is essentially a no-op that returns back `ptr`, but by going through
+ // the `TypeID` functions we can get compiler errors in case the `TypeID`
+ // api/representation changes
+ return wrap(mlir::TypeID::getFromOpaquePointer(ptr));
+}
+
+bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
+ return unwrap(typeID1) == unwrap(typeID2);
+}
+
+size_t mlirTypeIDHashValue(MlirTypeID typeID) {
+ return hash_value(unwrap(typeID));
+}
+
+//===----------------------------------------------------------------------===//
+// TypeIDAllocator API.
+//===----------------------------------------------------------------------===//
+
+MlirTypeIDAllocator mlirTypeIDAllocatorCreate() {
+ return wrap(new mlir::TypeIDAllocator());
+}
+
+void mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator) {
+ delete unwrap(allocator);
+}
+
+MlirTypeID mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator) {
+ return wrap(unwrap(allocator)->allocate());
+}
diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt
index 8a7a567f2a775..a4e2f024dd122 100644
--- a/mlir/test/CAPI/CMakeLists.txt
+++ b/mlir/test/CAPI/CMakeLists.txt
@@ -49,6 +49,7 @@ _add_capi_test_executable(mlir-capi-llvm-test
_add_capi_test_executable(mlir-capi-pass-test
pass.c
LINK_LIBS PRIVATE
+ MLIRCAPIFunc
MLIRCAPIIR
MLIRCAPIRegistration
MLIRCAPITransforms
diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index 69d052c3131a3..aeabac4fd844d 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -11,6 +11,7 @@
*/
#include "mlir-c/Pass.h"
+#include "mlir-c/Dialect/Func.h"
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#include "mlir-c/Transforms.h"
@@ -169,7 +170,9 @@ void testParsePassPipeline() {
" func.func(print-op-stats))"));
// Expect a failure, we haven't registered the print-op-stats pass yet.
if (mlirLogicalResultIsSuccess(status)) {
- fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n");
+ fprintf(
+ stderr,
+ "Unexpected success parsing pipeline without registering the pass\n");
exit(EXIT_FAILURE);
}
// Try again after registrating the pass.
@@ -180,7 +183,8 @@ void testParsePassPipeline() {
" func.func(print-op-stats))"));
// Expect a failure, we haven't registered the print-op-stats pass yet.
if (mlirLogicalResultIsFailure(status)) {
- fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n");
+ fprintf(stderr,
+ "Unexpected failure parsing pipeline after registering the pass\n");
exit(EXIT_FAILURE);
}
@@ -194,10 +198,328 @@ void testParsePassPipeline() {
mlirContextDestroy(ctx);
}
+struct TestExternalPassUserData {
+ int constructCallCount;
+ int destructCallCount;
+ int initializeCallCount;
+ int cloneCallCount;
+ int runCallCount;
+};
+typedef struct TestExternalPassUserData TestExternalPassUserData;
+
+void testConstructExternalPass(void *userData) {
+ ++((TestExternalPassUserData *)userData)->constructCallCount;
+}
+
+void testDestructExternalPass(void *userData) {
+ ++((TestExternalPassUserData *)userData)->destructCallCount;
+}
+
+MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) {
+ ++((TestExternalPassUserData *)userData)->initializeCallCount;
+ return mlirLogicalResultSuccess();
+}
+
+MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx,
+ void *userData) {
+ ++((TestExternalPassUserData *)userData)->initializeCallCount;
+ return mlirLogicalResultFailure();
+}
+
+void *testCloneExternalPass(void *userData) {
+ ++((TestExternalPassUserData *)userData)->cloneCallCount;
+ return userData;
+}
+
+void testRunExternalPass(MlirOperation op, MlirExternalPass pass,
+ void *userData) {
+ ++((TestExternalPassUserData *)userData)->runCallCount;
+}
+
+void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass,
+ void *userData) {
+ ++((TestExternalPassUserData *)userData)->runCallCount;
+ MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op));
+ if (!mlirStringRefEqual(opName,
+ mlirStringRefCreateFromCString("func.func"))) {
+ mlirExternalPassSignalFailure(pass);
+ }
+}
+
+void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass,
+ void *userData) {
+ ++((TestExternalPassUserData *)userData)->runCallCount;
+ mlirExternalPassSignalFailure(pass);
+}
+
+MlirExternalPassCallbacks makeTestExternalPassCallbacks(
+ MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData),
+ void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) {
+ return (MlirExternalPassCallbacks){testConstructExternalPass,
+ testDestructExternalPass, initializePass,
+ testCloneExternalPass, runPass};
+}
+
+void testExternalPass() {
+ MlirContext ctx = mlirContextCreate();
+ mlirRegisterAllDialects(ctx);
+
+ MlirModule module = mlirModuleCreateParse(
+ ctx,
+ // clang-format off
+ mlirStringRefCreateFromCString(
+"func @foo(%arg0 : i32) -> i32 { \n"
+" %res = arith.addi %arg0, %arg0 : i32 \n"
+" return %res : i32 \n"
+"}"));
+ // clang-format on
+ if (mlirModuleIsNull(module)) {
+ fprintf(stderr, "Unexpected failure parsing module.\n");
+ exit(EXIT_FAILURE);
+ }
+
+ MlirStringRef description = mlirStringRefCreateFromCString("");
+ MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
+
+ MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+
+ // Run a generic pass
+ {
+ MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
+ MlirStringRef argument =
+ mlirStringRefCreateFromCString("test-external-pass");
+ TestExternalPassUserData userData = {0};
+
+ MlirPass externalPass = mlirCreateExternalPass(
+ passID, name, argument, description, emptyOpName, 0, NULL,
+ makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
+
+ if (userData.constructCallCount != 1) {
+ fprintf(stderr, "Expected constructCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ MlirPassManager pm = mlirPassManagerCreate(ctx);
+ mlirPassManagerAddOwnedPass(pm, externalPass);
+ MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ if (mlirLogicalResultIsFailure(success)) {
+ fprintf(stderr, "Unexpected failure running external pass.\n");
+ exit(EXIT_FAILURE);
+ }
+
+ if (userData.runCallCount != 1) {
+ fprintf(stderr, "Expected runCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ mlirPassManagerDestroy(pm);
+
+ if (userData.destructCallCount != userData.constructCallCount) {
+ fprintf(stderr, "Expected destructCallCount to be equal to "
+ "constructCallCount\n");
+ exit(EXIT_FAILURE);
+ }
+ }
+
+ // Run a func operation pass
+ {
+ MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass");
+ MlirStringRef argument =
+ mlirStringRefCreateFromCString("test-external-func-pass");
+ TestExternalPassUserData userData = {0};
+ MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
+ MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
+
+ MlirPass externalPass = mlirCreateExternalPass(
+ passID, name, argument, description, funcOpName, 1, &funcHandle,
+ makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
+ &userData);
+
+ if (userData.constructCallCount != 1) {
+ fprintf(stderr, "Expected constructCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ MlirPassManager pm = mlirPassManagerCreate(ctx);
+ MlirOpPassManager nestedFuncPm =
+ mlirPassManagerGetNestedUnder(pm, funcOpName);
+ mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
+ MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ if (mlirLogicalResultIsFailure(success)) {
+ fprintf(stderr, "Unexpected failure running external operation pass.\n");
+ exit(EXIT_FAILURE);
+ }
+
+ // Since this is a nested pass, it can be cloned and run in parallel
+ if (userData.cloneCallCount != userData.constructCallCount - 1) {
+ fprintf(stderr, "Expected constructCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ // The pass should only be run once this there is only one func op
+ if (userData.runCallCount != 1) {
+ fprintf(stderr, "Expected runCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ mlirPassManagerDestroy(pm);
+
+ if (userData.destructCallCount != userData.constructCallCount) {
+ fprintf(stderr, "Expected destructCallCount to be equal to "
+ "constructCallCount\n");
+ exit(EXIT_FAILURE);
+ }
+ }
+
+ // Run a pass with `initialize` set
+ {
+ MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
+ MlirStringRef argument =
+ mlirStringRefCreateFromCString("test-external-pass");
+ TestExternalPassUserData userData = {0};
+
+ MlirPass externalPass = mlirCreateExternalPass(
+ passID, name, argument, description, emptyOpName, 0, NULL,
+ makeTestExternalPassCallbacks(testInitializeExternalPass,
+ testRunExternalPass),
+ &userData);
+
+ if (userData.constructCallCount != 1) {
+ fprintf(stderr, "Expected constructCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ MlirPassManager pm = mlirPassManagerCreate(ctx);
+ mlirPassManagerAddOwnedPass(pm, externalPass);
+ MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ if (mlirLogicalResultIsFailure(success)) {
+ fprintf(stderr, "Unexpected failure running external pass.\n");
+ exit(EXIT_FAILURE);
+ }
+
+ if (userData.initializeCallCount != 1) {
+ fprintf(stderr, "Expected initializeCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ if (userData.runCallCount != 1) {
+ fprintf(stderr, "Expected runCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ mlirPassManagerDestroy(pm);
+
+ if (userData.destructCallCount != userData.constructCallCount) {
+ fprintf(stderr, "Expected destructCallCount to be equal to "
+ "constructCallCount\n");
+ exit(EXIT_FAILURE);
+ }
+ }
+
+ // Run a pass that fails during `initialize`
+ {
+ MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ MlirStringRef name =
+ mlirStringRefCreateFromCString("TestExternalFailingPass");
+ MlirStringRef argument =
+ mlirStringRefCreateFromCString("test-external-failing-pass");
+ TestExternalPassUserData userData = {0};
+
+ MlirPass externalPass = mlirCreateExternalPass(
+ passID, name, argument, description, emptyOpName, 0, NULL,
+ makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
+ testRunExternalPass),
+ &userData);
+
+ if (userData.constructCallCount != 1) {
+ fprintf(stderr, "Expected constructCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ MlirPassManager pm = mlirPassManagerCreate(ctx);
+ mlirPassManagerAddOwnedPass(pm, externalPass);
+ MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ if (mlirLogicalResultIsSuccess(success)) {
+ fprintf(
+ stderr,
+ "Expected failure running pass manager on failing external pass.\n");
+ exit(EXIT_FAILURE);
+ }
+
+ if (userData.initializeCallCount != 1) {
+ fprintf(stderr, "Expected initializeCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ if (userData.runCallCount != 0) {
+ fprintf(stderr, "Expected runCallCount to be 0\n");
+ exit(EXIT_FAILURE);
+ }
+
+ mlirPassManagerDestroy(pm);
+
+ if (userData.destructCallCount != userData.constructCallCount) {
+ fprintf(stderr, "Expected destructCallCount to be equal to "
+ "constructCallCount\n");
+ exit(EXIT_FAILURE);
+ }
+ }
+
+ // Run a pass that fails during `run`
+ {
+ MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ MlirStringRef name =
+ mlirStringRefCreateFromCString("TestExternalFailingPass");
+ MlirStringRef argument =
+ mlirStringRefCreateFromCString("test-external-failing-pass");
+ TestExternalPassUserData userData = {0};
+
+ MlirPass externalPass = mlirCreateExternalPass(
+ passID, name, argument, description, emptyOpName, 0, NULL,
+ makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
+ &userData);
+
+ if (userData.constructCallCount != 1) {
+ fprintf(stderr, "Expected constructCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ MlirPassManager pm = mlirPassManagerCreate(ctx);
+ mlirPassManagerAddOwnedPass(pm, externalPass);
+ MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ if (mlirLogicalResultIsSuccess(success)) {
+ fprintf(
+ stderr,
+ "Expected failure running pass manager on failing external pass.\n");
+ exit(EXIT_FAILURE);
+ }
+
+ if (userData.runCallCount != 1) {
+ fprintf(stderr, "Expected runCallCount to be 1\n");
+ exit(EXIT_FAILURE);
+ }
+
+ mlirPassManagerDestroy(pm);
+
+ if (userData.destructCallCount != userData.constructCallCount) {
+ fprintf(stderr, "Expected destructCallCount to be equal to "
+ "constructCallCount\n");
+ exit(EXIT_FAILURE);
+ }
+ }
+
+ mlirTypeIDAllocatorDestroy(typeIDAllocator);
+ mlirContextDestroy(ctx);
+}
+
int main() {
testRunPassOnModule();
testRunPassOnNestedModule();
testPrintPassPipeline();
testParsePassPipeline();
+ testExternalPass();
return 0;
}
More information about the Mlir-commits
mailing list