[Mlir-commits] [mlir] [MLIR][Python] Support dynamic traits in python-defined dialects (PR #179705)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 6 22:43:46 PST 2026
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/179705
>From 3b084d320d8b30c665c1d2a3a52dfe6aacca982a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 5 Feb 2026 00:04:33 +0800
Subject: [PATCH 1/8] [MLIR][Python] Support dynamic traits in python-defined
dialects
---
mlir/include/mlir-c/ExtensibleDialect.h | 55 ++++++++++++++++++++++
mlir/include/mlir/Bindings/Python/IRCore.h | 39 +++++++++++++++
mlir/lib/Bindings/Python/IRCore.cpp | 21 +++++++++
mlir/lib/CAPI/IR/CMakeLists.txt | 1 +
mlir/lib/CAPI/IR/ExtensibleDialect.cpp | 47 ++++++++++++++++++
mlir/test/python/dialects/ext.py | 55 +++++++++++++++++-----
6 files changed, 205 insertions(+), 13 deletions(-)
create mode 100644 mlir/include/mlir-c/ExtensibleDialect.h
create mode 100644 mlir/lib/CAPI/IR/ExtensibleDialect.cpp
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
new file mode 100644
index 0000000000000..4a77c76c92d76
--- /dev/null
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -0,0 +1,55 @@
+//===-- mlir-c/ExtensibleDialect.h - Extensible dialect management ---*- C
+//-*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header provides basic access to the MLIR JIT. This is minimalist and
+// experimental at the moment.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_EXTENSIBLEDIALECT_H
+#define MLIR_C_EXTENSIBLEDIALECT_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//===----------------------------------------------------------------------===//
+/// Opaque type declarations (see mlir-c/IR.h for more details).
+//===----------------------------------------------------------------------===//
+
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirDynamicOpTrait, void);
+
+/// Attach a dynamic op trait to the given operation name.
+/// Note that the operation name must be modeled by dynamic dialect and must be
+/// registered.
+MLIR_CAPI_EXPORTED bool
+mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
+ MlirStringRef opName, MlirContext context);
+
+/// Get the dynamic op trait that indicates the operation is a terminator.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator(void);
+
+/// Get the dynamic op trait that indicates regions have no terminator.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator(void);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_EXTENSIBLEDIALECT_H
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 4bb49e6bc245d..45a04ccc4bb3d 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -23,6 +23,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
+#include "mlir-c/ExtensibleDialect.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Support.h"
@@ -1844,6 +1845,44 @@ class MLIR_PYTHON_API_EXPORTED PyOpAdaptor {
PyOpAttributeMap attributes;
};
+class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
+public:
+ PyDynamicOpTrait(MlirDynamicOpTrait trait) : trait(trait) {}
+
+ bool attach(std::string opName, DefaultingPyMlirContext context) {
+ return mlirDynamicOpTraitAttach(trait,
+ MlirStringRef{opName.data(), opName.size()},
+ context.get()->get());
+ }
+
+ static void bind(nanobind::module_ &m);
+
+private:
+ MlirDynamicOpTrait trait;
+};
+
+namespace PyDynamicOpTraits {
+
+class IsTerminator : public PyDynamicOpTrait {
+public:
+ IsTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetIsTerminator()) {}
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<IsTerminator, PyDynamicOpTrait>(m, "IsTerminatorTrait")
+ .def(nanobind::init<>());
+ }
+};
+
+class NoTerminator : public PyDynamicOpTrait {
+public:
+ NoTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetNoTerminator()) {}
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<NoTerminator, PyDynamicOpTrait>(m, "NoTerminatorTrait")
+ .def(nanobind::init<>());
+ }
+};
+
+} // namespace PyDynamicOpTraits
+
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7f34343eba6c9..f8fc737e9b8fe 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2521,6 +2521,22 @@ void PyOpAdaptor::bind(nb::module_ &m) {
"Returns the attributes of the adaptor.");
}
+void PyDynamicOpTrait::bind(nb::module_ &m) {
+ nb::class_<PyDynamicOpTrait>(m, "DynamicOpTrait")
+ .def("attach", &PyDynamicOpTrait::attach,
+ "Attach the dynamic op trait to the given operation name.",
+ nb::arg("op_name"), nb::arg("context").none() = nb::none())
+ .def(
+ "attach",
+ [](PyDynamicOpTrait &self, const nb::type_object &opView,
+ DefaultingPyMlirContext context) {
+ return self.attach(
+ nb::cast<std::string>(opView.attr("OPERATION_NAME")), context);
+ },
+ "Attach the dynamic op trait to the given OpView class.",
+ nb::arg("op_view"), nb::arg("context").none() = nb::none());
+}
+
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
@@ -4844,6 +4860,11 @@ void populateIRCore(nb::module_ &m) {
// Attribute builder getter.
PyAttrBuilderMap::bind(m);
+
+ // Extensible Dialect
+ PyDynamicOpTrait::bind(m);
+ PyDynamicOpTraits::IsTerminator::bind(m);
+ PyDynamicOpTraits::NoTerminator::bind(m);
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 36f28520d6757..d78f9d9735aa3 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIIR
BuiltinTypes.cpp
Diagnostics.cpp
DialectHandle.cpp
+ ExtensibleDialect.cpp
IntegerSet.cpp
IR.cpp
Pass.cpp
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
new file mode 100644
index 0000000000000..6fe527963895d
--- /dev/null
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -0,0 +1,47 @@
+//===- ExtensibleDialect - C API for MLIR Extensible Dialect
+//-----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/ExtensibleDialect.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/OperationSupport.h"
+
+using namespace mlir;
+
+DEFINE_C_API_PTR_METHODS(MlirDynamicOpTrait, DynamicOpTrait)
+
+bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
+ MlirStringRef opName, MlirContext context) {
+ std::optional<RegisteredOperationName> opNameFound =
+ RegisteredOperationName::lookup(unwrap(opName), unwrap(context));
+ assert(opNameFound && "operation name must be registered in the context");
+
+ // The original getImpl() is protected, so we create a small helper struct
+ // here.
+ struct RegisteredOperationNameWithImpl : RegisteredOperationName {
+ Impl *getImpl() { return RegisteredOperationName::getImpl(); }
+ };
+ OperationName::Impl *impl =
+ static_cast<RegisteredOperationNameWithImpl &>(*opNameFound).getImpl();
+
+ DynamicOpTrait *trait = unwrap(dynamicOpTrait);
+ // TODO: we should check whether the `impl` is a DynamicOpDefinition here
+ // via llvm-style RTTI.
+ return static_cast<DynamicOpDefinition *>(impl)->addTrait(
+ std::unique_ptr<DynamicOpTrait>(trait));
+}
+
+MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator() {
+ return wrap(new DynamicOpTraits::IsTerminator());
+}
+
+MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator() {
+ return wrap(new DynamicOpTraits::NoTerminator());
+}
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 30e705726756b..48ac5c5c51d19 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -350,22 +350,43 @@ class TestRegion(Dialect, name="ext_region"):
class IfOp(TestRegion.Operation, name="if"):
cond: Operand[IntegerType[1]]
+ result: Result[Any]
then: Region
else_: Region
+
+ class YieldOp(TestRegion.Operation, name="yield"):
+ value: Operand[Any]
+
+ class NoTermOp(TestRegion.Operation, name="no_term"):
+ body: Region
with Context(), Location.unknown():
TestRegion.load()
# CHECK: irdl.dialect @ext_region {
- # CHECK: irdl.operation @if {
+ # CHECK: irdl.operation @if {
# CHECK: %0 = irdl.is i1
# CHECK: irdl.operands(cond: %0)
- # CHECK: %1 = irdl.region
+ # CHECK: %1 = irdl.any
+ # CHECK: irdl.results(result: %1)
# CHECK: %2 = irdl.region
- # CHECK: irdl.regions(then: %1, else_: %2)
+ # CHECK: %3 = irdl.region
+ # CHECK: irdl.regions(then: %2, else_: %3)
+ # CHECK: }
+ # CHECK: irdl.operation @yield {
+ # CHECK: %0 = irdl.any
+ # CHECK: irdl.operands(value: %0)
+ # CHECK: }
+ # CHECK: irdl.operation @no_term {
+ # CHECK: %0 = irdl.region
+ # CHECK: irdl.regions(body: %0)
+ # CHECK: }
# CHECK: }
print(TestRegion._mlir_module)
- # CHECK: (self, /, cond, *, loc=None, ip=None)
+ IsTerminatorTrait().attach(YieldOp)
+ NoTerminatorTrait().attach(NoTermOp)
+
+ # CHECK: (self, /, result, cond, *, loc=None, ip=None)
print(IfOp.__init__.__signature__)
# CHECK: None None
@@ -373,36 +394,44 @@ class IfOp(TestRegion.Operation, name="if"):
# CHECK: (2, True)
print(IfOp._ODS_REGIONS)
- from mlir.dialects import llvm
-
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
- if_ = IfOp(cond)
+ if_ = IfOp(i32, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
- llvm.unreachable()
+ YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
- llvm.unreachable()
+ YieldOp(v)
+
+ nt = NoTermOp()
+ nt.body.blocks.append()
+
+ with InsertionPoint(nt.body.blocks[0]):
+ arith.constant(i32, 4)
+ # No terminator here
assert module.operation.verify()
# CHECK: module {
# CHECK: %true = arith.constant true
- # CHECK: "ext_region.if"(%true) ({
+ # CHECK: %0 = "ext_region.if"(%true) ({
# CHECK: %c2_i32 = arith.constant 2 : i32
- # CHECK: llvm.unreachable
+ # CHECK: "ext_region.yield"(%c2_i32) : (i32) -> ()
# CHECK: }, {
# CHECK: %c3_i32 = arith.constant 3 : i32
- # CHECK: llvm.unreachable
- # CHECK: }) : (i1) -> ()
+ # CHECK: "ext_region.yield"(%c3_i32) : (i32) -> ()
+ # CHECK: }) : (i1) -> i32
+ # CHECK: "ext_region.no_term"() ({
+ # CHECK: %c4_i32 = arith.constant 4 : i32
+ # CHECK: }) : () -> ()
# CHECK: }
print(module)
>From dfc972d08ff59af29a79787d7bc1b0165f6ea3d1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 5 Feb 2026 00:25:12 +0800
Subject: [PATCH 2/8] format
---
mlir/test/python/dialects/ext.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 48ac5c5c51d19..0062e4ae0b804 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -353,7 +353,7 @@ class IfOp(TestRegion.Operation, name="if"):
result: Result[Any]
then: Region
else_: Region
-
+
class YieldOp(TestRegion.Operation, name="yield"):
value: Operand[Any]
@@ -411,7 +411,7 @@ class NoTermOp(TestRegion.Operation, name="no_term"):
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
YieldOp(v)
-
+
nt = NoTermOp()
nt.body.blocks.append()
>From 47047c562ac009ee7af05923917d91bf3786e5b0 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 6 Feb 2026 00:05:14 +0800
Subject: [PATCH 3/8] refine
---
mlir/include/mlir-c/ExtensibleDialect.h | 6 ++++++
mlir/include/mlir/Bindings/Python/IRCore.h | 11 +++++++++++
mlir/lib/Bindings/Python/IRCore.cpp | 12 +++---------
mlir/lib/CAPI/IR/ExtensibleDialect.cpp | 9 ++++++---
4 files changed, 26 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index 4a77c76c92d76..c2c56d9096e63 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -38,6 +38,8 @@ DEFINE_C_API_STRUCT(MlirDynamicOpTrait, void);
/// Attach a dynamic op trait to the given operation name.
/// Note that the operation name must be modeled by dynamic dialect and must be
/// registered.
+/// The ownership of the trait will be transferred to the operation name
+/// after this call.
MLIR_CAPI_EXPORTED bool
mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
MlirStringRef opName, MlirContext context);
@@ -48,6 +50,10 @@ MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator(void);
/// Get the dynamic op trait that indicates regions have no terminator.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator(void);
+/// Destroy the dynamic op trait.
+MLIR_CAPI_EXPORTED void
+mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 45a04ccc4bb3d..47d2564837398 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1848,13 +1848,24 @@ class MLIR_PYTHON_API_EXPORTED PyOpAdaptor {
class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
public:
PyDynamicOpTrait(MlirDynamicOpTrait trait) : trait(trait) {}
+ ~PyDynamicOpTrait() { mlirDynamicOpTraitDestroy(trait); }
bool attach(std::string opName, DefaultingPyMlirContext context) {
+ assert(this->trait.ptr && "Trait has already been attached");
+
+ MlirDynamicOpTrait trait = this->trait;
+ this->trait = MlirDynamicOpTrait{nullptr};
return mlirDynamicOpTraitAttach(trait,
MlirStringRef{opName.data(), opName.size()},
context.get()->get());
}
+ bool attachToOpView(const nanobind::type_object &opView,
+ DefaultingPyMlirContext context) {
+ return attach(nanobind::cast<std::string>(opView.attr("OPERATION_NAME")),
+ context);
+ }
+
static void bind(nanobind::module_ &m);
private:
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f8fc737e9b8fe..2b892344e2161 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2526,15 +2526,9 @@ void PyDynamicOpTrait::bind(nb::module_ &m) {
.def("attach", &PyDynamicOpTrait::attach,
"Attach the dynamic op trait to the given operation name.",
nb::arg("op_name"), nb::arg("context").none() = nb::none())
- .def(
- "attach",
- [](PyDynamicOpTrait &self, const nb::type_object &opView,
- DefaultingPyMlirContext context) {
- return self.attach(
- nb::cast<std::string>(opView.attr("OPERATION_NAME")), context);
- },
- "Attach the dynamic op trait to the given OpView class.",
- nb::arg("op_view"), nb::arg("context").none() = nb::none());
+ .def("attach", &PyDynamicOpTrait::attachToOpView,
+ "Attach the dynamic op trait to the given OpView class.",
+ nb::arg("op_view"), nb::arg("context").none() = nb::none());
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index 6fe527963895d..b33cd3c8952fc 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -31,11 +31,10 @@ bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
OperationName::Impl *impl =
static_cast<RegisteredOperationNameWithImpl &>(*opNameFound).getImpl();
- DynamicOpTrait *trait = unwrap(dynamicOpTrait);
+ std::unique_ptr<DynamicOpTrait> trait(unwrap(dynamicOpTrait));
// TODO: we should check whether the `impl` is a DynamicOpDefinition here
// via llvm-style RTTI.
- return static_cast<DynamicOpDefinition *>(impl)->addTrait(
- std::unique_ptr<DynamicOpTrait>(trait));
+ return static_cast<DynamicOpDefinition *>(impl)->addTrait(std::move(trait));
}
MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator() {
@@ -45,3 +44,7 @@ MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator() {
MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator() {
return wrap(new DynamicOpTraits::NoTerminator());
}
+
+void mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait) {
+ delete unwrap(dynamicOpTrait);
+}
>From 9e2d3c68b344e3888dcd1d1a247f7d7258680133 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 6 Feb 2026 00:06:43 +0800
Subject: [PATCH 4/8] refine
---
mlir/include/mlir/Bindings/Python/IRCore.h | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 47d2564837398..313f09cd1b7fb 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1874,7 +1874,7 @@ class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
namespace PyDynamicOpTraits {
-class IsTerminator : public PyDynamicOpTrait {
+class MLIR_PYTHON_API_EXPORTED IsTerminator : public PyDynamicOpTrait {
public:
IsTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetIsTerminator()) {}
static void bind(nanobind::module_ &m) {
@@ -1883,7 +1883,7 @@ class IsTerminator : public PyDynamicOpTrait {
}
};
-class NoTerminator : public PyDynamicOpTrait {
+class MLIR_PYTHON_API_EXPORTED NoTerminator : public PyDynamicOpTrait {
public:
NoTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetNoTerminator()) {}
static void bind(nanobind::module_ &m) {
>From 3d89b6d086878d99915279db8bd19e9fdd30da66 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 6 Feb 2026 00:11:02 +0800
Subject: [PATCH 5/8] refine
---
mlir/include/mlir-c/ExtensibleDialect.h | 6 ++----
mlir/lib/CAPI/IR/ExtensibleDialect.cpp | 3 +--
2 files changed, 3 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index c2c56d9096e63..fdbc5a9752c76 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -1,5 +1,4 @@
-//===-- mlir-c/ExtensibleDialect.h - Extensible dialect management ---*- C
-//-*-====//
+//===-- mlir-c/ExtensibleDialect.h - Extensible dialect APIs -----*- C -*-====//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
@@ -8,8 +7,7 @@
//
//===----------------------------------------------------------------------===//
//
-// This header provides basic access to the MLIR JIT. This is minimalist and
-// experimental at the moment.
+// This header provides APIs for extensible dialects.
//
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index b33cd3c8952fc..aebcb55bd1042 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -1,5 +1,4 @@
-//===- ExtensibleDialect - C API for MLIR Extensible Dialect
-//-----------------===//
+//===- ExtensibleDialect - C API for MLIR Extensible Dialect --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
>From ffb04c4e02964624a382256bf334544fb51bfa39 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 6 Feb 2026 00:19:06 +0800
Subject: [PATCH 6/8] refine
---
mlir/test/python/dialects/ext.py | 26 ++++++++++++++++++++++++++
1 file changed, 26 insertions(+)
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 0062e4ae0b804..4a497d30a1320 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -439,3 +439,29 @@ class NoTermOp(TestRegion.Operation, name="no_term"):
print(if_.then.blocks[0])
# CHECK: %c3_i32 = arith.constant 3 : i32
print(if_.else_.blocks[0])
+
+ # CHECK-LABEL: Testing violation cases
+ print("Testing violation cases:")
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+ i1 = IntegerType.get_signless(1)
+ i32 = IntegerType.get_signless(32)
+ cond = arith.constant(i1, 1)
+
+ if_ = IfOp(i32, cond)
+ if_.then.blocks.append()
+ if_.else_.blocks.append()
+
+ with InsertionPoint(if_.then.blocks[0]):
+ v = arith.constant(i32, 2)
+
+ with InsertionPoint(if_.else_.blocks[0]):
+ v = arith.constant(i32, 3)
+
+ try:
+ module.operation.verify()
+ except Exception as e:
+ # CHECK: Verification failed:
+ # CHECK: block with no terminator
+ print(e)
>From bea09d47e556f054ef95cf871ce7245d1a76d2b2 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 7 Feb 2026 14:41:08 +0800
Subject: [PATCH 7/8] refine
---
mlir/.clang-tidy | 1 -
mlir/include/mlir-c/ExtensibleDialect.h | 17 +++
mlir/include/mlir/Bindings/Python/IRCore.h | 37 +-----
mlir/lib/Bindings/Python/IRCore.cpp | 126 +++++++++++++++++++--
mlir/lib/CAPI/IR/ExtensibleDialect.cpp | 38 +++++++
mlir/test/python/dialects/ext.py | 26 ++++-
6 files changed, 204 insertions(+), 41 deletions(-)
diff --git a/mlir/.clang-tidy b/mlir/.clang-tidy
index eb8cbbeb9723c..57e2da85cce0e 100644
--- a/mlir/.clang-tidy
+++ b/mlir/.clang-tidy
@@ -29,7 +29,6 @@ Checks: >
modernize-use-emplace,
modernize-use-nullptr,
modernize-use-override,
- modernize-use-using,
performance-for-range-copy,
performance-implicit-conversion-in-loop,
performance-inefficient-algorithm,
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index fdbc5a9752c76..98457805f57c0 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -52,6 +52,23 @@ MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator(void);
MLIR_CAPI_EXPORTED void
mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait);
+typedef struct {
+ /// Optional constructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*construct)(void *userData);
+ /// Optional destructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*destruct)(void *userData);
+ /// The callback function to verify the operation.
+ MlirLogicalResult (*verifyTrait)(MlirOperation op, void *userData);
+ /// The callback function to verify the operation with access to regions.
+ MlirLogicalResult (*verifyRegionTrait)(MlirOperation op, void *userData);
+} MlirDynamicOpTraitCallbacks;
+
+/// Create a custom dynamic op trait with the given type ID and callbacks.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitCreate(
+ MlirTypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 313f09cd1b7fb..e551a49bb34a8 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1847,49 +1847,24 @@ class MLIR_PYTHON_API_EXPORTED PyOpAdaptor {
class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
public:
- PyDynamicOpTrait(MlirDynamicOpTrait trait) : trait(trait) {}
- ~PyDynamicOpTrait() { mlirDynamicOpTraitDestroy(trait); }
-
- bool attach(std::string opName, DefaultingPyMlirContext context) {
- assert(this->trait.ptr && "Trait has already been attached");
-
- MlirDynamicOpTrait trait = this->trait;
- this->trait = MlirDynamicOpTrait{nullptr};
- return mlirDynamicOpTraitAttach(trait,
- MlirStringRef{opName.data(), opName.size()},
- context.get()->get());
- }
-
- bool attachToOpView(const nanobind::type_object &opView,
- DefaultingPyMlirContext context) {
- return attach(nanobind::cast<std::string>(opView.attr("OPERATION_NAME")),
- context);
- }
+ static bool attach(const nanobind::object &opName,
+ const nanobind::object &target, PyMlirContext &context);
static void bind(nanobind::module_ &m);
-
-private:
- MlirDynamicOpTrait trait;
};
namespace PyDynamicOpTraits {
class MLIR_PYTHON_API_EXPORTED IsTerminator : public PyDynamicOpTrait {
public:
- IsTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetIsTerminator()) {}
- static void bind(nanobind::module_ &m) {
- nanobind::class_<IsTerminator, PyDynamicOpTrait>(m, "IsTerminatorTrait")
- .def(nanobind::init<>());
- }
+ static bool attach(const nanobind::object &opName, PyMlirContext &context);
+ static void bind(nanobind::module_ &m);
};
class MLIR_PYTHON_API_EXPORTED NoTerminator : public PyDynamicOpTrait {
public:
- NoTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetNoTerminator()) {}
- static void bind(nanobind::module_ &m) {
- nanobind::class_<NoTerminator, PyDynamicOpTrait>(m, "NoTerminatorTrait")
- .def(nanobind::init<>());
- }
+ static bool attach(const nanobind::object &opName, PyMlirContext &context);
+ static void bind(nanobind::module_ &m);
};
} // namespace PyDynamicOpTraits
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2b892344e2161..aaa19b65d15b1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -16,6 +16,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
+#include "mlir-c/ExtensibleDialect.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
@@ -2521,14 +2522,125 @@ void PyOpAdaptor::bind(nb::module_ &m) {
"Returns the attributes of the adaptor.");
}
+static MlirLogicalResult verifyTraitByMethod(MlirOperation op, void *userData,
+ const char *methodName) {
+ nb::handle targetObj(static_cast<PyObject *>(userData));
+ if (!nb::hasattr(targetObj, methodName)) {
+ return mlirLogicalResultSuccess();
+ }
+ PyMlirContextRef ctx = PyMlirContext::forContext(mlirOperationGetContext(op));
+ PyOperationRef pyOp = PyOperation::forOperation(ctx, op);
+ bool success = nb::cast<bool>(targetObj.attr(methodName)(pyOp.get()));
+ return success ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
+};
+
+static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait,
+ PyMlirContext &context) {
+ std::string opNameStr;
+ if (opName.is_type()) {
+ opNameStr = nb::cast<std::string>(opName.attr("OPERATION_NAME"));
+ } else if (nb::isinstance<nb::str>(opName)) {
+ opNameStr = nb::cast<std::string>(opName);
+ } else {
+ throw nb::type_error("the root argument must be a type or a string");
+ }
+
+ return mlirDynamicOpTraitAttach(
+ trait, MlirStringRef{opNameStr.data(), opNameStr.size()}, context.get());
+}
+
+bool PyDynamicOpTrait::attach(const nb::object &opName,
+ const nb::object &target,
+ PyMlirContext &context) {
+ if (!nb::hasattr(target, "verify") && !nb::hasattr(target, "verify_region"))
+ throw nb::type_error(
+ "the target object must have at least one of 'verify' or "
+ "'verify_region' methods");
+
+ MlirDynamicOpTraitCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+
+ callbacks.verifyTrait = [](MlirOperation op,
+ void *userData) -> MlirLogicalResult {
+ return verifyTraitByMethod(op, userData, "verify");
+ };
+ callbacks.verifyRegionTrait = [](MlirOperation op,
+ void *userData) -> MlirLogicalResult {
+ return verifyTraitByMethod(op, userData, "verify_region");
+ };
+
+ constexpr const char *typeIDAttr = "_TYPE_ID";
+ if (!nb::hasattr(target, typeIDAttr)) {
+ nb::setattr(target, typeIDAttr,
+ nb::cast(PyTypeID(PyGlobals::get().allocateTypeID())));
+ }
+ MlirDynamicOpTrait trait = mlirDynamicOpTraitCreate(
+ nb::cast<PyTypeID>(target.attr(typeIDAttr)).get(), callbacks,
+ static_cast<void *>(target.ptr()));
+ return attachOpTrait(opName, trait, context);
+}
+
void PyDynamicOpTrait::bind(nb::module_ &m) {
- nb::class_<PyDynamicOpTrait>(m, "DynamicOpTrait")
- .def("attach", &PyDynamicOpTrait::attach,
- "Attach the dynamic op trait to the given operation name.",
- nb::arg("op_name"), nb::arg("context").none() = nb::none())
- .def("attach", &PyDynamicOpTrait::attachToOpView,
- "Attach the dynamic op trait to the given OpView class.",
- nb::arg("op_view"), nb::arg("context").none() = nb::none());
+ nb::class_<PyDynamicOpTrait> cls(m, "DynamicOpTrait");
+ cls.def_static(
+ "attach_target",
+ [](const nb::object &opName, const nb::object &target,
+ DefaultingPyMlirContext context) {
+ return PyDynamicOpTrait::attach(opName, target, *context.get());
+ },
+ "Attach the dynamic op trait with the target object to the given "
+ "operation name.",
+ nb::arg("op_name"), nb::arg("target"),
+ nb::arg("context").none() = nb::none());
+ cls.attr("attach") = classmethod(
+ [](const nb::object &cls, const nb::object &opName,
+ DefaultingPyMlirContext context) {
+ return PyDynamicOpTrait::attach(opName, cls, *context.get());
+ },
+ nb::arg("cls"), nb::arg("op_name"),
+ nb::arg("context").none() = nb::none(),
+ "Attach the dynamic op trait subclass to the given operation name.");
+}
+
+bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName,
+ PyMlirContext &context) {
+ MlirDynamicOpTrait trait = mlirDynamicOpTraitGetIsTerminator();
+ return attachOpTrait(opName, trait, context);
+}
+
+void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) {
+ nb::class_<PyDynamicOpTraits::IsTerminator, PyDynamicOpTrait> cls(
+ m, "IsTerminatorTrait");
+ cls.attr("attach") = classmethod(
+ [](const nb::object &cls, const nb::object &opName,
+ DefaultingPyMlirContext context) {
+ return PyDynamicOpTraits::IsTerminator::attach(opName, *context.get());
+ },
+ "Attach IsTerminator trait to the given operation name.", nb::arg("cls"),
+ nb::arg("op_name"), nb::arg("context").none() = nb::none());
+}
+
+bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName,
+ PyMlirContext &context) {
+ MlirDynamicOpTrait trait = mlirDynamicOpTraitGetNoTerminator();
+ return attachOpTrait(opName, trait, context);
+}
+
+void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) {
+ nb::class_<PyDynamicOpTraits::NoTerminator, PyDynamicOpTrait> cls(
+ m, "NoTerminatorTrait");
+ cls.attr("attach") = classmethod(
+ [](const nb::object &cls, const nb::object &opName,
+ DefaultingPyMlirContext context) {
+ return PyDynamicOpTraits::NoTerminator::attach(opName, *context.get());
+ },
+ "Attach NoTerminator trait to the given operation name.", nb::arg("cls"),
+ nb::arg("op_name"), nb::arg("context").none() = nb::none());
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index aebcb55bd1042..f3239d996a0e6 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -47,3 +47,41 @@ MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator() {
void mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait) {
delete unwrap(dynamicOpTrait);
}
+
+namespace mlir {
+
+class ExternalDynamicOpTrait : public DynamicOpTrait {
+public:
+ ExternalDynamicOpTrait(TypeID typeID, MlirDynamicOpTraitCallbacks callbacks,
+ void *userData)
+ : typeID(typeID), callbacks(callbacks), userData(userData) {
+ if (callbacks.construct)
+ callbacks.construct(userData);
+ }
+ ~ExternalDynamicOpTrait() {
+ if (callbacks.destruct)
+ callbacks.destruct(userData);
+ }
+
+ LogicalResult verifyTrait(Operation *op) const override {
+ return unwrap(callbacks.verifyTrait(wrap(op), userData));
+ };
+ LogicalResult verifyRegionTrait(Operation *op) const override {
+ return unwrap(callbacks.verifyRegionTrait(wrap(op), userData));
+ };
+
+ TypeID getTypeID() const override { return typeID; };
+
+private:
+ TypeID typeID;
+ MlirDynamicOpTraitCallbacks callbacks;
+ void *userData;
+};
+
+} // namespace mlir
+
+MlirDynamicOpTrait mlirDynamicOpTraitCreate(
+ MlirTypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData) {
+ return wrap(
+ new mlir::ExternalDynamicOpTrait(unwrap(typeID), callbacks, userData));
+}
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 4a497d30a1320..4a69d2e8deb00 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -383,8 +383,19 @@ class NoTermOp(TestRegion.Operation, name="no_term"):
# CHECK: }
print(TestRegion._mlir_module)
- IsTerminatorTrait().attach(YieldOp)
- NoTerminatorTrait().attach(NoTermOp)
+ IsTerminatorTrait.attach(YieldOp)
+ NoTerminatorTrait.attach(NoTermOp)
+
+ class ParentIsIfTrait(DynamicOpTrait):
+ @staticmethod
+ def verify(op) -> bool:
+ if not isinstance(op.parent.opview, IfOp):
+ raise RuntimeError(
+ f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
+ )
+ return True
+
+ ParentIsIfTrait.attach(YieldOp)
# CHECK: (self, /, result, cond, *, loc=None, ip=None)
print(IfOp.__init__.__signature__)
@@ -465,3 +476,14 @@ class NoTermOp(TestRegion.Operation, name="no_term"):
# CHECK: Verification failed:
# CHECK: block with no terminator
print(e)
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+ v = arith.constant(i32, 2)
+ YieldOp(v)
+
+ try:
+ module.operation.verify()
+ except Exception as e:
+ # CHECK: ext_region.yield should be put inside ext_region.if
+ print(e)
>From c35375e65114f75af4808886744144bfacddb0f0 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 7 Feb 2026 14:43:31 +0800
Subject: [PATCH 8/8] remove useless change
---
mlir/.clang-tidy | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/.clang-tidy b/mlir/.clang-tidy
index 57e2da85cce0e..eb8cbbeb9723c 100644
--- a/mlir/.clang-tidy
+++ b/mlir/.clang-tidy
@@ -29,6 +29,7 @@ Checks: >
modernize-use-emplace,
modernize-use-nullptr,
modernize-use-override,
+ modernize-use-using,
performance-for-range-copy,
performance-implicit-conversion-in-loop,
performance-inefficient-algorithm,
More information about the Mlir-commits
mailing list