[Mlir-commits] [mlir] [MLIR][Python] Support dynamic traits in python-defined dialects (PR #179705)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 4 08:20:00 PST 2026
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/179705
🚧 WIP 🚧
>From 3b084d320d8b30c665c1d2a3a52dfe6aacca982a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 5 Feb 2026 00:04:33 +0800
Subject: [PATCH] [MLIR][Python] Support dynamic traits in python-defined
dialects
---
mlir/include/mlir-c/ExtensibleDialect.h | 55 ++++++++++++++++++++++
mlir/include/mlir/Bindings/Python/IRCore.h | 39 +++++++++++++++
mlir/lib/Bindings/Python/IRCore.cpp | 21 +++++++++
mlir/lib/CAPI/IR/CMakeLists.txt | 1 +
mlir/lib/CAPI/IR/ExtensibleDialect.cpp | 47 ++++++++++++++++++
mlir/test/python/dialects/ext.py | 55 +++++++++++++++++-----
6 files changed, 205 insertions(+), 13 deletions(-)
create mode 100644 mlir/include/mlir-c/ExtensibleDialect.h
create mode 100644 mlir/lib/CAPI/IR/ExtensibleDialect.cpp
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
new file mode 100644
index 0000000000000..4a77c76c92d76
--- /dev/null
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -0,0 +1,55 @@
+//===-- mlir-c/ExtensibleDialect.h - Extensible dialect management ---*- C
+//-*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header provides basic access to the MLIR JIT. This is minimalist and
+// experimental at the moment.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_EXTENSIBLEDIALECT_H
+#define MLIR_C_EXTENSIBLEDIALECT_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//===----------------------------------------------------------------------===//
+/// Opaque type declarations (see mlir-c/IR.h for more details).
+//===----------------------------------------------------------------------===//
+
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirDynamicOpTrait, void);
+
+/// Attach a dynamic op trait to the given operation name.
+/// Note that the operation name must be modeled by dynamic dialect and must be
+/// registered.
+MLIR_CAPI_EXPORTED bool
+mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
+ MlirStringRef opName, MlirContext context);
+
+/// Get the dynamic op trait that indicates the operation is a terminator.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator(void);
+
+/// Get the dynamic op trait that indicates regions have no terminator.
+MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator(void);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_EXTENSIBLEDIALECT_H
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 4bb49e6bc245d..45a04ccc4bb3d 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -23,6 +23,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
+#include "mlir-c/ExtensibleDialect.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Support.h"
@@ -1844,6 +1845,44 @@ class MLIR_PYTHON_API_EXPORTED PyOpAdaptor {
PyOpAttributeMap attributes;
};
+class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
+public:
+ PyDynamicOpTrait(MlirDynamicOpTrait trait) : trait(trait) {}
+
+ bool attach(std::string opName, DefaultingPyMlirContext context) {
+ return mlirDynamicOpTraitAttach(trait,
+ MlirStringRef{opName.data(), opName.size()},
+ context.get()->get());
+ }
+
+ static void bind(nanobind::module_ &m);
+
+private:
+ MlirDynamicOpTrait trait;
+};
+
+namespace PyDynamicOpTraits {
+
+class IsTerminator : public PyDynamicOpTrait {
+public:
+ IsTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetIsTerminator()) {}
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<IsTerminator, PyDynamicOpTrait>(m, "IsTerminatorTrait")
+ .def(nanobind::init<>());
+ }
+};
+
+class NoTerminator : public PyDynamicOpTrait {
+public:
+ NoTerminator() : PyDynamicOpTrait(mlirDynamicOpTraitGetNoTerminator()) {}
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<NoTerminator, PyDynamicOpTrait>(m, "NoTerminatorTrait")
+ .def(nanobind::init<>());
+ }
+};
+
+} // namespace PyDynamicOpTraits
+
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7f34343eba6c9..f8fc737e9b8fe 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2521,6 +2521,22 @@ void PyOpAdaptor::bind(nb::module_ &m) {
"Returns the attributes of the adaptor.");
}
+void PyDynamicOpTrait::bind(nb::module_ &m) {
+ nb::class_<PyDynamicOpTrait>(m, "DynamicOpTrait")
+ .def("attach", &PyDynamicOpTrait::attach,
+ "Attach the dynamic op trait to the given operation name.",
+ nb::arg("op_name"), nb::arg("context").none() = nb::none())
+ .def(
+ "attach",
+ [](PyDynamicOpTrait &self, const nb::type_object &opView,
+ DefaultingPyMlirContext context) {
+ return self.attach(
+ nb::cast<std::string>(opView.attr("OPERATION_NAME")), context);
+ },
+ "Attach the dynamic op trait to the given OpView class.",
+ nb::arg("op_view"), nb::arg("context").none() = nb::none());
+}
+
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
@@ -4844,6 +4860,11 @@ void populateIRCore(nb::module_ &m) {
// Attribute builder getter.
PyAttrBuilderMap::bind(m);
+
+ // Extensible Dialect
+ PyDynamicOpTrait::bind(m);
+ PyDynamicOpTraits::IsTerminator::bind(m);
+ PyDynamicOpTraits::NoTerminator::bind(m);
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 36f28520d6757..d78f9d9735aa3 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIIR
BuiltinTypes.cpp
Diagnostics.cpp
DialectHandle.cpp
+ ExtensibleDialect.cpp
IntegerSet.cpp
IR.cpp
Pass.cpp
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
new file mode 100644
index 0000000000000..6fe527963895d
--- /dev/null
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -0,0 +1,47 @@
+//===- ExtensibleDialect - C API for MLIR Extensible Dialect
+//-----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/ExtensibleDialect.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/OperationSupport.h"
+
+using namespace mlir;
+
+DEFINE_C_API_PTR_METHODS(MlirDynamicOpTrait, DynamicOpTrait)
+
+bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
+ MlirStringRef opName, MlirContext context) {
+ std::optional<RegisteredOperationName> opNameFound =
+ RegisteredOperationName::lookup(unwrap(opName), unwrap(context));
+ assert(opNameFound && "operation name must be registered in the context");
+
+ // The original getImpl() is protected, so we create a small helper struct
+ // here.
+ struct RegisteredOperationNameWithImpl : RegisteredOperationName {
+ Impl *getImpl() { return RegisteredOperationName::getImpl(); }
+ };
+ OperationName::Impl *impl =
+ static_cast<RegisteredOperationNameWithImpl &>(*opNameFound).getImpl();
+
+ DynamicOpTrait *trait = unwrap(dynamicOpTrait);
+ // TODO: we should check whether the `impl` is a DynamicOpDefinition here
+ // via llvm-style RTTI.
+ return static_cast<DynamicOpDefinition *>(impl)->addTrait(
+ std::unique_ptr<DynamicOpTrait>(trait));
+}
+
+MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator() {
+ return wrap(new DynamicOpTraits::IsTerminator());
+}
+
+MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator() {
+ return wrap(new DynamicOpTraits::NoTerminator());
+}
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 30e705726756b..48ac5c5c51d19 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -350,22 +350,43 @@ class TestRegion(Dialect, name="ext_region"):
class IfOp(TestRegion.Operation, name="if"):
cond: Operand[IntegerType[1]]
+ result: Result[Any]
then: Region
else_: Region
+
+ class YieldOp(TestRegion.Operation, name="yield"):
+ value: Operand[Any]
+
+ class NoTermOp(TestRegion.Operation, name="no_term"):
+ body: Region
with Context(), Location.unknown():
TestRegion.load()
# CHECK: irdl.dialect @ext_region {
- # CHECK: irdl.operation @if {
+ # CHECK: irdl.operation @if {
# CHECK: %0 = irdl.is i1
# CHECK: irdl.operands(cond: %0)
- # CHECK: %1 = irdl.region
+ # CHECK: %1 = irdl.any
+ # CHECK: irdl.results(result: %1)
# CHECK: %2 = irdl.region
- # CHECK: irdl.regions(then: %1, else_: %2)
+ # CHECK: %3 = irdl.region
+ # CHECK: irdl.regions(then: %2, else_: %3)
+ # CHECK: }
+ # CHECK: irdl.operation @yield {
+ # CHECK: %0 = irdl.any
+ # CHECK: irdl.operands(value: %0)
+ # CHECK: }
+ # CHECK: irdl.operation @no_term {
+ # CHECK: %0 = irdl.region
+ # CHECK: irdl.regions(body: %0)
+ # CHECK: }
# CHECK: }
print(TestRegion._mlir_module)
- # CHECK: (self, /, cond, *, loc=None, ip=None)
+ IsTerminatorTrait().attach(YieldOp)
+ NoTerminatorTrait().attach(NoTermOp)
+
+ # CHECK: (self, /, result, cond, *, loc=None, ip=None)
print(IfOp.__init__.__signature__)
# CHECK: None None
@@ -373,36 +394,44 @@ class IfOp(TestRegion.Operation, name="if"):
# CHECK: (2, True)
print(IfOp._ODS_REGIONS)
- from mlir.dialects import llvm
-
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
- if_ = IfOp(cond)
+ if_ = IfOp(i32, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
- llvm.unreachable()
+ YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
- llvm.unreachable()
+ YieldOp(v)
+
+ nt = NoTermOp()
+ nt.body.blocks.append()
+
+ with InsertionPoint(nt.body.blocks[0]):
+ arith.constant(i32, 4)
+ # No terminator here
assert module.operation.verify()
# CHECK: module {
# CHECK: %true = arith.constant true
- # CHECK: "ext_region.if"(%true) ({
+ # CHECK: %0 = "ext_region.if"(%true) ({
# CHECK: %c2_i32 = arith.constant 2 : i32
- # CHECK: llvm.unreachable
+ # CHECK: "ext_region.yield"(%c2_i32) : (i32) -> ()
# CHECK: }, {
# CHECK: %c3_i32 = arith.constant 3 : i32
- # CHECK: llvm.unreachable
- # CHECK: }) : (i1) -> ()
+ # CHECK: "ext_region.yield"(%c3_i32) : (i32) -> ()
+ # CHECK: }) : (i1) -> i32
+ # CHECK: "ext_region.no_term"() ({
+ # CHECK: %c4_i32 = arith.constant 4 : i32
+ # CHECK: }) : () -> ()
# CHECK: }
print(module)
More information about the Mlir-commits
mailing list