[Mlir-commits] [mlir] [MLIR][Python] add unchecked gettors (PR #160954)
Maksim Levental
llvmlistbot at llvm.org
Fri Sep 26 19:18:25 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/160954
>From bfa9fefcafeb8abfa8d21b48ce74c7af4a20464f Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 26 Sep 2025 14:42:04 -0700
Subject: [PATCH 1/2] [MLIR][Python] rename checked gettors and add unchecked
gettors
---
mlir/lib/Bindings/Python/DialectLLVM.cpp | 46 ++++++---
mlir/lib/Bindings/Python/IRAttributes.cpp | 12 +++
mlir/lib/Bindings/Python/IRTypes.cpp | 116 +++++++++++++++++++++-
3 files changed, 154 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 55b9331270cdc..38de4a0e329a0 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -33,21 +33,37 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
auto llvmStructType =
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
- llvmStructType.def_classmethod(
- "get_literal",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirLocation loc) {
- CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
- MlirType type = mlirLLVMStructTypeLiteralGetChecked(
- loc, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "loc"_a = nb::none());
+ llvmStructType
+ .def_classmethod(
+ "get_literal",
+ [](const nb::object &cls, const std::vector<MlirType> &elements,
+ bool packed, MlirLocation loc) {
+ CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
+
+ MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+ loc, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return cls(type);
+ },
+ "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "loc"_a = nb::none())
+ .def_classmethod(
+ "get_literal_unchecked",
+ [](const nb::object &cls, const std::vector<MlirType> &elements,
+ bool packed, MlirContext context) {
+ CollectDiagnosticsToStringScope scope(context);
+
+ MlirType type = mlirLLVMStructTypeLiteralGet(
+ context, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return cls(type);
+ },
+ "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
llvmStructType.def_classmethod(
"get_identified",
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c77653f97e6dd..045c0fbf4630f 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -575,6 +575,18 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
},
nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
"Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get_unchecked",
+ [](PyType &type, double value, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute attr =
+ mlirFloatAttrDoubleGet(context.get()->get(), type, value);
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued float point attribute associated to a type");
c.def_static(
"get_f32",
[](double value, DefaultingPyMlirContext context) {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 07dc00521833f..3488d92250b45 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -639,11 +639,16 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyVectorType::get, nb::arg("shape"),
+ c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
nb::arg("element_type"), nb::kw_only(),
nb::arg("scalable") = nb::none(),
nb::arg("scalable_dims") = nb::none(),
nb::arg("loc") = nb::none(), "Create a vector type")
+ .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
+ nb::arg("element_type"), nb::kw_only(),
+ nb::arg("scalable") = nb::none(),
+ nb::arg("scalable_dims") = nb::none(),
+ nb::arg("context") = nb::none(), "Create a vector type")
.def_prop_ro(
"scalable",
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
@@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
}
private:
- static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyLocation loc) {
+ static PyVectorType
+ getChecked(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc) {
if (scalable && scalableDims) {
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
"are mutually exclusive.");
@@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
throw MLIRError("Invalid type", errors.take());
return PyVectorType(elementType.getContext(), type);
}
+
+ static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyMlirContext context) {
+ if (scalable && scalableDims) {
+ throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nb::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+ *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nb::value_error("Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
+ }
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else {
+ type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+ }
};
/// Ranked Tensor Type subclass - RankedTensorType.
@@ -724,6 +766,22 @@ class PyRankedTensorType
nb::arg("shape"), nb::arg("element_type"),
nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
"Create a ranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
+ "Create a ranked tensor type");
c.def_prop_ro(
"encoding",
[](PyRankedTensorType &self)
@@ -758,6 +816,17 @@ class PyUnrankedTensorType
},
nb::arg("element_type"), nb::arg("loc") = nb::none(),
"Create a unranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](PyType &elementType, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirUnrankedTensorTypeGet(elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("context") = nb::none(),
+ "Create a unranked tensor type");
}
};
@@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
nb::arg("shape"), nb::arg("element_type"),
nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
nb::arg("loc") = nb::none(), "Create a memref type")
+ .def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ PyAttribute *layout, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute layoutAttr =
+ layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+ layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("layout") = nb::none(),
+ nb::arg("memory_space") = nb::none(),
+ nb::arg("context") = nb::none(), "Create a memref type")
.def_prop_ro(
"layout",
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
@@ -858,6 +948,22 @@ class PyUnrankedMemRefType
},
nb::arg("element_type"), nb::arg("memory_space").none(),
nb::arg("loc") = nb::none(), "Create a unranked memref type")
+ .def_static(
+ "get_unchecked",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("memory_space").none(),
+ nb::arg("context") = nb::none(), "Create a unranked memref type")
.def_prop_ro(
"memory_space",
[](PyUnrankedMemRefType &self)
>From bcc8bcad3b7a6dc65b40892fa52d33a2ff813bb8 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 26 Sep 2025 16:15:13 -0700
Subject: [PATCH 2/2] run ci
---
mlir/test/python/ir/builtin_types.py | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index b42bfd9bc6587..54863253fc770 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -371,11 +371,16 @@ def testAbstractShapedType():
# CHECK-LABEL: TEST: testVectorType
@run
def testVectorType():
+ shape = [2, 3]
+ with Context():
+ f32 = F32Type.get()
+ # CHECK: unchecked vector type: vector<2x3xf32>
+ print("unchecked vector type:", VectorType.get_unchecked(shape, f32))
+
with Context(), Location.unknown():
f32 = F32Type.get()
- shape = [2, 3]
- # CHECK: vector type: vector<2x3xf32>
- print("vector type:", VectorType.get(shape, f32))
+ # CHECK: checked vector type: vector<2x3xf32>
+ print("checked vector type:", VectorType.get(shape, f32))
none = NoneType.get()
try:
More information about the Mlir-commits
mailing list