[Mlir-commits] [mlir] [MLIR][Python] Impl XOpInterface(s) from Python, with X=Transform and X=MemoryEffects (PR #176920)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 3 01:01:29 PST 2026
================
@@ -126,3 +131,216 @@ MlirStringRef mlirTransformParamTypeGetName(void) {
MlirType mlirTransformParamTypeGetType(MlirType type) {
return wrap(cast<transform::ParamType>(unwrap(type)).getType());
}
+
+//===---------------------------------------------------------------------===//
+// TransformRewriter
+//===---------------------------------------------------------------------===//
+
+/// Casts a `MlirTransformRewriter` to a `MlirRewriterBase`.
+MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter) {
+ mlir::transform::TransformRewriter *t = unwrap(rewriter);
+ mlir::RewriterBase *base = static_cast<mlir::RewriterBase *>(t);
+ return wrap(base);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformResults
+//===---------------------------------------------------------------------===//
+
+void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result,
+ intptr_t numOps, MlirOperation *ops) {
+ SmallVector<Operation *> opsVec;
+ opsVec.reserve(numOps);
+ for (intptr_t i = 0; i < numOps; ++i)
+ opsVec.push_back(unwrap(ops[i]));
+ unwrap(results)->set(cast<OpResult>(unwrap(result)), opsVec);
+}
+
+void mlirTransformResultsSetValues(MlirTransformResults results,
+ MlirValue result, intptr_t numValues,
+ MlirValue *values) {
+ SmallVector<Value> valuesVec;
+ valuesVec.reserve(numValues);
+ for (intptr_t i = 0; i < numValues; ++i)
+ valuesVec.push_back(unwrap(values[i]));
+ unwrap(results)->setValues(cast<OpResult>(unwrap(result)), valuesVec);
+}
+
+void mlirTransformResultsSetParams(MlirTransformResults results,
+ MlirValue result, intptr_t numParams,
+ MlirAttribute *params) {
+ SmallVector<Attribute> paramsVec;
+ paramsVec.reserve(numParams);
+ for (intptr_t i = 0; i < numParams; ++i)
+ paramsVec.push_back(unwrap(params[i]));
+ unwrap(results)->setParams(cast<OpResult>(unwrap(result)), paramsVec);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformState
+//===---------------------------------------------------------------------===//
+
+void mlirTransformStateForEachPayloadOp(MlirTransformState state,
+ MlirValue value,
+ MlirOperationCallback callback,
+ void *userData) {
+ for (Operation *op : unwrap(state)->getPayloadOps(unwrap(value)))
+ callback(wrap(op), userData);
+}
+
+void mlirTransformStateForEachPayloadValue(MlirTransformState state,
+ MlirValue value,
+ MlirValueCallback callback,
+ void *userData) {
+ for (Value val : unwrap(state)->getPayloadValues(unwrap(value)))
+ callback(wrap(val), userData);
+}
+
+void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
+ MlirAttributeCallback callback,
+ void *userData) {
+ for (Attribute attr : unwrap(state)->getParams(unwrap(value)))
+ callback(wrap(attr), userData);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformOpInterface
+//===---------------------------------------------------------------------===//
+
+MlirTypeID mlirTransformOpInterfaceTypeID(void) {
+ return wrap(transform::TransformOpInterface::getInterfaceID());
+}
+
+/// Fallback model for the TransformOpInterface that uses C API callbacks.
+class TransformOpInterfaceFallbackModel
+ : public mlir::transform::TransformOpInterface::FallbackModel<
+ TransformOpInterfaceFallbackModel> {
+public:
+ /// Sets the callbacks that this FallbackModel will use.
+ /// NB: the callbacks can only be set through this method as the
+ /// RegisteredOperationName::attachInterface mechanism default-constructs
+ /// the FallbackModel without being able to provide arguments.
+ void setCallbacks(MlirTransformOpInterfaceCallbacks callbacks) {
+ this->callbacks = callbacks;
+ }
+
+ ~TransformOpInterfaceFallbackModel() {
+ if (callbacks.destruct)
+ callbacks.destruct(callbacks.userData);
+ }
+
+ static TypeID getInterfaceID() {
+ return transform::TransformOpInterface::getInterfaceID();
+ }
+
+ static bool classof(const mlir::transform::detail::
+ TransformOpInterfaceInterfaceTraits::Concept *op) {
+ // Enable casting back to the FallbackModel from the Interface. This is
+ // necessary as attachInterface(...) default-constructs the FallbackModel
+ // without being able to pass in the callbacks and returns just the Concept.
+ return true;
+ }
+
+ ::mlir::DiagnosedSilenceableFailure
+ apply(Operation *op, ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state) const {
+ assert(callbacks.apply && "apply callback not set");
+
+ MlirDiagnosedSilenceableFailure status =
+ callbacks.apply(wrap(op), wrap(&rewriter), wrap(&transformResults),
+ wrap(&state), callbacks.userData);
+
+ switch (status) {
+ case MlirDiagnosedSilenceableFailureSuccess:
+ return DiagnosedSilenceableFailure::success();
+ case MlirDiagnosedSilenceableFailureSilenceableFailure:
+ // TODO: enable passing diagnostic info from C API to C++ API.
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(
+ *(op->emitError()
+ << "TransformOpInterfaceFallbackModel: silenceable failure")
+ .getUnderlyingDiagnostic()));
+ case MlirDiagnosedSilenceableFailureDefiniteFailure:
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ llvm_unreachable("unknown transform status");
+ }
+
+ bool allowsRepeatedHandleOperands(Operation *op) const {
+ assert(callbacks.allowsRepeatedHandleOperands &&
+ "allowsRepeatedHandleOperands callback not set");
+ return callbacks.allowsRepeatedHandleOperands(wrap(op), callbacks.userData);
+ }
+
+private:
+ MlirTransformOpInterfaceCallbacks callbacks;
+};
+
+/// Attach a TransformOpInterface FallbackModel to the given named operation.
+/// The FallbackModel uses the provided callbacks to implement the interface.
+void mlirTransformOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirTransformOpInterfaceCallbacks callbacks) {
+ // Look up the operation definition in the context.
+ std::optional<RegisteredOperationName> opInfo =
+ RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
+
+ if (!opInfo.has_value()) {
+ llvm::errs() << "Operation '" << unwrap(opName)
+ << "' not found in context\n";
+ return;
+ }
----------------
PragmaTwice wrote:
Should we make it an assertion?
https://github.com/llvm/llvm-project/pull/176920
More information about the Mlir-commits
mailing list