[Mlir-commits] [mlir] [MLIR] [Python] Relaxed `list` to `Sequence` in most parameter types (PR #188543)

Sergei Lebedev llvmlistbot at llvm.org
Wed Mar 25 09:59:03 PDT 2026


https://github.com/superbobry created https://github.com/llvm/llvm-project/pull/188543

Using `Sequence` frees users from the need to cast to `list` in cases where the underlying API does not really care about the type of the container.

Note that accepting an `nb::sequence` is marginally slower than accepting `nb::list` directly, because `__getitem__`, `__len__` etc need to go through an extra layer of indirection. However, I expect the performance difference to be negligible.

>From 0fd59df89c4ab982c974dd8a7de1236b8dd6dc40 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Wed, 25 Mar 2026 16:53:12 +0000
Subject: [PATCH] [MLIR] [Python] Relaxed `list` to `Sequence` in most
 parameter types

Using `Sequence` frees users from the need to cast to `list` in cases where the
underlying API does not really care about the type of the container.

Note that accepting an `nb::sequence` is marginally slower than accepting
`nb::list` directly, because `__getitem__`, `__len__` etc need to go through
an extra layer of indirection. However, I expect the performance difference
to be negligible.
---
 .../mlir/Bindings/Python/IRAttributes.h       | 10 ++---
 mlir/include/mlir/Bindings/Python/IRCore.h    | 11 ++---
 mlir/include/mlir/Bindings/Python/IRTypes.h   |  4 +-
 mlir/lib/Bindings/Python/DialectTransform.cpp | 15 ++++---
 mlir/lib/Bindings/Python/IRAffine.cpp         | 26 +++++------
 mlir/lib/Bindings/Python/IRAttributes.cpp     | 15 ++++---
 mlir/lib/Bindings/Python/IRCore.cpp           | 44 +++++++++++--------
 mlir/lib/Bindings/Python/IRInterfaces.cpp     | 17 +++----
 mlir/lib/Bindings/Python/IRTypes.cpp          | 12 ++---
 9 files changed, 83 insertions(+), 71 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
index 28a7134dc2d45..7f2f88b552453 100644
--- a/mlir/include/mlir/Bindings/Python/IRAttributes.h
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -193,7 +193,7 @@ class MLIR_PYTHON_API_EXPORTED PyDenseArrayAttribute
     });
     c.def("__iter__",
           [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
-    c.def("__add__", [](DerivedT &arr, const nanobind::list &extras) {
+    c.def("__add__", [](DerivedT &arr, const nanobind::sequence &extras) {
       std::vector<EltTy> values;
       intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
       values.reserve(numOldElements + nanobind::len(extras));
@@ -402,10 +402,10 @@ class MLIR_PYTHON_API_EXPORTED PyDenseElementsAttribute
   static constexpr const char *pyClassName = "DenseElementsAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
 
-  static PyDenseElementsAttribute
-  getFromList(const nanobind::list &attributes,
-              std::optional<PyType> explicitType,
-              DefaultingPyMlirContext contextWrapper);
+  static PyDenseElementsAttribute getFromList(
+      const nanobind::typed<nanobind::sequence, PyAttribute> &attributes,
+      std::optional<PyType> explicitType,
+      DefaultingPyMlirContext contextWrapper);
 
   static PyDenseElementsAttribute
   getFromBuffer(const nb_buffer &array, bool signless,
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index e834fca61522f..db8427cfc4f78 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -373,7 +373,7 @@ class MLIR_PYTHON_API_EXPORTED PyDiagnostic {
   PyDiagnosticSeverity getSeverity();
   PyLocation getLocation();
   nanobind::str getMessage();
-  nanobind::tuple getNotes();
+  nanobind::typed<nanobind::tuple, PyDiagnostic> getNotes();
 
   /// Materialized diagnostic information. This is safe to access outside the
   /// diagnostic callback.
@@ -755,8 +755,8 @@ class MLIR_PYTHON_API_EXPORTED PyOpView : public PyOperationBase {
   buildGeneric(std::string_view name, std::tuple<int, bool> opRegionSpec,
                nanobind::object operandSegmentSpecObj,
                nanobind::object resultSegmentSpecObj,
-               std::optional<nanobind::list> resultTypeList,
-               nanobind::list operandList,
+               std::optional<nanobind::sequence> resultTypeList,
+               nanobind::sequence operandList,
                std::optional<nanobind::dict> attributes,
                std::optional<std::vector<PyBlock *>> successors,
                std::optional<int> regions, PyLocation &location,
@@ -1360,8 +1360,9 @@ inline MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
 /// Create a block, using the current location context if no locations are
 /// specified.
 MlirBlock MLIR_PYTHON_API_EXPORTED
-createBlock(const nanobind::sequence &pyArgTypes,
-            const std::optional<nanobind::sequence> &pyArgLocs);
+createBlock(const nanobind::typed<nanobind::sequence, PyType> &pyArgTypes,
+            const std::optional<nanobind::typed<nanobind::sequence, PyLocation>>
+                &pyArgLocs);
 
 struct MLIR_PYTHON_API_EXPORTED PyAttrBuilderMap {
   static bool dunderContains(const std::string &attributeKind);
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index 8e81ee9805200..c84ed456de301 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -338,12 +338,12 @@ class MLIR_PYTHON_API_EXPORTED PyVectorType
 private:
   static PyVectorType
   getChecked(std::vector<int64_t> shape, PyType &elementType,
-             std::optional<nanobind::list> scalable,
+             std::optional<nanobind::sequence> scalable,
              std::optional<std::vector<int64_t>> scalableDims,
              DefaultingPyLocation loc);
 
   static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
-                          std::optional<nanobind::list> scalable,
+                          std::optional<nanobind::sequence> scalable,
                           std::optional<std::vector<int64_t>> scalableDims,
                           DefaultingPyMlirContext context);
 };
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 3b1b9abffb0f9..bd72082cea7e8 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -46,18 +46,20 @@ class PyTransformResults {
 
   MlirTransformResults get() const { return results; }
 
-  void setOps(PyValue &result, const nb::list &ops) {
+  void setOps(PyValue &result,
+              const nb::typed<nb::sequence, PyOperationBase> &ops) {
     std::vector<MlirOperation> opsVec;
-    opsVec.reserve(ops.size());
+    opsVec.reserve(nb::len(ops));
     for (auto op : ops) {
       opsVec.push_back(nb::cast<MlirOperation>(op));
     }
     mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
   }
 
-  void setValues(PyValue &result, const nb::list &values) {
+  void setValues(PyValue &result,
+                 const nb::typed<nb::sequence, PyValue> &values) {
     std::vector<MlirValue> valuesVec;
-    valuesVec.reserve(values.size());
+    valuesVec.reserve(nb::len(values));
     for (auto item : values) {
       valuesVec.push_back(nb::cast<MlirValue>(item));
     }
@@ -65,9 +67,10 @@ class PyTransformResults {
                                   valuesVec.data());
   }
 
-  void setParams(PyValue &result, const nb::list &params) {
+  void setParams(PyValue &result,
+                 const nb::typed<nb::sequence, PyAttribute> &params) {
     std::vector<MlirAttribute> paramsVec;
-    paramsVec.reserve(params.size());
+    paramsVec.reserve(nb::len(params));
     for (auto item : params) {
       paramsVec.push_back(nb::cast<MlirAttribute>(item));
     }
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 2ec13f10df380..07aa4759bae5f 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -37,7 +37,7 @@ static const char kDumpDocstring[] =
 /// Throws errors in case of failure, using "action" to describe what the caller
 /// was attempting to do.
 template <typename PyType, typename CType>
-static void pyListToVector(const nb::list &list, std::vector<CType> &result,
+static void pyListToVector(const nb::sequence &list, std::vector<CType> &result,
                            std::string_view action) {
   result.reserve(nb::len(list));
   for (nb::handle item : list) {
@@ -733,12 +733,12 @@ void populateIRAffine(nb::module_ &m) {
            })
       .def_static(
           "compress_unused_symbols",
-          [](nb::typed<nb::list, PyAffineMap> affineMaps,
+          [](nb::typed<nb::sequence, PyAffineMap> affineMaps,
              DefaultingPyMlirContext context) {
             std::vector<MlirAffineMap> maps;
             pyListToVector<PyAffineMap, MlirAffineMap>(
                 affineMaps, maps, "attempting to create an AffineMap");
-            std::vector<MlirAffineMap> compressed(affineMaps.size());
+            std::vector<MlirAffineMap> compressed(nb::len(affineMaps));
             auto populate = [](void *result, intptr_t idx, MlirAffineMap m) {
               static_cast<MlirAffineMap *>(result)[idx] = (m);
             };
@@ -762,7 +762,7 @@ void populateIRAffine(nb::module_ &m) {
       .def_static(
           "get",
           [](intptr_t dimCount, intptr_t symbolCount,
-             nb::typed<nb::list, PyAffineExpr> exprs,
+             nb::typed<nb::sequence, PyAffineExpr> exprs,
              DefaultingPyMlirContext context) {
             std::vector<MlirAffineExpr> affineExprs;
             pyListToVector<PyAffineExpr, MlirAffineExpr>(
@@ -930,13 +930,13 @@ void populateIRAffine(nb::module_ &m) {
       .def_static(
           "get",
           [](intptr_t numDims, intptr_t numSymbols,
-             nb::typed<nb::list, PyAffineExpr> exprs, std::vector<bool> eqFlags,
-             DefaultingPyMlirContext context) {
-            if (exprs.size() != eqFlags.size())
+             nb::typed<nb::sequence, PyAffineExpr> exprs,
+             std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
+            if (nb::len(exprs) != eqFlags.size())
               throw nb::value_error(
                   "Expected the number of constraints to match "
                   "that of equality flags");
-            if (exprs.size() == 0)
+            if (nb::len(exprs) == 0)
               throw nb::value_error("Expected non-empty list of constraints");
 
             // std::vector<bool> does not expose a bool* data pointer.
@@ -945,7 +945,7 @@ void populateIRAffine(nb::module_ &m) {
             pyListToVector<PyAffineExpr>(exprs, affineExprs,
                                          "attempting to create an IntegerSet");
             MlirIntegerSet set = mlirIntegerSetGet(
-                context->get(), numDims, numSymbols, exprs.size(),
+                context->get(), numDims, numSymbols, nb::len(exprs),
                 affineExprs.data(), reinterpret_cast<bool *>(flags.data()));
             return PyIntegerSet(context->getRef(), set);
           },
@@ -963,15 +963,15 @@ void populateIRAffine(nb::module_ &m) {
           nb::arg("context") = nb::none())
       .def(
           "get_replaced",
-          [](PyIntegerSet &self, nb::typed<nb::list, PyAffineExpr> dimExprs,
-             nb::typed<nb::list, PyAffineExpr> symbolExprs,
+          [](PyIntegerSet &self, nb::typed<nb::sequence, PyAffineExpr> dimExprs,
+             nb::typed<nb::sequence, PyAffineExpr> symbolExprs,
              intptr_t numResultDims, intptr_t numResultSymbols) {
-            if (static_cast<intptr_t>(dimExprs.size()) !=
+            if (static_cast<intptr_t>(nb::len(dimExprs)) !=
                 mlirIntegerSetGetNumDims(self))
               throw nb::value_error(
                   "Expected the number of dimension replacement expressions "
                   "to match that of dimensions");
-            if (static_cast<intptr_t>(symbolExprs.size()) !=
+            if (static_cast<intptr_t>(nb::len(symbolExprs)) !=
                 mlirIntegerSetGetNumSymbols(self))
               throw nb::value_error(
                   "Expected the number of symbol replacement expressions "
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 5aebfabf5bc18..7fada5bbc8502 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -280,7 +280,7 @@ MlirAttribute PyArrayAttribute::getItem(intptr_t i) const {
 void PyArrayAttribute::bindDerived(ClassTy &c) {
   c.def_static(
       "get",
-      [](nb::typed<nb::list, PyAttribute> attributes,
+      [](nb::typed<nb::sequence, PyAttribute> attributes,
          DefaultingPyMlirContext context) {
         std::vector<MlirAttribute> mlirAttributes;
         mlirAttributes.reserve(nb::len(attributes));
@@ -308,7 +308,7 @@ void PyArrayAttribute::bindDerived(ClassTy &c) {
         return PyArrayAttributeIterator(arr);
       });
   c.def("__add__", [](PyArrayAttribute arr,
-                      nb::typed<nb::list, PyAttribute> extras) {
+                      nb::typed<nb::sequence, PyAttribute> extras) {
     std::vector<MlirAttribute> attributes;
     intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
     attributes.reserve(numOldElements + nb::len(extras));
@@ -579,10 +579,10 @@ void PyOpaqueAttribute::bindDerived(ClassTy &c) {
       "Returns the data for the Opaqued attributes as `bytes`");
 }
 
-PyDenseElementsAttribute
-PyDenseElementsAttribute::getFromList(const nb::list &attributes,
-                                      std::optional<PyType> explicitType,
-                                      DefaultingPyMlirContext contextWrapper) {
+PyDenseElementsAttribute PyDenseElementsAttribute::getFromList(
+    const nb::typed<nb::sequence, PyAttribute> &attributes,
+    std::optional<PyType> explicitType,
+    DefaultingPyMlirContext contextWrapper) {
   const size_t numAttributes = nb::len(attributes);
   if (numAttributes == 0)
     throw nb::value_error("Attributes list must be non-empty.");
@@ -1175,7 +1175,8 @@ void PyDictAttribute::bindDerived(ClassTy &c) {
   c.def("__len__", &PyDictAttribute::dunderLen);
   c.def_static(
       "get",
-      [](const nb::dict &attributes, DefaultingPyMlirContext context) {
+      [](const nb::typed<nb::dict, nb::str, PyAttribute> &attributes,
+         DefaultingPyMlirContext context) {
         std::vector<MlirNamedAttribute> mlirNamedAttributes;
         mlirNamedAttributes.reserve(attributes.size());
         for (std::pair<nb::handle, nb::handle> it : attributes) {
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7e7e62f510bcd..350eb52765e1a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -76,8 +76,9 @@ namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
 
-MlirBlock createBlock(const nb::sequence &pyArgTypes,
-                      const std::optional<nb::sequence> &pyArgLocs) {
+MlirBlock createBlock(
+    const nb::typed<nb::sequence, PyType> &pyArgTypes,
+    const std::optional<nb::typed<nb::sequence, PyLocation>> &pyArgLocs) {
   std::vector<MlirType> argTypes;
   argTypes.reserve(nb::len(pyArgTypes));
   for (nb::handle pyType : pyArgTypes)
@@ -767,7 +768,7 @@ nb::str PyDiagnostic::getMessage() {
   return nb::cast<nb::str>(fileObject.attr("getvalue")());
 }
 
-nb::tuple PyDiagnostic::getNotes() {
+nb::typed<nb::tuple, PyDiagnostic> PyDiagnostic::getNotes() {
   checkValid();
   if (materializedNotes)
     return *materializedNotes;
@@ -1442,11 +1443,12 @@ PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
 // PyOpView
 //------------------------------------------------------------------------------
 
-static void populateResultTypes(std::string_view name, nb::list resultTypeList,
+static void populateResultTypes(std::string_view name,
+                                nb::sequence resultTypeList,
                                 const nb::object &resultSegmentSpecObj,
                                 std::vector<int32_t> &resultSegmentLengths,
                                 std::vector<PyType *> &resultTypes) {
-  resultTypes.reserve(resultTypeList.size());
+  resultTypes.reserve(nb::len(resultTypeList));
   if (resultSegmentSpecObj.is_none()) {
     // Non-variadic result unpacking.
     size_t index = 0;
@@ -1465,13 +1467,13 @@ static void populateResultTypes(std::string_view name, nb::list resultTypeList,
   } else {
     // Sized result unpacking.
     auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
-    if (resultSegmentSpec.size() != resultTypeList.size()) {
+    if (resultSegmentSpec.size() != nb::len(resultTypeList)) {
       throw nb::value_error(
           join("Operation \"", name, "\" requires ", resultSegmentSpec.size(),
-               " result segments but was provided ", resultTypeList.size())
+               " result segments but was provided ", nb::len(resultTypeList))
               .c_str());
     }
-    resultSegmentLengths.reserve(resultTypeList.size());
+    resultSegmentLengths.reserve(nb::len(resultTypeList));
     for (size_t i = 0, e = resultSegmentSpec.size(); i < e; ++i) {
       int segmentSpec = resultSegmentSpec[i];
       if (segmentSpec == 1 || segmentSpec == 0) {
@@ -1565,7 +1567,7 @@ static MlirValue getOpResultOrValue(nb::handle operand) {
 nb::typed<nb::object, PyOperation> PyOpView::buildGeneric(
     std::string_view name, std::tuple<int, bool> opRegionSpec,
     nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
-    std::optional<nb::list> resultTypeList, nb::list operandList,
+    std::optional<nb::sequence> resultTypeList, nb::sequence operandList,
     std::optional<nb::dict> attributes,
     std::optional<std::vector<PyBlock *>> successors,
     std::optional<int> regions, PyLocation &location,
@@ -1626,13 +1628,13 @@ nb::typed<nb::object, PyOperation> PyOpView::buildGeneric(
   } else {
     // Sized operand unpacking.
     auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
-    if (operandSegmentSpec.size() != operandList.size()) {
+    if (operandSegmentSpec.size() != nb::len(operandList)) {
       throw nb::value_error(
           join("Operation \"", name, "\" requires ", operandSegmentSpec.size(),
-               "operand segments but was provided ", operandList.size())
+               "operand segments but was provided ", nb::len(operandList))
               .c_str());
     }
-    operandSegmentLengths.reserve(operandList.size());
+    operandSegmentLengths.reserve(nb::len(operandList));
     for (size_t i = 0, e = operandSegmentSpec.size(); i < e; ++i) {
       int segmentSpec = operandSegmentSpec[i];
       if (segmentSpec == 1 || segmentSpec == 0) {
@@ -2513,10 +2515,10 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
 
 void PyOpAdaptor::bind(nb::module_ &m) {
   nb::class_<PyOpAdaptor>(m, "OpAdaptor")
-      .def(nb::init<nb::list, PyOpAttributeMap>(),
+      .def(nb::init<nb::typed<nb::list, PyValue>, PyOpAttributeMap>(),
            "Creates an OpAdaptor with the given operands and attributes.",
            "operands"_a, "attributes"_a)
-      .def(nb::init<nb::list, PyOpView &>(),
+      .def(nb::init<nb::typed<nb::list, PyValue>, PyOpView &>(),
            "Creates an OpAdaptor with the given operands and operation view.",
            "operands"_a, "opview"_a)
       .def_prop_ro(
@@ -3944,7 +3946,8 @@ void populateIRCore(nb::module_ &m) {
           [](std::string_view name,
              std::optional<std::vector<PyType *>> results,
              std::optional<std::vector<PyValue *>> operands,
-             std::optional<nb::dict> attributes,
+             std::optional<nb::typed<nb::dict, nb::str, PyAttribute>>
+                 attributes,
              std::optional<std::vector<PyBlock *>> successors, int regions,
              const std::optional<PyLocation> &location,
              const nb::object &maybeIp,
@@ -4046,8 +4049,10 @@ void populateIRCore(nb::module_ &m) {
                  std::tuple<int, bool> opRegionSpec,
                  nb::object operandSegmentSpecObj,
                  nb::object resultSegmentSpecObj,
-                 std::optional<nb::list> resultTypeList, nb::list operandList,
-                 std::optional<nb::dict> attributes,
+                 std::optional<nb::sequence> resultTypeList,
+                 nb::sequence operandList,
+                 std::optional<nb::typed<nb::dict, nb::str, PyAttribute>>
+                     attributes,
                  std::optional<std::vector<PyBlock *>> successors,
                  std::optional<int> regions,
                  const std::optional<PyLocation> &location,
@@ -4093,8 +4098,9 @@ void populateIRCore(nb::module_ &m) {
   // ods_operand_segments/ods_result_segments as arguments to the constructor,
   // rather than to access them as attributes.
   opViewClass.attr("build_generic") = classmethod(
-      [](nb::handle cls, std::optional<nb::list> resultTypeList,
-         nb::list operandList, std::optional<nb::dict> attributes,
+      [](nb::handle cls, std::optional<nb::sequence> resultTypeList,
+         nb::sequence operandList,
+         std::optional<nb::typed<nb::dict, nb::str, PyAttribute>> attributes,
          std::optional<std::vector<PyBlock *>> successors,
          std::optional<int> regions, std::optional<PyLocation> location,
          const nb::object &maybeIp) {
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index e7865cfda9d6f..561316d863f75 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -36,16 +36,16 @@ namespace {
 
 /// Takes in an optional ist of operands and converts them into a std::vector
 /// of MlirVlaues. Returns an empty std::vector if the list is empty.
-std::vector<MlirValue> wrapOperands(std::optional<nb::list> operandList) {
+std::vector<MlirValue> wrapOperands(std::optional<nb::sequence> operandList) {
   std::vector<MlirValue> mlirOperands;
 
-  if (!operandList || operandList->size() == 0) {
+  if (!operandList || nb::len(*operandList) == 0) {
     return mlirOperands;
   }
 
   // Note: as the list may contain other lists this may not be final size.
-  mlirOperands.reserve(operandList->size());
-  for (size_t i = 0, e = operandList->size(); i < e; ++i) {
+  mlirOperands.reserve(nb::len(*operandList));
+  for (size_t i = 0, e = nb::len(*operandList); i < e; ++i) {
     nb::handle operand = (*operandList)[i];
     intptr_t index = static_cast<intptr_t>(i);
     if (operand.is_none())
@@ -143,7 +143,7 @@ class PyInferTypeOpInterface
   /// Given the arguments required to build an operation, attempts to infer its
   /// return types. Throws value_error on failure.
   std::vector<PyType>
-  inferReturnTypes(std::optional<nb::list> operandList,
+  inferReturnTypes(std::optional<nb::sequence> operandList,
                    std::optional<PyAttribute> attributes, void *properties,
                    std::optional<std::vector<PyRegion>> regions,
                    DefaultingPyMlirContext context,
@@ -213,14 +213,15 @@ class PyShapedTypeComponents {
             "type.")
         .def_static(
             "get",
-            [](nb::list shape, PyType &elementType) {
+            [](nb::typed<nb::list, nb::int_> shape, PyType &elementType) {
               return PyShapedTypeComponents(std::move(shape), elementType);
             },
             nb::arg("shape"), nb::arg("element_type"),
             "Create a ranked shaped type components object.")
         .def_static(
             "get",
-            [](nb::list shape, PyType &elementType, PyAttribute &attribute) {
+            [](nb::typed<nb::list, nb::int_> shape, PyType &elementType,
+               PyAttribute &attribute) {
               return PyShapedTypeComponents(std::move(shape), elementType,
                                             attribute);
             },
@@ -300,7 +301,7 @@ class PyInferShapedTypeOpInterface
   /// Given the arguments required to build an operation, attempts to infer the
   /// shaped type components. Throws value_error on failure.
   std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
-      std::optional<nb::list> operandList,
+      std::optional<nb::sequence> operandList,
       std::optional<PyAttribute> attributes, void *properties,
       std::optional<std::vector<PyRegion>> regions,
       DefaultingPyMlirContext context, DefaultingPyLocation location) {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 340deb019eb20..75fd55c90c2b5 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -465,7 +465,7 @@ void PyVectorType::bindDerived(ClassTy &c) {
 
 PyVectorType
 PyVectorType::getChecked(std::vector<int64_t> shape, PyType &elementType,
-                         std::optional<nb::list> scalable,
+                         std::optional<nb::sequence> scalable,
                          std::optional<std::vector<int64_t>> scalableDims,
                          DefaultingPyLocation loc) {
   if (scalable && scalableDims) {
@@ -476,11 +476,11 @@ PyVectorType::getChecked(std::vector<int64_t> shape, PyType &elementType,
   PyMlirContext::ErrorCapture errors(loc->getContext());
   MlirType type;
   if (scalable) {
-    if (scalable->size() != shape.size())
+    if (nb::len(*scalable) != shape.size())
       throw nb::value_error("Expected len(scalable) == len(shape).");
 
     std::vector<char> scalableDimFlags;
-    scalableDimFlags.reserve(scalable->size());
+    scalableDimFlags.reserve(nb::len(*scalable));
     for (const nb::handle &h : *scalable) {
       scalableDimFlags.push_back(nb::cast<bool>(h) ? 1 : 0);
     }
@@ -507,7 +507,7 @@ PyVectorType::getChecked(std::vector<int64_t> shape, PyType &elementType,
 }
 
 PyVectorType PyVectorType::get(std::vector<int64_t> shape, PyType &elementType,
-                               std::optional<nb::list> scalable,
+                               std::optional<nb::sequence> scalable,
                                std::optional<std::vector<int64_t>> scalableDims,
                                DefaultingPyMlirContext context) {
   if (scalable && scalableDims) {
@@ -518,11 +518,11 @@ PyVectorType PyVectorType::get(std::vector<int64_t> shape, PyType &elementType,
   PyMlirContext::ErrorCapture errors(context->getRef());
   MlirType type;
   if (scalable) {
-    if (scalable->size() != shape.size())
+    if (nb::len(*scalable) != shape.size())
       throw nb::value_error("Expected len(scalable) == len(shape).");
 
     std::vector<char> scalableDimFlags;
-    scalableDimFlags.reserve(scalable->size());
+    scalableDimFlags.reserve(nb::len(*scalable));
     for (const nb::handle &h : *scalable) {
       scalableDimFlags.push_back(nb::cast<bool>(h) ? 1 : 0);
     }



More information about the Mlir-commits mailing list