[Mlir-commits] [mlir] [MLIR][Python] Support dialect conversion in python bindings (PR #177782)

Maksim Levental llvmlistbot at llvm.org
Wed Jan 28 20:51:13 PST 2026


================
@@ -447,6 +520,146 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
   return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
 }
 
+//===----------------------------------------------------------------------===//
+/// ConversionPatternRewriter API
+//===----------------------------------------------------------------------===//
+
+MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(
+    MlirConversionPatternRewriter rewriter) {
+  return wrap(static_cast<mlir::PatternRewriter *>(unwrap(rewriter)));
+}
+
+//===----------------------------------------------------------------------===//
+/// ConversionTarget API
+//===----------------------------------------------------------------------===//
+
+MlirConversionTarget mlirConversionTargetCreate(MlirContext context) {
+  return wrap(new mlir::ConversionTarget(*unwrap(context)));
+}
+
+void mlirConversionTargetDestroy(MlirConversionTarget target) {
+  delete unwrap(target);
+}
+
+void mlirConversionTargetAddLegalOp(MlirConversionTarget target,
+                                    MlirStringRef opName) {
+  unwrap(target)->addLegalOp(
+      mlir::OperationName(unwrap(opName), &unwrap(target)->getContext()));
+}
+
+void mlirConversionTargetAddIllegalOp(MlirConversionTarget target,
+                                      MlirStringRef opName) {
+  unwrap(target)->addIllegalOp(
+      mlir::OperationName(unwrap(opName), &unwrap(target)->getContext()));
+}
+
+void mlirConversionTargetAddLegalDialect(MlirConversionTarget target,
+                                         MlirStringRef dialectName) {
+  unwrap(target)->addLegalDialect(unwrap(dialectName));
+}
+
+void mlirConversionTargetAddIllegalDialect(MlirConversionTarget target,
+                                           MlirStringRef dialectName) {
+  unwrap(target)->addIllegalDialect(unwrap(dialectName));
+}
+
+//===----------------------------------------------------------------------===//
+/// TypeConverter API
+//===----------------------------------------------------------------------===//
+
+MlirTypeConverter mlirTypeConverterCreate() {
+  return wrap(new mlir::TypeConverter());
+}
+
+void mlirTypeConverterDestroy(MlirTypeConverter typeConverter) {
+  delete unwrap(typeConverter);
+}
+
+void mlirTypeConverterAddConversion(
+    MlirTypeConverter typeConverter,
+    MlirTypeConverterConversionCallback convertType, void *userData) {
+  unwrap(typeConverter)
+      ->addConversion(
+          [convertType, userData](Type type) -> std::optional<Type> {
+            MlirType converted{nullptr};
+            MlirLogicalResult result =
+                convertType(wrap(type), &converted, userData);
+            if (mlirLogicalResultIsFailure(result))
+              return std::nullopt; // allowed to try another conversion function
+            if (mlirTypeIsNull(converted))
+              return nullptr;
+            return unwrap(converted);
+          });
+}
+
+//===----------------------------------------------------------------------===//
+/// ConversionPattern API
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+
+class ExternalConversionPattern : public mlir::ConversionPattern {
+public:
+  ExternalConversionPattern(MlirConversionPatternCallbacks callbacks,
+                            void *userData, StringRef rootName,
+                            PatternBenefit benefit, MLIRContext *context,
+                            TypeConverter *typeConverter,
+                            ArrayRef<StringRef> generatedNames)
+      : ConversionPattern(*typeConverter, rootName, benefit, context,
+                          generatedNames),
+        callbacks(callbacks), userData(userData) {
+    if (callbacks.construct)
+      callbacks.construct(userData);
+  }
+
+  ~ExternalConversionPattern() {
+    if (callbacks.destruct)
+      callbacks.destruct(userData);
+  }
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    std::vector<MlirValue> wrappedOperands;
+    for (Value val : operands)
+      wrappedOperands.push_back(wrap(val));
+    return unwrap(callbacks.matchAndRewrite(
+        wrap(static_cast<const mlir::ConversionPattern *>(this)), wrap(op),
+        wrappedOperands.size(), wrappedOperands.data(), wrap(&rewriter),
+        userData));
+  }
+
+private:
+  MlirConversionPatternCallbacks callbacks;
+  void *userData;
+};
+
+} // namespace mlir
+
+MlirConversionPattern mlirOpConversionPatternCreate(
+    MlirStringRef rootName, unsigned benefit, MlirContext context,
+    MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks,
+    void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames) {
+  std::vector<mlir::StringRef> generatedNamesVec;
+  generatedNamesVec.reserve(nGeneratedNames);
+  for (size_t i = 0; i < nGeneratedNames; ++i) {
+    generatedNamesVec.push_back(unwrap(generatedNames[i]));
+  }
----------------
makslevental wrote:

nit:
```suggestion
  for (size_t i = 0; i < nGeneratedNames; ++i)
    generatedNamesVec.push_back(unwrap(generatedNames[i]));
```

https://github.com/llvm/llvm-project/pull/177782


More information about the Mlir-commits mailing list