[Mlir-commits] [mlir] [MLIR][Python] Support `has_trait` for operations (PR #188492)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 26 08:31:40 PDT 2026
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/188492
>From f138fd904532c958432638b9d294ce51c1aaee9e Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 25 Mar 2026 22:01:18 +0800
Subject: [PATCH 1/3] [MLIR][Python] Support has_trait for operations
---
mlir/include/mlir-c/ExtensibleDialect.h | 8 ++++++++
mlir/include/mlir-c/IR.h | 4 ++++
mlir/include/mlir/Bindings/Python/IRCore.h | 2 ++
mlir/include/mlir/IR/ExtensibleDialect.h | 3 ++-
mlir/lib/Bindings/Python/IRCore.cpp | 11 +++++++++--
mlir/lib/CAPI/IR/ExtensibleDialect.cpp | 8 ++++++++
mlir/lib/CAPI/IR/IR.cpp | 4 ++++
mlir/test/python/dialects/builtin.py | 12 ++++++++++++
mlir/test/python/dialects/ext.py | 15 ++++++++++++++-
mlir/test/python/dialects/func.py | 11 +++++++++++
10 files changed, 74 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index fee26772e1560..c849658c710bb 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -50,10 +50,18 @@ mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
MLIR_CAPI_EXPORTED MlirDynamicOpTrait
mlirDynamicOpTraitIsTerminatorCreate(void);
+/// Get the type ID of the dynamic op trait that indicates the operation is a
+/// terminator.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID(void);
+
/// Get the dynamic op trait that indicates regions have no terminator.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait
mlirDynamicOpTraitNoTerminatorCreate(void);
+/// Get the type ID of the dynamic op trait that indicates regions have no
+/// terminator.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID(void);
+
/// Destroy the dynamic op trait.
MLIR_CAPI_EXPORTED void
mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait);
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 805f0ffaaf7ce..efe1d933e3f6a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -631,6 +631,10 @@ MLIR_CAPI_EXPORTED size_t mlirOperationHashValue(MlirOperation op);
/// Gets the context this operation is associated with
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op);
+/// Checks if the operation has a trait identified by the given type id.
+MLIR_CAPI_EXPORTED bool mlirOperationHasTrait(MlirOperation op,
+ MlirTypeID traitTypeID);
+
/// Gets the location of the operation.
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op);
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index eefc51d519d62..6b67b78cd2184 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1861,6 +1861,8 @@ class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
const nanobind::object &target, PyMlirContext &context);
static void bind(nanobind::module_ &m);
+
+ static inline const char *typeIDAttr = "_trait_typeid";
};
namespace PyDynamicOpTraits {
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index dcbf6813506d5..d26ddafd2ef1d 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -401,7 +401,8 @@ class DynamicOpTraitList {
template <template <typename T> class Trait>
class DynamicOpTraitImpl : public DynamicOpTrait {
public:
- TypeID getTypeID() const override { return TypeID::get<Trait>(); }
+ static TypeID getStaticTypeID() { return TypeID::get<Trait>(); }
+ TypeID getTypeID() const override { return getStaticTypeID(); }
};
namespace DynamicOpTraits {
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f3f1ee4ce343f..0b5fbc635fb50 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2608,7 +2608,6 @@ bool PyDynamicOpTrait::attach(const nb::object &opName,
// To ensure that the same dynamic trait gets the same TypeID despite how many
// times `attach` is called, we store it as an attribute on the target class.
- constexpr const char *typeIDAttr = "_TYPE_ID";
if (!nb::hasattr(target, typeIDAttr)) {
nb::setattr(target, typeIDAttr,
nb::cast(PyTypeID(PyGlobals::get().allocateTypeID())));
@@ -2642,6 +2641,7 @@ bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName,
void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) {
nb::class_<PyDynamicOpTraits::IsTerminator, PyDynamicOpTrait> cls(
m, "IsTerminatorTrait");
+ cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitIsTerminatorGetTypeID());
cls.attr("attach") = classmethod(
[](const nb::object &cls, const nb::object &opName,
DefaultingPyMlirContext context) {
@@ -2660,6 +2660,7 @@ bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName,
void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) {
nb::class_<PyDynamicOpTraits::NoTerminator, PyDynamicOpTrait> cls(
m, "NoTerminatorTrait");
+ cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitNoTerminatorGetTypeID());
cls.attr("attach") = classmethod(
[](const nb::object &cls, const nb::object &opName,
DefaultingPyMlirContext context) {
@@ -3963,7 +3964,13 @@ void populateIRCore(nb::module_ &m) {
Args:
callback: A callable that takes an Operation and returns a WalkResult.
walk_order: The order of traversal (PRE_ORDER or POST_ORDER).
- op_class: If provided, only operations of this type are passed to the callback.)");
+ op_class: If provided, only operations of this type are passed to the callback.)")
+ .def("has_trait", [](PyOperationBase &self, nb::type_object &traitCls) {
+ PyTypeID traitTypeID =
+ nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
+ return mlirOperationHasTrait(self.getOperation().get(),
+ traitTypeID.get());
+ });
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
.def_static(
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index 60a7f7d9064eb..5c1f74884f2c5 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -42,10 +42,18 @@ MlirDynamicOpTrait mlirDynamicOpTraitIsTerminatorCreate() {
return wrap(new DynamicOpTraits::IsTerminator());
}
+MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID() {
+ return wrap(DynamicOpTraits::IsTerminator::getStaticTypeID());
+}
+
MlirDynamicOpTrait mlirDynamicOpTraitNoTerminatorCreate() {
return wrap(new DynamicOpTraits::NoTerminator());
}
+MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID() {
+ return wrap(DynamicOpTraits::NoTerminator::getStaticTypeID());
+}
+
void mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait) {
delete unwrap(dynamicOpTrait);
}
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 29f9287279b8f..9f39ebfa9a450 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -651,6 +651,10 @@ MlirContext mlirOperationGetContext(MlirOperation op) {
return wrap(unwrap(op)->getContext());
}
+bool mlirOperationHasTrait(MlirOperation op, MlirTypeID traitTypeID) {
+ return unwrap(op)->getName().hasTrait(unwrap(traitTypeID));
+}
+
MlirLocation mlirOperationGetLocation(MlirOperation op) {
return wrap(unwrap(op)->getLoc());
}
diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 973a0eaeca2cd..bd58b2880a816 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -250,3 +250,15 @@ def testDenseElementsAttr():
idx_type = IndexType.get()
print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type)))
# CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex>
+
+
+# CHECK-LABEL: TEST: testBuiltinTraits
+ at run
+def testBuiltinTraits():
+ with Context() as ctx, Location.unknown() as loc:
+ module = builtin.ModuleOp()
+
+ # CHECK: True
+ print(module.has_trait(NoTerminatorTrait))
+ # CHECK: False
+ print(module.has_trait(IsTerminatorTrait))
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 78c74684cef77..c161ccf6970bf 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -444,7 +444,7 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
- YieldOp(v)
+ yield_ = YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
@@ -473,6 +473,19 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
# CHECK: }
print(module)
+ # CHECK: True
+ print(yield_.has_trait(IsTerminatorTrait))
+ # CHECK: False
+ print(yield_.has_trait(NoTerminatorTrait))
+ # CHECK: True
+ print(yield_.has_trait(ParentIsIfTrait))
+ # CHECK: False
+ print(nt.has_trait(IsTerminatorTrait))
+ # CHECK: True
+ print(nt.has_trait(NoTerminatorTrait))
+ # CHECK: False
+ print(nt.has_trait(ParentIsIfTrait))
+
# CHECK: %c2_i32 = arith.constant 2 : i32
print(if_.then.blocks[0])
# CHECK: %c3_i32 = arith.constant 3 : i32
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index 6b3932ce64f13..cfe4ca4a56ed6 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -133,3 +133,14 @@ def testFunctionArgAttrs():
# CHECK: func private @foo(f32 {test.foo = "bar"})
# CHECK: func private @foo2(f32, f32 {test.baz = "qux"})
+
+
+# CHECK-LABEL: TEST: testFunctionTraits
+ at constructAndPrintInModule
+def testFunctionTraits():
+ ret = func.ReturnOp([])
+
+ # CHECK: False
+ print(ret.has_trait(NoTerminatorTrait))
+ # CHECK: True
+ print(ret.has_trait(IsTerminatorTrait))
>From acf4ead4607965ca96600bde56f34512c9f7de1b Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Thu, 26 Mar 2026 23:16:49 +0800
Subject: [PATCH 2/3] Update mlir/lib/CAPI/IR/IR.cpp
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
mlir/lib/CAPI/IR/IR.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 9f39ebfa9a450..3d06ba654a6f5 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -651,8 +651,8 @@ MlirContext mlirOperationGetContext(MlirOperation op) {
return wrap(unwrap(op)->getContext());
}
-bool mlirOperationHasTrait(MlirOperation op, MlirTypeID traitTypeID) {
- return unwrap(op)->getName().hasTrait(unwrap(traitTypeID));
+bool mlirOperationNameHasTrait(MlirStringRef opName, MlirTypeID traitTypeID, MlirContext context) {
+ return OperationName(unwrap(opName), unwrap(context)).hasTrait(unwrap(traitTypeID));
}
MlirLocation mlirOperationGetLocation(MlirOperation op) {
>From 94e4b0629bc2c39bd31485c5e1ab9a65b82f6ca9 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 26 Mar 2026 23:31:19 +0800
Subject: [PATCH 3/3] fix
---
mlir/include/mlir-c/IR.h | 7 ++++---
mlir/lib/Bindings/Python/IRCore.cpp | 30 ++++++++++++++++++++++------
mlir/lib/CAPI/IR/IR.cpp | 6 ++++--
mlir/test/python/dialects/builtin.py | 8 ++++++--
mlir/test/python/dialects/ext.py | 12 ++++++++---
mlir/test/python/dialects/func.py | 8 ++++++--
6 files changed, 53 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index efe1d933e3f6a..8d30051d615f4 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -631,9 +631,10 @@ MLIR_CAPI_EXPORTED size_t mlirOperationHashValue(MlirOperation op);
/// Gets the context this operation is associated with
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op);
-/// Checks if the operation has a trait identified by the given type id.
-MLIR_CAPI_EXPORTED bool mlirOperationHasTrait(MlirOperation op,
- MlirTypeID traitTypeID);
+/// Checks if the operation name has a trait identified by the given type id.
+MLIR_CAPI_EXPORTED bool mlirOperationNameHasTrait(MlirStringRef opName,
+ MlirTypeID traitTypeID,
+ MlirContext context);
/// Gets the location of the operation.
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0b5fbc635fb50..9f35509075a3a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3965,12 +3965,18 @@ void populateIRCore(nb::module_ &m) {
callback: A callable that takes an Operation and returns a WalkResult.
walk_order: The order of traversal (PRE_ORDER or POST_ORDER).
op_class: If provided, only operations of this type are passed to the callback.)")
- .def("has_trait", [](PyOperationBase &self, nb::type_object &traitCls) {
- PyTypeID traitTypeID =
- nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
- return mlirOperationHasTrait(self.getOperation().get(),
- traitTypeID.get());
- });
+ .def(
+ "has_trait",
+ [](PyOperationBase &self, nb::type_object &traitCls) {
+ PyTypeID traitTypeID =
+ nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
+ MlirIdentifier opName =
+ mlirOperationGetName(self.getOperation().get());
+ return mlirOperationNameHasTrait(
+ mlirIdentifierStr(opName), traitTypeID.get(),
+ self.getOperation().getContext()->get());
+ },
+ "trait_cls"_a, "Checks if the operation has a given trait.");
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
.def_static(
@@ -4172,6 +4178,18 @@ void populateIRCore(nb::module_ &m) {
"cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "",
"context"_a = nb::none(),
"Parses a specific, generated OpView based on class level attributes.");
+ opViewClass.attr("has_trait") = classmethod(
+ [](nb::object &self, nb::type_object &traitCls,
+ DefaultingPyMlirContext &context) {
+ PyTypeID traitTypeID =
+ nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
+ std::string opName = nb::cast<std::string>(self.attr("OPERATION_NAME"));
+ return mlirOperationNameHasTrait(
+ mlirStringRefCreate(opName.data(), opName.size()),
+ traitTypeID.get(), context->get());
+ },
+ "cls"_a, "trait_cls"_a, "context"_a = nb::none(),
+ "Checks if the operation has a given trait.");
PyOpAdaptor::bind(m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 3d06ba654a6f5..c7ecab15a04d4 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -651,8 +651,10 @@ MlirContext mlirOperationGetContext(MlirOperation op) {
return wrap(unwrap(op)->getContext());
}
-bool mlirOperationNameHasTrait(MlirStringRef opName, MlirTypeID traitTypeID, MlirContext context) {
- return OperationName(unwrap(opName), unwrap(context)).hasTrait(unwrap(traitTypeID));
+bool mlirOperationNameHasTrait(MlirStringRef opName, MlirTypeID traitTypeID,
+ MlirContext context) {
+ return OperationName(unwrap(opName), unwrap(context))
+ .hasTrait(unwrap(traitTypeID));
}
MlirLocation mlirOperationGetLocation(MlirOperation op) {
diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index bd58b2880a816..52ada96ee46f9 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -259,6 +259,10 @@ def testBuiltinTraits():
module = builtin.ModuleOp()
# CHECK: True
- print(module.has_trait(NoTerminatorTrait))
+ print(module.operation.has_trait(NoTerminatorTrait))
# CHECK: False
- print(module.has_trait(IsTerminatorTrait))
+ print(module.operation.has_trait(IsTerminatorTrait))
+ # CHECK: True
+ print(builtin.ModuleOp.has_trait(NoTerminatorTrait))
+ # CHECK: False
+ print(builtin.ModuleOp.has_trait(IsTerminatorTrait))
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index c161ccf6970bf..90f68b876b139 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -480,11 +480,17 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
# CHECK: True
print(yield_.has_trait(ParentIsIfTrait))
# CHECK: False
- print(nt.has_trait(IsTerminatorTrait))
+ print(nt.operation.has_trait(IsTerminatorTrait))
# CHECK: True
- print(nt.has_trait(NoTerminatorTrait))
+ print(nt.operation.has_trait(NoTerminatorTrait))
# CHECK: False
- print(nt.has_trait(ParentIsIfTrait))
+ print(nt.operation.has_trait(ParentIsIfTrait))
+ # CHECK: False
+ print(NoTermOp.has_trait(IsTerminatorTrait))
+ # CHECK: True
+ print(NoTermOp.has_trait(NoTerminatorTrait))
+ # CHECK: False
+ print(NoTermOp.has_trait(ParentIsIfTrait))
# CHECK: %c2_i32 = arith.constant 2 : i32
print(if_.then.blocks[0])
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index cfe4ca4a56ed6..3c56bd993aec4 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -141,6 +141,10 @@ def testFunctionTraits():
ret = func.ReturnOp([])
# CHECK: False
- print(ret.has_trait(NoTerminatorTrait))
+ print(ret.operation.has_trait(NoTerminatorTrait))
# CHECK: True
- print(ret.has_trait(IsTerminatorTrait))
+ print(ret.operation.has_trait(IsTerminatorTrait))
+ # CHECK: False
+ print(func.ReturnOp.has_trait(NoTerminatorTrait))
+ # CHECK: True
+ print(func.ReturnOp.has_trait(IsTerminatorTrait))
More information about the Mlir-commits
mailing list