[Mlir-commits] [mlir] [MLIR][Python] Support dynamic traits in python-defined dialects (PR #179705)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 5 08:05:57 PST 2026


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/179705

>From 3b084d320d8b30c665c1d2a3a52dfe6aacca982a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 5 Feb 2026 00:04:33 +0800
Subject: [PATCH 1/3] [MLIR][Python] Support dynamic traits in python-defined
 dialects

---
 mlir/include/mlir-c/ExtensibleDialect.h    | 55 ++++++++++++++++++++++
 mlir/include/mlir/Bindings/Python/IRCore.h | 39 +++++++++++++++
 mlir/lib/Bindings/Python/IRCore.cpp        | 21 +++++++++
 mlir/lib/CAPI/IR/CMakeLists.txt            |  1 +
 mlir/lib/CAPI/IR/ExtensibleDialect.cpp     | 47 ++++++++++++++++++
 mlir/test/python/dialects/ext.py           | 55 +++++++++++++++++-----
 6 files changed, 205 insertions(+), 13 deletions(-)
 create mode 100644 mlir/include/mlir-c/ExtensibleDialect.h
 create mode 100644 mlir/lib/CAPI/IR/ExtensibleDialect.cpp

diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
new file mode 100644
index 0000000000000..4a77c76c92d76
--- /dev/null
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -0,0 +1,55 @@
+//===-- mlir-c/ExtensibleDialect.h - Extensible dialect management ---*- C
+//-*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header provides basic access to the MLIR JIT. This is minimalist and
+// experimental at the moment.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_EXTENSIBLEDIALECT_H
+#define MLIR_C_EXTENSIBLEDIALECT_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//===----------------------------------------------------------------------===//
+/// Opaque type declarations (see mlir-c/IR.h for more details).
+//===----------------------------------------------------------------------===//
+
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirDynamicOpTrait, void);
+
+/// Attach a dynamic op trait to the given operation name.
+/// Note that the operation name must be modeled by dynamic dialect and must be
+/// registered.
+MLIR_CAPI_EXPORTED bool
+mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
+                         MlirStringRef opName, MlirContext context);
+
+/// Get the dynamic op trait that indicates the operation is a terminator.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator(void);
+
+/// Get the dynamic op trait that indicates regions have no terminator.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator(void);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_EXTENSIBLEDIALECT_H
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 4bb49e6bc245d..45a04ccc4bb3d 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -23,6 +23,7 @@
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Debug.h"
 #include "mlir-c/Diagnostics.h"
+#include "mlir-c/ExtensibleDialect.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Support.h"
@@ -1844,6 +1845,44 @@ class MLIR_PYTHON_API_EXPORTED PyOpAdaptor {
   PyOpAttributeMap attributes;
 };
 
+class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
+public:
+  PyDynamicOpTrait(MlirDynamicOpTrait trait) : trait(trait) {}
+
+  bool attach(std::string opName, DefaultingPyMlirContext context) {
+    return mlirDynamicOpTraitAttach(trait,
+                                    MlirStringRef{opName.data(), opName.size()},
+                                    context.get()->get());
+  }
+
+  static void bind(nanobind::module_ &m);
+
+private:
+  MlirDynamicOpTrait trait;
+};
+
+namespace PyDynamicOpTraits {
+
+class IsTerminator : public PyDynamicOpTrait {
+public:
+  IsTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetIsTerminator()) {}
+  static void bind(nanobind::module_ &m) {
+    nanobind::class_<IsTerminator, PyDynamicOpTrait>(m, "IsTerminatorTrait")
+        .def(nanobind::init<>());
+  }
+};
+
+class NoTerminator : public PyDynamicOpTrait {
+public:
+  NoTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetNoTerminator()) {}
+  static void bind(nanobind::module_ &m) {
+    nanobind::class_<NoTerminator, PyDynamicOpTrait>(m, "NoTerminatorTrait")
+        .def(nanobind::init<>());
+  }
+};
+
+} // namespace PyDynamicOpTraits
+
 MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
 MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
 MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7f34343eba6c9..f8fc737e9b8fe 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2521,6 +2521,22 @@ void PyOpAdaptor::bind(nb::module_ &m) {
           "Returns the attributes of the adaptor.");
 }
 
+void PyDynamicOpTrait::bind(nb::module_ &m) {
+  nb::class_<PyDynamicOpTrait>(m, "DynamicOpTrait")
+      .def("attach", &PyDynamicOpTrait::attach,
+           "Attach the dynamic op trait to the given operation name.",
+           nb::arg("op_name"), nb::arg("context").none() = nb::none())
+      .def(
+          "attach",
+          [](PyDynamicOpTrait &self, const nb::type_object &opView,
+             DefaultingPyMlirContext context) {
+            return self.attach(
+                nb::cast<std::string>(opView.attr("OPERATION_NAME")), context);
+          },
+          "Attach the dynamic op trait to the given OpView class.",
+          nb::arg("op_view"), nb::arg("context").none() = nb::none());
+}
+
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
 } // namespace mlir
@@ -4844,6 +4860,11 @@ void populateIRCore(nb::module_ &m) {
 
   // Attribute builder getter.
   PyAttrBuilderMap::bind(m);
+
+  // Extensible Dialect
+  PyDynamicOpTrait::bind(m);
+  PyDynamicOpTraits::IsTerminator::bind(m);
+  PyDynamicOpTraits::NoTerminator::bind(m);
 }
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 36f28520d6757..d78f9d9735aa3 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIIR
   BuiltinTypes.cpp
   Diagnostics.cpp
   DialectHandle.cpp
+  ExtensibleDialect.cpp
   IntegerSet.cpp
   IR.cpp
   Pass.cpp
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
new file mode 100644
index 0000000000000..6fe527963895d
--- /dev/null
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -0,0 +1,47 @@
+//===- ExtensibleDialect - C API for MLIR Extensible Dialect
+//-----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/ExtensibleDialect.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/OperationSupport.h"
+
+using namespace mlir;
+
+DEFINE_C_API_PTR_METHODS(MlirDynamicOpTrait, DynamicOpTrait)
+
+bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
+                              MlirStringRef opName, MlirContext context) {
+  std::optional<RegisteredOperationName> opNameFound =
+      RegisteredOperationName::lookup(unwrap(opName), unwrap(context));
+  assert(opNameFound && "operation name must be registered in the context");
+
+  // The original getImpl() is protected, so we create a small helper struct
+  // here.
+  struct RegisteredOperationNameWithImpl : RegisteredOperationName {
+    Impl *getImpl() { return RegisteredOperationName::getImpl(); }
+  };
+  OperationName::Impl *impl =
+      static_cast<RegisteredOperationNameWithImpl &>(*opNameFound).getImpl();
+
+  DynamicOpTrait *trait = unwrap(dynamicOpTrait);
+  // TODO: we should check whether the `impl` is a DynamicOpDefinition here
+  // via llvm-style RTTI.
+  return static_cast<DynamicOpDefinition *>(impl)->addTrait(
+      std::unique_ptr<DynamicOpTrait>(trait));
+}
+
+MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator() {
+  return wrap(new DynamicOpTraits::IsTerminator());
+}
+
+MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator() {
+  return wrap(new DynamicOpTraits::NoTerminator());
+}
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 30e705726756b..48ac5c5c51d19 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -350,22 +350,43 @@ class TestRegion(Dialect, name="ext_region"):
 
     class IfOp(TestRegion.Operation, name="if"):
         cond: Operand[IntegerType[1]]
+        result: Result[Any]
         then: Region
         else_: Region
+    
+    class YieldOp(TestRegion.Operation, name="yield"):
+        value: Operand[Any]
+
+    class NoTermOp(TestRegion.Operation, name="no_term"):
+        body: Region
 
     with Context(), Location.unknown():
         TestRegion.load()
         # CHECK: irdl.dialect @ext_region {
-        # CHECK:     irdl.operation @if {
+        # CHECK:   irdl.operation @if {
         # CHECK:     %0 = irdl.is i1
         # CHECK:     irdl.operands(cond: %0)
-        # CHECK:     %1 = irdl.region
+        # CHECK:     %1 = irdl.any
+        # CHECK:     irdl.results(result: %1)
         # CHECK:     %2 = irdl.region
-        # CHECK:     irdl.regions(then: %1, else_: %2)
+        # CHECK:     %3 = irdl.region
+        # CHECK:     irdl.regions(then: %2, else_: %3)
+        # CHECK:   }
+        # CHECK:   irdl.operation @yield {
+        # CHECK:     %0 = irdl.any
+        # CHECK:     irdl.operands(value: %0)
+        # CHECK:   }
+        # CHECK:   irdl.operation @no_term {
+        # CHECK:     %0 = irdl.region
+        # CHECK:     irdl.regions(body: %0)
+        # CHECK:   }
         # CHECK: }
         print(TestRegion._mlir_module)
 
-        # CHECK: (self, /, cond, *, loc=None, ip=None)
+        IsTerminatorTrait().attach(YieldOp)
+        NoTerminatorTrait().attach(NoTermOp)
+
+        # CHECK: (self, /, result, cond, *, loc=None, ip=None)
         print(IfOp.__init__.__signature__)
 
         # CHECK: None None
@@ -373,36 +394,44 @@ class IfOp(TestRegion.Operation, name="if"):
         # CHECK: (2, True)
         print(IfOp._ODS_REGIONS)
 
-        from mlir.dialects import llvm
-
         module = Module.create()
         with InsertionPoint(module.body):
             i1 = IntegerType.get_signless(1)
             i32 = IntegerType.get_signless(32)
             cond = arith.constant(i1, 1)
 
-            if_ = IfOp(cond)
+            if_ = IfOp(i32, cond)
             if_.then.blocks.append()
             if_.else_.blocks.append()
 
             with InsertionPoint(if_.then.blocks[0]):
                 v = arith.constant(i32, 2)
-                llvm.unreachable()
+                YieldOp(v)
 
             with InsertionPoint(if_.else_.blocks[0]):
                 v = arith.constant(i32, 3)
-                llvm.unreachable()
+                YieldOp(v)
+            
+            nt = NoTermOp()
+            nt.body.blocks.append()
+
+            with InsertionPoint(nt.body.blocks[0]):
+                arith.constant(i32, 4)
+                # No terminator here
 
         assert module.operation.verify()
         # CHECK: module {
         # CHECK:     %true = arith.constant true
-        # CHECK:     "ext_region.if"(%true) ({
+        # CHECK:     %0 = "ext_region.if"(%true) ({
         # CHECK:         %c2_i32 = arith.constant 2 : i32
-        # CHECK:         llvm.unreachable
+        # CHECK:         "ext_region.yield"(%c2_i32) : (i32) -> ()
         # CHECK:     }, {
         # CHECK:         %c3_i32 = arith.constant 3 : i32
-        # CHECK:         llvm.unreachable
-        # CHECK:     }) : (i1) -> ()
+        # CHECK:         "ext_region.yield"(%c3_i32) : (i32) -> ()
+        # CHECK:     }) : (i1) -> i32
+        # CHECK:     "ext_region.no_term"() ({
+        # CHECK:       %c4_i32 = arith.constant 4 : i32
+        # CHECK:     }) : () -> ()
         # CHECK: }
         print(module)
 

>From dfc972d08ff59af29a79787d7bc1b0165f6ea3d1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 5 Feb 2026 00:25:12 +0800
Subject: [PATCH 2/3] format

---
 mlir/test/python/dialects/ext.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 48ac5c5c51d19..0062e4ae0b804 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -353,7 +353,7 @@ class IfOp(TestRegion.Operation, name="if"):
         result: Result[Any]
         then: Region
         else_: Region
-    
+
     class YieldOp(TestRegion.Operation, name="yield"):
         value: Operand[Any]
 
@@ -411,7 +411,7 @@ class NoTermOp(TestRegion.Operation, name="no_term"):
             with InsertionPoint(if_.else_.blocks[0]):
                 v = arith.constant(i32, 3)
                 YieldOp(v)
-            
+
             nt = NoTermOp()
             nt.body.blocks.append()
 

>From 47047c562ac009ee7af05923917d91bf3786e5b0 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 6 Feb 2026 00:05:14 +0800
Subject: [PATCH 3/3] refine

---
 mlir/include/mlir-c/ExtensibleDialect.h    |  6 ++++++
 mlir/include/mlir/Bindings/Python/IRCore.h | 11 +++++++++++
 mlir/lib/Bindings/Python/IRCore.cpp        | 12 +++---------
 mlir/lib/CAPI/IR/ExtensibleDialect.cpp     |  9 ++++++---
 4 files changed, 26 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index 4a77c76c92d76..c2c56d9096e63 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -38,6 +38,8 @@ DEFINE_C_API_STRUCT(MlirDynamicOpTrait, void);
 /// Attach a dynamic op trait to the given operation name.
 /// Note that the operation name must be modeled by dynamic dialect and must be
 /// registered.
+/// The ownership of the trait will be transferred to the operation name
+/// after this call.
 MLIR_CAPI_EXPORTED bool
 mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
                          MlirStringRef opName, MlirContext context);
@@ -48,6 +50,10 @@ MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator(void);
 /// Get the dynamic op trait that indicates regions have no terminator.
 MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator(void);
 
+/// Destroy the dynamic op trait.
+MLIR_CAPI_EXPORTED void
+mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 45a04ccc4bb3d..47d2564837398 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1848,13 +1848,24 @@ class MLIR_PYTHON_API_EXPORTED PyOpAdaptor {
 class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
 public:
   PyDynamicOpTrait(MlirDynamicOpTrait trait) : trait(trait) {}
+  ~PyDynamicOpTrait() { mlirDynamicOpTraitDestroy(trait); }
 
   bool attach(std::string opName, DefaultingPyMlirContext context) {
+    assert(this->trait.ptr && "Trait has already been attached");
+
+    MlirDynamicOpTrait trait = this->trait;
+    this->trait = MlirDynamicOpTrait{nullptr};
     return mlirDynamicOpTraitAttach(trait,
                                     MlirStringRef{opName.data(), opName.size()},
                                     context.get()->get());
   }
 
+  bool attachToOpView(const nanobind::type_object &opView,
+                      DefaultingPyMlirContext context) {
+    return attach(nanobind::cast<std::string>(opView.attr("OPERATION_NAME")),
+                  context);
+  }
+
   static void bind(nanobind::module_ &m);
 
 private:
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f8fc737e9b8fe..2b892344e2161 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2526,15 +2526,9 @@ void PyDynamicOpTrait::bind(nb::module_ &m) {
       .def("attach", &PyDynamicOpTrait::attach,
            "Attach the dynamic op trait to the given operation name.",
            nb::arg("op_name"), nb::arg("context").none() = nb::none())
-      .def(
-          "attach",
-          [](PyDynamicOpTrait &self, const nb::type_object &opView,
-             DefaultingPyMlirContext context) {
-            return self.attach(
-                nb::cast<std::string>(opView.attr("OPERATION_NAME")), context);
-          },
-          "Attach the dynamic op trait to the given OpView class.",
-          nb::arg("op_view"), nb::arg("context").none() = nb::none());
+      .def("attach", &PyDynamicOpTrait::attachToOpView,
+           "Attach the dynamic op trait to the given OpView class.",
+           nb::arg("op_view"), nb::arg("context").none() = nb::none());
 }
 
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index 6fe527963895d..b33cd3c8952fc 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -31,11 +31,10 @@ bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
   OperationName::Impl *impl =
       static_cast<RegisteredOperationNameWithImpl &>(*opNameFound).getImpl();
 
-  DynamicOpTrait *trait = unwrap(dynamicOpTrait);
+  std::unique_ptr<DynamicOpTrait> trait(unwrap(dynamicOpTrait));
   // TODO: we should check whether the `impl` is a DynamicOpDefinition here
   // via llvm-style RTTI.
-  return static_cast<DynamicOpDefinition *>(impl)->addTrait(
-      std::unique_ptr<DynamicOpTrait>(trait));
+  return static_cast<DynamicOpDefinition *>(impl)->addTrait(std::move(trait));
 }
 
 MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator() {
@@ -45,3 +44,7 @@ MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator() {
 MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator() {
   return wrap(new DynamicOpTraits::NoTerminator());
 }
+
+void mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait) {
+  delete unwrap(dynamicOpTrait);
+}



More information about the Mlir-commits mailing list