[Mlir-commits] [mlir] [MLIR][Python] Impl XOpInterface(s) from Python, with X=Transform and X=MemoryEffects (PR #176920)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 27 12:57:54 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Rolf Morel (rolfmorel)
<details>
<summary>Changes</summary>
Provides the infrastructure for implementing and late-binding OpInterfaces from Python.
* On the mlir-c API declaration side, each `XOpInterface` has a callback struct, with a callback for each method and a userdata member (provided as an arg to each method), and a `mlirXOpInterfaceAttachFallbackModel(ctx, op_name, callbacks)` func.
* For `MemoryEffectsOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-da9b0ad13ba75167cca65c45470dd4063bceca4603f6b6c82e3e01179b7b566bR107-R123
* For `TransformOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-8c29f5db66b6d811f22ed39e91578a320fd8053e317ce3d8bfeafcd5d86d9469R185-R208
* This CAPI is implemented by defining a subclass of `XOpInterface::FallbackModel` that holds the callback struct and has each method call the corresponding callback (with userdata as an arg). Given a callback struct, a new `FallbackModel` is created and attached, i.e. late bound, to the named op. (MLIR's interface infrastructure is such that the thus registered `FallbackModel` will be returned in case the op gets cast to the `XOpInterface`.)
* For `MemoryEffectsOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-e599b77dd395948103500ae86e792c37d622af721bb1bcdacada921b20250e4eR179-R243
* For `TransformOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-07f07055f38e2bcd03357d9be81bd563c51d499a094c418816d39faa145bd11eR213-R302
* On the Python side, we expose a stand-in `XOpInterface` base class which has one (class)method: `XOpInterface.attach(cls, op_name, ctx)`. Python users subclass this class (`class MyInterfaceImpl(XOpInterface): ...`) and implement the interface's methods (with the right names and signatures). The user calls `attach` on the subclass (`MyInterfaceImpl.attach("my_dialect.my_op", ctx)`) which prepares the callbacks struct _with userdata set to the subclass_ (as we use it to lookup methods). These callbacks (and userdata) are then registered as an `XOpInterface::FallbackModel` by `mlirXOpInterfaceAttachFallbackModel(...)`. From then on the Python methods will be used to respond to calls to the interface methods (originating in C++).
* For `MemoryEffectsOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-2400af23004f6bad0f7c3193cccaa544816b7e624c3513f9afe5939222236039R343-R389
* For `TransformOpInterface`: https://github.com/llvm/llvm-project/pull/176920/changes#diff-22bf2dd9a1ae598e9ee814d2ed92f3e394d0394844fa8ac3330b63c2bb131a17R159-R238
This PR enables implementing the TransformOpInterface and the MemoryEffectsOpInterface, both of which are required for making an op into a transform op.
Everything besides the above linked code is there to facilitate exposing the interfaces: the right types for the arguments of the methods are exposed as are functions/methods for manipulating these arguments (e.g. specifying side effects on `OpOperand`s and `OpResult`s and being able to access and set the transform handles associated with args and results).
---
Patch is 91.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/176920.diff
22 Files Affected:
- (modified) mlir/include/mlir-c/Dialect/Transform.h (+140)
- (modified) mlir/include/mlir-c/IR.h (+8)
- (modified) mlir/include/mlir-c/Interfaces.h (+37-2)
- (modified) mlir/include/mlir/Bindings/Python/IRCore.h (+10-1)
- (added) mlir/include/mlir/CAPI/Dialect/Transform.h (+28)
- (modified) mlir/include/mlir/CAPI/IR.h (+1)
- (modified) mlir/include/mlir/CAPI/Interfaces.h (+8)
- (modified) mlir/lib/Bindings/Python/DialectTransform.cpp (+272-2)
- (modified) mlir/lib/Bindings/Python/Globals.cpp (+1-2)
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+45-8)
- (modified) mlir/lib/Bindings/Python/IRInterfaces.cpp (+58-127)
- (added) mlir/lib/Bindings/Python/IRInterfaces.h (+156)
- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+7-57)
- (modified) mlir/lib/Bindings/Python/Rewrite.h (+69-2)
- (modified) mlir/lib/CAPI/Dialect/Transform.cpp (+203)
- (modified) mlir/lib/CAPI/IR/IR.cpp (+12)
- (modified) mlir/lib/CAPI/Interfaces/Interfaces.cpp (+74)
- (modified) mlir/lib/IR/OperationSupport.cpp (+1-1)
- (modified) mlir/python/CMakeLists.txt (+2)
- (modified) mlir/python/mlir/_mlir_libs/__init__.py (+1)
- (modified) mlir/python/mlir/dialects/ext.py (+20-13)
- (added) mlir/test/python/dialects/transform_op_interface.py (+494)
``````````diff
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 911c9ef659a1e..674d5f8b7b72d 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -11,6 +11,8 @@
#define MLIR_C_DIALECT_TRANSFORM_H
#include "mlir-c/IR.h"
+#include "mlir-c/Interfaces.h"
+#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
@@ -19,6 +21,32 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirTransformResults, void);
+DEFINE_C_API_STRUCT(MlirTransformRewriter, void);
+DEFINE_C_API_STRUCT(MlirTransformState, void);
+
+#undef DEFINE_C_API_STRUCT
+
+//===---------------------------------------------------------------------===//
+// DiagnosedSilenceableFailure
+//===---------------------------------------------------------------------===//
+
+/// Enum representing the result of a transform operation.
+typedef enum {
+ /// The operation succeeded.
+ MlirDiagnosedSilenceableFailureSuccess,
+ /// The operation failed in a silenceable way.
+ MlirDiagnosedSilenceableFailureSilenceableFailure,
+ /// The operation failed definitively.
+ MlirDiagnosedSilenceableFailureDefiniteFailure
+} MlirDiagnosedSilenceableFailure;
+
//===---------------------------------------------------------------------===//
// AnyOpType
//===---------------------------------------------------------------------===//
@@ -86,6 +114,118 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirTransformParamTypeGetName(void);
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
+//===---------------------------------------------------------------------===//
+// TransformRewriter
+//===---------------------------------------------------------------------===//
+
+/// Cast the TransformRewriter to a RewriterBase
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirTransformRewriterAsBase(MlirTransformRewriter rewriter);
+
+//===---------------------------------------------------------------------===//
+// TransformResults
+//===---------------------------------------------------------------------===//
+
+/// Set the payload operations for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void mlirTransformResultsSetOps(MlirTransformResults results,
+ MlirValue result,
+ intptr_t numOps,
+ MlirOperation *ops);
+
+/// Set the payload values for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void
+mlirTransformResultsSetValues(MlirTransformResults results, MlirValue result,
+ intptr_t numValues, MlirValue *values);
+
+/// Set the parameters for a transform result by iterating over a list.
+MLIR_CAPI_EXPORTED void
+mlirTransformResultsSetParams(MlirTransformResults results, MlirValue result,
+ intptr_t numParams, MlirAttribute *params);
+
+//===---------------------------------------------------------------------===//
+// TransformState
+//===---------------------------------------------------------------------===//
+
+/// Callback for iterating over payload operations.
+typedef void (*MlirOperationCallback)(MlirOperation, void *userData);
+
+/// Iterate over payload operations associated with the transform IR value.
+/// Calls the callback for each payload operation.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachPayloadOp(MlirTransformState state, MlirValue value,
+ MlirOperationCallback callback,
+ void *userData);
+
+/// Callback for iterating over payload values.
+typedef void (*MlirValueCallback)(MlirValue, void *userData);
+
+/// Iterate over payload values associated with the transform IR value.
+/// Calls the callback for each payload value.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachPayloadValue(MlirTransformState state, MlirValue value,
+ MlirValueCallback callback,
+ void *userData);
+
+/// Callback for iterating over parameters.
+typedef void (*MlirAttributeCallback)(MlirAttribute, void *userData);
+
+/// Iterate over parameters associated with the transform IR value.
+/// Calls the callback for each parameter.
+MLIR_CAPI_EXPORTED void
+mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
+ MlirAttributeCallback callback, void *userData);
+
+//===---------------------------------------------------------------------===//
+// TransformOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the TransformOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOpInterfaceTypeID(void);
+
+/// Callbacks for implementing TransformOpInterface from external code.
+typedef struct {
+ /// Optional constructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*construct)(void *userData);
+ /// Optional destructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*destruct)(void *userData);
+ /// Apply callback that implements the transformation.
+ MlirDiagnosedSilenceableFailure (*apply)(MlirOperation op,
+ MlirTransformRewriter rewriter,
+ MlirTransformResults results,
+ MlirTransformState state,
+ void *userData);
+ /// Callback to check if repeated handle operands are allowed.
+ bool (*allowsRepeatedHandleOperands)(MlirOperation op, void *userData);
+ void *userData;
+} MlirTransformOpInterfaceCallbacks;
+
+/// Attach TransformOpInterface to the operation with the given name using
+/// the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirTransformOpInterfaceCallbacks callbacks);
+
+//===---------------------------------------------------------------------===//
+// Transform-specifc MemoryEffectsOpInterface helpers
+//===---------------------------------------------------------------------===//
+
+/// Helper to mark operands as only reading handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
+ MlirMemoryEffectInstancesList effects);
+
+/// Helper to mark operands as consuming handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
+ MlirMemoryEffectInstancesList effects);
+
+/// Helper to mark results as producing handles.
+MLIR_CAPI_EXPORTED void
+mlirTransformProducesHandle(MlirOpResult *results, intptr_t numResults,
+ MlirMemoryEffectInstancesList effects);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 80ff39c82a9ee..9b58267cf1c5a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -55,6 +55,7 @@ DEFINE_C_API_STRUCT(MlirDialect, void);
DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
DEFINE_C_API_STRUCT(MlirOperation, void);
DEFINE_C_API_STRUCT(MlirOpOperand, void);
+DEFINE_C_API_STRUCT(MlirOpResult, void);
DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
DEFINE_C_API_STRUCT(MlirBlock, void);
DEFINE_C_API_STRUCT(MlirRegion, void);
@@ -673,6 +674,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op);
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op,
intptr_t pos);
+/// Returns `pos`-th OpOperand of the operation.
+MLIR_CAPI_EXPORTED MlirOpOperand mlirOperationGetOpOperand(MlirOperation op,
+ intptr_t pos);
+
/// Sets the `pos`-th operand of the operation.
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
MlirValue newValue);
@@ -1044,6 +1049,9 @@ MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value);
/// Returns 1 if the value is an operation result, 0 otherwise.
MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value);
+/// Cast the value to an OpResult. Asserts if the value is not an op result.
+MLIR_CAPI_EXPORTED MlirOpResult mlirValueToOpResult(MlirValue value);
+
/// Returns the block in which this value is defined as an argument. Asserts if
/// the value is not a block argument.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value);
diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h
index a5a3473eaef59..17a812dcd86a9 100644
--- a/mlir/include/mlir-c/Interfaces.h
+++ b/mlir/include/mlir-c/Interfaces.h
@@ -22,6 +22,16 @@
extern "C" {
#endif
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirMemoryEffectInstancesList, void);
+
+#undef DEFINE_C_API_STRUCT
+
/// Returns `true` if the given operation implements an interface identified by
/// its TypeID.
MLIR_CAPI_EXPORTED bool
@@ -42,7 +52,7 @@ mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
//===----------------------------------------------------------------------===//
/// Returns the interface TypeID of the InferTypeOpInterface.
-MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID();
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(void);
/// These callbacks are used to return multiple types from functions while
/// transferring ownership to the caller. The first argument is the number of
@@ -65,7 +75,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
//===----------------------------------------------------------------------===//
/// Returns the interface TypeID of the InferShapedTypeOpInterface.
-MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID();
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(void);
/// These callbacks are used to return multiple shaped type components from
/// functions while transferring ownership to the caller. The first argument is
@@ -87,6 +97,31 @@ mlirInferShapedTypeOpInterfaceInferReturnTypes(
void *properties, intptr_t nRegions, MlirRegion *regions,
MlirShapedTypeComponentsCallback callback, void *userData);
+//===---------------------------------------------------------------------===//
+// MemoryEffectsOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the MemoryEffectsOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirMemoryEffectsOpInterfaceTypeID(void);
+
+/// Callbacks for implementing MemoryEffectsOpInterface from external code.
+typedef struct {
+ /// Optional constructor for user data. Set to nullptr to disable it.
+ void (*construct)(void *userData);
+ /// Optional destructor for user data. Set to nullptr to disable it.
+ void (*destruct)(void *userData);
+ /// Get memory effects callback.
+ void (*getEffects)(MlirOperation op, MlirMemoryEffectInstancesList effects,
+ void *userData);
+ void *userData;
+} MlirMemoryEffectsOpInterfaceCallbacks;
+
+/// Attach a new FallbackModel for the MemoryEffectsOpInterface to the named
+/// operation. The FallbackModel will call the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirMemoryEffectsOpInterfaceCallbacks callbacks);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index f9fc34e82c972..2460202a31b8a 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1496,6 +1496,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationList {
class MLIR_PYTHON_API_EXPORTED PyOpOperand {
public:
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
+ operator MlirOpOperand() const { return opOperand; }
nanobind::typed<nanobind::object, PyOpView> getOwner() const;
@@ -1602,6 +1603,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpResult : public PyConcreteValue<PyOpResult> {
static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
static constexpr const char *pyClassName = "OpResult";
using PyConcreteValue::PyConcreteValue;
+ operator MlirOpResult() { return mlirValueToOpResult(castFrom(*this)); }
static void bindDerived(ClassTy &c);
};
@@ -1838,13 +1840,20 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
+
+/// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+static nanobind::object classmethod(Func f, Args... args) {
+ nanobind::object cf = nanobind::cpp_function(f, args...);
+ return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
+}
+
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
namespace nanobind {
namespace detail {
-
template <>
struct type_caster<
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext>
diff --git a/mlir/include/mlir/CAPI/Dialect/Transform.h b/mlir/include/mlir/CAPI/Dialect/Transform.h
new file mode 100644
index 0000000000000..792236cd8601f
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Dialect/Transform.h
@@ -0,0 +1,28 @@
+//===- Transform.h - C API Utils for Transform dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// the Transform dialect. This file should not be included from C++ code other
+// than C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_DIALECT_TRANSFORM_H
+#define MLIR_CAPI_DIALECT_TRANSFORM_H
+
+#include "mlir-c/Dialect/Transform.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+
+DEFINE_C_API_PTR_METHODS(MlirTransformRewriter,
+ mlir::transform::TransformRewriter)
+DEFINE_C_API_PTR_METHODS(MlirTransformResults,
+ mlir::transform::TransformResults)
+DEFINE_C_API_PTR_METHODS(MlirTransformState, mlir::transform::TransformState)
+
+#endif // MLIR_CAPI_DIALECT_TRANSFORM_H
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 1836cb0acb67e..f58ff423d23f7 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -29,6 +29,7 @@ DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand)
+DEFINE_C_API_PTR_METHODS(MlirOpResult, mlir::OpResult)
DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable)
diff --git a/mlir/include/mlir/CAPI/Interfaces.h b/mlir/include/mlir/CAPI/Interfaces.h
index 4154b8c9ec6cc..15afc9fb0f18e 100644
--- a/mlir/include/mlir/CAPI/Interfaces.h
+++ b/mlir/include/mlir/CAPI/Interfaces.h
@@ -15,4 +15,12 @@
#ifndef MLIR_CAPI_INTERFACES_H
#define MLIR_CAPI_INTERFACES_H
+#include "mlir-c/Interfaces.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+DEFINE_C_API_PTR_METHODS(
+ MlirMemoryEffectInstancesList,
+ llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance>)
+
#endif // MLIR_CAPI_INTERFACES_H
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 82905498921ce..ecb911fd7a3ba 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -8,12 +8,14 @@
#include <string>
+#include "IRInterfaces.h"
+#include "Rewrite.h"
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
+#include <nanobind/trampoline.h>
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
@@ -22,6 +24,219 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
namespace transform {
+
+//===----------------------------------------------------------------------===//
+// TransformRewriter
+//===----------------------------------------------------------------------===//
+class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
+public:
+ static constexpr const char *pyClassName = "TransformRewriter";
+
+ PyTransformRewriter(MlirTransformRewriter rewriter)
+ : PyRewriterBase(mlirTransformRewriterAsBase(rewriter)) {}
+};
+
+//===----------------------------------------------------------------------===//
+// TransformResults
+//===----------------------------------------------------------------------===//
+class PyTransformResults {
+public:
+ PyTransformResults(MlirTransformResults results) : results(results) {}
+
+ MlirTransformResults get() const { return results; }
+
+ void setOps(MlirValue result, const nanobind::list &ops) {
+ std::vector<MlirOperation> opsVec;
+ opsVec.reserve(ops.size());
+ for (auto op : ops) {
+ opsVec.push_back(nb::cast<MlirOperation>(op));
+ }
+ mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
+ }
+
+ void setValues(MlirValue result, const nanobind::list &values) {
+ std::vector<MlirValue> valuesVec;
+ valuesVec.reserve(values.size());
+ for (auto item : values) {
+ valuesVec.push_back(nb::cast<MlirValue>(item));
+ }
+ mlirTransformResultsSetValues(results, result, valuesVec.size(),
+ valuesVec.data());
+ }
+
+ void setParams(MlirValue result, const nanobind::list ¶ms) {
+ std::vector<MlirAttribute> paramsVec;
+ paramsVec.reserve(params.size());
+ for (auto item : params) {
+ paramsVec.push_back(nb::cast<MlirAttribute>(item));
+ }
+ mlirTransformResultsSetParams(results, result, paramsVec.size(),
+ paramsVec.data());
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nb::class_<PyTransformResults>(m, "TransformResults")
+ .def(nb::init<MlirTransformResults>())
+ .def("set_ops", &PyTransformResults::setOps,
+ "Set the payload operations for a transform result.",
+ nb::arg("result"), nb::arg("ops"))
+ .def("set_values", &PyTransformResults::setValues,
+ "Set the payload values for a transform result.",
+ nb::arg("result"), nb::arg("values"))
+ .def("set_params", &PyTransformResults::setParams,
+ "Set the parameters for a transform result.", nb::arg("result"),
+ nb::arg("params"));
+ }
+
+private:
+ MlirTransformResults results;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformS...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/176920
More information about the Mlir-commits
mailing list