[Mlir-commits] [mlir] [MLIR][Python] Impl XOpInterface(s) from Python, with X=Transform and X=MemoryEffects (PR #176920)

Rolf Morel llvmlistbot at llvm.org
Wed Feb 11 08:53:57 PST 2026


================
@@ -126,3 +131,216 @@ MlirStringRef mlirTransformParamTypeGetName(void) {
 MlirType mlirTransformParamTypeGetType(MlirType type) {
   return wrap(cast<transform::ParamType>(unwrap(type)).getType());
 }
+
+//===---------------------------------------------------------------------===//
+// TransformRewriter
+//===---------------------------------------------------------------------===//
+
+/// Casts a `MlirTransformRewriter` to a `MlirRewriterBase`.
+MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter) {
+  mlir::transform::TransformRewriter *t = unwrap(rewriter);
+  mlir::RewriterBase *base = static_cast<mlir::RewriterBase *>(t);
+  return wrap(base);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformResults
+//===---------------------------------------------------------------------===//
+
+void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result,
+                                intptr_t numOps, MlirOperation *ops) {
+  SmallVector<Operation *> opsVec;
+  opsVec.reserve(numOps);
+  for (intptr_t i = 0; i < numOps; ++i)
+    opsVec.push_back(unwrap(ops[i]));
+  unwrap(results)->set(cast<OpResult>(unwrap(result)), opsVec);
+}
+
+void mlirTransformResultsSetValues(MlirTransformResults results,
+                                   MlirValue result, intptr_t numValues,
+                                   MlirValue *values) {
+  SmallVector<Value> valuesVec;
+  valuesVec.reserve(numValues);
+  for (intptr_t i = 0; i < numValues; ++i)
+    valuesVec.push_back(unwrap(values[i]));
+  unwrap(results)->setValues(cast<OpResult>(unwrap(result)), valuesVec);
+}
+
+void mlirTransformResultsSetParams(MlirTransformResults results,
+                                   MlirValue result, intptr_t numParams,
+                                   MlirAttribute *params) {
+  SmallVector<Attribute> paramsVec;
+  paramsVec.reserve(numParams);
+  for (intptr_t i = 0; i < numParams; ++i)
+    paramsVec.push_back(unwrap(params[i]));
+  unwrap(results)->setParams(cast<OpResult>(unwrap(result)), paramsVec);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformState
+//===---------------------------------------------------------------------===//
+
+void mlirTransformStateForEachPayloadOp(MlirTransformState state,
+                                        MlirValue value,
+                                        MlirOperationCallback callback,
+                                        void *userData) {
+  for (Operation *op : unwrap(state)->getPayloadOps(unwrap(value)))
+    callback(wrap(op), userData);
+}
+
+void mlirTransformStateForEachPayloadValue(MlirTransformState state,
+                                           MlirValue value,
+                                           MlirValueCallback callback,
+                                           void *userData) {
+  for (Value val : unwrap(state)->getPayloadValues(unwrap(value)))
+    callback(wrap(val), userData);
+}
+
+void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
+                                    MlirAttributeCallback callback,
+                                    void *userData) {
+  for (Attribute attr : unwrap(state)->getParams(unwrap(value)))
+    callback(wrap(attr), userData);
+}
+
+//===---------------------------------------------------------------------===//
+// TransformOpInterface
+//===---------------------------------------------------------------------===//
+
+MlirTypeID mlirTransformOpInterfaceTypeID(void) {
+  return wrap(transform::TransformOpInterface::getInterfaceID());
+}
+
+/// Fallback model for the TransformOpInterface that uses C API callbacks.
+class TransformOpInterfaceFallbackModel
+    : public mlir::transform::TransformOpInterface::FallbackModel<
+          TransformOpInterfaceFallbackModel> {
+public:
+  /// Sets the callbacks that this FallbackModel will use.
+  /// NB: the callbacks can only be set through this method as the
+  /// RegisteredOperationName::attachInterface mechanism default-constructs
+  /// the FallbackModel without being able to provide arguments.
+  void setCallbacks(MlirTransformOpInterfaceCallbacks callbacks) {
+    this->callbacks = callbacks;
+  }
+
+  ~TransformOpInterfaceFallbackModel() {
+    if (callbacks.destruct)
+      callbacks.destruct(callbacks.userData);
+  }
+
+  static TypeID getInterfaceID() {
+    return transform::TransformOpInterface::getInterfaceID();
+  }
+
+  static bool classof(const mlir::transform::detail::
+                          TransformOpInterfaceInterfaceTraits::Concept *op) {
+    // Enable casting back to the FallbackModel from the Interface. This is
+    // necessary as attachInterface(...) default-constructs the FallbackModel
+    // without being able to pass in the callbacks and returns just the Concept.
+    return true;
+  }
+
+  ::mlir::DiagnosedSilenceableFailure
+  apply(Operation *op, ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state) const {
+    assert(callbacks.apply && "apply callback not set");
+
+    MlirDiagnosedSilenceableFailure status =
+        callbacks.apply(wrap(op), wrap(&rewriter), wrap(&transformResults),
+                        wrap(&state), callbacks.userData);
+
+    switch (status) {
+    case MlirDiagnosedSilenceableFailureSuccess:
+      return DiagnosedSilenceableFailure::success();
+    case MlirDiagnosedSilenceableFailureSilenceableFailure:
+      // TODO: enable passing diagnostic info from C API to C++ API.
+      return DiagnosedSilenceableFailure::silenceableFailure(std::move(
+          *(op->emitError()
+            << "TransformOpInterfaceFallbackModel: silenceable failure")
+               .getUnderlyingDiagnostic()));
+    case MlirDiagnosedSilenceableFailureDefiniteFailure:
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    llvm_unreachable("unknown transform status");
+  }
+
+  bool allowsRepeatedHandleOperands(Operation *op) const {
+    assert(callbacks.allowsRepeatedHandleOperands &&
+           "allowsRepeatedHandleOperands callback not set");
+    return callbacks.allowsRepeatedHandleOperands(wrap(op), callbacks.userData);
+  }
+
+private:
+  MlirTransformOpInterfaceCallbacks callbacks;
+};
+
+/// Attach a TransformOpInterface FallbackModel to the given named operation.
+/// The FallbackModel uses the provided callbacks to implement the interface.
+void mlirTransformOpInterfaceAttachFallbackModel(
+    MlirContext ctx, MlirStringRef opName,
+    MlirTransformOpInterfaceCallbacks callbacks) {
+  // Look up the operation definition in the context.
+  std::optional<RegisteredOperationName> opInfo =
+      RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
+
+  if (!opInfo.has_value()) {
+    llvm::errs() << "Operation '" << unwrap(opName)
+                 << "' not found in context\n";
+    return;
+  }
+
+  // NB: the following default-constructs the FallbackModel _without_ being able
+  // to provide arguments.
+  opInfo->attachInterface<TransformOpInterfaceFallbackModel>();
+  // Cast to get the underlying FallbackModel and set the callbacks.
+  auto *model = cast<TransformOpInterfaceFallbackModel>(
+      opInfo->getInterface<TransformOpInterfaceFallbackModel>());
+
+  assert(model && "Failed to get TransformOpInterfaceFallbackModel");
+  model->setCallbacks(callbacks);
+}
+
+//===---------------------------------------------------------------------===//
+// MemoryEffectsOpInterface helpers
+//===---------------------------------------------------------------------===//
+
+/// Set the effect for the operands to only read the transform handles.
+void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
+                                  MlirMemoryEffectInstancesList effects) {
+  MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
+  transform::onlyReadsHandle(operandArray, *unwrap(effects));
+}
+
+/// Set the effect for the operands to consuming the transform handles.
+void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
+                                 MlirMemoryEffectInstancesList effects) {
+  MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
+  transform::consumesHandle(operandArray, *unwrap(effects));
+}
+
+/// Set the effect for the results to that they produce transform handles.
+void mlirTransformProducesHandle(MlirValue *results, intptr_t numResults,
+                                 MlirMemoryEffectInstancesList effects) {
+  // NB: calling `producesHandle()` `numResults` as we cannot cast array of
+  // `OpResult`s to a single `ResultRange` (and neither is `ResultRange` exposed
+  // to Python). `producesHandle` iterates over the given `ResultRange` anyway.
+  SmallVectorImpl<MemoryEffects::EffectInstance> &effectList = *unwrap(effects);
+  for (intptr_t i = 0; i < numResults; ++i)
+    TypeSwitch<Value, void>(unwrap(results[i]))
+        .Case<OpResult>([&](OpResult opResult) {
+          transform::producesHandle(ResultRange(opResult), effectList);
+        })
+        .DefaultUnreachable("expected an OpResult");
----------------
rolfmorel wrote:

Done.

https://github.com/llvm/llvm-project/pull/176920


More information about the Mlir-commits mailing list