[Mlir-commits] [mlir] 1503293 - [MLIR][Python] Support `has_trait` for operations (#188492)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 26 09:02:36 PDT 2026


Author: Twice
Date: 2026-03-27T00:02:31+08:00
New Revision: 1503293cbc7ee0fbeece3c48c2c4d9f4d87d19fd

URL: https://github.com/llvm/llvm-project/commit/1503293cbc7ee0fbeece3c48c2c4d9f4d87d19fd
DIFF: https://github.com/llvm/llvm-project/commit/1503293cbc7ee0fbeece3c48c2c4d9f4d87d19fd.diff

LOG: [MLIR][Python] Support `has_trait` for operations (#188492)

This PR adds a `has_trait(trait_cls)` API to `_OperationBase`, that can
be used for:
- C++-defined operations and C++-defined traits (e.g.
`func_return_op.has_trait(IsTerminatorTrait)`)
- Python-defined operations and C++-defined traits (e.g.
`my_python_op.has_trait(IsTerminatorTrait)`)
- Python-defined operations and Python-defined traits (e.g.
`my_python_op.has_trait(MyPythonTrait)`)

---------

Co-authored-by: Maksim Levental <maksim.levental at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir-c/ExtensibleDialect.h
    mlir/include/mlir-c/IR.h
    mlir/include/mlir/Bindings/Python/IRCore.h
    mlir/include/mlir/IR/ExtensibleDialect.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/CAPI/IR/ExtensibleDialect.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/python/dialects/builtin.py
    mlir/test/python/dialects/ext.py
    mlir/test/python/dialects/func.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index fee26772e1560..c849658c710bb 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -50,10 +50,18 @@ mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
 MLIR_CAPI_EXPORTED MlirDynamicOpTrait
 mlirDynamicOpTraitIsTerminatorCreate(void);
 
+/// Get the type ID of the dynamic op trait that indicates the operation is a
+/// terminator.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID(void);
+
 /// Get the dynamic op trait that indicates regions have no terminator.
 MLIR_CAPI_EXPORTED MlirDynamicOpTrait
 mlirDynamicOpTraitNoTerminatorCreate(void);
 
+/// Get the type ID of the dynamic op trait that indicates regions have no
+/// terminator.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID(void);
+
 /// Destroy the dynamic op trait.
 MLIR_CAPI_EXPORTED void
 mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait);

diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 805f0ffaaf7ce..8d30051d615f4 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -631,6 +631,11 @@ MLIR_CAPI_EXPORTED size_t mlirOperationHashValue(MlirOperation op);
 /// Gets the context this operation is associated with
 MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op);
 
+/// Checks if the operation name has a trait identified by the given type id.
+MLIR_CAPI_EXPORTED bool mlirOperationNameHasTrait(MlirStringRef opName,
+                                                  MlirTypeID traitTypeID,
+                                                  MlirContext context);
+
 /// Gets the location of the operation.
 MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op);
 

diff  --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index db8427cfc4f78..5f451606d5f82 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1842,6 +1842,8 @@ class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait {
                      const nanobind::object &target, PyMlirContext &context);
 
   static void bind(nanobind::module_ &m);
+
+  static inline const char *typeIDAttr = "_trait_typeid";
 };
 
 namespace PyDynamicOpTraits {

diff  --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index dcbf6813506d5..d26ddafd2ef1d 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -401,7 +401,8 @@ class DynamicOpTraitList {
 template <template <typename T> class Trait>
 class DynamicOpTraitImpl : public DynamicOpTrait {
 public:
-  TypeID getTypeID() const override { return TypeID::get<Trait>(); }
+  static TypeID getStaticTypeID() { return TypeID::get<Trait>(); }
+  TypeID getTypeID() const override { return getStaticTypeID(); }
 };
 
 namespace DynamicOpTraits {

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 350eb52765e1a..b52328a37e5f1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2583,7 +2583,6 @@ bool PyDynamicOpTrait::attach(const nb::object &opName,
 
   // To ensure that the same dynamic trait gets the same TypeID despite how many
   // times `attach` is called, we store it as an attribute on the target class.
-  constexpr const char *typeIDAttr = "_TYPE_ID";
   if (!nb::hasattr(target, typeIDAttr)) {
     nb::setattr(target, typeIDAttr,
                 nb::cast(PyTypeID(PyGlobals::get().allocateTypeID())));
@@ -2617,6 +2616,7 @@ bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName,
 void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) {
   nb::class_<PyDynamicOpTraits::IsTerminator, PyDynamicOpTrait> cls(
       m, "IsTerminatorTrait");
+  cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitIsTerminatorGetTypeID());
   cls.attr("attach") = classmethod(
       [](const nb::object &cls, const nb::object &opName,
          DefaultingPyMlirContext context) {
@@ -2635,6 +2635,7 @@ bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName,
 void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) {
   nb::class_<PyDynamicOpTraits::NoTerminator, PyDynamicOpTrait> cls(
       m, "NoTerminatorTrait");
+  cls.attr(typeIDAttr) = PyTypeID(mlirDynamicOpTraitNoTerminatorGetTypeID());
   cls.attr("attach") = classmethod(
       [](const nb::object &cls, const nb::object &opName,
          DefaultingPyMlirContext context) {
@@ -3938,7 +3939,19 @@ void populateIRCore(nb::module_ &m) {
              Args:
                callback: A callable that takes an Operation and returns a WalkResult.
                walk_order: The order of traversal (PRE_ORDER or POST_ORDER).
-               op_class: If provided, only operations of this type are passed to the callback.)");
+               op_class: If provided, only operations of this type are passed to the callback.)")
+      .def(
+          "has_trait",
+          [](PyOperationBase &self, nb::type_object &traitCls) {
+            PyTypeID traitTypeID =
+                nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
+            MlirIdentifier opName =
+                mlirOperationGetName(self.getOperation().get());
+            return mlirOperationNameHasTrait(
+                mlirIdentifierStr(opName), traitTypeID.get(),
+                self.getOperation().getContext()->get());
+          },
+          "trait_cls"_a, "Checks if the operation has a given trait.");
 
   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
       .def_static(
@@ -4150,6 +4163,18 @@ void populateIRCore(nb::module_ &m) {
       nb::sig("def parse(cls, source: str, *, source_name: str = '', context: Context | None = None) -> typing.Self"),
       // clang-format on
       "Parses a specific, generated OpView based on class level attributes.");
+  opViewClass.attr("has_trait") = classmethod(
+      [](nb::object &self, nb::type_object &traitCls,
+         DefaultingPyMlirContext &context) {
+        PyTypeID traitTypeID =
+            nb::cast<PyTypeID>(traitCls.attr(PyDynamicOpTrait::typeIDAttr));
+        std::string opName = nb::cast<std::string>(self.attr("OPERATION_NAME"));
+        return mlirOperationNameHasTrait(
+            mlirStringRefCreate(opName.data(), opName.size()),
+            traitTypeID.get(), context->get());
+      },
+      "cls"_a, "trait_cls"_a, "context"_a = nb::none(),
+      "Checks if the operation has a given trait.");
 
   PyOpAdaptor::bind(m);
 

diff  --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index 60a7f7d9064eb..5c1f74884f2c5 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -42,10 +42,18 @@ MlirDynamicOpTrait mlirDynamicOpTraitIsTerminatorCreate() {
   return wrap(new DynamicOpTraits::IsTerminator());
 }
 
+MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID() {
+  return wrap(DynamicOpTraits::IsTerminator::getStaticTypeID());
+}
+
 MlirDynamicOpTrait mlirDynamicOpTraitNoTerminatorCreate() {
   return wrap(new DynamicOpTraits::NoTerminator());
 }
 
+MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID() {
+  return wrap(DynamicOpTraits::NoTerminator::getStaticTypeID());
+}
+
 void mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait) {
   delete unwrap(dynamicOpTrait);
 }

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 29f9287279b8f..c7ecab15a04d4 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -651,6 +651,12 @@ MlirContext mlirOperationGetContext(MlirOperation op) {
   return wrap(unwrap(op)->getContext());
 }
 
+bool mlirOperationNameHasTrait(MlirStringRef opName, MlirTypeID traitTypeID,
+                               MlirContext context) {
+  return OperationName(unwrap(opName), unwrap(context))
+      .hasTrait(unwrap(traitTypeID));
+}
+
 MlirLocation mlirOperationGetLocation(MlirOperation op) {
   return wrap(unwrap(op)->getLoc());
 }

diff  --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 973a0eaeca2cd..52ada96ee46f9 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -250,3 +250,19 @@ def testDenseElementsAttr():
         idx_type = IndexType.get()
         print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type)))
         # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex>
+
+
+# CHECK-LABEL: TEST: testBuiltinTraits
+ at run
+def testBuiltinTraits():
+    with Context() as ctx, Location.unknown() as loc:
+        module = builtin.ModuleOp()
+
+        # CHECK: True
+        print(module.operation.has_trait(NoTerminatorTrait))
+        # CHECK: False
+        print(module.operation.has_trait(IsTerminatorTrait))
+        # CHECK: True
+        print(builtin.ModuleOp.has_trait(NoTerminatorTrait))
+        # CHECK: False
+        print(builtin.ModuleOp.has_trait(IsTerminatorTrait))

diff  --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 78c74684cef77..90f68b876b139 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -444,7 +444,7 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
 
             with InsertionPoint(if_.then.blocks[0]):
                 v = arith.constant(i32, 2)
-                YieldOp(v)
+                yield_ = YieldOp(v)
 
             with InsertionPoint(if_.else_.blocks[0]):
                 v = arith.constant(i32, 3)
@@ -473,6 +473,25 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
         # CHECK: }
         print(module)
 
+        # CHECK: True
+        print(yield_.has_trait(IsTerminatorTrait))
+        # CHECK: False
+        print(yield_.has_trait(NoTerminatorTrait))
+        # CHECK: True
+        print(yield_.has_trait(ParentIsIfTrait))
+        # CHECK: False
+        print(nt.operation.has_trait(IsTerminatorTrait))
+        # CHECK: True
+        print(nt.operation.has_trait(NoTerminatorTrait))
+        # CHECK: False
+        print(nt.operation.has_trait(ParentIsIfTrait))
+        # CHECK: False
+        print(NoTermOp.has_trait(IsTerminatorTrait))
+        # CHECK: True
+        print(NoTermOp.has_trait(NoTerminatorTrait))
+        # CHECK: False
+        print(NoTermOp.has_trait(ParentIsIfTrait))
+
         # CHECK: %c2_i32 = arith.constant 2 : i32
         print(if_.then.blocks[0])
         # CHECK: %c3_i32 = arith.constant 3 : i32

diff  --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index 6b3932ce64f13..3c56bd993aec4 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -133,3 +133,18 @@ def testFunctionArgAttrs():
 
 # CHECK: func private @foo(f32 {test.foo = "bar"})
 # CHECK: func private @foo2(f32, f32  {test.baz = "qux"})
+
+
+# CHECK-LABEL: TEST: testFunctionTraits
+ at constructAndPrintInModule
+def testFunctionTraits():
+    ret = func.ReturnOp([])
+
+    # CHECK: False
+    print(ret.operation.has_trait(NoTerminatorTrait))
+    # CHECK: True
+    print(ret.operation.has_trait(IsTerminatorTrait))
+    # CHECK: False
+    print(func.ReturnOp.has_trait(NoTerminatorTrait))
+    # CHECK: True
+    print(func.ReturnOp.has_trait(IsTerminatorTrait))


        


More information about the Mlir-commits mailing list