[Mlir-commits] [mlir] 2387fad - [mlir][capi] Add external pass creation to MLIR C-API

Daniel Resnick llvmlistbot at llvm.org
Mon Apr 4 09:29:42 PDT 2022


Author: Daniel Resnick
Date: 2022-04-04T10:27:11-06:00
New Revision: 2387fadea3a807ba59993a23529035c13f478002

URL: https://github.com/llvm/llvm-project/commit/2387fadea3a807ba59993a23529035c13f478002
DIFF: https://github.com/llvm/llvm-project/commit/2387fadea3a807ba59993a23529035c13f478002.diff

LOG: [mlir][capi] Add external pass creation to MLIR C-API

Adds the ability to create external passes using the C-API. This allows passes
to be written in C or languages that use the C-bindings.

Differential Revision: https://reviews.llvm.org/D121866

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir-c/Pass.h
    mlir/include/mlir-c/Support.h
    mlir/include/mlir/CAPI/IR.h
    mlir/include/mlir/CAPI/Support.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/lib/CAPI/IR/Pass.cpp
    mlir/lib/CAPI/IR/Support.cpp
    mlir/test/CAPI/CMakeLists.txt
    mlir/test/CAPI/pass.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index d999554664d96..1c0517f5918f1 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -62,7 +62,6 @@ DEFINE_C_API_STRUCT(MlirIdentifier, const void);
 DEFINE_C_API_STRUCT(MlirLocation, const void);
 DEFINE_C_API_STRUCT(MlirModule, const void);
 DEFINE_C_API_STRUCT(MlirType, const void);
-DEFINE_C_API_STRUCT(MlirTypeID, const void);
 DEFINE_C_API_STRUCT(MlirValue, const void);
 
 #undef DEFINE_C_API_STRUCT
@@ -757,19 +756,6 @@ MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident,
 /// Gets the string value of the identifier.
 MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident);
 
-//===----------------------------------------------------------------------===//
-// TypeID API.
-//===----------------------------------------------------------------------===//
-
-/// Checks whether a type id is null.
-static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; }
-
-/// Checks if two type ids are equal.
-MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
-
-/// Returns the hash value of the type id.
-MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
-
 //===----------------------------------------------------------------------===//
 // Symbol and SymbolTable API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index d8b2168127f90..cdb947bdeb2fd 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -15,6 +15,7 @@
 #define MLIR_C_PASS_H
 
 #include "mlir-c/IR.h"
+#include "mlir-c/Registration.h"
 #include "mlir-c/Support.h"
 
 #ifdef __cplusplus
@@ -41,11 +42,16 @@ extern "C" {
   typedef struct name name
 
 DEFINE_C_API_STRUCT(MlirPass, void);
+DEFINE_C_API_STRUCT(MlirExternalPass, void);
 DEFINE_C_API_STRUCT(MlirPassManager, void);
 DEFINE_C_API_STRUCT(MlirOpPassManager, void);
 
 #undef DEFINE_C_API_STRUCT
 
+//===----------------------------------------------------------------------===//
+// PassManager/OpPassManager APIs.
+//===----------------------------------------------------------------------===//
+
 /// Create a new top-level PassManager.
 MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx);
 
@@ -112,6 +118,55 @@ MLIR_CAPI_EXPORTED void mlirPrintPassPipeline(MlirOpPassManager passManager,
 MLIR_CAPI_EXPORTED MlirLogicalResult
 mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline);
 
+//===----------------------------------------------------------------------===//
+// External Pass API.
+//
+// This API allows to define passes outside of MLIR, not necessarily in
+// C++, and register them with the MLIR pass management infrastructure.
+//
+//===----------------------------------------------------------------------===//
+
+/// Structure of external `MlirPass` callbacks.
+/// All callbacks are required to be set unless otherwise specified.
+struct MlirExternalPassCallbacks {
+  /// This callback is called from the pass is created.
+  /// This is analogous to a C++ pass constructor.
+  void (*construct)(void *userData);
+
+  /// This callback is called when the pass is destroyed
+  /// This is analogous to a C++ pass destructor.
+  void (*destruct)(void *userData);
+
+  /// This callback is optional.
+  /// The callback is called before the pass is run, allowing a chance to
+  /// initialize any complex state necessary for running the pass.
+  /// See Pass::initialize(MLIRContext *).
+  MlirLogicalResult (*initialize)(MlirContext ctx, void *userData);
+
+  /// This callback is called when the pass is cloned.
+  /// See Pass::clonePass().
+  void *(*clone)(void *userData);
+
+  /// This callback is called when the pass is run.
+  /// See Pass::runOnOperation().
+  void (*run)(MlirOperation op, MlirExternalPass pass, void *userData);
+};
+typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks;
+
+/// Creates an external `MlirPass` that calls the supplied `callbacks` using the
+/// supplied `userData`. If `opName` is empty, the pass is a generic operation
+/// pass. Otherwise it is an operation pass specific to the specified pass name.
+MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
+    MlirTypeID passID, MlirStringRef name, MlirStringRef argument,
+    MlirStringRef description, MlirStringRef opName,
+    intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
+    MlirExternalPassCallbacks callbacks, void *userData);
+
+/// This signals that the pass has failed. This is only valid to call during
+/// the `run` callback of `MlirExternalPassCallbacks`.
+/// See Pass::signalPassFailure().
+MLIR_CAPI_EXPORTED void mlirExternalPassSignalFailure(MlirExternalPass pass);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h
index f20e58fe62317..5d20fb78d5dc2 100644
--- a/mlir/include/mlir-c/Support.h
+++ b/mlir/include/mlir-c/Support.h
@@ -50,6 +50,17 @@
 extern "C" {
 #endif
 
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirTypeID, const void);
+DEFINE_C_API_STRUCT(MlirTypeIDAllocator, void);
+
+#undef DEFINE_C_API_STRUCT
+
 //===----------------------------------------------------------------------===//
 // MlirStringRef.
 //===----------------------------------------------------------------------===//
@@ -127,6 +138,38 @@ inline static MlirLogicalResult mlirLogicalResultFailure() {
   return res;
 }
 
+//===----------------------------------------------------------------------===//
+// TypeID API.
+//===----------------------------------------------------------------------===//
+
+/// `ptr` must be 8 byte aligned and unique to a type valid for the duration of
+/// the returned type id's usage
+MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr);
+
+/// Checks whether a type id is null.
+static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; }
+
+/// Checks if two type ids are equal.
+MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
+
+/// Returns the hash value of the type id.
+MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
+
+//===----------------------------------------------------------------------===//
+// TypeIDAllocator API.
+//===----------------------------------------------------------------------===//
+
+/// Creates a type id allocator for dynamic type id creation
+MLIR_CAPI_EXPORTED MlirTypeIDAllocator mlirTypeIDAllocatorCreate();
+
+/// Deallocates the allocator and all allocated type ids
+MLIR_CAPI_EXPORTED void
+mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator);
+
+/// Allocates a type id that is valid for the lifetime of the allocator
+MLIR_CAPI_EXPORTED MlirTypeID
+mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 06cf7762a9c0e..899b411670747 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -34,7 +34,6 @@ DEFINE_C_API_METHODS(MlirIdentifier, mlir::StringAttr)
 DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
 DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
 DEFINE_C_API_METHODS(MlirType, mlir::Type)
-DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
 DEFINE_C_API_METHODS(MlirValue, mlir::Value)
 
 #endif // MLIR_CAPI_IR_H

diff  --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h
index 6d9a59abf111f..f3e8a67e0ac36 100644
--- a/mlir/include/mlir/CAPI/Support.h
+++ b/mlir/include/mlir/CAPI/Support.h
@@ -16,7 +16,9 @@
 #define MLIR_CAPI_SUPPORT_H
 
 #include "mlir-c/Support.h"
+#include "mlir/CAPI/Wrap.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/TypeID.h"
 #include "llvm/ADT/StringRef.h"
 
 /// Converts a StringRef into its MLIR C API equivalent.
@@ -39,4 +41,7 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
   return mlir::success(mlirLogicalResultIsSuccess(res));
 }
 
+DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
+DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
+
 #endif // MLIR_CAPI_SUPPORT_H

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 75ac93ffb6897..527aa4eaf9e89 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -787,18 +787,6 @@ MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
   return wrap(unwrap(ident).strref());
 }
 
-//===----------------------------------------------------------------------===//
-// TypeID API.
-//===----------------------------------------------------------------------===//
-
-bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
-  return unwrap(typeID1) == unwrap(typeID2);
-}
-
-size_t mlirTypeIDHashValue(MlirTypeID typeID) {
-  return hash_value(unwrap(typeID));
-}
-
 //===----------------------------------------------------------------------===//
 // Symbol and SymbolTable API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 4bfc9d0132b37..a2998939a1586 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -77,3 +77,94 @@ MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
   // stream and redirect to a diagnostic.
   return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
 }
+
+//===----------------------------------------------------------------------===//
+// External Pass API.
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+class ExternalPass;
+} // namespace mlir
+DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
+
+namespace mlir {
+/// This pass class wraps external passes defined in other languages using the
+/// MLIR C-interface
+class ExternalPass : public Pass {
+public:
+  ExternalPass(TypeID passID, StringRef name, StringRef argument,
+               StringRef description, Optional<StringRef> opName,
+               ArrayRef<MlirDialectHandle> dependentDialects,
+               MlirExternalPassCallbacks callbacks, void *userData)
+      : Pass(passID, opName), id(passID), name(name), argument(argument),
+        description(description), dependentDialects(dependentDialects),
+        callbacks(callbacks), userData(userData) {
+    callbacks.construct(userData);
+  }
+
+  ~ExternalPass() override { callbacks.destruct(userData); }
+
+  StringRef getName() const override { return name; }
+  StringRef getArgument() const override { return argument; }
+  StringRef getDescription() const override { return description; }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    MlirDialectRegistry cRegistry = wrap(&registry);
+    for (MlirDialectHandle dialect : dependentDialects)
+      mlirDialectHandleInsertDialect(dialect, cRegistry);
+  }
+
+  void signalPassFailure() { Pass::signalPassFailure(); }
+
+protected:
+  LogicalResult initialize(MLIRContext *ctx) override {
+    if (callbacks.initialize)
+      return unwrap(callbacks.initialize(wrap(ctx), userData));
+    return success();
+  }
+
+  bool canScheduleOn(RegisteredOperationName opName) const override {
+    if (Optional<StringRef> specifiedOpName = getOpName())
+      return opName.getStringRef() == specifiedOpName;
+    return true;
+  }
+
+  void runOnOperation() override {
+    callbacks.run(wrap(getOperation()), wrap(this), userData);
+  }
+
+  std::unique_ptr<Pass> clonePass() const override {
+    void *clonedUserData = callbacks.clone(userData);
+    return std::make_unique<ExternalPass>(id, name, argument, description,
+                                          getOpName(), dependentDialects,
+                                          callbacks, clonedUserData);
+  }
+
+private:
+  TypeID id;
+  std::string name;
+  std::string argument;
+  std::string description;
+  std::vector<MlirDialectHandle> dependentDialects;
+  MlirExternalPassCallbacks callbacks;
+  void *userData;
+};
+} // namespace mlir
+
+MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
+                                MlirStringRef argument,
+                                MlirStringRef description, MlirStringRef opName,
+                                intptr_t nDependentDialects,
+                                MlirDialectHandle *dependentDialects,
+                                MlirExternalPassCallbacks callbacks,
+                                void *userData) {
+  return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
+      unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
+      opName.length > 0 ? Optional<StringRef>(unwrap(opName)) : None,
+      {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
+      userData)));
+}
+
+void mlirExternalPassSignalFailure(MlirExternalPass pass) {
+  unwrap(pass)->signalPassFailure();
+}

diff  --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp
index b6e1f9180c771..cbfbb54769aa9 100644
--- a/mlir/lib/CAPI/IR/Support.cpp
+++ b/mlir/lib/CAPI/IR/Support.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir-c/Support.h"
+#include "mlir/CAPI/Support.h"
 #include "llvm/ADT/StringRef.h"
 
 #include <cstring>
@@ -19,3 +19,40 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
   return llvm::StringRef(string.data, string.length) ==
          llvm::StringRef(other.data, other.length);
 }
+
+//===----------------------------------------------------------------------===//
+// TypeID API.
+//===----------------------------------------------------------------------===//
+
+MlirTypeID mlirTypeIDCreate(const void *ptr) {
+  assert(reinterpret_cast<uintptr_t>(ptr) % 8 == 0 &&
+         "ptr must be 8 byte aligned");
+  // This is essentially a no-op that returns back `ptr`, but by going through
+  // the `TypeID` functions we can get compiler errors in case the `TypeID`
+  // api/representation changes
+  return wrap(mlir::TypeID::getFromOpaquePointer(ptr));
+}
+
+bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
+  return unwrap(typeID1) == unwrap(typeID2);
+}
+
+size_t mlirTypeIDHashValue(MlirTypeID typeID) {
+  return hash_value(unwrap(typeID));
+}
+
+//===----------------------------------------------------------------------===//
+// TypeIDAllocator API.
+//===----------------------------------------------------------------------===//
+
+MlirTypeIDAllocator mlirTypeIDAllocatorCreate() {
+  return wrap(new mlir::TypeIDAllocator());
+}
+
+void mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator) {
+  delete unwrap(allocator);
+}
+
+MlirTypeID mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator) {
+  return wrap(unwrap(allocator)->allocate());
+}

diff  --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt
index 8a7a567f2a775..a4e2f024dd122 100644
--- a/mlir/test/CAPI/CMakeLists.txt
+++ b/mlir/test/CAPI/CMakeLists.txt
@@ -49,6 +49,7 @@ _add_capi_test_executable(mlir-capi-llvm-test
 _add_capi_test_executable(mlir-capi-pass-test
   pass.c
   LINK_LIBS PRIVATE
+    MLIRCAPIFunc
     MLIRCAPIIR
     MLIRCAPIRegistration
     MLIRCAPITransforms

diff  --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index 69d052c3131a3..aeabac4fd844d 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -11,6 +11,7 @@
  */
 
 #include "mlir-c/Pass.h"
+#include "mlir-c/Dialect/Func.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Registration.h"
 #include "mlir-c/Transforms.h"
@@ -169,7 +170,9 @@ void testParsePassPipeline() {
                                      " func.func(print-op-stats))"));
   // Expect a failure, we haven't registered the print-op-stats pass yet.
   if (mlirLogicalResultIsSuccess(status)) {
-    fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n");
+    fprintf(
+        stderr,
+        "Unexpected success parsing pipeline without registering the pass\n");
     exit(EXIT_FAILURE);
   }
   // Try again after registrating the pass.
@@ -180,7 +183,8 @@ void testParsePassPipeline() {
                                      " func.func(print-op-stats))"));
   // Expect a failure, we haven't registered the print-op-stats pass yet.
   if (mlirLogicalResultIsFailure(status)) {
-    fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n");
+    fprintf(stderr,
+            "Unexpected failure parsing pipeline after registering the pass\n");
     exit(EXIT_FAILURE);
   }
 
@@ -194,10 +198,328 @@ void testParsePassPipeline() {
   mlirContextDestroy(ctx);
 }
 
+struct TestExternalPassUserData {
+  int constructCallCount;
+  int destructCallCount;
+  int initializeCallCount;
+  int cloneCallCount;
+  int runCallCount;
+};
+typedef struct TestExternalPassUserData TestExternalPassUserData;
+
+void testConstructExternalPass(void *userData) {
+  ++((TestExternalPassUserData *)userData)->constructCallCount;
+}
+
+void testDestructExternalPass(void *userData) {
+  ++((TestExternalPassUserData *)userData)->destructCallCount;
+}
+
+MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) {
+  ++((TestExternalPassUserData *)userData)->initializeCallCount;
+  return mlirLogicalResultSuccess();
+}
+
+MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx,
+                                                    void *userData) {
+  ++((TestExternalPassUserData *)userData)->initializeCallCount;
+  return mlirLogicalResultFailure();
+}
+
+void *testCloneExternalPass(void *userData) {
+  ++((TestExternalPassUserData *)userData)->cloneCallCount;
+  return userData;
+}
+
+void testRunExternalPass(MlirOperation op, MlirExternalPass pass,
+                         void *userData) {
+  ++((TestExternalPassUserData *)userData)->runCallCount;
+}
+
+void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass,
+                             void *userData) {
+  ++((TestExternalPassUserData *)userData)->runCallCount;
+  MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op));
+  if (!mlirStringRefEqual(opName,
+                          mlirStringRefCreateFromCString("func.func"))) {
+    mlirExternalPassSignalFailure(pass);
+  }
+}
+
+void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass,
+                                void *userData) {
+  ++((TestExternalPassUserData *)userData)->runCallCount;
+  mlirExternalPassSignalFailure(pass);
+}
+
+MlirExternalPassCallbacks makeTestExternalPassCallbacks(
+    MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData),
+    void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) {
+  return (MlirExternalPassCallbacks){testConstructExternalPass,
+                                     testDestructExternalPass, initializePass,
+                                     testCloneExternalPass, runPass};
+}
+
+void testExternalPass() {
+  MlirContext ctx = mlirContextCreate();
+  mlirRegisterAllDialects(ctx);
+
+  MlirModule module = mlirModuleCreateParse(
+      ctx,
+      // clang-format off
+      mlirStringRefCreateFromCString(
+"func @foo(%arg0 : i32) -> i32 {                                            \n"
+"  %res = arith.addi %arg0, %arg0 : i32                                     \n"
+"  return %res : i32                                                        \n"
+"}"));
+  // clang-format on
+  if (mlirModuleIsNull(module)) {
+    fprintf(stderr, "Unexpected failure parsing module.\n");
+    exit(EXIT_FAILURE);
+  }
+
+  MlirStringRef description = mlirStringRefCreateFromCString("");
+  MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
+
+  MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+
+  // Run a generic pass
+  {
+    MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+    MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
+    MlirStringRef argument =
+        mlirStringRefCreateFromCString("test-external-pass");
+    TestExternalPassUserData userData = {0};
+
+    MlirPass externalPass = mlirCreateExternalPass(
+        passID, name, argument, description, emptyOpName, 0, NULL,
+        makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
+
+    if (userData.constructCallCount != 1) {
+      fprintf(stderr, "Expected constructCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    MlirPassManager pm = mlirPassManagerCreate(ctx);
+    mlirPassManagerAddOwnedPass(pm, externalPass);
+    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    if (mlirLogicalResultIsFailure(success)) {
+      fprintf(stderr, "Unexpected failure running external pass.\n");
+      exit(EXIT_FAILURE);
+    }
+
+    if (userData.runCallCount != 1) {
+      fprintf(stderr, "Expected runCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    mlirPassManagerDestroy(pm);
+
+    if (userData.destructCallCount != userData.constructCallCount) {
+      fprintf(stderr, "Expected destructCallCount to be equal to "
+                      "constructCallCount\n");
+      exit(EXIT_FAILURE);
+    }
+  }
+
+  // Run a func operation pass
+  {
+    MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+    MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass");
+    MlirStringRef argument =
+        mlirStringRefCreateFromCString("test-external-func-pass");
+    TestExternalPassUserData userData = {0};
+    MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
+    MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
+
+    MlirPass externalPass = mlirCreateExternalPass(
+        passID, name, argument, description, funcOpName, 1, &funcHandle,
+        makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
+        &userData);
+
+    if (userData.constructCallCount != 1) {
+      fprintf(stderr, "Expected constructCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    MlirPassManager pm = mlirPassManagerCreate(ctx);
+    MlirOpPassManager nestedFuncPm =
+        mlirPassManagerGetNestedUnder(pm, funcOpName);
+    mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
+    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    if (mlirLogicalResultIsFailure(success)) {
+      fprintf(stderr, "Unexpected failure running external operation pass.\n");
+      exit(EXIT_FAILURE);
+    }
+
+    // Since this is a nested pass, it can be cloned and run in parallel
+    if (userData.cloneCallCount != userData.constructCallCount - 1) {
+      fprintf(stderr, "Expected constructCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    // The pass should only be run once this there is only one func op
+    if (userData.runCallCount != 1) {
+      fprintf(stderr, "Expected runCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    mlirPassManagerDestroy(pm);
+
+    if (userData.destructCallCount != userData.constructCallCount) {
+      fprintf(stderr, "Expected destructCallCount to be equal to "
+                      "constructCallCount\n");
+      exit(EXIT_FAILURE);
+    }
+  }
+
+  // Run a pass with `initialize` set
+  {
+    MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+    MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
+    MlirStringRef argument =
+        mlirStringRefCreateFromCString("test-external-pass");
+    TestExternalPassUserData userData = {0};
+
+    MlirPass externalPass = mlirCreateExternalPass(
+        passID, name, argument, description, emptyOpName, 0, NULL,
+        makeTestExternalPassCallbacks(testInitializeExternalPass,
+                                      testRunExternalPass),
+        &userData);
+
+    if (userData.constructCallCount != 1) {
+      fprintf(stderr, "Expected constructCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    MlirPassManager pm = mlirPassManagerCreate(ctx);
+    mlirPassManagerAddOwnedPass(pm, externalPass);
+    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    if (mlirLogicalResultIsFailure(success)) {
+      fprintf(stderr, "Unexpected failure running external pass.\n");
+      exit(EXIT_FAILURE);
+    }
+
+    if (userData.initializeCallCount != 1) {
+      fprintf(stderr, "Expected initializeCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    if (userData.runCallCount != 1) {
+      fprintf(stderr, "Expected runCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    mlirPassManagerDestroy(pm);
+
+    if (userData.destructCallCount != userData.constructCallCount) {
+      fprintf(stderr, "Expected destructCallCount to be equal to "
+                      "constructCallCount\n");
+      exit(EXIT_FAILURE);
+    }
+  }
+
+  // Run a pass that fails during `initialize`
+  {
+    MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+    MlirStringRef name =
+        mlirStringRefCreateFromCString("TestExternalFailingPass");
+    MlirStringRef argument =
+        mlirStringRefCreateFromCString("test-external-failing-pass");
+    TestExternalPassUserData userData = {0};
+
+    MlirPass externalPass = mlirCreateExternalPass(
+        passID, name, argument, description, emptyOpName, 0, NULL,
+        makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
+                                      testRunExternalPass),
+        &userData);
+
+    if (userData.constructCallCount != 1) {
+      fprintf(stderr, "Expected constructCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    MlirPassManager pm = mlirPassManagerCreate(ctx);
+    mlirPassManagerAddOwnedPass(pm, externalPass);
+    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    if (mlirLogicalResultIsSuccess(success)) {
+      fprintf(
+          stderr,
+          "Expected failure running pass manager on failing external pass.\n");
+      exit(EXIT_FAILURE);
+    }
+
+    if (userData.initializeCallCount != 1) {
+      fprintf(stderr, "Expected initializeCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    if (userData.runCallCount != 0) {
+      fprintf(stderr, "Expected runCallCount to be 0\n");
+      exit(EXIT_FAILURE);
+    }
+
+    mlirPassManagerDestroy(pm);
+
+    if (userData.destructCallCount != userData.constructCallCount) {
+      fprintf(stderr, "Expected destructCallCount to be equal to "
+                      "constructCallCount\n");
+      exit(EXIT_FAILURE);
+    }
+  }
+
+  // Run a pass that fails during `run`
+  {
+    MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+    MlirStringRef name =
+        mlirStringRefCreateFromCString("TestExternalFailingPass");
+    MlirStringRef argument =
+        mlirStringRefCreateFromCString("test-external-failing-pass");
+    TestExternalPassUserData userData = {0};
+
+    MlirPass externalPass = mlirCreateExternalPass(
+        passID, name, argument, description, emptyOpName, 0, NULL,
+        makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
+        &userData);
+
+    if (userData.constructCallCount != 1) {
+      fprintf(stderr, "Expected constructCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    MlirPassManager pm = mlirPassManagerCreate(ctx);
+    mlirPassManagerAddOwnedPass(pm, externalPass);
+    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    if (mlirLogicalResultIsSuccess(success)) {
+      fprintf(
+          stderr,
+          "Expected failure running pass manager on failing external pass.\n");
+      exit(EXIT_FAILURE);
+    }
+
+    if (userData.runCallCount != 1) {
+      fprintf(stderr, "Expected runCallCount to be 1\n");
+      exit(EXIT_FAILURE);
+    }
+
+    mlirPassManagerDestroy(pm);
+
+    if (userData.destructCallCount != userData.constructCallCount) {
+      fprintf(stderr, "Expected destructCallCount to be equal to "
+                      "constructCallCount\n");
+      exit(EXIT_FAILURE);
+    }
+  }
+
+  mlirTypeIDAllocatorDestroy(typeIDAllocator);
+  mlirContextDestroy(ctx);
+}
+
 int main() {
   testRunPassOnModule();
   testRunPassOnNestedModule();
   testPrintPassPipeline();
   testParsePassPipeline();
+  testExternalPass();
   return 0;
 }


        


More information about the Mlir-commits mailing list