[Mlir-commits] [mlir] [MLIR][Python] Add C and Python API for `mlir::DynamicAttr` (PR #182820)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 23 00:44:54 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Twice (PragmaTwice)
<details>
<summary>Changes</summary>
This PR adds C and Python API support for `mlir::DynamicAttr`. It primarily enables attributes in dialects that are dynamically generated via IRDL to be constructed in Python, and allows retrieving the parameters contained in a dynamic attribute from Python.
This PR is quite similiar to #<!-- -->182751, so I use tab to autocomplete some code via github copilot, but manually verified.
---
Full diff: https://github.com/llvm/llvm-project/pull/182820.diff
5 Files Affected:
- (modified) mlir/include/mlir-c/ExtensibleDialect.h (+37)
- (modified) mlir/include/mlir/Bindings/Python/IRAttributes.h (+12)
- (modified) mlir/lib/Bindings/Python/IRAttributes.cpp (+65)
- (modified) mlir/lib/CAPI/IR/ExtensibleDialect.cpp (+48)
- (modified) mlir/test/python/dialects/irdl.py (+79)
``````````diff
diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h
index eabcd080c5d4b..d6aa8181c024c 100644
--- a/mlir/include/mlir-c/ExtensibleDialect.h
+++ b/mlir/include/mlir-c/ExtensibleDialect.h
@@ -33,6 +33,7 @@ extern "C" {
DEFINE_C_API_STRUCT(MlirDynamicOpTrait, void);
DEFINE_C_API_STRUCT(MlirDynamicTypeDefinition, void);
+DEFINE_C_API_STRUCT(MlirDynamicAttrDefinition, void);
#undef DEFINE_C_API_STRUCT
@@ -113,6 +114,42 @@ mlirDynamicTypeDefinitionGetName(MlirDynamicTypeDefinition typeDef);
MLIR_CAPI_EXPORTED MlirDialect
mlirDynamicTypeDefinitionGetDialect(MlirDynamicTypeDefinition typeDef);
+/// Look up a registered attribute definition by attribute name in the given
+/// dialect. Note that the dialect must be an extensible dialect.
+MLIR_CAPI_EXPORTED MlirDynamicAttrDefinition
+mlirExtensibleDialectLookupAttrDefinition(MlirDialect dialect,
+ MlirStringRef attrName);
+
+/// Check if the given attribute is a dynamic attribute.
+MLIR_CAPI_EXPORTED bool mlirAttributeIsADynamicAttr(MlirAttribute attr);
+
+/// Get the type ID of a dynamic attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDynamicAttrGetTypeID(void);
+
+/// Get a dynamic attribute by instantiating the given attribute definition with
+/// the provided attributes.
+MLIR_CAPI_EXPORTED MlirAttribute mlirDynamicAttrGet(
+ MlirDynamicAttrDefinition attrDef, MlirAttribute *attrs, intptr_t numAttrs);
+
+/// Get the number of parameters in the given dynamic attribute.
+MLIR_CAPI_EXPORTED intptr_t mlirDynamicAttrGetNumParams(MlirAttribute attr);
+
+/// Get the parameter at the given index in the provided dynamic attribute.
+MLIR_CAPI_EXPORTED MlirAttribute mlirDynamicAttrGetParam(MlirAttribute attr,
+ intptr_t index);
+
+/// Get the attribute definition of the given dynamic attribute.
+MLIR_CAPI_EXPORTED MlirDynamicAttrDefinition
+mlirDynamicAttrGetAttrDef(MlirAttribute attr);
+
+/// Get the name of the given dynamic attribute definition.
+MLIR_CAPI_EXPORTED MlirStringRef
+mlirDynamicAttrDefinitionGetName(MlirDynamicAttrDefinition attrDef);
+
+/// Get the dialect that the given dynamic attribute definition belongs to.
+MLIR_CAPI_EXPORTED MlirDialect
+mlirDynamicAttrDefinitionGetDialect(MlirDynamicAttrDefinition attrDef);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
index 2a6dbe229044e..2f1e4a2ad99d0 100644
--- a/mlir/include/mlir/Bindings/Python/IRAttributes.h
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -588,6 +588,18 @@ class MLIR_PYTHON_API_EXPORTED PyStridedLayoutAttribute
static void bindDerived(ClassTy &c);
};
+class MLIR_PYTHON_API_EXPORTED PyDynamicAttribute
+ : public PyConcreteAttribute<PyDynamicAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADynamicAttr;
+ static constexpr const char *pyClassName = "DynamicAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirDynamicAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
MLIR_PYTHON_API_EXPORTED void populateIRAttributes(nanobind::module_ &m);
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 8c4a2dcd5a7f7..370b9d8d6e062 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1374,6 +1374,70 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
"Returns the value of the string attribute as `bytes`");
}
+void PyDynamicAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &fullAttrName, const std::vector<PyAttribute> &attrs,
+ DefaultingPyMlirContext context) {
+ size_t dotPos = fullAttrName.find('.');
+ if (dotPos == std::string::npos) {
+ throw nb::value_error(
+ "Expected full attribute name to be in the format "
+ "'<dialectName>.<attributeName>'.");
+ }
+
+ std::string dialectName = fullAttrName.substr(0, dotPos);
+ std::string attrName = fullAttrName.substr(dotPos + 1);
+ PyDialects dialects(context->getRef());
+ MlirDialect dialect = dialects.getDialectForKey(dialectName, false);
+ if (!mlirDialectIsAExtensibleDialect(dialect))
+ throw nb::value_error(
+ ("Dialect '" + dialectName + "' is not an extensible dialect.")
+ .c_str());
+
+ MlirDynamicAttrDefinition attrDef =
+ mlirExtensibleDialectLookupAttrDefinition(
+ dialect, toMlirStringRef(attrName));
+ if (attrDef.ptr == nullptr) {
+ throw nb::value_error(("Dialect '" + dialectName +
+ "' does not contain an attribute named '" +
+ attrName + "'.")
+ .c_str());
+ }
+
+ std::vector<MlirAttribute> mlirAttrs;
+ mlirAttrs.reserve(attrs.size());
+ for (const auto &attr : attrs)
+ mlirAttrs.push_back(attr.get());
+ MlirAttribute attr =
+ mlirDynamicAttrGet(attrDef, mlirAttrs.data(), mlirAttrs.size());
+ return PyDynamicAttribute(context->getRef(), attr);
+ },
+ nb::arg("full_attr_name"), nb::arg("attributes"),
+ nb::arg("context") = nb::none(), "Create a dynamic attribute.");
+ c.def_prop_ro(
+ "params",
+ [](PyDynamicAttribute &self) {
+ size_t numParams = mlirDynamicAttrGetNumParams(self);
+ std::vector<PyAttribute> params;
+ params.reserve(numParams);
+ for (size_t i = 0; i < numParams; ++i)
+ params.emplace_back(self.getContext(),
+ mlirDynamicAttrGetParam(self, i));
+ return params;
+ },
+ "Returns the parameters of the dynamic attribute as a list of "
+ "attributes.");
+ c.def_prop_ro("attr_name", [](PyDynamicAttribute &self) {
+ MlirDynamicAttrDefinition attrDef = mlirDynamicAttrGetAttrDef(self);
+ MlirStringRef name = mlirDynamicAttrDefinitionGetName(attrDef);
+ MlirDialect dialect = mlirDynamicAttrDefinitionGetDialect(attrDef);
+ MlirStringRef dialectNamespace = mlirDialectGetNamespace(dialect);
+ return std::string(dialectNamespace.data, dialectNamespace.length) + "." +
+ std::string(name.data, name.length);
+ });
+}
+
void populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
@@ -1426,6 +1490,7 @@ void populateIRAttributes(nb::module_ &m) {
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);
+ PyDynamicAttribute::bind(m);
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
index 1659b0afd7354..51cec5f95a201 100644
--- a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp
@@ -16,6 +16,7 @@ using namespace mlir;
DEFINE_C_API_PTR_METHODS(MlirDynamicOpTrait, DynamicOpTrait)
DEFINE_C_API_PTR_METHODS(MlirDynamicTypeDefinition, DynamicTypeDefinition)
+DEFINE_C_API_PTR_METHODS(MlirDynamicAttrDefinition, DynamicAttrDefinition)
bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait,
MlirStringRef opName, MlirContext context) {
@@ -137,3 +138,50 @@ MlirDialect
mlirDynamicTypeDefinitionGetDialect(MlirDynamicTypeDefinition typeDef) {
return wrap(unwrap(typeDef)->getDialect());
}
+
+MlirDynamicAttrDefinition
+mlirExtensibleDialectLookupAttrDefinition(MlirDialect dialect,
+ MlirStringRef attrName) {
+ return wrap(llvm::cast<mlir::ExtensibleDialect>(unwrap(dialect))
+ ->lookupAttrDefinition(unwrap(attrName)));
+}
+
+bool mlirAttributeIsADynamicAttr(MlirAttribute attr) {
+ return llvm::isa<mlir::DynamicAttr>(unwrap(attr));
+}
+
+MlirTypeID mlirDynamicAttrGetTypeID(void) {
+ return wrap(mlir::DynamicAttr::getTypeID());
+}
+
+MlirAttribute mlirDynamicAttrGet(MlirDynamicAttrDefinition attrDef,
+ MlirAttribute *attrs, intptr_t numAttrs) {
+ llvm::SmallVector<mlir::Attribute> attributes;
+ attributes.reserve(numAttrs);
+ for (intptr_t i = 0; i < numAttrs; ++i)
+ attributes.push_back(unwrap(attrs[i]));
+
+ return wrap(mlir::DynamicAttr::get(unwrap(attrDef), attributes));
+}
+
+intptr_t mlirDynamicAttrGetNumParams(MlirAttribute attr) {
+ return llvm::cast<mlir::DynamicAttr>(unwrap(attr)).getParams().size();
+}
+
+MlirAttribute mlirDynamicAttrGetParam(MlirAttribute attr, intptr_t index) {
+ return wrap(llvm::cast<mlir::DynamicAttr>(unwrap(attr)).getParams()[index]);
+}
+
+MlirDynamicAttrDefinition mlirDynamicAttrGetAttrDef(MlirAttribute attr) {
+ return wrap(llvm::cast<mlir::DynamicAttr>(unwrap(attr)).getAttrDef());
+}
+
+MlirStringRef
+mlirDynamicAttrDefinitionGetName(MlirDynamicAttrDefinition attrDef) {
+ return wrap(unwrap(attrDef)->getName());
+}
+
+MlirDialect
+mlirDynamicAttrDefinitionGetDialect(MlirDynamicAttrDefinition attrDef) {
+ return wrap(unwrap(attrDef)->getDialect());
+}
diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
index da5d51d1b082a..004a795511f75 100644
--- a/mlir/test/python/dialects/irdl.py
+++ b/mlir/test/python/dialects/irdl.py
@@ -66,6 +66,7 @@ def testIRDL():
m.dump()
+# CHECK: TEST: testIRDLTypes
@run
def testIRDLTypes():
with Context() as ctx, Location.unknown():
@@ -137,5 +138,83 @@ def testIRDLTypes():
with InsertionPoint(m.body):
Operation.create("irdl_type_test.op1", results=[t1])
+ assert m.operation.verify()
# CHECK: %0 = "irdl_type_test.op1"() : () -> !irdl_type_test.type1<42 : i32>
m.dump()
+
+
+# CHECK: TEST: testIRDLAttrs
+ at run
+def testIRDLAttrs():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ irdl_test = dialect("irdl_attr_test")
+ with InsertionPoint(irdl_test.body):
+ attr1 = attribute("attr1")
+ with InsertionPoint(attr1.body):
+ iattr = base(base_name="#builtin.integer")
+ parameters([iattr], ["val"])
+ attr2 = attribute("attr2")
+ with InsertionPoint(attr2.body):
+ iattr = base(base_name="#builtin.integer")
+ unit = is_(UnitAttr.get())
+ parameters([iattr, unit], ["val1", "val2"])
+ op1 = operation_("op1")
+ with InsertionPoint(op1.body):
+ a1 = base(base_ref=["irdl_attr_test", "attr1"])
+ attributes_([a1], ["attr"])
+
+ # CHECK: module {
+ # CHECK: irdl.dialect @irdl_attr_test {
+ # CHECK: irdl.attribute @attr1 {
+ # CHECK: %0 = irdl.base "#builtin.integer"
+ # CHECK: irdl.parameters(val: %0)
+ # CHECK: }
+ # CHECK: irdl.attribute @attr2 {
+ # CHECK: %0 = irdl.base "#builtin.integer"
+ # CHECK: %1 = irdl.is unit
+ # CHECK: irdl.parameters(val1: %0, val2: %1)
+ # CHECK: }
+ # CHECK: irdl.operation @op1 {
+ # CHECK: %0 = irdl.base @irdl_attr_test::@attr1
+ # CHECK: irdl.attributes {"attr" = %0}
+ # CHECK: }
+ # CHECK: }
+ # CHECK: }
+ module.operation.verify()
+ module.dump()
+
+ load_dialects(module)
+
+ i32 = IntegerType.get(32)
+ a1 = DynamicAttr.get("irdl_attr_test.attr1", [IntegerAttr.get(i32, 42)])
+ # CHECK: #irdl_attr_test.attr1<42 : i32>
+ a1.dump()
+ # CHECK: irdl_attr_test.attr1
+ print(a1.attr_name, file=sys.stderr)
+ # CHECK: 1
+ print(len(a1.params), file=sys.stderr)
+ # CHECK: 42 : i32
+ a1.params[0].dump()
+ a2 = DynamicAttr.get(
+ "irdl_attr_test.attr2", [IntegerAttr.get(i32, 33), UnitAttr.get()]
+ )
+ # CHECK: #irdl_attr_test.attr2<33 : i32, unit>
+ a2.dump()
+ # CHECK: irdl_attr_test.attr2
+ print(a2.attr_name, file=sys.stderr)
+ # CHECK: 2
+ print(len(a2.params), file=sys.stderr)
+ # CHECK: 33 : i32
+ a2.params[0].dump()
+ # CHECK: unit
+ a2.params[1].dump()
+
+ m = Module.create()
+ with InsertionPoint(m.body):
+ Operation.create("irdl_attr_test.op1", attributes={"attr": a1})
+
+ assert m.operation.verify()
+ # CHECK: "irdl_attr_test.op1"() {attr = #irdl_attr_test.attr1<42 : i32>} : () -> ()
+ m.dump()
``````````
</details>
https://github.com/llvm/llvm-project/pull/182820
More information about the Mlir-commits
mailing list