[Mlir-commits] [mlir] [MLIR][Python] Support dynamic traits in python-defined dialects (PR #179705)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 6 22:46:05 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

This is a follow-up PR of #<!-- -->169045 and the second part of #<!-- -->179086.

In #<!-- -->179086, we added support for defining regions in Python-defined ops, but its usefulness was quite limited because we still couldn’t mark an op as a `Terminator` or `NoTerminator`. In this PR, we port the `DynamicOpTrait` (introduced on the C++ side for `DynamicDialect` in #<!-- -->177735) to Python, so we can dynamically attach traits to Python-defined ops.


---
Full diff: https://github.com/llvm/llvm-project/pull/179705.diff


6 Files Affected:

- (added) mlir/include/mlir-c/ExtensibleDialect.h (+76) 
- (modified) mlir/include/mlir/Bindings/Python/IRCore.h (+25) 
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+127) 
- (modified) mlir/lib/CAPI/IR/CMakeLists.txt (+1) 
- (added) mlir/lib/CAPI/IR/ExtensibleDialect.cpp (+87) 
- (modified) mlir/test/python/dialects/ext.py (+90-13) 


``````````diff
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
new file mode 100644
index 0000000000000..98457805f57c0
--- /dev/null
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -0,0 +1,76 @@
+//===-- mlir-c/ExtensibleDialect.h - Extensible dialect APIs -----*- C -*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header provides APIs for extensible dialects.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_EXTENSIBLEDIALECT_H
+#define MLIR_C_EXTENSIBLEDIALECT_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//===----------------------------------------------------------------------===//
+/// Opaque type declarations (see mlir-c/IR.h for more details).
+//===----------------------------------------------------------------------===//
+
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirDynamicOpTrait, void);
+
+/// Attach a dynamic op trait to the given operation name.
+/// Note that the operation name must be modeled by dynamic dialect and must be
+/// registered.
+/// The ownership of the trait will be transferred to the operation name
+/// after this call.
+MLIR_CAPI_EXPORTED bool
+mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
+                         MlirStringRef opName, MlirContext context);
+
+/// Get the dynamic op trait that indicates the operation is a terminator.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator(void);
+
+/// Get the dynamic op trait that indicates regions have no terminator.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator(void);
+
+/// Destroy the dynamic op trait.
+MLIR_CAPI_EXPORTED void
+mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait);
+
+typedef struct {
+  /// Optional constructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*construct)(void *userData);
+  /// Optional destructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*destruct)(void *userData);
+  /// The callback function to verify the operation.
+  MlirLogicalResult (*verifyTrait)(MlirOperation op, void *userData);
+  /// The callback function to verify the operation with access to regions.
+  MlirLogicalResult (*verifyRegionTrait)(MlirOperation op, void *userData);
+} MlirDynamicOpTraitCallbacks;
+
+/// Create a custom dynamic op trait with the given type ID and callbacks.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitCreate(
+    MlirTypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_EXTENSIBLEDIALECT_H
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 4bb49e6bc245d..e551a49bb34a8 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -23,6 +23,7 @@
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Debug.h"
 #include "mlir-c/Diagnostics.h"
+#include "mlir-c/ExtensibleDialect.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Support.h"
@@ -1844,6 +1845,30 @@ class MLIR_PYTHON_API_EXPORTED PyOpAdaptor {
   PyOpAttributeMap attributes;
 };
 
+class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
+public:
+  static bool attach(const nanobind::object &opName,
+                     const nanobind::object &target, PyMlirContext &context);
+
+  static void bind(nanobind::module_ &m);
+};
+
+namespace PyDynamicOpTraits {
+
+class MLIR_PYTHON_API_EXPORTED IsTerminator : public PyDynamicOpTrait {
+public:
+  static bool attach(const nanobind::object &opName, PyMlirContext &context);
+  static void bind(nanobind::module_ &m);
+};
+
+class MLIR_PYTHON_API_EXPORTED NoTerminator : public PyDynamicOpTrait {
+public:
+  static bool attach(const nanobind::object &opName, PyMlirContext &context);
+  static void bind(nanobind::module_ &m);
+};
+
+} // namespace PyDynamicOpTraits
+
 MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
 MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
 MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7f34343eba6c9..aaa19b65d15b1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -16,6 +16,7 @@
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Debug.h"
 #include "mlir-c/Diagnostics.h"
+#include "mlir-c/ExtensibleDialect.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
 
@@ -2521,6 +2522,127 @@ void PyOpAdaptor::bind(nb::module_ &m) {
           "Returns the attributes of the adaptor.");
 }
 
+static MlirLogicalResult verifyTraitByMethod(MlirOperation op, void *userData,
+                                             const char *methodName) {
+  nb::handle targetObj(static_cast<PyObject *>(userData));
+  if (!nb::hasattr(targetObj, methodName)) {
+    return mlirLogicalResultSuccess();
+  }
+  PyMlirContextRef ctx = PyMlirContext::forContext(mlirOperationGetContext(op));
+  PyOperationRef pyOp = PyOperation::forOperation(ctx, op);
+  bool success = nb::cast<bool>(targetObj.attr(methodName)(pyOp.get()));
+  return success ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
+};
+
+static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait,
+                          PyMlirContext &context) {
+  std::string opNameStr;
+  if (opName.is_type()) {
+    opNameStr = nb::cast<std::string>(opName.attr("OPERATION_NAME"));
+  } else if (nb::isinstance<nb::str>(opName)) {
+    opNameStr = nb::cast<std::string>(opName);
+  } else {
+    throw nb::type_error("the root argument must be a type or a string");
+  }
+
+  return mlirDynamicOpTraitAttach(
+      trait, MlirStringRef{opNameStr.data(), opNameStr.size()}, context.get());
+}
+
+bool PyDynamicOpTrait::attach(const nb::object &opName,
+                              const nb::object &target,
+                              PyMlirContext &context) {
+  if (!nb::hasattr(target, "verify") && !nb::hasattr(target, "verify_region"))
+    throw nb::type_error(
+        "the target object must have at least one of 'verify' or "
+        "'verify_region' methods");
+
+  MlirDynamicOpTraitCallbacks callbacks;
+  callbacks.construct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+  };
+  callbacks.destruct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+  };
+
+  callbacks.verifyTrait = [](MlirOperation op,
+                             void *userData) -> MlirLogicalResult {
+    return verifyTraitByMethod(op, userData, "verify");
+  };
+  callbacks.verifyRegionTrait = [](MlirOperation op,
+                                   void *userData) -> MlirLogicalResult {
+    return verifyTraitByMethod(op, userData, "verify_region");
+  };
+
+  constexpr const char *typeIDAttr = "_TYPE_ID";
+  if (!nb::hasattr(target, typeIDAttr)) {
+    nb::setattr(target, typeIDAttr,
+                nb::cast(PyTypeID(PyGlobals::get().allocateTypeID())));
+  }
+  MlirDynamicOpTrait trait = mlirDynamicOpTraitCreate(
+      nb::cast<PyTypeID>(target.attr(typeIDAttr)).get(), callbacks,
+      static_cast<void *>(target.ptr()));
+  return attachOpTrait(opName, trait, context);
+}
+
+void PyDynamicOpTrait::bind(nb::module_ &m) {
+  nb::class_<PyDynamicOpTrait> cls(m, "DynamicOpTrait");
+  cls.def_static(
+      "attach_target",
+      [](const nb::object &opName, const nb::object &target,
+         DefaultingPyMlirContext context) {
+        return PyDynamicOpTrait::attach(opName, target, *context.get());
+      },
+      "Attach the dynamic op trait with the target object to the given "
+      "operation name.",
+      nb::arg("op_name"), nb::arg("target"),
+      nb::arg("context").none() = nb::none());
+  cls.attr("attach") = classmethod(
+      [](const nb::object &cls, const nb::object &opName,
+         DefaultingPyMlirContext context) {
+        return PyDynamicOpTrait::attach(opName, cls, *context.get());
+      },
+      nb::arg("cls"), nb::arg("op_name"),
+      nb::arg("context").none() = nb::none(),
+      "Attach the dynamic op trait subclass to the given operation name.");
+}
+
+bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName,
+                                             PyMlirContext &context) {
+  MlirDynamicOpTrait trait = mlirDynamicOpTraitGetIsTerminator();
+  return attachOpTrait(opName, trait, context);
+}
+
+void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) {
+  nb::class_<PyDynamicOpTraits::IsTerminator, PyDynamicOpTrait> cls(
+      m, "IsTerminatorTrait");
+  cls.attr("attach") = classmethod(
+      [](const nb::object &cls, const nb::object &opName,
+         DefaultingPyMlirContext context) {
+        return PyDynamicOpTraits::IsTerminator::attach(opName, *context.get());
+      },
+      "Attach IsTerminator trait to the given operation name.", nb::arg("cls"),
+      nb::arg("op_name"), nb::arg("context").none() = nb::none());
+}
+
+bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName,
+                                             PyMlirContext &context) {
+  MlirDynamicOpTrait trait = mlirDynamicOpTraitGetNoTerminator();
+  return attachOpTrait(opName, trait, context);
+}
+
+void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) {
+  nb::class_<PyDynamicOpTraits::NoTerminator, PyDynamicOpTrait> cls(
+      m, "NoTerminatorTrait");
+  cls.attr("attach") = classmethod(
+      [](const nb::object &cls, const nb::object &opName,
+         DefaultingPyMlirContext context) {
+        return PyDynamicOpTraits::NoTerminator::attach(opName, *context.get());
+      },
+      "Attach NoTerminator trait to the given operation name.", nb::arg("cls"),
+      nb::arg("op_name"), nb::arg("context").none() = nb::none());
+}
+
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
 } // namespace mlir
@@ -4844,6 +4966,11 @@ void populateIRCore(nb::module_ &m) {
 
   // Attribute builder getter.
   PyAttrBuilderMap::bind(m);
+
+  // Extensible Dialect
+  PyDynamicOpTrait::bind(m);
+  PyDynamicOpTraits::IsTerminator::bind(m);
+  PyDynamicOpTraits::NoTerminator::bind(m);
 }
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 36f28520d6757..d78f9d9735aa3 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIIR
   BuiltinTypes.cpp
   Diagnostics.cpp
   DialectHandle.cpp
+  ExtensibleDialect.cpp
   IntegerSet.cpp
   IR.cpp
   Pass.cpp
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
new file mode 100644
index 0000000000000..f3239d996a0e6
--- /dev/null
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -0,0 +1,87 @@
+//===- ExtensibleDialect - C API for MLIR Extensible Dialect --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/ExtensibleDialect.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/OperationSupport.h"
+
+using namespace mlir;
+
+DEFINE_C_API_PTR_METHODS(MlirDynamicOpTrait, DynamicOpTrait)
+
+bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
+                              MlirStringRef opName, MlirContext context) {
+  std::optional<RegisteredOperationName> opNameFound =
+      RegisteredOperationName::lookup(unwrap(opName), unwrap(context));
+  assert(opNameFound && "operation name must be registered in the context");
+
+  // The original getImpl() is protected, so we create a small helper struct
+  // here.
+  struct RegisteredOperationNameWithImpl : RegisteredOperationName {
+    Impl *getImpl() { return RegisteredOperationName::getImpl(); }
+  };
+  OperationName::Impl *impl =
+      static_cast<RegisteredOperationNameWithImpl &>(*opNameFound).getImpl();
+
+  std::unique_ptr<DynamicOpTrait> trait(unwrap(dynamicOpTrait));
+  // TODO: we should check whether the `impl` is a DynamicOpDefinition here
+  // via llvm-style RTTI.
+  return static_cast<DynamicOpDefinition *>(impl)->addTrait(std::move(trait));
+}
+
+MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator() {
+  return wrap(new DynamicOpTraits::IsTerminator());
+}
+
+MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator() {
+  return wrap(new DynamicOpTraits::NoTerminator());
+}
+
+void mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait) {
+  delete unwrap(dynamicOpTrait);
+}
+
+namespace mlir {
+
+class ExternalDynamicOpTrait : public DynamicOpTrait {
+public:
+  ExternalDynamicOpTrait(TypeID typeID, MlirDynamicOpTraitCallbacks callbacks,
+                         void *userData)
+      : typeID(typeID), callbacks(callbacks), userData(userData) {
+    if (callbacks.construct)
+      callbacks.construct(userData);
+  }
+  ~ExternalDynamicOpTrait() {
+    if (callbacks.destruct)
+      callbacks.destruct(userData);
+  }
+
+  LogicalResult verifyTrait(Operation *op) const override {
+    return unwrap(callbacks.verifyTrait(wrap(op), userData));
+  };
+  LogicalResult verifyRegionTrait(Operation *op) const override {
+    return unwrap(callbacks.verifyRegionTrait(wrap(op), userData));
+  };
+
+  TypeID getTypeID() const override { return typeID; };
+
+private:
+  TypeID typeID;
+  MlirDynamicOpTraitCallbacks callbacks;
+  void *userData;
+};
+
+} // namespace mlir
+
+MlirDynamicOpTrait mlirDynamicOpTraitCreate(
+    MlirTypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData) {
+  return wrap(
+      new mlir::ExternalDynamicOpTrait(unwrap(typeID), callbacks, userData));
+}
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 30e705726756b..4a69d2e8deb00 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -350,22 +350,54 @@ class TestRegion(Dialect, name="ext_region"):
 
     class IfOp(TestRegion.Operation, name="if"):
         cond: Operand[IntegerType[1]]
+        result: Result[Any]
         then: Region
         else_: Region
 
+    class YieldOp(TestRegion.Operation, name="yield"):
+        value: Operand[Any]
+
+    class NoTermOp(TestRegion.Operation, name="no_term"):
+        body: Region
+
     with Context(), Location.unknown():
         TestRegion.load()
         # CHECK: irdl.dialect @ext_region {
-        # CHECK:     irdl.operation @if {
+        # CHECK:   irdl.operation @if {
         # CHECK:     %0 = irdl.is i1
         # CHECK:     irdl.operands(cond: %0)
-        # CHECK:     %1 = irdl.region
+        # CHECK:     %1 = irdl.any
+        # CHECK:     irdl.results(result: %1)
         # CHECK:     %2 = irdl.region
-        # CHECK:     irdl.regions(then: %1, else_: %2)
+        # CHECK:     %3 = irdl.region
+        # CHECK:     irdl.regions(then: %2, else_: %3)
+        # CHECK:   }
+        # CHECK:   irdl.operation @yield {
+        # CHECK:     %0 = irdl.any
+        # CHECK:     irdl.operands(value: %0)
+        # CHECK:   }
+        # CHECK:   irdl.operation @no_term {
+        # CHECK:     %0 = irdl.region
+        # CHECK:     irdl.regions(body: %0)
+        # CHECK:   }
         # CHECK: }
         print(TestRegion._mlir_module)
 
-        # CHECK: (self, /, cond, *, loc=None, ip=None)
+        IsTerminatorTrait.attach(YieldOp)
+        NoTerminatorTrait.attach(NoTermOp)
+
+        class ParentIsIfTrait(DynamicOpTrait):
+            @staticmethod
+            def verify(op) -> bool:
+                if not isinstance(op.parent.opview, IfOp):
+                    raise RuntimeError(
+                        f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
+                    )
+                return True
+
+        ParentIsIfTrait.attach(YieldOp)
+
+        # CHECK: (self, /, result, cond, *, loc=None, ip=None)
         print(IfOp.__init__.__signature__)
 
         # CHECK: None None
@@ -373,36 +405,44 @@ class IfOp(TestRegion.Operation, name="if"):
         # CHECK: (2, True)
         print(IfOp._ODS_REGIONS)
 
-        from mlir.dialects import llvm
-
         module = Module.create()
         with InsertionPoint(module.body):
             i1 = IntegerType.get_signless(1)
             i32 = IntegerType.get_signless(32)
             cond = arith.constant(i1, 1)
 
-            if_ = IfOp(cond)
+            if_ = IfOp(i32, cond)
             if_.then.blocks.append()
             if_.else_.blocks.append()
 
             with InsertionPoint(if_.then.blocks[0]):
                 v = arith.constant(i32, 2)
-                llvm.unreachable()
+                YieldOp(v)
 
             with InsertionPoint(if_.else_.blocks[0]):
                 v = arith.constant(i32, 3)
-                llvm.unreachable()
+                YieldOp(v)
+
+            nt = NoTermOp()
+            nt.body.blocks.append()
+
+            with InsertionPoint(nt.body.blocks[0]):
+                arith.constant(i32, 4)
+                # No terminator here
 
         assert module.operation.verify()
         # CHECK: module {
         # CHECK:     %true = arith.constant true
-        # CHECK:     "ext_region.if"(%true) ({
+        # CHECK:     %0 = "ext_region.if"(%true) ({
         # CHECK:         %c2_i32 = arith.constant 2 : i32
-        # CHECK:         llvm.unreachable
+        # CHECK:         "ext_region.yield"(%c2_i32) : (i32) -> ()
         # CHECK:     }, {
         # CHECK:         %c3_i32 = arith.constant 3 : i32
-        # CHECK:         llvm.unreachable
-        # CHECK:     }) : (i1) -> ()
+        # CHECK:         "ext_region.yield"(%c3_i32) : (i32) -> ()
+        # CHECK:     }) : (i1) -> i32
+        # CHECK:     "ext_region.no_term"() ({
+        # CHECK:       %c4_i32 = arith.constant 4 : i32
+        # CHECK:     }) : () -> ()
         # CHECK: }
         print(module)
 
@@ -410,3 +450,40 @@ class IfOp(TestRegion.Operation, name="if"):
         print(if_.then.blocks[0])
         # CHECK: %c3_i32 = arith.constant 3 : i32
         print(if_.else_.blocks[0])
+
+        # CHECK-LABEL: Testing violation cases
+        print("Testing violation cases:")
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i1 = IntegerType.get_signless(1)
+            i32 = IntegerType.get_signless(32)
+            cond = arith.constant(i1, 1)
+
+            if_ = IfOp(i32, cond)
+            if_.then.blocks.append()
+            if_.else_.blocks.append()
+
+            with InsertionPoint(if_.then.blocks[0]):
+                v = arith.constant(i32, 2)
+
+            with InsertionPoint(if_.else_.blocks[0]):
+                v = arith.constant(i32, 3)
+
+        try:
+            module.operation.verify()
+        except Exception as e:
+            # CHECK: Verification failed:
+            # CHECK: block with no terminator
+            print(e)
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            v = arith.constant(i32, 2)
+            YieldOp(v)
+
+        try:
+            module.operation.verify()
+        except Exception as e:
+            # CHECK: ext_region.yield should be put inside ext_region.if
+            print(e)

``````````

</details>


https://github.com/llvm/llvm-project/pull/179705


More information about the Mlir-commits mailing list