[Mlir-commits] [mlir] [MLIR][Python] Add C and Python API for `mlir::DynamicAttr` (PR #182820)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 23 00:44:54 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

This PR adds C and Python API support for `mlir::DynamicAttr`. It primarily enables attributes in dialects that are dynamically generated via IRDL to be constructed in Python, and allows retrieving the parameters contained in a dynamic attribute from Python.

This PR is quite similiar to #<!-- -->182751, so I use tab to autocomplete some code via github copilot, but manually verified.

---
Full diff: https://github.com/llvm/llvm-project/pull/182820.diff


5 Files Affected:

- (modified) mlir/include/mlir-c/ExtensibleDialect.h (+37) 
- (modified) mlir/include/mlir/Bindings/Python/IRAttributes.h (+12) 
- (modified) mlir/lib/Bindings/Python/IRAttributes.cpp (+65) 
- (modified) mlir/lib/CAPI/IR/ExtensibleDialect.cpp (+48) 
- (modified) mlir/test/python/dialects/irdl.py (+79) 


``````````diff
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index eabcd080c5d4b..d6aa8181c024c 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -33,6 +33,7 @@ extern "C" {
 
 DEFINE_C_API_STRUCT(MlirDynamicOpTrait, void);
 DEFINE_C_API_STRUCT(MlirDynamicTypeDefinition, void);
+DEFINE_C_API_STRUCT(MlirDynamicAttrDefinition, void);
 
 #undef DEFINE_C_API_STRUCT
 
@@ -113,6 +114,42 @@ mlirDynamicTypeDefinitionGetName(MlirDynamicTypeDefinition typeDef);
 MLIR_CAPI_EXPORTED MlirDialect
 mlirDynamicTypeDefinitionGetDialect(MlirDynamicTypeDefinition typeDef);
 
+/// Look up a registered attribute definition by attribute name in the given
+/// dialect. Note that the dialect must be an extensible dialect.
+MLIR_CAPI_EXPORTED MlirDynamicAttrDefinition
+mlirExtensibleDialectLookupAttrDefinition(MlirDialect dialect,
+                                          MlirStringRef attrName);
+
+/// Check if the given attribute is a dynamic attribute.
+MLIR_CAPI_EXPORTED bool mlirAttributeIsADynamicAttr(MlirAttribute attr);
+
+/// Get the type ID of a dynamic attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicAttrGetTypeID(void);
+
+/// Get a dynamic attribute by instantiating the given attribute definition with
+/// the provided attributes.
+MLIR_CAPI_EXPORTED MlirAttribute mlirDynamicAttrGet(
+    MlirDynamicAttrDefinition attrDef, MlirAttribute *attrs, intptr_t numAttrs);
+
+/// Get the number of parameters in the given dynamic attribute.
+MLIR_CAPI_EXPORTED intptr_t mlirDynamicAttrGetNumParams(MlirAttribute attr);
+
+/// Get the parameter at the given index in the provided dynamic attribute.
+MLIR_CAPI_EXPORTED MlirAttribute mlirDynamicAttrGetParam(MlirAttribute attr,
+                                                         intptr_t index);
+
+/// Get the attribute definition of the given dynamic attribute.
+MLIR_CAPI_EXPORTED MlirDynamicAttrDefinition
+mlirDynamicAttrGetAttrDef(MlirAttribute attr);
+
+/// Get the name of the given dynamic attribute definition.
+MLIR_CAPI_EXPORTED MlirStringRef
+mlirDynamicAttrDefinitionGetName(MlirDynamicAttrDefinition attrDef);
+
+/// Get the dialect that the given dynamic attribute definition belongs to.
+MLIR_CAPI_EXPORTED MlirDialect
+mlirDynamicAttrDefinitionGetDialect(MlirDynamicAttrDefinition attrDef);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
index 2a6dbe229044e..2f1e4a2ad99d0 100644
--- a/mlir/include/mlir/Bindings/Python/IRAttributes.h
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -588,6 +588,18 @@ class MLIR_PYTHON_API_EXPORTED PyStridedLayoutAttribute
   static void bindDerived(ClassTy &c);
 };
 
+class MLIR_PYTHON_API_EXPORTED PyDynamicAttribute
+    : public PyConcreteAttribute<PyDynamicAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADynamicAttr;
+  static constexpr const char *pyClassName = "DynamicAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirDynamicAttrGetTypeID;
+
+  static void bindDerived(ClassTy &c);
+};
+
 MLIR_PYTHON_API_EXPORTED void populateIRAttributes(nanobind::module_ &m);
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 8c4a2dcd5a7f7..370b9d8d6e062 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1374,6 +1374,70 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
       "Returns the value of the string attribute as `bytes`");
 }
 
+void PyDynamicAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](const std::string &fullAttrName, const std::vector<PyAttribute> &attrs,
+         DefaultingPyMlirContext context) {
+        size_t dotPos = fullAttrName.find('.');
+        if (dotPos == std::string::npos) {
+          throw nb::value_error(
+              "Expected full attribute name to be in the format "
+              "'<dialectName>.<attributeName>'.");
+        }
+
+        std::string dialectName = fullAttrName.substr(0, dotPos);
+        std::string attrName = fullAttrName.substr(dotPos + 1);
+        PyDialects dialects(context->getRef());
+        MlirDialect dialect = dialects.getDialectForKey(dialectName, false);
+        if (!mlirDialectIsAExtensibleDialect(dialect))
+          throw nb::value_error(
+              ("Dialect '" + dialectName + "' is not an extensible dialect.")
+                  .c_str());
+
+        MlirDynamicAttrDefinition attrDef =
+            mlirExtensibleDialectLookupAttrDefinition(
+                dialect, toMlirStringRef(attrName));
+        if (attrDef.ptr == nullptr) {
+          throw nb::value_error(("Dialect '" + dialectName +
+                                 "' does not contain an attribute named '" +
+                                 attrName + "'.")
+                                    .c_str());
+        }
+
+        std::vector<MlirAttribute> mlirAttrs;
+        mlirAttrs.reserve(attrs.size());
+        for (const auto &attr : attrs)
+          mlirAttrs.push_back(attr.get());
+        MlirAttribute attr =
+            mlirDynamicAttrGet(attrDef, mlirAttrs.data(), mlirAttrs.size());
+        return PyDynamicAttribute(context->getRef(), attr);
+      },
+      nb::arg("full_attr_name"), nb::arg("attributes"),
+      nb::arg("context") = nb::none(), "Create a dynamic attribute.");
+  c.def_prop_ro(
+      "params",
+      [](PyDynamicAttribute &self) {
+        size_t numParams = mlirDynamicAttrGetNumParams(self);
+        std::vector<PyAttribute> params;
+        params.reserve(numParams);
+        for (size_t i = 0; i < numParams; ++i)
+          params.emplace_back(self.getContext(),
+                              mlirDynamicAttrGetParam(self, i));
+        return params;
+      },
+      "Returns the parameters of the dynamic attribute as a list of "
+      "attributes.");
+  c.def_prop_ro("attr_name", [](PyDynamicAttribute &self) {
+    MlirDynamicAttrDefinition attrDef = mlirDynamicAttrGetAttrDef(self);
+    MlirStringRef name = mlirDynamicAttrDefinitionGetName(attrDef);
+    MlirDialect dialect = mlirDynamicAttrDefinitionGetDialect(attrDef);
+    MlirStringRef dialectNamespace = mlirDialectGetNamespace(dialect);
+    return std::string(dialectNamespace.data, dialectNamespace.length) + "." +
+           std::string(name.data, name.length);
+  });
+}
+
 void populateIRAttributes(nb::module_ &m) {
   PyAffineMapAttribute::bind(m);
   PyDenseBoolArrayAttribute::bind(m);
@@ -1426,6 +1490,7 @@ void populateIRAttributes(nb::module_ &m) {
   PyUnitAttribute::bind(m);
 
   PyStridedLayoutAttribute::bind(m);
+  PyDynamicAttribute::bind(m);
 }
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index 1659b0afd7354..51cec5f95a201 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -16,6 +16,7 @@ using namespace mlir;
 
 DEFINE_C_API_PTR_METHODS(MlirDynamicOpTrait, DynamicOpTrait)
 DEFINE_C_API_PTR_METHODS(MlirDynamicTypeDefinition, DynamicTypeDefinition)
+DEFINE_C_API_PTR_METHODS(MlirDynamicAttrDefinition, DynamicAttrDefinition)
 
 bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
                               MlirStringRef opName, MlirContext context) {
@@ -137,3 +138,50 @@ MlirDialect
 mlirDynamicTypeDefinitionGetDialect(MlirDynamicTypeDefinition typeDef) {
   return wrap(unwrap(typeDef)->getDialect());
 }
+
+MlirDynamicAttrDefinition
+mlirExtensibleDialectLookupAttrDefinition(MlirDialect dialect,
+                                          MlirStringRef attrName) {
+  return wrap(llvm::cast<mlir::ExtensibleDialect>(unwrap(dialect))
+                  ->lookupAttrDefinition(unwrap(attrName)));
+}
+
+bool mlirAttributeIsADynamicAttr(MlirAttribute attr) {
+  return llvm::isa<mlir::DynamicAttr>(unwrap(attr));
+}
+
+MlirTypeID mlirDynamicAttrGetTypeID(void) {
+  return wrap(mlir::DynamicAttr::getTypeID());
+}
+
+MlirAttribute mlirDynamicAttrGet(MlirDynamicAttrDefinition attrDef,
+                                 MlirAttribute *attrs, intptr_t numAttrs) {
+  llvm::SmallVector<mlir::Attribute> attributes;
+  attributes.reserve(numAttrs);
+  for (intptr_t i = 0; i < numAttrs; ++i)
+    attributes.push_back(unwrap(attrs[i]));
+
+  return wrap(mlir::DynamicAttr::get(unwrap(attrDef), attributes));
+}
+
+intptr_t mlirDynamicAttrGetNumParams(MlirAttribute attr) {
+  return llvm::cast<mlir::DynamicAttr>(unwrap(attr)).getParams().size();
+}
+
+MlirAttribute mlirDynamicAttrGetParam(MlirAttribute attr, intptr_t index) {
+  return wrap(llvm::cast<mlir::DynamicAttr>(unwrap(attr)).getParams()[index]);
+}
+
+MlirDynamicAttrDefinition mlirDynamicAttrGetAttrDef(MlirAttribute attr) {
+  return wrap(llvm::cast<mlir::DynamicAttr>(unwrap(attr)).getAttrDef());
+}
+
+MlirStringRef
+mlirDynamicAttrDefinitionGetName(MlirDynamicAttrDefinition attrDef) {
+  return wrap(unwrap(attrDef)->getName());
+}
+
+MlirDialect
+mlirDynamicAttrDefinitionGetDialect(MlirDynamicAttrDefinition attrDef) {
+  return wrap(unwrap(attrDef)->getDialect());
+}
diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
index da5d51d1b082a..004a795511f75 100644
--- a/mlir/test/python/dialects/irdl.py
+++ b/mlir/test/python/dialects/irdl.py
@@ -66,6 +66,7 @@ def testIRDL():
         m.dump()
 
 
+# CHECK: TEST: testIRDLTypes
 @run
 def testIRDLTypes():
     with Context() as ctx, Location.unknown():
@@ -137,5 +138,83 @@ def testIRDLTypes():
         with InsertionPoint(m.body):
             Operation.create("irdl_type_test.op1", results=[t1])
 
+        assert m.operation.verify()
         # CHECK: %0 = "irdl_type_test.op1"() : () -> !irdl_type_test.type1<42 : i32>
         m.dump()
+
+
+# CHECK: TEST: testIRDLAttrs
+ at run
+def testIRDLAttrs():
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            irdl_test = dialect("irdl_attr_test")
+            with InsertionPoint(irdl_test.body):
+                attr1 = attribute("attr1")
+                with InsertionPoint(attr1.body):
+                    iattr = base(base_name="#builtin.integer")
+                    parameters([iattr], ["val"])
+                attr2 = attribute("attr2")
+                with InsertionPoint(attr2.body):
+                    iattr = base(base_name="#builtin.integer")
+                    unit = is_(UnitAttr.get())
+                    parameters([iattr, unit], ["val1", "val2"])
+                op1 = operation_("op1")
+                with InsertionPoint(op1.body):
+                    a1 = base(base_ref=["irdl_attr_test", "attr1"])
+                    attributes_([a1], ["attr"])
+
+        # CHECK: module {
+        # CHECK:   irdl.dialect @irdl_attr_test {
+        # CHECK:     irdl.attribute @attr1 {
+        # CHECK:       %0 = irdl.base "#builtin.integer"
+        # CHECK:       irdl.parameters(val: %0)
+        # CHECK:     }
+        # CHECK:     irdl.attribute @attr2 {
+        # CHECK:       %0 = irdl.base "#builtin.integer"
+        # CHECK:       %1 = irdl.is unit
+        # CHECK:       irdl.parameters(val1: %0, val2: %1)
+        # CHECK:     }
+        # CHECK:     irdl.operation @op1 {
+        # CHECK:       %0 = irdl.base @irdl_attr_test::@attr1
+        # CHECK:       irdl.attributes {"attr" = %0}
+        # CHECK:     }
+        # CHECK:   }
+        # CHECK: }
+        module.operation.verify()
+        module.dump()
+
+        load_dialects(module)
+
+        i32 = IntegerType.get(32)
+        a1 = DynamicAttr.get("irdl_attr_test.attr1", [IntegerAttr.get(i32, 42)])
+        # CHECK: #irdl_attr_test.attr1<42 : i32>
+        a1.dump()
+        # CHECK: irdl_attr_test.attr1
+        print(a1.attr_name, file=sys.stderr)
+        # CHECK: 1
+        print(len(a1.params), file=sys.stderr)
+        # CHECK: 42 : i32
+        a1.params[0].dump()
+        a2 = DynamicAttr.get(
+            "irdl_attr_test.attr2", [IntegerAttr.get(i32, 33), UnitAttr.get()]
+        )
+        # CHECK: #irdl_attr_test.attr2<33 : i32, unit>
+        a2.dump()
+        # CHECK: irdl_attr_test.attr2
+        print(a2.attr_name, file=sys.stderr)
+        # CHECK: 2
+        print(len(a2.params), file=sys.stderr)
+        # CHECK: 33 : i32
+        a2.params[0].dump()
+        # CHECK: unit
+        a2.params[1].dump()
+
+        m = Module.create()
+        with InsertionPoint(m.body):
+            Operation.create("irdl_attr_test.op1", attributes={"attr": a1})
+
+        assert m.operation.verify()
+        # CHECK: "irdl_attr_test.op1"() {attr = #irdl_attr_test.attr1<42 : i32>} : () -> ()
+        m.dump()

``````````

</details>


https://github.com/llvm/llvm-project/pull/182820


More information about the Mlir-commits mailing list