[Mlir-commits] [mlir] [MLIR][Python] Impl XOpInterface(s) from Python, with X=Transform and X=MemoryEffects (PR #176920)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 27 12:57:54 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Rolf Morel (rolfmorel)

<details>
<summary>Changes</summary>

Provides the infrastructure for implementing and late-binding OpInterfaces from Python. 

* On the mlir-c API declaration side, each `XOpInterface` has a callback struct, with a callback for each method and a userdata member (provided as an arg to each method), and a `mlirXOpInterfaceAttachFallbackModel(ctx, op_name, callbacks)` func.
  * For `MemoryEffectsOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-da9b0ad13ba75167cca65c45470dd4063bceca4603f6b6c82e3e01179b7b566bR107-R123
  * For `TransformOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-8c29f5db66b6d811f22ed39e91578a320fd8053e317ce3d8bfeafcd5d86d9469R185-R208
* This CAPI is implemented by defining a subclass of `XOpInterface::FallbackModel` that holds the callback struct and has each method call the corresponding callback (with userdata as an arg). Given a callback struct, a new `FallbackModel` is created and attached, i.e. late bound, to the named op. (MLIR's interface infrastructure is such that the thus registered `FallbackModel` will be returned in case the op gets cast to the `XOpInterface`.)
  * For `MemoryEffectsOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-e599b77dd395948103500ae86e792c37d622af721bb1bcdacada921b20250e4eR179-R243
  * For `TransformOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-07f07055f38e2bcd03357d9be81bd563c51d499a094c418816d39faa145bd11eR213-R302
* On the Python side, we expose a stand-in `XOpInterface` base class which has one (class)method: `XOpInterface.attach(cls, op_name, ctx)`. Python users subclass this class (`class MyInterfaceImpl(XOpInterface): ...`) and implement the interface's methods (with the right names and signatures). The user calls `attach` on the subclass (`MyInterfaceImpl.attach("my_dialect.my_op", ctx)`) which prepares the callbacks struct _with userdata set to the subclass_ (as we use it to lookup methods). These callbacks (and userdata) are then registered as an `XOpInterface::FallbackModel` by `mlirXOpInterfaceAttachFallbackModel(...)`. From then on the Python methods will be used to respond to calls to the interface methods (originating in C++).
  * For `MemoryEffectsOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-2400af23004f6bad0f7c3193cccaa544816b7e624c3513f9afe5939222236039R343-R389
  * For `TransformOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-22bf2dd9a1ae598e9ee814d2ed92f3e394d0394844fa8ac3330b63c2bb131a17R159-R238

This PR enables implementing the TransformOpInterface and the MemoryEffectsOpInterface, both of which are required for making an op into a transform op.

Everything besides the above linked code is there to facilitate exposing the interfaces: the right types for the arguments of the methods are exposed as are functions/methods for manipulating these arguments (e.g. specifying side effects on `OpOperand`s and `OpResult`s and being able to access and set the transform handles associated with args and results).

---

Patch is 91.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/176920.diff


22 Files Affected:

- (modified) mlir/include/mlir-c/Dialect/Transform.h (+140) 
- (modified) mlir/include/mlir-c/IR.h (+8) 
- (modified) mlir/include/mlir-c/Interfaces.h (+37-2) 
- (modified) mlir/include/mlir/Bindings/Python/IRCore.h (+10-1) 
- (added) mlir/include/mlir/CAPI/Dialect/Transform.h (+28) 
- (modified) mlir/include/mlir/CAPI/IR.h (+1) 
- (modified) mlir/include/mlir/CAPI/Interfaces.h (+8) 
- (modified) mlir/lib/Bindings/Python/DialectTransform.cpp (+272-2) 
- (modified) mlir/lib/Bindings/Python/Globals.cpp (+1-2) 
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+45-8) 
- (modified) mlir/lib/Bindings/Python/IRInterfaces.cpp (+58-127) 
- (added) mlir/lib/Bindings/Python/IRInterfaces.h (+156) 
- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+7-57) 
- (modified) mlir/lib/Bindings/Python/Rewrite.h (+69-2) 
- (modified) mlir/lib/CAPI/Dialect/Transform.cpp (+203) 
- (modified) mlir/lib/CAPI/IR/IR.cpp (+12) 
- (modified) mlir/lib/CAPI/Interfaces/Interfaces.cpp (+74) 
- (modified) mlir/lib/IR/OperationSupport.cpp (+1-1) 
- (modified) mlir/python/CMakeLists.txt (+2) 
- (modified) mlir/python/mlir/_mlir_libs/__init__.py (+1) 
- (modified) mlir/python/mlir/dialects/ext.py (+20-13) 
- (added) mlir/test/python/dialects/transform_op_interface.py (+494) 


``````````diff
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 911c9ef659a1e..674d5f8b7b72d 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -11,6 +11,8 @@
 #define MLIR_C_DIALECT_TRANSFORM_H
 
 #include "mlir-c/IR.h"
+#include "mlir-c/Interfaces.h"
+#include "mlir-c/Rewrite.h"
 #include "mlir-c/Support.h"
 
 #ifdef __cplusplus
@@ -19,6 +21,32 @@ extern "C" {
 
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
 
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirTransformResults, void);
+DEFINE_C_API_STRUCT(MlirTransformRewriter, void);
+DEFINE_C_API_STRUCT(MlirTransformState, void);
+
+#undef DEFINE_C_API_STRUCT
+
+//===---------------------------------------------------------------------===//
+// DiagnosedSilenceableFailure
+//===---------------------------------------------------------------------===//
+
+/// Enum representing the result of a transform operation.
+typedef enum {
+  /// The operation succeeded.
+  MlirDiagnosedSilenceableFailureSuccess,
+  /// The operation failed in a silenceable way.
+  MlirDiagnosedSilenceableFailureSilenceableFailure,
+  /// The operation failed definitively.
+  MlirDiagnosedSilenceableFailureDefiniteFailure
+} MlirDiagnosedSilenceableFailure;
+
 //===---------------------------------------------------------------------===//
 // AnyOpType
 //===---------------------------------------------------------------------===//
@@ -86,6 +114,118 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirTransformParamTypeGetName(void);
 
 MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
 
+//===---------------------------------------------------------------------===//
+// TransformRewriter
+//===---------------------------------------------------------------------===//
+
+/// Cast the TransformRewriter to a RewriterBase
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirTransformRewriterAsBase(MlirTransformRewriter rewriter);
+
+//===---------------------------------------------------------------------===//
+// TransformResults
+//===---------------------------------------------------------------------===//
+
+/// Set the payload operations for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void mlirTransformResultsSetOps(MlirTransformResults results,
+                                                   MlirValue result,
+                                                   intptr_t numOps,
+                                                   MlirOperation *ops);
+
+/// Set the payload values for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void
+mlirTransformResultsSetValues(MlirTransformResults results, MlirValue result,
+                              intptr_t numValues, MlirValue *values);
+
+/// Set the parameters for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void
+mlirTransformResultsSetParams(MlirTransformResults results, MlirValue result,
+                              intptr_t numParams, MlirAttribute *params);
+
+//===---------------------------------------------------------------------===//
+// TransformState
+//===---------------------------------------------------------------------===//
+
+/// Callback for iterating over payload operations.
+typedef void (*MlirOperationCallback)(MlirOperation, void *userData);
+
+/// Iterate over payload operations associated with the transform IR value.
+/// Calls the callback for each payload operation.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachPayloadOp(MlirTransformState state, MlirValue value,
+                                   MlirOperationCallback callback,
+                                   void *userData);
+
+/// Callback for iterating over payload values.
+typedef void (*MlirValueCallback)(MlirValue, void *userData);
+
+/// Iterate over payload values associated with the transform IR value.
+/// Calls the callback for each payload value.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachPayloadValue(MlirTransformState state, MlirValue value,
+                                      MlirValueCallback callback,
+                                      void *userData);
+
+/// Callback for iterating over parameters.
+typedef void (*MlirAttributeCallback)(MlirAttribute, void *userData);
+
+/// Iterate over parameters associated with the transform IR value.
+/// Calls the callback for each parameter.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
+                               MlirAttributeCallback callback, void *userData);
+
+//===---------------------------------------------------------------------===//
+// TransformOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the TransformOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOpInterfaceTypeID(void);
+
+/// Callbacks for implementing TransformOpInterface from external code.
+typedef struct {
+  /// Optional constructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*construct)(void *userData);
+  /// Optional destructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*destruct)(void *userData);
+  /// Apply callback that implements the transformation.
+  MlirDiagnosedSilenceableFailure (*apply)(MlirOperation op,
+                                           MlirTransformRewriter rewriter,
+                                           MlirTransformResults results,
+                                           MlirTransformState state,
+                                           void *userData);
+  /// Callback to check if repeated handle operands are allowed.
+  bool (*allowsRepeatedHandleOperands)(MlirOperation op, void *userData);
+  void *userData;
+} MlirTransformOpInterfaceCallbacks;
+
+/// Attach TransformOpInterface to the operation with the given name using
+/// the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
+    MlirContext ctx, MlirStringRef opName,
+    MlirTransformOpInterfaceCallbacks callbacks);
+
+//===---------------------------------------------------------------------===//
+// Transform-specifc MemoryEffectsOpInterface helpers
+//===---------------------------------------------------------------------===//
+
+/// Helper to mark operands as only reading handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
+                             MlirMemoryEffectInstancesList effects);
+
+/// Helper to mark operands as consuming handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
+                            MlirMemoryEffectInstancesList effects);
+
+/// Helper to mark results as producing handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
+                            MlirMemoryEffectInstancesList effects);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 80ff39c82a9ee..9b58267cf1c5a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -55,6 +55,7 @@ DEFINE_C_API_STRUCT(MlirDialect, void);
 DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
 DEFINE_C_API_STRUCT(MlirOperation, void);
 DEFINE_C_API_STRUCT(MlirOpOperand, void);
+DEFINE_C_API_STRUCT(MlirOpResult, void);
 DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
 DEFINE_C_API_STRUCT(MlirBlock, void);
 DEFINE_C_API_STRUCT(MlirRegion, void);
@@ -673,6 +674,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op);
 MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op,
                                                      intptr_t pos);
 
+/// Returns `pos`-th OpOperand of the operation.
+MLIR_CAPI_EXPORTED MlirOpOperand mlirOperationGetOpOperand(MlirOperation op,
+                                                           intptr_t pos);
+
 /// Sets the `pos`-th operand of the operation.
 MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
                                                 MlirValue newValue);
@@ -1044,6 +1049,9 @@ MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value);
 /// Returns 1 if the value is an operation result, 0 otherwise.
 MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value);
 
+/// Cast the value to an OpResult. Asserts if the value is not an op result.
+MLIR_CAPI_EXPORTED MlirOpResult mlirValueToOpResult(MlirValue value);
+
 /// Returns the block in which this value is defined as an argument. Asserts if
 /// the value is not a block argument.
 MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value);
diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h
index a5a3473eaef59..17a812dcd86a9 100644
--- a/mlir/include/mlir-c/Interfaces.h
+++ b/mlir/include/mlir-c/Interfaces.h
@@ -22,6 +22,16 @@
 extern "C" {
 #endif
 
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirMemoryEffectInstancesList, void);
+
+#undef DEFINE_C_API_STRUCT
+
 /// Returns `true` if the given operation implements an interface identified by
 /// its TypeID.
 MLIR_CAPI_EXPORTED bool
@@ -42,7 +52,7 @@ mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
 //===----------------------------------------------------------------------===//
 
 /// Returns the interface TypeID of the InferTypeOpInterface.
-MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID();
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(void);
 
 /// These callbacks are used to return multiple types from functions while
 /// transferring ownership to the caller. The first argument is the number of
@@ -65,7 +75,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
 //===----------------------------------------------------------------------===//
 
 /// Returns the interface TypeID of the InferShapedTypeOpInterface.
-MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID();
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(void);
 
 /// These callbacks are used to return multiple shaped type components from
 /// functions while transferring ownership to the caller. The first argument is
@@ -87,6 +97,31 @@ mlirInferShapedTypeOpInterfaceInferReturnTypes(
     void *properties, intptr_t nRegions, MlirRegion *regions,
     MlirShapedTypeComponentsCallback callback, void *userData);
 
+//===---------------------------------------------------------------------===//
+// MemoryEffectsOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the MemoryEffectsOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirMemoryEffectsOpInterfaceTypeID(void);
+
+/// Callbacks for implementing MemoryEffectsOpInterface from external code.
+typedef struct {
+  /// Optional constructor for user data. Set to nullptr to disable it.
+  void (*construct)(void *userData);
+  /// Optional destructor for user data. Set to nullptr to disable it.
+  void (*destruct)(void *userData);
+  /// Get memory effects callback.
+  void (*getEffects)(MlirOperation op, MlirMemoryEffectInstancesList effects,
+                     void *userData);
+  void *userData;
+} MlirMemoryEffectsOpInterfaceCallbacks;
+
+/// Attach a new FallbackModel for the MemoryEffectsOpInterface to the named
+/// operation. The FallbackModel will call the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
+    MlirContext ctx, MlirStringRef opName,
+    MlirMemoryEffectsOpInterfaceCallbacks callbacks);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index f9fc34e82c972..2460202a31b8a 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1496,6 +1496,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationList {
 class MLIR_PYTHON_API_EXPORTED PyOpOperand {
 public:
   PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
+  operator MlirOpOperand() const { return opOperand; }
 
   nanobind::typed<nanobind::object, PyOpView> getOwner() const;
 
@@ -1602,6 +1603,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpResult : public PyConcreteValue<PyOpResult> {
   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
   static constexpr const char *pyClassName = "OpResult";
   using PyConcreteValue::PyConcreteValue;
+  operator MlirOpResult() { return mlirValueToOpResult(castFrom(*this)); }
 
   static void bindDerived(ClassTy &c);
 };
@@ -1838,13 +1840,20 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
 MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
 MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
 MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
+
+/// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+static nanobind::object classmethod(Func f, Args... args) {
+  nanobind::object cf = nanobind::cpp_function(f, args...);
+  return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
+}
+
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
 } // namespace mlir
 
 namespace nanobind {
 namespace detail {
-
 template <>
 struct type_caster<
     mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext>
diff --git a/mlir/include/mlir/CAPI/Dialect/Transform.h b/mlir/include/mlir/CAPI/Dialect/Transform.h
new file mode 100644
index 0000000000000..792236cd8601f
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Dialect/Transform.h
@@ -0,0 +1,28 @@
+//===- Transform.h - C API Utils for Transform dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// the Transform dialect. This file should not be included from C++ code other
+// than C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_DIALECT_TRANSFORM_H
+#define MLIR_CAPI_DIALECT_TRANSFORM_H
+
+#include "mlir-c/Dialect/Transform.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+
+DEFINE_C_API_PTR_METHODS(MlirTransformRewriter,
+                         mlir::transform::TransformRewriter)
+DEFINE_C_API_PTR_METHODS(MlirTransformResults,
+                         mlir::transform::TransformResults)
+DEFINE_C_API_PTR_METHODS(MlirTransformState, mlir::transform::TransformState)
+
+#endif // MLIR_CAPI_DIALECT_TRANSFORM_H
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 1836cb0acb67e..f58ff423d23f7 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -29,6 +29,7 @@ DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
 DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
 DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
 DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand)
+DEFINE_C_API_PTR_METHODS(MlirOpResult, mlir::OpResult)
 DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
 DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
 DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable)
diff --git a/mlir/include/mlir/CAPI/Interfaces.h b/mlir/include/mlir/CAPI/Interfaces.h
index 4154b8c9ec6cc..15afc9fb0f18e 100644
--- a/mlir/include/mlir/CAPI/Interfaces.h
+++ b/mlir/include/mlir/CAPI/Interfaces.h
@@ -15,4 +15,12 @@
 #ifndef MLIR_CAPI_INTERFACES_H
 #define MLIR_CAPI_INTERFACES_H
 
+#include "mlir-c/Interfaces.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+DEFINE_C_API_PTR_METHODS(
+    MlirMemoryEffectInstancesList,
+    llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance>)
+
 #endif // MLIR_CAPI_INTERFACES_H
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 82905498921ce..ecb911fd7a3ba 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -8,12 +8,14 @@
 
 #include <string>
 
+#include "IRInterfaces.h"
+#include "Rewrite.h"
 #include "mlir-c/Dialect/Transform.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/IRCore.h"
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
+#include <nanobind/trampoline.h>
 
 namespace nb = nanobind;
 using namespace mlir::python::nanobind_adaptors;
@@ -22,6 +24,219 @@ namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
 namespace transform {
+
+//===----------------------------------------------------------------------===//
+// TransformRewriter
+//===----------------------------------------------------------------------===//
+class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
+public:
+  static constexpr const char *pyClassName = "TransformRewriter";
+
+  PyTransformRewriter(MlirTransformRewriter rewriter)
+      : PyRewriterBase(mlirTransformRewriterAsBase(rewriter)) {}
+};
+
+//===----------------------------------------------------------------------===//
+// TransformResults
+//===----------------------------------------------------------------------===//
+class PyTransformResults {
+public:
+  PyTransformResults(MlirTransformResults results) : results(results) {}
+
+  MlirTransformResults get() const { return results; }
+
+  void setOps(MlirValue result, const nanobind::list &ops) {
+    std::vector<MlirOperation> opsVec;
+    opsVec.reserve(ops.size());
+    for (auto op : ops) {
+      opsVec.push_back(nb::cast<MlirOperation>(op));
+    }
+    mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
+  }
+
+  void setValues(MlirValue result, const nanobind::list &values) {
+    std::vector<MlirValue> valuesVec;
+    valuesVec.reserve(values.size());
+    for (auto item : values) {
+      valuesVec.push_back(nb::cast<MlirValue>(item));
+    }
+    mlirTransformResultsSetValues(results, result, valuesVec.size(),
+                                  valuesVec.data());
+  }
+
+  void setParams(MlirValue result, const nanobind::list &params) {
+    std::vector<MlirAttribute> paramsVec;
+    paramsVec.reserve(params.size());
+    for (auto item : params) {
+      paramsVec.push_back(nb::cast<MlirAttribute>(item));
+    }
+    mlirTransformResultsSetParams(results, result, paramsVec.size(),
+                                  paramsVec.data());
+  }
+
+  static void bind(nanobind::module_ &m) {
+    nb::class_<PyTransformResults>(m, "TransformResults")
+        .def(nb::init<MlirTransformResults>())
+        .def("set_ops", &PyTransformResults::setOps,
+             "Set the payload operations for a transform result.",
+             nb::arg("result"), nb::arg("ops"))
+        .def("set_values", &PyTransformResults::setValues,
+             "Set the payload values for a transform result.",
+             nb::arg("result"), nb::arg("values"))
+        .def("set_params", &PyTransformResults::setParams,
+             "Set the parameters for a transform result.", nb::arg("result"),
+             nb::arg("params"));
+  }
+
+private:
+  MlirTransformResults results;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformS...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/176920


More information about the Mlir-commits mailing list