[Mlir-commits] [mlir] 9566ee2 - [MLIR][python bindings] TypeCasters for Attributes

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 7 10:58:05 PDT 2023


Author: max
Date: 2023-06-07T12:01:00-05:00
New Revision: 9566ee280607d91fa2e5eca730a6765ac84dfd0f

URL: https://github.com/llvm/llvm-project/commit/9566ee280607d91fa2e5eca730a6765ac84dfd0f
DIFF: https://github.com/llvm/llvm-project/commit/9566ee280607d91fa2e5eca730a6765ac84dfd0f.diff

LOG: [MLIR][python bindings] TypeCasters for Attributes

Differential Revision: https://reviews.llvm.org/D151840

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/include/mlir-c/IR.h
    mlir/include/mlir/Bindings/Python/PybindAdaptors.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/python/ir/attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 2e62879396db2..b760dd0cdb9a5 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -45,6 +45,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map);
 /// Returns the affine map wrapped in the given affine map attribute.
 MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr);
 
+/// Returns the typeID of an AffineMap attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // Array attribute.
 //===----------------------------------------------------------------------===//
@@ -64,6 +67,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr);
 MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr,
                                                          intptr_t pos);
 
+/// Returns the typeID of an Array attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // Dictionary attribute.
 //===----------------------------------------------------------------------===//
@@ -89,6 +95,9 @@ mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos);
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name);
 
+/// Returns the typeID of a Dictionary attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // Floating point attribute.
 //===----------------------------------------------------------------------===//
@@ -115,6 +124,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc,
 /// the value as double.
 MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr);
 
+/// Returns the typeID of a Float attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // Integer attribute.
 //===----------------------------------------------------------------------===//
@@ -142,6 +154,9 @@ MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr);
 /// is of unsigned type and fits into an unsigned 64-bit integer.
 MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
 
+/// Returns the typeID of an Integer attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // Bool attribute.
 //===----------------------------------------------------------------------===//
@@ -162,6 +177,9 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr);
 /// Checks whether the given attribute is an integer set attribute.
 MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr);
 
+/// Returns the typeID of an IntegerSet attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // Opaque attribute.
 //===----------------------------------------------------------------------===//
@@ -185,6 +203,9 @@ mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr);
 /// the context in which the attribute lives.
 MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr);
 
+/// Returns the typeID of an Opaque attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // String attribute.
 //===----------------------------------------------------------------------===//
@@ -206,6 +227,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrTypedGet(MlirType type,
 /// long as the context in which the attribute lives.
 MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr);
 
+/// Returns the typeID of a String attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirStringAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // SymbolRef attribute.
 //===----------------------------------------------------------------------===//
@@ -239,6 +263,9 @@ mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr);
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos);
 
+/// Returns the typeID of an SymbolRef attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // Flat SymbolRef attribute.
 //===----------------------------------------------------------------------===//
@@ -256,6 +283,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx,
 MLIR_CAPI_EXPORTED MlirStringRef
 mlirFlatSymbolRefAttrGetValue(MlirAttribute attr);
 
+/// Returns the typeID of an FlatSymbolRef attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // Type attribute.
 //===----------------------------------------------------------------------===//
@@ -270,6 +300,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirTypeAttrGet(MlirType type);
 /// Returns the type stored in the given type attribute.
 MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr);
 
+/// Returns the typeID of a Type attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // Unit attribute.
 //===----------------------------------------------------------------------===//
@@ -280,6 +313,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAUnit(MlirAttribute attr);
 /// Creates a unit attribute in the given context.
 MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx);
 
+/// Returns the typeID of a Unit attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // Elements attributes.
 //===----------------------------------------------------------------------===//
@@ -306,6 +342,8 @@ MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr);
 // Dense array attribute.
 //===----------------------------------------------------------------------===//
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void);
+
 /// Checks whether the given attribute is a dense array attribute.
 MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr);
 MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr);
@@ -370,6 +408,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr);
 MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr);
 MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr);
 
+/// Returns the typeID of an DenseIntOrFPElements attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void);
+
 /// Creates a dense elements attribute with the given Shaped type and elements
 /// in the same context as the type.
 MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet(
@@ -612,6 +653,9 @@ mlirSparseElementsAttrGetIndices(MlirAttribute attr);
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirSparseElementsAttrGetValues(MlirAttribute attr);
 
+/// Returns the typeID of a SparseElements attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirSparseElementsAttrGetTypeID(void);
+
 //===----------------------------------------------------------------------===//
 // Strided layout attribute.
 //===----------------------------------------------------------------------===//
@@ -635,6 +679,9 @@ mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr);
 MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr,
                                                           intptr_t pos);
 
+/// Returns the typeID of a StridedLayout attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 8253981b3cda2..6b5d8cc4b8c03 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -860,6 +860,9 @@ MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute);
 /// Gets the type id of the attribute.
 MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute);
 
+/// Gets the dialect of the attribute.
+MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute);
+
 /// Checks whether an attribute is null.
 static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
 

diff  --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 272067a261edb..44a10d619d029 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -97,6 +97,7 @@ struct type_caster<MlirAttribute> {
     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
         .attr("Attribute")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
         .release();
   }
 };

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 3c7926e784dbb..99881b35c96d3 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -80,6 +80,8 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
   static constexpr const char *pyClassName = "AffineMapAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAffineMapAttrGetTypeID;
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
@@ -259,6 +261,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
   static constexpr const char *pyClassName = "ArrayAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirArrayAttrGetTypeID;
 
   class PyArrayAttributeIterator {
   public:
@@ -339,6 +343,8 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
   static constexpr const char *pyClassName = "FloatAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloatAttrGetTypeID;
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
@@ -406,6 +412,10 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
           return mlirIntegerAttrGetValueUInt(self);
         },
         "Returns the value of the integer attribute");
+    c.def_property_readonly_static("static_typeid",
+                                   [](py::object & /*class*/) -> MlirTypeID {
+                                     return mlirIntegerAttrGetTypeID();
+                                   });
   }
 };
 
@@ -438,6 +448,8 @@ class PyFlatSymbolRefAttribute
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFlatSymbolRefAttrGetTypeID;
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
@@ -464,6 +476,8 @@ class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
   static constexpr const char *pyClassName = "OpaqueAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirOpaqueAttrGetTypeID;
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
@@ -501,6 +515,8 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
   static constexpr const char *pyClassName = "StringAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirStringAttrGetTypeID;
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
@@ -921,6 +937,8 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
   static constexpr const char *pyClassName = "DictAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirDictionaryAttrGetTypeID;
 
   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
 
@@ -1013,6 +1031,8 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
   static constexpr const char *pyClassName = "TypeAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirTypeAttrGetTypeID;
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
@@ -1035,6 +1055,8 @@ class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
   static constexpr const char *pyClassName = "UnitAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirUnitAttrGetTypeID;
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
@@ -1054,6 +1076,8 @@ class PyStridedLayoutAttribute
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
   static constexpr const char *pyClassName = "StridedLayoutAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirStridedLayoutAttrGetTypeID;
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
@@ -1099,6 +1123,50 @@ class PyStridedLayoutAttribute
   }
 };
 
+py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
+  if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
+    return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
+  if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
+    return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
+  if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
+    return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
+  if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
+    return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
+  if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
+    return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
+  if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
+    return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
+  if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
+    return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
+  std::string msg =
+      std::string("Can't cast unknown element type DenseArrayAttr (") +
+      std::string(py::repr(py::cast(pyAttribute))) + ")";
+  throw py::cast_error(msg);
+}
+
+py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
+  if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
+    return py::cast(PyDenseFPElementsAttribute(pyAttribute));
+  if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
+    return py::cast(PyDenseIntElementsAttribute(pyAttribute));
+  std::string msg =
+      std::string(
+          "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
+      std::string(py::repr(py::cast(pyAttribute))) + ")";
+  throw py::cast_error(msg);
+}
+
+py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
+  if (PyBoolAttribute::isaFunction(pyAttribute))
+    return py::cast(PyBoolAttribute(pyAttribute));
+  if (PyIntegerAttribute::isaFunction(pyAttribute))
+    return py::cast(PyIntegerAttribute(pyAttribute));
+  std::string msg =
+      std::string("Can't cast unknown element type DenseArrayAttr (") +
+      std::string(py::repr(py::cast(pyAttribute))) + ")";
+  throw py::cast_error(msg);
+}
+
 } // namespace
 
 void mlir::python::populateIRAttributes(py::module &m) {
@@ -1118,6 +1186,9 @@ void mlir::python::populateIRAttributes(py::module &m) {
   PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
   PyDenseF64ArrayAttribute::bind(m);
   PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
+  PyGlobals::get().registerTypeCaster(
+      mlirDenseArrayAttrGetTypeID(),
+      pybind11::cpp_function(denseArrayAttributeCaster));
 
   PyArrayAttribute::bind(m);
   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
@@ -1125,6 +1196,10 @@ void mlir::python::populateIRAttributes(py::module &m) {
   PyDenseElementsAttribute::bind(m);
   PyDenseFPElementsAttribute::bind(m);
   PyDenseIntElementsAttribute::bind(m);
+  PyGlobals::get().registerTypeCaster(
+      mlirDenseIntOrFPElementsAttrGetTypeID(),
+      pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
+
   PyDictAttribute::bind(m);
   PyFlatSymbolRefAttribute::bind(m);
   PyOpaqueAttribute::bind(m);
@@ -1132,6 +1207,9 @@ void mlir::python::populateIRAttributes(py::module &m) {
   PyIntegerAttribute::bind(m);
   PyStringAttribute::bind(m);
   PyTypeAttribute::bind(m);
+  PyGlobals::get().registerTypeCaster(
+      mlirIntegerAttrGetTypeID(),
+      pybind11::cpp_function(integerOrBoolAttributeCaster));
   PyUnitAttribute::bind(m);
 
   PyStridedLayoutAttribute::bind(m);

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index ec9066aa10cee..facd33c727351 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2640,10 +2640,7 @@ void mlir::python::populateIRCore(py::module &m) {
           "Context that owns the Location")
       .def_property_readonly(
           "attr",
-          [](PyLocation &self) {
-            return PyAttribute(self.getContext(),
-                               mlirLocationGetAttribute(self));
-          },
+          [](PyLocation &self) { return mlirLocationGetAttribute(self); },
           "Get the underlying LocationAttr")
       .def(
           "emit_error",
@@ -3139,7 +3136,7 @@ void mlir::python::populateIRCore(py::module &m) {
                 context->get(), toMlirStringRef(attrSpec));
             if (mlirAttributeIsNull(type))
               throw MLIRError("Unable to parse attribute", errors.take());
-            return PyAttribute(context->getRef(), type);
+            return type;
           },
           py::arg("asm"), py::arg("context") = py::none(),
           "Parses an attribute from an assembly form. Raises an MLIRError on "
@@ -3175,18 +3172,38 @@ void mlir::python::populateIRCore(py::module &m) {
             return printAccum.join();
           },
           "Returns the assembly form of the Attribute.")
-      .def("__repr__", [](PyAttribute &self) {
-        // Generally, assembly formats are not printed for __repr__ because
-        // this can cause exceptionally long debug output and exceptions.
-        // However, attribute values are generally considered useful and are
-        // printed. This may need to be re-evaluated if debug dumps end up
-        // being excessive.
-        PyPrintAccumulator printAccum;
-        printAccum.parts.append("Attribute(");
-        mlirAttributePrint(self, printAccum.getCallback(),
-                           printAccum.getUserData());
-        printAccum.parts.append(")");
-        return printAccum.join();
+      .def("__repr__",
+           [](PyAttribute &self) {
+             // Generally, assembly formats are not printed for __repr__ because
+             // this can cause exceptionally long debug output and exceptions.
+             // However, attribute values are generally considered useful and
+             // are printed. This may need to be re-evaluated if debug dumps end
+             // up being excessive.
+             PyPrintAccumulator printAccum;
+             printAccum.parts.append("Attribute(");
+             mlirAttributePrint(self, printAccum.getCallback(),
+                                printAccum.getUserData());
+             printAccum.parts.append(")");
+             return printAccum.join();
+           })
+      .def_property_readonly(
+          "typeid",
+          [](PyAttribute &self) -> MlirTypeID {
+            MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
+            assert(!mlirTypeIDIsNull(mlirTypeID) &&
+                   "mlirTypeID was expected to be non-null.");
+            return mlirTypeID;
+          })
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
+        MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
+        assert(!mlirTypeIDIsNull(mlirTypeID) &&
+               "mlirTypeID was expected to be non-null.");
+        std::optional<pybind11::function> typeCaster =
+            PyGlobals::get().lookupTypeCaster(mlirTypeID,
+                                              mlirAttributeGetDialect(self));
+        if (!typeCaster)
+          return py::cast(self);
+        return typeCaster.value()(self);
       });
 
   //----------------------------------------------------------------------------
@@ -3216,13 +3233,7 @@ void mlir::python::populateIRCore(py::module &m) {
           "The name of the NamedAttribute binding")
       .def_property_readonly(
           "attr",
-          [](PyNamedAttribute &self) {
-            // TODO: When named attribute is removed/refactored, also remove
-            // this constructor (it does an inefficient table lookup).
-            auto contextRef = PyMlirContext::forContext(
-                mlirAttributeGetContext(self.namedAttr.attribute));
-            return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
-          },
+          [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
           py::keep_alive<0, 1>(),
           "The underlying generic attribute of the NamedAttribute binding");
 

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 013bb7b9256f4..225580f0f4575 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -986,6 +986,8 @@ class PyConcreteAttribute : public BaseTy {
   //   const char *pyClassName
   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
   using IsAFunctionTy = bool (*)(MlirAttribute);
+  using GetTypeIDFunctionTy = MlirTypeID (*)();
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
 
   PyConcreteAttribute() = default;
   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
@@ -1017,6 +1019,34 @@ class PyConcreteAttribute : public BaseTy {
         pybind11::arg("other"));
     cls.def_property_readonly(
         "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
+    cls.def_property_readonly_static(
+        "static_typeid", [](py::object & /*class*/) -> MlirTypeID {
+          if (DerivedTy::getTypeIdFunction)
+            return DerivedTy::getTypeIdFunction();
+          throw py::attribute_error(
+              (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str());
+        });
+    cls.def_property_readonly("typeid", [](PyAttribute &self) {
+      return py::cast(self).attr("typeid").cast<MlirTypeID>();
+    });
+    cls.def("__repr__", [](DerivedTy &self) {
+      PyPrintAccumulator printAccum;
+      printAccum.parts.append(DerivedTy::pyClassName);
+      printAccum.parts.append("(");
+      mlirAttributePrint(self, printAccum.getCallback(),
+                         printAccum.getUserData());
+      printAccum.parts.append(")");
+      return printAccum.join();
+    });
+
+    if (DerivedTy::getTypeIdFunction) {
+      PyGlobals::get().registerTypeCaster(
+          DerivedTy::getTypeIdFunction(),
+          pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
+            return pyAttribute;
+          }));
+    }
+
     DerivedTy::bindDerived(cls);
   }
 

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index f2441e0b0ae9b..289913d4f5480 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -44,6 +44,10 @@ MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
   return wrap(llvm::cast<AffineMapAttr>(unwrap(attr)).getValue());
 }
 
+MlirTypeID mlirAffineMapAttrGetTypeID(void) {
+  return wrap(AffineMapAttr::getTypeID());
+}
+
 //===----------------------------------------------------------------------===//
 // Array attribute.
 //===----------------------------------------------------------------------===//
@@ -68,6 +72,8 @@ MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
   return wrap(llvm::cast<ArrayAttr>(unwrap(attr)).getValue()[pos]);
 }
 
+MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); }
+
 //===----------------------------------------------------------------------===//
 // Dictionary attribute.
 //===----------------------------------------------------------------------===//
@@ -102,6 +108,10 @@ MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
   return wrap(llvm::cast<DictionaryAttr>(unwrap(attr)).get(unwrap(name)));
 }
 
+MlirTypeID mlirDictionaryAttrGetTypeID(void) {
+  return wrap(DictionaryAttr::getTypeID());
+}
+
 //===----------------------------------------------------------------------===//
 // Floating point attribute.
 //===----------------------------------------------------------------------===//
@@ -124,6 +134,8 @@ double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
   return llvm::cast<FloatAttr>(unwrap(attr)).getValueAsDouble();
 }
 
+MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); }
+
 //===----------------------------------------------------------------------===//
 // Integer attribute.
 //===----------------------------------------------------------------------===//
@@ -148,6 +160,10 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
   return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
 }
 
+MlirTypeID mlirIntegerAttrGetTypeID(void) {
+  return wrap(IntegerAttr::getTypeID());
+}
+
 //===----------------------------------------------------------------------===//
 // Bool attribute.
 //===----------------------------------------------------------------------===//
@@ -172,6 +188,10 @@ bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
   return llvm::isa<IntegerSetAttr>(unwrap(attr));
 }
 
+MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
+  return wrap(IntegerSetAttr::getTypeID());
+}
+
 //===----------------------------------------------------------------------===//
 // Opaque attribute.
 //===----------------------------------------------------------------------===//
@@ -197,6 +217,10 @@ MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
   return wrap(llvm::cast<OpaqueAttr>(unwrap(attr)).getAttrData());
 }
 
+MlirTypeID mlirOpaqueAttrGetTypeID(void) {
+  return wrap(OpaqueAttr::getTypeID());
+}
+
 //===----------------------------------------------------------------------===//
 // String attribute.
 //===----------------------------------------------------------------------===//
@@ -217,6 +241,10 @@ MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
   return wrap(llvm::cast<StringAttr>(unwrap(attr)).getValue());
 }
 
+MlirTypeID mlirStringAttrGetTypeID(void) {
+  return wrap(StringAttr::getTypeID());
+}
+
 //===----------------------------------------------------------------------===//
 // SymbolRef attribute.
 //===----------------------------------------------------------------------===//
@@ -257,6 +285,10 @@ MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
       llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences()[pos]);
 }
 
+MlirTypeID mlirSymbolRefAttrGetTypeID(void) {
+  return wrap(SymbolRefAttr::getTypeID());
+}
+
 //===----------------------------------------------------------------------===//
 // Flat SymbolRef attribute.
 //===----------------------------------------------------------------------===//
@@ -273,6 +305,10 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
   return wrap(llvm::cast<FlatSymbolRefAttr>(unwrap(attr)).getValue());
 }
 
+MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) {
+  return wrap(FlatSymbolRefAttr::getTypeID());
+}
+
 //===----------------------------------------------------------------------===//
 // Type attribute.
 //===----------------------------------------------------------------------===//
@@ -289,6 +325,8 @@ MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
   return wrap(llvm::cast<TypeAttr>(unwrap(attr)).getValue());
 }
 
+MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); }
+
 //===----------------------------------------------------------------------===//
 // Unit attribute.
 //===----------------------------------------------------------------------===//
@@ -301,6 +339,8 @@ MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
   return wrap(UnitAttr::get(unwrap(ctx)));
 }
 
+MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); }
+
 //===----------------------------------------------------------------------===//
 // Elements attributes.
 //===----------------------------------------------------------------------===//
@@ -329,8 +369,13 @@ int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
 // Dense array attribute.
 //===----------------------------------------------------------------------===//
 
+MlirTypeID mlirDenseArrayAttrGetTypeID() {
+  return wrap(DenseArrayAttr::getTypeID());
+}
+
 //===----------------------------------------------------------------------===//
 // IsA support.
+//===----------------------------------------------------------------------===//
 
 bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) {
   return llvm::isa<DenseBoolArrayAttr>(unwrap(attr));
@@ -356,6 +401,7 @@ bool mlirAttributeIsADenseF64Array(MlirAttribute attr) {
 
 //===----------------------------------------------------------------------===//
 // Constructors.
+//===----------------------------------------------------------------------===//
 
 MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size,
                                     int const *values) {
@@ -395,6 +441,7 @@ MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
 
 //===----------------------------------------------------------------------===//
 // Accessors.
+//===----------------------------------------------------------------------===//
 
 intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
   return llvm::cast<DenseArrayAttr>(unwrap(attr)).size();
@@ -402,6 +449,7 @@ intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
 
 //===----------------------------------------------------------------------===//
 // Indexed accessors.
+//===----------------------------------------------------------------------===//
 
 bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) {
   return llvm::cast<DenseBoolArrayAttr>(unwrap(attr))[pos];
@@ -431,19 +479,27 @@ double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
 
 //===----------------------------------------------------------------------===//
 // IsA support.
+//===----------------------------------------------------------------------===//
 
 bool mlirAttributeIsADenseElements(MlirAttribute attr) {
   return llvm::isa<DenseElementsAttr>(unwrap(attr));
 }
+
 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
   return llvm::isa<DenseIntElementsAttr>(unwrap(attr));
 }
+
 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
   return llvm::isa<DenseFPElementsAttr>(unwrap(attr));
 }
 
+MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) {
+  return wrap(DenseIntOrFPElementsAttr::getTypeID());
+}
+
 //===----------------------------------------------------------------------===//
 // Constructors.
+//===----------------------------------------------------------------------===//
 
 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
                                        intptr_t numElements,
@@ -620,6 +676,7 @@ MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
 
 //===----------------------------------------------------------------------===//
 // Splat accessors.
+//===----------------------------------------------------------------------===//
 
 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
   return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat();
@@ -663,6 +720,7 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
 
 //===----------------------------------------------------------------------===//
 // Indexed accessors.
+//===----------------------------------------------------------------------===//
 
 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<bool>()[pos];
@@ -705,6 +763,7 @@ MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
 
 //===----------------------------------------------------------------------===//
 // Raw data accessors.
+//===----------------------------------------------------------------------===//
 
 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
   return static_cast<const void *>(
@@ -876,6 +935,10 @@ MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
   return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getValues());
 }
 
+MlirTypeID mlirSparseElementsAttrGetTypeID(void) {
+  return wrap(SparseElementsAttr::getTypeID());
+}
+
 //===----------------------------------------------------------------------===//
 // Strided layout attribute.
 //===----------------------------------------------------------------------===//
@@ -903,3 +966,7 @@ intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) {
 int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) {
   return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides()[pos];
 }
+
+MlirTypeID mlirStridedLayoutAttrGetTypeID(void) {
+  return wrap(StridedLayoutAttr::getTypeID());
+}

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 373e01a1362c7..16b333afc102d 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -870,6 +870,10 @@ MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
   return wrap(unwrap(attr).getTypeID());
 }
 
+MlirDialect mlirAttributeGetDialect(MlirAttribute attr) {
+  return wrap(&unwrap(attr).getDialect());
+}
+
 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
   return unwrap(a1) == unwrap(a2);
 }

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 3ee0691d606f8..221c186ae7d52 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -23,7 +23,7 @@ def testParsePrint():
     gc.collect()
     # CHECK: "hello"
     print(str(t))
-    # CHECK: Attribute("hello")
+    # CHECK: StringAttr("hello")
     print(repr(t))
 
 
@@ -134,7 +134,7 @@ def testStandardAttrCasts():
         a1 = Attribute.parse('"attr1"')
         astr = StringAttr(a1)
         aself = StringAttr(astr)
-        # CHECK: Attribute("attr1")
+        # CHECK: StringAttr("attr1")
         print(repr(astr))
         try:
             tillegal = StringAttr(Attribute.parse("1.0"))
@@ -324,32 +324,32 @@ def testDenseIntAttr():
 
 @run
 def testDenseArrayGetItem():
-    def print_item(AttrClass, attr_asm):
-        attr = AttrClass(Attribute.parse(attr_asm))
+    def print_item(attr_asm):
+        attr = Attribute.parse(attr_asm)
         print(f"{len(attr)}: {attr[0]}, {attr[1]}")
 
     with Context():
         # CHECK: 2: 0, 1
-        print_item(DenseBoolArrayAttr, "array<i1: false, true>")
+        print_item("array<i1: false, true>")
         # CHECK: 2: 2, 3
-        print_item(DenseI8ArrayAttr, "array<i8: 2, 3>")
+        print_item("array<i8: 2, 3>")
         # CHECK: 2: 4, 5
-        print_item(DenseI16ArrayAttr, "array<i16: 4, 5>")
+        print_item("array<i16: 4, 5>")
         # CHECK: 2: 6, 7
-        print_item(DenseI32ArrayAttr, "array<i32: 6, 7>")
+        print_item("array<i32: 6, 7>")
         # CHECK: 2: 8, 9
-        print_item(DenseI64ArrayAttr, "array<i64: 8, 9>")
+        print_item("array<i64: 8, 9>")
         # CHECK: 2: 1.{{0+}}, 2.{{0+}}
-        print_item(DenseF32ArrayAttr, "array<f32: 1.0, 2.0>")
+        print_item("array<f32: 1.0, 2.0>")
         # CHECK: 2: 3.{{0+}}, 4.{{0+}}
-        print_item(DenseF64ArrayAttr, "array<f64: 3.0, 4.0>")
+        print_item("array<f64: 3.0, 4.0>")
 
 
 # CHECK-LABEL: TEST: testDenseIntAttrGetItem
 @run
 def testDenseIntAttrGetItem():
     def print_item(attr_asm):
-        attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
+        attr = Attribute.parse(attr_asm)
         dtype = ShapedType(attr.type).element_type
         try:
             item = attr[0]
@@ -592,3 +592,14 @@ def print_container_item(attr_asm):
         print(repr(type_attr.value))
         # CHECK: F32Type(f32)
         print(repr(type_attr.value.element_type))
+
+
+# CHECK-LABEL: TEST: testConcreteAttributesRoundTrip
+ at run
+def testConcreteAttributesRoundTrip():
+    with Context(), Location.unknown():
+
+        # CHECK: FloatAttr(4.200000e+01 : f32)
+        print(repr(Attribute.parse("42.0 : f32")))
+
+        assert IntegerAttr.static_typeid is not None


        


More information about the Mlir-commits mailing list