[Mlir-commits] [mlir] 3834c54 - [MLIR][Python] add unchecked gettors (#160954)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Sep 27 11:54:37 PDT 2025
Author: Maksim Levental
Date: 2025-09-27T13:54:33-05:00
New Revision: 3834c5428dcdd8807331b2ff522115dba2f89aae
URL: https://github.com/llvm/llvm-project/commit/3834c5428dcdd8807331b2ff522115dba2f89aae
DIFF: https://github.com/llvm/llvm-project/commit/3834c5428dcdd8807331b2ff522115dba2f89aae.diff
LOG: [MLIR][Python] add unchecked gettors (#160954)
Some of the current gettors require passing locations (i.e., there be an
active location) because they're using the "checked" APIs. This PR adds
"unchecked" gettors which only require an active context.
Added:
Modified:
mlir/lib/Bindings/Python/DialectLLVM.cpp
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/test/python/ir/builtin_types.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 55b9331270cdc..38de4a0e329a0 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -33,21 +33,37 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
auto llvmStructType =
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
- llvmStructType.def_classmethod(
- "get_literal",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirLocation loc) {
- CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
- MlirType type = mlirLLVMStructTypeLiteralGetChecked(
- loc, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "loc"_a = nb::none());
+ llvmStructType
+ .def_classmethod(
+ "get_literal",
+ [](const nb::object &cls, const std::vector<MlirType> &elements,
+ bool packed, MlirLocation loc) {
+ CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
+
+ MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+ loc, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return cls(type);
+ },
+ "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "loc"_a = nb::none())
+ .def_classmethod(
+ "get_literal_unchecked",
+ [](const nb::object &cls, const std::vector<MlirType> &elements,
+ bool packed, MlirContext context) {
+ CollectDiagnosticsToStringScope scope(context);
+
+ MlirType type = mlirLLVMStructTypeLiteralGet(
+ context, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return cls(type);
+ },
+ "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
llvmStructType.def_classmethod(
"get_identified",
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c77653f97e6dd..045c0fbf4630f 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -575,6 +575,18 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
},
nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
"Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get_unchecked",
+ [](PyType &type, double value, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute attr =
+ mlirFloatAttrDoubleGet(context.get()->get(), type, value);
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued float point attribute associated to a type");
c.def_static(
"get_f32",
[](double value, DefaultingPyMlirContext context) {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 07dc00521833f..3488d92250b45 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -639,11 +639,16 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyVectorType::get, nb::arg("shape"),
+ c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
nb::arg("element_type"), nb::kw_only(),
nb::arg("scalable") = nb::none(),
nb::arg("scalable_dims") = nb::none(),
nb::arg("loc") = nb::none(), "Create a vector type")
+ .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
+ nb::arg("element_type"), nb::kw_only(),
+ nb::arg("scalable") = nb::none(),
+ nb::arg("scalable_dims") = nb::none(),
+ nb::arg("context") = nb::none(), "Create a vector type")
.def_prop_ro(
"scalable",
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
@@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
}
private:
- static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyLocation loc) {
+ static PyVectorType
+ getChecked(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc) {
if (scalable && scalableDims) {
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
"are mutually exclusive.");
@@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
throw MLIRError("Invalid type", errors.take());
return PyVectorType(elementType.getContext(), type);
}
+
+ static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyMlirContext context) {
+ if (scalable && scalableDims) {
+ throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nb::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+ *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nb::value_error("Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
+ }
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else {
+ type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+ }
};
/// Ranked Tensor Type subclass - RankedTensorType.
@@ -724,6 +766,22 @@ class PyRankedTensorType
nb::arg("shape"), nb::arg("element_type"),
nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
"Create a ranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
+ "Create a ranked tensor type");
c.def_prop_ro(
"encoding",
[](PyRankedTensorType &self)
@@ -758,6 +816,17 @@ class PyUnrankedTensorType
},
nb::arg("element_type"), nb::arg("loc") = nb::none(),
"Create a unranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](PyType &elementType, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirUnrankedTensorTypeGet(elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("context") = nb::none(),
+ "Create a unranked tensor type");
}
};
@@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
nb::arg("shape"), nb::arg("element_type"),
nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
nb::arg("loc") = nb::none(), "Create a memref type")
+ .def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ PyAttribute *layout, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute layoutAttr =
+ layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+ layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("layout") = nb::none(),
+ nb::arg("memory_space") = nb::none(),
+ nb::arg("context") = nb::none(), "Create a memref type")
.def_prop_ro(
"layout",
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
@@ -858,6 +948,22 @@ class PyUnrankedMemRefType
},
nb::arg("element_type"), nb::arg("memory_space").none(),
nb::arg("loc") = nb::none(), "Create a unranked memref type")
+ .def_static(
+ "get_unchecked",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("memory_space").none(),
+ nb::arg("context") = nb::none(), "Create a unranked memref type")
.def_prop_ro(
"memory_space",
[](PyUnrankedMemRefType &self)
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index b42bfd9bc6587..54863253fc770 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -371,11 +371,16 @@ def testAbstractShapedType():
# CHECK-LABEL: TEST: testVectorType
@run
def testVectorType():
+ shape = [2, 3]
+ with Context():
+ f32 = F32Type.get()
+ # CHECK: unchecked vector type: vector<2x3xf32>
+ print("unchecked vector type:", VectorType.get_unchecked(shape, f32))
+
with Context(), Location.unknown():
f32 = F32Type.get()
- shape = [2, 3]
- # CHECK: vector type: vector<2x3xf32>
- print("vector type:", VectorType.get(shape, f32))
+ # CHECK: checked vector type: vector<2x3xf32>
+ print("checked vector type:", VectorType.get(shape, f32))
none = NoneType.get()
try:
More information about the Mlir-commits
mailing list