[Mlir-commits] [mlir] [MLIR][Python] Support dialect conversion in python bindings (PR #177782)
Maksim Levental
llvmlistbot at llvm.org
Wed Jan 28 20:51:13 PST 2026
================
@@ -447,6 +520,146 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}
+//===----------------------------------------------------------------------===//
+/// ConversionPatternRewriter API
+//===----------------------------------------------------------------------===//
+
+MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(
+ MlirConversionPatternRewriter rewriter) {
+ return wrap(static_cast<mlir::PatternRewriter *>(unwrap(rewriter)));
+}
+
+//===----------------------------------------------------------------------===//
+/// ConversionTarget API
+//===----------------------------------------------------------------------===//
+
+MlirConversionTarget mlirConversionTargetCreate(MlirContext context) {
+ return wrap(new mlir::ConversionTarget(*unwrap(context)));
+}
+
+void mlirConversionTargetDestroy(MlirConversionTarget target) {
+ delete unwrap(target);
+}
+
+void mlirConversionTargetAddLegalOp(MlirConversionTarget target,
+ MlirStringRef opName) {
+ unwrap(target)->addLegalOp(
+ mlir::OperationName(unwrap(opName), &unwrap(target)->getContext()));
+}
+
+void mlirConversionTargetAddIllegalOp(MlirConversionTarget target,
+ MlirStringRef opName) {
+ unwrap(target)->addIllegalOp(
+ mlir::OperationName(unwrap(opName), &unwrap(target)->getContext()));
+}
+
+void mlirConversionTargetAddLegalDialect(MlirConversionTarget target,
+ MlirStringRef dialectName) {
+ unwrap(target)->addLegalDialect(unwrap(dialectName));
+}
+
+void mlirConversionTargetAddIllegalDialect(MlirConversionTarget target,
+ MlirStringRef dialectName) {
+ unwrap(target)->addIllegalDialect(unwrap(dialectName));
+}
+
+//===----------------------------------------------------------------------===//
+/// TypeConverter API
+//===----------------------------------------------------------------------===//
+
+MlirTypeConverter mlirTypeConverterCreate() {
+ return wrap(new mlir::TypeConverter());
+}
+
+void mlirTypeConverterDestroy(MlirTypeConverter typeConverter) {
+ delete unwrap(typeConverter);
+}
+
+void mlirTypeConverterAddConversion(
+ MlirTypeConverter typeConverter,
+ MlirTypeConverterConversionCallback convertType, void *userData) {
+ unwrap(typeConverter)
+ ->addConversion(
+ [convertType, userData](Type type) -> std::optional<Type> {
+ MlirType converted{nullptr};
+ MlirLogicalResult result =
+ convertType(wrap(type), &converted, userData);
+ if (mlirLogicalResultIsFailure(result))
+ return std::nullopt; // allowed to try another conversion function
+ if (mlirTypeIsNull(converted))
+ return nullptr;
+ return unwrap(converted);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+/// ConversionPattern API
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+
+class ExternalConversionPattern : public mlir::ConversionPattern {
+public:
+ ExternalConversionPattern(MlirConversionPatternCallbacks callbacks,
+ void *userData, StringRef rootName,
+ PatternBenefit benefit, MLIRContext *context,
+ TypeConverter *typeConverter,
+ ArrayRef<StringRef> generatedNames)
+ : ConversionPattern(*typeConverter, rootName, benefit, context,
+ generatedNames),
+ callbacks(callbacks), userData(userData) {
+ if (callbacks.construct)
+ callbacks.construct(userData);
+ }
+
+ ~ExternalConversionPattern() {
+ if (callbacks.destruct)
+ callbacks.destruct(userData);
+ }
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ std::vector<MlirValue> wrappedOperands;
+ for (Value val : operands)
+ wrappedOperands.push_back(wrap(val));
+ return unwrap(callbacks.matchAndRewrite(
+ wrap(static_cast<const mlir::ConversionPattern *>(this)), wrap(op),
+ wrappedOperands.size(), wrappedOperands.data(), wrap(&rewriter),
+ userData));
+ }
+
+private:
+ MlirConversionPatternCallbacks callbacks;
+ void *userData;
+};
+
+} // namespace mlir
+
+MlirConversionPattern mlirOpConversionPatternCreate(
+ MlirStringRef rootName, unsigned benefit, MlirContext context,
+ MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks,
+ void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames) {
+ std::vector<mlir::StringRef> generatedNamesVec;
+ generatedNamesVec.reserve(nGeneratedNames);
+ for (size_t i = 0; i < nGeneratedNames; ++i) {
+ generatedNamesVec.push_back(unwrap(generatedNames[i]));
+ }
----------------
makslevental wrote:
nit:
```suggestion
for (size_t i = 0; i < nGeneratedNames; ++i)
generatedNamesVec.push_back(unwrap(generatedNames[i]));
```
https://github.com/llvm/llvm-project/pull/177782
More information about the Mlir-commits
mailing list