[Mlir-commits] [mlir] [MLIR][Python] rename checked gettors and add unchecked gettors (PR #160954)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Fri Sep 26 14:49:12 PDT 2025
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
Some of the current gettors required passing locations (i.e., there be an active location) because they're using the "checked" API. This PR renames those gettors (explicitly advertising the checked aspect) and adds "unchecked" gettors which only require an active context.
---
Full diff: https://github.com/llvm/llvm-project/pull/160954.diff
3 Files Affected:
- (modified) mlir/lib/Bindings/Python/DialectLLVM.cpp (+31-15) 
- (modified) mlir/lib/Bindings/Python/IRAttributes.cpp (+13-1) 
- (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+124-8) 
``````````diff
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 55b9331270cdc..b044965f6ac1a 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -33,21 +33,37 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
   auto llvmStructType =
       mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
 
-  llvmStructType.def_classmethod(
-      "get_literal",
-      [](const nb::object &cls, const std::vector<MlirType> &elements,
-         bool packed, MlirLocation loc) {
-        CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
-        MlirType type = mlirLLVMStructTypeLiteralGetChecked(
-            loc, elements.size(), elements.data(), packed);
-        if (mlirTypeIsNull(type)) {
-          throw nb::value_error(scope.takeMessage().c_str());
-        }
-        return cls(type);
-      },
-      "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
-      "loc"_a = nb::none());
+  llvmStructType
+      .def_classmethod(
+          "get_literal_checked",
+          [](const nb::object &cls, const std::vector<MlirType> &elements,
+             bool packed, MlirLocation loc) {
+            CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
+
+            MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+                loc, elements.size(), elements.data(), packed);
+            if (mlirTypeIsNull(type)) {
+              throw nb::value_error(scope.takeMessage().c_str());
+            }
+            return cls(type);
+          },
+          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+          "loc"_a = nb::none())
+      .def_classmethod(
+          "get_literal",
+          [](const nb::object &cls, const std::vector<MlirType> &elements,
+             bool packed, MlirContext context) {
+            CollectDiagnosticsToStringScope scope(context);
+
+            MlirType type = mlirLLVMStructTypeLiteralGet(
+                context, elements.size(), elements.data(), packed);
+            if (mlirTypeIsNull(type)) {
+              throw nb::value_error(scope.takeMessage().c_str());
+            }
+            return cls(type);
+          },
+          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+          "context"_a = nb::none());
 
   llvmStructType.def_classmethod(
       "get_identified",
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c77653f97e6dd..24e92ffffe8ae 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -565,7 +565,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-        "get",
+        "get_checked",
         [](PyType &type, double value, DefaultingPyLocation loc) {
           PyMlirContext::ErrorCapture errors(loc->getContext());
           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
@@ -575,6 +575,18 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
         },
         nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
         "Gets an uniqued float point attribute associated to a type");
+    c.def_static(
+        "get",
+        [](PyType &type, double value, DefaultingPyMlirContext context) {
+          PyMlirContext::ErrorCapture errors(context->getRef());
+          MlirAttribute attr =
+              mlirFloatAttrDoubleGet(context.get()->get(), type, value);
+          if (mlirAttributeIsNull(attr))
+            throw MLIRError("Invalid attribute", errors.take());
+          return PyFloatAttribute(type.getContext(), attr);
+        },
+        nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
+        "Gets an uniqued float point attribute associated to a type");
     c.def_static(
         "get_f32",
         [](double value, DefaultingPyMlirContext context) {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 07dc00521833f..0238a24708962 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -643,7 +643,12 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
                  nb::arg("element_type"), nb::kw_only(),
                  nb::arg("scalable") = nb::none(),
                  nb::arg("scalable_dims") = nb::none(),
-                 nb::arg("loc") = nb::none(), "Create a vector type")
+                 nb::arg("context") = nb::none(), "Create a vector type")
+        .def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
+                    nb::arg("element_type"), nb::kw_only(),
+                    nb::arg("scalable") = nb::none(),
+                    nb::arg("scalable_dims") = nb::none(),
+                    nb::arg("loc") = nb::none(), "Create a vector type")
         .def_prop_ro(
             "scalable",
             [](MlirType self) { return mlirVectorTypeIsScalable(self); })
@@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
   }
 
 private:
-  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
-                          std::optional<nb::list> scalable,
-                          std::optional<std::vector<int64_t>> scalableDims,
-                          DefaultingPyLocation loc) {
+  static PyVectorType
+  getChecked(std::vector<int64_t> shape, PyType &elementType,
+             std::optional<nb::list> scalable,
+             std::optional<std::vector<int64_t>> scalableDims,
+             DefaultingPyLocation loc) {
     if (scalable && scalableDims) {
       throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
                             "are mutually exclusive.");
@@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
       throw MLIRError("Invalid type", errors.take());
     return PyVectorType(elementType.getContext(), type);
   }
+
+  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+                          std::optional<nb::list> scalable,
+                          std::optional<std::vector<int64_t>> scalableDims,
+                          DefaultingPyMlirContext context) {
+    if (scalable && scalableDims) {
+      throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+                            "are mutually exclusive.");
+    }
+
+    PyMlirContext::ErrorCapture errors(context->getRef());
+    MlirType type;
+    if (scalable) {
+      if (scalable->size() != shape.size())
+        throw nb::value_error("Expected len(scalable) == len(shape).");
+
+      SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+          *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+      type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+                                       scalableDimFlags.data(), elementType);
+    } else if (scalableDims) {
+      SmallVector<bool> scalableDimFlags(shape.size(), false);
+      for (int64_t dim : *scalableDims) {
+        if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+          throw nb::value_error("Scalable dimension index out of bounds.");
+        scalableDimFlags[dim] = true;
+      }
+      type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+                                       scalableDimFlags.data(), elementType);
+    } else {
+      type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+    }
+    if (mlirTypeIsNull(type))
+      throw MLIRError("Invalid type", errors.take());
+    return PyVectorType(elementType.getContext(), type);
+  }
 };
 
 /// Ranked Tensor Type subclass - RankedTensorType.
@@ -710,7 +752,7 @@ class PyRankedTensorType
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-        "get",
+        "get_checked",
         [](std::vector<int64_t> shape, PyType &elementType,
            std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
           PyMlirContext::ErrorCapture errors(loc->getContext());
@@ -724,6 +766,22 @@ class PyRankedTensorType
         nb::arg("shape"), nb::arg("element_type"),
         nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
         "Create a ranked tensor type");
+    c.def_static(
+        "get",
+        [](std::vector<int64_t> shape, PyType &elementType,
+           std::optional<PyAttribute> &encodingAttr,
+           DefaultingPyMlirContext context) {
+          PyMlirContext::ErrorCapture errors(context->getRef());
+          MlirType t = mlirRankedTensorTypeGet(
+              shape.size(), shape.data(), elementType,
+              encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+          if (mlirTypeIsNull(t))
+            throw MLIRError("Invalid type", errors.take());
+          return PyRankedTensorType(elementType.getContext(), t);
+        },
+        nb::arg("shape"), nb::arg("element_type"),
+        nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
+        "Create a ranked tensor type");
     c.def_prop_ro(
         "encoding",
         [](PyRankedTensorType &self)
@@ -758,6 +816,17 @@ class PyUnrankedTensorType
         },
         nb::arg("element_type"), nb::arg("loc") = nb::none(),
         "Create a unranked tensor type");
+    c.def_static(
+        "get_checked",
+        [](PyType &elementType, DefaultingPyMlirContext context) {
+          PyMlirContext::ErrorCapture errors(context->getRef());
+          MlirType t = mlirUnrankedTensorTypeGet(elementType);
+          if (mlirTypeIsNull(t))
+            throw MLIRError("Invalid type", errors.take());
+          return PyUnrankedTensorType(elementType.getContext(), t);
+        },
+        nb::arg("element_type"), nb::arg("context") = nb::none(),
+        "Create a unranked tensor type");
   }
 };
 
@@ -772,7 +841,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-         "get",
+         "get_checked",
          [](std::vector<int64_t> shape, PyType &elementType,
             PyAttribute *layout, PyAttribute *memorySpace,
             DefaultingPyLocation loc) {
@@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
          nb::arg("shape"), nb::arg("element_type"),
          nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
          nb::arg("loc") = nb::none(), "Create a memref type")
+        .def_static(
+            "get",
+            [](std::vector<int64_t> shape, PyType &elementType,
+               PyAttribute *layout, PyAttribute *memorySpace,
+               DefaultingPyMlirContext context) {
+              PyMlirContext::ErrorCapture errors(context->getRef());
+              MlirAttribute layoutAttr =
+                  layout ? *layout : mlirAttributeGetNull();
+              MlirAttribute memSpaceAttr =
+                  memorySpace ? *memorySpace : mlirAttributeGetNull();
+              MlirType t =
+                  mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+                                    layoutAttr, memSpaceAttr);
+              if (mlirTypeIsNull(t))
+                throw MLIRError("Invalid type", errors.take());
+              return PyMemRefType(elementType.getContext(), t);
+            },
+            nb::arg("shape"), nb::arg("element_type"),
+            nb::arg("layout") = nb::none(),
+            nb::arg("memory_space") = nb::none(),
+            nb::arg("context") = nb::none(), "Create a memref type")
         .def_prop_ro(
             "layout",
             [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
@@ -842,7 +932,7 @@ class PyUnrankedMemRefType
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-         "get",
+         "get_checked",
          [](PyType &elementType, PyAttribute *memorySpace,
             DefaultingPyLocation loc) {
            PyMlirContext::ErrorCapture errors(loc->getContext());
@@ -858,6 +948,32 @@ class PyUnrankedMemRefType
          },
          nb::arg("element_type"), nb::arg("memory_space").none(),
          nb::arg("loc") = nb::none(), "Create a unranked memref type")
+        .def_prop_ro(
+            "memory_space",
+            [](PyUnrankedMemRefType &self)
+                -> std::optional<nb::typed<nb::object, PyAttribute>> {
+              MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
+              if (mlirAttributeIsNull(a))
+                return std::nullopt;
+              return PyAttribute(self.getContext(), a).maybeDownCast();
+            },
+            "Returns the memory space of the given Unranked MemRef type.")
+        .def_static(
+            "get",
+            [](PyType &elementType, PyAttribute *memorySpace,
+               DefaultingPyMlirContext context) {
+              PyMlirContext::ErrorCapture errors(context->getRef());
+              MlirAttribute memSpaceAttr = {};
+              if (memorySpace)
+                memSpaceAttr = *memorySpace;
+
+              MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+              if (mlirTypeIsNull(t))
+                throw MLIRError("Invalid type", errors.take());
+              return PyUnrankedMemRefType(elementType.getContext(), t);
+            },
+            nb::arg("element_type"), nb::arg("memory_space").none(),
+            nb::arg("context") = nb::none(), "Create a unranked memref type")
         .def_prop_ro(
             "memory_space",
             [](PyUnrankedMemRefType &self)
``````````
</details>
https://github.com/llvm/llvm-project/pull/160954
    
    
More information about the Mlir-commits
mailing list