[Mlir-commits] [mlir] [MLIR][Python] Support region and traits in python-defined dialects (PR #179032)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 31 05:02:12 PST 2026


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/179032

>From f97aa360e4375830facb938bfaebd63738af9274 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 31 Jan 2026 20:59:22 +0800
Subject: [PATCH 1/2] [MLIR][Python] Support region and traits in
 python-defined dialects

---
 .../examples/standalone/python/CMakeLists.txt |  3 +-
 mlir/include/mlir-c/ExtensibleDialect.h       | 55 +++++++++++++
 mlir/include/mlir/Bindings/Python/IRCore.h    | 39 +++++++++
 mlir/lib/Bindings/Python/IRCore.cpp           | 21 +++++
 mlir/lib/CAPI/IR/CMakeLists.txt               |  1 +
 mlir/lib/CAPI/IR/ExtensibleDialect.cpp        | 47 +++++++++++
 mlir/python/mlir/dialects/ext.py              | 64 +++++++++++----
 mlir/test/python/dialects/ext.py              | 79 +++++++++++++++++++
 8 files changed, 292 insertions(+), 17 deletions(-)
 create mode 100644 mlir/include/mlir-c/ExtensibleDialect.h
 create mode 100644 mlir/lib/CAPI/IR/ExtensibleDialect.cpp

diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index edaedf18cc843..77df1d1b9f92c 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -66,7 +66,8 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
   if(NOT EXTERNAL_PROJECT_BUILD)
     set(_core_type_stub_sources
       _mlir/__init__.pyi
-      _mlir/ir.pyi
+      _mlir/ir/__init__.pyi
+      _mlir/ir/dynamic_op_traits.pyi
       _mlir/passmanager.pyi
       _mlir/rewrite.pyi
     )
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
new file mode 100644
index 0000000000000..4a77c76c92d76
--- /dev/null
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -0,0 +1,55 @@
+//===-- mlir-c/ExtensibleDialect.h - Extensible dialect management ---*- C
+//-*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header provides basic access to the MLIR JIT. This is minimalist and
+// experimental at the moment.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_EXTENSIBLEDIALECT_H
+#define MLIR_C_EXTENSIBLEDIALECT_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//===----------------------------------------------------------------------===//
+/// Opaque type declarations (see mlir-c/IR.h for more details).
+//===----------------------------------------------------------------------===//
+
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirDynamicOpTrait, void);
+
+/// Attach a dynamic op trait to the given operation name.
+/// Note that the operation name must be modeled by dynamic dialect and must be
+/// registered.
+MLIR_CAPI_EXPORTED bool
+mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
+                         MlirStringRef opName, MlirContext context);
+
+/// Get the dynamic op trait that indicates the operation is a terminator.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator(void);
+
+/// Get the dynamic op trait that indicates regions have no terminator.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator(void);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_EXTENSIBLEDIALECT_H
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 4bb49e6bc245d..45a04ccc4bb3d 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -23,6 +23,7 @@
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Debug.h"
 #include "mlir-c/Diagnostics.h"
+#include "mlir-c/ExtensibleDialect.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Support.h"
@@ -1844,6 +1845,44 @@ class MLIR_PYTHON_API_EXPORTED PyOpAdaptor {
   PyOpAttributeMap attributes;
 };
 
+class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
+public:
+  PyDynamicOpTrait(MlirDynamicOpTrait trait) : trait(trait) {}
+
+  bool attach(std::string opName, DefaultingPyMlirContext context) {
+    return mlirDynamicOpTraitAttach(trait,
+                                    MlirStringRef{opName.data(), opName.size()},
+                                    context.get()->get());
+  }
+
+  static void bind(nanobind::module_ &m);
+
+private:
+  MlirDynamicOpTrait trait;
+};
+
+namespace PyDynamicOpTraits {
+
+class IsTerminator : public PyDynamicOpTrait {
+public:
+  IsTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetIsTerminator()) {}
+  static void bind(nanobind::module_ &m) {
+    nanobind::class_<IsTerminator, PyDynamicOpTrait>(m, "IsTerminatorTrait")
+        .def(nanobind::init<>());
+  }
+};
+
+class NoTerminator : public PyDynamicOpTrait {
+public:
+  NoTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetNoTerminator()) {}
+  static void bind(nanobind::module_ &m) {
+    nanobind::class_<NoTerminator, PyDynamicOpTrait>(m, "NoTerminatorTrait")
+        .def(nanobind::init<>());
+  }
+};
+
+} // namespace PyDynamicOpTraits
+
 MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
 MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
 MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index dda4a027f0a30..b681dc246ea3c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2521,6 +2521,22 @@ void PyOpAdaptor::bind(nb::module_ &m) {
           "Returns the attributes of the adaptor.");
 }
 
+void PyDynamicOpTrait::bind(nb::module_ &m) {
+  nb::class_<PyDynamicOpTrait>(m, "DynamicOpTrait")
+      .def("attach", &PyDynamicOpTrait::attach,
+           "Attach the dynamic op trait to the given operation name.",
+           nb::arg("op_name"), nb::arg("context").none() = nb::none())
+      .def(
+          "attach",
+          [](PyDynamicOpTrait &self, const nb::type_object &opView,
+             DefaultingPyMlirContext context) {
+            return self.attach(
+                nb::cast<std::string>(opView.attr("OPERATION_NAME")), context);
+          },
+          "Attach the dynamic op trait to the given OpView class.",
+          nb::arg("op_view"), nb::arg("context").none() = nb::none());
+}
+
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
 } // namespace mlir
@@ -4844,6 +4860,11 @@ void populateIRCore(nb::module_ &m) {
 
   // Attribute builder getter.
   PyAttrBuilderMap::bind(m);
+
+  // Extensible Dialect
+  PyDynamicOpTrait::bind(m);
+  PyDynamicOpTraits::IsTerminator::bind(m);
+  PyDynamicOpTraits::NoTerminator::bind(m);
 }
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 36f28520d6757..d78f9d9735aa3 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIIR
   BuiltinTypes.cpp
   Diagnostics.cpp
   DialectHandle.cpp
+  ExtensibleDialect.cpp
   IntegerSet.cpp
   IR.cpp
   Pass.cpp
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
new file mode 100644
index 0000000000000..6fe527963895d
--- /dev/null
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -0,0 +1,47 @@
+//===- ExtensibleDialect - C API for MLIR Extensible Dialect
+//-----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/ExtensibleDialect.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/OperationSupport.h"
+
+using namespace mlir;
+
+DEFINE_C_API_PTR_METHODS(MlirDynamicOpTrait, DynamicOpTrait)
+
+bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
+                              MlirStringRef opName, MlirContext context) {
+  std::optional<RegisteredOperationName> opNameFound =
+      RegisteredOperationName::lookup(unwrap(opName), unwrap(context));
+  assert(opNameFound && "operation name must be registered in the context");
+
+  // The original getImpl() is protected, so we create a small helper struct
+  // here.
+  struct RegisteredOperationNameWithImpl : RegisteredOperationName {
+    Impl *getImpl() { return RegisteredOperationName::getImpl(); }
+  };
+  OperationName::Impl *impl =
+      static_cast<RegisteredOperationNameWithImpl &>(*opNameFound).getImpl();
+
+  DynamicOpTrait *trait = unwrap(dynamicOpTrait);
+  // TODO: we should check whether the `impl` is a DynamicOpDefinition here
+  // via llvm-style RTTI.
+  return static_cast<DynamicOpDefinition *>(impl)->addTrait(
+      std::unique_ptr<DynamicOpTrait>(trait));
+}
+
+MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator() {
+  return wrap(new DynamicOpTraits::IsTerminator());
+}
+
+MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator() {
+  return wrap(new DynamicOpTraits::NoTerminator());
+}
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 237c27bf62f77..31378f74d049f 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -29,10 +29,12 @@
     "Dialect",
     "Operand",
     "Result",
+    "Region",
 ]
 
 Operand = ir.Value
 Result = ir.OpResult
+Region = ir.Region
 
 
 class ConstraintLoweringContext:
@@ -102,7 +104,6 @@ class FieldDef:
     """
 
     name: str
-    constraint: Any
     variadicity: Variadicity
 
     @staticmethod
@@ -117,38 +118,50 @@ def from_type_hint(name, type_) -> "FieldDef":
 
         origin = get_origin(type_)
         if origin is ir.OpResult:
-            return ResultDef(name, get_args(type_)[0], variadicity)
+            return ResultDef(name, variadicity, get_args(type_)[0])
         elif origin is ir.Value:
-            return OperandDef(name, get_args(type_)[0], variadicity)
+            return OperandDef(name, variadicity, get_args(type_)[0])
         elif issubclass(origin or type_, ir.Attribute):
-            return AttributeDef(name, type_, variadicity)
+            return AttributeDef(name, variadicity, type_)
+        elif type_ is ir.Region:
+            return RegionDef(name, variadicity)
         raise TypeError(f"unsupported type in operation definition: {type_}")
 
 
 @dataclass
 class OperandDef(FieldDef):
-    pass
+    constraint: Any
 
 
 @dataclass
 class ResultDef(FieldDef):
-    pass
+    constraint: Any
 
 
 @dataclass
 class AttributeDef(FieldDef):
+    constraint: Any
+
     def __post_init__(self):
         if self.variadicity != Variadicity.single:
-            raise ValueError("optional attribute is not supported in IRDL")
+            raise ValueError("optional attribute is not currently supported")
+
+
+ at dataclass
+class RegionDef(FieldDef):
+    def __post_init__(self):
+        if self.variadicity != Variadicity.single:
+            raise ValueError("optional region is not currently supported")
 
 
 def partition_fields(
     fields: List[FieldDef],
-) -> Tuple[List[OperandDef], List[AttributeDef], List[ResultDef]]:
+) -> Tuple[List[OperandDef], List[AttributeDef], List[ResultDef], List[RegionDef]]:
     operands = [i for i in fields if isinstance(i, OperandDef)]
     attrs = [i for i in fields if isinstance(i, AttributeDef)]
     results = [i for i in fields if isinstance(i, ResultDef)]
-    return operands, attrs, results
+    regions = [i for i in fields if isinstance(i, RegionDef)]
+    return operands, attrs, results, regions
 
 
 def normalize_value_range(
@@ -223,10 +236,11 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
 
         cls._generate_class_attributes(dialect_name, op_name, fields)
         cls._generate_init_method(fields)
-        operands, attrs, results = partition_fields(fields)
+        operands, attrs, results, regions = partition_fields(fields)
         cls._generate_attr_properties(attrs)
         cls._generate_operand_properties(operands)
         cls._generate_result_properties(results)
+        cls._generate_region_properties(regions)
 
         dialect_obj.operations.append(cls)
 
@@ -254,7 +268,11 @@ def _generate_init_signature(
         )
         # results are placed at the beginning of the parameter list,
         # but operands and attributes can appear in any relative order.
-        args = result_args + [i for i in fields if not isinstance(i, ResultDef)]
+        args = result_args + [
+            i
+            for i in fields
+            if not isinstance(i, ResultDef) and not isinstance(i, RegionDef)
+        ]
         positional_args = [
             i.name for i in args if i.variadicity != Variadicity.optional
         ]
@@ -272,7 +290,7 @@ def _generate_init_signature(
 
     @classmethod
     def _generate_init_method(cls, fields: List[FieldDef]) -> None:
-        operands, attrs, results = partition_fields(fields)
+        operands, attrs, results, regions = partition_fields(fields)
         inferred_types = [infer_type(i.constraint) for i in results]
 
         # we infer result types only when all result types can be inferred
@@ -299,7 +317,7 @@ def __init__(*args, **kwargs):
                 for attr in attrs
                 if args[attr.name] is not None
             )
-            _regions = None
+            _regions = len(regions) or None
             _ods_successors = None
             self = args["self"]
             super(Operation, self).__init__(
@@ -323,13 +341,13 @@ def __init__(*args, **kwargs):
     def _generate_class_attributes(
         cls, dialect_name: str, op_name: str, fields: List[FieldDef]
     ) -> None:
-        operands, attrs, results = partition_fields(fields)
+        operands, attrs, results, regions = partition_fields(fields)
 
         operand_segments = cls._generate_segments(operands)
         result_segments = cls._generate_segments(results)
 
         cls.OPERATION_NAME = f"{dialect_name}.{op_name}"
-        cls._ODS_REGIONS = (0, True)
+        cls._ODS_REGIONS = (len(regions), True)
         cls._ODS_OPERAND_SEGMENTS = operand_segments
         cls._ODS_RESULT_SEGMENTS = result_segments
 
@@ -342,6 +360,15 @@ def _generate_attr_properties(cls, attrs: List[AttributeDef]) -> None:
                 property(lambda self, name=attr.name: self.attributes[name]),
             )
 
+    @classmethod
+    def _generate_region_properties(cls, regions: List[RegionDef]) -> None:
+        for i, region in enumerate(regions):
+            setattr(
+                cls,
+                region.name,
+                property(lambda self, i=i: self.regions[i]),
+            )
+
     @classmethod
     def _generate_operand_properties(cls, operands: List[OperandDef]) -> None:
         for i, operand in enumerate(operands):
@@ -379,7 +406,7 @@ def getter(self, i=i, result=result):
     @classmethod
     def _emit_operation(cls) -> None:
         ctx = ConstraintLoweringContext()
-        operands, attrs, results = partition_fields(cls._fields)
+        operands, attrs, results, regions = partition_fields(cls._fields)
 
         op = irdl.operation_(cls._op_name)
         with ir.InsertionPoint(op.body):
@@ -400,6 +427,11 @@ def _emit_operation(cls) -> None:
                     [i.name for i in results],
                     [i.variadicity for i in results],
                 )
+            if regions:
+                irdl.regions_(
+                    [irdl.region([]) for _ in regions],
+                    [i.name for i in regions],
+                )
 
 
 class Dialect(ir.Dialect):
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 483953ddfde51..77e5f63d092bc 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -76,6 +76,8 @@ class AddOp(MyInt.Operation, name="add"):
         print(add1._ODS_OPERAND_SEGMENTS)
         # CHECK: None
         print(add1._ODS_RESULT_SEGMENTS)
+        # CHECK: (0, True)
+        print(add1._ODS_REGIONS)
         # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
         print(add1.lhs.owner)
         # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
@@ -338,3 +340,80 @@ class TypeVarOp(Test.Operation, name="type_var"):
             except TypeError as e:
                 # CHECK:too many positional arguments
                 print(e)
+
+
+# CHECK: TEST: testExtDialectWithRegion
+ at run
+def testExtDialectWithRegion():
+    class TestRegion(Dialect, name="ext_region"):
+        pass
+
+    class IfOp(TestRegion.Operation, name="if"):
+        cond: Operand[IntegerType[1]]
+        result: Result[Any]
+        then: Region
+        else_: Region
+
+    class YieldOp(TestRegion.Operation, name="yield"):
+        val: Operand[Any]
+
+    with Context(), Location.unknown():
+        TestRegion.load()
+        # CHECK: irdl.dialect @ext_region {
+        # CHECK:     irdl.operation @if {
+        # CHECK:     %0 = irdl.is i1
+        # CHECK:     irdl.operands(cond: %0)
+        # CHECK:     %1 = irdl.any
+        # CHECK:     irdl.results(result: %1)
+        # CHECK:     %2 = irdl.region
+        # CHECK:     %3 = irdl.region
+        # CHECK:     irdl.regions(then: %2, else_: %3)
+        # CHECK: }
+        # CHECK: irdl.operation @yield {
+        # CHECK:     %0 = irdl.any
+        # CHECK:     irdl.operands(val: %0)
+        # CHECK: }
+        print(TestRegion._mlir_module)
+
+        IsTerminatorTrait().attach(YieldOp)
+
+        # CHECK: (self, /, result, cond, *, loc=None, ip=None)
+        print(IfOp.__init__.__signature__)
+
+        # CHECK: None None
+        print(IfOp._ODS_OPERAND_SEGMENTS, IfOp._ODS_RESULT_SEGMENTS)
+        # CHECK: (2, True)
+        print(IfOp._ODS_REGIONS)
+
+        from mlir.dialects import llvm
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i1 = IntegerType.get_signless(1)
+            i32 = IntegerType.get_signless(32)
+            cond = arith.constant(i1, 1)
+
+            if_ = IfOp(i32, cond)
+            if_.then.blocks.append()
+            if_.else_.blocks.append()
+
+            with InsertionPoint(if_.then.blocks[0]):
+                v = arith.constant(i32, 2)
+                YieldOp(v)
+
+            with InsertionPoint(if_.else_.blocks[0]):
+                v = arith.constant(i32, 3)
+                YieldOp(v)
+
+        assert module.operation.verify()
+        # CHECK: module {
+        # CHECK:     %true = arith.constant true
+        # CHECK:     %0 = "ext_region.if"(%true) ({
+        # CHECK:         %c2_i32 = arith.constant 2 : i32
+        # CHECK:         "ext_region.yield"(%c2_i32) : (i32) -> ()
+        # CHECK:     }, {
+        # CHECK:         %c3_i32 = arith.constant 3 : i32
+        # CHECK:         "ext_region.yield"(%c3_i32) : (i32) -> ()
+        # CHECK:     }) : (i1) -> i32
+        # CHECK: }
+        print(module)

>From b300ea16a006b548923609d72f7c2f18ead0ed12 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 31 Jan 2026 21:02:00 +0800
Subject: [PATCH 2/2] revert useless change

---
 mlir/examples/standalone/python/CMakeLists.txt | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index 77df1d1b9f92c..edaedf18cc843 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -66,8 +66,7 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
   if(NOT EXTERNAL_PROJECT_BUILD)
     set(_core_type_stub_sources
       _mlir/__init__.pyi
-      _mlir/ir/__init__.pyi
-      _mlir/ir/dynamic_op_traits.pyi
+      _mlir/ir.pyi
       _mlir/passmanager.pyi
       _mlir/rewrite.pyi
     )



More information about the Mlir-commits mailing list