[Mlir-commits] [mlir] [MLIR][Python] Support `has_trait` for operations (PR #188492)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 25 07:05:59 PDT 2026


https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/188492

This PR adds a `has_trait(trait_cls)` API to `_OperationBase`, that can used for:
- C++-defined operations and C++-defined traits (`func_return_op.has_trait(IsTerminatorTrait)`)
- Python-defined operations and C++-defined traits (`my_python_op.has_trait(IsTerminatorTrait)`)
- Python-defined operations and Python-defined traits (`func_return_op.has_trait(MyPythonTrait)`)

>From f138fd904532c958432638b9d294ce51c1aaee9e Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 25 Mar 2026 22:01:18 +0800
Subject: [PATCH] [MLIR][Python] Support has_trait for operations

---
 mlir/include/mlir-c/ExtensibleDialect.h    |  8 ++++++++
 mlir/include/mlir-c/IR.h                   |  4 ++++
 mlir/include/mlir/Bindings/Python/IRCore.h |  2 ++
 mlir/include/mlir/IR/ExtensibleDialect.h   |  3 ++-
 mlir/lib/Bindings/Python/IRCore.cpp        | 11 +++++++++--
 mlir/lib/CAPI/IR/ExtensibleDialect.cpp     |  8 ++++++++
 mlir/lib/CAPI/IR/IR.cpp                    |  4 ++++
 mlir/test/python/dialects/builtin.py       | 12 ++++++++++++
 mlir/test/python/dialects/ext.py           | 15 ++++++++++++++-
 mlir/test/python/dialects/func.py          | 11 +++++++++++
 10 files changed, 74 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index fee26772e1560..c849658c710bb 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -50,10 +50,18 @@ mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
 MLIR_CAPI_EXPORTED MlirDynamicOpTrait
 mlirDynamicOpTraitIsTerminatorCreate(void);
 
+/// Get the type ID of the dynamic op trait that indicates the operation is a
+/// terminator.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID(void);
+
 /// Get the dynamic op trait that indicates regions have no terminator.
 MLIR_CAPI_EXPORTED MlirDynamicOpTrait
 mlirDynamicOpTraitNoTerminatorCreate(void);
 
+/// Get the type ID of the dynamic op trait that indicates regions have no
+/// terminator.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID(void);
+
 /// Destroy the dynamic op trait.
 MLIR_CAPI_EXPORTED void
 mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait);
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 805f0ffaaf7ce..efe1d933e3f6a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -631,6 +631,10 @@ MLIR_CAPI_EXPORTED size_t mlirOperationHashValue(MlirOperation op);
 /// Gets the context this operation is associated with
 MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op);
 
+/// Checks if the operation has a trait identified by the given type id.
+MLIR_CAPI_EXPORTED bool mlirOperationHasTrait(MlirOperation op,
+                                              MlirTypeID traitTypeID);
+
 /// Gets the location of the operation.
 MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op);
 
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index eefc51d519d62..6b67b78cd2184 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1861,6 +1861,8 @@ class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
                      const nanobind::object &target, PyMlirContext &context);
 
   static void bind(nanobind::module_ &m);
+
+  static inline const char *typeIDAttr = "_trait_typeid";
 };
 
 namespace PyDynamicOpTraits {
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index dcbf6813506d5..d26ddafd2ef1d 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -401,7 +401,8 @@ class DynamicOpTraitList {
 template <template <typename T> class Trait>
 class DynamicOpTraitImpl : public DynamicOpTrait {
 public:
-  TypeID getTypeID() const override { return TypeID::get<Trait>(); }
+  static TypeID getStaticTypeID() { return TypeID::get<Trait>(); }
+  TypeID getTypeID() const override { return getStaticTypeID(); }
 };
 
 namespace DynamicOpTraits {
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f3f1ee4ce343f..0b5fbc635fb50 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2608,7 +2608,6 @@ bool PyDynamicOpTrait::attach(const nb::object &opName,
 
   // To ensure that the same dynamic trait gets the same TypeID despite how many
   // times `attach` is called, we store it as an attribute on the target class.
-  constexpr const char *typeIDAttr = "_TYPE_ID";
   if (!nb::hasattr(target, typeIDAttr)) {
     nb::setattr(target, typeIDAttr,
                 nb::cast(PyTypeID(PyGlobals::get().allocateTypeID())));
@@ -2642,6 +2641,7 @@ bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName,
 void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) {
   nb::class_<PyDynamicOpTraits::IsTerminator, PyDynamicOpTrait> cls(
       m, "IsTerminatorTrait");
+  cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitIsTerminatorGetTypeID());
   cls.attr("attach") = classmethod(
       [](const nb::object &cls, const nb::object &opName,
          DefaultingPyMlirContext context) {
@@ -2660,6 +2660,7 @@ bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName,
 void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) {
   nb::class_<PyDynamicOpTraits::NoTerminator, PyDynamicOpTrait> cls(
       m, "NoTerminatorTrait");
+  cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitNoTerminatorGetTypeID());
   cls.attr("attach") = classmethod(
       [](const nb::object &cls, const nb::object &opName,
          DefaultingPyMlirContext context) {
@@ -3963,7 +3964,13 @@ void populateIRCore(nb::module_ &m) {
              Args:
                callback: A callable that takes an Operation and returns a WalkResult.
                walk_order: The order of traversal (PRE_ORDER or POST_ORDER).
-               op_class: If provided, only operations of this type are passed to the callback.)");
+               op_class: If provided, only operations of this type are passed to the callback.)")
+      .def("has_trait", [](PyOperationBase &self, nb::type_object &traitCls) {
+        PyTypeID traitTypeID =
+            nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
+        return mlirOperationHasTrait(self.getOperation().get(),
+                                     traitTypeID.get());
+      });
 
   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
       .def_static(
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index 60a7f7d9064eb..5c1f74884f2c5 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -42,10 +42,18 @@ MlirDynamicOpTrait mlirDynamicOpTraitIsTerminatorCreate() {
   return wrap(new DynamicOpTraits::IsTerminator());
 }
 
+MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID() {
+  return wrap(DynamicOpTraits::IsTerminator::getStaticTypeID());
+}
+
 MlirDynamicOpTrait mlirDynamicOpTraitNoTerminatorCreate() {
   return wrap(new DynamicOpTraits::NoTerminator());
 }
 
+MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID() {
+  return wrap(DynamicOpTraits::NoTerminator::getStaticTypeID());
+}
+
 void mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait) {
   delete unwrap(dynamicOpTrait);
 }
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 29f9287279b8f..9f39ebfa9a450 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -651,6 +651,10 @@ MlirContext mlirOperationGetContext(MlirOperation op) {
   return wrap(unwrap(op)->getContext());
 }
 
+bool mlirOperationHasTrait(MlirOperation op, MlirTypeID traitTypeID) {
+  return unwrap(op)->getName().hasTrait(unwrap(traitTypeID));
+}
+
 MlirLocation mlirOperationGetLocation(MlirOperation op) {
   return wrap(unwrap(op)->getLoc());
 }
diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 973a0eaeca2cd..bd58b2880a816 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -250,3 +250,15 @@ def testDenseElementsAttr():
         idx_type = IndexType.get()
         print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type)))
         # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex>
+
+
+# CHECK-LABEL: TEST: testBuiltinTraits
+ at run
+def testBuiltinTraits():
+    with Context() as ctx, Location.unknown() as loc:
+        module = builtin.ModuleOp()
+
+        # CHECK: True
+        print(module.has_trait(NoTerminatorTrait))
+        # CHECK: False
+        print(module.has_trait(IsTerminatorTrait))
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 78c74684cef77..c161ccf6970bf 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -444,7 +444,7 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
 
             with InsertionPoint(if_.then.blocks[0]):
                 v = arith.constant(i32, 2)
-                YieldOp(v)
+                yield_ = YieldOp(v)
 
             with InsertionPoint(if_.else_.blocks[0]):
                 v = arith.constant(i32, 3)
@@ -473,6 +473,19 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
         # CHECK: }
         print(module)
 
+        # CHECK: True
+        print(yield_.has_trait(IsTerminatorTrait))
+        # CHECK: False
+        print(yield_.has_trait(NoTerminatorTrait))
+        # CHECK: True
+        print(yield_.has_trait(ParentIsIfTrait))
+        # CHECK: False
+        print(nt.has_trait(IsTerminatorTrait))
+        # CHECK: True
+        print(nt.has_trait(NoTerminatorTrait))
+        # CHECK: False
+        print(nt.has_trait(ParentIsIfTrait))
+
         # CHECK: %c2_i32 = arith.constant 2 : i32
         print(if_.then.blocks[0])
         # CHECK: %c3_i32 = arith.constant 3 : i32
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index 6b3932ce64f13..cfe4ca4a56ed6 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -133,3 +133,14 @@ def testFunctionArgAttrs():
 
 # CHECK: func private @foo(f32 {test.foo = "bar"})
 # CHECK: func private @foo2(f32, f32  {test.baz = "qux"})
+
+
+# CHECK-LABEL: TEST: testFunctionTraits
+ at constructAndPrintInModule
+def testFunctionTraits():
+    ret = func.ReturnOp([])
+
+    # CHECK: False
+    print(ret.has_trait(NoTerminatorTrait))
+    # CHECK: True
+    print(ret.has_trait(IsTerminatorTrait))



More information about the Mlir-commits mailing list