[Mlir-commits] [mlir] [MLIR][Python] rename checked gettors and add unchecked gettors (PR #160954)
Maksim Levental
llvmlistbot at llvm.org
Fri Sep 26 15:30:33 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/160954
>From bec4ede62b8aa92f1aa4dcb2b135cf8a1e286f70 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 26 Sep 2025 14:42:04 -0700
Subject: [PATCH] [MLIR][Python] rename checked gettors and add unchecked
gettors
---
mlir/lib/Bindings/Python/DialectLLVM.cpp | 46 +++++---
mlir/lib/Bindings/Python/IRAttributes.cpp | 14 ++-
mlir/lib/Bindings/Python/IRTypes.cpp | 124 ++++++++++++++++++++--
mlir/test/python/ir/attributes.py | 2 +-
mlir/test/python/ir/builtin_types.py | 16 +--
5 files changed, 168 insertions(+), 34 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 55b9331270cdc..b044965f6ac1a 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -33,21 +33,37 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
auto llvmStructType =
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
- llvmStructType.def_classmethod(
- "get_literal",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirLocation loc) {
- CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
- MlirType type = mlirLLVMStructTypeLiteralGetChecked(
- loc, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "loc"_a = nb::none());
+ llvmStructType
+ .def_classmethod(
+ "get_literal_checked",
+ [](const nb::object &cls, const std::vector<MlirType> &elements,
+ bool packed, MlirLocation loc) {
+ CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
+
+ MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+ loc, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return cls(type);
+ },
+ "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "loc"_a = nb::none())
+ .def_classmethod(
+ "get_literal",
+ [](const nb::object &cls, const std::vector<MlirType> &elements,
+ bool packed, MlirContext context) {
+ CollectDiagnosticsToStringScope scope(context);
+
+ MlirType type = mlirLLVMStructTypeLiteralGet(
+ context, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return cls(type);
+ },
+ "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
llvmStructType.def_classmethod(
"get_identified",
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c77653f97e6dd..24e92ffffe8ae 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -565,7 +565,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
static void bindDerived(ClassTy &c) {
c.def_static(
- "get",
+ "get_checked",
[](PyType &type, double value, DefaultingPyLocation loc) {
PyMlirContext::ErrorCapture errors(loc->getContext());
MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
@@ -575,6 +575,18 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
},
nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
"Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get",
+ [](PyType &type, double value, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute attr =
+ mlirFloatAttrDoubleGet(context.get()->get(), type, value);
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued float point attribute associated to a type");
c.def_static(
"get_f32",
[](double value, DefaultingPyMlirContext context) {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 07dc00521833f..b0d8345096ff2 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -643,7 +643,12 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
nb::arg("element_type"), nb::kw_only(),
nb::arg("scalable") = nb::none(),
nb::arg("scalable_dims") = nb::none(),
- nb::arg("loc") = nb::none(), "Create a vector type")
+ nb::arg("context") = nb::none(), "Create a vector type")
+ .def_static("get_checked", &PyVectorType::getChecked, nb::arg("shape"),
+ nb::arg("element_type"), nb::kw_only(),
+ nb::arg("scalable") = nb::none(),
+ nb::arg("scalable_dims") = nb::none(),
+ nb::arg("loc") = nb::none(), "Create a vector type")
.def_prop_ro(
"scalable",
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
@@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
}
private:
- static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyLocation loc) {
+ static PyVectorType
+ getChecked(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc) {
if (scalable && scalableDims) {
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
"are mutually exclusive.");
@@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
throw MLIRError("Invalid type", errors.take());
return PyVectorType(elementType.getContext(), type);
}
+
+ static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyMlirContext context) {
+ if (scalable && scalableDims) {
+ throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nb::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+ *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nb::value_error("Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
+ }
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else {
+ type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+ }
};
/// Ranked Tensor Type subclass - RankedTensorType.
@@ -710,7 +752,7 @@ class PyRankedTensorType
static void bindDerived(ClassTy &c) {
c.def_static(
- "get",
+ "get_checked",
[](std::vector<int64_t> shape, PyType &elementType,
std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
PyMlirContext::ErrorCapture errors(loc->getContext());
@@ -724,6 +766,22 @@ class PyRankedTensorType
nb::arg("shape"), nb::arg("element_type"),
nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
"Create a ranked tensor type");
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
+ "Create a ranked tensor type");
c.def_prop_ro(
"encoding",
[](PyRankedTensorType &self)
@@ -748,7 +806,7 @@ class PyUnrankedTensorType
static void bindDerived(ClassTy &c) {
c.def_static(
- "get",
+ "get_checked",
[](PyType &elementType, DefaultingPyLocation loc) {
PyMlirContext::ErrorCapture errors(loc->getContext());
MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
@@ -758,6 +816,17 @@ class PyUnrankedTensorType
},
nb::arg("element_type"), nb::arg("loc") = nb::none(),
"Create a unranked tensor type");
+ c.def_static(
+ "get",
+ [](PyType &elementType, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirUnrankedTensorTypeGet(elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("context") = nb::none(),
+ "Create a unranked tensor type");
}
};
@@ -772,7 +841,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
static void bindDerived(ClassTy &c) {
c.def_static(
- "get",
+ "get_checked",
[](std::vector<int64_t> shape, PyType &elementType,
PyAttribute *layout, PyAttribute *memorySpace,
DefaultingPyLocation loc) {
@@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
nb::arg("shape"), nb::arg("element_type"),
nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
nb::arg("loc") = nb::none(), "Create a memref type")
+ .def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ PyAttribute *layout, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute layoutAttr =
+ layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+ layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("layout") = nb::none(),
+ nb::arg("memory_space") = nb::none(),
+ nb::arg("context") = nb::none(), "Create a memref type")
.def_prop_ro(
"layout",
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
@@ -842,7 +932,7 @@ class PyUnrankedMemRefType
static void bindDerived(ClassTy &c) {
c.def_static(
- "get",
+ "get_checked",
[](PyType &elementType, PyAttribute *memorySpace,
DefaultingPyLocation loc) {
PyMlirContext::ErrorCapture errors(loc->getContext());
@@ -858,6 +948,22 @@ class PyUnrankedMemRefType
},
nb::arg("element_type"), nb::arg("memory_space").none(),
nb::arg("loc") = nb::none(), "Create a unranked memref type")
+ .def_static(
+ "get",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("memory_space").none(),
+ nb::arg("context") = nb::none(), "Create a unranked memref type")
.def_prop_ro(
"memory_space",
[](PyUnrankedMemRefType &self)
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 2f3c4460d3f59..9a9b693c896aa 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -198,7 +198,7 @@ def testFloatAttr():
# CHECK: f64_get: 4.200000e+01 : f64
print("f64_get:", FloatAttr.get_f64(42.0))
try:
- fattr_invalid = FloatAttr.get(IntegerType.get_signless(32), 42)
+ fattr_invalid = FloatAttr.get_checked(IntegerType.get_signless(32), 42)
except MLIRError as e:
# CHECK: Invalid attribute:
# CHECK: error: unknown: expected floating point type
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index b42bfd9bc6587..486e2e9d4d563 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -379,7 +379,7 @@ def testVectorType():
none = NoneType.get()
try:
- VectorType.get(shape, none)
+ VectorType.get_checked(shape, none)
except MLIRError as e:
# CHECK: Invalid type:
# CHECK: error: unknown: failed to verify 'elementType': VectorElementTypeInterface instance
@@ -404,7 +404,7 @@ def testVectorType():
assert scalable_4 == scalable_2
try:
- VectorType.get(shape, f32, scalable=[False, True, True])
+ VectorType.get_checked(shape, f32, scalable=[False, True, True])
except ValueError as e:
# CHECK: Expected len(scalable) == len(shape).
print(e)
@@ -412,7 +412,7 @@ def testVectorType():
print("Exception not produced")
try:
- VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[1])
+ VectorType.get_checked(shape, f32, scalable=[False, True], scalable_dims=[1])
except ValueError as e:
# CHECK: kwargs are mutually exclusive.
print(e)
@@ -420,7 +420,7 @@ def testVectorType():
print("Exception not produced")
try:
- VectorType.get(shape, f32, scalable_dims=[42])
+ VectorType.get_checked(shape, f32, scalable_dims=[42])
except ValueError as e:
# CHECK: Scalable dimension index out of bounds.
print(e)
@@ -440,7 +440,7 @@ def testRankedTensorType():
none = NoneType.get()
try:
- tensor_invalid = RankedTensorType.get(shape, none)
+ tensor_invalid = RankedTensorType.get_checked(shape, none)
except MLIRError as e:
# CHECK: Invalid type:
# CHECK: error: unknown: invalid tensor element type: 'none'
@@ -489,7 +489,7 @@ def testUnrankedTensorType():
none = NoneType.get()
try:
- tensor_invalid = UnrankedTensorType.get(none)
+ tensor_invalid = UnrankedTensorType.get_checked(none)
except MLIRError as e:
# CHECK: Invalid type:
# CHECK: error: unknown: invalid tensor element type: 'none'
@@ -528,7 +528,7 @@ def testMemRefType():
none = NoneType.get()
try:
- memref_invalid = MemRefType.get(shape, none)
+ memref_invalid = MemRefType.get_checked(shape, none)
except MLIRError as e:
# CHECK: Invalid type:
# CHECK: error: unknown: invalid memref element type
@@ -574,7 +574,7 @@ def testUnrankedMemRefType():
none = NoneType.get()
try:
- memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2"))
+ memref_invalid = UnrankedMemRefType.get_checked(none, Attribute.parse("2"))
except MLIRError as e:
# CHECK: Invalid type:
# CHECK: error: unknown: invalid memref element type
More information about the Mlir-commits
mailing list