[Mlir-commits] [mlir] [MLIR][Python] add unchecked gettors (PR #160954)

Maksim Levental llvmlistbot at llvm.org
Fri Sep 26 19:18:25 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/160954

>From bfa9fefcafeb8abfa8d21b48ce74c7af4a20464f Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 26 Sep 2025 14:42:04 -0700
Subject: [PATCH 1/2] [MLIR][Python] rename checked gettors and add unchecked
 gettors

---
 mlir/lib/Bindings/Python/DialectLLVM.cpp  |  46 ++++++---
 mlir/lib/Bindings/Python/IRAttributes.cpp |  12 +++
 mlir/lib/Bindings/Python/IRTypes.cpp      | 116 +++++++++++++++++++++-
 3 files changed, 154 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 55b9331270cdc..38de4a0e329a0 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -33,21 +33,37 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
   auto llvmStructType =
       mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
 
-  llvmStructType.def_classmethod(
-      "get_literal",
-      [](const nb::object &cls, const std::vector<MlirType> &elements,
-         bool packed, MlirLocation loc) {
-        CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
-        MlirType type = mlirLLVMStructTypeLiteralGetChecked(
-            loc, elements.size(), elements.data(), packed);
-        if (mlirTypeIsNull(type)) {
-          throw nb::value_error(scope.takeMessage().c_str());
-        }
-        return cls(type);
-      },
-      "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
-      "loc"_a = nb::none());
+  llvmStructType
+      .def_classmethod(
+          "get_literal",
+          [](const nb::object &cls, const std::vector<MlirType> &elements,
+             bool packed, MlirLocation loc) {
+            CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
+
+            MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+                loc, elements.size(), elements.data(), packed);
+            if (mlirTypeIsNull(type)) {
+              throw nb::value_error(scope.takeMessage().c_str());
+            }
+            return cls(type);
+          },
+          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+          "loc"_a = nb::none())
+      .def_classmethod(
+          "get_literal_unchecked",
+          [](const nb::object &cls, const std::vector<MlirType> &elements,
+             bool packed, MlirContext context) {
+            CollectDiagnosticsToStringScope scope(context);
+
+            MlirType type = mlirLLVMStructTypeLiteralGet(
+                context, elements.size(), elements.data(), packed);
+            if (mlirTypeIsNull(type)) {
+              throw nb::value_error(scope.takeMessage().c_str());
+            }
+            return cls(type);
+          },
+          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+          "context"_a = nb::none());
 
   llvmStructType.def_classmethod(
       "get_identified",
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c77653f97e6dd..045c0fbf4630f 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -575,6 +575,18 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
         },
         nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
         "Gets an uniqued float point attribute associated to a type");
+    c.def_static(
+        "get_unchecked",
+        [](PyType &type, double value, DefaultingPyMlirContext context) {
+          PyMlirContext::ErrorCapture errors(context->getRef());
+          MlirAttribute attr =
+              mlirFloatAttrDoubleGet(context.get()->get(), type, value);
+          if (mlirAttributeIsNull(attr))
+            throw MLIRError("Invalid attribute", errors.take());
+          return PyFloatAttribute(type.getContext(), attr);
+        },
+        nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
+        "Gets an uniqued float point attribute associated to a type");
     c.def_static(
         "get_f32",
         [](double value, DefaultingPyMlirContext context) {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 07dc00521833f..3488d92250b45 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -639,11 +639,16 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
   using PyConcreteType::PyConcreteType;
 
   static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyVectorType::get, nb::arg("shape"),
+    c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
                  nb::arg("element_type"), nb::kw_only(),
                  nb::arg("scalable") = nb::none(),
                  nb::arg("scalable_dims") = nb::none(),
                  nb::arg("loc") = nb::none(), "Create a vector type")
+        .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
+                    nb::arg("element_type"), nb::kw_only(),
+                    nb::arg("scalable") = nb::none(),
+                    nb::arg("scalable_dims") = nb::none(),
+                    nb::arg("context") = nb::none(), "Create a vector type")
         .def_prop_ro(
             "scalable",
             [](MlirType self) { return mlirVectorTypeIsScalable(self); })
@@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
   }
 
 private:
-  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
-                          std::optional<nb::list> scalable,
-                          std::optional<std::vector<int64_t>> scalableDims,
-                          DefaultingPyLocation loc) {
+  static PyVectorType
+  getChecked(std::vector<int64_t> shape, PyType &elementType,
+             std::optional<nb::list> scalable,
+             std::optional<std::vector<int64_t>> scalableDims,
+             DefaultingPyLocation loc) {
     if (scalable && scalableDims) {
       throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
                             "are mutually exclusive.");
@@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
       throw MLIRError("Invalid type", errors.take());
     return PyVectorType(elementType.getContext(), type);
   }
+
+  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+                          std::optional<nb::list> scalable,
+                          std::optional<std::vector<int64_t>> scalableDims,
+                          DefaultingPyMlirContext context) {
+    if (scalable && scalableDims) {
+      throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+                            "are mutually exclusive.");
+    }
+
+    PyMlirContext::ErrorCapture errors(context->getRef());
+    MlirType type;
+    if (scalable) {
+      if (scalable->size() != shape.size())
+        throw nb::value_error("Expected len(scalable) == len(shape).");
+
+      SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+          *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+      type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+                                       scalableDimFlags.data(), elementType);
+    } else if (scalableDims) {
+      SmallVector<bool> scalableDimFlags(shape.size(), false);
+      for (int64_t dim : *scalableDims) {
+        if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+          throw nb::value_error("Scalable dimension index out of bounds.");
+        scalableDimFlags[dim] = true;
+      }
+      type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+                                       scalableDimFlags.data(), elementType);
+    } else {
+      type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+    }
+    if (mlirTypeIsNull(type))
+      throw MLIRError("Invalid type", errors.take());
+    return PyVectorType(elementType.getContext(), type);
+  }
 };
 
 /// Ranked Tensor Type subclass - RankedTensorType.
@@ -724,6 +766,22 @@ class PyRankedTensorType
         nb::arg("shape"), nb::arg("element_type"),
         nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
         "Create a ranked tensor type");
+    c.def_static(
+        "get_unchecked",
+        [](std::vector<int64_t> shape, PyType &elementType,
+           std::optional<PyAttribute> &encodingAttr,
+           DefaultingPyMlirContext context) {
+          PyMlirContext::ErrorCapture errors(context->getRef());
+          MlirType t = mlirRankedTensorTypeGet(
+              shape.size(), shape.data(), elementType,
+              encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+          if (mlirTypeIsNull(t))
+            throw MLIRError("Invalid type", errors.take());
+          return PyRankedTensorType(elementType.getContext(), t);
+        },
+        nb::arg("shape"), nb::arg("element_type"),
+        nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
+        "Create a ranked tensor type");
     c.def_prop_ro(
         "encoding",
         [](PyRankedTensorType &self)
@@ -758,6 +816,17 @@ class PyUnrankedTensorType
         },
         nb::arg("element_type"), nb::arg("loc") = nb::none(),
         "Create a unranked tensor type");
+    c.def_static(
+        "get_unchecked",
+        [](PyType &elementType, DefaultingPyMlirContext context) {
+          PyMlirContext::ErrorCapture errors(context->getRef());
+          MlirType t = mlirUnrankedTensorTypeGet(elementType);
+          if (mlirTypeIsNull(t))
+            throw MLIRError("Invalid type", errors.take());
+          return PyUnrankedTensorType(elementType.getContext(), t);
+        },
+        nb::arg("element_type"), nb::arg("context") = nb::none(),
+        "Create a unranked tensor type");
   }
 };
 
@@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
          nb::arg("shape"), nb::arg("element_type"),
          nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
          nb::arg("loc") = nb::none(), "Create a memref type")
+        .def_static(
+            "get_unchecked",
+            [](std::vector<int64_t> shape, PyType &elementType,
+               PyAttribute *layout, PyAttribute *memorySpace,
+               DefaultingPyMlirContext context) {
+              PyMlirContext::ErrorCapture errors(context->getRef());
+              MlirAttribute layoutAttr =
+                  layout ? *layout : mlirAttributeGetNull();
+              MlirAttribute memSpaceAttr =
+                  memorySpace ? *memorySpace : mlirAttributeGetNull();
+              MlirType t =
+                  mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+                                    layoutAttr, memSpaceAttr);
+              if (mlirTypeIsNull(t))
+                throw MLIRError("Invalid type", errors.take());
+              return PyMemRefType(elementType.getContext(), t);
+            },
+            nb::arg("shape"), nb::arg("element_type"),
+            nb::arg("layout") = nb::none(),
+            nb::arg("memory_space") = nb::none(),
+            nb::arg("context") = nb::none(), "Create a memref type")
         .def_prop_ro(
             "layout",
             [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
@@ -858,6 +948,22 @@ class PyUnrankedMemRefType
          },
          nb::arg("element_type"), nb::arg("memory_space").none(),
          nb::arg("loc") = nb::none(), "Create a unranked memref type")
+        .def_static(
+            "get_unchecked",
+            [](PyType &elementType, PyAttribute *memorySpace,
+               DefaultingPyMlirContext context) {
+              PyMlirContext::ErrorCapture errors(context->getRef());
+              MlirAttribute memSpaceAttr = {};
+              if (memorySpace)
+                memSpaceAttr = *memorySpace;
+
+              MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+              if (mlirTypeIsNull(t))
+                throw MLIRError("Invalid type", errors.take());
+              return PyUnrankedMemRefType(elementType.getContext(), t);
+            },
+            nb::arg("element_type"), nb::arg("memory_space").none(),
+            nb::arg("context") = nb::none(), "Create a unranked memref type")
         .def_prop_ro(
             "memory_space",
             [](PyUnrankedMemRefType &self)

>From bcc8bcad3b7a6dc65b40892fa52d33a2ff813bb8 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 26 Sep 2025 16:15:13 -0700
Subject: [PATCH 2/2] run ci

---
 mlir/test/python/ir/builtin_types.py | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index b42bfd9bc6587..54863253fc770 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -371,11 +371,16 @@ def testAbstractShapedType():
 # CHECK-LABEL: TEST: testVectorType
 @run
 def testVectorType():
+    shape = [2, 3]
+    with Context():
+        f32 = F32Type.get()
+        # CHECK: unchecked vector type: vector<2x3xf32>
+        print("unchecked vector type:", VectorType.get_unchecked(shape, f32))
+
     with Context(), Location.unknown():
         f32 = F32Type.get()
-        shape = [2, 3]
-        # CHECK: vector type: vector<2x3xf32>
-        print("vector type:", VectorType.get(shape, f32))
+        # CHECK: checked vector type: vector<2x3xf32>
+        print("checked vector type:", VectorType.get(shape, f32))
 
         none = NoneType.get()
         try:



More information about the Mlir-commits mailing list