[Mlir-commits] [mlir] [MLIR][Python] Support `has_trait` for operations (PR #188492)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 25 07:06:41 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Twice (PragmaTwice)
<details>
<summary>Changes</summary>
This PR adds a `has_trait(trait_cls)` API to `_OperationBase`, that can used for:
- C++-defined operations and C++-defined traits (`func_return_op.has_trait(IsTerminatorTrait)`)
- Python-defined operations and C++-defined traits (`my_python_op.has_trait(IsTerminatorTrait)`)
- Python-defined operations and Python-defined traits (`func_return_op.has_trait(MyPythonTrait)`)
---
Full diff: https://github.com/llvm/llvm-project/pull/188492.diff
10 Files Affected:
- (modified) mlir/include/mlir-c/ExtensibleDialect.h (+8)
- (modified) mlir/include/mlir-c/IR.h (+4)
- (modified) mlir/include/mlir/Bindings/Python/IRCore.h (+2)
- (modified) mlir/include/mlir/IR/ExtensibleDialect.h (+2-1)
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+9-2)
- (modified) mlir/lib/CAPI/IR/ExtensibleDialect.cpp (+8)
- (modified) mlir/lib/CAPI/IR/IR.cpp (+4)
- (modified) mlir/test/python/dialects/builtin.py (+12)
- (modified) mlir/test/python/dialects/ext.py (+14-1)
- (modified) mlir/test/python/dialects/func.py (+11)
``````````diff
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index fee26772e1560..c849658c710bb 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -50,10 +50,18 @@ mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
MLIR_CAPI_EXPORTED MlirDynamicOpTrait
mlirDynamicOpTraitIsTerminatorCreate(void);
+/// Get the type ID of the dynamic op trait that indicates the operation is a
+/// terminator.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID(void);
+
/// Get the dynamic op trait that indicates regions have no terminator.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait
mlirDynamicOpTraitNoTerminatorCreate(void);
+/// Get the type ID of the dynamic op trait that indicates regions have no
+/// terminator.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID(void);
+
/// Destroy the dynamic op trait.
MLIR_CAPI_EXPORTED void
mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait);
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 805f0ffaaf7ce..efe1d933e3f6a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -631,6 +631,10 @@ MLIR_CAPI_EXPORTED size_t mlirOperationHashValue(MlirOperation op);
/// Gets the context this operation is associated with
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op);
+/// Checks if the operation has a trait identified by the given type id.
+MLIR_CAPI_EXPORTED bool mlirOperationHasTrait(MlirOperation op,
+ MlirTypeID traitTypeID);
+
/// Gets the location of the operation.
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op);
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index eefc51d519d62..6b67b78cd2184 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1861,6 +1861,8 @@ class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
const nanobind::object &target, PyMlirContext &context);
static void bind(nanobind::module_ &m);
+
+ static inline const char *typeIDAttr = "_trait_typeid";
};
namespace PyDynamicOpTraits {
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index dcbf6813506d5..d26ddafd2ef1d 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -401,7 +401,8 @@ class DynamicOpTraitList {
template <template <typename T> class Trait>
class DynamicOpTraitImpl : public DynamicOpTrait {
public:
- TypeID getTypeID() const override { return TypeID::get<Trait>(); }
+ static TypeID getStaticTypeID() { return TypeID::get<Trait>(); }
+ TypeID getTypeID() const override { return getStaticTypeID(); }
};
namespace DynamicOpTraits {
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f3f1ee4ce343f..0b5fbc635fb50 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2608,7 +2608,6 @@ bool PyDynamicOpTrait::attach(const nb::object &opName,
// To ensure that the same dynamic trait gets the same TypeID despite how many
// times `attach` is called, we store it as an attribute on the target class.
- constexpr const char *typeIDAttr = "_TYPE_ID";
if (!nb::hasattr(target, typeIDAttr)) {
nb::setattr(target, typeIDAttr,
nb::cast(PyTypeID(PyGlobals::get().allocateTypeID())));
@@ -2642,6 +2641,7 @@ bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName,
void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) {
nb::class_<PyDynamicOpTraits::IsTerminator, PyDynamicOpTrait> cls(
m, "IsTerminatorTrait");
+ cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitIsTerminatorGetTypeID());
cls.attr("attach") = classmethod(
[](const nb::object &cls, const nb::object &opName,
DefaultingPyMlirContext context) {
@@ -2660,6 +2660,7 @@ bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName,
void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) {
nb::class_<PyDynamicOpTraits::NoTerminator, PyDynamicOpTrait> cls(
m, "NoTerminatorTrait");
+ cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitNoTerminatorGetTypeID());
cls.attr("attach") = classmethod(
[](const nb::object &cls, const nb::object &opName,
DefaultingPyMlirContext context) {
@@ -3963,7 +3964,13 @@ void populateIRCore(nb::module_ &m) {
Args:
callback: A callable that takes an Operation and returns a WalkResult.
walk_order: The order of traversal (PRE_ORDER or POST_ORDER).
- op_class: If provided, only operations of this type are passed to the callback.)");
+ op_class: If provided, only operations of this type are passed to the callback.)")
+ .def("has_trait", [](PyOperationBase &self, nb::type_object &traitCls) {
+ PyTypeID traitTypeID =
+ nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
+ return mlirOperationHasTrait(self.getOperation().get(),
+ traitTypeID.get());
+ });
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
.def_static(
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index 60a7f7d9064eb..5c1f74884f2c5 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -42,10 +42,18 @@ MlirDynamicOpTrait mlirDynamicOpTraitIsTerminatorCreate() {
return wrap(new DynamicOpTraits::IsTerminator());
}
+MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID() {
+ return wrap(DynamicOpTraits::IsTerminator::getStaticTypeID());
+}
+
MlirDynamicOpTrait mlirDynamicOpTraitNoTerminatorCreate() {
return wrap(new DynamicOpTraits::NoTerminator());
}
+MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID() {
+ return wrap(DynamicOpTraits::NoTerminator::getStaticTypeID());
+}
+
void mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait) {
delete unwrap(dynamicOpTrait);
}
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 29f9287279b8f..9f39ebfa9a450 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -651,6 +651,10 @@ MlirContext mlirOperationGetContext(MlirOperation op) {
return wrap(unwrap(op)->getContext());
}
+bool mlirOperationHasTrait(MlirOperation op, MlirTypeID traitTypeID) {
+ return unwrap(op)->getName().hasTrait(unwrap(traitTypeID));
+}
+
MlirLocation mlirOperationGetLocation(MlirOperation op) {
return wrap(unwrap(op)->getLoc());
}
diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 973a0eaeca2cd..bd58b2880a816 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -250,3 +250,15 @@ def testDenseElementsAttr():
idx_type = IndexType.get()
print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type)))
# CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex>
+
+
+# CHECK-LABEL: TEST: testBuiltinTraits
+ at run
+def testBuiltinTraits():
+ with Context() as ctx, Location.unknown() as loc:
+ module = builtin.ModuleOp()
+
+ # CHECK: True
+ print(module.has_trait(NoTerminatorTrait))
+ # CHECK: False
+ print(module.has_trait(IsTerminatorTrait))
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 78c74684cef77..c161ccf6970bf 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -444,7 +444,7 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
- YieldOp(v)
+ yield_ = YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
@@ -473,6 +473,19 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
# CHECK: }
print(module)
+ # CHECK: True
+ print(yield_.has_trait(IsTerminatorTrait))
+ # CHECK: False
+ print(yield_.has_trait(NoTerminatorTrait))
+ # CHECK: True
+ print(yield_.has_trait(ParentIsIfTrait))
+ # CHECK: False
+ print(nt.has_trait(IsTerminatorTrait))
+ # CHECK: True
+ print(nt.has_trait(NoTerminatorTrait))
+ # CHECK: False
+ print(nt.has_trait(ParentIsIfTrait))
+
# CHECK: %c2_i32 = arith.constant 2 : i32
print(if_.then.blocks[0])
# CHECK: %c3_i32 = arith.constant 3 : i32
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index 6b3932ce64f13..cfe4ca4a56ed6 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -133,3 +133,14 @@ def testFunctionArgAttrs():
# CHECK: func private @foo(f32 {test.foo = "bar"})
# CHECK: func private @foo2(f32, f32 {test.baz = "qux"})
+
+
+# CHECK-LABEL: TEST: testFunctionTraits
+ at constructAndPrintInModule
+def testFunctionTraits():
+ ret = func.ReturnOp([])
+
+ # CHECK: False
+ print(ret.has_trait(NoTerminatorTrait))
+ # CHECK: True
+ print(ret.has_trait(IsTerminatorTrait))
``````````
</details>
https://github.com/llvm/llvm-project/pull/188492
More information about the Mlir-commits
mailing list