[Mlir-commits] [mlir] f992f97 - [MLIR][Python] Support dialect conversion in python bindings (#177782)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 30 20:37:54 PST 2026


Author: Twice
Date: 2026-01-31T12:37:49+08:00
New Revision: f992f9719fe13c9ed8bf8e3571d190a69e0e5593

URL: https://github.com/llvm/llvm-project/commit/f992f9719fe13c9ed8bf8e3571d190a69e0e5593
DIFF: https://github.com/llvm/llvm-project/commit/f992f9719fe13c9ed8bf8e3571d190a69e0e5593.diff

LOG: [MLIR][Python] Support dialect conversion in python bindings (#177782)

This PR adds dialect conversion support to the MLIR Python bindings.
Because it introduces a number of new APIs, it’s a fairly large PR. It
mainly includes the following parts:

* Add a set of types and APIs to the C API, including
`MlirConversionTarget`, `MlirConversionPattern`, `MlirTypeConverter`,
`MlirConversionPatternRewriter`, and others.
* Add the corresponding types and APIs to the Python bindings.
* Extend `mlir-tblgen` with codegen for Python adaptor classes, which
generates an adaptor class for each op.

Note that this PR only adds support for 1-to-1 conversions, 1-to-N
type/value conversions are not supported yet.

---------

Co-authored-by: Maksim Levental <maksim.levental at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir-c/Rewrite.h
    mlir/include/mlir/Bindings/Python/Globals.h
    mlir/include/mlir/Bindings/Python/IRCore.h
    mlir/include/mlir/CAPI/Rewrite.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Bindings/Python/Globals.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/Rewrite.cpp
    mlir/lib/CAPI/Transforms/Rewrite.cpp
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/test/python/rewrite.py
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 26f7f08535b41..b4f93fd5a9b78 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -59,6 +59,11 @@ typedef enum {
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
 DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
+DEFINE_C_API_STRUCT(MlirConversionTarget, void);
+DEFINE_C_API_STRUCT(MlirConversionPattern, const void);
+DEFINE_C_API_STRUCT(MlirTypeConverter, void);
+DEFINE_C_API_STRUCT(MlirConversionPatternRewriter, void);
+DEFINE_C_API_STRUCT(MlirConversionConfig, void);
 
 //===----------------------------------------------------------------------===//
 /// RewriterBase API inherited from OpBuilder
@@ -423,6 +428,51 @@ MLIR_CAPI_EXPORTED void
 mlirWalkAndApplyPatterns(MlirOperation op,
                          MlirFrozenRewritePatternSet patterns);
 
+/// Apply a partial conversion on the given operation.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPartialConversion(
+    MlirOperation op, MlirConversionTarget target,
+    MlirFrozenRewritePatternSet patterns, MlirConversionConfig config);
+
+/// Apply a full conversion on the given operation.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyFullConversion(
+    MlirOperation op, MlirConversionTarget target,
+    MlirFrozenRewritePatternSet patterns, MlirConversionConfig config);
+
+//===----------------------------------------------------------------------===//
+/// ConversionConfig API
+//===----------------------------------------------------------------------===//
+
+/// Create a default ConversionConfig.
+MLIR_CAPI_EXPORTED MlirConversionConfig mlirConversionConfigCreate(void);
+
+/// Destroy the given ConversionConfig.
+MLIR_CAPI_EXPORTED void
+mlirConversionConfigDestroy(MlirConversionConfig config);
+
+typedef enum {
+  MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER,
+  MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS,
+  MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS,
+} MlirDialectConversionFoldingMode;
+
+/// Set the folding mode for the given ConversionConfig.
+MLIR_CAPI_EXPORTED void
+mlirConversionConfigSetFoldingMode(MlirConversionConfig config,
+                                   MlirDialectConversionFoldingMode mode);
+
+/// Get the folding mode for the given ConversionConfig.
+MLIR_CAPI_EXPORTED MlirDialectConversionFoldingMode
+mlirConversionConfigGetFoldingMode(MlirConversionConfig config);
+
+/// Enable or disable building materializations during conversion.
+MLIR_CAPI_EXPORTED void
+mlirConversionConfigEnableBuildMaterializations(MlirConversionConfig config,
+                                                bool enable);
+
+/// Check if building materializations during conversion is enabled.
+MLIR_CAPI_EXPORTED bool
+mlirConversionConfigIsBuildMaterializationsEnabled(MlirConversionConfig config);
+
 //===----------------------------------------------------------------------===//
 /// PatternRewriter API
 //===----------------------------------------------------------------------===//
@@ -431,6 +481,107 @@ mlirWalkAndApplyPatterns(MlirOperation op,
 MLIR_CAPI_EXPORTED MlirRewriterBase
 mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
 
+//===----------------------------------------------------------------------===//
+/// ConversionPatternRewriter API
+//===----------------------------------------------------------------------===//
+
+/// Cast the ConversionPatternRewriter to a PatternRewriter
+MLIR_CAPI_EXPORTED MlirPatternRewriter
+mlirConversionPatternRewriterAsPatternRewriter(
+    MlirConversionPatternRewriter rewriter);
+
+//===----------------------------------------------------------------------===//
+/// ConversionTarget API
+//===----------------------------------------------------------------------===//
+
+/// Create an empty ConversionTarget.
+MLIR_CAPI_EXPORTED MlirConversionTarget
+mlirConversionTargetCreate(MlirContext context);
+
+/// Destroy the given ConversionTarget.
+MLIR_CAPI_EXPORTED void
+mlirConversionTargetDestroy(MlirConversionTarget target);
+
+/// Register the given operations as legal.
+MLIR_CAPI_EXPORTED void
+mlirConversionTargetAddLegalOp(MlirConversionTarget target,
+                               MlirStringRef opName);
+
+/// Register the given operations as illegal.
+MLIR_CAPI_EXPORTED void
+mlirConversionTargetAddIllegalOp(MlirConversionTarget target,
+                                 MlirStringRef opName);
+
+/// Register the operations of the given dialect as legal.
+MLIR_CAPI_EXPORTED void
+mlirConversionTargetAddLegalDialect(MlirConversionTarget target,
+                                    MlirStringRef dialectName);
+
+/// Register the operations of the given dialect as illegal.
+MLIR_CAPI_EXPORTED void
+mlirConversionTargetAddIllegalDialect(MlirConversionTarget target,
+                                      MlirStringRef dialectName);
+
+//===----------------------------------------------------------------------===//
+/// TypeConverter API
+//===----------------------------------------------------------------------===//
+
+/// Create a TypeConverter.
+MLIR_CAPI_EXPORTED MlirTypeConverter mlirTypeConverterCreate(void);
+
+/// Destroy the given TypeConverter.
+MLIR_CAPI_EXPORTED void
+mlirTypeConverterDestroy(MlirTypeConverter typeConverter);
+
+/// Callback type for type conversion functions.
+/// Returns failure or sets convertedType to MlirType{NULL} to indicate failure.
+/// If failure is returned, the converter is allowed to try another
+/// conversion function to perform the conversion.
+typedef MlirLogicalResult (*MlirTypeConverterConversionCallback)(
+    MlirType type, MlirType *convertedType, void *userData);
+
+/// Add a type conversion function to the given TypeConverter.
+MLIR_CAPI_EXPORTED void
+mlirTypeConverterAddConversion(MlirTypeConverter typeConverter,
+                               MlirTypeConverterConversionCallback convertType,
+                               void *userData);
+
+//===----------------------------------------------------------------------===//
+/// ConversionPattern API
+//===----------------------------------------------------------------------===//
+
+typedef struct {
+  /// Optional constructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*construct)(void *userData);
+  /// Optional destructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*destruct)(void *userData);
+  /// The callback function to match against code rooted at the specified
+  /// operation, and perform the conversion rewrite if the match is successful,
+  /// corresponding to ConversionPattern::matchAndRewrite.
+  MlirLogicalResult (*matchAndRewrite)(MlirConversionPattern pattern,
+                                       MlirOperation op, intptr_t nOperands,
+                                       MlirValue *operands,
+                                       MlirConversionPatternRewriter rewriter,
+                                       void *userData);
+} MlirConversionPatternCallbacks;
+
+/// Create a conversion pattern that matches the operation with the given
+/// rootName, corresponding to mlir::OpConversionPattern.
+MLIR_CAPI_EXPORTED MlirConversionPattern mlirOpConversionPatternCreate(
+    MlirStringRef rootName, unsigned benefit, MlirContext context,
+    MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks,
+    void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames);
+
+/// Get the type converter used by this conversion pattern.
+MLIR_CAPI_EXPORTED MlirTypeConverter
+mlirConversionPatternGetTypeConverter(MlirConversionPattern pattern);
+
+/// Cast the ConversionPattern to a RewritePattern.
+MLIR_CAPI_EXPORTED MlirRewritePattern
+mlirConversionPatternAsRewritePattern(MlirConversionPattern pattern);
+
 //===----------------------------------------------------------------------===//
 /// RewritePattern API
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 5548a716cbe21..6a722575c4e48 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -94,6 +94,12 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
   void registerOperationImpl(const std::string &operationName,
                              nanobind::object pyClass, bool replace = false);
 
+  /// Adds an operation adaptor class.
+  /// Raises an exception if the mapping already exists and replace == false.
+  /// This is intended to be called by implementation code.
+  void registerOpAdaptorImpl(const std::string &operationName,
+                             nanobind::object pyClass, bool replace = false);
+
   /// Returns the custom Attribute builder for Attribute kind.
   std::optional<nanobind::callable>
   lookupAttributeBuilder(const std::string &attributeKind);
@@ -117,6 +123,12 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
   std::optional<nanobind::object>
   lookupOperationClass(llvm::StringRef operationName);
 
+  /// Looks up a registered operation adaptor class by operation
+  /// name. Note that this may trigger a load of the dialect, which can
+  /// arbitrarily re-enter.
+  std::optional<nanobind::object>
+  lookupOpAdaptorClass(llvm::StringRef operationName);
+
   class MLIR_PYTHON_API_EXPORTED TracebackLoc {
   public:
     bool locTracebacksEnabled();
@@ -184,6 +196,8 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
   llvm::StringMap<nanobind::object> dialectClassMap;
   /// Map of full operation name to external operation class object.
   llvm::StringMap<nanobind::object> operationClassMap;
+  /// Map of full operation name to external operation adaptor class object.
+  llvm::StringMap<nanobind::object> opAdaptorClassMap;
   /// Map of attribute ODS name to custom builder.
   llvm::StringMap<nanobind::callable> attributeBuilderMap;
   /// Map of MlirTypeID to custom type caster.

diff  --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index a9861c3c72555..4bb49e6bc245d 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1828,6 +1828,22 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
   PyOperationRef operation;
 };
 
+/// Base class of operation adaptors.
+class MLIR_PYTHON_API_EXPORTED PyOpAdaptor {
+public:
+  PyOpAdaptor(nanobind::list operands, PyOpAttributeMap attributes)
+      : operands(std::move(operands)), attributes(std::move(attributes)) {}
+  PyOpAdaptor(nanobind::list operands, PyOpView &opView)
+      : operands(std::move(operands)),
+        attributes(opView.getOperation().getRef()) {}
+
+  static void bind(nanobind::module_ &m);
+
+private:
+  nanobind::list operands;
+  PyOpAttributeMap attributes;
+};
+
 MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
 MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
 MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);

diff  --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h
index 9c96d354d4fc9..a172f59b3e3ce 100644
--- a/mlir/include/mlir/CAPI/Rewrite.h
+++ b/mlir/include/mlir/CAPI/Rewrite.h
@@ -19,6 +19,7 @@
 #include "mlir/CAPI/Wrap.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase)
 DEFINE_C_API_PTR_METHODS(MlirRewritePattern, const mlir::RewritePattern)
@@ -26,6 +27,12 @@ DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet)
 DEFINE_C_API_PTR_METHODS(MlirFrozenRewritePatternSet,
                          mlir::FrozenRewritePatternSet)
 DEFINE_C_API_PTR_METHODS(MlirPatternRewriter, mlir::PatternRewriter)
+DEFINE_C_API_PTR_METHODS(MlirConversionTarget, mlir::ConversionTarget)
+DEFINE_C_API_PTR_METHODS(MlirConversionPattern, const mlir::ConversionPattern)
+DEFINE_C_API_PTR_METHODS(MlirTypeConverter, mlir::TypeConverter)
+DEFINE_C_API_PTR_METHODS(MlirConversionPatternRewriter,
+                         mlir::ConversionPatternRewriter)
+DEFINE_C_API_PTR_METHODS(MlirConversionConfig, mlir::ConversionConfig)
 
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, mlir::PDLPatternModule)
@@ -33,4 +40,4 @@ DEFINE_C_API_PTR_METHODS(MlirPDLResultList, mlir::PDLResultList)
 DEFINE_C_API_PTR_METHODS(MlirPDLValue, const mlir::PDLValue)
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
-#endif // MLIR_CAPIREWRITER_H
+#endif // MLIR_CAPI_REWRITE_H

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9f449080b0f37..0f67b9eceab59 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1108,6 +1108,8 @@ class ConversionTarget {
   ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
   virtual ~ConversionTarget() = default;
 
+  MLIRContext &getContext() const { return ctx; }
+
   //===--------------------------------------------------------------------===//
   // Legality Registration
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index e2e8693ba45f3..3d7ee3d30656e 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -137,6 +137,18 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
   found = std::move(pyClass);
 }
 
+void PyGlobals::registerOpAdaptorImpl(const std::string &operationName,
+                                      nb::object pyClass, bool replace) {
+  nb::ft_lock_guard lock(mutex);
+  nb::object &found = opAdaptorClassMap[operationName];
+  if (found && !replace) {
+    throw std::runtime_error((llvm::Twine("Operation adaptor of '") +
+                              operationName + "' is already registered.")
+                                 .str());
+  }
+  found = std::move(pyClass);
+}
+
 std::optional<nb::callable>
 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
   nb::ft_lock_guard lock(mutex);
@@ -207,6 +219,24 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
   return std::nullopt;
 }
 
+std::optional<nb::object>
+PyGlobals::lookupOpAdaptorClass(llvm::StringRef operationName) {
+  // Make sure dialect module is loaded.
+  auto split = operationName.split('.');
+  llvm::StringRef dialectNamespace = split.first;
+  if (!loadDialectModule(dialectNamespace))
+    return std::nullopt;
+
+  nb::ft_lock_guard lock(mutex);
+  auto foundIt = opAdaptorClassMap.find(operationName);
+  if (foundIt != opAdaptorClassMap.end()) {
+    assert(foundIt->second && "OpAdaptor is defined");
+    return foundIt->second;
+  }
+  // Not found and loading did not yield a registration.
+  return std::nullopt;
+}
+
 bool PyGlobals::TracebackLoc::locTracebacksEnabled() {
   nanobind::ft_lock_guard lock(mutex);
   return locTracebackEnabled_;

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index c7653f1ee4b15..dda4a027f0a30 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2505,6 +2505,22 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
           "Returns a list of `(name, attribute)` tuples.");
 }
 
+void PyOpAdaptor::bind(nb::module_ &m) {
+  nb::class_<PyOpAdaptor>(m, "OpAdaptor")
+      .def(nb::init<nb::list, PyOpAttributeMap>(),
+           "Creates an OpAdaptor with the given operands and attributes.",
+           "operands"_a, "attributes"_a)
+      .def(nb::init<nb::list, PyOpView &>(),
+           "Creates an OpAdaptor with the given operands and operation view.",
+           "operands"_a, "opview"_a)
+      .def_prop_ro(
+          "operands", [](PyOpAdaptor &self) { return self.operands; },
+          "Returns the operands of the adaptor.")
+      .def_prop_ro(
+          "attributes", [](PyOpAdaptor &self) { return self.attributes; },
+          "Returns the attributes of the adaptor.");
+}
+
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
 } // namespace mlir
@@ -2755,6 +2771,28 @@ void populateRoot(nb::module_ &m) {
       "dialect_class"_a, nb::kw_only(), "replace"_a = false,
       "Produce a class decorator for registering an Operation class as part of "
       "a dialect");
+  m.def(
+      "register_op_adaptor",
+      [](const nb::type_object &opClass, bool replace) -> nb::object {
+        return nb::cpp_function(
+            [opClass,
+             replace](nb::type_object adaptorClass) -> nb::type_object {
+              std::string operationName =
+                  nb::cast<std::string>(adaptorClass.attr("OPERATION_NAME"));
+              PyGlobals::get().registerOpAdaptorImpl(operationName,
+                                                     adaptorClass, replace);
+              // Dict-stuff the new adaptorClass by name onto the opClass.
+              opClass.attr("Adaptor") = adaptorClass;
+              return adaptorClass;
+            });
+      },
+      // clang-format off
+      nb::sig("def register_op_adaptor(op_class: type, *, replace: bool = False) "
+        "-> typing.Callable[[type[T]], type[T]]"),
+      // clang-format on
+      "op_class"_a, nb::kw_only(), "replace"_a = false,
+      "Produce a class decorator for registering an OpAdaptor class for an "
+      "operation.");
   m.def(
       MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
       [](PyTypeID mlirTypeID, bool replace) -> nb::object {
@@ -3933,6 +3971,8 @@ void populateIRCore(nb::module_ &m) {
       "context"_a = nb::none(),
       "Parses a specific, generated OpView based on class level attributes.");
 
+  PyOpAdaptor::bind(m);
+
   //----------------------------------------------------------------------------
   // Mapping of PyRegion.
   //----------------------------------------------------------------------------

diff  --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 2b649f79c5982..bd95adbca5274 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -11,6 +11,7 @@
 #include "mlir-c/IR.h"
 #include "mlir-c/Rewrite.h"
 #include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Globals.h"
 #include "mlir/Bindings/Python/IRCore.h"
 // clang-format off
 #include "mlir/Bindings/Python/Nanobind.h"
@@ -18,6 +19,7 @@
 // clang-format on
 #include "mlir/Config/mlir-config.h"
 #include "nanobind/nanobind.h"
+#include <type_traits>
 
 namespace nb = nanobind;
 using namespace mlir;
@@ -62,9 +64,94 @@ class PyPatternRewriter {
   PyMlirContextRef ctx;
 };
 
-struct PyMlirPDLResultList : MlirPDLResultList {};
+class PyConversionPatternRewriter : PyPatternRewriter {
+public:
+  PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
+      : PyPatternRewriter(
+            mlirConversionPatternRewriterAsPatternRewriter(rewriter)) {}
+};
+
+class PyConversionTarget {
+public:
+  PyConversionTarget(MlirContext context)
+      : target(mlirConversionTargetCreate(context)) {}
+  ~PyConversionTarget() { mlirConversionTargetDestroy(target); }
+
+  void addLegalOp(const std::string &opName) {
+    mlirConversionTargetAddLegalOp(
+        target, mlirStringRefCreate(opName.data(), opName.size()));
+  }
+
+  void addIllegalOp(const std::string &opName) {
+    mlirConversionTargetAddIllegalOp(
+        target, mlirStringRefCreate(opName.data(), opName.size()));
+  }
+
+  void addLegalDialect(const std::string &dialectName) {
+    mlirConversionTargetAddLegalDialect(
+        target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
+  }
+
+  void addIllegalDialect(const std::string &dialectName) {
+    mlirConversionTargetAddIllegalDialect(
+        target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
+  }
+
+  MlirConversionTarget get() { return target; }
+
+private:
+  MlirConversionTarget target;
+};
+
+class PyTypeConverter {
+public:
+  PyTypeConverter() : typeConverter(mlirTypeConverterCreate()), owner(true) {}
+  PyTypeConverter(MlirTypeConverter typeConverter)
+      : typeConverter(typeConverter), owner(false) {}
+  ~PyTypeConverter() {
+    if (owner)
+      mlirTypeConverterDestroy(typeConverter);
+  }
+
+  void addConversion(const nb::callable &convert) {
+    mlirTypeConverterAddConversion(
+        typeConverter,
+        [](MlirType type, MlirType *converted,
+           void *userData) -> MlirLogicalResult {
+          nb::handle f = nb::handle(static_cast<PyObject *>(userData));
+          auto ctx = PyMlirContext::forContext(mlirTypeGetContext(type));
+          nb::object res = f(PyType(ctx, type).maybeDownCast());
+          if (res.is_none())
+            return mlirLogicalResultFailure();
+
+          *converted = nb::cast<PyType>(res).get();
+          return mlirLogicalResultSuccess();
+        },
+        convert.ptr());
+  }
+
+  MlirTypeConverter get() { return typeConverter; }
+
+private:
+  MlirTypeConverter typeConverter;
+  bool owner;
+};
+
+class PyConversionPattern {
+public:
+  PyConversionPattern(MlirConversionPattern pattern) : pattern(pattern) {}
+
+  PyTypeConverter getTypeConverter() {
+    return PyTypeConverter(mlirConversionPatternGetTypeConverter(pattern));
+  }
+
+private:
+  MlirConversionPattern pattern;
+};
 
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+struct PyMlirPDLResultList : MlirPDLResultList {};
+
 static nb::object objectFromPDLValue(MlirPDLValue value) {
   if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
     return nb::cast(v);
@@ -216,6 +303,47 @@ class PyRewritePatternSet {
     mlirRewritePatternSetAdd(set, pattern);
   }
 
+  void addConversion(MlirStringRef rootName, unsigned benefit,
+                     const nb::callable &matchAndRewrite,
+                     PyTypeConverter &typeConverter) {
+    MlirConversionPatternCallbacks callbacks;
+    callbacks.construct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+    };
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+    callbacks.matchAndRewrite =
+        [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
+           MlirValue *operands, MlirConversionPatternRewriter rewriter,
+           void *userData) -> MlirLogicalResult {
+      nb::handle f(static_cast<PyObject *>(userData));
+
+      PyMlirContextRef ctx =
+          PyMlirContext::forContext(mlirOperationGetContext(op));
+      nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
+
+      std::vector<MlirValue> operandsVec(operands, operands + nOperands);
+      nb::object adaptorCls =
+          PyGlobals::get()
+              .lookupOpAdaptorClass(
+                  unwrap(mlirIdentifierStr(mlirOperationGetName(op))))
+              .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
+
+      nb::object res = f(opView, adaptorCls(operandsVec, opView),
+                         PyConversionPattern(pattern).getTypeConverter(),
+                         PyConversionPatternRewriter(rewriter));
+      return logicalResultFromObject(res);
+    };
+    MlirConversionPattern pattern = mlirOpConversionPatternCreate(
+        rootName, benefit, ctx, typeConverter.get(), callbacks,
+        matchAndRewrite.ptr(),
+        /* nGeneratedNames */ 0,
+        /* generatedNames */ nullptr);
+    mlirRewritePatternSetAdd(set,
+                             mlirConversionPatternAsRewritePattern(pattern));
+  }
+
   PyFrozenRewritePatternSet freeze() {
     MlirRewritePatternSet s = set;
     set.ptr = nullptr;
@@ -324,6 +452,46 @@ class PyGreedyRewriteConfig {
   }
 };
 
+enum class PyDialectConversionFoldingMode : std::underlying_type_t<
+    MlirDialectConversionFoldingMode> {
+  Never = MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER,
+  BeforePatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS,
+  AfterPatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS,
+};
+
+class PyConversionConfig {
+public:
+  PyConversionConfig()
+      : config(mlirConversionConfigCreate().ptr,
+               PyConversionConfig::customDeleter) {}
+
+  MlirConversionConfig get() { return MlirConversionConfig{config.get()}; }
+
+  void setFoldingMode(PyDialectConversionFoldingMode mode) {
+    mlirConversionConfigSetFoldingMode(get(),
+                                       MlirDialectConversionFoldingMode(mode));
+  }
+
+  PyDialectConversionFoldingMode getFoldingMode() {
+    return PyDialectConversionFoldingMode(
+        mlirConversionConfigGetFoldingMode(get()));
+  }
+
+  void enableBuildMaterializations(bool enabled) {
+    mlirConversionConfigEnableBuildMaterializations(get(), enabled);
+  }
+
+  bool isBuildMaterializationsEnabled() {
+    return mlirConversionConfigIsBuildMaterializationsEnabled(get());
+  }
+
+private:
+  std::shared_ptr<void> config;
+  static void customDeleter(void *c) {
+    mlirConversionConfigDestroy(MlirConversionConfig{c});
+  }
+};
+
 /// Create the `mlir.rewrite` here.
 void populateRewriteSubmodule(nb::module_ &m) {
   // Enum definitions
@@ -337,6 +505,12 @@ void populateRewriteSubmodule(nb::module_ &m) {
       .value("DISABLED", PyGreedySimplifyRegionLevel::DISABLED)
       .value("NORMAL", PyGreedySimplifyRegionLevel::NORMAL)
       .value("AGGRESSIVE", PyGreedySimplifyRegionLevel::AGGRESSIVE);
+
+  nb::enum_<PyDialectConversionFoldingMode>(m, "DialectConversionFoldingMode")
+      .value("NEVER", PyDialectConversionFoldingMode::Never)
+      .value("BEFORE_PATTERNS", PyDialectConversionFoldingMode::BeforePatterns)
+      .value("AFTER_PATTERNS", PyDialectConversionFoldingMode::AfterPatterns);
+
   //----------------------------------------------------------------------------
   // Mapping of the PatternRewriter
   //----------------------------------------------------------------------------
@@ -409,9 +583,94 @@ void populateRewriteSubmodule(nb::module_ &m) {
                   If possible, the operation is cast to its corresponding OpView subclass
                   before being passed to the callable.
               benefit: The benefit of the pattern, defaulting to 1.)")
+      .def(
+          "add_conversion",
+          [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+             PyTypeConverter &typeConverter, unsigned benefit) {
+            std::string opName =
+                nb::cast<std::string>(root.attr("OPERATION_NAME"));
+            self.addConversion(
+                mlirStringRefCreate(opName.data(), opName.size()), benefit, fn,
+                typeConverter);
+          },
+          "root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1,
+          R"(
+            Add a new conversion pattern on the specified root operation,
+            using the provided callable for matching and rewriting,
+            and assign it the given benefit.
+
+            Args:
+              root: The root operation to which this pattern applies.
+                    This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
+                    an operation name string (e.g., ``"arith.addi"``).
+              fn: The callable to use for matching and rewriting,
+                  which takes an operation, its adaptor,
+                  the type converter and a pattern rewriter as arguments.
+                  The match is considered successful iff the callable returns
+                  a value where ``bool(value)`` is ``False`` (e.g. ``None``).
+                  If possible, the operation is cast to its corresponding OpView subclass
+                  before being passed to the callable.
+              type_converter: The type converter to convert types in the IR.
+              benefit: The benefit of the pattern, defaulting to 1.)")
       .def("freeze", &PyRewritePatternSet::freeze,
            "Freeze the pattern set into a frozen one.");
 
+  nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
+      m, "ConversionPatternRewriter");
+
+  nb::class_<PyConversionTarget>(m, "ConversionTarget")
+      .def(
+          "__init__",
+          [](PyConversionTarget &self, DefaultingPyMlirContext context) {
+            new (&self) PyConversionTarget(context.get()->get());
+          },
+          "context"_a = nb::none())
+      .def(
+          "add_legal_op",
+          [](PyConversionTarget &self, const nb::args &ops) {
+            for (auto op : ops) {
+              std::string opName =
+                  nb::cast<std::string>(op.attr("OPERATION_NAME"));
+              self.addLegalOp(opName);
+            }
+          },
+          "ops"_a, "Mark the given operations as legal.")
+      .def(
+          "add_illegal_op",
+          [](PyConversionTarget &self, const nb::args &ops) {
+            for (auto op : ops) {
+              std::string opName =
+                  nb::cast<std::string>(op.attr("OPERATION_NAME"));
+              self.addIllegalOp(opName);
+            }
+          },
+          "ops"_a, "Mark the given operations as illegal.")
+      .def(
+          "add_legal_dialect",
+          [](PyConversionTarget &self, const nb::args &dialects) {
+            for (auto dialect : dialects) {
+              std::string dialectName =
+                  nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
+              self.addLegalDialect(dialectName);
+            }
+          },
+          "dialects"_a, "Mark the given dialects as legal.")
+      .def(
+          "add_illegal_dialect",
+          [](PyConversionTarget &self, const nb::args &dialects) {
+            for (auto dialect : dialects) {
+              std::string dialectName =
+                  nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
+              self.addIllegalDialect(dialectName);
+            }
+          },
+          "dialects"_a, "Mark the given dialect as illegal.");
+
+  nb::class_<PyTypeConverter>(m, "TypeConverter")
+      .def(nb::init<>(), "Create a new TypeConverter.")
+      .def("add_conversion", &PyTypeConverter::addConversion, "convert"_a,
+           nb::keep_alive<0, 1>(), "Register a type conversion function.");
+
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule
   //----------------------------------------------------------------------------
@@ -498,6 +757,17 @@ void populateRewriteSubmodule(nb::module_ &m) {
                    &PyGreedyRewriteConfig::enableConstantCSE,
                    "Enable or disable constant CSE");
 
+  nb::class_<PyConversionConfig>(m, "ConversionConfig")
+      .def(nb::init<>(), "Create a conversion config with defaults")
+      .def_prop_rw("folding_mode", &PyConversionConfig::getFoldingMode,
+                   &PyConversionConfig::setFoldingMode,
+                   "folding behavior during dialect conversion")
+      .def_prop_rw("build_materializations",
+                   &PyConversionConfig::isBuildMaterializationsEnabled,
+                   &PyConversionConfig::enableBuildMaterializations,
+                   "Whether the dialect conversion attempts to build "
+                   "source/target materializations");
+
   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
                    &PyFrozenRewritePatternSet::getCapsule)
@@ -539,7 +809,35 @@ void populateRewriteSubmodule(nb::module_ &m) {
           },
           "op"_a, "set"_a,
           "Applies the given patterns to the given op by a fast walk-based "
-          "driver.");
+          "driver.")
+      .def(
+          "apply_partial_conversion",
+          [](PyOperationBase &op, PyConversionTarget &target,
+             PyFrozenRewritePatternSet &set,
+             std::optional<PyConversionConfig> config) {
+            if (!config)
+              config.emplace(PyConversionConfig());
+            MlirLogicalResult status = mlirApplyPartialConversion(
+                op.getOperation(), target.get(), set.get(), config->get());
+            if (mlirLogicalResultIsFailure(status))
+              throw std::runtime_error("partial conversion failed");
+          },
+          "op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
+          "Applies a partial conversion on the given operation.")
+      .def(
+          "apply_full_conversion",
+          [](PyOperationBase &op, PyConversionTarget &target,
+             PyFrozenRewritePatternSet &set,
+             std::optional<PyConversionConfig> config) {
+            if (!config)
+              config.emplace(PyConversionConfig());
+            MlirLogicalResult status = mlirApplyFullConversion(
+                op.getOperation(), target.get(), set.get(), config->get());
+            if (mlirLogicalResultIsFailure(status))
+              throw std::runtime_error("full conversion failed");
+          },
+          "op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
+          "Applies a full conversion on the given operation.");
 }
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python

diff  --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 65701ab070508..5900f08ae1730 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir-c/Rewrite.h"
 
+#include "mlir-c/Support.h"
 #include "mlir-c/Transforms.h"
 #include "mlir/CAPI/IR.h"
 #include "mlir/CAPI/Rewrite.h"
@@ -17,6 +18,7 @@
 #include "mlir/IR/PDLPatternMatch.h.inc"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
@@ -441,6 +443,73 @@ void mlirWalkAndApplyPatterns(MlirOperation op,
   mlir::walkAndApplyPatterns(unwrap(op), *unwrap(patterns));
 }
 
+MlirLogicalResult
+mlirApplyPartialConversion(MlirOperation op, MlirConversionTarget target,
+                           MlirFrozenRewritePatternSet patterns,
+                           MlirConversionConfig config) {
+  return wrap(mlir::applyPartialConversion(unwrap(op), *unwrap(target),
+                                           *unwrap(patterns), *unwrap(config)));
+}
+
+MlirLogicalResult mlirApplyFullConversion(MlirOperation op,
+                                          MlirConversionTarget target,
+                                          MlirFrozenRewritePatternSet patterns,
+                                          MlirConversionConfig config) {
+  return wrap(mlir::applyFullConversion(unwrap(op), *unwrap(target),
+                                        *unwrap(patterns), *unwrap(config)));
+}
+
+//===----------------------------------------------------------------------===//
+/// ConversionConfig API
+//===----------------------------------------------------------------------===//
+
+MlirConversionConfig mlirConversionConfigCreate(void) {
+  return wrap(new mlir::ConversionConfig());
+}
+
+void mlirConversionConfigDestroy(MlirConversionConfig config) {
+  delete unwrap(config);
+}
+
+void mlirConversionConfigSetFoldingMode(MlirConversionConfig config,
+                                        MlirDialectConversionFoldingMode mode) {
+  mlir::DialectConversionFoldingMode cppMode;
+  switch (mode) {
+  case MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER:
+    cppMode = mlir::DialectConversionFoldingMode::Never;
+    break;
+  case MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS:
+    cppMode = mlir::DialectConversionFoldingMode::BeforePatterns;
+    break;
+  case MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS:
+    cppMode = mlir::DialectConversionFoldingMode::AfterPatterns;
+    break;
+  }
+  unwrap(config)->foldingMode = cppMode;
+}
+
+MlirDialectConversionFoldingMode
+mlirConversionConfigGetFoldingMode(MlirConversionConfig config) {
+  switch (unwrap(config)->foldingMode) {
+  case mlir::DialectConversionFoldingMode::Never:
+    return MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER;
+  case mlir::DialectConversionFoldingMode::BeforePatterns:
+    return MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS;
+  case mlir::DialectConversionFoldingMode::AfterPatterns:
+    return MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS;
+  }
+}
+
+void mlirConversionConfigEnableBuildMaterializations(
+    MlirConversionConfig config, bool enable) {
+  unwrap(config)->buildMaterializations = enable;
+}
+
+bool mlirConversionConfigIsBuildMaterializationsEnabled(
+    MlirConversionConfig config) {
+  return unwrap(config)->buildMaterializations;
+}
+
 //===----------------------------------------------------------------------===//
 /// PatternRewriter API
 //===----------------------------------------------------------------------===//
@@ -449,6 +518,145 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
   return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
 }
 
+//===----------------------------------------------------------------------===//
+/// ConversionPatternRewriter API
+//===----------------------------------------------------------------------===//
+
+MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(
+    MlirConversionPatternRewriter rewriter) {
+  return wrap(static_cast<mlir::PatternRewriter *>(unwrap(rewriter)));
+}
+
+//===----------------------------------------------------------------------===//
+/// ConversionTarget API
+//===----------------------------------------------------------------------===//
+
+MlirConversionTarget mlirConversionTargetCreate(MlirContext context) {
+  return wrap(new mlir::ConversionTarget(*unwrap(context)));
+}
+
+void mlirConversionTargetDestroy(MlirConversionTarget target) {
+  delete unwrap(target);
+}
+
+void mlirConversionTargetAddLegalOp(MlirConversionTarget target,
+                                    MlirStringRef opName) {
+  unwrap(target)->addLegalOp(
+      mlir::OperationName(unwrap(opName), &unwrap(target)->getContext()));
+}
+
+void mlirConversionTargetAddIllegalOp(MlirConversionTarget target,
+                                      MlirStringRef opName) {
+  unwrap(target)->addIllegalOp(
+      mlir::OperationName(unwrap(opName), &unwrap(target)->getContext()));
+}
+
+void mlirConversionTargetAddLegalDialect(MlirConversionTarget target,
+                                         MlirStringRef dialectName) {
+  unwrap(target)->addLegalDialect(unwrap(dialectName));
+}
+
+void mlirConversionTargetAddIllegalDialect(MlirConversionTarget target,
+                                           MlirStringRef dialectName) {
+  unwrap(target)->addIllegalDialect(unwrap(dialectName));
+}
+
+//===----------------------------------------------------------------------===//
+/// TypeConverter API
+//===----------------------------------------------------------------------===//
+
+MlirTypeConverter mlirTypeConverterCreate() {
+  return wrap(new mlir::TypeConverter());
+}
+
+void mlirTypeConverterDestroy(MlirTypeConverter typeConverter) {
+  delete unwrap(typeConverter);
+}
+
+void mlirTypeConverterAddConversion(
+    MlirTypeConverter typeConverter,
+    MlirTypeConverterConversionCallback convertType, void *userData) {
+  unwrap(typeConverter)
+      ->addConversion(
+          [convertType, userData](Type type) -> std::optional<Type> {
+            MlirType converted{nullptr};
+            MlirLogicalResult result =
+                convertType(wrap(type), &converted, userData);
+            if (mlirLogicalResultIsFailure(result))
+              return std::nullopt; // allowed to try another conversion function
+            if (mlirTypeIsNull(converted))
+              return nullptr;
+            return unwrap(converted);
+          });
+}
+
+//===----------------------------------------------------------------------===//
+/// ConversionPattern API
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+
+class ExternalConversionPattern : public mlir::ConversionPattern {
+public:
+  ExternalConversionPattern(MlirConversionPatternCallbacks callbacks,
+                            void *userData, StringRef rootName,
+                            PatternBenefit benefit, MLIRContext *context,
+                            TypeConverter *typeConverter,
+                            ArrayRef<StringRef> generatedNames)
+      : ConversionPattern(*typeConverter, rootName, benefit, context,
+                          generatedNames),
+        callbacks(callbacks), userData(userData) {
+    if (callbacks.construct)
+      callbacks.construct(userData);
+  }
+
+  ~ExternalConversionPattern() {
+    if (callbacks.destruct)
+      callbacks.destruct(userData);
+  }
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    std::vector<MlirValue> wrappedOperands;
+    for (Value val : operands)
+      wrappedOperands.push_back(wrap(val));
+    return unwrap(callbacks.matchAndRewrite(
+        wrap(static_cast<const mlir::ConversionPattern *>(this)), wrap(op),
+        wrappedOperands.size(), wrappedOperands.data(), wrap(&rewriter),
+        userData));
+  }
+
+private:
+  MlirConversionPatternCallbacks callbacks;
+  void *userData;
+};
+
+} // namespace mlir
+
+MlirConversionPattern mlirOpConversionPatternCreate(
+    MlirStringRef rootName, unsigned benefit, MlirContext context,
+    MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks,
+    void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames) {
+  std::vector<mlir::StringRef> generatedNamesVec;
+  generatedNamesVec.reserve(nGeneratedNames);
+  for (size_t i = 0; i < nGeneratedNames; ++i)
+    generatedNamesVec.push_back(unwrap(generatedNames[i]));
+  return wrap(new mlir::ExternalConversionPattern(
+      callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+      unwrap(context), unwrap(typeConverter), generatedNamesVec));
+}
+
+MlirTypeConverter
+mlirConversionPatternGetTypeConverter(MlirConversionPattern pattern) {
+  return wrap(const_cast<TypeConverter *>(unwrap(pattern)->getTypeConverter()));
+}
+
+MlirRewritePattern
+mlirConversionPatternAsRewritePattern(MlirConversionPattern pattern) {
+  return wrap(static_cast<const RewritePattern *>(unwrap(pattern)));
+}
+
 //===----------------------------------------------------------------------===//
 /// RewritePattern API
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 929851724ba71..9bfda1ec02303 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -16,8 +16,8 @@ class TestOp<string mnemonic, list<Trait> traits = []> :
     Op<Test_Dialect, mnemonic, traits>;
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class AttrSizedOperandsOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands"
+// CHECK-LABEL: class AttrSizedOperandsOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.attr_sized_operands"
 // CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,0,]
 def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
                                  [AttrSizedOperandSegments]> {
@@ -64,8 +64,8 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
 // CHECK:   return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
+// CHECK-LABEL: class AttrSizedResultsOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.attr_sized_results"
 // CHECK: _ODS_RESULT_SEGMENTS = [0,1,-1,]
 def AttrSizedResultsOp : TestOp<"attr_sized_results",
                                [AttrSizedResultSegments]> {
@@ -112,10 +112,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
 // CHECK:   op = AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip); results = op.results
 // CHECK:   return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
 
-
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class AttributedOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.attributed_op"
+// CHECK-LABEL: class AttributedOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.attributed_op"
 // CHECK-NOT: _ODS_OPERAND_SEGMENTS
 // CHECK-NOT: _ODS_RESULT_SEGMENTS
 def AttributedOp : TestOp<"attributed_op"> {
@@ -159,13 +158,30 @@ def AttributedOp : TestOp<"attributed_op"> {
   let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
                    UnitAttr:$unitAttr, I32Attr:$in);
 }
+// CHECK: @_ods_cext.register_op_adaptor(AttributedOp)
+// CHECK-LABEL: class AttributedOpAdaptor(_ods_ir.OpAdaptor):
+// CHECK:   OPERATION_NAME = "test.attributed_op"
+// CHECK:   @builtins.property
+// CHECK:   def i32attr(self) -> _ods_ir.IntegerAttr:
+// CHECK:     return self.attributes["i32attr"]
+// CHECK:   @builtins.property
+// CHECK:   def optionalF32Attr(self) -> _Optional[_ods_ir.FloatAttr]:
+// CHECK:     if "optionalF32Attr" not in self.attributes:
+// CHECK:       return None
+// CHECK:     return self.attributes["optionalF32Attr"]
+// CHECK:   @builtins.property
+// CHECK:   def unitAttr(self) -> bool:
+// CHECK:     return "unitAttr" in self.attributes
+// CHECK:   @builtins.property
+// CHECK:   def in_(self) -> _ods_ir.IntegerAttr:
+// CHECK:     return self.attributes["in"]
 
 // CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -> AttributedOp:
 // CHECK:     return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
+// CHECK-LABEL: class AttributedOpWithOperands(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.attributed_op_with_operands"
 // CHECK-NOT: _ODS_OPERAND_SEGMENTS
 // CHECK-NOT: _ODS_RESULT_SEGMENTS
 def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
@@ -201,8 +217,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
 // CHECK:   return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs"
+// CHECK-LABEL: class DefaultValuedAttrsOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.default_valued_attrs"
 def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
   // CHECK: def __init__(self, *, arr=None, unsupported=None, loc=None, ip=None):
   // CHECK:   operands = []
@@ -283,8 +299,8 @@ def DescriptionOp : TestOp<"description"> {
 }
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class EmptyOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.empty"
+// CHECK-LABEL: class EmptyOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.empty"
 def EmptyOp : TestOp<"empty">;
   // CHECK: def __init__(self, *, loc=None, ip=None):
   // CHECK:   operands = []
@@ -329,8 +345,8 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
 // CHECK:   return InferResultTypesOp(results=results, loc=loc, ip=ip).results
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class MissingNamesOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.missing_names"
+// CHECK-LABEL: class MissingNamesOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.missing_names"
 def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None):
   // CHECK:   operands = []
@@ -368,8 +384,8 @@ def MissingNamesOp : TestOp<"missing_names"> {
 // CHECK:   return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand"
+// CHECK-LABEL: class OneOptionalOperandOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.one_optional_operand"
 // CHECK-NOT: _ODS_OPERAND_SEGMENTS
 // CHECK-NOT: _ODS_RESULT_SEGMENTS
 def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
@@ -395,13 +411,22 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
   // CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
   // CHECK:   return None if len(self.operation.operands) < 2 else self.operation.operands[1]
 }
+// CHECK: @_ods_cext.register_op_adaptor(OneOptionalOperandOp)
+// CHECK-LABEL: class OneOptionalOperandOpAdaptor(_ods_ir.OpAdaptor):
+// CHECK:   OPERATION_NAME = "test.one_optional_operand"
+// CHECK:   @builtins.property
+// CHECK:   def non_optional(self) -> _ods_ir.Value:
+// CHECK:     return self.operands[0]
+// CHECK:   @builtins.property
+// CHECK:   def optional(self) -> _Optional[_ods_ir.Value]:
+// CHECK:     return None if len(self.operands) < 2 else self.operands[1]
 
 // CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -> OneOptionalOperandOp:
 // CHECK:   return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
+// CHECK-LABEL: class OneVariadicOperandOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.one_variadic_operand"
 // CHECK-NOT: _ODS_OPERAND_SEGMENTS
 // CHECK-NOT: _ODS_RESULT_SEGMENTS
 def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
@@ -428,13 +453,23 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   // CHECK:   return self.operation.operands[1:1 + _ods_variadic_group_length]
   let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
 }
+// CHECK: @_ods_cext.register_op_adaptor(OneVariadicOperandOp)
+// CHECK-LABEL: class OneVariadicOperandOpAdaptor(_ods_ir.OpAdaptor):
+// CHECK:   OPERATION_NAME = "test.one_variadic_operand"
+// CHECK:   @builtins.property
+// CHECK:   def non_variadic(self) -> _ods_ir.Value:
+// CHECK:     return self.operands[0]
+// CHECK:   @builtins.property
+// CHECK:   def variadic(self) -> _ods_ir.OpOperandList:
+// CHECK:     _ods_variadic_group_length = len(self.operands) - 2 + 1
+// CHECK:     return self.operands[1:1 + _ods_variadic_group_length]
 
 // CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -> OneVariadicOperandOp:
 // CHECK:   return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
+// CHECK-LABEL: class OneVariadicResultOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.one_variadic_result"
 // CHECK-NOT: _ODS_OPERAND_SEGMENTS
 // CHECK-NOT: _ODS_RESULT_SEGMENTS
 def OneVariadicResultOp : TestOp<"one_variadic_result"> {
@@ -468,8 +503,8 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
 // CHECK:   return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class PythonKeywordOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.python_keyword"
+// CHECK-LABEL: class PythonKeywordOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.python_keyword"
 def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK: def __init__(self, in_, *, loc=None, ip=None):
   // CHECK:   operands = []
@@ -518,8 +553,8 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
 
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_operand"
+// CHECK-LABEL: class SameVariadicOperandSizeOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.same_variadic_operand"
 def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
                                        [SameVariadicOperandSize]> {
   // CHECK: @builtins.property
@@ -544,8 +579,8 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
 // CHECK:   return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_result"
+// CHECK-LABEL: class SameVariadicResultSizeOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.same_variadic_result"
 def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
                                       [SameVariadicResultSize]> {
   // CHECK: @builtins.property
@@ -571,8 +606,8 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
 // CHECK:   return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class SimpleOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.simple"
+// CHECK-LABEL: class SimpleOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.simple"
 def SimpleOp : TestOp<"simple"> {
   // CHECK: def __init__(self, i64, f64, i32, f32, *, loc=None, ip=None):
   // CHECK:   operands = []
@@ -611,8 +646,8 @@ def SimpleOp : TestOp<"simple"> {
 // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _ods_ir.OpResultList:
 // CHECK:   return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results
 
-// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
+// CHECK-LABEL: class VariadicAndNormalRegionOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.variadic_and_normal_region"
 def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
   // CHECK:  def __init__(self, num_variadic, *, loc=None, ip=None):
   // CHECK:    operands = []
@@ -639,8 +674,8 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
 // CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -> VariadicAndNormalRegionOp:
 // CHECK:   return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
 
-// CHECK: class VariadicRegionOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
+// CHECK-LABEL: class VariadicRegionOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.variadic_region"
 def VariadicRegionOp : TestOp<"variadic_region"> {
   // CHECK:  def __init__(self, num_variadic, *, loc=None, ip=None):
   // CHECK:    operands = []
@@ -664,8 +699,8 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
 // CHECK:   return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.123with--special.characters"
+// CHECK-LABEL: class WithSpecialCharactersOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.123with--special.characters"
 def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
 }
 
@@ -673,8 +708,8 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
 // CHECK:   return WithSpecialCharactersOp(loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.with_successors"
+// CHECK-LABEL: class WithSuccessorsOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.with_successors"
 def WithSuccessorsOp : TestOp<"with_successors"> {
   // CHECK-NOT:  _ods_successors = None
   // CHECK:      _ods_successors = []
@@ -687,8 +722,8 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
 // CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -> WithSuccessorsOp:
 // CHECK:   return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
 
-// CHECK: class snake_case(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.snake_case"
+// CHECK-LABEL: class snake_case(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.snake_case"
 def already_snake_case : TestOp<"snake_case"> {}
 // CHECK: def snake_case_(*, loc=None, ip=None) -> snake_case:
 // CHECK:   return snake_case(loc=loc, ip=ip)

diff  --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 8ef49981a8b3c..d805831369ce0 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -226,3 +226,76 @@ def constant_1_to_2(op, rewriter):
         # CHECK: %c2_i64 = arith.constant 2 : i64
         # CHECK: return %c2_i64, %c2_i64 : i64
         print(module)
+
+
+ at run
+def testConversionPattern():
+    from mlir.dialects import smt
+
+    def convert_int(t):
+        if isinstance(t, IntegerType):
+            return smt.IntType.get()
+
+    converter = TypeConverter()
+    converter.add_conversion(convert_int)
+
+    def convert_constant(op, adaptor, type_converter, rewriter):
+        assert isinstance(op, arith.ConstantOp)
+        assert isinstance(adaptor, arith.ConstantOpAdaptor)
+        with rewriter.ip:
+            new_op = smt.IntConstantOp(op.value, loc=op.location)
+        rewriter.replace_op(op, new_op)
+
+    def convert_addi(op, adaptor, type_converter, rewriter):
+        assert isinstance(op, arith.AddIOp)
+        assert isinstance(adaptor, arith.AddIOpAdaptor)
+        with rewriter.ip:
+            new_op = smt.IntAddOp([adaptor.lhs, adaptor.rhs], loc=op.location)
+        rewriter.replace_op(op, new_op)
+
+    def convert_muli(op, adaptor, type_converter, rewriter):
+        assert isinstance(op, arith.MulIOp)
+        assert isinstance(adaptor, arith.MulIOpAdaptor)
+        with rewriter.ip:
+            new_op = smt.IntMulOp([adaptor.lhs, adaptor.rhs], loc=op.location)
+        rewriter.replace_op(op, new_op)
+
+    with Context():
+        patterns = RewritePatternSet()
+        patterns.add_conversion(arith.ConstantOp, convert_constant, converter)
+        patterns.add_conversion(arith.AddIOp, convert_addi, converter)
+        patterns.add_conversion(arith.MulIOp, convert_muli, converter)
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+                func.func @f(%0: i64) -> i64 {
+                    %1 = arith.constant 3 : i64
+                    %2 = arith.addi %0, %1 : i64
+                    %3 = arith.muli %2, %1 : i64
+                    return %3 : i64
+                }
+            }
+            """
+        )
+
+        target = ConversionTarget()
+        target.add_legal_dialect(smt._Dialect)
+        target.add_illegal_op(arith.ConstantOp, arith.AddIOp, arith.MulIOp)
+
+        frozen = patterns.freeze()
+        config = ConversionConfig()
+        config.build_materializations = False
+
+        apply_partial_conversion(module, target, frozen, config)
+        assert module.operation.verify()
+
+        # CHECK: func.func @f(%arg0: i64) -> i64 {
+        # CHECK:     %0 = builtin.unrealized_conversion_cast %arg0 : i64 to !smt.int
+        # CHECK:     %c3 = smt.int.constant 3
+        # CHECK:     %1 = smt.int.add %0, %c3
+        # CHECK:     %2 = smt.int.mul %1, %c3
+        # CHECK:     %3 = builtin.unrealized_conversion_cast %2 : !smt.int to i64
+        # CHECK:     return %3 : i64
+        # CHECK: }
+        print(module)

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 4c1f82b164c39..e8acf4ce40fc8 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -72,6 +72,15 @@ class {0}(_ods_ir.OpView):{2}
   OPERATION_NAME = "{1}"
 )Py";
 
+/// Template for operation class:
+///   {0} is the Python class name;
+///   {1} is the operation name。
+constexpr const char *opAdaptorClassTemplate = R"Py(
+ at _ods_cext.register_op_adaptor({0})
+class {0}Adaptor(_ods_ir.OpAdaptor):
+  OPERATION_NAME = "{1}"
+)Py";
+
 /// Template for class level declarations of operand and result
 /// segment specs.
 ///   {0} is either "OPERAND" or "RESULT"
@@ -100,7 +109,7 @@ constexpr const char *opClassRegionSpecTemplate = R"Py(
 constexpr const char *opSingleTemplate = R"Py(
   @builtins.property
   def {0}(self) -> {3}:
-    return self.operation.{1}s[{2}]
+    return self.{1}s[{2}]
 )Py";
 
 /// Template for single-element accessor after a variable-length group:
@@ -114,8 +123,8 @@ constexpr const char *opSingleTemplate = R"Py(
 constexpr const char *opSingleAfterVariableTemplate = R"Py(
   @builtins.property
   def {0}(self) -> {4}:
-    _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
-    return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
+    _ods_variadic_group_length = len(self.{1}s) - {2} + 1
+    return self.{1}s[{3} + _ods_variadic_group_length - 1]
 )Py";
 
 /// Template for an optional element accessor:
@@ -130,7 +139,7 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
 constexpr const char *opOneOptionalTemplate = R"Py(
   @builtins.property
   def {0}(self) -> _Optional[{4}]:
-    return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
+    return None if len(self.{1}s) < {2} else self.{1}s[{3}]
 )Py";
 
 /// Template for the variadic group accessor in the single variadic group case:
@@ -142,8 +151,8 @@ constexpr const char *opOneOptionalTemplate = R"Py(
 constexpr const char *opOneVariadicTemplate = R"Py(
   @builtins.property
   def {0}(self) -> {4}:
-    _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
-    return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
+    _ods_variadic_group_length = len(self.{1}s) - {2} + 1
+    return self.{1}s[{3}:{3} + _ods_variadic_group_length]
 )Py";
 
 /// First part of the template for equally-sized variadic group accessor:
@@ -157,20 +166,20 @@ constexpr const char *opOneVariadicTemplate = R"Py(
 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
   @builtins.property
   def {0}(self) -> {6}:
-    start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
+    start, elements_per_group = _ods_equally_sized_accessor(self.{1}s, {2}, {3}, {4}, {5}))Py";
 
 /// Second part of the template for equally-sized case, accessing a single
 /// element:
 ///   {0} is either 'operand' or 'result'.
 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
-    return self.operation.{0}s[start]
+    return self.{0}s[start]
 )Py";
 
 /// Second part of the template for equally-sized case, accessing a variadic
 /// group:
 ///   {0} is either 'operand' or 'result'.
 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
-    return self.operation.{0}s[start:start + elements_per_group]
+    return self.{0}s[start:start + elements_per_group]
 )Py";
 
 /// Template for an attribute-sized group accessor:
@@ -178,14 +187,16 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the position of the group in the group list;
 ///   {3} is a return suffix (expected [0] for single-element, empty for
-///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
-///   {4} is the type hint.
+///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional);
+///   {4} is the type hint;
+///   {5} is the instance variable name in python;
+///   {6} is the instance variable name for attributes in python.
 constexpr const char *opVariadicSegmentTemplate = R"Py(
   @builtins.property
   def {0}(self) -> {4}:
     {1}_range = _ods_segmented_accessor(
-         self.operation.{1}s,
-         self.operation.attributes["{1}SegmentSizes"], {2})
+         self.{5}s,
+         self.{6}["{1}SegmentSizes"], {2})
     return {1}_range{3}
 )Py";
 
@@ -217,6 +228,28 @@ constexpr const char *optionalAttributeGetterTemplate = R"Py(
     return self.operation.attributes["{1}"]
 )Py";
 
+/// Template for an operation attribute getter for adaptors:
+///   {0} is the name of the attribute sanitized for Python;
+///   {1} is the original name of the attribute.
+///   {2} is the type hint.
+constexpr const char *adaptorAttributeGetterTemplate = R"Py(
+  @builtins.property
+  def {0}(self) -> {2}:
+    return self.attributes["{1}"]
+)Py";
+
+/// Template for an optional operation attribute getter for adaptors:
+///   {0} is the name of the attribute sanitized for Python;
+///   {1} is the original name of the attribute.
+///   {2} is the type hint.
+constexpr const char *adaptorOptionalAttributeGetterTemplate = R"Py(
+  @builtins.property
+  def {0}(self) -> _Optional[{2}]:
+    if "{1}" not in self.attributes:
+      return None
+    return self.attributes["{1}"]
+)Py";
+
 /// Template for a getter of a unit operation attribute, returns True of the
 /// unit attribute is present, False otherwise (unit attributes have meaning
 /// by mere presence):
@@ -228,6 +261,17 @@ constexpr const char *unitAttributeGetterTemplate = R"Py(
     return "{1}" in self.operation.attributes
 )Py";
 
+/// Template for a getter of a unit operation attribute for adaptors, returns
+/// True of the unit attribute is present, False otherwise (unit attributes have
+/// meaning by mere presence):
+///    {0} is the name of the attribute sanitized for Python,
+///    {1} is the original name of the attribute.
+constexpr const char *adaptorUnitAttributeGetterTemplate = R"Py(
+  @builtins.property
+  def {0}(self) -> bool:
+    return "{1}" in self.attributes
+)Py";
+
 /// Template for an operation attribute setter:
 ///    {0} is the name of the attribute sanitized for Python;
 ///    {1} is the original name of the attribute.
@@ -365,7 +409,8 @@ static void emitElementAccessors(
     const Operator &op, raw_ostream &os, const char *kind,
     unsigned numVariadicGroups, unsigned numElements,
     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
-        getElement) {
+        getElement,
+    bool isAdaptor = false) {
   assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
                             kind) &&
          "unsupported kind");
@@ -376,6 +421,8 @@ static void emitElementAccessors(
                                       StringRef(kind).drop_front());
   std::string attrSizedTrait = attrSizedTraitForKind(kind);
 
+  std::string pyAttrName = isAdaptor ? kind : std::string("operation.") + kind;
+
   // If there is only one variable-length element group, its size can be
   // inferred from the total number of elements. If there are none, the
   // generation is straightforward.
@@ -394,20 +441,20 @@ static void emitElementAccessors(
         type = llvm::formatv("{0}[{1}]", type, pythonType);
       if (element.isVariableLength()) {
         if (element.isOptional()) {
-          os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
-                        numElements, i, type);
+          os << formatv(opOneOptionalTemplate, sanitizeName(element.name),
+                        pyAttrName, numElements, i, type);
         } else {
           type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
                                                    : "_ods_ir.OpResultList";
-          os << formatv(opOneVariadicTemplate, sanitizeName(element.name), kind,
-                        numElements, i, type);
+          os << formatv(opOneVariadicTemplate, sanitizeName(element.name),
+                        pyAttrName, numElements, i, type);
         }
       } else if (seenVariableLength) {
         os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
-                      kind, numElements, i, type);
+                      pyAttrName, numElements, i, type);
       } else {
-        os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i,
-                      type);
+        os << formatv(opSingleTemplate, sanitizeName(element.name), pyAttrName,
+                      i, type);
       }
     }
     return;
@@ -445,12 +492,12 @@ static void emitElementAccessors(
             type += "[" + pythonType.str() + "]";
         }
         os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
-                      kind, numSimpleLength, numVariadicGroups,
+                      pyAttrName, numSimpleLength, numVariadicGroups,
                       numPrecedingSimple, numPrecedingVariadic, type);
         os << formatv(element.isVariableLength()
                           ? opVariadicEqualVariadicTemplate
                           : opVariadicEqualSimpleTemplate,
-                      kind);
+                      pyAttrName);
       }
       if (element.isVariableLength())
         ++numPrecedingVariadic;
@@ -491,7 +538,8 @@ static void emitElementAccessors(
       }
 
       os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
-                    i, trailing, type);
+                    i, trailing, type, pyAttrName,
+                    isAdaptor ? "attributes" : "operation.attributes");
     }
     return;
   }
@@ -625,6 +673,34 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
   }
 }
 
+/// Emits accessors to Op attributes for adaptors.
+static void emitAdaptorAttributeAccessors(const Operator &op, raw_ostream &os) {
+  for (const auto &namedAttr : op.getAttributes()) {
+    // Skip "derived" attributes because they are just C++ functions that we
+    // don't currently expose.
+    if (namedAttr.attr.isDerivedAttr())
+      continue;
+
+    if (namedAttr.name.empty())
+      continue;
+
+    std::string sanitizedName = sanitizeName(namedAttr.name);
+
+    // Unit attributes are handled specially.
+    if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
+      os << formatv(adaptorUnitAttributeGetterTemplate, sanitizedName,
+                    namedAttr.name);
+      continue;
+    }
+
+    std::string type = "_ods_ir." + getPythonAttrName(namedAttr.attr);
+    os << formatv(namedAttr.attr.isOptional()
+                      ? adaptorOptionalAttributeGetterTemplate
+                      : adaptorAttributeGetterTemplate,
+                  sanitizedName, namedAttr.name, type);
+  }
+}
+
 /// Template for the default auto-generated builder.
 ///   {0} is a comma-separated list of builder arguments, including the trailing
 ///       `loc` and `ip`;
@@ -1194,6 +1270,11 @@ static std::string makeDocStringForOp(const Operator &op) {
   return docString;
 }
 
+static void emitAdaptorOperandAccessors(const Operator &op, raw_ostream &os) {
+  emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(),
+                       getNumOperands(op), getOperand, /*isAdaptor=*/true);
+}
+
 /// Emits bindings for a specific Op to the given output stream.
 static void emitOpBindings(const Operator &op, raw_ostream &os) {
   os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName(),
@@ -1213,6 +1294,12 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
   emitAttributeAccessors(op, os);
   emitResultAccessors(op, os);
   emitRegionAccessors(op, os);
+
+  os << formatv(opAdaptorClassTemplate, op.getCppClassName(),
+                op.getOperationName());
+  emitAdaptorOperandAccessors(op, os);
+  emitAdaptorAttributeAccessors(op, os);
+
   emitValueBuilder(op, functionArgs, os);
 }
 


        


More information about the Mlir-commits mailing list