[Mlir-commits] [mlir] 1503293 - [MLIR][Python] Support `has_trait` for operations (#188492)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 26 09:02:36 PDT 2026
Author: Twice
Date: 2026-03-27T00:02:31+08:00
New Revision: 1503293cbc7ee0fbeece3c48c2c4d9f4d87d19fd
URL: https://github.com/llvm/llvm-project/commit/1503293cbc7ee0fbeece3c48c2c4d9f4d87d19fd
DIFF: https://github.com/llvm/llvm-project/commit/1503293cbc7ee0fbeece3c48c2c4d9f4d87d19fd.diff
LOG: [MLIR][Python] Support `has_trait` for operations (#188492)
This PR adds a `has_trait(trait_cls)` API to `_OperationBase`, that can
be used for:
- C++-defined operations and C++-defined traits (e.g.
`func_return_op.has_trait(IsTerminatorTrait)`)
- Python-defined operations and C++-defined traits (e.g.
`my_python_op.has_trait(IsTerminatorTrait)`)
- Python-defined operations and Python-defined traits (e.g.
`my_python_op.has_trait(MyPythonTrait)`)
---------
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
Added:
Modified:
mlir/include/mlir-c/ExtensibleDialect.h
mlir/include/mlir-c/IR.h
mlir/include/mlir/Bindings/Python/IRCore.h
mlir/include/mlir/IR/ExtensibleDialect.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/ExtensibleDialect.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/python/dialects/builtin.py
mlir/test/python/dialects/ext.py
mlir/test/python/dialects/func.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index fee26772e1560..c849658c710bb 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -50,10 +50,18 @@ mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
MLIR_CAPI_EXPORTED MlirDynamicOpTrait
mlirDynamicOpTraitIsTerminatorCreate(void);
+/// Get the type ID of the dynamic op trait that indicates the operation is a
+/// terminator.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID(void);
+
/// Get the dynamic op trait that indicates regions have no terminator.
MLIR_CAPI_EXPORTED MlirDynamicOpTrait
mlirDynamicOpTraitNoTerminatorCreate(void);
+/// Get the type ID of the dynamic op trait that indicates regions have no
+/// terminator.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID(void);
+
/// Destroy the dynamic op trait.
MLIR_CAPI_EXPORTED void
mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait);
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 805f0ffaaf7ce..8d30051d615f4 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -631,6 +631,11 @@ MLIR_CAPI_EXPORTED size_t mlirOperationHashValue(MlirOperation op);
/// Gets the context this operation is associated with
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op);
+/// Checks if the operation name has a trait identified by the given type id.
+MLIR_CAPI_EXPORTED bool mlirOperationNameHasTrait(MlirStringRef opName,
+ MlirTypeID traitTypeID,
+ MlirContext context);
+
/// Gets the location of the operation.
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op);
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index db8427cfc4f78..5f451606d5f82 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1842,6 +1842,8 @@ class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
const nanobind::object &target, PyMlirContext &context);
static void bind(nanobind::module_ &m);
+
+ static inline const char *typeIDAttr = "_trait_typeid";
};
namespace PyDynamicOpTraits {
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index dcbf6813506d5..d26ddafd2ef1d 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -401,7 +401,8 @@ class DynamicOpTraitList {
template <template <typename T> class Trait>
class DynamicOpTraitImpl : public DynamicOpTrait {
public:
- TypeID getTypeID() const override { return TypeID::get<Trait>(); }
+ static TypeID getStaticTypeID() { return TypeID::get<Trait>(); }
+ TypeID getTypeID() const override { return getStaticTypeID(); }
};
namespace DynamicOpTraits {
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 350eb52765e1a..b52328a37e5f1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2583,7 +2583,6 @@ bool PyDynamicOpTrait::attach(const nb::object &opName,
// To ensure that the same dynamic trait gets the same TypeID despite how many
// times `attach` is called, we store it as an attribute on the target class.
- constexpr const char *typeIDAttr = "_TYPE_ID";
if (!nb::hasattr(target, typeIDAttr)) {
nb::setattr(target, typeIDAttr,
nb::cast(PyTypeID(PyGlobals::get().allocateTypeID())));
@@ -2617,6 +2616,7 @@ bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName,
void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) {
nb::class_<PyDynamicOpTraits::IsTerminator, PyDynamicOpTrait> cls(
m, "IsTerminatorTrait");
+ cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitIsTerminatorGetTypeID());
cls.attr("attach") = classmethod(
[](const nb::object &cls, const nb::object &opName,
DefaultingPyMlirContext context) {
@@ -2635,6 +2635,7 @@ bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName,
void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) {
nb::class_<PyDynamicOpTraits::NoTerminator, PyDynamicOpTrait> cls(
m, "NoTerminatorTrait");
+ cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitNoTerminatorGetTypeID());
cls.attr("attach") = classmethod(
[](const nb::object &cls, const nb::object &opName,
DefaultingPyMlirContext context) {
@@ -3938,7 +3939,19 @@ void populateIRCore(nb::module_ &m) {
Args:
callback: A callable that takes an Operation and returns a WalkResult.
walk_order: The order of traversal (PRE_ORDER or POST_ORDER).
- op_class: If provided, only operations of this type are passed to the callback.)");
+ op_class: If provided, only operations of this type are passed to the callback.)")
+ .def(
+ "has_trait",
+ [](PyOperationBase &self, nb::type_object &traitCls) {
+ PyTypeID traitTypeID =
+ nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
+ MlirIdentifier opName =
+ mlirOperationGetName(self.getOperation().get());
+ return mlirOperationNameHasTrait(
+ mlirIdentifierStr(opName), traitTypeID.get(),
+ self.getOperation().getContext()->get());
+ },
+ "trait_cls"_a, "Checks if the operation has a given trait.");
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
.def_static(
@@ -4150,6 +4163,18 @@ void populateIRCore(nb::module_ &m) {
nb::sig("def parse(cls, source: str, *, source_name: str = '', context: Context | None = None) -> typing.Self"),
// clang-format on
"Parses a specific, generated OpView based on class level attributes.");
+ opViewClass.attr("has_trait") = classmethod(
+ [](nb::object &self, nb::type_object &traitCls,
+ DefaultingPyMlirContext &context) {
+ PyTypeID traitTypeID =
+ nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
+ std::string opName = nb::cast<std::string>(self.attr("OPERATION_NAME"));
+ return mlirOperationNameHasTrait(
+ mlirStringRefCreate(opName.data(), opName.size()),
+ traitTypeID.get(), context->get());
+ },
+ "cls"_a, "trait_cls"_a, "context"_a = nb::none(),
+ "Checks if the operation has a given trait.");
PyOpAdaptor::bind(m);
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index 60a7f7d9064eb..5c1f74884f2c5 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -42,10 +42,18 @@ MlirDynamicOpTrait mlirDynamicOpTraitIsTerminatorCreate() {
return wrap(new DynamicOpTraits::IsTerminator());
}
+MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID() {
+ return wrap(DynamicOpTraits::IsTerminator::getStaticTypeID());
+}
+
MlirDynamicOpTrait mlirDynamicOpTraitNoTerminatorCreate() {
return wrap(new DynamicOpTraits::NoTerminator());
}
+MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID() {
+ return wrap(DynamicOpTraits::NoTerminator::getStaticTypeID());
+}
+
void mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait) {
delete unwrap(dynamicOpTrait);
}
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 29f9287279b8f..c7ecab15a04d4 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -651,6 +651,12 @@ MlirContext mlirOperationGetContext(MlirOperation op) {
return wrap(unwrap(op)->getContext());
}
+bool mlirOperationNameHasTrait(MlirStringRef opName, MlirTypeID traitTypeID,
+ MlirContext context) {
+ return OperationName(unwrap(opName), unwrap(context))
+ .hasTrait(unwrap(traitTypeID));
+}
+
MlirLocation mlirOperationGetLocation(MlirOperation op) {
return wrap(unwrap(op)->getLoc());
}
diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 973a0eaeca2cd..52ada96ee46f9 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -250,3 +250,19 @@ def testDenseElementsAttr():
idx_type = IndexType.get()
print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type)))
# CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex>
+
+
+# CHECK-LABEL: TEST: testBuiltinTraits
+ at run
+def testBuiltinTraits():
+ with Context() as ctx, Location.unknown() as loc:
+ module = builtin.ModuleOp()
+
+ # CHECK: True
+ print(module.operation.has_trait(NoTerminatorTrait))
+ # CHECK: False
+ print(module.operation.has_trait(IsTerminatorTrait))
+ # CHECK: True
+ print(builtin.ModuleOp.has_trait(NoTerminatorTrait))
+ # CHECK: False
+ print(builtin.ModuleOp.has_trait(IsTerminatorTrait))
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 78c74684cef77..90f68b876b139 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -444,7 +444,7 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
- YieldOp(v)
+ yield_ = YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
@@ -473,6 +473,25 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
# CHECK: }
print(module)
+ # CHECK: True
+ print(yield_.has_trait(IsTerminatorTrait))
+ # CHECK: False
+ print(yield_.has_trait(NoTerminatorTrait))
+ # CHECK: True
+ print(yield_.has_trait(ParentIsIfTrait))
+ # CHECK: False
+ print(nt.operation.has_trait(IsTerminatorTrait))
+ # CHECK: True
+ print(nt.operation.has_trait(NoTerminatorTrait))
+ # CHECK: False
+ print(nt.operation.has_trait(ParentIsIfTrait))
+ # CHECK: False
+ print(NoTermOp.has_trait(IsTerminatorTrait))
+ # CHECK: True
+ print(NoTermOp.has_trait(NoTerminatorTrait))
+ # CHECK: False
+ print(NoTermOp.has_trait(ParentIsIfTrait))
+
# CHECK: %c2_i32 = arith.constant 2 : i32
print(if_.then.blocks[0])
# CHECK: %c3_i32 = arith.constant 3 : i32
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index 6b3932ce64f13..3c56bd993aec4 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -133,3 +133,18 @@ def testFunctionArgAttrs():
# CHECK: func private @foo(f32 {test.foo = "bar"})
# CHECK: func private @foo2(f32, f32 {test.baz = "qux"})
+
+
+# CHECK-LABEL: TEST: testFunctionTraits
+ at constructAndPrintInModule
+def testFunctionTraits():
+ ret = func.ReturnOp([])
+
+ # CHECK: False
+ print(ret.operation.has_trait(NoTerminatorTrait))
+ # CHECK: True
+ print(ret.operation.has_trait(IsTerminatorTrait))
+ # CHECK: False
+ print(func.ReturnOp.has_trait(NoTerminatorTrait))
+ # CHECK: True
+ print(func.ReturnOp.has_trait(IsTerminatorTrait))
More information about the Mlir-commits
mailing list