[Mlir-commits] [mlir] [MLIR][Python] Expose TransformOpInterface and MemoryEffectsOpInterfaces (PR #176920)
Rolf Morel
llvmlistbot at llvm.org
Tue Jan 20 04:58:39 PST 2026
https://github.com/rolfmorel created https://github.com/llvm/llvm-project/pull/176920
Enables implementing these two interfaces from Python through defining the relevant methods on a Python class which then serve as callbacks for a new FallbackModel C++ class that acts as a "trampoline" to Python for when the Interface's methods are called from C++. Like in the C++ codebase, these FallbackModels are a late-binding mechanism which can be attached to an operation after its definition.
>From 12c733af13bf13fce8734dd800b80409c1bf9889 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 7 Jan 2026 12:05:58 -0800
Subject: [PATCH] [MLIR][Python] Expose TransformOpInterface and
MemoryEffectsOpInterfaces
Enables implementing these two interfaces from Python through defining
the relevant methods on a Python class which then serve as callbacks for
a new FallbackModel C++ class that acts as a "trampoline" to Python for
when the Interface's methods are called from C++. Like in the C++
codebase, these FallbackModels are a late-binding mechanism which can be
attached to an operation after its definition.
---
mlir/include/mlir-c/Dialect/Transform.h | 135 ++++++++++
mlir/include/mlir-c/IR.h | 8 +
mlir/include/mlir-c/Interfaces.h | 39 ++-
mlir/include/mlir/Bindings/Python/IRCore.h | 11 +-
mlir/include/mlir/CAPI/Dialect/Transform.h | 28 ++
mlir/include/mlir/CAPI/IR.h | 1 +
mlir/include/mlir/CAPI/Interfaces.h | 8 +
mlir/lib/Bindings/Python/DialectTransform.cpp | 250 ++++++++++++++++-
mlir/lib/Bindings/Python/IRCore.cpp | 53 +++-
mlir/lib/Bindings/Python/IRInterfaces.cpp | 181 ++++---------
mlir/lib/Bindings/Python/IRInterfaces.h | 156 +++++++++++
mlir/lib/Bindings/Python/Rewrite.cpp | 64 +----
mlir/lib/Bindings/Python/Rewrite.h | 71 ++++-
mlir/lib/CAPI/Dialect/Transform.cpp | 196 ++++++++++++++
mlir/lib/CAPI/IR/IR.cpp | 12 +
mlir/lib/CAPI/Interfaces/Interfaces.cpp | 74 +++++
mlir/lib/IR/OperationSupport.cpp | 2 +-
mlir/python/CMakeLists.txt | 2 +
mlir/python/mlir/_mlir_libs/__init__.py | 1 +
.../python/dialects/transform_op_interface.py | 252 ++++++++++++++++++
20 files changed, 1344 insertions(+), 200 deletions(-)
create mode 100644 mlir/include/mlir/CAPI/Dialect/Transform.h
create mode 100644 mlir/lib/Bindings/Python/IRInterfaces.h
create mode 100644 mlir/test/python/dialects/transform_op_interface.py
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 911c9ef659a1e..38bd9756176a0 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -10,7 +10,9 @@
#ifndef MLIR_C_DIALECT_TRANSFORM_H
#define MLIR_C_DIALECT_TRANSFORM_H
+#include "mlir-c/Interfaces.h"
#include "mlir-c/IR.h"
+#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
@@ -19,6 +21,32 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirTransformResults, void);
+DEFINE_C_API_STRUCT(MlirTransformRewriter, void);
+DEFINE_C_API_STRUCT(MlirTransformState, void);
+
+#undef DEFINE_C_API_STRUCT
+
+//===---------------------------------------------------------------------===//
+// DiagnosedSilenceableFailure
+//===---------------------------------------------------------------------===//
+
+/// Enum representing the result of a transform operation.
+typedef enum {
+ /// The operation succeeded.
+ MlirDiagnosedSilenceableFailureSuccess,
+ /// The operation failed in a silenceable way.
+ MlirDiagnosedSilenceableFailureSilenceableFailure,
+ /// The operation failed definitively.
+ MlirDiagnosedSilenceableFailureDefiniteFailure
+} MlirDiagnosedSilenceableFailure;
+
//===---------------------------------------------------------------------===//
// AnyOpType
//===---------------------------------------------------------------------===//
@@ -86,6 +114,113 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirTransformParamTypeGetName(void);
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
+//===---------------------------------------------------------------------===//
+// TransformRewriter
+//===---------------------------------------------------------------------===//
+
+/// Cast the TransformRewriter to a RewriterBase
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirTransformRewriterAsBase(MlirTransformRewriter rewriter);
+
+//===---------------------------------------------------------------------===//
+// TransformResults
+//===---------------------------------------------------------------------===//
+
+/// Set the payload operations for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void mlirTransformResultsSetOps(MlirTransformResults results,
+ MlirValue result,
+ intptr_t numOps,
+ MlirOperation *ops);
+
+/// Set the payload values for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void
+mlirTransformResultsSetValues(MlirTransformResults results, MlirValue result,
+ intptr_t numValues, MlirValue *values);
+
+/// Set the parameters for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void
+mlirTransformResultsSetParams(MlirTransformResults results, MlirValue result,
+ intptr_t numParams, MlirAttribute *params);
+
+//===---------------------------------------------------------------------===//
+// TransformState
+//===---------------------------------------------------------------------===//
+
+/// Callback for iterating over payload operations.
+typedef void (*MlirOperationCallback)(MlirOperation, void *userData);
+
+/// Iterate over payload operations associated with the transform IR value.
+/// Calls the callback for each payload operation.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachPayloadOp(MlirTransformState state, MlirValue value,
+ MlirOperationCallback callback,
+ void *userData);
+
+/// Callback for iterating over payload values.
+typedef void (*MlirValueCallback)(MlirValue, void *userData);
+
+/// Iterate over payload values associated with the transform IR value.
+/// Calls the callback for each payload value.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachPayloadValue(MlirTransformState state, MlirValue value,
+ MlirValueCallback callback,
+ void *userData);
+
+/// Callback for iterating over parameters.
+typedef void (*MlirAttributeCallback)(MlirAttribute, void *userData);
+
+/// Iterate over parameters associated with the transform IR value.
+/// Calls the callback for each parameter.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
+ MlirAttributeCallback callback, void *userData);
+
+//===---------------------------------------------------------------------===//
+// TransformOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the TransformOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOpInterfaceTypeID(void);
+
+/// Callbacks for implementing TransformOpInterface from external code.
+typedef struct {
+ /// Optional constructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*construct)(void *userData);
+ /// Optional destructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*destruct)(void *userData);
+ /// Apply callback that implements the transformation.
+ MlirDiagnosedSilenceableFailure (*apply)(MlirOperation op,
+ MlirTransformRewriter rewriter,
+ MlirTransformResults results,
+ MlirTransformState state,
+ void *userData);
+ /// Callback to check if repeated handle operands are allowed.
+ bool (*allowsRepeatedHandleOperands)(MlirOperation op, void *userData);
+ void *userData;
+} MlirTransformOpInterfaceCallbacks;
+
+/// Attach TransformOpInterface to the operation with the given name using
+/// the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirTransformOpInterfaceCallbacks callbacks);
+
+//===---------------------------------------------------------------------===//
+// Transform-specifc MemoryEffectsOpInterface helpers
+//===---------------------------------------------------------------------===//
+
+/// Helper to mark operands as only reading handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
+ MlirMemoryEffectInstancesList effects);
+
+/// Helper to mark results as producing handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
+ MlirMemoryEffectInstancesList effects);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 80ff39c82a9ee..9b58267cf1c5a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -55,6 +55,7 @@ DEFINE_C_API_STRUCT(MlirDialect, void);
DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
DEFINE_C_API_STRUCT(MlirOperation, void);
DEFINE_C_API_STRUCT(MlirOpOperand, void);
+DEFINE_C_API_STRUCT(MlirOpResult, void);
DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
DEFINE_C_API_STRUCT(MlirBlock, void);
DEFINE_C_API_STRUCT(MlirRegion, void);
@@ -673,6 +674,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op);
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op,
intptr_t pos);
+/// Returns `pos`-th OpOperand of the operation.
+MLIR_CAPI_EXPORTED MlirOpOperand mlirOperationGetOpOperand(MlirOperation op,
+ intptr_t pos);
+
/// Sets the `pos`-th operand of the operation.
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
MlirValue newValue);
@@ -1044,6 +1049,9 @@ MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value);
/// Returns 1 if the value is an operation result, 0 otherwise.
MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value);
+/// Cast the value to an OpResult. Asserts if the value is not an op result.
+MLIR_CAPI_EXPORTED MlirOpResult mlirValueToOpResult(MlirValue value);
+
/// Returns the block in which this value is defined as an argument. Asserts if
/// the value is not a block argument.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value);
diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h
index a5a3473eaef59..17a812dcd86a9 100644
--- a/mlir/include/mlir-c/Interfaces.h
+++ b/mlir/include/mlir-c/Interfaces.h
@@ -22,6 +22,16 @@
extern "C" {
#endif
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirMemoryEffectInstancesList, void);
+
+#undef DEFINE_C_API_STRUCT
+
/// Returns `true` if the given operation implements an interface identified by
/// its TypeID.
MLIR_CAPI_EXPORTED bool
@@ -42,7 +52,7 @@ mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
//===----------------------------------------------------------------------===//
/// Returns the interface TypeID of the InferTypeOpInterface.
-MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID();
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(void);
/// These callbacks are used to return multiple types from functions while
/// transferring ownership to the caller. The first argument is the number of
@@ -65,7 +75,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
//===----------------------------------------------------------------------===//
/// Returns the interface TypeID of the InferShapedTypeOpInterface.
-MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID();
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(void);
/// These callbacks are used to return multiple shaped type components from
/// functions while transferring ownership to the caller. The first argument is
@@ -87,6 +97,31 @@ mlirInferShapedTypeOpInterfaceInferReturnTypes(
void *properties, intptr_t nRegions, MlirRegion *regions,
MlirShapedTypeComponentsCallback callback, void *userData);
+//===---------------------------------------------------------------------===//
+// MemoryEffectsOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the MemoryEffectsOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirMemoryEffectsOpInterfaceTypeID(void);
+
+/// Callbacks for implementing MemoryEffectsOpInterface from external code.
+typedef struct {
+ /// Optional constructor for user data. Set to nullptr to disable it.
+ void (*construct)(void *userData);
+ /// Optional destructor for user data. Set to nullptr to disable it.
+ void (*destruct)(void *userData);
+ /// Get memory effects callback.
+ void (*getEffects)(MlirOperation op, MlirMemoryEffectInstancesList effects,
+ void *userData);
+ void *userData;
+} MlirMemoryEffectsOpInterfaceCallbacks;
+
+/// Attach a new FallbackModel for the MemoryEffectsOpInterface to the named
+/// operation. The FallbackModel will call the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirMemoryEffectsOpInterfaceCallbacks callbacks);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 599771f8a3283..c1371ad650858 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1495,6 +1495,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationList {
class MLIR_PYTHON_API_EXPORTED PyOpOperand {
public:
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
+ operator MlirOpOperand() const { return opOperand; }
nanobind::typed<nanobind::object, PyOpView> getOwner() const;
@@ -1601,6 +1602,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpResult : public PyConcreteValue<PyOpResult> {
static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
static constexpr const char *pyClassName = "OpResult";
using PyConcreteValue::PyConcreteValue;
+ operator MlirOpResult() { return mlirValueToOpResult(castFrom(*this)); }
static void bindDerived(ClassTy &c);
};
@@ -1837,13 +1839,20 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
+
+/// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+static nanobind::object classmethod(Func f, Args... args) {
+ nanobind::object cf = nanobind::cpp_function(f, args...);
+ return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
+}
+
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
namespace nanobind {
namespace detail {
-
template <>
struct type_caster<
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext>
diff --git a/mlir/include/mlir/CAPI/Dialect/Transform.h b/mlir/include/mlir/CAPI/Dialect/Transform.h
new file mode 100644
index 0000000000000..792236cd8601f
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Dialect/Transform.h
@@ -0,0 +1,28 @@
+//===- Transform.h - C API Utils for Transform dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// the Transform dialect. This file should not be included from C++ code other
+// than C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_DIALECT_TRANSFORM_H
+#define MLIR_CAPI_DIALECT_TRANSFORM_H
+
+#include "mlir-c/Dialect/Transform.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+
+DEFINE_C_API_PTR_METHODS(MlirTransformRewriter,
+ mlir::transform::TransformRewriter)
+DEFINE_C_API_PTR_METHODS(MlirTransformResults,
+ mlir::transform::TransformResults)
+DEFINE_C_API_PTR_METHODS(MlirTransformState, mlir::transform::TransformState)
+
+#endif // MLIR_CAPI_DIALECT_TRANSFORM_H
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 1836cb0acb67e..f58ff423d23f7 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -29,6 +29,7 @@ DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand)
+DEFINE_C_API_PTR_METHODS(MlirOpResult, mlir::OpResult)
DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable)
diff --git a/mlir/include/mlir/CAPI/Interfaces.h b/mlir/include/mlir/CAPI/Interfaces.h
index 4154b8c9ec6cc..15afc9fb0f18e 100644
--- a/mlir/include/mlir/CAPI/Interfaces.h
+++ b/mlir/include/mlir/CAPI/Interfaces.h
@@ -15,4 +15,12 @@
#ifndef MLIR_CAPI_INTERFACES_H
#define MLIR_CAPI_INTERFACES_H
+#include "mlir-c/Interfaces.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+DEFINE_C_API_PTR_METHODS(
+ MlirMemoryEffectInstancesList,
+ llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance>)
+
#endif // MLIR_CAPI_INTERFACES_H
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 82905498921ce..7dbeb72a0d2a8 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -8,12 +8,14 @@
#include <string>
+#include "IRInterfaces.h"
+#include "Rewrite.h"
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
+#include <nanobind/trampoline.h>
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
@@ -22,6 +24,208 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
namespace transform {
+
+//===----------------------------------------------------------------------===//
+// TransformRewriter
+//===----------------------------------------------------------------------===//
+class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
+public:
+ static constexpr const char *pyClassName = "TransformRewriter";
+
+ PyTransformRewriter(MlirTransformRewriter rewriter)
+ : PyRewriterBase(mlirTransformRewriterAsBase(rewriter)) {}
+};
+
+//===----------------------------------------------------------------------===//
+// TransformResults
+//===----------------------------------------------------------------------===//
+class PyTransformResults {
+public:
+ PyTransformResults(MlirTransformResults results) : results(results) {}
+
+ MlirTransformResults get() const { return results; }
+
+ void setOps(MlirValue result, const nanobind::list &ops) {
+ std::vector<MlirOperation> opsVec;
+ opsVec.reserve(ops.size());
+ for (auto op : ops) {
+ opsVec.push_back(nb::cast<MlirOperation>(op));
+ }
+ mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
+ }
+
+ void setValues(MlirValue result, const nanobind::list &values) {
+ std::vector<MlirValue> valuesVec;
+ valuesVec.reserve(values.size());
+ for (auto item : values) {
+ valuesVec.push_back(nb::cast<MlirValue>(item));
+ }
+ mlirTransformResultsSetValues(results, result, valuesVec.size(),
+ valuesVec.data());
+ }
+
+ void setParams(MlirValue result, const nanobind::list ¶ms) {
+ std::vector<MlirAttribute> paramsVec;
+ paramsVec.reserve(params.size());
+ for (auto item : params) {
+ paramsVec.push_back(nb::cast<MlirAttribute>(item));
+ }
+ mlirTransformResultsSetParams(results, result, paramsVec.size(),
+ paramsVec.data());
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nb::class_<PyTransformResults>(m, "TransformResults")
+ .def(nb::init<MlirTransformResults>())
+ .def("set_ops", &PyTransformResults::setOps,
+ "Set the payload operations for a transform result.",
+ nb::arg("result"), nb::arg("ops"))
+ .def("set_values", &PyTransformResults::setValues,
+ "Set the payload values for a transform result.",
+ nb::arg("result"), nb::arg("values"))
+ .def("set_params", &PyTransformResults::setParams,
+ "Set the parameters for a transform result.", nb::arg("result"),
+ nb::arg("params"));
+ }
+
+private:
+ MlirTransformResults results;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformState
+//===----------------------------------------------------------------------===//
+class PyTransformState {
+public:
+ PyTransformState(MlirTransformState state) : state(state) {}
+
+ MlirTransformState get() const { return state; }
+
+ nanobind::list getPayloadOps(MlirValue value) {
+ nanobind::list result;
+ mlirTransformStateForEachPayloadOp(
+ state, value,
+ [](MlirOperation op, void *userData) {
+ static_cast<nanobind::list *>(userData)->append(op);
+ },
+ &result);
+ return result;
+ }
+
+ nanobind::list getPayloadValues(MlirValue value) {
+ nanobind::list result;
+ mlirTransformStateForEachPayloadValue(
+ state, value,
+ [](MlirValue val, void *userData) {
+ static_cast<nanobind::list *>(userData)->append(val);
+ },
+ &result);
+ return result;
+ }
+
+ nanobind::list getParams(MlirValue value) {
+ nanobind::list result;
+ mlirTransformStateForEachParam(
+ state, value,
+ [](MlirAttribute attr, void *userData) {
+ static_cast<nanobind::list *>(userData)->append(attr);
+ },
+ &result);
+ return result;
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nb::class_<PyTransformState>(m, "TransformState")
+ .def(nb::init<MlirTransformState>())
+ .def("get_payload_ops", &PyTransformState::getPayloadOps,
+ "Get the payload operations associated with a transform IR value.",
+ nb::arg("operand"))
+ .def("get_payload_values", &PyTransformState::getPayloadValues,
+ "Get the payload values associated with a transform IR value.",
+ nb::arg("operand"))
+ .def("get_params", &PyTransformState::getParams,
+ "Get the parameters (attributes) associated with a transform IR "
+ "value.",
+ nb::arg("operand"));
+ }
+
+private:
+ MlirTransformState state;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformOpInterface
+//===----------------------------------------------------------------------===//
+class PyTransformOpInterface
+ : public PyConcreteOpInterface<PyTransformOpInterface> {
+public:
+ using PyConcreteOpInterface<PyTransformOpInterface>::PyConcreteOpInterface;
+
+ constexpr static const char *pyClassName = "TransformOpInterface";
+ constexpr static GetTypeIDFunctionTy getInterfaceID =
+ &mlirTransformOpInterfaceTypeID;
+
+ /// Attach a new TransformOpInterface FallbackModel to the named operation.
+ /// The FallbackModel acts as a trampoline for callbacks on the Python class.
+ static void attach(nb::object &pyClass, const std::string &opName,
+ DefaultingPyMlirContext ctx) {
+ // Prepare the callbacks that will be used by the FallbackModel.
+ MlirTransformOpInterfaceCallbacks callbacks;
+ // Make the pointer to the Python class available to the callbacks.
+ callbacks.userData = pyClass.ptr();
+ nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+
+ // The above ref bump is all we need as initialization, no need to run the
+ // construct callback.
+ callbacks.construct = nullptr;
+ // Upon the FallbackModel's destruction, drop the ref to the Python class.
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ // The apply callback which calls into Python.
+ callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
+ MlirTransformResults results, MlirTransformState state,
+ void *userData) -> MlirDiagnosedSilenceableFailure {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
+
+ auto pyRewriter = PyTransformRewriter(rewriter);
+ auto pyResults = PyTransformResults(results);
+ auto pyState = PyTransformState(state);
+
+ // Invoke `pyClass.apply(op, rewriter, results, state)` as a classmethod.
+ nb::object res = pyApply(op, pyRewriter, pyResults, pyState);
+
+ return nb::cast<MlirDiagnosedSilenceableFailure>(res);
+ };
+
+ // The allows_repeated_handle_operands callback which calls into Python.
+ callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
+ void *userData) -> bool {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
+ nb::getattr(pyClass, "allow_repeated_handle_operands"));
+
+ // Invoke `pyClass.allow_repeated_handle_operands(op)` as a classmethod.
+ nb::object res = pyAllowRepeatedHandleOperands(op);
+
+ return nb::cast<bool>(res);
+ };
+
+ // Attach a FallbackModel, which calls into Python, to the named operation.
+ mlirTransformOpInterfaceAttachFallbackModel(
+ ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
+ }
+
+ static void bindDerived(ClassTy &transformOpInterfaceClass) {
+ transformOpInterfaceClass.attr("attach") =
+ classmethod(&PyTransformOpInterface::attach, nb::arg("cls"),
+ nb::arg("op_name"), nb::arg("ctx") = nb::none());
+ }
+};
+
//===-------------------------------------------------------------------===//
// AnyOpType
//===-------------------------------------------------------------------===//
@@ -162,12 +366,54 @@ struct ParamType : PyConcreteType<ParamType> {
}
};
+//===----------------------------------------------------------------------===//
+// MemoryEffectsOpInterface helpers
+//===----------------------------------------------------------------------===//
+
+namespace {
+void onlyReadsHandle(nb::list &operands, PyMemoryEffectsInstanceList effects) {
+ std::vector<MlirOpOperand> operandsVec;
+ operandsVec.reserve(operands.size());
+ for (auto operand : operands)
+ operandsVec.push_back(nb::cast<PyOpOperand>(operand));
+ mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
+ effects.effects);
+};
+
+void producesHandle(nb::list &results, PyMemoryEffectsInstanceList effects) {
+ std::vector<MlirOpResult> resultsVec;
+ resultsVec.reserve(results.size());
+ for (auto result : results)
+ resultsVec.push_back(nb::cast<PyOpResult>(result));
+ mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
+ effects.effects);
+};
+} // namespace
+
static void populateDialectTransformSubmodule(nb::module_ &m) {
+ nb::enum_<MlirDiagnosedSilenceableFailure>(m, "DiagnosedSilenceableFailure")
+ .value("Success", MlirDiagnosedSilenceableFailureSuccess)
+ .value("SilenceableFailure",
+ MlirDiagnosedSilenceableFailureSilenceableFailure)
+ .value("DefiniteFailure", MlirDiagnosedSilenceableFailureDefiniteFailure);
+
AnyOpType::bind(m);
AnyParamType::bind(m);
AnyValueType::bind(m);
OperationType::bind(m);
ParamType::bind(m);
+
+ PyTransformRewriter::bind(m);
+ PyTransformResults::bind(m);
+ PyTransformState::bind(m);
+ PyTransformOpInterface::bind(m);
+
+ m.def("only_reads_handle", onlyReadsHandle,
+ "Mark operands as only reading handles.", nb::arg("operands"),
+ nb::arg("effects"));
+
+ m.def("produces_handle", producesHandle, "Mark results as producing handles.",
+ nb::arg("results"), nb::arg("effects"));
}
} // namespace transform
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index eb00363a54034..f3108473c712a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -10,7 +10,6 @@
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/NanobindUtils.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
// clang-format on
#include "mlir-c/BuiltinAttributes.h"
@@ -52,13 +51,6 @@ operations.
// Utilities.
//------------------------------------------------------------------------------
-/// Helper for creating an @classmethod.
-template <class Func, typename... Args>
-static nb::object classmethod(Func f, Args... args) {
- nb::object cf = nb::cpp_function(f, args...);
- return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
-}
-
static nb::object
createCustomDialectWrapper(const std::string &dialectNamespace,
nb::object dialectDescriptor) {
@@ -2306,6 +2298,44 @@ PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
return PyOpOperandList(operation, startIndex, length, step);
}
+/// A list of OpOperands. Internally, these are stored as consecutive elements,
+/// random access is cheap. The (returned) OpOperand list is associated with the
+/// operation whose operands these are, and thus extends the lifetime of this
+/// operation.
+class PyOpOpOperandList : public Sliceable<PyOpOpOperandList, PyOpOperand> {
+public:
+ static constexpr const char *pyClassName = "OpOpOperandList";
+ using SliceableT = Sliceable<PyOpOperandList, PyOpOperand>;
+
+ PyOpOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumOperands(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpOpOperandList, PyOpOperand>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumOperands(operation->get());
+ }
+
+ PyOpOperand getRawElement(intptr_t pos) {
+ MlirOpOperand opOperand = mlirOperationGetOpOperand(operation->get(), pos);
+ return PyOpOperand(opOperand);
+ }
+
+ PyOpOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpOpOperandList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
PyOpSuccessors::PyOpSuccessors(PyOperationRef operation, intptr_t startIndex,
intptr_t length, intptr_t step)
: Sliceable(startIndex,
@@ -3534,6 +3564,12 @@ void populateIRCore(nb::module_ &m) {
return PyOpOperandList(self.getOperation().getRef());
},
"Returns the list of operation operands.")
+ .def_prop_ro(
+ "op_operands",
+ [](PyOperationBase &self) {
+ return PyOpOpOperandList(self.getOperation().getRef());
+ },
+ "Returns the list of op operands.")
.def_prop_ro(
"regions",
[](PyOperationBase &self) {
@@ -4825,6 +4861,7 @@ void populateIRCore(nb::module_ &m) {
PyOpAttributeMap::bind(m);
PyOpOperandIterator::bind(m);
PyOpOperandList::bind(m);
+ PyOpOpOperandList::bind(m);
PyOpResultList::bind(m);
PyOpSuccessors::bind(m);
PyRegionIterator::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 09112d4989ae4..05633746c3136 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,12 +12,12 @@
#include <utility>
#include <vector>
+#include "IRInterfaces.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
-#include "mlir/Bindings/Python/Nanobind.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -26,18 +26,6 @@ namespace nb = nanobind;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-constexpr static const char *constructorDoc =
- R"(Creates an interface from a given operation/opview object or from a
-subclass of OpView. Raises ValueError if the operation does not implement the
-interface.)";
-
-constexpr static const char *operationDoc =
- R"(Returns an Operation for which the interface was constructed.)";
-
-constexpr static const char *opviewDoc =
- R"(Returns an OpView subclass _instance_ for which the interface was
-constructed)";
-
constexpr static const char *inferReturnTypesDoc =
R"(Given the arguments required to build an operation, attempts to infer
its return types. Raises ValueError on failure.)";
@@ -124,119 +112,6 @@ wrapRegions(std::optional<std::vector<PyRegion>> regions) {
} // namespace
-/// CRTP base class for Python classes representing MLIR Op interfaces.
-/// Interface hierarchies are flat so no base class is expected here. The
-/// derived class is expected to define the following static fields:
-/// - `const char *pyClassName` - the name of the Python class to create;
-/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
-/// of the interface.
-/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
-/// interface-specific methods.
-///
-/// An interface class may be constructed from either an Operation/OpView object
-/// or from a subclass of OpView. In the latter case, only the static interface
-/// methods are available, similarly to calling ConcereteOp::staticMethod on the
-/// C++ side. Implementations of concrete interfaces can use the `isStatic`
-/// method to check whether the interface object was constructed from a class or
-/// an operation/opview instance. The `getOpName` always succeeds and returns a
-/// canonical name of the operation suitable for lookups.
-template <typename ConcreteIface>
-class PyConcreteOpInterface {
-protected:
- using ClassTy = nb::class_<ConcreteIface>;
- using GetTypeIDFunctionTy = MlirTypeID (*)();
-
-public:
- /// Constructs an interface instance from an object that is either an
- /// operation or a subclass of OpView. In the latter case, only the static
- /// methods of the interface are accessible to the caller.
- PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
- : obj(std::move(object)) {
- try {
- operation = &nb::cast<PyOperation &>(obj);
- } catch (nb::cast_error &) {
- // Do nothing.
- }
-
- try {
- operation = &nb::cast<PyOpView &>(obj).getOperation();
- } catch (nb::cast_error &) {
- // Do nothing.
- }
-
- if (operation != nullptr) {
- if (!mlirOperationImplementsInterface(*operation,
- ConcreteIface::getInterfaceID())) {
- std::string msg = "the operation does not implement ";
- throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
- }
-
- MlirIdentifier identifier = mlirOperationGetName(*operation);
- MlirStringRef stringRef = mlirIdentifierStr(identifier);
- opName = std::string(stringRef.data, stringRef.length);
- } else {
- try {
- opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
- } catch (nb::cast_error &) {
- throw nb::type_error(
- "Op interface does not refer to an operation or OpView class");
- }
-
- if (!mlirOperationImplementsInterfaceStatic(
- mlirStringRefCreate(opName.data(), opName.length()),
- context.resolve().get(), ConcreteIface::getInterfaceID())) {
- std::string msg = "the operation does not implement ";
- throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
- }
- }
- }
-
- /// Creates the Python bindings for this class in the given module.
- static void bind(nb::module_ &m) {
- nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
- cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
- nb::arg("context") = nb::none(), constructorDoc)
- .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
- operationDoc)
- .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
- ConcreteIface::bindDerived(cls);
- }
-
- /// Hook for derived classes to add class-specific bindings.
- static void bindDerived(ClassTy &cls) {}
-
- /// Returns `true` if this object was constructed from a subclass of OpView
- /// rather than from an operation instance.
- bool isStatic() { return operation == nullptr; }
-
- /// Returns the operation instance from which this object was constructed.
- /// Throws a type error if this object was constructed from a subclass of
- /// OpView.
- nb::typed<nb::object, PyOperation> getOperationObject() {
- if (operation == nullptr)
- throw nb::type_error("Cannot get an operation from a static interface");
- return operation->getRef().releaseObject();
- }
-
- /// Returns the opview of the operation instance from which this object was
- /// constructed. Throws a type error if this object was constructed form a
- /// subclass of OpView.
- nb::typed<nb::object, PyOpView> getOpView() {
- if (operation == nullptr)
- throw nb::type_error("Cannot get an opview from a static interface");
- return operation->createOpView();
- }
-
- /// Returns the canonical name of the operation this interface is constructed
- /// from.
- const std::string &getOpName() { return opName; }
-
-private:
- PyOperation *operation = nullptr;
- std::string opName;
- nb::object obj;
-};
-
/// Python wrapper for InferTypeOpInterface. This interface has only static
/// methods.
class PyInferTypeOpInterface
@@ -464,10 +339,62 @@ class PyInferShapedTypeOpInterface
}
};
+/// Wrapper around the MemoryEffectsOpInterface.
+class PyMemoryEffectsOpInterface
+ : public PyConcreteOpInterface<PyMemoryEffectsOpInterface> {
+public:
+ using PyConcreteOpInterface<
+ PyMemoryEffectsOpInterface>::PyConcreteOpInterface;
+
+ constexpr static const char *pyClassName = "MemoryEffectsOpInterface";
+ constexpr static GetTypeIDFunctionTy getInterfaceID =
+ &mlirMemoryEffectsOpInterfaceTypeID;
+
+ /// Attach a new MemoryEffectsOpInterface FallbackModel to the named
+ /// operation. The FallbackModel acts as a trampoline for callbacks on the
+ /// Python class.
+ static void attach(nb::object &pySubclass, const std::string &opName,
+ DefaultingPyMlirContext ctx) {
+ MlirMemoryEffectsOpInterfaceCallbacks callbacks;
+ callbacks.userData = pySubclass.ptr();
+ nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+ callbacks.construct = nullptr;
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.getEffects = [](MlirOperation op,
+ MlirMemoryEffectInstancesList effects,
+ void *userData) {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ // Get the 'get_effects' method from the Python class.
+ auto pyGetEffects =
+ nb::cast<nb::callable>(nb::getattr(pyClass, "get_effects"));
+
+ PyMemoryEffectsInstanceList effectsWrapper{effects};
+
+ // Invoke `pyClass.get_effects(op, effects)`.
+ pyGetEffects(op, effectsWrapper);
+ };
+
+ mlirMemoryEffectsOpInterfaceAttachFallbackModel(
+ ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
+ }
+
+ static void bindDerived(ClassTy &memoryEffectsOpInterfaceClass) {
+ memoryEffectsOpInterfaceClass.attr("attach") =
+ classmethod(&PyMemoryEffectsOpInterface::attach, nb::arg("cls"),
+ nb::arg("op_name"), nb::arg("ctx") = nb::none());
+ }
+};
+
void populateIRInterfaces(nb::module_ &m) {
+ nb::class_<PyMemoryEffectsInstanceList>(m, "MemoryEffectInstancesList");
+
+ PyInferShapedTypeOpInterface::bind(m);
PyInferTypeOpInterface::bind(m);
+ PyMemoryEffectsOpInterface::bind(m);
PyShapedTypeComponents::bind(m);
- PyInferShapedTypeOpInterface::bind(m);
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.h b/mlir/lib/Bindings/Python/IRInterfaces.h
new file mode 100644
index 0000000000000..3374b505b5034
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRInterfaces.h
@@ -0,0 +1,156 @@
+//===- IRInterfaces.h - IR Interfaces for Python Bindings -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_IRINTERFACES_H
+#define MLIR_BINDINGS_PYTHON_IRINTERFACES_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Interfaces.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
+
+namespace nb = nanobind;
+
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+constexpr static const char *constructorDoc =
+ R"(Creates an interface from a given operation/opview object or from a
+subclass of OpView. Raises ValueError if the operation does not implement the
+interface.)";
+
+constexpr static const char *operationDoc =
+ R"(Returns an Operation for which the interface was constructed.)";
+
+constexpr static const char *opviewDoc =
+ R"(Returns an OpView subclass _instance_ for which the interface was
+constructed)";
+
+/// CRTP base class for Python classes representing MLIR Op interfaces.
+/// Interface hierarchies are flat so no base class is expected here. The
+/// derived class is expected to define the following static fields:
+/// - `const char *pyClassName` - the name of the Python class to create;
+/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
+/// of the interface.
+/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
+/// interface-specific methods.
+///
+/// An interface class may be constructed from either an Operation/OpView object
+/// or from a subclass of OpView. In the latter case, only the static interface
+/// methods are available, similarly to calling ConcereteOp::staticMethod on the
+/// C++ side. Implementations of concrete interfaces can use the `isStatic`
+/// method to check whether the interface object was constructed from a class or
+/// an operation/opview instance. The `getOpName` always succeeds and returns a
+/// canonical name of the operation suitable for lookups.
+template <typename ConcreteIface>
+class PyConcreteOpInterface {
+protected:
+ using ClassTy = nb::class_<ConcreteIface>;
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+
+public:
+ /// Constructs an interface instance from an object that is either an
+ /// operation or a subclass of OpView. In the latter case, only the static
+ /// methods of the interface are accessible to the caller.
+ PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
+ : obj(std::move(object)) {
+ try {
+ operation = &nb::cast<PyOperation &>(obj);
+ } catch (nb::cast_error &) {
+ // Do nothing.
+ }
+
+ try {
+ operation = &nb::cast<PyOpView &>(obj).getOperation();
+ } catch (nb::cast_error &) {
+ // Do nothing.
+ }
+
+ if (operation != nullptr) {
+ if (!mlirOperationImplementsInterface(*operation,
+ ConcreteIface::getInterfaceID())) {
+ std::string msg = "the operation does not implement ";
+ throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
+ }
+
+ MlirIdentifier identifier = mlirOperationGetName(*operation);
+ MlirStringRef stringRef = mlirIdentifierStr(identifier);
+ opName = std::string(stringRef.data, stringRef.length);
+ } else {
+ try {
+ opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
+ } catch (nb::cast_error &) {
+ throw nb::type_error(
+ "Op interface does not refer to an operation or OpView class");
+ }
+
+ if (!mlirOperationImplementsInterfaceStatic(
+ mlirStringRefCreate(opName.data(), opName.length()),
+ context.resolve().get(), ConcreteIface::getInterfaceID())) {
+ std::string msg = "the operation does not implement ";
+ throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
+ }
+ }
+ }
+
+ /// Creates the Python bindings for this class in the given module.
+ static void bind(nb::module_ &m) {
+ nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
+ cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
+ nb::arg("context") = nb::none(), constructorDoc)
+ .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
+ operationDoc)
+ .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
+ ConcreteIface::bindDerived(cls);
+ }
+
+ /// Hook for derived classes to add class-specific bindings.
+ static void bindDerived(ClassTy &cls) {}
+
+ /// Returns `true` if this object was constructed from a subclass of OpView
+ /// rather than from an operation instance.
+ bool isStatic() { return operation == nullptr; }
+
+ /// Returns the operation instance from which this object was constructed.
+ /// Throws a type error if this object was constructed from a subclass of
+ /// OpView.
+ nb::typed<nb::object, PyOperation> getOperationObject() {
+ if (operation == nullptr)
+ throw nb::type_error("Cannot get an operation from a static interface");
+ return operation->getRef().releaseObject();
+ }
+
+ /// Returns the opview of the operation instance from which this object was
+ /// constructed. Throws a type error if this object was constructed form a
+ /// subclass of OpView.
+ nb::typed<nb::object, PyOpView> getOpView() {
+ if (operation == nullptr)
+ throw nb::type_error("Cannot get an opview from a static interface");
+ return operation->createOpView();
+ }
+
+ /// Returns the canonical name of the operation this interface is constructed
+ /// from.
+ const std::string &getOpName() { return opName; }
+
+private:
+ PyOperation *operation = nullptr;
+ std::string opName;
+ nb::object obj;
+};
+
+struct PyMemoryEffectsInstanceList {
+ MlirMemoryEffectInstancesList effects;
+};
+
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_IRINTERFACES_H
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 2b649f79c5982..dcb95e81fc126 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,14 +8,11 @@
#include "Rewrite.h"
+#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
-// clang-format off
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
-// clang-format on
#include "mlir/Config/mlir-config.h"
#include "nanobind/nanobind.h"
@@ -28,38 +25,12 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-class PyPatternRewriter {
+class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
public:
- PyPatternRewriter(MlirPatternRewriter rewriter)
- : base(mlirPatternRewriterAsBase(rewriter)),
- ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
-
- PyInsertionPoint getInsertionPoint() const {
- MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
- MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
-
- if (mlirOperationIsNull(op)) {
- MlirOperation owner = mlirBlockGetParentOperation(block);
- auto parent = PyOperation::forOperation(ctx, owner);
- return PyInsertionPoint(PyBlock(parent, block));
- }
-
- return PyInsertionPoint(PyOperation::forOperation(ctx, op));
- }
-
- void replaceOp(MlirOperation op, MlirOperation newOp) {
- mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
- }
-
- void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
- mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
- }
+ static constexpr const char *pyClassName = "PatternRewriter";
- void eraseOp(const PyOperation &op) { mlirRewriterBaseEraseOp(base, op); }
-
-private:
- MlirRewriterBase base;
- PyMlirContextRef ctx;
+ PyPatternRewriter(MlirPatternRewriter rewriter)
+ : PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
};
struct PyMlirPDLResultList : MlirPDLResultList {};
@@ -340,29 +311,8 @@ void populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
- nb::class_<PyPatternRewriter>(m, "PatternRewriter")
- .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
- "The current insertion point of the PatternRewriter.")
- .def(
- "replace_op",
- [](PyPatternRewriter &self, PyOperationBase &op,
- PyOperationBase &newOp) {
- self.replaceOp(op.getOperation(), newOp.getOperation());
- },
- "Replace an operation with a new operation.", nb::arg("op"),
- nb::arg("new_op"))
- .def(
- "replace_op",
- [](PyPatternRewriter &self, PyOperationBase &op,
- const std::vector<PyValue> &values) {
- std::vector<MlirValue> values_(values.size());
- std::copy(values.begin(), values.end(), values_.begin());
- self.replaceOp(op.getOperation(), values_);
- },
- "Replace an operation with a list of values.", nb::arg("op"),
- nb::arg("values"))
- .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
- nb::arg("op"));
+
+ PyPatternRewriter::bind(m);
//----------------------------------------------------------------------------
// Mapping of the RewritePatternSet
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index d287f19187708..9b4dfafb59b86 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -9,13 +9,80 @@
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H
-#include "mlir/Bindings/Python/NanobindUtils.h"
+#include "mlir-c/Rewrite.h"
+#include "mlir/Bindings/Python/IRCore.h"
+
+#include <nanobind/nanobind.h>
+
+namespace nb = nanobind;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+/// CRTP Base class for rewriter wrappers.
+template <typename DerivedTy>
+class PyRewriterBase {
+public:
+ PyRewriterBase(MlirRewriterBase rewriter)
+ : base(rewriter),
+ ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
+
+ PyInsertionPoint getInsertionPoint() const {
+ MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
+ MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
+
+ if (mlirOperationIsNull(op)) {
+ MlirOperation owner = mlirBlockGetParentOperation(block);
+ auto parent = PyOperation::forOperation(ctx, owner);
+ return PyInsertionPoint(PyBlock(parent, block));
+ }
+
+ return PyInsertionPoint(PyOperation::forOperation(ctx, op));
+ }
+
+ void replaceOp(MlirOperation op, MlirOperation newOp) {
+ mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+ }
+
+ void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+ mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+ }
+
+ void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+
+ static void bind(nanobind::module_ &m) {
+ nb::class_<DerivedTy>(m, DerivedTy::pyClassName)
+ .def_prop_ro("ip", &PyRewriterBase::getInsertionPoint,
+ "The current insertion point of the PatternRewriter.")
+ .def(
+ "replace_op",
+ [](DerivedTy &self, PyOperationBase &op, PyOperationBase &newOp) {
+ self.replaceOp(op.getOperation(), newOp.getOperation());
+ },
+ "Replace an operation with a new operation.", nb::arg("op"),
+ nb::arg("new_op"))
+ .def(
+ "replace_op",
+ [](DerivedTy &self, PyOperationBase &op,
+ const std::vector<PyValue> &values) {
+ std::vector<MlirValue> values_(values.size());
+ std::copy(values.begin(), values.end(), values_.begin());
+ self.replaceOp(op.getOperation(), values_);
+ },
+ "Replace an operation with a list of values.", nb::arg("op"),
+ nb::arg("values"))
+ .def("erase_op", &DerivedTy::eraseOp, "Erase an operation.",
+ nb::arg("op"));
+ }
+
+private:
+ MlirRewriterBase base;
+ PyMlirContextRef ctx;
+};
+
void populateRewriteSubmodule(nanobind::module_ &m);
-}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 18d4c075dbb9c..fdb27ec5cdc89 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -8,9 +8,13 @@
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/Support.h"
+#include "mlir/CAPI/Dialect/Transform.h"
+#include "mlir/CAPI/Interfaces.h"
#include "mlir/CAPI/Registration.h"
+#include "mlir/CAPI/Rewrite.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
using namespace mlir;
@@ -126,3 +130,195 @@ MlirStringRef mlirTransformParamTypeGetName(void) {
MlirType mlirTransformParamTypeGetType(MlirType type) {
return wrap(cast<transform::ParamType>(unwrap(type)).getType());
}
+
+//===---------------------------------------------------------------------===//
+// TransformRewriter
+//===---------------------------------------------------------------------===//
+
+/// Casts a `MlirTransformRewriter` to a `MlirRewriterBase`.
+MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter) {
+ mlir::transform::TransformRewriter *t = unwrap(rewriter);
+ mlir::RewriterBase *base = static_cast<mlir::RewriterBase *>(t);
+ return wrap(base);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformResults
+//===---------------------------------------------------------------------===//
+
+void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result,
+ intptr_t numOps, MlirOperation *ops) {
+ SmallVector<Operation *> opsVec;
+ opsVec.reserve(numOps);
+ for (intptr_t i = 0; i < numOps; ++i)
+ opsVec.push_back(unwrap(ops[i]));
+ unwrap(results)->set(cast<OpResult>(unwrap(result)), opsVec);
+}
+
+void mlirTransformResultsSetValues(MlirTransformResults results,
+ MlirValue result, intptr_t numValues,
+ MlirValue *values) {
+ SmallVector<Value> valuesVec;
+ valuesVec.reserve(numValues);
+ for (intptr_t i = 0; i < numValues; ++i)
+ valuesVec.push_back(unwrap(values[i]));
+ unwrap(results)->setValues(cast<OpResult>(unwrap(result)), valuesVec);
+}
+
+void mlirTransformResultsSetParams(MlirTransformResults results,
+ MlirValue result, intptr_t numParams,
+ MlirAttribute *params) {
+ SmallVector<Attribute> paramsVec;
+ paramsVec.reserve(numParams);
+ for (intptr_t i = 0; i < numParams; ++i)
+ paramsVec.push_back(unwrap(params[i]));
+ unwrap(results)->setParams(cast<OpResult>(unwrap(result)), paramsVec);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformState
+//===---------------------------------------------------------------------===//
+
+void mlirTransformStateForEachPayloadOp(MlirTransformState state,
+ MlirValue value,
+ MlirOperationCallback callback,
+ void *userData) {
+ for (Operation *op : unwrap(state)->getPayloadOps(unwrap(value)))
+ callback(wrap(op), userData);
+}
+
+void mlirTransformStateForEachPayloadValue(MlirTransformState state,
+ MlirValue value,
+ MlirValueCallback callback,
+ void *userData) {
+ for (Value val : unwrap(state)->getPayloadValues(unwrap(value)))
+ callback(wrap(val), userData);
+}
+
+void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
+ MlirAttributeCallback callback,
+ void *userData) {
+ for (Attribute attr : unwrap(state)->getParams(unwrap(value)))
+ callback(wrap(attr), userData);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformOpInterface
+//===---------------------------------------------------------------------===//
+
+MlirTypeID mlirTransformOpInterfaceTypeID(void) {
+ return wrap(transform::TransformOpInterface::getInterfaceID());
+}
+
+/// Fallback model for the TransformOpInterface that uses C API callbacks.
+class TransformOpInterfaceFallbackModel
+ : public mlir::transform::TransformOpInterface::FallbackModel<
+ TransformOpInterfaceFallbackModel> {
+public:
+ /// Sets the callbacks that this FallbackModel will use.
+ /// NB: the callbacks can only be set through this method as the
+ /// RegisteredOperationName::attachInterface mechanism default-constructs
+ /// the FallbackModel without being able to provide arguments.
+ void setCallbacks(MlirTransformOpInterfaceCallbacks callbacks) {
+ this->callbacks = callbacks;
+ }
+
+ ~TransformOpInterfaceFallbackModel() {
+ if (callbacks.destruct)
+ callbacks.destruct(callbacks.userData);
+ }
+
+ static TypeID getInterfaceID() {
+ return transform::TransformOpInterface::getInterfaceID();
+ }
+
+ static bool classof(const mlir::transform::detail::
+ TransformOpInterfaceInterfaceTraits::Concept *op) {
+ // Enable casting back to the FallbackModel from the Interface. This is
+ // necessary as attachInterface(...) default-constructs the FallbackModel
+ // without being able to pass in the callbacks and returns just the Concept.
+ return true;
+ }
+
+ ::mlir::DiagnosedSilenceableFailure
+ apply(Operation *op, ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state) const {
+ assert(callbacks.apply && "apply callback not set");
+
+ MlirDiagnosedSilenceableFailure status =
+ callbacks.apply(wrap(op), wrap(&rewriter), wrap(&transformResults),
+ wrap(&state), callbacks.userData);
+
+ switch (status) {
+ case MlirDiagnosedSilenceableFailureSuccess:
+ return DiagnosedSilenceableFailure::success();
+ case MlirDiagnosedSilenceableFailureSilenceableFailure:
+ // TODO: enable passing diagnostic info from C API to C++ API.
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(
+ *(op->emitError()
+ << "TransformOpInterfaceFallbackModel: silenceable failure")
+ .getUnderlyingDiagnostic()));
+ case MlirDiagnosedSilenceableFailureDefiniteFailure:
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ llvm_unreachable("unknown transform status");
+ }
+
+ bool allowsRepeatedHandleOperands(Operation *op) const {
+ assert(callbacks.allowsRepeatedHandleOperands &&
+ "allowsRepeatedHandleOperands callback not set");
+ return callbacks.allowsRepeatedHandleOperands(wrap(op), callbacks.userData);
+ }
+
+private:
+ MlirTransformOpInterfaceCallbacks callbacks;
+};
+
+/// Attach a TransformOpInterface FallbackModel to the given named operation.
+/// The FallbackModel uses the provided callbacks to implement the interface.
+void mlirTransformOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirTransformOpInterfaceCallbacks callbacks) {
+ // Look up the operation definition in the context.
+ std::optional<RegisteredOperationName> opInfo =
+ RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
+
+ if (!opInfo.has_value()) {
+ llvm::errs() << "Operation '" << unwrap(opName)
+ << "' not found in context\n";
+ return;
+ }
+
+ // NB: the following default-constructs the FallbackModel _without_ being able
+ // to provide arguments.
+ opInfo->attachInterface<TransformOpInterfaceFallbackModel>();
+ // Cast to get the underlying FallbackModel and set the callbacks.
+ auto *model = cast<TransformOpInterfaceFallbackModel>(
+ opInfo->getInterface<TransformOpInterfaceFallbackModel>());
+
+ assert(model && "Failed to get TransformOpInterfaceFallbackModel");
+ model->setCallbacks(callbacks);
+}
+
+//===---------------------------------------------------------------------===//
+// MemoryEffectsOpInterface helpers
+//===---------------------------------------------------------------------===//
+
+/// Set the effect for the operands to only read the transform handles.
+void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
+ MlirMemoryEffectInstancesList effects) {
+ MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
+ transform::onlyReadsHandle(operandArray, *unwrap(effects));
+}
+
+/// Set the effect for the results to that they produce transform handles.
+void mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
+ MlirMemoryEffectInstancesList effects) {
+ // NB: calling `producesHandle()` `numResults` as we cannot cast array of
+ // `OpResult`s to a single `ResultRange` (and neither is `ResultRange` exposed
+ // to Python). `producesHandle` iterates over the given `ResultRange` anyway.
+ for (intptr_t i = 0; i < numResults; ++i)
+ transform::producesHandle(ResultRange(*unwrap(results[i])),
+ *unwrap(effects));
+}
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 09666932004a4..ccfd30e2354c5 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -30,6 +30,7 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ThreadPool.h"
#include <cstddef>
@@ -714,6 +715,10 @@ MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
}
+MlirOpOperand mlirOperationGetOpOperand(MlirOperation op, intptr_t pos) {
+ return wrap(&unwrap(op)->getOpOperand(static_cast<unsigned>(pos)));
+}
+
void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
MlirValue newValue) {
unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
@@ -1121,6 +1126,13 @@ bool mlirValueIsAOpResult(MlirValue value) {
return llvm::isa<OpResult>(unwrap(value));
}
+/// Cast an MlirValue to an MlirOpResult, asserting in case of a type mismatch.
+MlirOpResult mlirValueToOpResult(MlirValue value) {
+ return TypeSwitch<Value, MlirOpResult>(unwrap(value))
+ .Case<OpResult>([&](OpResult opResult) { return wrap(&opResult); })
+ .DefaultUnreachable("expected an OpResult");
+}
+
MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
return wrap(llvm::dyn_cast<BlockArgument>(unwrap(value)).getOwner());
}
diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
index ef3fc23869550..50f007f5d15f6 100644
--- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp
+++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
@@ -167,3 +167,77 @@ MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
}
return mlirLogicalResultSuccess();
}
+
+//===---------------------------------------------------------------------===//
+// MemoryEffectOpInterface
+//===---------------------------------------------------------------------===//
+
+MlirTypeID mlirMemoryEffectsOpInterfaceTypeID() {
+ return wrap(MemoryEffectOpInterface::getInterfaceID());
+}
+
+/// Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
+class MemoryEffectOpInterfaceFallbackModel
+ : public mlir::MemoryEffectOpInterface::FallbackModel<
+ MemoryEffectOpInterfaceFallbackModel> {
+public:
+ /// Sets the callbacks that this FallbackModel will use.
+ /// NB: the callbacks can only be set through this method as the
+ /// RegisteredOperationName::attachInterface mechanism default-constructs
+ /// the FallbackModel without being able to provide arguments.
+ void setCallbacks(MlirMemoryEffectsOpInterfaceCallbacks callbacks) {
+ this->callbacks = callbacks;
+ }
+
+ ~MemoryEffectOpInterfaceFallbackModel() {
+ if (callbacks.destruct)
+ callbacks.destruct(callbacks.userData);
+ }
+
+ static TypeID getInterfaceID() {
+ return MemoryEffectOpInterface::getInterfaceID();
+ }
+
+ static bool classof(const mlir::MemoryEffectOpInterface::Concept *op) {
+ // Enable casting back to the FallbackModel from the Interface. This is
+ // necessary as attachInterface(...) default-constructs the FallbackModel
+ // without being able to pass in the callbacks and returns just the Concept.
+ return true;
+ }
+
+ void
+ getEffects(Operation *op,
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) const {
+ assert(callbacks.getEffects && "getEffects callback not set");
+ MlirMemoryEffectInstancesList cEffects = wrap(&effects);
+ callbacks.getEffects(wrap(op), cEffects, callbacks.userData);
+ }
+
+private:
+ MlirMemoryEffectsOpInterfaceCallbacks callbacks;
+};
+
+/// Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
+/// The FallbackModel uses the provided callbacks to implement the interface.
+void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirMemoryEffectsOpInterfaceCallbacks callbacks) {
+ // Look up the operation definition in the context
+ std::optional<RegisteredOperationName> opInfo =
+ RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
+
+ if (!opInfo.has_value()) {
+ llvm::errs() << "Operation '" << unwrap(opName)
+ << "' not found in context\n";
+ return;
+ }
+
+ // NB: the following default-constructs the FallbackModel _without_ being able
+ // to provide arguments.
+ opInfo->attachInterface<MemoryEffectOpInterfaceFallbackModel>();
+ // Cast to get the underlying FallbackModel and set the callbacks.
+ auto *model = cast<MemoryEffectOpInterfaceFallbackModel>(
+ opInfo->getInterface<MemoryEffectOpInterfaceFallbackModel>());
+ assert(model && "Failed to get MemoryEffectOpInterfaceFallbackModel");
+ model->setCallbacks(callbacks);
+}
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 2a37f3860fe00..c53c59d87d039 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -575,7 +575,7 @@ MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
ResultRange::ResultRange(OpResult result)
: ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
- 1) {}
+ 0) {}
ResultRange::use_range ResultRange::getUses() const {
return {use_begin(), use_end()};
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 003a06b16daac..748b881efa1dd 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -670,8 +670,10 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectTransform.cpp
+ Rewrite.h
PRIVATE_LINK_LIBS
LLVMSupport
+ MLIRPythonExtension.Core
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPITransformDialect
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index c0e8775149d41..303616080a75b 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -242,6 +242,7 @@ def __str__(self):
Sequence.register(ir.BlockPredecessors)
Sequence.register(ir.OperationList)
Sequence.register(ir.OpOperandList)
+ Sequence.register(ir.OpOpOperandList)
Sequence.register(ir.OpResultList)
Sequence.register(ir.OpSuccessors)
Sequence.register(ir.RegionSequence)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
new file mode 100644
index 0000000000000..85e9ec89a3de3
--- /dev/null
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -0,0 +1,252 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from contextlib import contextmanager
+
+from mlir import ir
+from mlir.ir import TypeAttr, F32Type, UnitAttr
+from mlir.dialects import transform, irdl, func
+from mlir.dialects.transform import AnyOpType, AnyValueType, AnyParamType, structured, interpreter
+
+from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
+
+
+def run(emit_schedule):
+ with ir.Context(), ir.Location.unknown():
+ irdl_module = ir.Module.create()
+ with ir.InsertionPoint(irdl_module.body):
+ my_transform = irdl.dialect("my_transform")
+ with ir.InsertionPoint(my_transform.body):
+ OneOpInOneOpOut.emit_irdl()
+ OpValParamInParamOpValOut.emit_irdl()
+ OpsParamInParamsOut.emit_irdl()
+ irdl_module.operation.verify()
+
+ irdl.load_dialects(irdl_module)
+
+ payload = emit_payload()
+ schedule = emit_schedule()
+
+ print("payload:", payload)
+ print("schedule:", schedule)
+ named_seq = schedule.operation.regions[0].blocks[0].operations[0]
+
+ interpreter.apply_named_sequence(
+ payload,
+ named_seq,
+ schedule,
+ )
+
+ del payload
+ del schedule
+
+
+# Payload used by all tests
+def emit_payload():
+ payload_module = ir.Module.create()
+ with ir.InsertionPoint(payload_module.body):
+
+ @func.FuncOp.from_py_func(
+ F32Type.get(), F32Type.get(), results=[F32Type.get()]
+ )
+ def name_of_func(a, b):
+ func.ReturnOp([b])
+
+ return payload_module
+
+
+class OneOpInOneOpOut:
+ name = "one_op_in_one_op_out"
+
+ def __init__(self, op_arg: AnyOpType):
+ self.op = ir.Operation.create(
+ "my_transform.one_op_in_one_op_out",
+ [AnyOpType.get()],
+ [get_op_result_or_value(op_arg)],
+ )
+
+ @property
+ def result(self):
+ return self.op.results[0]
+
+ @classmethod
+ def emit_irdl(cls):
+ op = irdl.operation_(cls.name)
+ with ir.InsertionPoint(op.body):
+ op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+ irdl.operands_(
+ [op_handle_type],
+ ["arg"],
+ [irdl.Variadicity.single],
+ )
+ irdl.results_([op_handle_type], ["result"], [irdl.Variadicity.single])
+ return op
+
+
+class OpValParamInParamOpValOut:
+ name = "op_val_param_in_param_op_val_out"
+
+ def __init__(
+ self,
+ op_arg: AnyOpType,
+ val: AnyValueType,
+ param: AnyParamType,
+ ):
+ self.op = ir.Operation.create(
+ "my_transform." + self.name,
+ [
+ AnyParamType.get(),
+ AnyOpType.get(),
+ AnyValueType.get(),
+ ],
+ [
+ get_op_result_or_value(op_arg),
+ get_op_result_or_value(val),
+ get_op_result_or_value(param),
+ ],
+ )
+
+ @property
+ def param_res(self):
+ return self.op.results[0]
+
+ @property
+ def op_res(self):
+ return self.op.results[1]
+
+ @property
+ def value_res(self):
+ return self.op.results[2]
+
+ @classmethod
+ def emit_irdl(cls):
+ op = irdl.operation_(cls.name)
+ with ir.InsertionPoint(op.body):
+ op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+ value_handle_type = irdl.is_(TypeAttr.get(AnyValueType.get()))
+ param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
+ irdl.operands_(
+ [op_handle_type, value_handle_type, param_handle_type],
+ ["op_arg", "value_arg", "param_arg"],
+ [
+ irdl.Variadicity.single,
+ irdl.Variadicity.single,
+ irdl.Variadicity.single,
+ ],
+ )
+ irdl.results_(
+ [param_handle_type, op_handle_type, value_handle_type],
+ ["param_res", "op_res", "value_res"],
+ [
+ irdl.Variadicity.single,
+ irdl.Variadicity.single,
+ irdl.Variadicity.single,
+ ],
+ )
+ return op
+
+
+class OpsParamInParamsOut:
+ name = "ops_param_in_params_out"
+
+ def __init__(
+ self,
+ ops: list[AnyOpType],
+ param: AnyParamType,
+ ):
+ self.op = ir.Operation.create(
+ "my_transform." + self.name,
+ [AnyParamType.get()],
+ [get_op_results_or_values(ops), get_op_result_or_value(param)],
+ )
+
+ @property
+ def param_results(self):
+ return self.op.results
+
+ @classmethod
+ def emit_irdl(cls):
+ op = irdl.operation_(cls.name)
+ with ir.InsertionPoint(op.body):
+ op_handle_type = irdl.is_(TypeAttr.get(AnyOpType.get()))
+ param_handle_type = irdl.is_(TypeAttr.get(AnyParamType.get()))
+ irdl.operands_(
+ [op_handle_type, param_handle_type],
+ ["op_args", "param_arg"],
+ [irdl.Variadicity.variadic, irdl.Variadicity.single],
+ )
+ irdl.results_(
+ [param_handle_type], ["param_results"], [irdl.Variadicity.variadic]
+ )
+ return op
+
+
+ at contextmanager
+def schedule_boilerplate():
+ schedule = ir.Module.create()
+ schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+ with ir.InsertionPoint(schedule.body):
+ named_sequence = transform.NamedSequenceOp(
+ "__transform_main",
+ [AnyOpType.get()],
+ [AnyOpType.get()],
+ arg_attrs=[{"transform.consumed": UnitAttr.get()}],
+ )
+ with ir.InsertionPoint(named_sequence.body):
+ yield schedule, named_sequence
+
+
+ at run
+def OneOpInOneOpOutTransformOpInterface():
+ class OneOpInOneOpOutTransformOpInterfaceFallbackModel(
+ transform.TransformOpInterface
+ ):
+ @staticmethod
+ def apply(
+ op_: ir.Operation,
+ rewriter: transform.TransformRewriter,
+ results: transform.TransformResults,
+ state: transform.TransformState,
+ ):
+ targets = state.get_payload_ops(arg := op_.operands[0])
+ target_names = [t.opview.name.value for t in targets]
+ print(
+ f"OneOpInOneOpOutTransformOpInterfaceFallbackModel: target_names={target_names}"
+ )
+ results.set_ops(result := op_.results[0], targets)
+ return transform.DiagnosedSilenceableFailure.Success
+
+ @staticmethod
+ def allow_repeated_handle_operands(op_: ir.Operation) -> bool:
+ return False
+
+ OneOpInOneOpOutTransformOpInterfaceFallbackModel.attach(
+ "my_transform.one_op_in_one_op_out", ir.Context.current
+ )
+
+ # TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface.
+ class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
+ @staticmethod
+ def get_effects(op_: ir.Operation, effects):
+ transform.only_reads_handle(list(op_.op_operands), effects)
+ transform.produces_handle(list(op_.results), effects)
+
+ MemoryEffectsOpInterfaceFallbackModel.attach(
+ "my_transform.one_op_in_one_op_out", ir.Context.current
+ )
+
+ with schedule_boilerplate() as (schedule, named_seq):
+ print(f"{named_seq=}")
+ func_handle = structured.MatchOp.match_op_names(
+ named_seq.bodyTarget, ["func.func"]
+ ).result
+ func_handle.dump()
+ # CHECK: OneOpInOneOpOutTransformOpInterfaceFallbackModel: target_names=['name_of_func']
+ out = OneOpInOneOpOut(func_handle).result
+ out.dump()
+ print(out.owner)
+ transform.YieldOp([out])
+ named_seq.verify()
+ print("named_seq", named_seq)
+ print("named_seq.parent", named_seq.parent)
+
+ return schedule
More information about the Mlir-commits
mailing list