[Mlir-commits] [mlir] bfb1ba7 - [MLIR][python bindings] Add TypeCaster for returning refined types from python APIs
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 26 09:02:13 PDT 2023
Author: max
Date: 2023-05-26T11:02:05-05:00
New Revision: bfb1ba752655bf09b35c486f6cc9817dbedfb1bb
URL: https://github.com/llvm/llvm-project/commit/bfb1ba752655bf09b35c486f6cc9817dbedfb1bb
DIFF: https://github.com/llvm/llvm-project/commit/bfb1ba752655bf09b35c486f6cc9817dbedfb1bb.diff
LOG: [MLIR][python bindings] Add TypeCaster for returning refined types from python APIs
depends on D150839
This diff uses `MlirTypeID` to register `TypeCaster`s (i.e., `[](PyType pyType) -> DerivedTy { return pyType; }`) for all concrete types (i.e., `PyConcrete<...>`) that are then queried for (by `MlirTypeID`) and called in `struct type_caster<MlirType>::cast`. The result is that anywhere an `MlirType mlirType` is returned from a python binding, that `mlirType` is automatically cast to the correct concrete type. For example:
```
c0 = arith.ConstantOp(f32, 0.0)
# CHECK: F32Type(f32)
print(repr(c0.result.type))
unranked_tensor_type = UnrankedTensorType.get(f32)
unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result
# CHECK: UnrankedTensorType
print(type(unranked_tensor.type).__name__)
# CHECK: UnrankedTensorType(tensor<*xf32>)
print(repr(unranked_tensor.type))
```
This functionality immediately extends to typed attributes (i.e., `attr.type`).
The diff also implements similar functionality for `mlir_type_subclass`es but in a slightly different way - for such types (which have no cpp corresponding `class` or `struct`) the user must provide a type caster in python (similar to how `AttrBuilder` works) or in cpp as a `py::cpp_function`.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D150927
Added:
Modified:
mlir/include/mlir-c/Bindings/Python/Interop.h
mlir/include/mlir-c/Dialect/Transform.h
mlir/include/mlir-c/IR.h
mlir/include/mlir/Bindings/Python/PybindAdaptors.h
mlir/include/mlir/CAPI/Support.h
mlir/lib/Bindings/Python/DialectTransform.cpp
mlir/lib/Bindings/Python/Globals.h
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRInterfaces.cpp
mlir/lib/Bindings/Python/IRModule.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/Bindings/Python/MainModule.cpp
mlir/lib/CAPI/Dialect/Transform.cpp
mlir/lib/CAPI/IR/BuiltinTypes.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/CAPI/IR/Support.cpp
mlir/python/mlir/dialects/python_test.py
mlir/python/mlir/ir.py
mlir/test/python/dialects/python_test.py
mlir/test/python/ir/attributes.py
mlir/test/python/ir/builtin_types.py
mlir/test/python/lib/PythonTestCAPI.cpp
mlir/test/python/lib/PythonTestCAPI.h
mlir/test/python/lib/PythonTestModule.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index 6ebb458082d7c..33332d6a34bb8 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -107,6 +107,23 @@
* delineated). */
#define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate"
+/** Attribute on MLIR Python objects that expose a function for downcasting the
+ * corresponding Python object to a subclass if the object is in fact a subclass
+ * (Concrete or mlir_type_subclass) of ir.Type. The signature of the function
+ * is: def maybe_downcast(self) -> object where the resulting object will
+ * (possibly) be an instance of the subclass.
+ */
+#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR "maybe_downcast"
+
+/** Attribute on main C extension module (_mlir) that corresponds to the
+ * type caster registration binding. The signature of the function is:
+ * def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
+ * bool replace)
+ * where replace indicates the typeCaster should replace any existing registered
+ * type casters (such as those for upstream ConcreteTypes).
+ */
+#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
+
/// Gets a void* from a wrapped struct. Needed because const cast is
diff erent
/// between C/C++.
#ifdef __cplusplus
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 864dffa3f7824..0409890b21406 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -33,6 +33,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type);
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void);
+
MLIR_CAPI_EXPORTED MlirType
mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 13a3cb0130cea..8253981b3cda2 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -825,6 +825,9 @@ MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type);
/// Gets the type ID of the type.
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type);
+/// Gets the dialect a type belongs to.
+MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type);
+
/// Checks whether a type is null.
static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; }
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index ccca3aa0172e6..272067a261edb 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -28,6 +28,7 @@
#include "llvm/ADT/Twine.h"
namespace py = pybind11;
+using namespace py::literals;
// Raw CAPI type casters need to be declared before use, so always include them
// first.
@@ -272,6 +273,7 @@ struct type_caster<MlirType> {
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Type")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
.release();
}
};
@@ -424,20 +426,24 @@ class mlir_attribute_subclass : public pure_subclass {
class mlir_type_subclass : public pure_subclass {
public:
using IsAFunctionTy = bool (*)(MlirType);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
/// Subclasses by looking up the super-class dynamically.
mlir_type_subclass(py::handle scope, const char *typeClassName,
- IsAFunctionTy isaFunction)
+ IsAFunctionTy isaFunction,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: mlir_type_subclass(
scope, typeClassName, isaFunction,
- py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type")) {}
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"),
+ getTypeIDFunction) {}
/// Subclasses with a provided mlir.ir.Type super-class. This must
/// be used if the subclass is being defined in the same extension module
/// as the mlir.ir class (otherwise, it will trigger a recursive
/// initialization).
mlir_type_subclass(py::handle scope, const char *typeClassName,
- IsAFunctionTy isaFunction, const py::object &superCls)
+ IsAFunctionTy isaFunction, const py::object &superCls,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: pure_subclass(scope, typeClassName, superCls) {
// Casting constructor. Note that it hard, if not impossible, to properly
// call chain to parent `__init__` in pybind11 due to its special handling
@@ -471,6 +477,19 @@ class mlir_type_subclass : public pure_subclass {
"isinstance",
[isaFunction](MlirType other) { return isaFunction(other); },
py::arg("other_type"));
+ def("__repr__", [superCls, captureTypeName](py::object self) {
+ return py::repr(superCls(self))
+ .attr("replace")(superCls.attr("__name__"), captureTypeName);
+ });
+ if (getTypeIDFunction) {
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+ getTypeIDFunction(),
+ pybind11::cpp_function(
+ [thisClass = thisClass](const py::object &mlirType) {
+ return thisClass(mlirType);
+ }));
+ }
}
};
diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h
index f3e8a67e0ac36..e42413dbe6d28 100644
--- a/mlir/include/mlir/CAPI/Support.h
+++ b/mlir/include/mlir/CAPI/Support.h
@@ -44,4 +44,25 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
+namespace llvm {
+
+template <>
+struct DenseMapInfo<MlirTypeID> {
+ static inline MlirTypeID getEmptyKey() {
+ auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlirTypeIDCreate(pointer);
+ }
+ static inline MlirTypeID getTombstoneKey() {
+ auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlirTypeIDCreate(pointer);
+ }
+ static inline unsigned getHashValue(const MlirTypeID &val) {
+ return mlirTypeIDHashValue(val);
+ }
+ static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
+ return mlirTypeIDEqual(lhs, rhs);
+ }
+};
+} // namespace llvm
+
#endif // MLIR_CAPI_SUPPORT_H
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index a9db2428c6c3f..e4b8cee73ef16 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -36,7 +36,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//
auto operationType =
- mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType);
+ mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
+ mlirTransformOperationTypeGetTypeID);
operationType.def_classmethod(
"get",
[](py::object cls, const std::string &operationName, MlirContext ctx) {
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 45d03689642ee..0fc7614ccad52 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -9,12 +9,15 @@
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
+#include <optional>
#include <string>
#include <vector>
-#include <optional>
#include "PybindUtils.h"
+#include "mlir-c/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
@@ -54,16 +57,18 @@ class PyGlobals {
/// entities.
void loadDialectModule(llvm::StringRef dialectNamespace);
- /// Decorator for registering a custom Dialect class. The class object must
- /// have a DIALECT_NAMESPACE attribute.
- pybind11::object registerDialectDecorator(pybind11::object pyClass);
-
/// Adds a user-friendly Attribute builder.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
pybind11::function pyFunc);
+ /// Adds a user-friendly type caster. Raises an exception if the mapping
+ /// already exists and replace == false. This is intended to be called by
+ /// implementation code.
+ void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
+ bool replace = false);
+
/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
@@ -80,6 +85,10 @@ class PyGlobals {
std::optional<pybind11::function>
lookupAttributeBuilder(const std::string &attributeKind);
+ /// Returns the custom type caster for MlirTypeID mlirTypeID.
+ std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
+ MlirDialect dialect);
+
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
std::optional<pybind11::object>
@@ -101,6 +110,10 @@ class PyGlobals {
llvm::StringMap<pybind11::object> operationClassMap;
/// Map of attribute ODS name to custom builder.
llvm::StringMap<pybind11::object> attributeBuilderMap;
+ /// Map of MlirTypeID to custom type caster.
+ llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
+ /// Cache for map of MlirTypeID to custom type caster.
+ llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 0ab47cc241d91..3c7926e784dbb 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -15,6 +15,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
using namespace mlir;
@@ -1023,8 +1024,7 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
py::arg("value"), py::arg("context") = py::none(),
"Gets a uniqued Type attribute");
c.def_property_readonly("value", [](PyTypeAttribute &self) {
- return PyType(self.getContext()->getRef(),
- mlirTypeAttrGetValue(self.get()));
+ return mlirTypeAttrGetValue(self.get());
});
}
};
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index a6bd4d849b97c..ec9066aa10cee 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -25,6 +25,7 @@
#include <utility>
namespace py = pybind11;
+using namespace py::literals;
using namespace mlir;
using namespace mlir::python;
@@ -2121,13 +2122,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
/// Returns the list of types of the values held by container.
template <typename Container>
-static std::vector<PyType> getValueTypes(Container &container,
- PyMlirContextRef &context) {
- std::vector<PyType> result;
+static std::vector<MlirType> getValueTypes(Container &container,
+ PyMlirContextRef &context) {
+ std::vector<MlirType> result;
result.reserve(container.size());
for (int i = 0, e = container.size(); i < e; ++i) {
- result.push_back(
- PyType(context, mlirValueGetType(container.getElement(i).get())));
+ result.push_back(mlirValueGetType(container.getElement(i).get()));
}
return result;
}
@@ -3148,11 +3148,8 @@ void mlir::python::populateIRCore(py::module &m) {
"context",
[](PyAttribute &self) { return self.getContext().getObject(); },
"Context that owns the Attribute")
- .def_property_readonly("type",
- [](PyAttribute &self) {
- return PyType(self.getContext()->getRef(),
- mlirAttributeGetType(self));
- })
+ .def_property_readonly(
+ "type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
.def(
"get_named",
[](PyAttribute &self, std::string name) {
@@ -3247,7 +3244,7 @@ void mlir::python::populateIRCore(py::module &m) {
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
if (mlirTypeIsNull(type))
throw MLIRError("Unable to parse type", errors.take());
- return PyType(context->getRef(), type);
+ return type;
},
py::arg("asm"), py::arg("context") = py::none(),
kContextParseTypeDocstring)
@@ -3284,6 +3281,18 @@ void mlir::python::populateIRCore(py::module &m) {
printAccum.parts.append(")");
return printAccum.join();
})
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyType &self) {
+ MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
+ assert(!mlirTypeIDIsNull(mlirTypeID) &&
+ "mlirTypeID was expected to be non-null.");
+ std::optional<pybind11::function> typeCaster =
+ PyGlobals::get().lookupTypeCaster(mlirTypeID,
+ mlirTypeGetDialect(self));
+ if (!typeCaster)
+ return py::cast(self);
+ return typeCaster.value()(self);
+ })
.def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
if (!mlirTypeIDIsNull(mlirTypeID))
@@ -3387,12 +3396,8 @@ void mlir::python::populateIRCore(py::module &m) {
return printAccum.join();
},
py::arg("use_local_scope") = false, kGetNameAsOperand)
- .def_property_readonly("type",
- [](PyValue &self) {
- return PyType(
- self.getParentOperation()->getContext(),
- mlirValueGetType(self.get()));
- })
+ .def_property_readonly(
+ "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
.def(
"replace_all_uses_with",
[](PyValue &self, PyValue &with) {
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 0a7a25c0005fe..25fcaccd236db 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -321,11 +321,7 @@ class PyShapedTypeComponents {
py::module_local())
.def_property_readonly(
"element_type",
- [](PyShapedTypeComponents &self) {
- return PyType(PyMlirContext::forContext(
- mlirTypeGetContext(self.elementType)),
- self.elementType);
- },
+ [](PyShapedTypeComponents &self) { return self.elementType; },
"Returns the element type of the shaped type components.")
.def_static(
"get",
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 7c49f20f1a9ec..d9a66bce0fecb 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -14,6 +14,7 @@
#include <vector>
#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Support.h"
namespace py = pybind11;
using namespace mlir;
@@ -72,6 +73,15 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
found = std::move(pyFunc);
}
+void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
+ pybind11::function typeCaster,
+ bool replace) {
+ pybind11::object &found = typeCasterMap[mlirTypeID];
+ if (found && !found.is_none() && !replace)
+ throw std::runtime_error("Type caster is already registered");
+ found = std::move(typeCaster);
+}
+
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
@@ -110,6 +120,39 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
return std::nullopt;
}
+std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
+ MlirDialect dialect) {
+ {
+ // Fast match against the class map first (common case).
+ const auto foundIt = typeCasterMapCache.find(mlirTypeID);
+ if (foundIt != typeCasterMapCache.end()) {
+ if (foundIt->second.is_none())
+ return std::nullopt;
+ assert(foundIt->second && "py::function is defined");
+ return foundIt->second;
+ }
+ }
+
+ // Not found. Load the dialect namespace.
+ loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
+
+ // Attempt to find from the canonical map and cache.
+ {
+ const auto foundIt = typeCasterMap.find(mlirTypeID);
+ if (foundIt != typeCasterMap.end()) {
+ if (foundIt->second.is_none())
+ return std::nullopt;
+ assert(foundIt->second && "py::object is defined");
+ // Positive cache.
+ typeCasterMapCache[mlirTypeID] = foundIt->second;
+ return foundIt->second;
+ }
+ // Negative cache.
+ typeCasterMap[mlirTypeID] = py::none();
+ return std::nullopt;
+ }
+}
+
std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
loadDialectModule(dialectNamespace);
@@ -164,4 +207,5 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
void PyGlobals::clearImportCache() {
loadedDialectModulesCache.clear();
operationClassMapCache.clear();
+ typeCasterMapCache.clear();
}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index cfa3737cf30d2..013bb7b9256f4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -13,6 +13,7 @@
#include <utility>
#include <vector>
+#include "Globals.h"
#include "PybindUtils.h"
#include "mlir-c/AffineExpr.h"
@@ -868,9 +869,7 @@ class PyConcreteType : public BaseTy {
PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
- : BaseTy(std::move(contextRef), t) {
- pybind11::implicitly_convertible<PyType, DerivedTy>();
- }
+ : BaseTy(std::move(contextRef), t) {}
PyConcreteType(PyType &orig)
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
@@ -914,6 +913,13 @@ class PyConcreteType : public BaseTy {
return printAccum.join();
});
+ if (DerivedTy::getTypeIdFunction) {
+ PyGlobals::get().registerTypeCaster(
+ DerivedTy::getTypeIdFunction(),
+ pybind11::cpp_function(
+ [](PyType pyType) -> DerivedTy { return pyType; }));
+ }
+
DerivedTy::bindDerived(cls);
}
@@ -1009,9 +1015,8 @@ class PyConcreteAttribute : public BaseTy {
return DerivedTy::isaFunction(otherAttr);
},
pybind11::arg("other"));
- cls.def_property_readonly("type", [](PyAttribute &attr) {
- return PyType(attr.getContext(), mlirAttributeGetType(attr));
- });
+ cls.def_property_readonly(
+ "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
DerivedTy::bindDerived(cls);
}
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 5c089b2f2c506..25307262bddbd 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -334,10 +334,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
"Create a complex type");
c.def_property_readonly(
"element_type",
- [](PyComplexType &self) -> PyType {
- MlirType t = mlirComplexTypeGetElementType(self);
- return PyType(self.getContext(), t);
- },
+ [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
"Returns element type.");
}
};
@@ -351,10 +348,7 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
static void bindDerived(ClassTy &c) {
c.def_property_readonly(
"element_type",
- [](PyShapedType &self) {
- MlirType t = mlirShapedTypeGetElementType(self);
- return PyType(self.getContext(), t);
- },
+ [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
"Returns the element type of the shaped type.");
c.def_property_readonly(
"has_rank",
@@ -641,9 +635,8 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
"Create a tuple type");
c.def(
"get_type",
- [](PyTupleType &self, intptr_t pos) -> PyType {
- MlirType t = mlirTupleTypeGetType(self, pos);
- return PyType(self.getContext(), t);
+ [](PyTupleType &self, intptr_t pos) {
+ return mlirTupleTypeGetType(self, pos);
},
py::arg("pos"), "Returns the pos-th type in the tuple type.");
c.def_property_readonly(
@@ -686,7 +679,7 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
++i) {
- types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
+ types.append(mlirFunctionTypeGetInput(t, i));
}
return types;
},
@@ -698,8 +691,7 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
- types.append(
- PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
+ types.append(mlirFunctionTypeGetResult(self, i));
}
return types;
},
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index b32b4186fcb9f..cdddfbe50606d 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -16,6 +16,7 @@
namespace py = pybind11;
using namespace mlir;
+using namespace py::literals;
using namespace mlir::python;
// -----------------------------------------------------------------------------
@@ -35,12 +36,12 @@ PYBIND11_MODULE(_mlir, m) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
self.clearImportCache();
},
- py::arg("module_name"))
+ "module_name"_a)
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
- py::arg("dialect_namespace"), py::arg("dialect_class"),
+ "dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
- py::arg("operation_name"), py::arg("operation_class"),
+ "operation_name"_a, "operation_class"_a,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
@@ -58,11 +59,11 @@ PYBIND11_MODULE(_mlir, m) {
PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
return pyClass;
},
- py::arg("dialect_class"),
+ "dialect_class"_a,
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
- [](py::object dialectClass) -> py::cpp_function {
+ [](const py::object &dialectClass) -> py::cpp_function {
return py::cpp_function(
[dialectClass](py::object opClass) -> py::object {
std::string operationName =
@@ -75,9 +76,17 @@ PYBIND11_MODULE(_mlir, m) {
return opClass;
});
},
- py::arg("dialect_class"),
+ "dialect_class"_a,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
+ m.def(
+ MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
+ [](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) {
+ PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster),
+ replace);
+ },
+ "typeid"_a, "type_caster"_a, "replace"_a = false,
+ "Register a type caster for casting MLIR types to custom user types.");
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 90594b67aacfb..d3cd4e3d0bb28 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -37,6 +37,10 @@ bool mlirTypeIsATransformOperationType(MlirType type) {
return isa<transform::OperationType>(unwrap(type));
}
+MlirTypeID mlirTransformOperationTypeGetTypeID(void) {
+ return wrap(transform::OperationType::getTypeID());
+}
+
MlirType mlirTransformOperationTypeGet(MlirContext ctx,
MlirStringRef operationName) {
return wrap(
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 1925478c66d41..82c5b5a6147a4 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -324,6 +324,10 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
}
+MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) {
+ return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType());
+}
+
//===----------------------------------------------------------------------===//
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index c0cf5977736f4..373e01a1362c7 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -832,6 +832,10 @@ MlirTypeID mlirTypeGetTypeID(MlirType type) {
return wrap(unwrap(type).getTypeID());
}
+MlirDialect mlirTypeGetDialect(MlirType type) {
+ return wrap(&unwrap(type).getDialect());
+}
+
bool mlirTypeEqual(MlirType t1, MlirType t2) {
return unwrap(t1) == unwrap(t2);
}
diff --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp
index cbfbb54769aa9..ea081b2e99b59 100644
--- a/mlir/lib/CAPI/IR/Support.cpp
+++ b/mlir/lib/CAPI/IR/Support.cpp
@@ -23,7 +23,6 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//
-
MlirTypeID mlirTypeIDCreate(const void *ptr) {
assert(reinterpret_cast<uintptr_t>(ptr) % 8 == 0 &&
"ptr must be 8 byte aligned");
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 980f237b19391..8465af048a280 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
-from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
+from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType
def register_python_test_dialect(context, load=True):
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 99c21ff9aaef2..10a0f5bd2c6b9 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,6 +4,7 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
+from ._mlir_libs._mlir import register_type_caster
# Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 72a765c75e52c..5346955c906fc 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -369,9 +369,9 @@ def __str__(self):
# Classes of custom types that inherit from concrete types should have
# static_typeid
- assert isinstance(test.TestTensorType.static_typeid, TypeID)
+ assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID)
# And it should be equal to the in-tree concrete type
- assert test.TestTensorType.static_typeid == t.type.typeid
+ assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
# CHECK-LABEL: TEST: inferReturnTypeComponents
@@ -424,3 +424,46 @@ def inferReturnTypeComponents():
print("rank:", shaped_type_components.rank)
print("element type:", shaped_type_components.element_type)
print("shape:", shaped_type_components.shape)
+
+
+# CHECK-LABEL: TEST: testCustomTypeTypeCaster
+ at run
+def testCustomTypeTypeCaster():
+ with Context() as ctx, Location.unknown():
+ test.register_python_test_dialect(ctx)
+
+ a = test.TestType.get()
+ assert a.typeid is not None
+
+ b = Type.parse("!python_test.test_type")
+ # CHECK: !python_test.test_type
+ print(b)
+ # CHECK: TestType(!python_test.test_type)
+ print(repr(b))
+
+ c = test.TestIntegerRankedTensorType.get([10, 10], 5)
+ # CHECK: tensor<10x10xi5>
+ print(c)
+ # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
+ print(repr(c))
+
+ # CHECK: Type caster is already registered
+ try:
+
+ def type_caster(pytype):
+ return test.TestIntegerRankedTensorType(pytype)
+
+ register_type_caster(c.typeid, type_caster)
+ except RuntimeError as e:
+ print(e)
+
+ def type_caster(pytype):
+ return test.TestIntegerRankedTensorType(pytype)
+
+ register_type_caster(c.typeid, type_caster, replace=True)
+
+ d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
+ # CHECK: tensor<10x10xi5>
+ print(d.type)
+ # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
+ print(repr(d.type))
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 29074052796b4..3ee0691d606f8 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -553,3 +553,42 @@ def testStridedLayoutAttr():
print(f"rank: {len(attr.strides)}")
# CHECK: strides are dynamic: [True, True, True]
print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}")
+
+
+# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
+ at run
+def testConcreteTypesRoundTrip():
+ with Context(), Location.unknown():
+
+ def print_item(attr):
+ print(repr(attr.type))
+
+ # CHECK: F32Type(f32)
+ print_item(Attribute.parse("42.0 : f32"))
+ # CHECK: F32Type(f32)
+ print_item(FloatAttr.get_f32(42.0))
+ # CHECK: IntegerType(i64)
+ print_item(IntegerAttr.get(IntegerType.get_signless(64), 42))
+
+ def print_container_item(attr_asm):
+ attr = DenseElementsAttr(Attribute.parse(attr_asm))
+ print(repr(attr.type))
+ print(repr(attr.type.element_type))
+
+ # CHECK: RankedTensorType(tensor<i16>)
+ # CHECK: IntegerType(i16)
+ print_container_item("dense<123> : tensor<i16>")
+
+ # CHECK: RankedTensorType(tensor<f64>)
+ # CHECK: F64Type(f64)
+ print_container_item("dense<1.0> : tensor<f64>")
+
+ raw = Attribute.parse("vector<4xf32>")
+ # CHECK: attr: vector<4xf32>
+ print("attr:", raw)
+ type_attr = TypeAttr(raw)
+
+ # CHECK: VectorType(vector<4xf32>)
+ print(repr(type_attr.value))
+ # CHECK: F32Type(f32)
+ print(repr(type_attr.value.element_type))
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index fc484a5050839..99273bab0b495 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -2,6 +2,7 @@
import gc
from mlir.ir import *
+from mlir.dialects import arith, tensor, func, memref
def run(f):
@@ -382,15 +383,15 @@ def testMemRefType():
f32 = F32Type.get()
shape = [2, 3]
loc = Location.unknown()
- memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
+ memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
# CHECK: memref type: memref<2x3xf32, 2>
- print("memref type:", memref)
+ print("memref type:", memref_f32)
# CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
- print("memref layout:", memref.layout)
+ print("memref layout:", memref_f32.layout)
# CHECK: memref affine map: (d0, d1) -> (d0, d1)
- print("memref affine map:", memref.affine_map)
+ print("memref affine map:", memref_f32.affine_map)
# CHECK: memory space: 2
- print("memory space:", memref.memory_space)
+ print("memory space:", memref_f32.memory_space)
layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
memref_layout = MemRefType.get(shape, f32, layout=layout)
@@ -413,7 +414,7 @@ def testMemRefType():
else:
print("Exception not produced")
- assert memref.shape == shape
+ assert memref_f32.shape == shape
# CHECK-LABEL: TEST: testUnrankedMemRefType
@@ -482,9 +483,9 @@ def testFunctionType():
input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
result_types = [IndexType.get()]
func = FunctionType.get(input_types, result_types)
- # CHECK: INPUTS: [Type(i32), Type(i16)]
+ # CHECK: INPUTS: [IntegerType(i32), IntegerType(i16)]
print("INPUTS:", func.inputs)
- # CHECK: RESULTS: [Type(index)]
+ # CHECK: RESULTS: [IndexType(index)]
print("RESULTS:", func.results)
@@ -599,3 +600,130 @@ def testTypeIDs():
vector_type = Type.parse("vector<2x3xf32>")
# CHECK: True
print(ShapedType(vector_type).typeid == vector_type.typeid)
+
+
+# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
+ at run
+def testConcreteTypesRoundTrip():
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+
+ def print_downcasted(typ):
+ downcasted = Type(typ).maybe_downcast()
+ print(type(downcasted).__name__)
+ print(repr(downcasted))
+
+ # CHECK: F16Type
+ # CHECK: F16Type(f16)
+ print_downcasted(F16Type.get())
+ # CHECK: F32Type
+ # CHECK: F32Type(f32)
+ print_downcasted(F32Type.get())
+ # CHECK: F64Type
+ # CHECK: F64Type(f64)
+ print_downcasted(F64Type.get())
+ # CHECK: Float8E4M3B11FNUZType
+ # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
+ print_downcasted(Float8E4M3B11FNUZType.get())
+ # CHECK: Float8E4M3FNType
+ # CHECK: Float8E4M3FNType(f8E4M3FN)
+ print_downcasted(Float8E4M3FNType.get())
+ # CHECK: Float8E4M3FNUZType
+ # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
+ print_downcasted(Float8E4M3FNUZType.get())
+ # CHECK: Float8E5M2Type
+ # CHECK: Float8E5M2Type(f8E5M2)
+ print_downcasted(Float8E5M2Type.get())
+ # CHECK: Float8E5M2FNUZType
+ # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
+ print_downcasted(Float8E5M2FNUZType.get())
+ # CHECK: BF16Type
+ # CHECK: BF16Type(bf16)
+ print_downcasted(BF16Type.get())
+ # CHECK: IndexType
+ # CHECK: IndexType(index)
+ print_downcasted(IndexType.get())
+ # CHECK: IntegerType
+ # CHECK: IntegerType(i32)
+ print_downcasted(IntegerType.get_signless(32))
+
+ f32 = F32Type.get()
+ ranked_tensor = tensor.EmptyOp([10, 10], f32).result
+ # CHECK: RankedTensorType
+ print(type(ranked_tensor.type).__name__)
+ # CHECK: RankedTensorType(tensor<10x10xf32>)
+ print(repr(ranked_tensor.type))
+
+ cf32 = ComplexType.get(f32)
+ # CHECK: ComplexType
+ print(type(cf32).__name__)
+ # CHECK: ComplexType(complex<f32>)
+ print(repr(cf32))
+
+ ranked_tensor = tensor.EmptyOp([10, 10], f32).result
+ # CHECK: RankedTensorType
+ print(type(ranked_tensor.type).__name__)
+ # CHECK: RankedTensorType(tensor<10x10xf32>)
+ print(repr(ranked_tensor.type))
+
+ vector = VectorType.get([10, 10], f32)
+ tuple_type = TupleType.get_tuple([f32, vector])
+ # CHECK: TupleType
+ print(type(tuple_type).__name__)
+ # CHECK: TupleType(tuple<f32, vector<10x10xf32>>)
+ print(repr(tuple_type))
+ # CHECK: F32Type(f32)
+ print(repr(tuple_type.get_type(0)))
+ # CHECK: VectorType(vector<10x10xf32>)
+ print(repr(tuple_type.get_type(1)))
+
+ index_type = IndexType.get()
+
+ @func.FuncOp.from_py_func()
+ def default_builder():
+ c0 = arith.ConstantOp(f32, 0.0)
+ unranked_tensor_type = UnrankedTensorType.get(f32)
+ unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result
+ # CHECK: UnrankedTensorType
+ print(type(unranked_tensor.type).__name__)
+ # CHECK: UnrankedTensorType(tensor<*xf32>)
+ print(repr(unranked_tensor.type))
+
+ c10 = arith.ConstantOp(index_type, 10)
+ memref_f32_t = MemRefType.get([10, 10], f32)
+ memref_f32 = memref.AllocOp(memref_f32_t, [c10, c10], []).result
+ # CHECK: MemRefType
+ print(type(memref_f32.type).__name__)
+ # CHECK: MemRefType(memref<10x10xf32>)
+ print(repr(memref_f32.type))
+
+ unranked_memref_t = UnrankedMemRefType.get(f32, Attribute.parse("2"))
+ memref_f32 = memref.AllocOp(unranked_memref_t, [c10, c10], []).result
+ # CHECK: UnrankedMemRefType
+ print(type(memref_f32.type).__name__)
+ # CHECK: UnrankedMemRefType(memref<*xf32, 2>)
+ print(repr(memref_f32.type))
+
+ tuple_type = Operation.parse(
+ f'"test.make_tuple"() : () -> tuple<i32, f32>'
+ ).result
+ # CHECK: TupleType
+ print(type(tuple_type.type).__name__)
+ # CHECK: TupleType(tuple<i32, f32>)
+ print(repr(tuple_type.type))
+
+ return c0, c10
+
+
+# CHECK-LABEL: TEST: testCustomTypeTypeCaster
+# This tests being able to materialize a type from a dialect *and* have
+# the implemented type caster called without explicitly importing the dialect.
+# I.e., we get a transform.OperationType without explicitly importing the transform dialect.
+ at run
+def testCustomTypeTypeCaster():
+ with Context() as ctx, Location.unknown():
+ t = Type.parse('!transform.op<"foo.bar">', Context())
+ # CHECK: !transform.op<"foo.bar">
+ print(t)
+ # CHECK: OperationType(!transform.op<"foo.bar">)
+ print(repr(t))
diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index 7b443554440bc..71778a97d83a4 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -31,6 +31,10 @@ MlirType mlirPythonTestTestTypeGet(MlirContext context) {
return wrap(python_test::TestTypeType::get(unwrap(context)));
}
+MlirTypeID mlirPythonTestTestTypeGetTypeID(void) {
+ return wrap(python_test::TestTypeType::getTypeID());
+}
+
bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) {
return mlirTypeIsATensor(wrap(unwrap(value).getType()));
}
diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h
index 90c5d4a383f95..5f1ed3a5b2ad6 100644
--- a/mlir/test/python/lib/PythonTestCAPI.h
+++ b/mlir/test/python/lib/PythonTestCAPI.h
@@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
+MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestTypeGetTypeID(void);
+
MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value);
#ifdef __cplusplus
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index 7edeaac86e45c..f533082a0a147 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -7,11 +7,19 @@
//===----------------------------------------------------------------------===//
#include "PythonTestCAPI.h"
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
using namespace mlir::python::adaptors;
+using namespace pybind11::literals;
+
+static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
+ return mlirTypeIsARankedTensor(t) &&
+ mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
+}
PYBIND11_MODULE(_mlirPythonTest, m) {
m.def(
@@ -34,16 +42,38 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
return cls(mlirPythonTestTestAttributeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
- mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType)
+ mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
+ mlirPythonTestTestTypeGetTypeID)
.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirPythonTestTestTypeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
- mlir_type_subclass(m, "TestTensorType", mlirTypeIsARankedTensor,
- py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("RankedTensorType"));
+ auto cls =
+ mlir_type_subclass(m, "TestIntegerRankedTensorType",
+ mlirTypeIsARankedIntegerTensor,
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("RankedTensorType"))
+ .def_classmethod(
+ "get",
+ [](const py::object &cls, std::vector<int64_t> shape,
+ unsigned width, MlirContext ctx) {
+ MlirAttribute encoding = mlirAttributeGetNull();
+ return cls(mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
+ encoding));
+ },
+ "cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
+ assert(py::hasattr(cls.get_class(), "static_typeid") &&
+ "TestIntegerRankedTensorType has no static_typeid");
+ MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID();
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+ mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) {
+ return cls.get_class()(mlirType);
+ }),
+ /*replace=*/true);
mlir_value_subclass(m, "TestTensorValue",
mlirTypeIsAPythonTestTestTensorValue)
.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
More information about the Mlir-commits
mailing list