[Mlir-commits] [mlir] [MLIR][Python] Impl XOpInterface(s) from Python, with X=Transform and X=MemoryEffects (PR #176920)
Rolf Morel
llvmlistbot at llvm.org
Tue Jan 27 12:55:41 PST 2026
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/176920
>From 256fff028969ba360b89edb98c1aea147fd9cea6 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 7 Jan 2026 12:05:58 -0800
Subject: [PATCH 1/9] [MLIR][Python] Expose TransformOpInterface and
MemoryEffectsOpInterfaces
Enables implementing these two interfaces from Python through defining
the relevant methods on a Python class which then serve as callbacks for
a new FallbackModel C++ class that acts as a "trampoline" to Python for
when the Interface's methods are called from C++. Like in the C++
codebase, these FallbackModels are a late-binding mechanism which can be
attached to an operation after its definition.
---
mlir/include/mlir-c/Dialect/Transform.h | 135 ++++++++++
mlir/include/mlir-c/IR.h | 8 +
mlir/include/mlir-c/Interfaces.h | 39 ++-
mlir/include/mlir/Bindings/Python/IRCore.h | 11 +-
mlir/include/mlir/CAPI/Dialect/Transform.h | 28 ++
mlir/include/mlir/CAPI/IR.h | 1 +
mlir/include/mlir/CAPI/Interfaces.h | 8 +
mlir/lib/Bindings/Python/DialectTransform.cpp | 250 ++++++++++++++++-
mlir/lib/Bindings/Python/IRCore.cpp | 53 +++-
mlir/lib/Bindings/Python/IRInterfaces.cpp | 181 ++++---------
mlir/lib/Bindings/Python/IRInterfaces.h | 156 +++++++++++
mlir/lib/Bindings/Python/Rewrite.cpp | 64 +----
mlir/lib/Bindings/Python/Rewrite.h | 71 ++++-
mlir/lib/CAPI/Dialect/Transform.cpp | 196 ++++++++++++++
mlir/lib/CAPI/IR/IR.cpp | 12 +
mlir/lib/CAPI/Interfaces/Interfaces.cpp | 74 +++++
mlir/lib/IR/OperationSupport.cpp | 2 +-
mlir/python/CMakeLists.txt | 2 +
mlir/python/mlir/_mlir_libs/__init__.py | 1 +
.../python/dialects/transform_op_interface.py | 252 ++++++++++++++++++
20 files changed, 1344 insertions(+), 200 deletions(-)
create mode 100644 mlir/include/mlir/CAPI/Dialect/Transform.h
create mode 100644 mlir/lib/Bindings/Python/IRInterfaces.h
create mode 100644 mlir/test/python/dialects/transform_op_interface.py
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 911c9ef659a1e..38bd9756176a0 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -10,7 +10,9 @@
#ifndef MLIR_C_DIALECT_TRANSFORM_H
#define MLIR_C_DIALECT_TRANSFORM_H
+#include "mlir-c/Interfaces.h"
#include "mlir-c/IR.h"
+#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
@@ -19,6 +21,32 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirTransformResults, void);
+DEFINE_C_API_STRUCT(MlirTransformRewriter, void);
+DEFINE_C_API_STRUCT(MlirTransformState, void);
+
+#undef DEFINE_C_API_STRUCT
+
+//===---------------------------------------------------------------------===//
+// DiagnosedSilenceableFailure
+//===---------------------------------------------------------------------===//
+
+/// Enum representing the result of a transform operation.
+typedef enum {
+ /// The operation succeeded.
+ MlirDiagnosedSilenceableFailureSuccess,
+ /// The operation failed in a silenceable way.
+ MlirDiagnosedSilenceableFailureSilenceableFailure,
+ /// The operation failed definitively.
+ MlirDiagnosedSilenceableFailureDefiniteFailure
+} MlirDiagnosedSilenceableFailure;
+
//===---------------------------------------------------------------------===//
// AnyOpType
//===---------------------------------------------------------------------===//
@@ -86,6 +114,113 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirTransformParamTypeGetName(void);
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
+//===---------------------------------------------------------------------===//
+// TransformRewriter
+//===---------------------------------------------------------------------===//
+
+/// Cast the TransformRewriter to a RewriterBase
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirTransformRewriterAsBase(MlirTransformRewriter rewriter);
+
+//===---------------------------------------------------------------------===//
+// TransformResults
+//===---------------------------------------------------------------------===//
+
+/// Set the payload operations for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void mlirTransformResultsSetOps(MlirTransformResults results,
+ MlirValue result,
+ intptr_t numOps,
+ MlirOperation *ops);
+
+/// Set the payload values for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void
+mlirTransformResultsSetValues(MlirTransformResults results, MlirValue result,
+ intptr_t numValues, MlirValue *values);
+
+/// Set the parameters for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void
+mlirTransformResultsSetParams(MlirTransformResults results, MlirValue result,
+ intptr_t numParams, MlirAttribute *params);
+
+//===---------------------------------------------------------------------===//
+// TransformState
+//===---------------------------------------------------------------------===//
+
+/// Callback for iterating over payload operations.
+typedef void (*MlirOperationCallback)(MlirOperation, void *userData);
+
+/// Iterate over payload operations associated with the transform IR value.
+/// Calls the callback for each payload operation.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachPayloadOp(MlirTransformState state, MlirValue value,
+ MlirOperationCallback callback,
+ void *userData);
+
+/// Callback for iterating over payload values.
+typedef void (*MlirValueCallback)(MlirValue, void *userData);
+
+/// Iterate over payload values associated with the transform IR value.
+/// Calls the callback for each payload value.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachPayloadValue(MlirTransformState state, MlirValue value,
+ MlirValueCallback callback,
+ void *userData);
+
+/// Callback for iterating over parameters.
+typedef void (*MlirAttributeCallback)(MlirAttribute, void *userData);
+
+/// Iterate over parameters associated with the transform IR value.
+/// Calls the callback for each parameter.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
+ MlirAttributeCallback callback, void *userData);
+
+//===---------------------------------------------------------------------===//
+// TransformOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the TransformOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOpInterfaceTypeID(void);
+
+/// Callbacks for implementing TransformOpInterface from external code.
+typedef struct {
+ /// Optional constructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*construct)(void *userData);
+ /// Optional destructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*destruct)(void *userData);
+ /// Apply callback that implements the transformation.
+ MlirDiagnosedSilenceableFailure (*apply)(MlirOperation op,
+ MlirTransformRewriter rewriter,
+ MlirTransformResults results,
+ MlirTransformState state,
+ void *userData);
+ /// Callback to check if repeated handle operands are allowed.
+ bool (*allowsRepeatedHandleOperands)(MlirOperation op, void *userData);
+ void *userData;
+} MlirTransformOpInterfaceCallbacks;
+
+/// Attach TransformOpInterface to the operation with the given name using
+/// the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirTransformOpInterfaceCallbacks callbacks);
+
+//===---------------------------------------------------------------------===//
+// Transform-specifc MemoryEffectsOpInterface helpers
+//===---------------------------------------------------------------------===//
+
+/// Helper to mark operands as only reading handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
+ MlirMemoryEffectInstancesList effects);
+
+/// Helper to mark results as producing handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
+ MlirMemoryEffectInstancesList effects);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 80ff39c82a9ee..9b58267cf1c5a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -55,6 +55,7 @@ DEFINE_C_API_STRUCT(MlirDialect, void);
DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
DEFINE_C_API_STRUCT(MlirOperation, void);
DEFINE_C_API_STRUCT(MlirOpOperand, void);
+DEFINE_C_API_STRUCT(MlirOpResult, void);
DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
DEFINE_C_API_STRUCT(MlirBlock, void);
DEFINE_C_API_STRUCT(MlirRegion, void);
@@ -673,6 +674,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op);
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op,
intptr_t pos);
+/// Returns `pos`-th OpOperand of the operation.
+MLIR_CAPI_EXPORTED MlirOpOperand mlirOperationGetOpOperand(MlirOperation op,
+ intptr_t pos);
+
/// Sets the `pos`-th operand of the operation.
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
MlirValue newValue);
@@ -1044,6 +1049,9 @@ MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value);
/// Returns 1 if the value is an operation result, 0 otherwise.
MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value);
+/// Cast the value to an OpResult. Asserts if the value is not an op result.
+MLIR_CAPI_EXPORTED MlirOpResult mlirValueToOpResult(MlirValue value);
+
/// Returns the block in which this value is defined as an argument. Asserts if
/// the value is not a block argument.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value);
diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h
index a5a3473eaef59..17a812dcd86a9 100644
--- a/mlir/include/mlir-c/Interfaces.h
+++ b/mlir/include/mlir-c/Interfaces.h
@@ -22,6 +22,16 @@
extern "C" {
#endif
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirMemoryEffectInstancesList, void);
+
+#undef DEFINE_C_API_STRUCT
+
/// Returns `true` if the given operation implements an interface identified by
/// its TypeID.
MLIR_CAPI_EXPORTED bool
@@ -42,7 +52,7 @@ mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
//===----------------------------------------------------------------------===//
/// Returns the interface TypeID of the InferTypeOpInterface.
-MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID();
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(void);
/// These callbacks are used to return multiple types from functions while
/// transferring ownership to the caller. The first argument is the number of
@@ -65,7 +75,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
//===----------------------------------------------------------------------===//
/// Returns the interface TypeID of the InferShapedTypeOpInterface.
-MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID();
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(void);
/// These callbacks are used to return multiple shaped type components from
/// functions while transferring ownership to the caller. The first argument is
@@ -87,6 +97,31 @@ mlirInferShapedTypeOpInterfaceInferReturnTypes(
void *properties, intptr_t nRegions, MlirRegion *regions,
MlirShapedTypeComponentsCallback callback, void *userData);
+//===---------------------------------------------------------------------===//
+// MemoryEffectsOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the MemoryEffectsOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirMemoryEffectsOpInterfaceTypeID(void);
+
+/// Callbacks for implementing MemoryEffectsOpInterface from external code.
+typedef struct {
+ /// Optional constructor for user data. Set to nullptr to disable it.
+ void (*construct)(void *userData);
+ /// Optional destructor for user data. Set to nullptr to disable it.
+ void (*destruct)(void *userData);
+ /// Get memory effects callback.
+ void (*getEffects)(MlirOperation op, MlirMemoryEffectInstancesList effects,
+ void *userData);
+ void *userData;
+} MlirMemoryEffectsOpInterfaceCallbacks;
+
+/// Attach a new FallbackModel for the MemoryEffectsOpInterface to the named
+/// operation. The FallbackModel will call the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirMemoryEffectsOpInterfaceCallbacks callbacks);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index f9fc34e82c972..2460202a31b8a 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1496,6 +1496,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationList {
class MLIR_PYTHON_API_EXPORTED PyOpOperand {
public:
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
+ operator MlirOpOperand() const { return opOperand; }
nanobind::typed<nanobind::object, PyOpView> getOwner() const;
@@ -1602,6 +1603,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpResult : public PyConcreteValue<PyOpResult> {
static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
static constexpr const char *pyClassName = "OpResult";
using PyConcreteValue::PyConcreteValue;
+ operator MlirOpResult() { return mlirValueToOpResult(castFrom(*this)); }
static void bindDerived(ClassTy &c);
};
@@ -1838,13 +1840,20 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
+
+/// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+static nanobind::object classmethod(Func f, Args... args) {
+ nanobind::object cf = nanobind::cpp_function(f, args...);
+ return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
+}
+
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
namespace nanobind {
namespace detail {
-
template <>
struct type_caster<
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext>
diff --git a/mlir/include/mlir/CAPI/Dialect/Transform.h b/mlir/include/mlir/CAPI/Dialect/Transform.h
new file mode 100644
index 0000000000000..792236cd8601f
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Dialect/Transform.h
@@ -0,0 +1,28 @@
+//===- Transform.h - C API Utils for Transform dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// the Transform dialect. This file should not be included from C++ code other
+// than C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_DIALECT_TRANSFORM_H
+#define MLIR_CAPI_DIALECT_TRANSFORM_H
+
+#include "mlir-c/Dialect/Transform.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+
+DEFINE_C_API_PTR_METHODS(MlirTransformRewriter,
+ mlir::transform::TransformRewriter)
+DEFINE_C_API_PTR_METHODS(MlirTransformResults,
+ mlir::transform::TransformResults)
+DEFINE_C_API_PTR_METHODS(MlirTransformState, mlir::transform::TransformState)
+
+#endif // MLIR_CAPI_DIALECT_TRANSFORM_H
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 1836cb0acb67e..f58ff423d23f7 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -29,6 +29,7 @@ DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand)
+DEFINE_C_API_PTR_METHODS(MlirOpResult, mlir::OpResult)
DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable)
diff --git a/mlir/include/mlir/CAPI/Interfaces.h b/mlir/include/mlir/CAPI/Interfaces.h
index 4154b8c9ec6cc..15afc9fb0f18e 100644
--- a/mlir/include/mlir/CAPI/Interfaces.h
+++ b/mlir/include/mlir/CAPI/Interfaces.h
@@ -15,4 +15,12 @@
#ifndef MLIR_CAPI_INTERFACES_H
#define MLIR_CAPI_INTERFACES_H
+#include "mlir-c/Interfaces.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+DEFINE_C_API_PTR_METHODS(
+ MlirMemoryEffectInstancesList,
+ llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance>)
+
#endif // MLIR_CAPI_INTERFACES_H
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 82905498921ce..7dbeb72a0d2a8 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -8,12 +8,14 @@
#include <string>
+#include "IRInterfaces.h"
+#include "Rewrite.h"
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
+#include <nanobind/trampoline.h>
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
@@ -22,6 +24,208 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
namespace transform {
+
+//===----------------------------------------------------------------------===//
+// TransformRewriter
+//===----------------------------------------------------------------------===//
+class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
+public:
+ static constexpr const char *pyClassName = "TransformRewriter";
+
+ PyTransformRewriter(MlirTransformRewriter rewriter)
+ : PyRewriterBase(mlirTransformRewriterAsBase(rewriter)) {}
+};
+
+//===----------------------------------------------------------------------===//
+// TransformResults
+//===----------------------------------------------------------------------===//
+class PyTransformResults {
+public:
+ PyTransformResults(MlirTransformResults results) : results(results) {}
+
+ MlirTransformResults get() const { return results; }
+
+ void setOps(MlirValue result, const nanobind::list &ops) {
+ std::vector<MlirOperation> opsVec;
+ opsVec.reserve(ops.size());
+ for (auto op : ops) {
+ opsVec.push_back(nb::cast<MlirOperation>(op));
+ }
+ mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
+ }
+
+ void setValues(MlirValue result, const nanobind::list &values) {
+ std::vector<MlirValue> valuesVec;
+ valuesVec.reserve(values.size());
+ for (auto item : values) {
+ valuesVec.push_back(nb::cast<MlirValue>(item));
+ }
+ mlirTransformResultsSetValues(results, result, valuesVec.size(),
+ valuesVec.data());
+ }
+
+ void setParams(MlirValue result, const nanobind::list ¶ms) {
+ std::vector<MlirAttribute> paramsVec;
+ paramsVec.reserve(params.size());
+ for (auto item : params) {
+ paramsVec.push_back(nb::cast<MlirAttribute>(item));
+ }
+ mlirTransformResultsSetParams(results, result, paramsVec.size(),
+ paramsVec.data());
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nb::class_<PyTransformResults>(m, "TransformResults")
+ .def(nb::init<MlirTransformResults>())
+ .def("set_ops", &PyTransformResults::setOps,
+ "Set the payload operations for a transform result.",
+ nb::arg("result"), nb::arg("ops"))
+ .def("set_values", &PyTransformResults::setValues,
+ "Set the payload values for a transform result.",
+ nb::arg("result"), nb::arg("values"))
+ .def("set_params", &PyTransformResults::setParams,
+ "Set the parameters for a transform result.", nb::arg("result"),
+ nb::arg("params"));
+ }
+
+private:
+ MlirTransformResults results;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformState
+//===----------------------------------------------------------------------===//
+class PyTransformState {
+public:
+ PyTransformState(MlirTransformState state) : state(state) {}
+
+ MlirTransformState get() const { return state; }
+
+ nanobind::list getPayloadOps(MlirValue value) {
+ nanobind::list result;
+ mlirTransformStateForEachPayloadOp(
+ state, value,
+ [](MlirOperation op, void *userData) {
+ static_cast<nanobind::list *>(userData)->append(op);
+ },
+ &result);
+ return result;
+ }
+
+ nanobind::list getPayloadValues(MlirValue value) {
+ nanobind::list result;
+ mlirTransformStateForEachPayloadValue(
+ state, value,
+ [](MlirValue val, void *userData) {
+ static_cast<nanobind::list *>(userData)->append(val);
+ },
+ &result);
+ return result;
+ }
+
+ nanobind::list getParams(MlirValue value) {
+ nanobind::list result;
+ mlirTransformStateForEachParam(
+ state, value,
+ [](MlirAttribute attr, void *userData) {
+ static_cast<nanobind::list *>(userData)->append(attr);
+ },
+ &result);
+ return result;
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nb::class_<PyTransformState>(m, "TransformState")
+ .def(nb::init<MlirTransformState>())
+ .def("get_payload_ops", &PyTransformState::getPayloadOps,
+ "Get the payload operations associated with a transform IR value.",
+ nb::arg("operand"))
+ .def("get_payload_values", &PyTransformState::getPayloadValues,
+ "Get the payload values associated with a transform IR value.",
+ nb::arg("operand"))
+ .def("get_params", &PyTransformState::getParams,
+ "Get the parameters (attributes) associated with a transform IR "
+ "value.",
+ nb::arg("operand"));
+ }
+
+private:
+ MlirTransformState state;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformOpInterface
+//===----------------------------------------------------------------------===//
+class PyTransformOpInterface
+ : public PyConcreteOpInterface<PyTransformOpInterface> {
+public:
+ using PyConcreteOpInterface<PyTransformOpInterface>::PyConcreteOpInterface;
+
+ constexpr static const char *pyClassName = "TransformOpInterface";
+ constexpr static GetTypeIDFunctionTy getInterfaceID =
+ &mlirTransformOpInterfaceTypeID;
+
+ /// Attach a new TransformOpInterface FallbackModel to the named operation.
+ /// The FallbackModel acts as a trampoline for callbacks on the Python class.
+ static void attach(nb::object &pyClass, const std::string &opName,
+ DefaultingPyMlirContext ctx) {
+ // Prepare the callbacks that will be used by the FallbackModel.
+ MlirTransformOpInterfaceCallbacks callbacks;
+ // Make the pointer to the Python class available to the callbacks.
+ callbacks.userData = pyClass.ptr();
+ nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+
+ // The above ref bump is all we need as initialization, no need to run the
+ // construct callback.
+ callbacks.construct = nullptr;
+ // Upon the FallbackModel's destruction, drop the ref to the Python class.
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ // The apply callback which calls into Python.
+ callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
+ MlirTransformResults results, MlirTransformState state,
+ void *userData) -> MlirDiagnosedSilenceableFailure {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
+
+ auto pyRewriter = PyTransformRewriter(rewriter);
+ auto pyResults = PyTransformResults(results);
+ auto pyState = PyTransformState(state);
+
+ // Invoke `pyClass.apply(op, rewriter, results, state)` as a classmethod.
+ nb::object res = pyApply(op, pyRewriter, pyResults, pyState);
+
+ return nb::cast<MlirDiagnosedSilenceableFailure>(res);
+ };
+
+ // The allows_repeated_handle_operands callback which calls into Python.
+ callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
+ void *userData) -> bool {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
+ nb::getattr(pyClass, "allow_repeated_handle_operands"));
+
+ // Invoke `pyClass.allow_repeated_handle_operands(op)` as a classmethod.
+ nb::object res = pyAllowRepeatedHandleOperands(op);
+
+ return nb::cast<bool>(res);
+ };
+
+ // Attach a FallbackModel, which calls into Python, to the named operation.
+ mlirTransformOpInterfaceAttachFallbackModel(
+ ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
+ }
+
+ static void bindDerived(ClassTy &transformOpInterfaceClass) {
+ transformOpInterfaceClass.attr("attach") =
+ classmethod(&PyTransformOpInterface::attach, nb::arg("cls"),
+ nb::arg("op_name"), nb::arg("ctx") = nb::none());
+ }
+};
+
//===-------------------------------------------------------------------===//
// AnyOpType
//===-------------------------------------------------------------------===//
@@ -162,12 +366,54 @@ struct ParamType : PyConcreteType<ParamType> {
}
};
+//===----------------------------------------------------------------------===//
+// MemoryEffectsOpInterface helpers
+//===----------------------------------------------------------------------===//
+
+namespace {
+void onlyReadsHandle(nb::list &operands, PyMemoryEffectsInstanceList effects) {
+ std::vector<MlirOpOperand> operandsVec;
+ operandsVec.reserve(operands.size());
+ for (auto operand : operands)
+ operandsVec.push_back(nb::cast<PyOpOperand>(operand));
+ mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
+ effects.effects);
+};
+
+void producesHandle(nb::list &results, PyMemoryEffectsInstanceList effects) {
+ std::vector<MlirOpResult> resultsVec;
+ resultsVec.reserve(results.size());
+ for (auto result : results)
+ resultsVec.push_back(nb::cast<PyOpResult>(result));
+ mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
+ effects.effects);
+};
+} // namespace
+
static void populateDialectTransformSubmodule(nb::module_ &m) {
+ nb::enum_<MlirDiagnosedSilenceableFailure>(m, "DiagnosedSilenceableFailure")
+ .value("Success", MlirDiagnosedSilenceableFailureSuccess)
+ .value("SilenceableFailure",
+ MlirDiagnosedSilenceableFailureSilenceableFailure)
+ .value("DefiniteFailure", MlirDiagnosedSilenceableFailureDefiniteFailure);
+
AnyOpType::bind(m);
AnyParamType::bind(m);
AnyValueType::bind(m);
OperationType::bind(m);
ParamType::bind(m);
+
+ PyTransformRewriter::bind(m);
+ PyTransformResults::bind(m);
+ PyTransformState::bind(m);
+ PyTransformOpInterface::bind(m);
+
+ m.def("only_reads_handle", onlyReadsHandle,
+ "Mark operands as only reading handles.", nb::arg("operands"),
+ nb::arg("effects"));
+
+ m.def("produces_handle", producesHandle, "Mark results as producing handles.",
+ nb::arg("results"), nb::arg("effects"));
}
} // namespace transform
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index eb00363a54034..f3108473c712a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -10,7 +10,6 @@
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/NanobindUtils.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
// clang-format on
#include "mlir-c/BuiltinAttributes.h"
@@ -52,13 +51,6 @@ operations.
// Utilities.
//------------------------------------------------------------------------------
-/// Helper for creating an @classmethod.
-template <class Func, typename... Args>
-static nb::object classmethod(Func f, Args... args) {
- nb::object cf = nb::cpp_function(f, args...);
- return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
-}
-
static nb::object
createCustomDialectWrapper(const std::string &dialectNamespace,
nb::object dialectDescriptor) {
@@ -2306,6 +2298,44 @@ PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
return PyOpOperandList(operation, startIndex, length, step);
}
+/// A list of OpOperands. Internally, these are stored as consecutive elements,
+/// random access is cheap. The (returned) OpOperand list is associated with the
+/// operation whose operands these are, and thus extends the lifetime of this
+/// operation.
+class PyOpOpOperandList : public Sliceable<PyOpOpOperandList, PyOpOperand> {
+public:
+ static constexpr const char *pyClassName = "OpOpOperandList";
+ using SliceableT = Sliceable<PyOpOperandList, PyOpOperand>;
+
+ PyOpOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumOperands(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpOpOperandList, PyOpOperand>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumOperands(operation->get());
+ }
+
+ PyOpOperand getRawElement(intptr_t pos) {
+ MlirOpOperand opOperand = mlirOperationGetOpOperand(operation->get(), pos);
+ return PyOpOperand(opOperand);
+ }
+
+ PyOpOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpOpOperandList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
PyOpSuccessors::PyOpSuccessors(PyOperationRef operation, intptr_t startIndex,
intptr_t length, intptr_t step)
: Sliceable(startIndex,
@@ -3534,6 +3564,12 @@ void populateIRCore(nb::module_ &m) {
return PyOpOperandList(self.getOperation().getRef());
},
"Returns the list of operation operands.")
+ .def_prop_ro(
+ "op_operands",
+ [](PyOperationBase &self) {
+ return PyOpOpOperandList(self.getOperation().getRef());
+ },
+ "Returns the list of op operands.")
.def_prop_ro(
"regions",
[](PyOperationBase &self) {
@@ -4825,6 +4861,7 @@ void populateIRCore(nb::module_ &m) {
PyOpAttributeMap::bind(m);
PyOpOperandIterator::bind(m);
PyOpOperandList::bind(m);
+ PyOpOpOperandList::bind(m);
PyOpResultList::bind(m);
PyOpSuccessors::bind(m);
PyRegionIterator::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 09112d4989ae4..05633746c3136 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,12 +12,12 @@
#include <utility>
#include <vector>
+#include "IRInterfaces.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
-#include "mlir/Bindings/Python/Nanobind.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -26,18 +26,6 @@ namespace nb = nanobind;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-constexpr static const char *constructorDoc =
- R"(Creates an interface from a given operation/opview object or from a
-subclass of OpView. Raises ValueError if the operation does not implement the
-interface.)";
-
-constexpr static const char *operationDoc =
- R"(Returns an Operation for which the interface was constructed.)";
-
-constexpr static const char *opviewDoc =
- R"(Returns an OpView subclass _instance_ for which the interface was
-constructed)";
-
constexpr static const char *inferReturnTypesDoc =
R"(Given the arguments required to build an operation, attempts to infer
its return types. Raises ValueError on failure.)";
@@ -124,119 +112,6 @@ wrapRegions(std::optional<std::vector<PyRegion>> regions) {
} // namespace
-/// CRTP base class for Python classes representing MLIR Op interfaces.
-/// Interface hierarchies are flat so no base class is expected here. The
-/// derived class is expected to define the following static fields:
-/// - `const char *pyClassName` - the name of the Python class to create;
-/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
-/// of the interface.
-/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
-/// interface-specific methods.
-///
-/// An interface class may be constructed from either an Operation/OpView object
-/// or from a subclass of OpView. In the latter case, only the static interface
-/// methods are available, similarly to calling ConcereteOp::staticMethod on the
-/// C++ side. Implementations of concrete interfaces can use the `isStatic`
-/// method to check whether the interface object was constructed from a class or
-/// an operation/opview instance. The `getOpName` always succeeds and returns a
-/// canonical name of the operation suitable for lookups.
-template <typename ConcreteIface>
-class PyConcreteOpInterface {
-protected:
- using ClassTy = nb::class_<ConcreteIface>;
- using GetTypeIDFunctionTy = MlirTypeID (*)();
-
-public:
- /// Constructs an interface instance from an object that is either an
- /// operation or a subclass of OpView. In the latter case, only the static
- /// methods of the interface are accessible to the caller.
- PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
- : obj(std::move(object)) {
- try {
- operation = &nb::cast<PyOperation &>(obj);
- } catch (nb::cast_error &) {
- // Do nothing.
- }
-
- try {
- operation = &nb::cast<PyOpView &>(obj).getOperation();
- } catch (nb::cast_error &) {
- // Do nothing.
- }
-
- if (operation != nullptr) {
- if (!mlirOperationImplementsInterface(*operation,
- ConcreteIface::getInterfaceID())) {
- std::string msg = "the operation does not implement ";
- throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
- }
-
- MlirIdentifier identifier = mlirOperationGetName(*operation);
- MlirStringRef stringRef = mlirIdentifierStr(identifier);
- opName = std::string(stringRef.data, stringRef.length);
- } else {
- try {
- opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
- } catch (nb::cast_error &) {
- throw nb::type_error(
- "Op interface does not refer to an operation or OpView class");
- }
-
- if (!mlirOperationImplementsInterfaceStatic(
- mlirStringRefCreate(opName.data(), opName.length()),
- context.resolve().get(), ConcreteIface::getInterfaceID())) {
- std::string msg = "the operation does not implement ";
- throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
- }
- }
- }
-
- /// Creates the Python bindings for this class in the given module.
- static void bind(nb::module_ &m) {
- nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
- cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
- nb::arg("context") = nb::none(), constructorDoc)
- .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
- operationDoc)
- .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
- ConcreteIface::bindDerived(cls);
- }
-
- /// Hook for derived classes to add class-specific bindings.
- static void bindDerived(ClassTy &cls) {}
-
- /// Returns `true` if this object was constructed from a subclass of OpView
- /// rather than from an operation instance.
- bool isStatic() { return operation == nullptr; }
-
- /// Returns the operation instance from which this object was constructed.
- /// Throws a type error if this object was constructed from a subclass of
- /// OpView.
- nb::typed<nb::object, PyOperation> getOperationObject() {
- if (operation == nullptr)
- throw nb::type_error("Cannot get an operation from a static interface");
- return operation->getRef().releaseObject();
- }
-
- /// Returns the opview of the operation instance from which this object was
- /// constructed. Throws a type error if this object was constructed form a
- /// subclass of OpView.
- nb::typed<nb::object, PyOpView> getOpView() {
- if (operation == nullptr)
- throw nb::type_error("Cannot get an opview from a static interface");
- return operation->createOpView();
- }
-
- /// Returns the canonical name of the operation this interface is constructed
- /// from.
- const std::string &getOpName() { return opName; }
-
-private:
- PyOperation *operation = nullptr;
- std::string opName;
- nb::object obj;
-};
-
/// Python wrapper for InferTypeOpInterface. This interface has only static
/// methods.
class PyInferTypeOpInterface
@@ -464,10 +339,62 @@ class PyInferShapedTypeOpInterface
}
};
+/// Wrapper around the MemoryEffectsOpInterface.
+class PyMemoryEffectsOpInterface
+ : public PyConcreteOpInterface<PyMemoryEffectsOpInterface> {
+public:
+ using PyConcreteOpInterface<
+ PyMemoryEffectsOpInterface>::PyConcreteOpInterface;
+
+ constexpr static const char *pyClassName = "MemoryEffectsOpInterface";
+ constexpr static GetTypeIDFunctionTy getInterfaceID =
+ &mlirMemoryEffectsOpInterfaceTypeID;
+
+ /// Attach a new MemoryEffectsOpInterface FallbackModel to the named
+ /// operation. The FallbackModel acts as a trampoline for callbacks on the
+ /// Python class.
+ static void attach(nb::object &pySubclass, const std::string &opName,
+ DefaultingPyMlirContext ctx) {
+ MlirMemoryEffectsOpInterfaceCallbacks callbacks;
+ callbacks.userData = pySubclass.ptr();
+ nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+ callbacks.construct = nullptr;
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.getEffects = [](MlirOperation op,
+ MlirMemoryEffectInstancesList effects,
+ void *userData) {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ // Get the 'get_effects' method from the Python class.
+ auto pyGetEffects =
+ nb::cast<nb::callable>(nb::getattr(pyClass, "get_effects"));
+
+ PyMemoryEffectsInstanceList effectsWrapper{effects};
+
+ // Invoke `pyClass.get_effects(op, effects)`.
+ pyGetEffects(op, effectsWrapper);
+ };
+
+ mlirMemoryEffectsOpInterfaceAttachFallbackModel(
+ ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
+ }
+
+ static void bindDerived(ClassTy &memoryEffectsOpInterfaceClass) {
+ memoryEffectsOpInterfaceClass.attr("attach") =
+ classmethod(&PyMemoryEffectsOpInterface::attach, nb::arg("cls"),
+ nb::arg("op_name"), nb::arg("ctx") = nb::none());
+ }
+};
+
void populateIRInterfaces(nb::module_ &m) {
+ nb::class_<PyMemoryEffectsInstanceList>(m, "MemoryEffectInstancesList");
+
+ PyInferShapedTypeOpInterface::bind(m);
PyInferTypeOpInterface::bind(m);
+ PyMemoryEffectsOpInterface::bind(m);
PyShapedTypeComponents::bind(m);
- PyInferShapedTypeOpInterface::bind(m);
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.h b/mlir/lib/Bindings/Python/IRInterfaces.h
new file mode 100644
index 0000000000000..3374b505b5034
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRInterfaces.h
@@ -0,0 +1,156 @@
+//===- IRInterfaces.h - IR Interfaces for Python Bindings -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_IRINTERFACES_H
+#define MLIR_BINDINGS_PYTHON_IRINTERFACES_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Interfaces.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
+
+namespace nb = nanobind;
+
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+constexpr static const char *constructorDoc =
+ R"(Creates an interface from a given operation/opview object or from a
+subclass of OpView. Raises ValueError if the operation does not implement the
+interface.)";
+
+constexpr static const char *operationDoc =
+ R"(Returns an Operation for which the interface was constructed.)";
+
+constexpr static const char *opviewDoc =
+ R"(Returns an OpView subclass _instance_ for which the interface was
+constructed)";
+
+/// CRTP base class for Python classes representing MLIR Op interfaces.
+/// Interface hierarchies are flat so no base class is expected here. The
+/// derived class is expected to define the following static fields:
+/// - `const char *pyClassName` - the name of the Python class to create;
+/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
+/// of the interface.
+/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
+/// interface-specific methods.
+///
+/// An interface class may be constructed from either an Operation/OpView object
+/// or from a subclass of OpView. In the latter case, only the static interface
+/// methods are available, similarly to calling ConcereteOp::staticMethod on the
+/// C++ side. Implementations of concrete interfaces can use the `isStatic`
+/// method to check whether the interface object was constructed from a class or
+/// an operation/opview instance. The `getOpName` always succeeds and returns a
+/// canonical name of the operation suitable for lookups.
+template <typename ConcreteIface>
+class PyConcreteOpInterface {
+protected:
+ using ClassTy = nb::class_<ConcreteIface>;
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+
+public:
+ /// Constructs an interface instance from an object that is either an
+ /// operation or a subclass of OpView. In the latter case, only the static
+ /// methods of the interface are accessible to the caller.
+ PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
+ : obj(std::move(object)) {
+ try {
+ operation = &nb::cast<PyOperation &>(obj);
+ } catch (nb::cast_error &) {
+ // Do nothing.
+ }
+
+ try {
+ operation = &nb::cast<PyOpView &>(obj).getOperation();
+ } catch (nb::cast_error &) {
+ // Do nothing.
+ }
+
+ if (operation != nullptr) {
+ if (!mlirOperationImplementsInterface(*operation,
+ ConcreteIface::getInterfaceID())) {
+ std::string msg = "the operation does not implement ";
+ throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
+ }
+
+ MlirIdentifier identifier = mlirOperationGetName(*operation);
+ MlirStringRef stringRef = mlirIdentifierStr(identifier);
+ opName = std::string(stringRef.data, stringRef.length);
+ } else {
+ try {
+ opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
+ } catch (nb::cast_error &) {
+ throw nb::type_error(
+ "Op interface does not refer to an operation or OpView class");
+ }
+
+ if (!mlirOperationImplementsInterfaceStatic(
+ mlirStringRefCreate(opName.data(), opName.length()),
+ context.resolve().get(), ConcreteIface::getInterfaceID())) {
+ std::string msg = "the operation does not implement ";
+ throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
+ }
+ }
+ }
+
+ /// Creates the Python bindings for this class in the given module.
+ static void bind(nb::module_ &m) {
+ nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
+ cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
+ nb::arg("context") = nb::none(), constructorDoc)
+ .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
+ operationDoc)
+ .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
+ ConcreteIface::bindDerived(cls);
+ }
+
+ /// Hook for derived classes to add class-specific bindings.
+ static void bindDerived(ClassTy &cls) {}
+
+ /// Returns `true` if this object was constructed from a subclass of OpView
+ /// rather than from an operation instance.
+ bool isStatic() { return operation == nullptr; }
+
+ /// Returns the operation instance from which this object was constructed.
+ /// Throws a type error if this object was constructed from a subclass of
+ /// OpView.
+ nb::typed<nb::object, PyOperation> getOperationObject() {
+ if (operation == nullptr)
+ throw nb::type_error("Cannot get an operation from a static interface");
+ return operation->getRef().releaseObject();
+ }
+
+ /// Returns the opview of the operation instance from which this object was
+ /// constructed. Throws a type error if this object was constructed form a
+ /// subclass of OpView.
+ nb::typed<nb::object, PyOpView> getOpView() {
+ if (operation == nullptr)
+ throw nb::type_error("Cannot get an opview from a static interface");
+ return operation->createOpView();
+ }
+
+ /// Returns the canonical name of the operation this interface is constructed
+ /// from.
+ const std::string &getOpName() { return opName; }
+
+private:
+ PyOperation *operation = nullptr;
+ std::string opName;
+ nb::object obj;
+};
+
+struct PyMemoryEffectsInstanceList {
+ MlirMemoryEffectInstancesList effects;
+};
+
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_IRINTERFACES_H
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 2b649f79c5982..dcb95e81fc126 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,14 +8,11 @@
#include "Rewrite.h"
+#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
-// clang-format off
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
-// clang-format on
#include "mlir/Config/mlir-config.h"
#include "nanobind/nanobind.h"
@@ -28,38 +25,12 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-class PyPatternRewriter {
+class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
public:
- PyPatternRewriter(MlirPatternRewriter rewriter)
- : base(mlirPatternRewriterAsBase(rewriter)),
- ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
-
- PyInsertionPoint getInsertionPoint() const {
- MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
- MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
-
- if (mlirOperationIsNull(op)) {
- MlirOperation owner = mlirBlockGetParentOperation(block);
- auto parent = PyOperation::forOperation(ctx, owner);
- return PyInsertionPoint(PyBlock(parent, block));
- }
-
- return PyInsertionPoint(PyOperation::forOperation(ctx, op));
- }
-
- void replaceOp(MlirOperation op, MlirOperation newOp) {
- mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
- }
-
- void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
- mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
- }
+ static constexpr const char *pyClassName = "PatternRewriter";
- void eraseOp(const PyOperation &op) { mlirRewriterBaseEraseOp(base, op); }
-
-private:
- MlirRewriterBase base;
- PyMlirContextRef ctx;
+ PyPatternRewriter(MlirPatternRewriter rewriter)
+ : PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
};
struct PyMlirPDLResultList : MlirPDLResultList {};
@@ -340,29 +311,8 @@ void populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
- nb::class_<PyPatternRewriter>(m, "PatternRewriter")
- .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
- "The current insertion point of the PatternRewriter.")
- .def(
- "replace_op",
- [](PyPatternRewriter &self, PyOperationBase &op,
- PyOperationBase &newOp) {
- self.replaceOp(op.getOperation(), newOp.getOperation());
- },
- "Replace an operation with a new operation.", nb::arg("op"),
- nb::arg("new_op"))
- .def(
- "replace_op",
- [](PyPatternRewriter &self, PyOperationBase &op,
- const std::vector<PyValue> &values) {
- std::vector<MlirValue> values_(values.size());
- std::copy(values.begin(), values.end(), values_.begin());
- self.replaceOp(op.getOperation(), values_);
- },
- "Replace an operation with a list of values.", nb::arg("op"),
- nb::arg("values"))
- .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
- nb::arg("op"));
+
+ PyPatternRewriter::bind(m);
//----------------------------------------------------------------------------
// Mapping of the RewritePatternSet
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index d287f19187708..9b4dfafb59b86 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -9,13 +9,80 @@
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H
-#include "mlir/Bindings/Python/NanobindUtils.h"
+#include "mlir-c/Rewrite.h"
+#include "mlir/Bindings/Python/IRCore.h"
+
+#include <nanobind/nanobind.h>
+
+namespace nb = nanobind;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+/// CRTP Base class for rewriter wrappers.
+template <typename DerivedTy>
+class PyRewriterBase {
+public:
+ PyRewriterBase(MlirRewriterBase rewriter)
+ : base(rewriter),
+ ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
+
+ PyInsertionPoint getInsertionPoint() const {
+ MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
+ MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
+
+ if (mlirOperationIsNull(op)) {
+ MlirOperation owner = mlirBlockGetParentOperation(block);
+ auto parent = PyOperation::forOperation(ctx, owner);
+ return PyInsertionPoint(PyBlock(parent, block));
+ }
+
+ return PyInsertionPoint(PyOperation::forOperation(ctx, op));
+ }
+
+ void replaceOp(MlirOperation op, MlirOperation newOp) {
+ mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+ }
+
+ void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+ mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+ }
+
+ void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+
+ static void bind(nanobind::module_ &m) {
+ nb::class_<DerivedTy>(m, DerivedTy::pyClassName)
+ .def_prop_ro("ip", &PyRewriterBase::getInsertionPoint,
+ "The current insertion point of the PatternRewriter.")
+ .def(
+ "replace_op",
+ [](DerivedTy &self, PyOperationBase &op, PyOperationBase &newOp) {
+ self.replaceOp(op.getOperation(), newOp.getOperation());
+ },
+ "Replace an operation with a new operation.", nb::arg("op"),
+ nb::arg("new_op"))
+ .def(
+ "replace_op",
+ [](DerivedTy &self, PyOperationBase &op,
+ const std::vector<PyValue> &values) {
+ std::vector<MlirValue> values_(values.size());
+ std::copy(values.begin(), values.end(), values_.begin());
+ self.replaceOp(op.getOperation(), values_);
+ },
+ "Replace an operation with a list of values.", nb::arg("op"),
+ nb::arg("values"))
+ .def("erase_op", &DerivedTy::eraseOp, "Erase an operation.",
+ nb::arg("op"));
+ }
+
+private:
+ MlirRewriterBase base;
+ PyMlirContextRef ctx;
+};
+
void populateRewriteSubmodule(nanobind::module_ &m);
-}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 18d4c075dbb9c..fdb27ec5cdc89 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -8,9 +8,13 @@
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/Support.h"
+#include "mlir/CAPI/Dialect/Transform.h"
+#include "mlir/CAPI/Interfaces.h"
#include "mlir/CAPI/Registration.h"
+#include "mlir/CAPI/Rewrite.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
using namespace mlir;
@@ -126,3 +130,195 @@ MlirStringRef mlirTransformParamTypeGetName(void) {
MlirType mlirTransformParamTypeGetType(MlirType type) {
return wrap(cast<transform::ParamType>(unwrap(type)).getType());
}
+
+//===---------------------------------------------------------------------===//
+// TransformRewriter
+//===---------------------------------------------------------------------===//
+
+/// Casts a `MlirTransformRewriter` to a `MlirRewriterBase`.
+MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter) {
+ mlir::transform::TransformRewriter *t = unwrap(rewriter);
+ mlir::RewriterBase *base = static_cast<mlir::RewriterBase *>(t);
+ return wrap(base);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformResults
+//===---------------------------------------------------------------------===//
+
+void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result,
+ intptr_t numOps, MlirOperation *ops) {
+ SmallVector<Operation *> opsVec;
+ opsVec.reserve(numOps);
+ for (intptr_t i = 0; i < numOps; ++i)
+ opsVec.push_back(unwrap(ops[i]));
+ unwrap(results)->set(cast<OpResult>(unwrap(result)), opsVec);
+}
+
+void mlirTransformResultsSetValues(MlirTransformResults results,
+ MlirValue result, intptr_t numValues,
+ MlirValue *values) {
+ SmallVector<Value> valuesVec;
+ valuesVec.reserve(numValues);
+ for (intptr_t i = 0; i < numValues; ++i)
+ valuesVec.push_back(unwrap(values[i]));
+ unwrap(results)->setValues(cast<OpResult>(unwrap(result)), valuesVec);
+}
+
+void mlirTransformResultsSetParams(MlirTransformResults results,
+ MlirValue result, intptr_t numParams,
+ MlirAttribute *params) {
+ SmallVector<Attribute> paramsVec;
+ paramsVec.reserve(numParams);
+ for (intptr_t i = 0; i < numParams; ++i)
+ paramsVec.push_back(unwrap(params[i]));
+ unwrap(results)->setParams(cast<OpResult>(unwrap(result)), paramsVec);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformState
+//===---------------------------------------------------------------------===//
+
+void mlirTransformStateForEachPayloadOp(MlirTransformState state,
+ MlirValue value,
+ MlirOperationCallback callback,
+ void *userData) {
+ for (Operation *op : unwrap(state)->getPayloadOps(unwrap(value)))
+ callback(wrap(op), userData);
+}
+
+void mlirTransformStateForEachPayloadValue(MlirTransformState state,
+ MlirValue value,
+ MlirValueCallback callback,
+ void *userData) {
+ for (Value val : unwrap(state)->getPayloadValues(unwrap(value)))
+ callback(wrap(val), userData);
+}
+
+void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
+ MlirAttributeCallback callback,
+ void *userData) {
+ for (Attribute attr : unwrap(state)->getParams(unwrap(value)))
+ callback(wrap(attr), userData);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformOpInterface
+//===---------------------------------------------------------------------===//
+
+MlirTypeID mlirTransformOpInterfaceTypeID(void) {
+ return wrap(transform::TransformOpInterface::getInterfaceID());
+}
+
+/// Fallback model for the TransformOpInterface that uses C API callbacks.
+class TransformOpInterfaceFallbackModel
+ : public mlir::transform::TransformOpInterface::FallbackModel<
+ TransformOpInterfaceFallbackModel> {
+public:
+ /// Sets the callbacks that this FallbackModel will use.
+ /// NB: the callbacks can only be set through this method as the
+ /// RegisteredOperationName::attachInterface mechanism default-constructs
+ /// the FallbackModel without being able to provide arguments.
+ void setCallbacks(MlirTransformOpInterfaceCallbacks callbacks) {
+ this->callbacks = callbacks;
+ }
+
+ ~TransformOpInterfaceFallbackModel() {
+ if (callbacks.destruct)
+ callbacks.destruct(callbacks.userData);
+ }
+
+ static TypeID getInterfaceID() {
+ return transform::TransformOpInterface::getInterfaceID();
+ }
+
+ static bool classof(const mlir::transform::detail::
+ TransformOpInterfaceInterfaceTraits::Concept *op) {
+ // Enable casting back to the FallbackModel from the Interface. This is
+ // necessary as attachInterface(...) default-constructs the FallbackModel
+ // without being able to pass in the callbacks and returns just the Concept.
+ return true;
+ }
+
+ ::mlir::DiagnosedSilenceableFailure
+ apply(Operation *op, ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state) const {
+ assert(callbacks.apply && "apply callback not set");
+
+ MlirDiagnosedSilenceableFailure status =
+ callbacks.apply(wrap(op), wrap(&rewriter), wrap(&transformResults),
+ wrap(&state), callbacks.userData);
+
+ switch (status) {
+ case MlirDiagnosedSilenceableFailureSuccess:
+ return DiagnosedSilenceableFailure::success();
+ case MlirDiagnosedSilenceableFailureSilenceableFailure:
+ // TODO: enable passing diagnostic info from C API to C++ API.
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(
+ *(op->emitError()
+ << "TransformOpInterfaceFallbackModel: silenceable failure")
+ .getUnderlyingDiagnostic()));
+ case MlirDiagnosedSilenceableFailureDefiniteFailure:
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ llvm_unreachable("unknown transform status");
+ }
+
+ bool allowsRepeatedHandleOperands(Operation *op) const {
+ assert(callbacks.allowsRepeatedHandleOperands &&
+ "allowsRepeatedHandleOperands callback not set");
+ return callbacks.allowsRepeatedHandleOperands(wrap(op), callbacks.userData);
+ }
+
+private:
+ MlirTransformOpInterfaceCallbacks callbacks;
+};
+
+/// Attach a TransformOpInterface FallbackModel to the given named operation.
+/// The FallbackModel uses the provided callbacks to implement the interface.
+void mlirTransformOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirTransformOpInterfaceCallbacks callbacks) {
+ // Look up the operation definition in the context.
+ std::optional<RegisteredOperationName> opInfo =
+ RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
+
+ if (!opInfo.has_value()) {
+ llvm::errs() << "Operation '" << unwrap(opName)
+ << "' not found in context\n";
+ return;
+ }
+
+ // NB: the following default-constructs the FallbackModel _without_ being able
+ // to provide arguments.
+ opInfo->attachInterface<TransformOpInterfaceFallbackModel>();
+ // Cast to get the underlying FallbackModel and set the callbacks.
+ auto *model = cast<TransformOpInterfaceFallbackModel>(
+ opInfo->getInterface<TransformOpInterfaceFallbackModel>());
+
+ assert(model && "Failed to get TransformOpInterfaceFallbackModel");
+ model->setCallbacks(callbacks);
+}
+
+//===---------------------------------------------------------------------===//
+// MemoryEffectsOpInterface helpers
+//===---------------------------------------------------------------------===//
+
+/// Set the effect for the operands to only read the transform handles.
+void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
+ MlirMemoryEffectInstancesList effects) {
+ MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
+ transform::onlyReadsHandle(operandArray, *unwrap(effects));
+}
+
+/// Set the effect for the results to that they produce transform handles.
+void mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
+ MlirMemoryEffectInstancesList effects) {
+ // NB: calling `producesHandle()` `numResults` as we cannot cast array of
+ // `OpResult`s to a single `ResultRange` (and neither is `ResultRange` exposed
+ // to Python). `producesHandle` iterates over the given `ResultRange` anyway.
+ for (intptr_t i = 0; i < numResults; ++i)
+ transform::producesHandle(ResultRange(*unwrap(results[i])),
+ *unwrap(effects));
+}
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 09666932004a4..ccfd30e2354c5 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -30,6 +30,7 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ThreadPool.h"
#include <cstddef>
@@ -714,6 +715,10 @@ MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
}
+MlirOpOperand mlirOperationGetOpOperand(MlirOperation op, intptr_t pos) {
+ return wrap(&unwrap(op)->getOpOperand(static_cast<unsigned>(pos)));
+}
+
void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
MlirValue newValue) {
unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
@@ -1121,6 +1126,13 @@ bool mlirValueIsAOpResult(MlirValue value) {
return llvm::isa<OpResult>(unwrap(value));
}
+/// Cast an MlirValue to an MlirOpResult, asserting in case of a type mismatch.
+MlirOpResult mlirValueToOpResult(MlirValue value) {
+ return TypeSwitch<Value, MlirOpResult>(unwrap(value))
+ .Case<OpResult>([&](OpResult opResult) { return wrap(&opResult); })
+ .DefaultUnreachable("expected an OpResult");
+}
+
MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
return wrap(llvm::dyn_cast<BlockArgument>(unwrap(value)).getOwner());
}
diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
index ef3fc23869550..50f007f5d15f6 100644
--- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp
+++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
@@ -167,3 +167,77 @@ MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
}
return mlirLogicalResultSuccess();
}
+
+//===---------------------------------------------------------------------===//
+// MemoryEffectOpInterface
+//===---------------------------------------------------------------------===//
+
+MlirTypeID mlirMemoryEffectsOpInterfaceTypeID() {
+ return wrap(MemoryEffectOpInterface::getInterfaceID());
+}
+
+/// Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
+class MemoryEffectOpInterfaceFallbackModel
+ : public mlir::MemoryEffectOpInterface::FallbackModel<
+ MemoryEffectOpInterfaceFallbackModel> {
+public:
+ /// Sets the callbacks that this FallbackModel will use.
+ /// NB: the callbacks can only be set through this method as the
+ /// RegisteredOperationName::attachInterface mechanism default-constructs
+ /// the FallbackModel without being able to provide arguments.
+ void setCallbacks(MlirMemoryEffectsOpInterfaceCallbacks callbacks) {
+ this->callbacks = callbacks;
+ }
+
+ ~MemoryEffectOpInterfaceFallbackModel() {
+ if (callbacks.destruct)
+ callbacks.destruct(callbacks.userData);
+ }
+
+ static TypeID getInterfaceID() {
+ return MemoryEffectOpInterface::getInterfaceID();
+ }
+
+ static bool classof(const mlir::MemoryEffectOpInterface::Concept *op) {
+ // Enable casting back to the FallbackModel from the Interface. This is
+ // necessary as attachInterface(...) default-constructs the FallbackModel
+ // without being able to pass in the callbacks and returns just the Concept.
+ return true;
+ }
+
+ void
+ getEffects(Operation *op,
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) const {
+ assert(callbacks.getEffects && "getEffects callback not set");
+ MlirMemoryEffectInstancesList cEffects = wrap(&effects);
+ callbacks.getEffects(wrap(op), cEffects, callbacks.userData);
+ }
+
+private:
+ MlirMemoryEffectsOpInterfaceCallbacks callbacks;
+};
+
+/// Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
+/// The FallbackModel uses the provided callbacks to implement the interface.
+void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirMemoryEffectsOpInterfaceCallbacks callbacks) {
+ // Look up the operation definition in the context
+ std::optional<RegisteredOperationName> opInfo =
+ RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
+
+ if (!opInfo.has_value()) {
+ llvm::errs() << "Operation '" << unwrap(opName)
+ << "' not found in context\n";
+ return;
+ }
+
+ // NB: the following default-constructs the FallbackModel _without_ being able
+ // to provide arguments.
+ opInfo->attachInterface<MemoryEffectOpInterfaceFallbackModel>();
+ // Cast to get the underlying FallbackModel and set the callbacks.
+ auto *model = cast<MemoryEffectOpInterfaceFallbackModel>(
+ opInfo->getInterface<MemoryEffectOpInterfaceFallbackModel>());
+ assert(model && "Failed to get MemoryEffectOpInterfaceFallbackModel");
+ model->setCallbacks(callbacks);
+}
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 2a37f3860fe00..c53c59d87d039 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -575,7 +575,7 @@ MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
ResultRange::ResultRange(OpResult result)
: ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
- 1) {}
+ 0) {}
ResultRange::use_range ResultRange::getUses() const {
return {use_begin(), use_end()};
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 8ab145ada85dd..13bf7624f5626 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -671,8 +671,10 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectTransform.cpp
+ Rewrite.h
PRIVATE_LINK_LIBS
LLVMSupport
+ MLIRPythonExtension.Core
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPITransformDialect
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index c0e8775149d41..303616080a75b 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -242,6 +242,7 @@ def __str__(self):
Sequence.register(ir.BlockPredecessors)
Sequence.register(ir.OperationList)
Sequence.register(ir.OpOperandList)
+ Sequence.register(ir.OpOpOperandList)
Sequence.register(ir.OpResultList)
Sequence.register(ir.OpSuccessors)
Sequence.register(ir.RegionSequence)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
new file mode 100644
index 0000000000000..85e9ec89a3de3
--- /dev/null
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -0,0 +1,252 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from contextlib import contextmanager
+
+from mlir import ir
+from mlir.ir import TypeAttr, F32Type, UnitAttr
+from mlir.dialects import transform, irdl, func
+from mlir.dialects.transform import AnyOpType, AnyValueType, AnyParamType, structured, interpreter
+
+from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
+
+
+def run(emit_schedule):
+ with ir.Context(), ir.Location.unknown():
+ irdl_module = ir.Module.create()
+ with ir.InsertionPoint(irdl_module.body):
+ my_transform = irdl.dialect("my_transform")
+ with ir.InsertionPoint(my_transform.body):
+ OneOpInOneOpOut.emit_irdl()
+ OpValParamInParamOpValOut.emit_irdl()
+ OpsParamInParamsOut.emit_irdl()
+ irdl_module.operation.verify()
+
+ irdl.load_dialects(irdl_module)
+
+ payload = emit_payload()
+ schedule = emit_schedule()
+
+ print("payload:", payload)
+ print("schedule:", schedule)
+ named_seq = schedule.operation.regions[0].blocks[0].operations[0]
+
+ interpreter.apply_named_sequence(
+ payload,
+ named_seq,
+ schedule,
+ )
+
+ del payload
+ del schedule
+
+
+# Payload used by all tests
+def emit_payload():
+ payload_module = ir.Module.create()
+ with ir.InsertionPoint(payload_module.body):
+
+ @func.FuncOp.from_py_func(
+ F32Type.get(), F32Type.get(), results=[F32Type.get()]
+ )
+ def name_of_func(a, b):
+ func.ReturnOp([b])
+
+ return payload_module
+
+
+class OneOpInOneOpOut:
+ name = "one_op_in_one_op_out"
+
+ def __init__(self, op_arg: AnyOpType):
+ self.op = ir.Operation.create(
+ "my_transform.one_op_in_one_op_out",
+ [AnyOpType.get()],
+ [get_op_result_or_value(op_arg)],
+ )
+
+ @property
+ def result(self):
+ return self.op.results[0]
+
+ @classmethod
+ def emit_irdl(cls):
+ op = irdl.operation_(cls.name)
+ with ir.InsertionPoint(op.body):
+ op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+ irdl.operands_(
+ [op_handle_type],
+ ["arg"],
+ [irdl.Variadicity.single],
+ )
+ irdl.results_([op_handle_type], ["result"], [irdl.Variadicity.single])
+ return op
+
+
+class OpValParamInParamOpValOut:
+ name = "op_val_param_in_param_op_val_out"
+
+ def __init__(
+ self,
+ op_arg: AnyOpType,
+ val: AnyValueType,
+ param: AnyParamType,
+ ):
+ self.op = ir.Operation.create(
+ "my_transform." + self.name,
+ [
+ AnyParamType.get(),
+ AnyOpType.get(),
+ AnyValueType.get(),
+ ],
+ [
+ get_op_result_or_value(op_arg),
+ get_op_result_or_value(val),
+ get_op_result_or_value(param),
+ ],
+ )
+
+ @property
+ def param_res(self):
+ return self.op.results[0]
+
+ @property
+ def op_res(self):
+ return self.op.results[1]
+
+ @property
+ def value_res(self):
+ return self.op.results[2]
+
+ @classmethod
+ def emit_irdl(cls):
+ op = irdl.operation_(cls.name)
+ with ir.InsertionPoint(op.body):
+ op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+ value_handle_type = irdl.is_(TypeAttr.get(AnyValueType.get()))
+ param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
+ irdl.operands_(
+ [op_handle_type, value_handle_type, param_handle_type],
+ ["op_arg", "value_arg", "param_arg"],
+ [
+ irdl.Variadicity.single,
+ irdl.Variadicity.single,
+ irdl.Variadicity.single,
+ ],
+ )
+ irdl.results_(
+ [param_handle_type, op_handle_type, value_handle_type],
+ ["param_res", "op_res", "value_res"],
+ [
+ irdl.Variadicity.single,
+ irdl.Variadicity.single,
+ irdl.Variadicity.single,
+ ],
+ )
+ return op
+
+
+class OpsParamInParamsOut:
+ name = "ops_param_in_params_out"
+
+ def __init__(
+ self,
+ ops: list[AnyOpType],
+ param: AnyParamType,
+ ):
+ self.op = ir.Operation.create(
+ "my_transform." + self.name,
+ [AnyParamType.get()],
+ [get_op_results_or_values(ops), get_op_result_or_value(param)],
+ )
+
+ @property
+ def param_results(self):
+ return self.op.results
+
+ @classmethod
+ def emit_irdl(cls):
+ op = irdl.operation_(cls.name)
+ with ir.InsertionPoint(op.body):
+ op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+ param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
+ irdl.operands_(
+ [op_handle_type, param_handle_type],
+ ["op_args", "param_arg"],
+ [irdl.Variadicity.variadic, irdl.Variadicity.single],
+ )
+ irdl.results_(
+ [param_handle_type], ["param_results"], [irdl.Variadicity.variadic]
+ )
+ return op
+
+
+ at contextmanager
+def schedule_boilerplate():
+ schedule = ir.Module.create()
+ schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+ with ir.InsertionPoint(schedule.body):
+ named_sequence = transform.NamedSequenceOp(
+ "__transform_main",
+ [AnyOpType.get()],
+ [AnyOpType.get()],
+ arg_attrs=[{"transform.consumed": UnitAttr.get()}],
+ )
+ with ir.InsertionPoint(named_sequence.body):
+ yield schedule, named_sequence
+
+
+ at run
+def OneOpInOneOpOutTransformOpInterface():
+ class OneOpInOneOpOutTransformOpInterfaceFallbackModel(
+ transform.TransformOpInterface
+ ):
+ @staticmethod
+ def apply(
+ op_: ir.Operation,
+ rewriter: transform.TransformRewriter,
+ results: transform.TransformResults,
+ state: transform.TransformState,
+ ):
+ targets = state.get_payload_ops(arg := op_.operands[0])
+ target_names = [t.opview.name.value for t in targets]
+ print(
+ f"OneOpInOneOpOutTransformOpInterfaceFallbackModel: target_names={target_names}"
+ )
+ results.set_ops(result := op_.results[0], targets)
+ return transform.DiagnosedSilenceableFailure.Success
+
+ @staticmethod
+ def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+ return False
+
+ OneOpInOneOpOutTransformOpInterfaceFallbackModel.attach(
+ "my_transform.one_op_in_one_op_out", ir.Context.current
+ )
+
+ # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface.
+ class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
+ @staticmethod
+ def get_effects(op_: ir.Operation, effects):
+ transform.only_reads_handle(list(op_.op_operands), effects)
+ transform.produces_handle(list(op_.results), effects)
+
+ MemoryEffectsOpInterfaceFallbackModel.attach(
+ "my_transform.one_op_in_one_op_out", ir.Context.current
+ )
+
+ with schedule_boilerplate() as (schedule, named_seq):
+ print(f"{named_seq=}")
+ func_handle = structured.MatchOp.match_op_names(
+ named_seq.bodyTarget, ["func.func"]
+ ).result
+ func_handle.dump()
+ # CHECK: OneOpInOneOpOutTransformOpInterfaceFallbackModel: target_names=['name_of_func']
+ out = OneOpInOneOpOut(func_handle).result
+ out.dump()
+ print(out.owner)
+ transform.YieldOp([out])
+ named_seq.verify()
+ print("named_seq", named_seq)
+ print("named_seq.parent", named_seq.parent)
+
+ return schedule
>From 369396ca6eb5e0d50eea9f022ec93e0eb9673609 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 20 Jan 2026 05:11:21 -0800
Subject: [PATCH 2/9] Format fix
---
mlir/include/mlir-c/Dialect/Transform.h | 2 +-
mlir/test/python/dialects/transform_op_interface.py | 12 ++++++++----
2 files changed, 9 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 38bd9756176a0..b087482345557 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -10,8 +10,8 @@
#ifndef MLIR_C_DIALECT_TRANSFORM_H
#define MLIR_C_DIALECT_TRANSFORM_H
-#include "mlir-c/Interfaces.h"
#include "mlir-c/IR.h"
+#include "mlir-c/Interfaces.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index 85e9ec89a3de3..f856f698f38cb 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -5,7 +5,13 @@
from mlir import ir
from mlir.ir import TypeAttr, F32Type, UnitAttr
from mlir.dialects import transform, irdl, func
-from mlir.dialects.transform import AnyOpType, AnyValueType, AnyParamType, structured, interpreter
+from mlir.dialects.transform import (
+ AnyOpType,
+ AnyValueType,
+ AnyParamType,
+ structured,
+ interpreter,
+)
from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
@@ -45,9 +51,7 @@ def emit_payload():
payload_module = ir.Module.create()
with ir.InsertionPoint(payload_module.body):
- @func.FuncOp.from_py_func(
- F32Type.get(), F32Type.get(), results=[F32Type.get()]
- )
+ @func.FuncOp.from_py_func(F32Type.get(), F32Type.get(), results=[F32Type.get()])
def name_of_func(a, b):
func.ReturnOp([b])
>From 508fdf5186bf7ea5b9314059c75dda3e0b8aeaa6 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 26 Jan 2026 04:24:38 -0800
Subject: [PATCH 3/9] More tests
---
.../python/dialects/transform_op_interface.py | 299 +++++++++++++++++-
1 file changed, 285 insertions(+), 14 deletions(-)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index f856f698f38cb..f1b9d4c23a5de 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -1,11 +1,14 @@
# RUN: %PYTHON %s | FileCheck %s
+from typing import Sequence
+
from contextlib import contextmanager
from mlir import ir
from mlir.ir import TypeAttr, F32Type, UnitAttr
-from mlir.dialects import transform, irdl, func
+from mlir.dialects import transform, irdl, func, arith
from mlir.dialects.transform import (
+ debug as transform_debug,
AnyOpType,
AnyValueType,
AnyParamType,
@@ -22,13 +25,22 @@ def run(emit_schedule):
with ir.InsertionPoint(irdl_module.body):
my_transform = irdl.dialect("my_transform")
with ir.InsertionPoint(my_transform.body):
+ GetNamedAttributeOp.emit_irdl()
OneOpInOneOpOut.emit_irdl()
OpValParamInParamOpValOut.emit_irdl()
- OpsParamInParamsOut.emit_irdl()
+ OpsParamsInValuesParamOut.emit_irdl()
irdl_module.operation.verify()
+ print(irdl_module)
irdl.load_dialects(irdl_module)
+ GetNamedAttributeTransformOpInterfaceFallbackModel.attach(
+ "my_transform." + GetNamedAttributeOp.name
+ )
+ GetNamedAttributeMemoryEffectsOpInterfaceFallbackModel.attach(
+ "my_transform." + GetNamedAttributeOp.name
+ )
+
payload = emit_payload()
schedule = emit_schedule()
@@ -53,11 +65,82 @@ def emit_payload():
@func.FuncOp.from_py_func(F32Type.get(), F32Type.get(), results=[F32Type.get()])
def name_of_func(a, b):
- func.ReturnOp([b])
+ c = arith.AddFOp(a, b)
+ d = arith.constant(F32Type.get(), 42.0)
+ e = arith.constant(F32Type.get(), 24.0)
+ func.ReturnOp([c.results[0]])
return payload_module
+class GetNamedAttributeOp:
+ name = "get_named_attribute"
+
+ def __init__(self, target: AnyOpType, attr_name: str):
+ self.op = ir.Operation.create(
+ "my_transform.get_named_attribute",
+ [AnyParamType.get()],
+ [get_op_result_or_value(target)],
+ {"attr_name": ir.StringAttr.get(attr_name)},
+ )
+
+ @property
+ def attr_as_param(self) -> ir.Value:
+ return self.op.results[0]
+
+ @classmethod
+ def emit_irdl(cls):
+ op = irdl.operation_(cls.name)
+ with ir.InsertionPoint(op.body):
+ op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+ param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
+ name_handle_kind = irdl.base(base_name="#builtin.string")
+ irdl.operands_(
+ [op_handle_type],
+ ["target"],
+ [irdl.Variadicity.single],
+ )
+ irdl.attributes_([name_handle_kind], ["attr_name"])
+ irdl.results_(
+ [param_handle_type], ["attr_as_param"], [irdl.Variadicity.single]
+ )
+ return op
+
+
+class GetNamedAttributeTransformOpInterfaceFallbackModel(
+ transform.TransformOpInterface
+):
+ @staticmethod
+ def apply(
+ op_: ir.Operation,
+ rewriter: transform.TransformRewriter,
+ results: transform.TransformResults,
+ state: transform.TransformState,
+ ):
+ targets = state.get_payload_ops(target := op_.operands[0])
+ associated_attrs = []
+ for target_op in targets:
+ assoc_attr = target_op.attributes.get(op_.attributes["attr_name"].value)
+ if assoc_attr is None:
+ return transform.DiagnosedSilenceableFailure.RecoverableFailure
+ associated_attrs.append(assoc_attr)
+ results.set_params(op_.results[0], associated_attrs)
+ return transform.DiagnosedSilenceableFailure.Success
+
+ @staticmethod
+ def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+ return False
+
+
+class GetNamedAttributeMemoryEffectsOpInterfaceFallbackModel(
+ ir.MemoryEffectsOpInterface
+):
+ @staticmethod
+ def get_effects(op_: ir.Operation, effects):
+ transform.only_reads_handle(list(op_.op_operands), effects)
+ transform.produces_handle(list(op_.results), effects)
+
+
class OneOpInOneOpOut:
name = "one_op_in_one_op_out"
@@ -149,37 +232,55 @@ def emit_irdl(cls):
return op
-class OpsParamInParamsOut:
- name = "ops_param_in_params_out"
+class OpsParamsInValuesParamOut:
+ name = "ops_params_in_values_param_out"
def __init__(
self,
- ops: list[AnyOpType],
- param: AnyParamType,
+ value_results: Sequence[AnyValueType],
+ ops: Sequence[AnyOpType],
+ params: Sequence[AnyParamType],
):
+ def as_i32(x):
+ return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), x)
+
self.op = ir.Operation.create(
"my_transform." + self.name,
- [AnyParamType.get()],
- [get_op_results_or_values(ops), get_op_result_or_value(param)],
+ list(value_results) + [AnyParamType.get()],
+ list(get_op_results_or_values(ops))
+ + list(get_op_results_or_values(params)),
+ {
+ "operandSegmentSizes": ir.DenseI32ArrayAttr.get(
+ [len(ops), len(params)]
+ ),
+ "resultSegmentSizes": ir.DenseI32ArrayAttr.get([len(value_results)]),
+ },
)
@property
- def param_results(self):
- return self.op.results
+ def param(self):
+ return self.op.results[-1]
+
+ @property
+ def values(self):
+ return self.op.results[:-1]
@classmethod
def emit_irdl(cls):
op = irdl.operation_(cls.name)
with ir.InsertionPoint(op.body):
op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+ value_handle_type = irdl.is_(TypeAttr.get(AnyValueType.get()))
param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
irdl.operands_(
[op_handle_type, param_handle_type],
- ["op_args", "param_arg"],
- [irdl.Variadicity.variadic, irdl.Variadicity.single],
+ ["ops", "params"],
+ [irdl.Variadicity.variadic, irdl.Variadicity.variadic],
)
irdl.results_(
- [param_handle_type], ["param_results"], [irdl.Variadicity.variadic]
+ [value_handle_type, param_handle_type],
+ ["value_results", "param"],
+ [irdl.Variadicity.variadic, irdl.Variadicity.single],
)
return op
@@ -254,3 +355,173 @@ def get_effects(op_: ir.Operation, effects):
print("named_seq.parent", named_seq.parent)
return schedule
+
+
+ at run
+def OpValParamInParamOpValOutTransformOpInterface():
+ class OpValParamInParamOpValOutTransformOpInterfaceFallbackModel(
+ transform.TransformOpInterface
+ ):
+ @staticmethod
+ def apply(
+ op_: ir.Operation,
+ rewriter: transform.TransformRewriter,
+ results: transform.TransformResults,
+ state: transform.TransformState,
+ ):
+ ops = state.get_payload_ops(op_.operands[0])
+ values = state.get_payload_values(op_.operands[1])
+ params = state.get_params(op_.operands[2])
+ print(
+ f"OpValParamInParamOpValOutTransformOpInterfaceFallbackModel: ops={len(ops)}, values={len(values)}, params={len(params)}"
+ )
+ results.set_params(op_.results[0], params)
+ results.set_ops(op_.results[1], ops)
+ results.set_values(op_.results[2], values)
+ return transform.DiagnosedSilenceableFailure.Success
+
+ @staticmethod
+ def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+ return False
+
+ OpValParamInParamOpValOutTransformOpInterfaceFallbackModel.attach(
+ "my_transform.op_val_param_in_param_op_val_out", ir.Context.current
+ )
+
+ # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface.
+ class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
+ @staticmethod
+ def get_effects(op_: ir.Operation, effects):
+ transform.only_reads_handle(list(op_.op_operands), effects)
+ transform.produces_handle(list(op_.results), effects)
+
+ MemoryEffectsOpInterfaceFallbackModel.attach(
+ "my_transform.op_val_param_in_param_op_val_out", ir.Context.current
+ )
+
+ with schedule_boilerplate() as (schedule, named_seq):
+ func_handle = structured.MatchOp.match_op_names(
+ named_seq.bodyTarget, ["func.func"]
+ ).result
+ addf_handle = structured.MatchOp.match_op_names(
+ named_seq.bodyTarget, ["arith.addf"]
+ ).result
+ func_and_addf = transform.MergeHandlesOp([func_handle, addf_handle])
+ value_handle = transform.GetResultOp(
+ AnyValueType.get(), addf_handle, [0]
+ ).result
+ param_handle = transform.ParamConstantOp(
+ AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
+ ).param
+
+ # CHECK: OpValParamInParamOpValOutTransformOpInterfaceFallbackModel: ops=2, values=1, params=1
+ op_val_param_op = OpValParamInParamOpValOut(
+ func_and_addf, value_handle, param_handle
+ )
+
+ transform.YieldOp([op_val_param_op.op_res])
+ named_seq.verify()
+
+ return schedule
+
+
+ at run
+def OpsParamsInValuesParamOutTransformOpInterface():
+ class OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel(
+ transform.TransformOpInterface
+ ):
+ @staticmethod
+ def apply(
+ op_: ir.Operation,
+ rewriter: transform.TransformRewriter,
+ results: transform.TransformResults,
+ state: transform.TransformState,
+ ):
+ # The last operand is the param. All previous ones are ops.
+ op_handles, param_handles = [], []
+ for operand in op_.operands:
+ if isinstance(operand.type, transform.AnyOpType):
+ op_handles.append(operand)
+ else:
+ param_handles.append(operand)
+
+ ops_count = 0
+ value_handles = []
+ for op_handle in op_handles:
+ ops = state.get_payload_ops(op_handle)
+ ops_count += len(ops)
+ value_handles.append(list(op.results[:1] for op in ops))
+
+ param_count = 0
+ param_sum = 0
+ for param_handle in param_handles:
+ params = state.get_params(param_handle)
+ param_count += len(params)
+ param_sum += sum(p.value for p in params)
+
+ print(
+ f"OpsParamInValuesParamOutTransformOpInterfaceFallbackModel: #op_handles={len(op_handles)}, ops_count={ops_count}, #param_handles={len(param_handles)}, param_count={param_count}"
+ )
+
+ assert len(op_.results) + 1 == len(op_handles)
+ for i in range(len(op_.results) - 1):
+ results.set_values(
+ op_.results[i],
+ value_handles[i],
+ )
+ results.set_params(
+ op_.results[-1],
+ [ir.IntegerAttr.get(ir.IntegerType.get_signless(32), param_sum)],
+ )
+ return transform.DiagnosedSilenceableFailure.Success
+
+ @staticmethod
+ def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+ return True
+
+ OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel.attach(
+ "my_transform." + OpsParamsInValuesParamOut.name
+ )
+
+ class OpsParamsInParamsOutMemoryEffectsOpInterfaceFallbackModel(
+ ir.MemoryEffectsOpInterface
+ ):
+ @staticmethod
+ def get_effects(op_: ir.Operation, effects):
+ transform.only_reads_handle(list(op_.op_operands), effects)
+ transform.produces_handle(list(op_.results), effects)
+
+ OpsParamsInParamsOutMemoryEffectsOpInterfaceFallbackModel.attach(
+ "my_transform." + OpsParamsInValuesParamOut.name
+ )
+
+ with schedule_boilerplate() as (schedule, named_seq):
+ func_handle = structured.MatchOp.match_op_names(
+ named_seq.bodyTarget, ["func.func"]
+ ).result
+ csts_handle = structured.MatchOp.match_op_names(
+ named_seq.bodyTarget, ["arith.constant"]
+ ).result
+ csts_as_param = GetNamedAttributeOp(csts_handle, "value").attr_as_param
+
+ param_handle = transform.ParamConstantOp(
+ AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 123)
+ ).param
+
+ # CHECK: OpsParamInParamsOutTransformOpInterfaceFallbackModel: op_count=2, param_count=1
+ op = OpsParamsInValuesParamOut(
+ [transform.AnyValueType.get()] * 2 + [transform.AnyParamType.get()],
+ [func_handle, csts_handle],
+ [csts_as_param, param_handle],
+ )
+ print(op.op)
+ # CHECK: Sum of params: 189
+ transform_debug.EmitParamAsRemarkOp(op.param, message="Sum of params")
+
+ transform_debug.EmitRemarkAtOp(op.values[0], message="Value results 0")
+ transform_debug.EmitRemarkAtOp(op.values[1], message="Value results 1")
+
+ transform.YieldOp([func_handle])
+ named_seq.verify()
+
+ return schedule
>From 28609afc3d2d14fc8708f134cef2d13731223948 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 26 Jan 2026 12:57:27 -0800
Subject: [PATCH 4/9] Switch to op def DSL and add more tests
---
mlir/lib/Bindings/Python/DialectTransform.cpp | 16 +-
mlir/lib/Bindings/Python/Globals.cpp | 3 +-
mlir/python/mlir/dialects/ext.py | 37 +-
.../python/dialects/transform_op_interface.py | 571 +++++++-----------
4 files changed, 266 insertions(+), 361 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 7dbeb72a0d2a8..3756518eeb06a 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -194,8 +194,12 @@ class PyTransformOpInterface
auto pyResults = PyTransformResults(results);
auto pyState = PyTransformState(state);
- // Invoke `pyClass.apply(op, rewriter, results, state)` as a classmethod.
- nb::object res = pyApply(op, pyRewriter, pyResults, pyState);
+ // Invoke `pyClass.apply(opview(op), rewriter, results, state)` as a
+ // staticmethod.
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ auto opview = PyOperation::forOperation(context, op)->createOpView();
+ nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
return nb::cast<MlirDiagnosedSilenceableFailure>(res);
};
@@ -208,8 +212,12 @@ class PyTransformOpInterface
auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
nb::getattr(pyClass, "allow_repeated_handle_operands"));
- // Invoke `pyClass.allow_repeated_handle_operands(op)` as a classmethod.
- nb::object res = pyAllowRepeatedHandleOperands(op);
+ // Invoke `pyClass.allow_repeated_handle_operands(opview(op))` as a
+ // staticmethod.
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ auto opview = PyOperation::forOperation(context, op)->createOpView();
+ nb::object res = pyAllowRepeatedHandleOperands(opview);
return nb::cast<bool>(res);
};
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index e2e8693ba45f3..590943b445250 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -194,8 +194,7 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
// Make sure dialect module is loaded.
auto split = operationName.split('.');
llvm::StringRef dialectNamespace = split.first;
- if (!loadDialectModule(dialectNamespace))
- return std::nullopt;
+ loadDialectModule(dialectNamespace);
nb::ft_lock_guard lock(mutex);
auto foundIt = operationClassMap.find(operationName);
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 237c27bf62f77..3fff63dd0578b 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -29,10 +29,14 @@
"Dialect",
"Operand",
"Result",
+ "register_dialect",
+ "register_operation",
]
Operand = ir.Value
Result = ir.OpResult
+register_dialect = _cext.register_dialect
+register_operation = _cext.register_operation
class ConstraintLoweringContext:
@@ -424,6 +428,11 @@ class AddOp(MyInt.Operation, name="add"):
```
"""
+ class ExtOperation(Operation):
+ def __init__(*args, **kwargs):
+ raise RuntimeError("Cannot instantiate Dialect.ExtOperation directly.")
+
+
@classmethod
def __init_subclass__(cls, name: str, **kwargs):
cls.name = name
@@ -431,7 +440,7 @@ def __init_subclass__(cls, name: str, **kwargs):
cls.operations = []
cls.Operation = type(
"Operation",
- (Operation,),
+ (cls.ExtOperation,),
{"_dialect_obj": cls, "_dialect_name": name},
)
@@ -451,21 +460,25 @@ def _emit_module(cls) -> ir.Module:
return m
@classmethod
- def load(cls) -> None:
- if hasattr(cls, "_mlir_module"):
- raise RuntimeError(f"Dialect {cls.name} is already loaded.")
+ def load(cls, register=True, context: Optional[ir.Context] = None) -> None:
+ context = context or ir.Context.current
- mlir_module = cls._emit_module()
+ try:
+ context.dialects[cls.name]
+ raise RuntimeError(f"Dialect {cls.name} is already loaded.")
+ except IndexError:
+ pass # Dialect not loaded yet.
+ cls._mlir_module = cls._emit_module()
pm = PassManager()
pm.add("canonicalize, cse")
- pm.run(mlir_module.operation)
-
- irdl.load_dialects(mlir_module)
+ pm.run(cls._mlir_module.operation)
- _cext.register_dialect(cls)
+ irdl.load_dialects(cls._mlir_module)
- for op in cls.operations:
- _cext.register_operation(cls)(op)
+ if register:
+ _cext.register_dialect(cls)
- cls._mlir_module = mlir_module
+ register_dialect_operation = _cext.register_operation(cls)
+ for op in cls.operations:
+ register_dialect_operation(op)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index f1b9d4c23a5de..ab919911bdd31 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -1,14 +1,13 @@
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: env PYTHONUNBUFFERED=1 %PYTHON %s 2>&1 | FileCheck %s
from typing import Sequence
from contextlib import contextmanager
from mlir import ir
-from mlir.ir import TypeAttr, F32Type, UnitAttr
-from mlir.dialects import transform, irdl, func, arith
+from mlir.ir import F32Type, UnitAttr
+from mlir.dialects import transform, func, arith, ext
from mlir.dialects.transform import (
- debug as transform_debug,
AnyOpType,
AnyValueType,
AnyParamType,
@@ -16,47 +15,31 @@
interpreter,
)
-from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
+ at ext.register_dialect
+class MyTransform(ext.Dialect, name="my_transform"):
+ pass
def run(emit_schedule):
- with ir.Context(), ir.Location.unknown():
- irdl_module = ir.Module.create()
- with ir.InsertionPoint(irdl_module.body):
- my_transform = irdl.dialect("my_transform")
- with ir.InsertionPoint(my_transform.body):
- GetNamedAttributeOp.emit_irdl()
- OneOpInOneOpOut.emit_irdl()
- OpValParamInParamOpValOut.emit_irdl()
- OpsParamsInValuesParamOut.emit_irdl()
- irdl_module.operation.verify()
-
- print(irdl_module)
- irdl.load_dialects(irdl_module)
-
- GetNamedAttributeTransformOpInterfaceFallbackModel.attach(
- "my_transform." + GetNamedAttributeOp.name
- )
- GetNamedAttributeMemoryEffectsOpInterfaceFallbackModel.attach(
- "my_transform." + GetNamedAttributeOp.name
- )
-
+ print(f"Test: {emit_schedule.__name__}")
+ with ir.Context() as ctx, ir.Location.unknown():
payload = emit_payload()
- schedule = emit_schedule()
- print("payload:", payload)
- print("schedule:", schedule)
- named_seq = schedule.operation.regions[0].blocks[0].operations[0]
+ MyTransform.load(register=False)
+
+ GetNamedAttributeOp.attach_interface_impls(ctx)
+ PrintParamOp.attach_interface_impls(ctx)
+
+ # NB: Other newly defined my_transform ops have their interfaces attached
+ # in their respective test functions.
+ schedule = emit_schedule()
interpreter.apply_named_sequence(
payload,
- named_seq,
+ named_seq := schedule.operation.regions[0].blocks[0].operations[0],
schedule,
)
- del payload
- del schedule
-
# Payload used by all tests
def emit_payload():
@@ -65,338 +48,200 @@ def emit_payload():
@func.FuncOp.from_py_func(F32Type.get(), F32Type.get(), results=[F32Type.get()])
def name_of_func(a, b):
- c = arith.AddFOp(a, b)
- d = arith.constant(F32Type.get(), 42.0)
- e = arith.constant(F32Type.get(), 24.0)
- func.ReturnOp([c.results[0]])
+ c = arith.addf(a, b)
+ i32 = ir.IntegerType.get_signless(32)
+ c42 = arith.constant(i32, 42)
+ c24 = arith.constant(i32, 24)
+ func.ReturnOp([c])
return payload_module
-class GetNamedAttributeOp:
- name = "get_named_attribute"
-
- def __init__(self, target: AnyOpType, attr_name: str):
- self.op = ir.Operation.create(
- "my_transform.get_named_attribute",
- [AnyParamType.get()],
- [get_op_result_or_value(target)],
- {"attr_name": ir.StringAttr.get(attr_name)},
+ at contextmanager
+def schedule_boilerplate():
+ schedule = ir.Module.create()
+ schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+ with ir.InsertionPoint(schedule.body):
+ named_sequence = transform.NamedSequenceOp(
+ "__transform_main",
+ [AnyOpType.get()],
+ [AnyOpType.get()],
+ arg_attrs=[{"transform.consumed": UnitAttr.get()}],
)
-
- @property
- def attr_as_param(self) -> ir.Value:
- return self.op.results[0]
-
- @classmethod
- def emit_irdl(cls):
- op = irdl.operation_(cls.name)
- with ir.InsertionPoint(op.body):
- op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
- param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
- name_handle_kind = irdl.base(base_name="#builtin.string")
- irdl.operands_(
- [op_handle_type],
- ["target"],
- [irdl.Variadicity.single],
- )
- irdl.attributes_([name_handle_kind], ["attr_name"])
- irdl.results_(
- [param_handle_type], ["attr_as_param"], [irdl.Variadicity.single]
- )
- return op
-
-
-class GetNamedAttributeTransformOpInterfaceFallbackModel(
- transform.TransformOpInterface
-):
- @staticmethod
- def apply(
- op_: ir.Operation,
- rewriter: transform.TransformRewriter,
- results: transform.TransformResults,
- state: transform.TransformState,
- ):
- targets = state.get_payload_ops(target := op_.operands[0])
- associated_attrs = []
- for target_op in targets:
- assoc_attr = target_op.attributes.get(op_.attributes["attr_name"].value)
- if assoc_attr is None:
- return transform.DiagnosedSilenceableFailure.RecoverableFailure
- associated_attrs.append(assoc_attr)
- results.set_params(op_.results[0], associated_attrs)
- return transform.DiagnosedSilenceableFailure.Success
-
- @staticmethod
- def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
- return False
+ with ir.InsertionPoint(named_sequence.body):
+ yield schedule, named_sequence
-class GetNamedAttributeMemoryEffectsOpInterfaceFallbackModel(
- ir.MemoryEffectsOpInterface
-):
+# MemoryEffectsOpInterface implementation for TransformOpInterface-implementing ops.
+# Used by all ops defined below.
+class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
@staticmethod
def get_effects(op_: ir.Operation, effects):
transform.only_reads_handle(list(op_.op_operands), effects)
transform.produces_handle(list(op_.results), effects)
-class OneOpInOneOpOut:
- name = "one_op_in_one_op_out"
-
- def __init__(self, op_arg: AnyOpType):
- self.op = ir.Operation.create(
- "my_transform.one_op_in_one_op_out",
- [AnyOpType.get()],
- [get_op_result_or_value(op_arg)],
- )
-
- @property
- def result(self):
- return self.op.results[0]
+# Demonstration of a TransformOpInterface-implementing op that gets named attributes
+# from target ops and produces them as param handles.
+ at ext.register_operation(MyTransform)
+class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
+ target: ext.Operand[transform.AnyOpType]
+ attr_name: ir.StringAttr
+ attr_as_param: ext.Result[transform.AnyParamType[()]]
@classmethod
- def emit_irdl(cls):
- op = irdl.operation_(cls.name)
- with ir.InsertionPoint(op.body):
- op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
- irdl.operands_(
- [op_handle_type],
- ["arg"],
- [irdl.Variadicity.single],
- )
- irdl.results_([op_handle_type], ["result"], [irdl.Variadicity.single])
- return op
-
-
-class OpValParamInParamOpValOut:
- name = "op_val_param_in_param_op_val_out"
-
- def __init__(
- self,
- op_arg: AnyOpType,
- val: AnyValueType,
- param: AnyParamType,
- ):
- self.op = ir.Operation.create(
- "my_transform." + self.name,
- [
- AnyParamType.get(),
- AnyOpType.get(),
- AnyValueType.get(),
- ],
- [
- get_op_result_or_value(op_arg),
- get_op_result_or_value(val),
- get_op_result_or_value(param),
- ],
- )
+ def attach_interface_impls(cls, ctx=None):
+ cls.TransformOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, ctx)
+ MemoryEffectsOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, ctx)
- @property
- def param_res(self):
- return self.op.results[0]
+ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
+ @staticmethod
+ def apply(
+ op: "GetNamedAttributeOp",
+ rewriter: transform.TransformRewriter,
+ results: transform.TransformResults,
+ state: transform.TransformState,
+ ):
+ target_ops = state.get_payload_ops(op.target)
+ associated_attrs = []
+ for target_op in target_ops:
+ assoc_attr = target_op.attributes.get(op.attr_name.value)
+ if assoc_attr is None:
+ return transform.DiagnosedSilenceableFailure.RecoverableFailure
+ associated_attrs.append(assoc_attr)
+ results.set_params(op.attr_as_param, associated_attrs)
+ return transform.DiagnosedSilenceableFailure.Success
- @property
- def op_res(self):
- return self.op.results[1]
+ @staticmethod
+ def allow_repeated_handle_operands(op: "GetNamedAttributeOp") -> bool:
+ return False
- @property
- def value_res(self):
- return self.op.results[2]
- @classmethod
- def emit_irdl(cls):
- op = irdl.operation_(cls.name)
- with ir.InsertionPoint(op.body):
- op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
- value_handle_type = irdl.is_(TypeAttr.get(AnyValueType.get()))
- param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
- irdl.operands_(
- [op_handle_type, value_handle_type, param_handle_type],
- ["op_arg", "value_arg", "param_arg"],
- [
- irdl.Variadicity.single,
- irdl.Variadicity.single,
- irdl.Variadicity.single,
- ],
- )
- irdl.results_(
- [param_handle_type, op_handle_type, value_handle_type],
- ["param_res", "op_res", "value_res"],
- [
- irdl.Variadicity.single,
- irdl.Variadicity.single,
- irdl.Variadicity.single,
- ],
- )
- return op
-
-
-class OpsParamsInValuesParamOut:
- name = "ops_params_in_values_param_out"
-
- def __init__(
- self,
- value_results: Sequence[AnyValueType],
- ops: Sequence[AnyOpType],
- params: Sequence[AnyParamType],
- ):
- def as_i32(x):
- return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), x)
-
- self.op = ir.Operation.create(
- "my_transform." + self.name,
- list(value_results) + [AnyParamType.get()],
- list(get_op_results_or_values(ops))
- + list(get_op_results_or_values(params)),
- {
- "operandSegmentSizes": ir.DenseI32ArrayAttr.get(
- [len(ops), len(params)]
- ),
- "resultSegmentSizes": ir.DenseI32ArrayAttr.get([len(value_results)]),
- },
- )
+ at ext.register_operation(MyTransform)
+class PrintParamOp(MyTransform.Operation, name="print_param"):
+ target: ext.Operand[transform.AnyParamType]
+ name: ir.StringAttr
- @property
- def param(self):
- return self.op.results[-1]
+ @classmethod
+ def attach_interface_impls(cls, ctx=None):
+ cls.TransformOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, ctx)
+ MemoryEffectsOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, ctx)
- @property
- def values(self):
- return self.op.results[:-1]
+ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
+ @staticmethod
+ def apply(
+ op: "PrintParamOp",
+ rewriter: transform.TransformRewriter,
+ results: transform.TransformResults,
+ state: transform.TransformState,
+ ):
+ target_attrs = state.get_params(op.target)
+ print(f"[[[ IR printer: {op.name.value} ]]]")
+ for attr in target_attrs:
+ print(attr)
+ return transform.DiagnosedSilenceableFailure.Success
- @classmethod
- def emit_irdl(cls):
- op = irdl.operation_(cls.name)
- with ir.InsertionPoint(op.body):
- op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
- value_handle_type = irdl.is_(TypeAttr.get(AnyValueType.get()))
- param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
- irdl.operands_(
- [op_handle_type, param_handle_type],
- ["ops", "params"],
- [irdl.Variadicity.variadic, irdl.Variadicity.variadic],
- )
- irdl.results_(
- [value_handle_type, param_handle_type],
- ["value_results", "param"],
- [irdl.Variadicity.variadic, irdl.Variadicity.single],
- )
- return op
+ @staticmethod
+ def allow_repeated_handle_operands(op: "GetNamedAttributeOp") -> bool:
+ return False
- at contextmanager
-def schedule_boilerplate():
- schedule = ir.Module.create()
- schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
- with ir.InsertionPoint(schedule.body):
- named_sequence = transform.NamedSequenceOp(
- "__transform_main",
- [AnyOpType.get()],
- [AnyOpType.get()],
- arg_attrs=[{"transform.consumed": UnitAttr.get()}],
- )
- with ir.InsertionPoint(named_sequence.body):
- yield schedule, named_sequence
+# Syntax for an op with one op handle operand and one op handle result.
+ at ext.register_operation(MyTransform)
+class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
+ target: ext.Operand[transform.AnyOpType]
+ res: ext.Result[transform.AnyOpType[()]]
+# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@run
def OneOpInOneOpOutTransformOpInterface():
- class OneOpInOneOpOutTransformOpInterfaceFallbackModel(
- transform.TransformOpInterface
- ):
+ # Define an implementation of the TransformOpInterface for OneOpInOneOpOut.
+ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
- op_: ir.Operation,
+ op: OneOpInOneOpOut,
rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
):
- targets = state.get_payload_ops(arg := op_.operands[0])
- target_names = [t.opview.name.value for t in targets]
- print(
- f"OneOpInOneOpOutTransformOpInterfaceFallbackModel: target_names={target_names}"
- )
- results.set_ops(result := op_.results[0], targets)
+ target_ops = state.get_payload_ops(op.target)
+ target_names = [t.opview.name.value for t in target_ops]
+ print(f"OneOpInOneOpOutTransformOpInterface: target_names={target_names}")
+ results.set_ops(op.res, target_ops)
return transform.DiagnosedSilenceableFailure.Success
@staticmethod
- def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+ def allow_repeated_handle_operands(op: OneOpInOneOpOut) -> bool:
return False
- OneOpInOneOpOutTransformOpInterfaceFallbackModel.attach(
- "my_transform.one_op_in_one_op_out", ir.Context.current
- )
-
- # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface.
- class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
- @staticmethod
- def get_effects(op_: ir.Operation, effects):
- transform.only_reads_handle(list(op_.op_operands), effects)
- transform.produces_handle(list(op_.results), effects)
+ # Attach the interface implementation to the op.
+ TransformOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
- MemoryEffectsOpInterfaceFallbackModel.attach(
- "my_transform.one_op_in_one_op_out", ir.Context.current
- )
+ # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
+ MemoryEffectsOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
with schedule_boilerplate() as (schedule, named_seq):
- print(f"{named_seq=}")
func_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["func.func"]
).result
- func_handle.dump()
- # CHECK: OneOpInOneOpOutTransformOpInterfaceFallbackModel: target_names=['name_of_func']
+ # CHECK: OneOpInOneOpOutTransformOpInterface: target_names=['name_of_func']
out = OneOpInOneOpOut(func_handle).result
- out.dump()
- print(out.owner)
+ # CHECK: Output handle from OneOpInOneOpOut
+ # CHECK-NEXT: func.func @
+ transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut")
transform.YieldOp([out])
- named_seq.verify()
- print("named_seq", named_seq)
- print("named_seq.parent", named_seq.parent)
return schedule
+ at ext.register_operation(MyTransform)
+class OpValParamInParamOpValOut(
+ MyTransform.Operation, name="op_val_param_in_param_op_val_out"
+):
+ # operands
+ op_arg: ext.Operand[transform.AnyOpType]
+ val_arg: ext.Operand[transform.AnyValueType]
+ param_arg: ext.Operand[transform.AnyParamType]
+ # results
+ param_res: ext.Result[transform.AnyParamType[()]]
+ op_res: ext.Result[transform.AnyOpType[()]]
+ value_res: ext.Result[transform.AnyValueType[()]]
+
+
+# CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
@run
def OpValParamInParamOpValOutTransformOpInterface():
- class OpValParamInParamOpValOutTransformOpInterfaceFallbackModel(
- transform.TransformOpInterface
- ):
+ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
- op_: ir.Operation,
+ op: OpValParamInParamOpValOut,
rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
):
- ops = state.get_payload_ops(op_.operands[0])
- values = state.get_payload_values(op_.operands[1])
- params = state.get_params(op_.operands[2])
+ ops = state.get_payload_ops(op.op_arg)
+ values = state.get_payload_values(op.val_arg)
+ params = state.get_params(op.param_arg)
print(
- f"OpValParamInParamOpValOutTransformOpInterfaceFallbackModel: ops={len(ops)}, values={len(values)}, params={len(params)}"
+ f"OpValParamInParamOpValOutTransformOpInterface: ops={len(ops)}, values={len(values)}, params={len(params)}"
)
- results.set_params(op_.results[0], params)
- results.set_ops(op_.results[1], ops)
- results.set_values(op_.results[2], values)
+ results.set_params(op.param_res, params)
+ results.set_ops(op.op_res, ops)
+ results.set_values(op.value_res, values)
return transform.DiagnosedSilenceableFailure.Success
@staticmethod
- def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+ def allow_repeated_handle_operands(op: OpValParamInParamOpValOut) -> bool:
return False
- OpValParamInParamOpValOutTransformOpInterfaceFallbackModel.attach(
- "my_transform.op_val_param_in_param_op_val_out", ir.Context.current
+ TransformOpInterfaceFallbackModel.attach(
+ OpValParamInParamOpValOut.OPERATION_NAME, ir.Context.current
)
- # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface.
- class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
- @staticmethod
- def get_effects(op_: ir.Operation, effects):
- transform.only_reads_handle(list(op_.op_operands), effects)
- transform.produces_handle(list(op_.results), effects)
-
+ # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
MemoryEffectsOpInterfaceFallbackModel.attach(
- "my_transform.op_val_param_in_param_op_val_out", ir.Context.current
+ OpValParamInParamOpValOut.OPERATION_NAME, ir.Context.current
)
with schedule_boilerplate() as (schedule, named_seq):
@@ -414,10 +259,36 @@ def get_effects(op_: ir.Operation, effects):
AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
).param
- # CHECK: OpValParamInParamOpValOutTransformOpInterfaceFallbackModel: ops=2, values=1, params=1
+ # CHECK: OpValParamInParamOpValOutTransformOpInterface: ops=2, values=1, params=1
op_val_param_op = OpValParamInParamOpValOut(
func_and_addf, value_handle, param_handle
)
+ # CHECK: Ops passed through OpValParamInParamOpValOut:
+ # CHECK-NEXT: func.func
+ # CHECK: arith.addf
+ transform.PrintOp(
+ target=op_val_param_op.op_res,
+ name="Ops passed through OpValParamInParamOpValOut:",
+ )
+
+ # CHECK: Ops defining values passed through OpValParamInParamOpValOut:
+ # CHECK-NEXT: arith.addf
+ addf_as_res = transform.GetDefiningOp(
+ transform.AnyOpType.get(), op_val_param_op.value_res
+ ).result
+ transform.PrintOp(
+ target=addf_as_res,
+ name="Ops defining values passed through OpValParamInParamOpValOut:",
+ )
+
+ # CHECK: Parameter passed through OpValParamInParamOpValOut:
+ # CHECK-NEXT: 42 : i32
+ PrintParamOp(
+ op_val_param_op.param_res,
+ name=ir.StringAttr.get(
+ "Parameter passed through OpValParamInParamOpValOut:"
+ ),
+ )
transform.YieldOp([op_val_param_op.op_res])
named_seq.verify()
@@ -425,74 +296,64 @@ def get_effects(op_: ir.Operation, effects):
return schedule
+ at ext.register_operation(MyTransform)
+class OpsParamsInValuesParamOut(
+ MyTransform.Operation, name="ops_params_in_values_param_out"
+):
+ # operands
+ ops: Sequence[ext.Operand[transform.AnyOpType]]
+ params: Sequence[ext.Operand[transform.AnyParamType]]
+ # results
+ values: Sequence[ext.Result[transform.AnyValueType]]
+ param: ext.Result[transform.AnyParamType]
+
+
+# CHECK-LABEL: Test: OpsParamsInValuesParamOutTransformOpInterface
@run
def OpsParamsInValuesParamOutTransformOpInterface():
- class OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel(
- transform.TransformOpInterface
- ):
+ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
- op_: ir.Operation,
+ op: OpsParamsInValuesParamOut,
rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
):
- # The last operand is the param. All previous ones are ops.
- op_handles, param_handles = [], []
- for operand in op_.operands:
- if isinstance(operand.type, transform.AnyOpType):
- op_handles.append(operand)
- else:
- param_handles.append(operand)
-
ops_count = 0
value_handles = []
- for op_handle in op_handles:
+ for op_handle in op.ops:
ops = state.get_payload_ops(op_handle)
ops_count += len(ops)
- value_handles.append(list(op.results[:1] for op in ops))
+ value_handles.append([i for op in ops for i in op.results])
param_count = 0
param_sum = 0
- for param_handle in param_handles:
+ for param_handle in op.params:
params = state.get_params(param_handle)
param_count += len(params)
param_sum += sum(p.value for p in params)
print(
- f"OpsParamInValuesParamOutTransformOpInterfaceFallbackModel: #op_handles={len(op_handles)}, ops_count={ops_count}, #param_handles={len(param_handles)}, param_count={param_count}"
+ f"OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel: op_count={ops_count}, param_count={param_count}"
)
- assert len(op_.results) + 1 == len(op_handles)
- for i in range(len(op_.results) - 1):
- results.set_values(
- op_.results[i],
- value_handles[i],
- )
+ assert len(op.values) == len(op.ops)
+ for value_res_handle, value_vector in zip(op.values, value_handles):
+ results.set_values(value_res_handle, value_vector)
results.set_params(
- op_.results[-1],
+ op.param,
[ir.IntegerAttr.get(ir.IntegerType.get_signless(32), param_sum)],
)
return transform.DiagnosedSilenceableFailure.Success
@staticmethod
- def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+ def allow_repeated_handle_operands(op: OpsParamsInValuesParamOut) -> bool:
return True
- OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel.attach(
- "my_transform." + OpsParamsInValuesParamOut.name
- )
+ TransformOpInterfaceFallbackModel.attach(OpsParamsInValuesParamOut.OPERATION_NAME)
- class OpsParamsInParamsOutMemoryEffectsOpInterfaceFallbackModel(
- ir.MemoryEffectsOpInterface
- ):
- @staticmethod
- def get_effects(op_: ir.Operation, effects):
- transform.only_reads_handle(list(op_.op_operands), effects)
- transform.produces_handle(list(op_.results), effects)
-
- OpsParamsInParamsOutMemoryEffectsOpInterfaceFallbackModel.attach(
- "my_transform." + OpsParamsInValuesParamOut.name
+ MemoryEffectsOpInterfaceFallbackModel.attach(
+ OpsParamsInValuesParamOut.OPERATION_NAME
)
with schedule_boilerplate() as (schedule, named_seq):
@@ -502,24 +363,48 @@ def get_effects(op_: ir.Operation, effects):
csts_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["arith.constant"]
).result
- csts_as_param = GetNamedAttributeOp(csts_handle, "value").attr_as_param
+ csts_as_param = GetNamedAttributeOp(
+ csts_handle, attr_name=ir.StringAttr.get("value")
+ ).attr_as_param
param_handle = transform.ParamConstantOp(
AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 123)
).param
- # CHECK: OpsParamInParamsOutTransformOpInterfaceFallbackModel: op_count=2, param_count=1
+ # CHECK: OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel: op_count=3, param_count=3
op = OpsParamsInValuesParamOut(
- [transform.AnyValueType.get()] * 2 + [transform.AnyParamType.get()],
+ [transform.AnyValueType.get()] * 2,
+ transform.AnyParamType.get(),
[func_handle, csts_handle],
[csts_as_param, param_handle],
)
- print(op.op)
- # CHECK: Sum of params: 189
- transform_debug.EmitParamAsRemarkOp(op.param, message="Sum of params")
- transform_debug.EmitRemarkAtOp(op.values[0], message="Value results 0")
- transform_debug.EmitRemarkAtOp(op.values[1], message="Value results 1")
+ empty_handle = transform.GetDefiningOp(transform.AnyOpType.get(), op.values[0])
+ # CHECK: Defining op of value result 0
+ transform.PrintOp(
+ target=empty_handle.result, name="Defining op of value result 0"
+ )
+ # NB: no result on the func.func, so output is expected to be empty
+ cst1_res, cst2_res = transform.SplitHandleOp(
+ [transform.AnyValueType.get()] * 2, op.values[1]
+ ).results
+
+ cst1_again = transform.GetDefiningOp(transform.AnyOpType.get(), cst1_res)
+ # CHECK-NEXT: Defining op of first constant
+ # CHECK-NEXT: arith.constant 42 : i32
+ transform.PrintOp(
+ target=cst1_again.result, name="Defining op of first constant"
+ )
+ cst2_again = transform.GetDefiningOp(transform.AnyOpType.get(), cst2_res)
+ # CHECK-NEXT: Defining op of second constant
+ # CHECK-NEXT: arith.constant 24 : i32
+ transform.PrintOp(
+ target=cst2_again.result, name="Defining op of second constant"
+ )
+
+ # CHECK: Sum of params:
+ # CHECK-NEXT: 189 : i32
+ PrintParamOp(op.param, name=ir.StringAttr.get("Sum of params:"))
transform.YieldOp([func_handle])
named_seq.verify()
>From d4efd172b3fb7ac483dc60234e00b9afd32f05be Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 27 Jan 2026 06:48:10 -0800
Subject: [PATCH 5/9] New rewriter test and general cleanup
---
mlir/include/mlir-c/Dialect/Transform.h | 5 +
mlir/lib/Bindings/Python/DialectTransform.cpp | 23 ++-
mlir/lib/CAPI/Dialect/Transform.cpp | 7 +
.../python/dialects/transform_op_interface.py | 131 +++++++++++++-----
4 files changed, 126 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index b087482345557..674d5f8b7b72d 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -216,6 +216,11 @@ MLIR_CAPI_EXPORTED void
mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
MlirMemoryEffectInstancesList effects);
+/// Helper to mark operands as consuming handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
+ MlirMemoryEffectInstancesList effects);
+
/// Helper to mark results as producing handles.
MLIR_CAPI_EXPORTED void
mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 3756518eeb06a..16a025a0581e5 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -106,7 +106,10 @@ class PyTransformState {
mlirTransformStateForEachPayloadOp(
state, value,
[](MlirOperation op, void *userData) {
- static_cast<nanobind::list *>(userData)->append(op);
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ auto opview = PyOperation::forOperation(context, op)->createOpView();
+ static_cast<nanobind::list *>(userData)->append(opview);
},
&result);
return result;
@@ -379,18 +382,24 @@ struct ParamType : PyConcreteType<ParamType> {
//===----------------------------------------------------------------------===//
namespace {
-void onlyReadsHandle(nb::list &operands, PyMemoryEffectsInstanceList effects) {
+void onlyReadsHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects) {
std::vector<MlirOpOperand> operandsVec;
- operandsVec.reserve(operands.size());
for (auto operand : operands)
operandsVec.push_back(nb::cast<PyOpOperand>(operand));
mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
effects.effects);
};
-void producesHandle(nb::list &results, PyMemoryEffectsInstanceList effects) {
+void consumesHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects) {
+ std::vector<MlirOpOperand> operandsVec;
+ for (auto operand : operands)
+ operandsVec.push_back(nb::cast<PyOpOperand>(operand));
+ mlirTransformConsumesHandle(operandsVec.data(), operandsVec.size(),
+ effects.effects);
+};
+
+void producesHandle(nb::iterable &results, PyMemoryEffectsInstanceList effects) {
std::vector<MlirOpResult> resultsVec;
- resultsVec.reserve(results.size());
for (auto result : results)
resultsVec.push_back(nb::cast<PyOpResult>(result));
mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
@@ -420,6 +429,10 @@ static void populateDialectTransformSubmodule(nb::module_ &m) {
"Mark operands as only reading handles.", nb::arg("operands"),
nb::arg("effects"));
+ m.def("consumes_handle", consumesHandle,
+ "Mark operands as consuming handles.", nb::arg("operands"),
+ nb::arg("effects"));
+
m.def("produces_handle", producesHandle, "Mark results as producing handles.",
nb::arg("results"), nb::arg("effects"));
}
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index fdb27ec5cdc89..121de43c1d065 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -312,6 +312,13 @@ void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
transform::onlyReadsHandle(operandArray, *unwrap(effects));
}
+/// Set the effect for the operands to consuming the transform handles.
+void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
+ MlirMemoryEffectInstancesList effects) {
+ MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
+ transform::consumesHandle(operandArray, *unwrap(effects));
+}
+
/// Set the effect for the results to that they produce transform handles.
void mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
MlirMemoryEffectInstancesList effects) {
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index ab919911bdd31..b2bc97ca3daeb 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -5,9 +5,9 @@
from contextlib import contextmanager
from mlir import ir
-from mlir.ir import F32Type, UnitAttr
-from mlir.dialects import transform, func, arith, ext
+from mlir.dialects import index, transform, func, arith, ext
from mlir.dialects.transform import (
+ DiagnosedSilenceableFailure,
AnyOpType,
AnyValueType,
AnyParamType,
@@ -15,6 +15,7 @@
interpreter,
)
+
@ext.register_dialect
class MyTransform(ext.Dialect, name="my_transform"):
pass
@@ -36,7 +37,7 @@ def run(emit_schedule):
interpreter.apply_named_sequence(
payload,
- named_seq := schedule.operation.regions[0].blocks[0].operations[0],
+ _named_seq := schedule.operation.regions[0].blocks[0].operations[0],
schedule,
)
@@ -45,13 +46,14 @@ def run(emit_schedule):
def emit_payload():
payload_module = ir.Module.create()
with ir.InsertionPoint(payload_module.body):
+ f32 = ir.F32Type.get()
- @func.FuncOp.from_py_func(F32Type.get(), F32Type.get(), results=[F32Type.get()])
+ @func.FuncOp.from_py_func(f32, f32, results=[f32])
def name_of_func(a, b):
c = arith.addf(a, b)
i32 = ir.IntegerType.get_signless(32)
- c42 = arith.constant(i32, 42)
- c24 = arith.constant(i32, 24)
+ arith.constant(i32, 42)
+ arith.constant(i32, 24)
func.ReturnOp([c])
return payload_module
@@ -66,19 +68,19 @@ def schedule_boilerplate():
"__transform_main",
[AnyOpType.get()],
[AnyOpType.get()],
- arg_attrs=[{"transform.consumed": UnitAttr.get()}],
+ arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
)
with ir.InsertionPoint(named_sequence.body):
yield schedule, named_sequence
# MemoryEffectsOpInterface implementation for TransformOpInterface-implementing ops.
-# Used by all ops defined below.
+# Used by most ops defined below.
class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
@staticmethod
- def get_effects(op_: ir.Operation, effects):
- transform.only_reads_handle(list(op_.op_operands), effects)
- transform.produces_handle(list(op_.results), effects)
+ def get_effects(op: ir.Operation, effects):
+ transform.only_reads_handle(op.op_operands, effects)
+ transform.produces_handle(op.results, effects)
# Demonstration of a TransformOpInterface-implementing op that gets named attributes
@@ -98,22 +100,22 @@ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: "GetNamedAttributeOp",
- rewriter: transform.TransformRewriter,
+ _rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
- ):
+ ) -> DiagnosedSilenceableFailure:
target_ops = state.get_payload_ops(op.target)
associated_attrs = []
for target_op in target_ops:
assoc_attr = target_op.attributes.get(op.attr_name.value)
if assoc_attr is None:
- return transform.DiagnosedSilenceableFailure.RecoverableFailure
+ return DiagnosedSilenceableFailure.RecoverableFailure
associated_attrs.append(assoc_attr)
results.set_params(op.attr_as_param, associated_attrs)
- return transform.DiagnosedSilenceableFailure.Success
+ return DiagnosedSilenceableFailure.Success
@staticmethod
- def allow_repeated_handle_operands(op: "GetNamedAttributeOp") -> bool:
+ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
return False
@@ -134,15 +136,15 @@ def apply(
rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
- ):
+ ) -> DiagnosedSilenceableFailure:
target_attrs = state.get_params(op.target)
print(f"[[[ IR printer: {op.name.value} ]]]")
for attr in target_attrs:
print(attr)
- return transform.DiagnosedSilenceableFailure.Success
+ return DiagnosedSilenceableFailure.Success
@staticmethod
- def allow_repeated_handle_operands(op: "GetNamedAttributeOp") -> bool:
+ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
return False
@@ -156,23 +158,23 @@ class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@run
def OneOpInOneOpOutTransformOpInterface():
- # Define an implementation of the TransformOpInterface for OneOpInOneOpOut.
+ # Define a simple passthrough implementation of the TransformOpInterface for OneOpInOneOpOut.
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: OneOpInOneOpOut,
- rewriter: transform.TransformRewriter,
+ _rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
- ):
+ ) -> DiagnosedSilenceableFailure:
target_ops = state.get_payload_ops(op.target)
- target_names = [t.opview.name.value for t in target_ops]
+ target_names = [t.name.value for t in target_ops]
print(f"OneOpInOneOpOutTransformOpInterface: target_names={target_names}")
results.set_ops(op.res, target_ops)
- return transform.DiagnosedSilenceableFailure.Success
+ return DiagnosedSilenceableFailure.Success
@staticmethod
- def allow_repeated_handle_operands(op: OneOpInOneOpOut) -> bool:
+ def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
return False
# Attach the interface implementation to the op.
@@ -188,13 +190,72 @@ def allow_repeated_handle_operands(op: OneOpInOneOpOut) -> bool:
# CHECK: OneOpInOneOpOutTransformOpInterface: target_names=['name_of_func']
out = OneOpInOneOpOut(func_handle).result
# CHECK: Output handle from OneOpInOneOpOut
- # CHECK-NEXT: func.func @
+ # CHECK-NEXT: func.func @name_of_func
transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut")
transform.YieldOp([out])
return schedule
+# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterfaceRewriterImpl
+ at run
+def OneOpInOneOpOutTransformOpInterfaceRewriterImpl():
+ # Define an implementation of the TransformOpInterface for OneOpInOneOpOut where
+ # the rewriter is used (to replace arith.constants by index.constants).
+ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
+ @staticmethod
+ def apply(
+ op: OneOpInOneOpOut,
+ rewriter: transform.TransformRewriter,
+ results: transform.TransformResults,
+ state: transform.TransformState,
+ ) -> DiagnosedSilenceableFailure:
+ result_ops = []
+ for target_op in state.get_payload_ops(op.target):
+ with ir.InsertionPoint(target_op):
+ index_version = index.constant(target_op.value.value)
+ result_ops.append(index_version.owner)
+ rewriter.replace_op(target_op, [index_version])
+ results.set_ops(op.res, result_ops)
+ return DiagnosedSilenceableFailure.Success
+
+ @staticmethod
+ def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
+ return False
+
+ # Attach the interface implementation to the op.
+ TransformOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
+
+ # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
+ MemoryEffectsOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
+
+ with schedule_boilerplate() as (schedule, named_seq):
+ func_handle = structured.MatchOp.match_op_names(
+ named_seq.bodyTarget, ["func.func"]
+ ).result
+ csts_handle = structured.MatchOp.match_op_names(
+ named_seq.bodyTarget, ["arith.constant"]
+ ).result
+ # CHECK: Before replacement:
+ # CHECK-NOT: index.constant
+ # CHECK-DAG: arith.constant 42 : i32
+ # CHECK-DAG: arith.constant 24 : i32
+ transform.PrintOp(target=func_handle, name="Before replacement:")
+ out = OneOpInOneOpOut(csts_handle).result
+ # CHECK: After replacement:
+ # CHECK-NOT: arith.constant
+ # CHECK-DAG: index.constant 42
+ # CHECK-DAG: index.constant 24
+ transform.PrintOp(target=func_handle, name="After replacement:")
+ # CHECK: Output handle from OneOpInOneOpOut:
+ # CHECK-NEXT: index.constant 42
+ # CHECK-NEXT: index.constant 24
+ transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut:")
+ transform.YieldOp([out])
+
+ return schedule
+
+
@ext.register_operation(MyTransform)
class OpValParamInParamOpValOut(
MyTransform.Operation, name="op_val_param_in_param_op_val_out"
@@ -216,10 +277,10 @@ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: OpValParamInParamOpValOut,
- rewriter: transform.TransformRewriter,
+ _rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
- ):
+ ) -> DiagnosedSilenceableFailure:
ops = state.get_payload_ops(op.op_arg)
values = state.get_payload_values(op.val_arg)
params = state.get_params(op.param_arg)
@@ -229,10 +290,10 @@ def apply(
results.set_params(op.param_res, params)
results.set_ops(op.op_res, ops)
results.set_values(op.value_res, values)
- return transform.DiagnosedSilenceableFailure.Success
+ return DiagnosedSilenceableFailure.Success
@staticmethod
- def allow_repeated_handle_operands(op: OpValParamInParamOpValOut) -> bool:
+ def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool:
return False
TransformOpInterfaceFallbackModel.attach(
@@ -315,10 +376,10 @@ class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: OpsParamsInValuesParamOut,
- rewriter: transform.TransformRewriter,
+ _rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
- ):
+ ) -> DiagnosedSilenceableFailure:
ops_count = 0
value_handles = []
for op_handle in op.ops:
@@ -344,11 +405,11 @@ def apply(
op.param,
[ir.IntegerAttr.get(ir.IntegerType.get_signless(32), param_sum)],
)
- return transform.DiagnosedSilenceableFailure.Success
+ return DiagnosedSilenceableFailure.Success
@staticmethod
- def allow_repeated_handle_operands(op: OpsParamsInValuesParamOut) -> bool:
- return True
+ def allow_repeated_handle_operands(_op: OpsParamsInValuesParamOut) -> bool:
+ return False
TransformOpInterfaceFallbackModel.attach(OpsParamsInValuesParamOut.OPERATION_NAME)
>From 40f692331d9a5cc7e89e49288239c0763583ecb7 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 27 Jan 2026 10:58:38 -0800
Subject: [PATCH 6/9] Add docstrings to tests
---
.../python/dialects/transform_op_interface.py | 25 +++++++++++++++++--
1 file changed, 23 insertions(+), 2 deletions(-)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index b2bc97ca3daeb..c707883866a03 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -158,6 +158,11 @@ class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@run
def OneOpInOneOpOutTransformOpInterface():
+ """Tests a simple passthrough interface implementation.
+
+ Checks that the target ops are correctly identified and passed as results.
+ """
+
# Define a simple passthrough implementation of the TransformOpInterface for OneOpInOneOpOut.
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
@@ -200,8 +205,12 @@ def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterfaceRewriterImpl
@run
def OneOpInOneOpOutTransformOpInterfaceRewriterImpl():
- # Define an implementation of the TransformOpInterface for OneOpInOneOpOut where
- # the rewriter is used (to replace arith.constants by index.constants).
+ """Tests an interface implementation using the rewriter to modify the IR.
+
+ Checks that `arith.constant` ops are replaced by `index.constant` ops and
+ that the results are correctly updated.
+ """
+
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
@@ -273,6 +282,12 @@ class OpValParamInParamOpValOut(
# CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
@run
def OpValParamInParamOpValOutTransformOpInterface():
+ """Tests an interface implementation involving Op, Value, and Param types.
+
+ Checks that payload ops, values, and parameters are correctly permuted and
+ propagated and accessible from the (permuted) result handles.
+ """
+
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
@@ -372,6 +387,12 @@ class OpsParamsInValuesParamOut(
# CHECK-LABEL: Test: OpsParamsInValuesParamOutTransformOpInterface
@run
def OpsParamsInValuesParamOutTransformOpInterface():
+ """Tests an interface with variadic Op and Param operands and variadic Value results.
+
+ Checks correct handling of multiple handles, parameter aggregation, and
+ result generation.
+ """
+
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
>From e5eed4f2df6028e1ca2ac91363f8c8869f0f17ad Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 27 Jan 2026 12:04:50 -0800
Subject: [PATCH 7/9] Formatting and minor load change
---
mlir/lib/Bindings/Python/DialectTransform.cpp | 9 ++++++---
mlir/python/mlir/dialects/ext.py | 12 +++---------
2 files changed, 9 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 16a025a0581e5..ecb911fd7a3ba 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -382,7 +382,8 @@ struct ParamType : PyConcreteType<ParamType> {
//===----------------------------------------------------------------------===//
namespace {
-void onlyReadsHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects) {
+void onlyReadsHandle(nb::iterable &operands,
+ PyMemoryEffectsInstanceList effects) {
std::vector<MlirOpOperand> operandsVec;
for (auto operand : operands)
operandsVec.push_back(nb::cast<PyOpOperand>(operand));
@@ -390,7 +391,8 @@ void onlyReadsHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects
effects.effects);
};
-void consumesHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects) {
+void consumesHandle(nb::iterable &operands,
+ PyMemoryEffectsInstanceList effects) {
std::vector<MlirOpOperand> operandsVec;
for (auto operand : operands)
operandsVec.push_back(nb::cast<PyOpOperand>(operand));
@@ -398,7 +400,8 @@ void consumesHandle(nb::iterable &operands, PyMemoryEffectsInstanceList effects)
effects.effects);
};
-void producesHandle(nb::iterable &results, PyMemoryEffectsInstanceList effects) {
+void producesHandle(nb::iterable &results,
+ PyMemoryEffectsInstanceList effects) {
std::vector<MlirOpResult> resultsVec;
for (auto result : results)
resultsVec.push_back(nb::cast<PyOpResult>(result));
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 3fff63dd0578b..f60dd62c5e6c8 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -432,7 +432,6 @@ class ExtOperation(Operation):
def __init__(*args, **kwargs):
raise RuntimeError("Cannot instantiate Dialect.ExtOperation directly.")
-
@classmethod
def __init_subclass__(cls, name: str, **kwargs):
cls.name = name
@@ -460,14 +459,9 @@ def _emit_module(cls) -> ir.Module:
return m
@classmethod
- def load(cls, register=True, context: Optional[ir.Context] = None) -> None:
- context = context or ir.Context.current
-
- try:
- context.dialects[cls.name]
- raise RuntimeError(f"Dialect {cls.name} is already loaded.")
- except IndexError:
- pass # Dialect not loaded yet.
+ def load(cls, register=True) -> None:
+ if hasattr(cls, "_mlir_module"):
+ return
cls._mlir_module = cls._emit_module()
pm = PassManager()
>From 7372de4b797fc4db0a7190d90b60fc9c1b396635 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 27 Jan 2026 12:41:46 -0800
Subject: [PATCH 8/9] Another ext.py fix
---
mlir/python/mlir/dialects/ext.py | 4 ++--
mlir/test/python/dialects/transform_op_interface.py | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index f60dd62c5e6c8..45a7249399cca 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -459,8 +459,8 @@ def _emit_module(cls) -> ir.Module:
return m
@classmethod
- def load(cls, register=True) -> None:
- if hasattr(cls, "_mlir_module"):
+ def load(cls, register=True, reload=False) -> None:
+ if hasattr(cls, "_mlir_module") and not reload:
return
cls._mlir_module = cls._emit_module()
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index c707883866a03..7818ce919d2f3 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -26,7 +26,7 @@ def run(emit_schedule):
with ir.Context() as ctx, ir.Location.unknown():
payload = emit_payload()
- MyTransform.load(register=False)
+ MyTransform.load(register=False, reload=True)
GetNamedAttributeOp.attach_interface_impls(ctx)
PrintParamOp.attach_interface_impls(ctx)
>From 97316be8b62015b7a5ff50c24007e8f5f90d1066 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 27 Jan 2026 12:55:00 -0800
Subject: [PATCH 9/9] Opview from memoryeffectsinterface
---
mlir/lib/Bindings/Python/IRInterfaces.cpp | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 05633746c3136..538ec3e62f231 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -373,8 +373,12 @@ class PyMemoryEffectsOpInterface
PyMemoryEffectsInstanceList effectsWrapper{effects};
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ auto opview = PyOperation::forOperation(context, op)->createOpView();
+
// Invoke `pyClass.get_effects(op, effects)`.
- pyGetEffects(op, effectsWrapper);
+ pyGetEffects(opview, effectsWrapper);
};
mlirMemoryEffectsOpInterfaceAttachFallbackModel(
More information about the Mlir-commits
mailing list