[Mlir-commits] [mlir] [MLIR][Python] Impl XOpInterface(s) from Python, with X=Transform and X=MemoryEffects (PR #176920)

Rolf Morel llvmlistbot at llvm.org
Tue Jan 27 12:55:41 PST 2026


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/176920

>From 256fff028969ba360b89edb98c1aea147fd9cea6 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 7 Jan 2026 12:05:58 -0800
Subject: [PATCH 1/9] [MLIR][Python] Expose TransformOpInterface and
 MemoryEffectsOpInterfaces

Enables implementing these two interfaces from Python through defining
the relevant methods on a Python class which then serve as callbacks for
a new FallbackModel C++ class that acts as a "trampoline" to Python for
when the Interface's methods are called from C++. Like in the C++
codebase, these FallbackModels are a late-binding mechanism which can be
attached to an operation after its definition.
---
 mlir/include/mlir-c/Dialect/Transform.h       | 135 ++++++++++
 mlir/include/mlir-c/IR.h                      |   8 +
 mlir/include/mlir-c/Interfaces.h              |  39 ++-
 mlir/include/mlir/Bindings/Python/IRCore.h    |  11 +-
 mlir/include/mlir/CAPI/Dialect/Transform.h    |  28 ++
 mlir/include/mlir/CAPI/IR.h                   |   1 +
 mlir/include/mlir/CAPI/Interfaces.h           |   8 +
 mlir/lib/Bindings/Python/DialectTransform.cpp | 250 ++++++++++++++++-
 mlir/lib/Bindings/Python/IRCore.cpp           |  53 +++-
 mlir/lib/Bindings/Python/IRInterfaces.cpp     | 181 ++++---------
 mlir/lib/Bindings/Python/IRInterfaces.h       | 156 +++++++++++
 mlir/lib/Bindings/Python/Rewrite.cpp          |  64 +----
 mlir/lib/Bindings/Python/Rewrite.h            |  71 ++++-
 mlir/lib/CAPI/Dialect/Transform.cpp           | 196 ++++++++++++++
 mlir/lib/CAPI/IR/IR.cpp                       |  12 +
 mlir/lib/CAPI/Interfaces/Interfaces.cpp       |  74 +++++
 mlir/lib/IR/OperationSupport.cpp              |   2 +-
 mlir/python/CMakeLists.txt                    |   2 +
 mlir/python/mlir/_mlir_libs/__init__.py       |   1 +
 .../python/dialects/transform_op_interface.py | 252 ++++++++++++++++++
 20 files changed, 1344 insertions(+), 200 deletions(-)
 create mode 100644 mlir/include/mlir/CAPI/Dialect/Transform.h
 create mode 100644 mlir/lib/Bindings/Python/IRInterfaces.h
 create mode 100644 mlir/test/python/dialects/transform_op_interface.py

diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 911c9ef659a1e..38bd9756176a0 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -10,7 +10,9 @@
 #ifndef MLIR_C_DIALECT_TRANSFORM_H
 #define MLIR_C_DIALECT_TRANSFORM_H
 
+#include "mlir-c/Interfaces.h"
 #include "mlir-c/IR.h"
+#include "mlir-c/Rewrite.h"
 #include "mlir-c/Support.h"
 
 #ifdef __cplusplus
@@ -19,6 +21,32 @@ extern "C" {
 
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
 
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirTransformResults, void);
+DEFINE_C_API_STRUCT(MlirTransformRewriter, void);
+DEFINE_C_API_STRUCT(MlirTransformState, void);
+
+#undef DEFINE_C_API_STRUCT
+
+//===---------------------------------------------------------------------===//
+// DiagnosedSilenceableFailure
+//===---------------------------------------------------------------------===//
+
+/// Enum representing the result of a transform operation.
+typedef enum {
+  /// The operation succeeded.
+  MlirDiagnosedSilenceableFailureSuccess,
+  /// The operation failed in a silenceable way.
+  MlirDiagnosedSilenceableFailureSilenceableFailure,
+  /// The operation failed definitively.
+  MlirDiagnosedSilenceableFailureDefiniteFailure
+} MlirDiagnosedSilenceableFailure;
+
 //===---------------------------------------------------------------------===//
 // AnyOpType
 //===---------------------------------------------------------------------===//
@@ -86,6 +114,113 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirTransformParamTypeGetName(void);
 
 MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
 
+//===---------------------------------------------------------------------===//
+// TransformRewriter
+//===---------------------------------------------------------------------===//
+
+/// Cast the TransformRewriter to a RewriterBase
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirTransformRewriterAsBase(MlirTransformRewriter rewriter);
+
+//===---------------------------------------------------------------------===//
+// TransformResults
+//===---------------------------------------------------------------------===//
+
+/// Set the payload operations for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void mlirTransformResultsSetOps(MlirTransformResults results,
+                                                   MlirValue result,
+                                                   intptr_t numOps,
+                                                   MlirOperation *ops);
+
+/// Set the payload values for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void
+mlirTransformResultsSetValues(MlirTransformResults results, MlirValue result,
+                              intptr_t numValues, MlirValue *values);
+
+/// Set the parameters for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void
+mlirTransformResultsSetParams(MlirTransformResults results, MlirValue result,
+                              intptr_t numParams, MlirAttribute *params);
+
+//===---------------------------------------------------------------------===//
+// TransformState
+//===---------------------------------------------------------------------===//
+
+/// Callback for iterating over payload operations.
+typedef void (*MlirOperationCallback)(MlirOperation, void *userData);
+
+/// Iterate over payload operations associated with the transform IR value.
+/// Calls the callback for each payload operation.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachPayloadOp(MlirTransformState state, MlirValue value,
+                                   MlirOperationCallback callback,
+                                   void *userData);
+
+/// Callback for iterating over payload values.
+typedef void (*MlirValueCallback)(MlirValue, void *userData);
+
+/// Iterate over payload values associated with the transform IR value.
+/// Calls the callback for each payload value.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachPayloadValue(MlirTransformState state, MlirValue value,
+                                      MlirValueCallback callback,
+                                      void *userData);
+
+/// Callback for iterating over parameters.
+typedef void (*MlirAttributeCallback)(MlirAttribute, void *userData);
+
+/// Iterate over parameters associated with the transform IR value.
+/// Calls the callback for each parameter.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
+                               MlirAttributeCallback callback, void *userData);
+
+//===---------------------------------------------------------------------===//
+// TransformOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the TransformOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOpInterfaceTypeID(void);
+
+/// Callbacks for implementing TransformOpInterface from external code.
+typedef struct {
+  /// Optional constructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*construct)(void *userData);
+  /// Optional destructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*destruct)(void *userData);
+  /// Apply callback that implements the transformation.
+  MlirDiagnosedSilenceableFailure (*apply)(MlirOperation op,
+                                           MlirTransformRewriter rewriter,
+                                           MlirTransformResults results,
+                                           MlirTransformState state,
+                                           void *userData);
+  /// Callback to check if repeated handle operands are allowed.
+  bool (*allowsRepeatedHandleOperands)(MlirOperation op, void *userData);
+  void *userData;
+} MlirTransformOpInterfaceCallbacks;
+
+/// Attach TransformOpInterface to the operation with the given name using
+/// the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
+    MlirContext ctx, MlirStringRef opName,
+    MlirTransformOpInterfaceCallbacks callbacks);
+
+//===---------------------------------------------------------------------===//
+// Transform-specifc MemoryEffectsOpInterface helpers
+//===---------------------------------------------------------------------===//
+
+/// Helper to mark operands as only reading handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
+                             MlirMemoryEffectInstancesList effects);
+
+/// Helper to mark results as producing handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
+                            MlirMemoryEffectInstancesList effects);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 80ff39c82a9ee..9b58267cf1c5a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -55,6 +55,7 @@ DEFINE_C_API_STRUCT(MlirDialect, void);
 DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
 DEFINE_C_API_STRUCT(MlirOperation, void);
 DEFINE_C_API_STRUCT(MlirOpOperand, void);
+DEFINE_C_API_STRUCT(MlirOpResult, void);
 DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
 DEFINE_C_API_STRUCT(MlirBlock, void);
 DEFINE_C_API_STRUCT(MlirRegion, void);
@@ -673,6 +674,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op);
 MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op,
                                                      intptr_t pos);
 
+/// Returns `pos`-th OpOperand of the operation.
+MLIR_CAPI_EXPORTED MlirOpOperand mlirOperationGetOpOperand(MlirOperation op,
+                                                           intptr_t pos);
+
 /// Sets the `pos`-th operand of the operation.
 MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
                                                 MlirValue newValue);
@@ -1044,6 +1049,9 @@ MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value);
 /// Returns 1 if the value is an operation result, 0 otherwise.
 MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value);
 
+/// Cast the value to an OpResult. Asserts if the value is not an op result.
+MLIR_CAPI_EXPORTED MlirOpResult mlirValueToOpResult(MlirValue value);
+
 /// Returns the block in which this value is defined as an argument. Asserts if
 /// the value is not a block argument.
 MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value);
diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h
index a5a3473eaef59..17a812dcd86a9 100644
--- a/mlir/include/mlir-c/Interfaces.h
+++ b/mlir/include/mlir-c/Interfaces.h
@@ -22,6 +22,16 @@
 extern "C" {
 #endif
 
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirMemoryEffectInstancesList, void);
+
+#undef DEFINE_C_API_STRUCT
+
 /// Returns `true` if the given operation implements an interface identified by
 /// its TypeID.
 MLIR_CAPI_EXPORTED bool
@@ -42,7 +52,7 @@ mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
 //===----------------------------------------------------------------------===//
 
 /// Returns the interface TypeID of the InferTypeOpInterface.
-MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID();
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(void);
 
 /// These callbacks are used to return multiple types from functions while
 /// transferring ownership to the caller. The first argument is the number of
@@ -65,7 +75,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
 //===----------------------------------------------------------------------===//
 
 /// Returns the interface TypeID of the InferShapedTypeOpInterface.
-MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID();
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(void);
 
 /// These callbacks are used to return multiple shaped type components from
 /// functions while transferring ownership to the caller. The first argument is
@@ -87,6 +97,31 @@ mlirInferShapedTypeOpInterfaceInferReturnTypes(
     void *properties, intptr_t nRegions, MlirRegion *regions,
     MlirShapedTypeComponentsCallback callback, void *userData);
 
+//===---------------------------------------------------------------------===//
+// MemoryEffectsOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the MemoryEffectsOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirMemoryEffectsOpInterfaceTypeID(void);
+
+/// Callbacks for implementing MemoryEffectsOpInterface from external code.
+typedef struct {
+  /// Optional constructor for user data. Set to nullptr to disable it.
+  void (*construct)(void *userData);
+  /// Optional destructor for user data. Set to nullptr to disable it.
+  void (*destruct)(void *userData);
+  /// Get memory effects callback.
+  void (*getEffects)(MlirOperation op, MlirMemoryEffectInstancesList effects,
+                     void *userData);
+  void *userData;
+} MlirMemoryEffectsOpInterfaceCallbacks;
+
+/// Attach a new FallbackModel for the MemoryEffectsOpInterface to the named
+/// operation. The FallbackModel will call the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
+    MlirContext ctx, MlirStringRef opName,
+    MlirMemoryEffectsOpInterfaceCallbacks callbacks);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index f9fc34e82c972..2460202a31b8a 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1496,6 +1496,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationList {
 class MLIR_PYTHON_API_EXPORTED PyOpOperand {
 public:
   PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
+  operator MlirOpOperand() const { return opOperand; }
 
   nanobind::typed<nanobind::object, PyOpView> getOwner() const;
 
@@ -1602,6 +1603,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpResult : public PyConcreteValue<PyOpResult> {
   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
   static constexpr const char *pyClassName = "OpResult";
   using PyConcreteValue::PyConcreteValue;
+  operator MlirOpResult() { return mlirValueToOpResult(castFrom(*this)); }
 
   static void bindDerived(ClassTy &c);
 };
@@ -1838,13 +1840,20 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
 MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
 MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
 MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
+
+/// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+static nanobind::object classmethod(Func f, Args... args) {
+  nanobind::object cf = nanobind::cpp_function(f, args...);
+  return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
+}
+
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
 } // namespace mlir
 
 namespace nanobind {
 namespace detail {
-
 template <>
 struct type_caster<
     mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext>
diff --git a/mlir/include/mlir/CAPI/Dialect/Transform.h b/mlir/include/mlir/CAPI/Dialect/Transform.h
new file mode 100644
index 0000000000000..792236cd8601f
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Dialect/Transform.h
@@ -0,0 +1,28 @@
+//===- Transform.h - C API Utils for Transform dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// the Transform dialect. This file should not be included from C++ code other
+// than C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_DIALECT_TRANSFORM_H
+#define MLIR_CAPI_DIALECT_TRANSFORM_H
+
+#include "mlir-c/Dialect/Transform.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+
+DEFINE_C_API_PTR_METHODS(MlirTransformRewriter,
+                         mlir::transform::TransformRewriter)
+DEFINE_C_API_PTR_METHODS(MlirTransformResults,
+                         mlir::transform::TransformResults)
+DEFINE_C_API_PTR_METHODS(MlirTransformState, mlir::transform::TransformState)
+
+#endif // MLIR_CAPI_DIALECT_TRANSFORM_H
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 1836cb0acb67e..f58ff423d23f7 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -29,6 +29,7 @@ DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
 DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
 DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
 DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand)
+DEFINE_C_API_PTR_METHODS(MlirOpResult, mlir::OpResult)
 DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
 DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
 DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable)
diff --git a/mlir/include/mlir/CAPI/Interfaces.h b/mlir/include/mlir/CAPI/Interfaces.h
index 4154b8c9ec6cc..15afc9fb0f18e 100644
--- a/mlir/include/mlir/CAPI/Interfaces.h
+++ b/mlir/include/mlir/CAPI/Interfaces.h
@@ -15,4 +15,12 @@
 #ifndef MLIR_CAPI_INTERFACES_H
 #define MLIR_CAPI_INTERFACES_H
 
+#include "mlir-c/Interfaces.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+DEFINE_C_API_PTR_METHODS(
+    MlirMemoryEffectInstancesList,
+    llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance>)
+
 #endif // MLIR_CAPI_INTERFACES_H
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 82905498921ce..7dbeb72a0d2a8 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -8,12 +8,14 @@
 
 #include <string>
 
+#include "IRInterfaces.h"
+#include "Rewrite.h"
 #include "mlir-c/Dialect/Transform.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/IRCore.h"
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
+#include <nanobind/trampoline.h>
 
 namespace nb = nanobind;
 using namespace mlir::python::nanobind_adaptors;
@@ -22,6 +24,208 @@ namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
 namespace transform {
+
+//===----------------------------------------------------------------------===//
+// TransformRewriter
+//===----------------------------------------------------------------------===//
+class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
+public:
+  static constexpr const char *pyClassName = "TransformRewriter";
+
+  PyTransformRewriter(MlirTransformRewriter rewriter)
+      : PyRewriterBase(mlirTransformRewriterAsBase(rewriter)) {}
+};
+
+//===----------------------------------------------------------------------===//
+// TransformResults
+//===----------------------------------------------------------------------===//
+class PyTransformResults {
+public:
+  PyTransformResults(MlirTransformResults results) : results(results) {}
+
+  MlirTransformResults get() const { return results; }
+
+  void setOps(MlirValue result, const nanobind::list &ops) {
+    std::vector<MlirOperation> opsVec;
+    opsVec.reserve(ops.size());
+    for (auto op : ops) {
+      opsVec.push_back(nb::cast<MlirOperation>(op));
+    }
+    mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
+  }
+
+  void setValues(MlirValue result, const nanobind::list &values) {
+    std::vector<MlirValue> valuesVec;
+    valuesVec.reserve(values.size());
+    for (auto item : values) {
+      valuesVec.push_back(nb::cast<MlirValue>(item));
+    }
+    mlirTransformResultsSetValues(results, result, valuesVec.size(),
+                                  valuesVec.data());
+  }
+
+  void setParams(MlirValue result, const nanobind::list &params) {
+    std::vector<MlirAttribute> paramsVec;
+    paramsVec.reserve(params.size());
+    for (auto item : params) {
+      paramsVec.push_back(nb::cast<MlirAttribute>(item));
+    }
+    mlirTransformResultsSetParams(results, result, paramsVec.size(),
+                                  paramsVec.data());
+  }
+
+  static void bind(nanobind::module_ &m) {
+    nb::class_<PyTransformResults>(m, "TransformResults")
+        .def(nb::init<MlirTransformResults>())
+        .def("set_ops", &PyTransformResults::setOps,
+             "Set the payload operations for a transform result.",
+             nb::arg("result"), nb::arg("ops"))
+        .def("set_values", &PyTransformResults::setValues,
+             "Set the payload values for a transform result.",
+             nb::arg("result"), nb::arg("values"))
+        .def("set_params", &PyTransformResults::setParams,
+             "Set the parameters for a transform result.", nb::arg("result"),
+             nb::arg("params"));
+  }
+
+private:
+  MlirTransformResults results;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformState
+//===----------------------------------------------------------------------===//
+class PyTransformState {
+public:
+  PyTransformState(MlirTransformState state) : state(state) {}
+
+  MlirTransformState get() const { return state; }
+
+  nanobind::list getPayloadOps(MlirValue value) {
+    nanobind::list result;
+    mlirTransformStateForEachPayloadOp(
+        state, value,
+        [](MlirOperation op, void *userData) {
+          static_cast<nanobind::list *>(userData)->append(op);
+        },
+        &result);
+    return result;
+  }
+
+  nanobind::list getPayloadValues(MlirValue value) {
+    nanobind::list result;
+    mlirTransformStateForEachPayloadValue(
+        state, value,
+        [](MlirValue val, void *userData) {
+          static_cast<nanobind::list *>(userData)->append(val);
+        },
+        &result);
+    return result;
+  }
+
+  nanobind::list getParams(MlirValue value) {
+    nanobind::list result;
+    mlirTransformStateForEachParam(
+        state, value,
+        [](MlirAttribute attr, void *userData) {
+          static_cast<nanobind::list *>(userData)->append(attr);
+        },
+        &result);
+    return result;
+  }
+
+  static void bind(nanobind::module_ &m) {
+    nb::class_<PyTransformState>(m, "TransformState")
+        .def(nb::init<MlirTransformState>())
+        .def("get_payload_ops", &PyTransformState::getPayloadOps,
+             "Get the payload operations associated with a transform IR value.",
+             nb::arg("operand"))
+        .def("get_payload_values", &PyTransformState::getPayloadValues,
+             "Get the payload values associated with a transform IR value.",
+             nb::arg("operand"))
+        .def("get_params", &PyTransformState::getParams,
+             "Get the parameters (attributes) associated with a transform IR "
+             "value.",
+             nb::arg("operand"));
+  }
+
+private:
+  MlirTransformState state;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformOpInterface
+//===----------------------------------------------------------------------===//
+class PyTransformOpInterface
+    : public PyConcreteOpInterface<PyTransformOpInterface> {
+public:
+  using PyConcreteOpInterface<PyTransformOpInterface>::PyConcreteOpInterface;
+
+  constexpr static const char *pyClassName = "TransformOpInterface";
+  constexpr static GetTypeIDFunctionTy getInterfaceID =
+      &mlirTransformOpInterfaceTypeID;
+
+  /// Attach a new TransformOpInterface FallbackModel to the named operation.
+  /// The FallbackModel acts as a trampoline for callbacks on the Python class.
+  static void attach(nb::object &pyClass, const std::string &opName,
+                     DefaultingPyMlirContext ctx) {
+    // Prepare the callbacks that will be used by the FallbackModel.
+    MlirTransformOpInterfaceCallbacks callbacks;
+    // Make the pointer to the Python class available to the callbacks.
+    callbacks.userData = pyClass.ptr();
+    nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+
+    // The above ref bump is all we need as initialization, no need to run the
+    // construct callback.
+    callbacks.construct = nullptr;
+    // Upon the FallbackModel's destruction, drop the ref to the Python class.
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+    // The apply callback which calls into Python.
+    callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
+                         MlirTransformResults results, MlirTransformState state,
+                         void *userData) -> MlirDiagnosedSilenceableFailure {
+      nb::handle pyClass(static_cast<PyObject *>(userData));
+
+      auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
+
+      auto pyRewriter = PyTransformRewriter(rewriter);
+      auto pyResults = PyTransformResults(results);
+      auto pyState = PyTransformState(state);
+
+      // Invoke `pyClass.apply(op, rewriter, results, state)` as a classmethod.
+      nb::object res = pyApply(op, pyRewriter, pyResults, pyState);
+
+      return nb::cast<MlirDiagnosedSilenceableFailure>(res);
+    };
+
+    // The allows_repeated_handle_operands callback which calls into Python.
+    callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
+                                                void *userData) -> bool {
+      nb::handle pyClass(static_cast<PyObject *>(userData));
+
+      auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
+          nb::getattr(pyClass, "allow_repeated_handle_operands"));
+
+      // Invoke `pyClass.allow_repeated_handle_operands(op)` as a classmethod.
+      nb::object res = pyAllowRepeatedHandleOperands(op);
+
+      return nb::cast<bool>(res);
+    };
+
+    // Attach a FallbackModel, which calls into Python, to the named operation.
+    mlirTransformOpInterfaceAttachFallbackModel(
+        ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
+  }
+
+  static void bindDerived(ClassTy &transformOpInterfaceClass) {
+    transformOpInterfaceClass.attr("attach") =
+        classmethod(&PyTransformOpInterface::attach, nb::arg("cls"),
+                    nb::arg("op_name"), nb::arg("ctx") = nb::none());
+  }
+};
+
 //===-------------------------------------------------------------------===//
 // AnyOpType
 //===-------------------------------------------------------------------===//
@@ -162,12 +366,54 @@ struct ParamType : PyConcreteType<ParamType> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// MemoryEffectsOpInterface helpers
+//===----------------------------------------------------------------------===//
+
+namespace {
+void onlyReadsHandle(nb::list &operands, PyMemoryEffectsInstanceList effects) {
+  std::vector<MlirOpOperand> operandsVec;
+  operandsVec.reserve(operands.size());
+  for (auto operand : operands)
+    operandsVec.push_back(nb::cast<PyOpOperand>(operand));
+  mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
+                               effects.effects);
+};
+
+void producesHandle(nb::list &results, PyMemoryEffectsInstanceList effects) {
+  std::vector<MlirOpResult> resultsVec;
+  resultsVec.reserve(results.size());
+  for (auto result : results)
+    resultsVec.push_back(nb::cast<PyOpResult>(result));
+  mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
+                              effects.effects);
+};
+} // namespace
+
 static void populateDialectTransformSubmodule(nb::module_ &m) {
+  nb::enum_<MlirDiagnosedSilenceableFailure>(m, "DiagnosedSilenceableFailure")
+      .value("Success", MlirDiagnosedSilenceableFailureSuccess)
+      .value("SilenceableFailure",
+             MlirDiagnosedSilenceableFailureSilenceableFailure)
+      .value("DefiniteFailure", MlirDiagnosedSilenceableFailureDefiniteFailure);
+
   AnyOpType::bind(m);
   AnyParamType::bind(m);
   AnyValueType::bind(m);
   OperationType::bind(m);
   ParamType::bind(m);
+
+  PyTransformRewriter::bind(m);
+  PyTransformResults::bind(m);
+  PyTransformState::bind(m);
+  PyTransformOpInterface::bind(m);
+
+  m.def("only_reads_handle", onlyReadsHandle,
+        "Mark operands as only reading handles.", nb::arg("operands"),
+        nb::arg("effects"));
+
+  m.def("produces_handle", producesHandle, "Mark results as producing handles.",
+        nb::arg("results"), nb::arg("effects"));
 }
 } // namespace transform
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index eb00363a54034..f3108473c712a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -10,7 +10,6 @@
 #include "mlir/Bindings/Python/Globals.h"
 #include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/NanobindUtils.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 // clang-format on
 #include "mlir-c/BuiltinAttributes.h"
@@ -52,13 +51,6 @@ operations.
 // Utilities.
 //------------------------------------------------------------------------------
 
-/// Helper for creating an @classmethod.
-template <class Func, typename... Args>
-static nb::object classmethod(Func f, Args... args) {
-  nb::object cf = nb::cpp_function(f, args...);
-  return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
-}
-
 static nb::object
 createCustomDialectWrapper(const std::string &dialectNamespace,
                            nb::object dialectDescriptor) {
@@ -2306,6 +2298,44 @@ PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
   return PyOpOperandList(operation, startIndex, length, step);
 }
 
+/// A list of OpOperands. Internally, these are stored as consecutive elements,
+/// random access is cheap. The (returned) OpOperand list is associated with the
+/// operation whose operands these are, and thus extends the lifetime of this
+/// operation.
+class PyOpOpOperandList : public Sliceable<PyOpOpOperandList, PyOpOperand> {
+public:
+  static constexpr const char *pyClassName = "OpOpOperandList";
+  using SliceableT = Sliceable<PyOpOperandList, PyOpOperand>;
+
+  PyOpOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
+                    intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirOperationGetNumOperands(operation->get())
+                               : length,
+                  step),
+        operation(operation) {}
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyOpOpOperandList, PyOpOperand>;
+
+  intptr_t getRawNumElements() {
+    operation->checkValid();
+    return mlirOperationGetNumOperands(operation->get());
+  }
+
+  PyOpOperand getRawElement(intptr_t pos) {
+    MlirOpOperand opOperand = mlirOperationGetOpOperand(operation->get(), pos);
+    return PyOpOperand(opOperand);
+  }
+
+  PyOpOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+    return PyOpOpOperandList(operation, startIndex, length, step);
+  }
+
+  PyOperationRef operation;
+};
+
 PyOpSuccessors::PyOpSuccessors(PyOperationRef operation, intptr_t startIndex,
                                intptr_t length, intptr_t step)
     : Sliceable(startIndex,
@@ -3534,6 +3564,12 @@ void populateIRCore(nb::module_ &m) {
             return PyOpOperandList(self.getOperation().getRef());
           },
           "Returns the list of operation operands.")
+      .def_prop_ro(
+          "op_operands",
+          [](PyOperationBase &self) {
+            return PyOpOpOperandList(self.getOperation().getRef());
+          },
+          "Returns the list of op operands.")
       .def_prop_ro(
           "regions",
           [](PyOperationBase &self) {
@@ -4825,6 +4861,7 @@ void populateIRCore(nb::module_ &m) {
   PyOpAttributeMap::bind(m);
   PyOpOperandIterator::bind(m);
   PyOpOperandList::bind(m);
+  PyOpOpOperandList::bind(m);
   PyOpResultList::bind(m);
   PyOpSuccessors::bind(m);
   PyRegionIterator::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 09112d4989ae4..05633746c3136 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,12 +12,12 @@
 #include <utility>
 #include <vector>
 
+#include "IRInterfaces.h"
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Interfaces.h"
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/IRCore.h"
-#include "mlir/Bindings/Python/Nanobind.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 
@@ -26,18 +26,6 @@ namespace nb = nanobind;
 namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-constexpr static const char *constructorDoc =
-    R"(Creates an interface from a given operation/opview object or from a
-subclass of OpView. Raises ValueError if the operation does not implement the
-interface.)";
-
-constexpr static const char *operationDoc =
-    R"(Returns an Operation for which the interface was constructed.)";
-
-constexpr static const char *opviewDoc =
-    R"(Returns an OpView subclass _instance_ for which the interface was
-constructed)";
-
 constexpr static const char *inferReturnTypesDoc =
     R"(Given the arguments required to build an operation, attempts to infer
 its return types. Raises ValueError on failure.)";
@@ -124,119 +112,6 @@ wrapRegions(std::optional<std::vector<PyRegion>> regions) {
 
 } // namespace
 
-/// CRTP base class for Python classes representing MLIR Op interfaces.
-/// Interface hierarchies are flat so no base class is expected here. The
-/// derived class is expected to define the following static fields:
-///  - `const char *pyClassName` - the name of the Python class to create;
-///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
-///    of the interface.
-/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
-/// interface-specific methods.
-///
-/// An interface class may be constructed from either an Operation/OpView object
-/// or from a subclass of OpView. In the latter case, only the static interface
-/// methods are available, similarly to calling ConcereteOp::staticMethod on the
-/// C++ side. Implementations of concrete interfaces can use the `isStatic`
-/// method to check whether the interface object was constructed from a class or
-/// an operation/opview instance. The `getOpName` always succeeds and returns a
-/// canonical name of the operation suitable for lookups.
-template <typename ConcreteIface>
-class PyConcreteOpInterface {
-protected:
-  using ClassTy = nb::class_<ConcreteIface>;
-  using GetTypeIDFunctionTy = MlirTypeID (*)();
-
-public:
-  /// Constructs an interface instance from an object that is either an
-  /// operation or a subclass of OpView. In the latter case, only the static
-  /// methods of the interface are accessible to the caller.
-  PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
-      : obj(std::move(object)) {
-    try {
-      operation = &nb::cast<PyOperation &>(obj);
-    } catch (nb::cast_error &) {
-      // Do nothing.
-    }
-
-    try {
-      operation = &nb::cast<PyOpView &>(obj).getOperation();
-    } catch (nb::cast_error &) {
-      // Do nothing.
-    }
-
-    if (operation != nullptr) {
-      if (!mlirOperationImplementsInterface(*operation,
-                                            ConcreteIface::getInterfaceID())) {
-        std::string msg = "the operation does not implement ";
-        throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
-      }
-
-      MlirIdentifier identifier = mlirOperationGetName(*operation);
-      MlirStringRef stringRef = mlirIdentifierStr(identifier);
-      opName = std::string(stringRef.data, stringRef.length);
-    } else {
-      try {
-        opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
-      } catch (nb::cast_error &) {
-        throw nb::type_error(
-            "Op interface does not refer to an operation or OpView class");
-      }
-
-      if (!mlirOperationImplementsInterfaceStatic(
-              mlirStringRefCreate(opName.data(), opName.length()),
-              context.resolve().get(), ConcreteIface::getInterfaceID())) {
-        std::string msg = "the operation does not implement ";
-        throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
-      }
-    }
-  }
-
-  /// Creates the Python bindings for this class in the given module.
-  static void bind(nb::module_ &m) {
-    nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
-    cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
-            nb::arg("context") = nb::none(), constructorDoc)
-        .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
-                     operationDoc)
-        .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
-    ConcreteIface::bindDerived(cls);
-  }
-
-  /// Hook for derived classes to add class-specific bindings.
-  static void bindDerived(ClassTy &cls) {}
-
-  /// Returns `true` if this object was constructed from a subclass of OpView
-  /// rather than from an operation instance.
-  bool isStatic() { return operation == nullptr; }
-
-  /// Returns the operation instance from which this object was constructed.
-  /// Throws a type error if this object was constructed from a subclass of
-  /// OpView.
-  nb::typed<nb::object, PyOperation> getOperationObject() {
-    if (operation == nullptr)
-      throw nb::type_error("Cannot get an operation from a static interface");
-    return operation->getRef().releaseObject();
-  }
-
-  /// Returns the opview of the operation instance from which this object was
-  /// constructed. Throws a type error if this object was constructed form a
-  /// subclass of OpView.
-  nb::typed<nb::object, PyOpView> getOpView() {
-    if (operation == nullptr)
-      throw nb::type_error("Cannot get an opview from a static interface");
-    return operation->createOpView();
-  }
-
-  /// Returns the canonical name of the operation this interface is constructed
-  /// from.
-  const std::string &getOpName() { return opName; }
-
-private:
-  PyOperation *operation = nullptr;
-  std::string opName;
-  nb::object obj;
-};
-
 /// Python wrapper for InferTypeOpInterface. This interface has only static
 /// methods.
 class PyInferTypeOpInterface
@@ -464,10 +339,62 @@ class PyInferShapedTypeOpInterface
   }
 };
 
+/// Wrapper around the MemoryEffectsOpInterface.
+class PyMemoryEffectsOpInterface
+    : public PyConcreteOpInterface<PyMemoryEffectsOpInterface> {
+public:
+  using PyConcreteOpInterface<
+      PyMemoryEffectsOpInterface>::PyConcreteOpInterface;
+
+  constexpr static const char *pyClassName = "MemoryEffectsOpInterface";
+  constexpr static GetTypeIDFunctionTy getInterfaceID =
+      &mlirMemoryEffectsOpInterfaceTypeID;
+
+  /// Attach a new MemoryEffectsOpInterface FallbackModel to the named
+  /// operation. The FallbackModel acts as a trampoline for callbacks on the
+  /// Python class.
+  static void attach(nb::object &pySubclass, const std::string &opName,
+                     DefaultingPyMlirContext ctx) {
+    MlirMemoryEffectsOpInterfaceCallbacks callbacks;
+    callbacks.userData = pySubclass.ptr();
+    nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+    callbacks.construct = nullptr;
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+    callbacks.getEffects = [](MlirOperation op,
+                              MlirMemoryEffectInstancesList effects,
+                              void *userData) {
+      nb::handle pyClass(static_cast<PyObject *>(userData));
+
+      // Get the 'get_effects' method from the Python class.
+      auto pyGetEffects =
+          nb::cast<nb::callable>(nb::getattr(pyClass, "get_effects"));
+
+      PyMemoryEffectsInstanceList effectsWrapper{effects};
+
+      // Invoke `pyClass.get_effects(op, effects)`.
+      pyGetEffects(op, effectsWrapper);
+    };
+
+    mlirMemoryEffectsOpInterfaceAttachFallbackModel(
+        ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
+  }
+
+  static void bindDerived(ClassTy &memoryEffectsOpInterfaceClass) {
+    memoryEffectsOpInterfaceClass.attr("attach") =
+        classmethod(&PyMemoryEffectsOpInterface::attach, nb::arg("cls"),
+                    nb::arg("op_name"), nb::arg("ctx") = nb::none());
+  }
+};
+
 void populateIRInterfaces(nb::module_ &m) {
+  nb::class_<PyMemoryEffectsInstanceList>(m, "MemoryEffectInstancesList");
+
+  PyInferShapedTypeOpInterface::bind(m);
   PyInferTypeOpInterface::bind(m);
+  PyMemoryEffectsOpInterface::bind(m);
   PyShapedTypeComponents::bind(m);
-  PyInferShapedTypeOpInterface::bind(m);
 }
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.h b/mlir/lib/Bindings/Python/IRInterfaces.h
new file mode 100644
index 0000000000000..3374b505b5034
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRInterfaces.h
@@ -0,0 +1,156 @@
+//===- IRInterfaces.h - IR Interfaces for Python Bindings -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_IRINTERFACES_H
+#define MLIR_BINDINGS_PYTHON_IRINTERFACES_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Interfaces.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
+
+namespace nb = nanobind;
+
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+constexpr static const char *constructorDoc =
+    R"(Creates an interface from a given operation/opview object or from a
+subclass of OpView. Raises ValueError if the operation does not implement the
+interface.)";
+
+constexpr static const char *operationDoc =
+    R"(Returns an Operation for which the interface was constructed.)";
+
+constexpr static const char *opviewDoc =
+    R"(Returns an OpView subclass _instance_ for which the interface was
+constructed)";
+
+/// CRTP base class for Python classes representing MLIR Op interfaces.
+/// Interface hierarchies are flat so no base class is expected here. The
+/// derived class is expected to define the following static fields:
+///  - `const char *pyClassName` - the name of the Python class to create;
+///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
+///    of the interface.
+/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
+/// interface-specific methods.
+///
+/// An interface class may be constructed from either an Operation/OpView object
+/// or from a subclass of OpView. In the latter case, only the static interface
+/// methods are available, similarly to calling ConcereteOp::staticMethod on the
+/// C++ side. Implementations of concrete interfaces can use the `isStatic`
+/// method to check whether the interface object was constructed from a class or
+/// an operation/opview instance. The `getOpName` always succeeds and returns a
+/// canonical name of the operation suitable for lookups.
+template <typename ConcreteIface>
+class PyConcreteOpInterface {
+protected:
+  using ClassTy = nb::class_<ConcreteIface>;
+  using GetTypeIDFunctionTy = MlirTypeID (*)();
+
+public:
+  /// Constructs an interface instance from an object that is either an
+  /// operation or a subclass of OpView. In the latter case, only the static
+  /// methods of the interface are accessible to the caller.
+  PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
+      : obj(std::move(object)) {
+    try {
+      operation = &nb::cast<PyOperation &>(obj);
+    } catch (nb::cast_error &) {
+      // Do nothing.
+    }
+
+    try {
+      operation = &nb::cast<PyOpView &>(obj).getOperation();
+    } catch (nb::cast_error &) {
+      // Do nothing.
+    }
+
+    if (operation != nullptr) {
+      if (!mlirOperationImplementsInterface(*operation,
+                                            ConcreteIface::getInterfaceID())) {
+        std::string msg = "the operation does not implement ";
+        throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
+      }
+
+      MlirIdentifier identifier = mlirOperationGetName(*operation);
+      MlirStringRef stringRef = mlirIdentifierStr(identifier);
+      opName = std::string(stringRef.data, stringRef.length);
+    } else {
+      try {
+        opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
+      } catch (nb::cast_error &) {
+        throw nb::type_error(
+            "Op interface does not refer to an operation or OpView class");
+      }
+
+      if (!mlirOperationImplementsInterfaceStatic(
+              mlirStringRefCreate(opName.data(), opName.length()),
+              context.resolve().get(), ConcreteIface::getInterfaceID())) {
+        std::string msg = "the operation does not implement ";
+        throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
+      }
+    }
+  }
+
+  /// Creates the Python bindings for this class in the given module.
+  static void bind(nb::module_ &m) {
+    nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
+    cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
+            nb::arg("context") = nb::none(), constructorDoc)
+        .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
+                     operationDoc)
+        .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
+    ConcreteIface::bindDerived(cls);
+  }
+
+  /// Hook for derived classes to add class-specific bindings.
+  static void bindDerived(ClassTy &cls) {}
+
+  /// Returns `true` if this object was constructed from a subclass of OpView
+  /// rather than from an operation instance.
+  bool isStatic() { return operation == nullptr; }
+
+  /// Returns the operation instance from which this object was constructed.
+  /// Throws a type error if this object was constructed from a subclass of
+  /// OpView.
+  nb::typed<nb::object, PyOperation> getOperationObject() {
+    if (operation == nullptr)
+      throw nb::type_error("Cannot get an operation from a static interface");
+    return operation->getRef().releaseObject();
+  }
+
+  /// Returns the opview of the operation instance from which this object was
+  /// constructed. Throws a type error if this object was constructed form a
+  /// subclass of OpView.
+  nb::typed<nb::object, PyOpView> getOpView() {
+    if (operation == nullptr)
+      throw nb::type_error("Cannot get an opview from a static interface");
+    return operation->createOpView();
+  }
+
+  /// Returns the canonical name of the operation this interface is constructed
+  /// from.
+  const std::string &getOpName() { return opName; }
+
+private:
+  PyOperation *operation = nullptr;
+  std::string opName;
+  nb::object obj;
+};
+
+struct PyMemoryEffectsInstanceList {
+  MlirMemoryEffectInstancesList effects;
+};
+
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_IRINTERFACES_H
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 2b649f79c5982..dcb95e81fc126 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,14 +8,11 @@
 
 #include "Rewrite.h"
 
+#include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Rewrite.h"
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/IRCore.h"
-// clang-format off
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
-// clang-format on
 #include "mlir/Config/mlir-config.h"
 #include "nanobind/nanobind.h"
 
@@ -28,38 +25,12 @@ namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
 
-class PyPatternRewriter {
+class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
 public:
-  PyPatternRewriter(MlirPatternRewriter rewriter)
-      : base(mlirPatternRewriterAsBase(rewriter)),
-        ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
-
-  PyInsertionPoint getInsertionPoint() const {
-    MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
-    MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
-
-    if (mlirOperationIsNull(op)) {
-      MlirOperation owner = mlirBlockGetParentOperation(block);
-      auto parent = PyOperation::forOperation(ctx, owner);
-      return PyInsertionPoint(PyBlock(parent, block));
-    }
-
-    return PyInsertionPoint(PyOperation::forOperation(ctx, op));
-  }
-
-  void replaceOp(MlirOperation op, MlirOperation newOp) {
-    mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
-  }
-
-  void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
-    mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
-  }
+  static constexpr const char *pyClassName = "PatternRewriter";
 
-  void eraseOp(const PyOperation &op) { mlirRewriterBaseEraseOp(base, op); }
-
-private:
-  MlirRewriterBase base;
-  PyMlirContextRef ctx;
+  PyPatternRewriter(MlirPatternRewriter rewriter)
+      : PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
 };
 
 struct PyMlirPDLResultList : MlirPDLResultList {};
@@ -340,29 +311,8 @@ void populateRewriteSubmodule(nb::module_ &m) {
   //----------------------------------------------------------------------------
   // Mapping of the PatternRewriter
   //----------------------------------------------------------------------------
-  nb::class_<PyPatternRewriter>(m, "PatternRewriter")
-      .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
-                   "The current insertion point of the PatternRewriter.")
-      .def(
-          "replace_op",
-          [](PyPatternRewriter &self, PyOperationBase &op,
-             PyOperationBase &newOp) {
-            self.replaceOp(op.getOperation(), newOp.getOperation());
-          },
-          "Replace an operation with a new operation.", nb::arg("op"),
-          nb::arg("new_op"))
-      .def(
-          "replace_op",
-          [](PyPatternRewriter &self, PyOperationBase &op,
-             const std::vector<PyValue> &values) {
-            std::vector<MlirValue> values_(values.size());
-            std::copy(values.begin(), values.end(), values_.begin());
-            self.replaceOp(op.getOperation(), values_);
-          },
-          "Replace an operation with a list of values.", nb::arg("op"),
-          nb::arg("values"))
-      .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
-           nb::arg("op"));
+
+  PyPatternRewriter::bind(m);
 
   //----------------------------------------------------------------------------
   // Mapping of the RewritePatternSet
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index d287f19187708..9b4dfafb59b86 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -9,13 +9,80 @@
 #ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
 #define MLIR_BINDINGS_PYTHON_REWRITE_H
 
-#include "mlir/Bindings/Python/NanobindUtils.h"
+#include "mlir-c/Rewrite.h"
+#include "mlir/Bindings/Python/IRCore.h"
+
+#include <nanobind/nanobind.h>
+
+namespace nb = nanobind;
 
 namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+/// CRTP Base class for rewriter wrappers.
+template <typename DerivedTy>
+class PyRewriterBase {
+public:
+  PyRewriterBase(MlirRewriterBase rewriter)
+      : base(rewriter),
+        ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
+
+  PyInsertionPoint getInsertionPoint() const {
+    MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
+    MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
+
+    if (mlirOperationIsNull(op)) {
+      MlirOperation owner = mlirBlockGetParentOperation(block);
+      auto parent = PyOperation::forOperation(ctx, owner);
+      return PyInsertionPoint(PyBlock(parent, block));
+    }
+
+    return PyInsertionPoint(PyOperation::forOperation(ctx, op));
+  }
+
+  void replaceOp(MlirOperation op, MlirOperation newOp) {
+    mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+  }
+
+  void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+    mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+  }
+
+  void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+
+  static void bind(nanobind::module_ &m) {
+    nb::class_<DerivedTy>(m, DerivedTy::pyClassName)
+        .def_prop_ro("ip", &PyRewriterBase::getInsertionPoint,
+                     "The current insertion point of the PatternRewriter.")
+        .def(
+            "replace_op",
+            [](DerivedTy &self, PyOperationBase &op, PyOperationBase &newOp) {
+              self.replaceOp(op.getOperation(), newOp.getOperation());
+            },
+            "Replace an operation with a new operation.", nb::arg("op"),
+            nb::arg("new_op"))
+        .def(
+            "replace_op",
+            [](DerivedTy &self, PyOperationBase &op,
+               const std::vector<PyValue> &values) {
+              std::vector<MlirValue> values_(values.size());
+              std::copy(values.begin(), values.end(), values_.begin());
+              self.replaceOp(op.getOperation(), values_);
+            },
+            "Replace an operation with a list of values.", nb::arg("op"),
+            nb::arg("values"))
+        .def("erase_op", &DerivedTy::eraseOp, "Erase an operation.",
+             nb::arg("op"));
+  }
+
+private:
+  MlirRewriterBase base;
+  PyMlirContextRef ctx;
+};
+
 void populateRewriteSubmodule(nanobind::module_ &m);
-}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
 } // namespace mlir
 
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 18d4c075dbb9c..fdb27ec5cdc89 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -8,9 +8,13 @@
 
 #include "mlir-c/Dialect/Transform.h"
 #include "mlir-c/Support.h"
+#include "mlir/CAPI/Dialect/Transform.h"
+#include "mlir/CAPI/Interfaces.h"
 #include "mlir/CAPI/Registration.h"
+#include "mlir/CAPI/Rewrite.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 
 using namespace mlir;
 
@@ -126,3 +130,195 @@ MlirStringRef mlirTransformParamTypeGetName(void) {
 MlirType mlirTransformParamTypeGetType(MlirType type) {
   return wrap(cast<transform::ParamType>(unwrap(type)).getType());
 }
+
+//===---------------------------------------------------------------------===//
+// TransformRewriter
+//===---------------------------------------------------------------------===//
+
+/// Casts a `MlirTransformRewriter` to a `MlirRewriterBase`.
+MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter) {
+  mlir::transform::TransformRewriter *t = unwrap(rewriter);
+  mlir::RewriterBase *base = static_cast<mlir::RewriterBase *>(t);
+  return wrap(base);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformResults
+//===---------------------------------------------------------------------===//
+
+void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result,
+                                intptr_t numOps, MlirOperation *ops) {
+  SmallVector<Operation *> opsVec;
+  opsVec.reserve(numOps);
+  for (intptr_t i = 0; i < numOps; ++i)
+    opsVec.push_back(unwrap(ops[i]));
+  unwrap(results)->set(cast<OpResult>(unwrap(result)), opsVec);
+}
+
+void mlirTransformResultsSetValues(MlirTransformResults results,
+                                   MlirValue result, intptr_t numValues,
+                                   MlirValue *values) {
+  SmallVector<Value> valuesVec;
+  valuesVec.reserve(numValues);
+  for (intptr_t i = 0; i < numValues; ++i)
+    valuesVec.push_back(unwrap(values[i]));
+  unwrap(results)->setValues(cast<OpResult>(unwrap(result)), valuesVec);
+}
+
+void mlirTransformResultsSetParams(MlirTransformResults results,
+                                   MlirValue result, intptr_t numParams,
+                                   MlirAttribute *params) {
+  SmallVector<Attribute> paramsVec;
+  paramsVec.reserve(numParams);
+  for (intptr_t i = 0; i < numParams; ++i)
+    paramsVec.push_back(unwrap(params[i]));
+  unwrap(results)->setParams(cast<OpResult>(unwrap(result)), paramsVec);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformState
+//===---------------------------------------------------------------------===//
+
+void mlirTransformStateForEachPayloadOp(MlirTransformState state,
+                                        MlirValue value,
+                                        MlirOperationCallback callback,
+                                        void *userData) {
+  for (Operation *op : unwrap(state)->getPayloadOps(unwrap(value)))
+    callback(wrap(op), userData);
+}
+
+void mlirTransformStateForEachPayloadValue(MlirTransformState state,
+                                           MlirValue value,
+                                           MlirValueCallback callback,
+                                           void *userData) {
+  for (Value val : unwrap(state)->getPayloadValues(unwrap(value)))
+    callback(wrap(val), userData);
+}
+
+void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
+                                    MlirAttributeCallback callback,
+                                    void *userData) {
+  for (Attribute attr : unwrap(state)->getParams(unwrap(value)))
+    callback(wrap(attr), userData);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformOpInterface
+//===---------------------------------------------------------------------===//
+
+MlirTypeID mlirTransformOpInterfaceTypeID(void) {
+  return wrap(transform::TransformOpInterface::getInterfaceID());
+}
+
+/// Fallback model for the TransformOpInterface that uses C API callbacks.
+class TransformOpInterfaceFallbackModel
+    : public mlir::transform::TransformOpInterface::FallbackModel<
+          TransformOpInterfaceFallbackModel> {
+public:
+  /// Sets the callbacks that this FallbackModel will use.
+  /// NB: the callbacks can only be set through this method as the
+  /// RegisteredOperationName::attachInterface mechanism default-constructs
+  /// the FallbackModel without being able to provide arguments.
+  void setCallbacks(MlirTransformOpInterfaceCallbacks callbacks) {
+    this->callbacks = callbacks;
+  }
+
+  ~TransformOpInterfaceFallbackModel() {
+    if (callbacks.destruct)
+      callbacks.destruct(callbacks.userData);
+  }
+
+  static TypeID getInterfaceID() {
+    return transform::TransformOpInterface::getInterfaceID();
+  }
+
+  static bool classof(const mlir::transform::detail::
+                          TransformOpInterfaceInterfaceTraits::Concept *op) {
+    // Enable casting back to the FallbackModel from the Interface. This is
+    // necessary as attachInterface(...) default-constructs the FallbackModel
+    // without being able to pass in the callbacks and returns just the Concept.
+    return true;
+  }
+
+  ::mlir::DiagnosedSilenceableFailure
+  apply(Operation *op, ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state) const {
+    assert(callbacks.apply && "apply callback not set");
+
+    MlirDiagnosedSilenceableFailure status =
+        callbacks.apply(wrap(op), wrap(&rewriter), wrap(&transformResults),
+                        wrap(&state), callbacks.userData);
+
+    switch (status) {
+    case MlirDiagnosedSilenceableFailureSuccess:
+      return DiagnosedSilenceableFailure::success();
+    case MlirDiagnosedSilenceableFailureSilenceableFailure:
+      // TODO: enable passing diagnostic info from C API to C++ API.
+      return DiagnosedSilenceableFailure::silenceableFailure(std::move(
+          *(op->emitError()
+            << "TransformOpInterfaceFallbackModel: silenceable failure")
+               .getUnderlyingDiagnostic()));
+    case MlirDiagnosedSilenceableFailureDefiniteFailure:
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    llvm_unreachable("unknown transform status");
+  }
+
+  bool allowsRepeatedHandleOperands(Operation *op) const {
+    assert(callbacks.allowsRepeatedHandleOperands &&
+           "allowsRepeatedHandleOperands callback not set");
+    return callbacks.allowsRepeatedHandleOperands(wrap(op), callbacks.userData);
+  }
+
+private:
+  MlirTransformOpInterfaceCallbacks callbacks;
+};
+
+/// Attach a TransformOpInterface FallbackModel to the given named operation.
+/// The FallbackModel uses the provided callbacks to implement the interface.
+void mlirTransformOpInterfaceAttachFallbackModel(
+    MlirContext ctx, MlirStringRef opName,
+    MlirTransformOpInterfaceCallbacks callbacks) {
+  // Look up the operation definition in the context.
+  std::optional<RegisteredOperationName> opInfo =
+      RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
+
+  if (!opInfo.has_value()) {
+    llvm::errs() << "Operation '" << unwrap(opName)
+                 << "' not found in context\n";
+    return;
+  }
+
+  // NB: the following default-constructs the FallbackModel _without_ being able
+  // to provide arguments.
+  opInfo->attachInterface<TransformOpInterfaceFallbackModel>();
+  // Cast to get the underlying FallbackModel and set the callbacks.
+  auto *model = cast<TransformOpInterfaceFallbackModel>(
+      opInfo->getInterface<TransformOpInterfaceFallbackModel>());
+
+  assert(model && "Failed to get TransformOpInterfaceFallbackModel");
+  model->setCallbacks(callbacks);
+}
+
+//===---------------------------------------------------------------------===//
+// MemoryEffectsOpInterface helpers
+//===---------------------------------------------------------------------===//
+
+/// Set the effect for the operands to only read the transform handles.
+void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
+                                  MlirMemoryEffectInstancesList effects) {
+  MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
+  transform::onlyReadsHandle(operandArray, *unwrap(effects));
+}
+
+/// Set the effect for the results to that they produce transform handles.
+void mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
+                                 MlirMemoryEffectInstancesList effects) {
+  // NB: calling `producesHandle()` `numResults` as we cannot cast array of
+  // `OpResult`s to a single `ResultRange` (and neither is `ResultRange` exposed
+  // to Python). `producesHandle` iterates over the given `ResultRange` anyway.
+  for (intptr_t i = 0; i < numResults; ++i)
+    transform::producesHandle(ResultRange(*unwrap(results[i])),
+                              *unwrap(effects));
+}
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 09666932004a4..ccfd30e2354c5 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -30,6 +30,7 @@
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/ThreadPool.h"
 
 #include <cstddef>
@@ -714,6 +715,10 @@ MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
 }
 
+MlirOpOperand mlirOperationGetOpOperand(MlirOperation op, intptr_t pos) {
+  return wrap(&unwrap(op)->getOpOperand(static_cast<unsigned>(pos)));
+}
+
 void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
                              MlirValue newValue) {
   unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
@@ -1121,6 +1126,13 @@ bool mlirValueIsAOpResult(MlirValue value) {
   return llvm::isa<OpResult>(unwrap(value));
 }
 
+/// Cast an MlirValue to an MlirOpResult, asserting in case of a type mismatch.
+MlirOpResult mlirValueToOpResult(MlirValue value) {
+  return TypeSwitch<Value, MlirOpResult>(unwrap(value))
+      .Case<OpResult>([&](OpResult opResult) { return wrap(&opResult); })
+      .DefaultUnreachable("expected an OpResult");
+}
+
 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
   return wrap(llvm::dyn_cast<BlockArgument>(unwrap(value)).getOwner());
 }
diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
index ef3fc23869550..50f007f5d15f6 100644
--- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp
+++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
@@ -167,3 +167,77 @@ MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
   }
   return mlirLogicalResultSuccess();
 }
+
+//===---------------------------------------------------------------------===//
+// MemoryEffectOpInterface
+//===---------------------------------------------------------------------===//
+
+MlirTypeID mlirMemoryEffectsOpInterfaceTypeID() {
+  return wrap(MemoryEffectOpInterface::getInterfaceID());
+}
+
+/// Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
+class MemoryEffectOpInterfaceFallbackModel
+    : public mlir::MemoryEffectOpInterface::FallbackModel<
+          MemoryEffectOpInterfaceFallbackModel> {
+public:
+  /// Sets the callbacks that this FallbackModel will use.
+  /// NB: the callbacks can only be set through this method as the
+  /// RegisteredOperationName::attachInterface mechanism default-constructs
+  /// the FallbackModel without being able to provide arguments.
+  void setCallbacks(MlirMemoryEffectsOpInterfaceCallbacks callbacks) {
+    this->callbacks = callbacks;
+  }
+
+  ~MemoryEffectOpInterfaceFallbackModel() {
+    if (callbacks.destruct)
+      callbacks.destruct(callbacks.userData);
+  }
+
+  static TypeID getInterfaceID() {
+    return MemoryEffectOpInterface::getInterfaceID();
+  }
+
+  static bool classof(const mlir::MemoryEffectOpInterface::Concept *op) {
+    // Enable casting back to the FallbackModel from the Interface. This is
+    // necessary as attachInterface(...) default-constructs the FallbackModel
+    // without being able to pass in the callbacks and returns just the Concept.
+    return true;
+  }
+
+  void
+  getEffects(Operation *op,
+             SmallVectorImpl<MemoryEffects::EffectInstance> &effects) const {
+    assert(callbacks.getEffects && "getEffects callback not set");
+    MlirMemoryEffectInstancesList cEffects = wrap(&effects);
+    callbacks.getEffects(wrap(op), cEffects, callbacks.userData);
+  }
+
+private:
+  MlirMemoryEffectsOpInterfaceCallbacks callbacks;
+};
+
+/// Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
+/// The FallbackModel uses the provided callbacks to implement the interface.
+void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
+    MlirContext ctx, MlirStringRef opName,
+    MlirMemoryEffectsOpInterfaceCallbacks callbacks) {
+  // Look up the operation definition in the context
+  std::optional<RegisteredOperationName> opInfo =
+      RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
+
+  if (!opInfo.has_value()) {
+    llvm::errs() << "Operation '" << unwrap(opName)
+                 << "' not found in context\n";
+    return;
+  }
+
+  // NB: the following default-constructs the FallbackModel _without_ being able
+  // to provide arguments.
+  opInfo->attachInterface<MemoryEffectOpInterfaceFallbackModel>();
+  // Cast to get the underlying FallbackModel and set the callbacks.
+  auto *model = cast<MemoryEffectOpInterfaceFallbackModel>(
+      opInfo->getInterface<MemoryEffectOpInterfaceFallbackModel>());
+  assert(model && "Failed to get MemoryEffectOpInterfaceFallbackModel");
+  model->setCallbacks(callbacks);
+}
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 2a37f3860fe00..c53c59d87d039 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -575,7 +575,7 @@ MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
 
 ResultRange::ResultRange(OpResult result)
     : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
-                  1) {}
+                  0) {}
 
 ResultRange::use_range ResultRange::getUses() const {
   return {use_begin(), use_end()};
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 8ab145ada85dd..13bf7624f5626 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -671,8 +671,10 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
   SOURCES
     DialectTransform.cpp
+    Rewrite.h
   PRIVATE_LINK_LIBS
     LLVMSupport
+    MLIRPythonExtension.Core
   EMBED_CAPI_LINK_LIBS
     MLIRCAPIIR
     MLIRCAPITransformDialect
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index c0e8775149d41..303616080a75b 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -242,6 +242,7 @@ def __str__(self):
     Sequence.register(ir.BlockPredecessors)
     Sequence.register(ir.OperationList)
     Sequence.register(ir.OpOperandList)
+    Sequence.register(ir.OpOpOperandList)
     Sequence.register(ir.OpResultList)
     Sequence.register(ir.OpSuccessors)
     Sequence.register(ir.RegionSequence)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
new file mode 100644
index 0000000000000..85e9ec89a3de3
--- /dev/null
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -0,0 +1,252 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from contextlib import contextmanager
+
+from mlir import ir
+from mlir.ir import TypeAttr, F32Type, UnitAttr
+from mlir.dialects import transform, irdl, func
+from mlir.dialects.transform import AnyOpType, AnyValueType, AnyParamType, structured, interpreter
+
+from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
+
+
+def run(emit_schedule):
+    with ir.Context(), ir.Location.unknown():
+        irdl_module = ir.Module.create()
+        with ir.InsertionPoint(irdl_module.body):
+            my_transform = irdl.dialect("my_transform")
+            with ir.InsertionPoint(my_transform.body):
+                OneOpInOneOpOut.emit_irdl()
+                OpValParamInParamOpValOut.emit_irdl()
+                OpsParamInParamsOut.emit_irdl()
+        irdl_module.operation.verify()
+
+        irdl.load_dialects(irdl_module)
+
+        payload = emit_payload()
+        schedule = emit_schedule()
+
+        print("payload:", payload)
+        print("schedule:", schedule)
+        named_seq = schedule.operation.regions[0].blocks[0].operations[0]
+
+        interpreter.apply_named_sequence(
+            payload,
+            named_seq,
+            schedule,
+        )
+
+    del payload
+    del schedule
+
+
+# Payload used by all tests
+def emit_payload():
+    payload_module = ir.Module.create()
+    with ir.InsertionPoint(payload_module.body):
+
+        @func.FuncOp.from_py_func(
+            F32Type.get(), F32Type.get(), results=[F32Type.get()]
+        )
+        def name_of_func(a, b):
+            func.ReturnOp([b])
+
+    return payload_module
+
+
+class OneOpInOneOpOut:
+    name = "one_op_in_one_op_out"
+
+    def __init__(self, op_arg: AnyOpType):
+        self.op = ir.Operation.create(
+            "my_transform.one_op_in_one_op_out",
+            [AnyOpType.get()],
+            [get_op_result_or_value(op_arg)],
+        )
+
+    @property
+    def result(self):
+        return self.op.results[0]
+
+    @classmethod
+    def emit_irdl(cls):
+        op = irdl.operation_(cls.name)
+        with ir.InsertionPoint(op.body):
+            op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+            irdl.operands_(
+                [op_handle_type],
+                ["arg"],
+                [irdl.Variadicity.single],
+            )
+            irdl.results_([op_handle_type], ["result"], [irdl.Variadicity.single])
+        return op
+
+
+class OpValParamInParamOpValOut:
+    name = "op_val_param_in_param_op_val_out"
+
+    def __init__(
+        self,
+        op_arg: AnyOpType,
+        val: AnyValueType,
+        param: AnyParamType,
+    ):
+        self.op = ir.Operation.create(
+            "my_transform." + self.name,
+            [
+                AnyParamType.get(),
+                AnyOpType.get(),
+                AnyValueType.get(),
+            ],
+            [
+                get_op_result_or_value(op_arg),
+                get_op_result_or_value(val),
+                get_op_result_or_value(param),
+            ],
+        )
+
+    @property
+    def param_res(self):
+        return self.op.results[0]
+
+    @property
+    def op_res(self):
+        return self.op.results[1]
+
+    @property
+    def value_res(self):
+        return self.op.results[2]
+
+    @classmethod
+    def emit_irdl(cls):
+        op = irdl.operation_(cls.name)
+        with ir.InsertionPoint(op.body):
+            op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+            value_handle_type = irdl.is_(TypeAttr.get(AnyValueType.get()))
+            param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
+            irdl.operands_(
+                [op_handle_type, value_handle_type, param_handle_type],
+                ["op_arg", "value_arg", "param_arg"],
+                [
+                    irdl.Variadicity.single,
+                    irdl.Variadicity.single,
+                    irdl.Variadicity.single,
+                ],
+            )
+            irdl.results_(
+                [param_handle_type, op_handle_type, value_handle_type],
+                ["param_res", "op_res", "value_res"],
+                [
+                    irdl.Variadicity.single,
+                    irdl.Variadicity.single,
+                    irdl.Variadicity.single,
+                ],
+            )
+        return op
+
+
+class OpsParamInParamsOut:
+    name = "ops_param_in_params_out"
+
+    def __init__(
+        self,
+        ops: list[AnyOpType],
+        param: AnyParamType,
+    ):
+        self.op = ir.Operation.create(
+            "my_transform." + self.name,
+            [AnyParamType.get()],
+            [get_op_results_or_values(ops), get_op_result_or_value(param)],
+        )
+
+    @property
+    def param_results(self):
+        return self.op.results
+
+    @classmethod
+    def emit_irdl(cls):
+        op = irdl.operation_(cls.name)
+        with ir.InsertionPoint(op.body):
+            op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+            param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
+            irdl.operands_(
+                [op_handle_type, param_handle_type],
+                ["op_args", "param_arg"],
+                [irdl.Variadicity.variadic, irdl.Variadicity.single],
+            )
+            irdl.results_(
+                [param_handle_type], ["param_results"], [irdl.Variadicity.variadic]
+            )
+        return op
+
+
+ at contextmanager
+def schedule_boilerplate():
+    schedule = ir.Module.create()
+    schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+    with ir.InsertionPoint(schedule.body):
+        named_sequence = transform.NamedSequenceOp(
+            "__transform_main",
+            [AnyOpType.get()],
+            [AnyOpType.get()],
+            arg_attrs=[{"transform.consumed": UnitAttr.get()}],
+        )
+        with ir.InsertionPoint(named_sequence.body):
+            yield schedule, named_sequence
+
+
+ at run
+def OneOpInOneOpOutTransformOpInterface():
+    class OneOpInOneOpOutTransformOpInterfaceFallbackModel(
+        transform.TransformOpInterface
+    ):
+        @staticmethod
+        def apply(
+            op_: ir.Operation,
+            rewriter: transform.TransformRewriter,
+            results: transform.TransformResults,
+            state: transform.TransformState,
+        ):
+            targets = state.get_payload_ops(arg := op_.operands[0])
+            target_names = [t.opview.name.value for t in targets]
+            print(
+                f"OneOpInOneOpOutTransformOpInterfaceFallbackModel: target_names={target_names}"
+            )
+            results.set_ops(result := op_.results[0], targets)
+            return transform.DiagnosedSilenceableFailure.Success
+
+        @staticmethod
+        def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+            return False
+
+    OneOpInOneOpOutTransformOpInterfaceFallbackModel.attach(
+        "my_transform.one_op_in_one_op_out", ir.Context.current
+    )
+
+    # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface.
+    class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
+        @staticmethod
+        def get_effects(op_: ir.Operation, effects):
+            transform.only_reads_handle(list(op_.op_operands), effects)
+            transform.produces_handle(list(op_.results), effects)
+
+    MemoryEffectsOpInterfaceFallbackModel.attach(
+        "my_transform.one_op_in_one_op_out", ir.Context.current
+    )
+
+    with schedule_boilerplate() as (schedule, named_seq):
+        print(f"{named_seq=}")
+        func_handle = structured.MatchOp.match_op_names(
+            named_seq.bodyTarget, ["func.func"]
+        ).result
+        func_handle.dump()
+        # CHECK: OneOpInOneOpOutTransformOpInterfaceFallbackModel: target_names=['name_of_func']
+        out = OneOpInOneOpOut(func_handle).result
+        out.dump()
+        print(out.owner)
+        transform.YieldOp([out])
+        named_seq.verify()
+        print("named_seq", named_seq)
+        print("named_seq.parent", named_seq.parent)
+
+    return schedule

>From 369396ca6eb5e0d50eea9f022ec93e0eb9673609 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 20 Jan 2026 05:11:21 -0800
Subject: [PATCH 2/9] Format fix

---
 mlir/include/mlir-c/Dialect/Transform.h             |  2 +-
 mlir/test/python/dialects/transform_op_interface.py | 12 ++++++++----
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 38bd9756176a0..b087482345557 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -10,8 +10,8 @@
 #ifndef MLIR_C_DIALECT_TRANSFORM_H
 #define MLIR_C_DIALECT_TRANSFORM_H
 
-#include "mlir-c/Interfaces.h"
 #include "mlir-c/IR.h"
+#include "mlir-c/Interfaces.h"
 #include "mlir-c/Rewrite.h"
 #include "mlir-c/Support.h"
 
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index 85e9ec89a3de3..f856f698f38cb 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -5,7 +5,13 @@
 from mlir import ir
 from mlir.ir import TypeAttr, F32Type, UnitAttr
 from mlir.dialects import transform, irdl, func
-from mlir.dialects.transform import AnyOpType, AnyValueType, AnyParamType, structured, interpreter
+from mlir.dialects.transform import (
+    AnyOpType,
+    AnyValueType,
+    AnyParamType,
+    structured,
+    interpreter,
+)
 
 from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
 
@@ -45,9 +51,7 @@ def emit_payload():
     payload_module = ir.Module.create()
     with ir.InsertionPoint(payload_module.body):
 
-        @func.FuncOp.from_py_func(
-            F32Type.get(), F32Type.get(), results=[F32Type.get()]
-        )
+        @func.FuncOp.from_py_func(F32Type.get(), F32Type.get(), results=[F32Type.get()])
         def name_of_func(a, b):
             func.ReturnOp([b])
 

>From 508fdf5186bf7ea5b9314059c75dda3e0b8aeaa6 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 26 Jan 2026 04:24:38 -0800
Subject: [PATCH 3/9] More tests

---
 .../python/dialects/transform_op_interface.py | 299 +++++++++++++++++-
 1 file changed, 285 insertions(+), 14 deletions(-)

diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index f856f698f38cb..f1b9d4c23a5de 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -1,11 +1,14 @@
 # RUN: %PYTHON %s | FileCheck %s
 
+from typing import Sequence
+
 from contextlib import contextmanager
 
 from mlir import ir
 from mlir.ir import TypeAttr, F32Type, UnitAttr
-from mlir.dialects import transform, irdl, func
+from mlir.dialects import transform, irdl, func, arith
 from mlir.dialects.transform import (
+    debug as transform_debug,
     AnyOpType,
     AnyValueType,
     AnyParamType,
@@ -22,13 +25,22 @@ def run(emit_schedule):
         with ir.InsertionPoint(irdl_module.body):
             my_transform = irdl.dialect("my_transform")
             with ir.InsertionPoint(my_transform.body):
+                GetNamedAttributeOp.emit_irdl()
                 OneOpInOneOpOut.emit_irdl()
                 OpValParamInParamOpValOut.emit_irdl()
-                OpsParamInParamsOut.emit_irdl()
+                OpsParamsInValuesParamOut.emit_irdl()
         irdl_module.operation.verify()
 
+        print(irdl_module)
         irdl.load_dialects(irdl_module)
 
+        GetNamedAttributeTransformOpInterfaceFallbackModel.attach(
+            "my_transform." + GetNamedAttributeOp.name
+        )
+        GetNamedAttributeMemoryEffectsOpInterfaceFallbackModel.attach(
+            "my_transform." + GetNamedAttributeOp.name
+        )
+
         payload = emit_payload()
         schedule = emit_schedule()
 
@@ -53,11 +65,82 @@ def emit_payload():
 
         @func.FuncOp.from_py_func(F32Type.get(), F32Type.get(), results=[F32Type.get()])
         def name_of_func(a, b):
-            func.ReturnOp([b])
+            c = arith.AddFOp(a, b)
+            d = arith.constant(F32Type.get(), 42.0)
+            e = arith.constant(F32Type.get(), 24.0)
+            func.ReturnOp([c.results[0]])
 
     return payload_module
 
 
+class GetNamedAttributeOp:
+    name = "get_named_attribute"
+
+    def __init__(self, target: AnyOpType, attr_name: str):
+        self.op = ir.Operation.create(
+            "my_transform.get_named_attribute",
+            [AnyParamType.get()],
+            [get_op_result_or_value(target)],
+            {"attr_name": ir.StringAttr.get(attr_name)},
+        )
+
+    @property
+    def attr_as_param(self) -> ir.Value:
+        return self.op.results[0]
+
+    @classmethod
+    def emit_irdl(cls):
+        op = irdl.operation_(cls.name)
+        with ir.InsertionPoint(op.body):
+            op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+            param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
+            name_handle_kind = irdl.base(base_name="#builtin.string")
+            irdl.operands_(
+                [op_handle_type],
+                ["target"],
+                [irdl.Variadicity.single],
+            )
+            irdl.attributes_([name_handle_kind], ["attr_name"])
+            irdl.results_(
+                [param_handle_type], ["attr_as_param"], [irdl.Variadicity.single]
+            )
+        return op
+
+
+class GetNamedAttributeTransformOpInterfaceFallbackModel(
+    transform.TransformOpInterface
+):
+    @staticmethod
+    def apply(
+        op_: ir.Operation,
+        rewriter: transform.TransformRewriter,
+        results: transform.TransformResults,
+        state: transform.TransformState,
+    ):
+        targets = state.get_payload_ops(target := op_.operands[0])
+        associated_attrs = []
+        for target_op in targets:
+            assoc_attr = target_op.attributes.get(op_.attributes["attr_name"].value)
+            if assoc_attr is None:
+                return transform.DiagnosedSilenceableFailure.RecoverableFailure
+            associated_attrs.append(assoc_attr)
+        results.set_params(op_.results[0], associated_attrs)
+        return transform.DiagnosedSilenceableFailure.Success
+
+    @staticmethod
+    def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+        return False
+
+
+class GetNamedAttributeMemoryEffectsOpInterfaceFallbackModel(
+    ir.MemoryEffectsOpInterface
+):
+    @staticmethod
+    def get_effects(op_: ir.Operation, effects):
+        transform.only_reads_handle(list(op_.op_operands), effects)
+        transform.produces_handle(list(op_.results), effects)
+
+
 class OneOpInOneOpOut:
     name = "one_op_in_one_op_out"
 
@@ -149,37 +232,55 @@ def emit_irdl(cls):
         return op
 
 
-class OpsParamInParamsOut:
-    name = "ops_param_in_params_out"
+class OpsParamsInValuesParamOut:
+    name = "ops_params_in_values_param_out"
 
     def __init__(
         self,
-        ops: list[AnyOpType],
-        param: AnyParamType,
+        value_results: Sequence[AnyValueType],
+        ops: Sequence[AnyOpType],
+        params: Sequence[AnyParamType],
     ):
+        def as_i32(x):
+            return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), x)
+
         self.op = ir.Operation.create(
             "my_transform." + self.name,
-            [AnyParamType.get()],
-            [get_op_results_or_values(ops), get_op_result_or_value(param)],
+            list(value_results) + [AnyParamType.get()],
+            list(get_op_results_or_values(ops))
+            + list(get_op_results_or_values(params)),
+            {
+                "operandSegmentSizes": ir.DenseI32ArrayAttr.get(
+                    [len(ops), len(params)]
+                ),
+                "resultSegmentSizes": ir.DenseI32ArrayAttr.get([len(value_results)]),
+            },
         )
 
     @property
-    def param_results(self):
-        return self.op.results
+    def param(self):
+        return self.op.results[-1]
+
+    @property
+    def values(self):
+        return self.op.results[:-1]
 
     @classmethod
     def emit_irdl(cls):
         op = irdl.operation_(cls.name)
         with ir.InsertionPoint(op.body):
             op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+            value_handle_type = irdl.is_(TypeAttr.get(AnyValueType.get()))
             param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
             irdl.operands_(
                 [op_handle_type, param_handle_type],
-                ["op_args", "param_arg"],
-                [irdl.Variadicity.variadic, irdl.Variadicity.single],
+                ["ops", "params"],
+                [irdl.Variadicity.variadic, irdl.Variadicity.variadic],
             )
             irdl.results_(
-                [param_handle_type], ["param_results"], [irdl.Variadicity.variadic]
+                [value_handle_type, param_handle_type],
+                ["value_results", "param"],
+                [irdl.Variadicity.variadic, irdl.Variadicity.single],
             )
         return op
 
@@ -254,3 +355,173 @@ def get_effects(op_: ir.Operation, effects):
         print("named_seq.parent", named_seq.parent)
 
     return schedule
+
+
+ at run
+def OpValParamInParamOpValOutTransformOpInterface():
+    class OpValParamInParamOpValOutTransformOpInterfaceFallbackModel(
+        transform.TransformOpInterface
+    ):
+        @staticmethod
+        def apply(
+            op_: ir.Operation,
+            rewriter: transform.TransformRewriter,
+            results: transform.TransformResults,
+            state: transform.TransformState,
+        ):
+            ops = state.get_payload_ops(op_.operands[0])
+            values = state.get_payload_values(op_.operands[1])
+            params = state.get_params(op_.operands[2])
+            print(
+                f"OpValParamInParamOpValOutTransformOpInterfaceFallbackModel: ops={len(ops)}, values={len(values)}, params={len(params)}"
+            )
+            results.set_params(op_.results[0], params)
+            results.set_ops(op_.results[1], ops)
+            results.set_values(op_.results[2], values)
+            return transform.DiagnosedSilenceableFailure.Success
+
+        @staticmethod
+        def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+            return False
+
+    OpValParamInParamOpValOutTransformOpInterfaceFallbackModel.attach(
+        "my_transform.op_val_param_in_param_op_val_out", ir.Context.current
+    )
+
+    # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface.
+    class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
+        @staticmethod
+        def get_effects(op_: ir.Operation, effects):
+            transform.only_reads_handle(list(op_.op_operands), effects)
+            transform.produces_handle(list(op_.results), effects)
+
+    MemoryEffectsOpInterfaceFallbackModel.attach(
+        "my_transform.op_val_param_in_param_op_val_out", ir.Context.current
+    )
+
+    with schedule_boilerplate() as (schedule, named_seq):
+        func_handle = structured.MatchOp.match_op_names(
+            named_seq.bodyTarget, ["func.func"]
+        ).result
+        addf_handle = structured.MatchOp.match_op_names(
+            named_seq.bodyTarget, ["arith.addf"]
+        ).result
+        func_and_addf = transform.MergeHandlesOp([func_handle, addf_handle])
+        value_handle = transform.GetResultOp(
+            AnyValueType.get(), addf_handle, [0]
+        ).result
+        param_handle = transform.ParamConstantOp(
+            AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
+        ).param
+
+        # CHECK: OpValParamInParamOpValOutTransformOpInterfaceFallbackModel: ops=2, values=1, params=1
+        op_val_param_op = OpValParamInParamOpValOut(
+            func_and_addf, value_handle, param_handle
+        )
+
+        transform.YieldOp([op_val_param_op.op_res])
+        named_seq.verify()
+
+    return schedule
+
+
+ at run
+def OpsParamsInValuesParamOutTransformOpInterface():
+    class OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel(
+        transform.TransformOpInterface
+    ):
+        @staticmethod
+        def apply(
+            op_: ir.Operation,
+            rewriter: transform.TransformRewriter,
+            results: transform.TransformResults,
+            state: transform.TransformState,
+        ):
+            # The last operand is the param. All previous ones are ops.
+            op_handles, param_handles = [], []
+            for operand in op_.operands:
+                if isinstance(operand.type, transform.AnyOpType):
+                    op_handles.append(operand)
+                else:
+                    param_handles.append(operand)
+
+            ops_count = 0
+            value_handles = []
+            for op_handle in op_handles:
+                ops = state.get_payload_ops(op_handle)
+                ops_count += len(ops)
+                value_handles.append(list(op.results[:1] for op in ops))
+
+            param_count = 0
+            param_sum = 0
+            for param_handle in param_handles:
+                params = state.get_params(param_handle)
+                param_count += len(params)
+                param_sum += sum(p.value for p in params)
+
+            print(
+                f"OpsParamInValuesParamOutTransformOpInterfaceFallbackModel: #op_handles={len(op_handles)}, ops_count={ops_count}, #param_handles={len(param_handles)}, param_count={param_count}"
+            )
+
+            assert len(op_.results) + 1 == len(op_handles)
+            for i in range(len(op_.results) - 1):
+                results.set_values(
+                    op_.results[i],
+                    value_handles[i],
+                )
+            results.set_params(
+                op_.results[-1],
+                [ir.IntegerAttr.get(ir.IntegerType.get_signless(32), param_sum)],
+            )
+            return transform.DiagnosedSilenceableFailure.Success
+
+        @staticmethod
+        def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+            return True
+
+    OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel.attach(
+        "my_transform." + OpsParamsInValuesParamOut.name
+    )
+
+    class OpsParamsInParamsOutMemoryEffectsOpInterfaceFallbackModel(
+        ir.MemoryEffectsOpInterface
+    ):
+        @staticmethod
+        def get_effects(op_: ir.Operation, effects):
+            transform.only_reads_handle(list(op_.op_operands), effects)
+            transform.produces_handle(list(op_.results), effects)
+
+    OpsParamsInParamsOutMemoryEffectsOpInterfaceFallbackModel.attach(
+        "my_transform." + OpsParamsInValuesParamOut.name
+    )
+
+    with schedule_boilerplate() as (schedule, named_seq):
+        func_handle = structured.MatchOp.match_op_names(
+            named_seq.bodyTarget, ["func.func"]
+        ).result
+        csts_handle = structured.MatchOp.match_op_names(
+            named_seq.bodyTarget, ["arith.constant"]
+        ).result
+        csts_as_param = GetNamedAttributeOp(csts_handle, "value").attr_as_param
+
+        param_handle = transform.ParamConstantOp(
+            AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 123)
+        ).param
+
+        # CHECK: OpsParamInParamsOutTransformOpInterfaceFallbackModel: op_count=2, param_count=1
+        op = OpsParamsInValuesParamOut(
+            [transform.AnyValueType.get()] * 2 + [transform.AnyParamType.get()],
+            [func_handle, csts_handle],
+            [csts_as_param, param_handle],
+        )
+        print(op.op)
+        # CHECK: Sum of params: 189
+        transform_debug.EmitParamAsRemarkOp(op.param, message="Sum of params")
+
+        transform_debug.EmitRemarkAtOp(op.values[0], message="Value results 0")
+        transform_debug.EmitRemarkAtOp(op.values[1], message="Value results 1")
+
+        transform.YieldOp([func_handle])
+        named_seq.verify()
+
+    return schedule

>From 28609afc3d2d14fc8708f134cef2d13731223948 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 26 Jan 2026 12:57:27 -0800
Subject: [PATCH 4/9] Switch to op def DSL and add more tests

---
 mlir/lib/Bindings/Python/DialectTransform.cpp |  16 +-
 mlir/lib/Bindings/Python/Globals.cpp          |   3 +-
 mlir/python/mlir/dialects/ext.py              |  37 +-
 .../python/dialects/transform_op_interface.py | 571 +++++++-----------
 4 files changed, 266 insertions(+), 361 deletions(-)

diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 7dbeb72a0d2a8..3756518eeb06a 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -194,8 +194,12 @@ class PyTransformOpInterface
       auto pyResults = PyTransformResults(results);
       auto pyState = PyTransformState(state);
 
-      // Invoke `pyClass.apply(op, rewriter, results, state)` as a classmethod.
-      nb::object res = pyApply(op, pyRewriter, pyResults, pyState);
+      // Invoke `pyClass.apply(opview(op), rewriter, results, state)` as a
+      // staticmethod.
+      PyMlirContextRef context =
+          PyMlirContext::forContext(mlirOperationGetContext(op));
+      auto opview = PyOperation::forOperation(context, op)->createOpView();
+      nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
 
       return nb::cast<MlirDiagnosedSilenceableFailure>(res);
     };
@@ -208,8 +212,12 @@ class PyTransformOpInterface
       auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
           nb::getattr(pyClass, "allow_repeated_handle_operands"));
 
-      // Invoke `pyClass.allow_repeated_handle_operands(op)` as a classmethod.
-      nb::object res = pyAllowRepeatedHandleOperands(op);
+      // Invoke `pyClass.allow_repeated_handle_operands(opview(op))` as a
+      // staticmethod.
+      PyMlirContextRef context =
+          PyMlirContext::forContext(mlirOperationGetContext(op));
+      auto opview = PyOperation::forOperation(context, op)->createOpView();
+      nb::object res = pyAllowRepeatedHandleOperands(opview);
 
       return nb::cast<bool>(res);
     };
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index e2e8693ba45f3..590943b445250 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -194,8 +194,7 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
   // Make sure dialect module is loaded.
   auto split = operationName.split('.');
   llvm::StringRef dialectNamespace = split.first;
-  if (!loadDialectModule(dialectNamespace))
-    return std::nullopt;
+  loadDialectModule(dialectNamespace);
 
   nb::ft_lock_guard lock(mutex);
   auto foundIt = operationClassMap.find(operationName);
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 237c27bf62f77..3fff63dd0578b 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -29,10 +29,14 @@
     "Dialect",
     "Operand",
     "Result",
+    "register_dialect",
+    "register_operation",
 ]
 
 Operand = ir.Value
 Result = ir.OpResult
+register_dialect = _cext.register_dialect
+register_operation = _cext.register_operation
 
 
 class ConstraintLoweringContext:
@@ -424,6 +428,11 @@ class AddOp(MyInt.Operation, name="add"):
     ```
     """
 
+    class ExtOperation(Operation):
+        def __init__(*args, **kwargs):
+            raise RuntimeError("Cannot instantiate Dialect.ExtOperation directly.")
+
+
     @classmethod
     def __init_subclass__(cls, name: str, **kwargs):
         cls.name = name
@@ -431,7 +440,7 @@ def __init_subclass__(cls, name: str, **kwargs):
         cls.operations = []
         cls.Operation = type(
             "Operation",
-            (Operation,),
+            (cls.ExtOperation,),
             {"_dialect_obj": cls, "_dialect_name": name},
         )
 
@@ -451,21 +460,25 @@ def _emit_module(cls) -> ir.Module:
         return m
 
     @classmethod
-    def load(cls) -> None:
-        if hasattr(cls, "_mlir_module"):
-            raise RuntimeError(f"Dialect {cls.name} is already loaded.")
+    def load(cls, register=True, context: Optional[ir.Context] = None) -> None:
+        context = context or ir.Context.current
 
-        mlir_module = cls._emit_module()
+        try:
+            context.dialects[cls.name]
+            raise RuntimeError(f"Dialect {cls.name} is already loaded.")
+        except IndexError:
+            pass  # Dialect not loaded yet.
 
+        cls._mlir_module = cls._emit_module()
         pm = PassManager()
         pm.add("canonicalize, cse")
-        pm.run(mlir_module.operation)
-
-        irdl.load_dialects(mlir_module)
+        pm.run(cls._mlir_module.operation)
 
-        _cext.register_dialect(cls)
+        irdl.load_dialects(cls._mlir_module)
 
-        for op in cls.operations:
-            _cext.register_operation(cls)(op)
+        if register:
+            _cext.register_dialect(cls)
 
-        cls._mlir_module = mlir_module
+            register_dialect_operation = _cext.register_operation(cls)
+            for op in cls.operations:
+                register_dialect_operation(op)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index f1b9d4c23a5de..ab919911bdd31 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -1,14 +1,13 @@
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: env PYTHONUNBUFFERED=1 %PYTHON %s 2>&1 | FileCheck %s
 
 from typing import Sequence
 
 from contextlib import contextmanager
 
 from mlir import ir
-from mlir.ir import TypeAttr, F32Type, UnitAttr
-from mlir.dialects import transform, irdl, func, arith
+from mlir.ir import F32Type, UnitAttr
+from mlir.dialects import transform, func, arith, ext
 from mlir.dialects.transform import (
-    debug as transform_debug,
     AnyOpType,
     AnyValueType,
     AnyParamType,
@@ -16,47 +15,31 @@
     interpreter,
 )
 
-from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
+ at ext.register_dialect
+class MyTransform(ext.Dialect, name="my_transform"):
+    pass
 
 
 def run(emit_schedule):
-    with ir.Context(), ir.Location.unknown():
-        irdl_module = ir.Module.create()
-        with ir.InsertionPoint(irdl_module.body):
-            my_transform = irdl.dialect("my_transform")
-            with ir.InsertionPoint(my_transform.body):
-                GetNamedAttributeOp.emit_irdl()
-                OneOpInOneOpOut.emit_irdl()
-                OpValParamInParamOpValOut.emit_irdl()
-                OpsParamsInValuesParamOut.emit_irdl()
-        irdl_module.operation.verify()
-
-        print(irdl_module)
-        irdl.load_dialects(irdl_module)
-
-        GetNamedAttributeTransformOpInterfaceFallbackModel.attach(
-            "my_transform." + GetNamedAttributeOp.name
-        )
-        GetNamedAttributeMemoryEffectsOpInterfaceFallbackModel.attach(
-            "my_transform." + GetNamedAttributeOp.name
-        )
-
+    print(f"Test: {emit_schedule.__name__}")
+    with ir.Context() as ctx, ir.Location.unknown():
         payload = emit_payload()
-        schedule = emit_schedule()
 
-        print("payload:", payload)
-        print("schedule:", schedule)
-        named_seq = schedule.operation.regions[0].blocks[0].operations[0]
+        MyTransform.load(register=False)
+
+        GetNamedAttributeOp.attach_interface_impls(ctx)
+        PrintParamOp.attach_interface_impls(ctx)
+
+        # NB: Other newly defined my_transform ops have their interfaces attached
+        #     in their respective test functions.
+        schedule = emit_schedule()
 
         interpreter.apply_named_sequence(
             payload,
-            named_seq,
+            named_seq := schedule.operation.regions[0].blocks[0].operations[0],
             schedule,
         )
 
-    del payload
-    del schedule
-
 
 # Payload used by all tests
 def emit_payload():
@@ -65,338 +48,200 @@ def emit_payload():
 
         @func.FuncOp.from_py_func(F32Type.get(), F32Type.get(), results=[F32Type.get()])
         def name_of_func(a, b):
-            c = arith.AddFOp(a, b)
-            d = arith.constant(F32Type.get(), 42.0)
-            e = arith.constant(F32Type.get(), 24.0)
-            func.ReturnOp([c.results[0]])
+            c = arith.addf(a, b)
+            i32 = ir.IntegerType.get_signless(32)
+            c42 = arith.constant(i32, 42)
+            c24 = arith.constant(i32, 24)
+            func.ReturnOp([c])
 
     return payload_module
 
 
-class GetNamedAttributeOp:
-    name = "get_named_attribute"
-
-    def __init__(self, target: AnyOpType, attr_name: str):
-        self.op = ir.Operation.create(
-            "my_transform.get_named_attribute",
-            [AnyParamType.get()],
-            [get_op_result_or_value(target)],
-            {"attr_name": ir.StringAttr.get(attr_name)},
+ at contextmanager
+def schedule_boilerplate():
+    schedule = ir.Module.create()
+    schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+    with ir.InsertionPoint(schedule.body):
+        named_sequence = transform.NamedSequenceOp(
+            "__transform_main",
+            [AnyOpType.get()],
+            [AnyOpType.get()],
+            arg_attrs=[{"transform.consumed": UnitAttr.get()}],
         )
-
-    @property
-    def attr_as_param(self) -> ir.Value:
-        return self.op.results[0]
-
-    @classmethod
-    def emit_irdl(cls):
-        op = irdl.operation_(cls.name)
-        with ir.InsertionPoint(op.body):
-            op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
-            param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
-            name_handle_kind = irdl.base(base_name="#builtin.string")
-            irdl.operands_(
-                [op_handle_type],
-                ["target"],
-                [irdl.Variadicity.single],
-            )
-            irdl.attributes_([name_handle_kind], ["attr_name"])
-            irdl.results_(
-                [param_handle_type], ["attr_as_param"], [irdl.Variadicity.single]
-            )
-        return op
-
-
-class GetNamedAttributeTransformOpInterfaceFallbackModel(
-    transform.TransformOpInterface
-):
-    @staticmethod
-    def apply(
-        op_: ir.Operation,
-        rewriter: transform.TransformRewriter,
-        results: transform.TransformResults,
-        state: transform.TransformState,
-    ):
-        targets = state.get_payload_ops(target := op_.operands[0])
-        associated_attrs = []
-        for target_op in targets:
-            assoc_attr = target_op.attributes.get(op_.attributes["attr_name"].value)
-            if assoc_attr is None:
-                return transform.DiagnosedSilenceableFailure.RecoverableFailure
-            associated_attrs.append(assoc_attr)
-        results.set_params(op_.results[0], associated_attrs)
-        return transform.DiagnosedSilenceableFailure.Success
-
-    @staticmethod
-    def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
-        return False
+        with ir.InsertionPoint(named_sequence.body):
+            yield schedule, named_sequence
 
 
-class GetNamedAttributeMemoryEffectsOpInterfaceFallbackModel(
-    ir.MemoryEffectsOpInterface
-):
+# MemoryEffectsOpInterface implementation for TransformOpInterface-implementing ops.
+# Used by all ops defined below.
+class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
     @staticmethod
     def get_effects(op_: ir.Operation, effects):
         transform.only_reads_handle(list(op_.op_operands), effects)
         transform.produces_handle(list(op_.results), effects)
 
 
-class OneOpInOneOpOut:
-    name = "one_op_in_one_op_out"
-
-    def __init__(self, op_arg: AnyOpType):
-        self.op = ir.Operation.create(
-            "my_transform.one_op_in_one_op_out",
-            [AnyOpType.get()],
-            [get_op_result_or_value(op_arg)],
-        )
-
-    @property
-    def result(self):
-        return self.op.results[0]
+# Demonstration of a TransformOpInterface-implementing op that gets named attributes
+# from target ops and produces them as param handles.
+ at ext.register_operation(MyTransform)
+class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
+    target: ext.Operand[transform.AnyOpType]
+    attr_name: ir.StringAttr
+    attr_as_param: ext.Result[transform.AnyParamType[()]]
 
     @classmethod
-    def emit_irdl(cls):
-        op = irdl.operation_(cls.name)
-        with ir.InsertionPoint(op.body):
-            op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
-            irdl.operands_(
-                [op_handle_type],
-                ["arg"],
-                [irdl.Variadicity.single],
-            )
-            irdl.results_([op_handle_type], ["result"], [irdl.Variadicity.single])
-        return op
-
-
-class OpValParamInParamOpValOut:
-    name = "op_val_param_in_param_op_val_out"
-
-    def __init__(
-        self,
-        op_arg: AnyOpType,
-        val: AnyValueType,
-        param: AnyParamType,
-    ):
-        self.op = ir.Operation.create(
-            "my_transform." + self.name,
-            [
-                AnyParamType.get(),
-                AnyOpType.get(),
-                AnyValueType.get(),
-            ],
-            [
-                get_op_result_or_value(op_arg),
-                get_op_result_or_value(val),
-                get_op_result_or_value(param),
-            ],
-        )
+    def attach_interface_impls(cls, ctx=None):
+        cls.TransformOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, ctx)
+        MemoryEffectsOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, ctx)
 
-    @property
-    def param_res(self):
-        return self.op.results[0]
+    class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
+        @staticmethod
+        def apply(
+            op: "GetNamedAttributeOp",
+            rewriter: transform.TransformRewriter,
+            results: transform.TransformResults,
+            state: transform.TransformState,
+        ):
+            target_ops = state.get_payload_ops(op.target)
+            associated_attrs = []
+            for target_op in target_ops:
+                assoc_attr = target_op.attributes.get(op.attr_name.value)
+                if assoc_attr is None:
+                    return transform.DiagnosedSilenceableFailure.RecoverableFailure
+                associated_attrs.append(assoc_attr)
+            results.set_params(op.attr_as_param, associated_attrs)
+            return transform.DiagnosedSilenceableFailure.Success
 
-    @property
-    def op_res(self):
-        return self.op.results[1]
+        @staticmethod
+        def allow_repeated_handle_operands(op: "GetNamedAttributeOp") -> bool:
+            return False
 
-    @property
-    def value_res(self):
-        return self.op.results[2]
 
-    @classmethod
-    def emit_irdl(cls):
-        op = irdl.operation_(cls.name)
-        with ir.InsertionPoint(op.body):
-            op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
-            value_handle_type = irdl.is_(TypeAttr.get(AnyValueType.get()))
-            param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
-            irdl.operands_(
-                [op_handle_type, value_handle_type, param_handle_type],
-                ["op_arg", "value_arg", "param_arg"],
-                [
-                    irdl.Variadicity.single,
-                    irdl.Variadicity.single,
-                    irdl.Variadicity.single,
-                ],
-            )
-            irdl.results_(
-                [param_handle_type, op_handle_type, value_handle_type],
-                ["param_res", "op_res", "value_res"],
-                [
-                    irdl.Variadicity.single,
-                    irdl.Variadicity.single,
-                    irdl.Variadicity.single,
-                ],
-            )
-        return op
-
-
-class OpsParamsInValuesParamOut:
-    name = "ops_params_in_values_param_out"
-
-    def __init__(
-        self,
-        value_results: Sequence[AnyValueType],
-        ops: Sequence[AnyOpType],
-        params: Sequence[AnyParamType],
-    ):
-        def as_i32(x):
-            return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), x)
-
-        self.op = ir.Operation.create(
-            "my_transform." + self.name,
-            list(value_results) + [AnyParamType.get()],
-            list(get_op_results_or_values(ops))
-            + list(get_op_results_or_values(params)),
-            {
-                "operandSegmentSizes": ir.DenseI32ArrayAttr.get(
-                    [len(ops), len(params)]
-                ),
-                "resultSegmentSizes": ir.DenseI32ArrayAttr.get([len(value_results)]),
-            },
-        )
+ at ext.register_operation(MyTransform)
+class PrintParamOp(MyTransform.Operation, name="print_param"):
+    target: ext.Operand[transform.AnyParamType]
+    name: ir.StringAttr
 
-    @property
-    def param(self):
-        return self.op.results[-1]
+    @classmethod
+    def attach_interface_impls(cls, ctx=None):
+        cls.TransformOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, ctx)
+        MemoryEffectsOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, ctx)
 
-    @property
-    def values(self):
-        return self.op.results[:-1]
+    class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
+        @staticmethod
+        def apply(
+            op: "PrintParamOp",
+            rewriter: transform.TransformRewriter,
+            results: transform.TransformResults,
+            state: transform.TransformState,
+        ):
+            target_attrs = state.get_params(op.target)
+            print(f"[[[ IR printer: {op.name.value} ]]]")
+            for attr in target_attrs:
+                print(attr)
+            return transform.DiagnosedSilenceableFailure.Success
 
-    @classmethod
-    def emit_irdl(cls):
-        op = irdl.operation_(cls.name)
-        with ir.InsertionPoint(op.body):
-            op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
-            value_handle_type = irdl.is_(TypeAttr.get(AnyValueType.get()))
-            param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
-            irdl.operands_(
-                [op_handle_type, param_handle_type],
-                ["ops", "params"],
-                [irdl.Variadicity.variadic, irdl.Variadicity.variadic],
-            )
-            irdl.results_(
-                [value_handle_type, param_handle_type],
-                ["value_results", "param"],
-                [irdl.Variadicity.variadic, irdl.Variadicity.single],
-            )
-        return op
+        @staticmethod
+        def allow_repeated_handle_operands(op: "GetNamedAttributeOp") -> bool:
+            return False
 
 
- at contextmanager
-def schedule_boilerplate():
-    schedule = ir.Module.create()
-    schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
-    with ir.InsertionPoint(schedule.body):
-        named_sequence = transform.NamedSequenceOp(
-            "__transform_main",
-            [AnyOpType.get()],
-            [AnyOpType.get()],
-            arg_attrs=[{"transform.consumed": UnitAttr.get()}],
-        )
-        with ir.InsertionPoint(named_sequence.body):
-            yield schedule, named_sequence
+# Syntax for an op with one op handle operand and one op handle result.
+ at ext.register_operation(MyTransform)
+class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
+    target: ext.Operand[transform.AnyOpType]
+    res: ext.Result[transform.AnyOpType[()]]
 
 
+# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
 @run
 def OneOpInOneOpOutTransformOpInterface():
-    class OneOpInOneOpOutTransformOpInterfaceFallbackModel(
-        transform.TransformOpInterface
-    ):
+    # Define an implementation of the TransformOpInterface for OneOpInOneOpOut.
+    class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
         @staticmethod
         def apply(
-            op_: ir.Operation,
+            op: OneOpInOneOpOut,
             rewriter: transform.TransformRewriter,
             results: transform.TransformResults,
             state: transform.TransformState,
         ):
-            targets = state.get_payload_ops(arg := op_.operands[0])
-            target_names = [t.opview.name.value for t in targets]
-            print(
-                f"OneOpInOneOpOutTransformOpInterfaceFallbackModel: target_names={target_names}"
-            )
-            results.set_ops(result := op_.results[0], targets)
+            target_ops = state.get_payload_ops(op.target)
+            target_names = [t.opview.name.value for t in target_ops]
+            print(f"OneOpInOneOpOutTransformOpInterface: target_names={target_names}")
+            results.set_ops(op.res, target_ops)
             return transform.DiagnosedSilenceableFailure.Success
 
         @staticmethod
-        def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+        def allow_repeated_handle_operands(op: OneOpInOneOpOut) -> bool:
             return False
 
-    OneOpInOneOpOutTransformOpInterfaceFallbackModel.attach(
-        "my_transform.one_op_in_one_op_out", ir.Context.current
-    )
-
-    # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface.
-    class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
-        @staticmethod
-        def get_effects(op_: ir.Operation, effects):
-            transform.only_reads_handle(list(op_.op_operands), effects)
-            transform.produces_handle(list(op_.results), effects)
+    # Attach the interface implementation to the op.
+    TransformOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
 
-    MemoryEffectsOpInterfaceFallbackModel.attach(
-        "my_transform.one_op_in_one_op_out", ir.Context.current
-    )
+    # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
+    MemoryEffectsOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
 
     with schedule_boilerplate() as (schedule, named_seq):
-        print(f"{named_seq=}")
         func_handle = structured.MatchOp.match_op_names(
             named_seq.bodyTarget, ["func.func"]
         ).result
-        func_handle.dump()
-        # CHECK: OneOpInOneOpOutTransformOpInterfaceFallbackModel: target_names=['name_of_func']
+        # CHECK: OneOpInOneOpOutTransformOpInterface: target_names=['name_of_func']
         out = OneOpInOneOpOut(func_handle).result
-        out.dump()
-        print(out.owner)
+        # CHECK: Output handle from OneOpInOneOpOut
+        # CHECK-NEXT: func.func @
+        transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut")
         transform.YieldOp([out])
-        named_seq.verify()
-        print("named_seq", named_seq)
-        print("named_seq.parent", named_seq.parent)
 
     return schedule
 
 
+ at ext.register_operation(MyTransform)
+class OpValParamInParamOpValOut(
+    MyTransform.Operation, name="op_val_param_in_param_op_val_out"
+):
+    # operands
+    op_arg: ext.Operand[transform.AnyOpType]
+    val_arg: ext.Operand[transform.AnyValueType]
+    param_arg: ext.Operand[transform.AnyParamType]
+    # results
+    param_res: ext.Result[transform.AnyParamType[()]]
+    op_res: ext.Result[transform.AnyOpType[()]]
+    value_res: ext.Result[transform.AnyValueType[()]]
+
+
+# CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
 @run
 def OpValParamInParamOpValOutTransformOpInterface():
-    class OpValParamInParamOpValOutTransformOpInterfaceFallbackModel(
-        transform.TransformOpInterface
-    ):
+    class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
         @staticmethod
         def apply(
-            op_: ir.Operation,
+            op: OpValParamInParamOpValOut,
             rewriter: transform.TransformRewriter,
             results: transform.TransformResults,
             state: transform.TransformState,
         ):
-            ops = state.get_payload_ops(op_.operands[0])
-            values = state.get_payload_values(op_.operands[1])
-            params = state.get_params(op_.operands[2])
+            ops = state.get_payload_ops(op.op_arg)
+            values = state.get_payload_values(op.val_arg)
+            params = state.get_params(op.param_arg)
             print(
-                f"OpValParamInParamOpValOutTransformOpInterfaceFallbackModel: ops={len(ops)}, values={len(values)}, params={len(params)}"
+                f"OpValParamInParamOpValOutTransformOpInterface: ops={len(ops)}, values={len(values)}, params={len(params)}"
             )
-            results.set_params(op_.results[0], params)
-            results.set_ops(op_.results[1], ops)
-            results.set_values(op_.results[2], values)
+            results.set_params(op.param_res, params)
+            results.set_ops(op.op_res, ops)
+            results.set_values(op.value_res, values)
             return transform.DiagnosedSilenceableFailure.Success
 
         @staticmethod
-        def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+        def allow_repeated_handle_operands(op: OpValParamInParamOpValOut) -> bool:
             return False
 
-    OpValParamInParamOpValOutTransformOpInterfaceFallbackModel.attach(
-        "my_transform.op_val_param_in_param_op_val_out", ir.Context.current
+    TransformOpInterfaceFallbackModel.attach(
+        OpValParamInParamOpValOut.OPERATION_NAME, ir.Context.current
     )
 
-    # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface.
-    class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
-        @staticmethod
-        def get_effects(op_: ir.Operation, effects):
-            transform.only_reads_handle(list(op_.op_operands), effects)
-            transform.produces_handle(list(op_.results), effects)
-
+    # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
     MemoryEffectsOpInterfaceFallbackModel.attach(
-        "my_transform.op_val_param_in_param_op_val_out", ir.Context.current
+        OpValParamInParamOpValOut.OPERATION_NAME, ir.Context.current
     )
 
     with schedule_boilerplate() as (schedule, named_seq):
@@ -414,10 +259,36 @@ def get_effects(op_: ir.Operation, effects):
             AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
         ).param
 
-        # CHECK: OpValParamInParamOpValOutTransformOpInterfaceFallbackModel: ops=2, values=1, params=1
+        # CHECK: OpValParamInParamOpValOutTransformOpInterface: ops=2, values=1, params=1
         op_val_param_op = OpValParamInParamOpValOut(
             func_and_addf, value_handle, param_handle
         )
+        # CHECK: Ops passed through OpValParamInParamOpValOut:
+        # CHECK-NEXT: func.func
+        # CHECK: arith.addf
+        transform.PrintOp(
+            target=op_val_param_op.op_res,
+            name="Ops passed through OpValParamInParamOpValOut:",
+        )
+
+        # CHECK: Ops defining values passed through OpValParamInParamOpValOut:
+        # CHECK-NEXT: arith.addf
+        addf_as_res = transform.GetDefiningOp(
+            transform.AnyOpType.get(), op_val_param_op.value_res
+        ).result
+        transform.PrintOp(
+            target=addf_as_res,
+            name="Ops defining values passed through OpValParamInParamOpValOut:",
+        )
+
+        # CHECK: Parameter passed through OpValParamInParamOpValOut:
+        # CHECK-NEXT: 42 : i32
+        PrintParamOp(
+            op_val_param_op.param_res,
+            name=ir.StringAttr.get(
+                "Parameter passed through OpValParamInParamOpValOut:"
+            ),
+        )
 
         transform.YieldOp([op_val_param_op.op_res])
         named_seq.verify()
@@ -425,74 +296,64 @@ def get_effects(op_: ir.Operation, effects):
     return schedule
 
 
+ at ext.register_operation(MyTransform)
+class OpsParamsInValuesParamOut(
+    MyTransform.Operation, name="ops_params_in_values_param_out"
+):
+    # operands
+    ops: Sequence[ext.Operand[transform.AnyOpType]]
+    params: Sequence[ext.Operand[transform.AnyParamType]]
+    # results
+    values: Sequence[ext.Result[transform.AnyValueType]]
+    param: ext.Result[transform.AnyParamType]
+
+
+# CHECK-LABEL: Test: OpsParamsInValuesParamOutTransformOpInterface
 @run
 def OpsParamsInValuesParamOutTransformOpInterface():
-    class OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel(
-        transform.TransformOpInterface
-    ):
+    class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
         @staticmethod
         def apply(
-            op_: ir.Operation,
+            op: OpsParamsInValuesParamOut,
             rewriter: transform.TransformRewriter,
             results: transform.TransformResults,
             state: transform.TransformState,
         ):
-            # The last operand is the param. All previous ones are ops.
-            op_handles, param_handles = [], []
-            for operand in op_.operands:
-                if isinstance(operand.type, transform.AnyOpType):
-                    op_handles.append(operand)
-                else:
-                    param_handles.append(operand)
-
             ops_count = 0
             value_handles = []
-            for op_handle in op_handles:
+            for op_handle in op.ops:
                 ops = state.get_payload_ops(op_handle)
                 ops_count += len(ops)
-                value_handles.append(list(op.results[:1] for op in ops))
+                value_handles.append([i for op in ops for i in op.results])
 
             param_count = 0
             param_sum = 0
-            for param_handle in param_handles:
+            for param_handle in op.params:
                 params = state.get_params(param_handle)
                 param_count += len(params)
                 param_sum += sum(p.value for p in params)
 
             print(
-                f"OpsParamInValuesParamOutTransformOpInterfaceFallbackModel: #op_handles={len(op_handles)}, ops_count={ops_count}, #param_handles={len(param_handles)}, param_count={param_count}"
+                f"OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel: op_count={ops_count}, param_count={param_count}"
             )
 
-            assert len(op_.results) + 1 == len(op_handles)
-            for i in range(len(op_.results) - 1):
-                results.set_values(
-                    op_.results[i],
-                    value_handles[i],
-                )
+            assert len(op.values) == len(op.ops)
+            for value_res_handle, value_vector in zip(op.values, value_handles):
+                results.set_values(value_res_handle, value_vector)
             results.set_params(
-                op_.results[-1],
+                op.param,
                 [ir.IntegerAttr.get(ir.IntegerType.get_signless(32), param_sum)],
             )
             return transform.DiagnosedSilenceableFailure.Success
 
         @staticmethod
-        def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+        def allow_repeated_handle_operands(op: OpsParamsInValuesParamOut) -> bool:
             return True
 
-    OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel.attach(
-        "my_transform." + OpsParamsInValuesParamOut.name
-    )
+    TransformOpInterfaceFallbackModel.attach(OpsParamsInValuesParamOut.OPERATION_NAME)
 
-    class OpsParamsInParamsOutMemoryEffectsOpInterfaceFallbackModel(
-        ir.MemoryEffectsOpInterface
-    ):
-        @staticmethod
-        def get_effects(op_: ir.Operation, effects):
-            transform.only_reads_handle(list(op_.op_operands), effects)
-            transform.produces_handle(list(op_.results), effects)
-
-    OpsParamsInParamsOutMemoryEffectsOpInterfaceFallbackModel.attach(
-        "my_transform." + OpsParamsInValuesParamOut.name
+    MemoryEffectsOpInterfaceFallbackModel.attach(
+        OpsParamsInValuesParamOut.OPERATION_NAME
     )
 
     with schedule_boilerplate() as (schedule, named_seq):
@@ -502,24 +363,48 @@ def get_effects(op_: ir.Operation, effects):
         csts_handle = structured.MatchOp.match_op_names(
             named_seq.bodyTarget, ["arith.constant"]
         ).result
-        csts_as_param = GetNamedAttributeOp(csts_handle, "value").attr_as_param
+        csts_as_param = GetNamedAttributeOp(
+            csts_handle, attr_name=ir.StringAttr.get("value")
+        ).attr_as_param
 
         param_handle = transform.ParamConstantOp(
             AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 123)
         ).param
 
-        # CHECK: OpsParamInParamsOutTransformOpInterfaceFallbackModel: op_count=2, param_count=1
+        # CHECK: OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel: op_count=3, param_count=3
         op = OpsParamsInValuesParamOut(
-            [transform.AnyValueType.get()] * 2 + [transform.AnyParamType.get()],
+            [transform.AnyValueType.get()] * 2,
+            transform.AnyParamType.get(),
             [func_handle, csts_handle],
             [csts_as_param, param_handle],
         )
-        print(op.op)
-        # CHECK: Sum of params: 189
-        transform_debug.EmitParamAsRemarkOp(op.param, message="Sum of params")
 
-        transform_debug.EmitRemarkAtOp(op.values[0], message="Value results 0")
-        transform_debug.EmitRemarkAtOp(op.values[1], message="Value results 1")
+        empty_handle = transform.GetDefiningOp(transform.AnyOpType.get(), op.values[0])
+        # CHECK: Defining op of value result 0
+        transform.PrintOp(
+            target=empty_handle.result, name="Defining op of value result 0"
+        )
+        # NB: no result on the func.func, so output is expected to be empty
+        cst1_res, cst2_res = transform.SplitHandleOp(
+            [transform.AnyValueType.get()] * 2, op.values[1]
+        ).results
+
+        cst1_again = transform.GetDefiningOp(transform.AnyOpType.get(), cst1_res)
+        # CHECK-NEXT: Defining op of first constant
+        # CHECK-NEXT: arith.constant 42 : i32
+        transform.PrintOp(
+            target=cst1_again.result, name="Defining op of first constant"
+        )
+        cst2_again = transform.GetDefiningOp(transform.AnyOpType.get(), cst2_res)
+        # CHECK-NEXT: Defining op of second constant
+        # CHECK-NEXT: arith.constant 24 : i32
+        transform.PrintOp(
+            target=cst2_again.result, name="Defining op of second constant"
+        )
+
+        # CHECK: Sum of params:
+        # CHECK-NEXT: 189 : i32
+        PrintParamOp(op.param, name=ir.StringAttr.get("Sum of params:"))
 
         transform.YieldOp([func_handle])
         named_seq.verify()

>From d4efd172b3fb7ac483dc60234e00b9afd32f05be Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 27 Jan 2026 06:48:10 -0800
Subject: [PATCH 5/9] New rewriter test and general cleanup

---
 mlir/include/mlir-c/Dialect/Transform.h       |   5 +
 mlir/lib/Bindings/Python/DialectTransform.cpp |  23 ++-
 mlir/lib/CAPI/Dialect/Transform.cpp           |   7 +
 .../python/dialects/transform_op_interface.py | 131 +++++++++++++-----
 4 files changed, 126 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index b087482345557..674d5f8b7b72d 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -216,6 +216,11 @@ MLIR_CAPI_EXPORTED void
 mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
                              MlirMemoryEffectInstancesList effects);
 
+/// Helper to mark operands as consuming handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
+                            MlirMemoryEffectInstancesList effects);
+
 /// Helper to mark results as producing handles.
 MLIR_CAPI_EXPORTED void
 mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 3756518eeb06a..16a025a0581e5 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -106,7 +106,10 @@ class PyTransformState {
     mlirTransformStateForEachPayloadOp(
         state, value,
         [](MlirOperation op, void *userData) {
-          static_cast<nanobind::list *>(userData)->append(op);
+          PyMlirContextRef context =
+              PyMlirContext::forContext(mlirOperationGetContext(op));
+          auto opview = PyOperation::forOperation(context, op)->createOpView();
+          static_cast<nanobind::list *>(userData)->append(opview);
         },
         &result);
     return result;
@@ -379,18 +382,24 @@ struct ParamType : PyConcreteType<ParamType> {
 //===----------------------------------------------------------------------===//
 
 namespace {
-void onlyReadsHandle(nb::list &operands, PyMemoryEffectsInstanceList effects) {
+void onlyReadsHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects) {
   std::vector<MlirOpOperand> operandsVec;
-  operandsVec.reserve(operands.size());
   for (auto operand : operands)
     operandsVec.push_back(nb::cast<PyOpOperand>(operand));
   mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
                                effects.effects);
 };
 
-void producesHandle(nb::list &results, PyMemoryEffectsInstanceList effects) {
+void consumesHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects) {
+  std::vector<MlirOpOperand> operandsVec;
+  for (auto operand : operands)
+    operandsVec.push_back(nb::cast<PyOpOperand>(operand));
+  mlirTransformConsumesHandle(operandsVec.data(), operandsVec.size(),
+                              effects.effects);
+};
+
+void producesHandle(nb::iterable &results, PyMemoryEffectsInstanceList effects) {
   std::vector<MlirOpResult> resultsVec;
-  resultsVec.reserve(results.size());
   for (auto result : results)
     resultsVec.push_back(nb::cast<PyOpResult>(result));
   mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
@@ -420,6 +429,10 @@ static void populateDialectTransformSubmodule(nb::module_ &m) {
         "Mark operands as only reading handles.", nb::arg("operands"),
         nb::arg("effects"));
 
+  m.def("consumes_handle", consumesHandle,
+        "Mark operands as consuming handles.", nb::arg("operands"),
+        nb::arg("effects"));
+
   m.def("produces_handle", producesHandle, "Mark results as producing handles.",
         nb::arg("results"), nb::arg("effects"));
 }
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index fdb27ec5cdc89..121de43c1d065 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -312,6 +312,13 @@ void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
   transform::onlyReadsHandle(operandArray, *unwrap(effects));
 }
 
+/// Set the effect for the operands to consuming the transform handles.
+void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
+                                 MlirMemoryEffectInstancesList effects) {
+  MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
+  transform::consumesHandle(operandArray, *unwrap(effects));
+}
+
 /// Set the effect for the results to that they produce transform handles.
 void mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
                                  MlirMemoryEffectInstancesList effects) {
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index ab919911bdd31..b2bc97ca3daeb 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -5,9 +5,9 @@
 from contextlib import contextmanager
 
 from mlir import ir
-from mlir.ir import F32Type, UnitAttr
-from mlir.dialects import transform, func, arith, ext
+from mlir.dialects import index, transform, func, arith, ext
 from mlir.dialects.transform import (
+    DiagnosedSilenceableFailure,
     AnyOpType,
     AnyValueType,
     AnyParamType,
@@ -15,6 +15,7 @@
     interpreter,
 )
 
+
 @ext.register_dialect
 class MyTransform(ext.Dialect, name="my_transform"):
     pass
@@ -36,7 +37,7 @@ def run(emit_schedule):
 
         interpreter.apply_named_sequence(
             payload,
-            named_seq := schedule.operation.regions[0].blocks[0].operations[0],
+            _named_seq := schedule.operation.regions[0].blocks[0].operations[0],
             schedule,
         )
 
@@ -45,13 +46,14 @@ def run(emit_schedule):
 def emit_payload():
     payload_module = ir.Module.create()
     with ir.InsertionPoint(payload_module.body):
+        f32 = ir.F32Type.get()
 
-        @func.FuncOp.from_py_func(F32Type.get(), F32Type.get(), results=[F32Type.get()])
+        @func.FuncOp.from_py_func(f32, f32, results=[f32])
         def name_of_func(a, b):
             c = arith.addf(a, b)
             i32 = ir.IntegerType.get_signless(32)
-            c42 = arith.constant(i32, 42)
-            c24 = arith.constant(i32, 24)
+            arith.constant(i32, 42)
+            arith.constant(i32, 24)
             func.ReturnOp([c])
 
     return payload_module
@@ -66,19 +68,19 @@ def schedule_boilerplate():
             "__transform_main",
             [AnyOpType.get()],
             [AnyOpType.get()],
-            arg_attrs=[{"transform.consumed": UnitAttr.get()}],
+            arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
         )
         with ir.InsertionPoint(named_sequence.body):
             yield schedule, named_sequence
 
 
 # MemoryEffectsOpInterface implementation for TransformOpInterface-implementing ops.
-# Used by all ops defined below.
+# Used by most ops defined below.
 class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
     @staticmethod
-    def get_effects(op_: ir.Operation, effects):
-        transform.only_reads_handle(list(op_.op_operands), effects)
-        transform.produces_handle(list(op_.results), effects)
+    def get_effects(op: ir.Operation, effects):
+        transform.only_reads_handle(op.op_operands, effects)
+        transform.produces_handle(op.results, effects)
 
 
 # Demonstration of a TransformOpInterface-implementing op that gets named attributes
@@ -98,22 +100,22 @@ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
         @staticmethod
         def apply(
             op: "GetNamedAttributeOp",
-            rewriter: transform.TransformRewriter,
+            _rewriter: transform.TransformRewriter,
             results: transform.TransformResults,
             state: transform.TransformState,
-        ):
+        ) -> DiagnosedSilenceableFailure:
             target_ops = state.get_payload_ops(op.target)
             associated_attrs = []
             for target_op in target_ops:
                 assoc_attr = target_op.attributes.get(op.attr_name.value)
                 if assoc_attr is None:
-                    return transform.DiagnosedSilenceableFailure.RecoverableFailure
+                    return DiagnosedSilenceableFailure.RecoverableFailure
                 associated_attrs.append(assoc_attr)
             results.set_params(op.attr_as_param, associated_attrs)
-            return transform.DiagnosedSilenceableFailure.Success
+            return DiagnosedSilenceableFailure.Success
 
         @staticmethod
-        def allow_repeated_handle_operands(op: "GetNamedAttributeOp") -> bool:
+        def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
             return False
 
 
@@ -134,15 +136,15 @@ def apply(
             rewriter: transform.TransformRewriter,
             results: transform.TransformResults,
             state: transform.TransformState,
-        ):
+        ) -> DiagnosedSilenceableFailure:
             target_attrs = state.get_params(op.target)
             print(f"[[[ IR printer: {op.name.value} ]]]")
             for attr in target_attrs:
                 print(attr)
-            return transform.DiagnosedSilenceableFailure.Success
+            return DiagnosedSilenceableFailure.Success
 
         @staticmethod
-        def allow_repeated_handle_operands(op: "GetNamedAttributeOp") -> bool:
+        def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
             return False
 
 
@@ -156,23 +158,23 @@ class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
 # CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
 @run
 def OneOpInOneOpOutTransformOpInterface():
-    # Define an implementation of the TransformOpInterface for OneOpInOneOpOut.
+    # Define a simple passthrough implementation of the TransformOpInterface for OneOpInOneOpOut.
     class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
         @staticmethod
         def apply(
             op: OneOpInOneOpOut,
-            rewriter: transform.TransformRewriter,
+            _rewriter: transform.TransformRewriter,
             results: transform.TransformResults,
             state: transform.TransformState,
-        ):
+        ) -> DiagnosedSilenceableFailure:
             target_ops = state.get_payload_ops(op.target)
-            target_names = [t.opview.name.value for t in target_ops]
+            target_names = [t.name.value for t in target_ops]
             print(f"OneOpInOneOpOutTransformOpInterface: target_names={target_names}")
             results.set_ops(op.res, target_ops)
-            return transform.DiagnosedSilenceableFailure.Success
+            return DiagnosedSilenceableFailure.Success
 
         @staticmethod
-        def allow_repeated_handle_operands(op: OneOpInOneOpOut) -> bool:
+        def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
             return False
 
     # Attach the interface implementation to the op.
@@ -188,13 +190,72 @@ def allow_repeated_handle_operands(op: OneOpInOneOpOut) -> bool:
         # CHECK: OneOpInOneOpOutTransformOpInterface: target_names=['name_of_func']
         out = OneOpInOneOpOut(func_handle).result
         # CHECK: Output handle from OneOpInOneOpOut
-        # CHECK-NEXT: func.func @
+        # CHECK-NEXT: func.func @name_of_func
         transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut")
         transform.YieldOp([out])
 
     return schedule
 
 
+# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterfaceRewriterImpl
+ at run
+def OneOpInOneOpOutTransformOpInterfaceRewriterImpl():
+    # Define an implementation of the TransformOpInterface for OneOpInOneOpOut where
+    # the rewriter is used (to replace arith.constants by index.constants).
+    class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
+        @staticmethod
+        def apply(
+            op: OneOpInOneOpOut,
+            rewriter: transform.TransformRewriter,
+            results: transform.TransformResults,
+            state: transform.TransformState,
+        ) -> DiagnosedSilenceableFailure:
+            result_ops = []
+            for target_op in state.get_payload_ops(op.target):
+                with ir.InsertionPoint(target_op):
+                    index_version = index.constant(target_op.value.value)
+                result_ops.append(index_version.owner)
+                rewriter.replace_op(target_op, [index_version])
+            results.set_ops(op.res, result_ops)
+            return DiagnosedSilenceableFailure.Success
+
+        @staticmethod
+        def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
+            return False
+
+    # Attach the interface implementation to the op.
+    TransformOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
+
+    # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
+    MemoryEffectsOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
+
+    with schedule_boilerplate() as (schedule, named_seq):
+        func_handle = structured.MatchOp.match_op_names(
+            named_seq.bodyTarget, ["func.func"]
+        ).result
+        csts_handle = structured.MatchOp.match_op_names(
+            named_seq.bodyTarget, ["arith.constant"]
+        ).result
+        # CHECK: Before replacement:
+        # CHECK-NOT: index.constant
+        # CHECK-DAG: arith.constant 42 : i32
+        # CHECK-DAG: arith.constant 24 : i32
+        transform.PrintOp(target=func_handle, name="Before replacement:")
+        out = OneOpInOneOpOut(csts_handle).result
+        # CHECK: After replacement:
+        # CHECK-NOT: arith.constant
+        # CHECK-DAG: index.constant 42
+        # CHECK-DAG: index.constant 24
+        transform.PrintOp(target=func_handle, name="After replacement:")
+        # CHECK: Output handle from OneOpInOneOpOut:
+        # CHECK-NEXT: index.constant 42
+        # CHECK-NEXT: index.constant 24
+        transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut:")
+        transform.YieldOp([out])
+
+    return schedule
+
+
 @ext.register_operation(MyTransform)
 class OpValParamInParamOpValOut(
     MyTransform.Operation, name="op_val_param_in_param_op_val_out"
@@ -216,10 +277,10 @@ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
         @staticmethod
         def apply(
             op: OpValParamInParamOpValOut,
-            rewriter: transform.TransformRewriter,
+            _rewriter: transform.TransformRewriter,
             results: transform.TransformResults,
             state: transform.TransformState,
-        ):
+        ) -> DiagnosedSilenceableFailure:
             ops = state.get_payload_ops(op.op_arg)
             values = state.get_payload_values(op.val_arg)
             params = state.get_params(op.param_arg)
@@ -229,10 +290,10 @@ def apply(
             results.set_params(op.param_res, params)
             results.set_ops(op.op_res, ops)
             results.set_values(op.value_res, values)
-            return transform.DiagnosedSilenceableFailure.Success
+            return DiagnosedSilenceableFailure.Success
 
         @staticmethod
-        def allow_repeated_handle_operands(op: OpValParamInParamOpValOut) -> bool:
+        def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool:
             return False
 
     TransformOpInterfaceFallbackModel.attach(
@@ -315,10 +376,10 @@ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
         @staticmethod
         def apply(
             op: OpsParamsInValuesParamOut,
-            rewriter: transform.TransformRewriter,
+            _rewriter: transform.TransformRewriter,
             results: transform.TransformResults,
             state: transform.TransformState,
-        ):
+        ) -> DiagnosedSilenceableFailure:
             ops_count = 0
             value_handles = []
             for op_handle in op.ops:
@@ -344,11 +405,11 @@ def apply(
                 op.param,
                 [ir.IntegerAttr.get(ir.IntegerType.get_signless(32), param_sum)],
             )
-            return transform.DiagnosedSilenceableFailure.Success
+            return DiagnosedSilenceableFailure.Success
 
         @staticmethod
-        def allow_repeated_handle_operands(op: OpsParamsInValuesParamOut) -> bool:
-            return True
+        def allow_repeated_handle_operands(_op: OpsParamsInValuesParamOut) -> bool:
+            return False
 
     TransformOpInterfaceFallbackModel.attach(OpsParamsInValuesParamOut.OPERATION_NAME)
 

>From 40f692331d9a5cc7e89e49288239c0763583ecb7 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 27 Jan 2026 10:58:38 -0800
Subject: [PATCH 6/9] Add docstrings to tests

---
 .../python/dialects/transform_op_interface.py | 25 +++++++++++++++++--
 1 file changed, 23 insertions(+), 2 deletions(-)

diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index b2bc97ca3daeb..c707883866a03 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -158,6 +158,11 @@ class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
 # CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
 @run
 def OneOpInOneOpOutTransformOpInterface():
+    """Tests a simple passthrough interface implementation.
+
+    Checks that the target ops are correctly identified and passed as results.
+    """
+
     # Define a simple passthrough implementation of the TransformOpInterface for OneOpInOneOpOut.
     class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
         @staticmethod
@@ -200,8 +205,12 @@ def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
 # CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterfaceRewriterImpl
 @run
 def OneOpInOneOpOutTransformOpInterfaceRewriterImpl():
-    # Define an implementation of the TransformOpInterface for OneOpInOneOpOut where
-    # the rewriter is used (to replace arith.constants by index.constants).
+    """Tests an interface implementation using the rewriter to modify the IR.
+
+    Checks that `arith.constant` ops are replaced by `index.constant` ops and
+    that the results are correctly updated.
+    """
+
     class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
         @staticmethod
         def apply(
@@ -273,6 +282,12 @@ class OpValParamInParamOpValOut(
 # CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
 @run
 def OpValParamInParamOpValOutTransformOpInterface():
+    """Tests an interface implementation involving Op, Value, and Param types.
+
+    Checks that payload ops, values, and parameters are correctly permuted and
+    propagated and accessible from the (permuted) result handles.
+    """
+
     class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
         @staticmethod
         def apply(
@@ -372,6 +387,12 @@ class OpsParamsInValuesParamOut(
 # CHECK-LABEL: Test: OpsParamsInValuesParamOutTransformOpInterface
 @run
 def OpsParamsInValuesParamOutTransformOpInterface():
+    """Tests an interface with variadic Op and Param operands and variadic Value results.
+
+    Checks correct handling of multiple handles, parameter aggregation, and
+    result generation.
+    """
+
     class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
         @staticmethod
         def apply(

>From e5eed4f2df6028e1ca2ac91363f8c8869f0f17ad Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 27 Jan 2026 12:04:50 -0800
Subject: [PATCH 7/9] Formatting and minor load change

---
 mlir/lib/Bindings/Python/DialectTransform.cpp |  9 ++++++---
 mlir/python/mlir/dialects/ext.py              | 12 +++---------
 2 files changed, 9 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 16a025a0581e5..ecb911fd7a3ba 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -382,7 +382,8 @@ struct ParamType : PyConcreteType<ParamType> {
 //===----------------------------------------------------------------------===//
 
 namespace {
-void onlyReadsHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects) {
+void onlyReadsHandle(nb::iterable &operands,
+                     PyMemoryEffectsInstanceList effects) {
   std::vector<MlirOpOperand> operandsVec;
   for (auto operand : operands)
     operandsVec.push_back(nb::cast<PyOpOperand>(operand));
@@ -390,7 +391,8 @@ void onlyReadsHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects
                                effects.effects);
 };
 
-void consumesHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects) {
+void consumesHandle(nb::iterable &operands,
+                    PyMemoryEffectsInstanceList effects) {
   std::vector<MlirOpOperand> operandsVec;
   for (auto operand : operands)
     operandsVec.push_back(nb::cast<PyOpOperand>(operand));
@@ -398,7 +400,8 @@ void consumesHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects)
                               effects.effects);
 };
 
-void producesHandle(nb::iterable &results, PyMemoryEffectsInstanceList effects) {
+void producesHandle(nb::iterable &results,
+                    PyMemoryEffectsInstanceList effects) {
   std::vector<MlirOpResult> resultsVec;
   for (auto result : results)
     resultsVec.push_back(nb::cast<PyOpResult>(result));
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 3fff63dd0578b..f60dd62c5e6c8 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -432,7 +432,6 @@ class ExtOperation(Operation):
         def __init__(*args, **kwargs):
             raise RuntimeError("Cannot instantiate Dialect.ExtOperation directly.")
 
-
     @classmethod
     def __init_subclass__(cls, name: str, **kwargs):
         cls.name = name
@@ -460,14 +459,9 @@ def _emit_module(cls) -> ir.Module:
         return m
 
     @classmethod
-    def load(cls, register=True, context: Optional[ir.Context] = None) -> None:
-        context = context or ir.Context.current
-
-        try:
-            context.dialects[cls.name]
-            raise RuntimeError(f"Dialect {cls.name} is already loaded.")
-        except IndexError:
-            pass  # Dialect not loaded yet.
+    def load(cls, register=True) -> None:
+        if hasattr(cls, "_mlir_module"):
+            return
 
         cls._mlir_module = cls._emit_module()
         pm = PassManager()

>From 7372de4b797fc4db0a7190d90b60fc9c1b396635 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 27 Jan 2026 12:41:46 -0800
Subject: [PATCH 8/9] Another ext.py fix

---
 mlir/python/mlir/dialects/ext.py                    | 4 ++--
 mlir/test/python/dialects/transform_op_interface.py | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index f60dd62c5e6c8..45a7249399cca 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -459,8 +459,8 @@ def _emit_module(cls) -> ir.Module:
         return m
 
     @classmethod
-    def load(cls, register=True) -> None:
-        if hasattr(cls, "_mlir_module"):
+    def load(cls, register=True, reload=False) -> None:
+        if hasattr(cls, "_mlir_module") and not reload:
             return
 
         cls._mlir_module = cls._emit_module()
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index c707883866a03..7818ce919d2f3 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -26,7 +26,7 @@ def run(emit_schedule):
     with ir.Context() as ctx, ir.Location.unknown():
         payload = emit_payload()
 
-        MyTransform.load(register=False)
+        MyTransform.load(register=False, reload=True)
 
         GetNamedAttributeOp.attach_interface_impls(ctx)
         PrintParamOp.attach_interface_impls(ctx)

>From 97316be8b62015b7a5ff50c24007e8f5f90d1066 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 27 Jan 2026 12:55:00 -0800
Subject: [PATCH 9/9] Opview from memoryeffectsinterface

---
 mlir/lib/Bindings/Python/IRInterfaces.cpp | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 05633746c3136..538ec3e62f231 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -373,8 +373,12 @@ class PyMemoryEffectsOpInterface
 
       PyMemoryEffectsInstanceList effectsWrapper{effects};
 
+      PyMlirContextRef context =
+          PyMlirContext::forContext(mlirOperationGetContext(op));
+      auto opview = PyOperation::forOperation(context, op)->createOpView();
+
       // Invoke `pyClass.get_effects(op, effects)`.
-      pyGetEffects(op, effectsWrapper);
+      pyGetEffects(opview, effectsWrapper);
     };
 
     mlirMemoryEffectsOpInterfaceAttachFallbackModel(



More information about the Mlir-commits mailing list