[Mlir-commits] [mlir] [MLIR] [Python] Relaxed `list` to `Sequence` in most parameter types (PR #188543)
Sergei Lebedev
llvmlistbot at llvm.org
Wed Mar 25 09:59:03 PDT 2026
https://github.com/superbobry created https://github.com/llvm/llvm-project/pull/188543
Using `Sequence` frees users from the need to cast to `list` in cases where the underlying API does not really care about the type of the container.
Note that accepting an `nb::sequence` is marginally slower than accepting `nb::list` directly, because `__getitem__`, `__len__` etc need to go through an extra layer of indirection. However, I expect the performance difference to be negligible.
>From 0fd59df89c4ab982c974dd8a7de1236b8dd6dc40 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Wed, 25 Mar 2026 16:53:12 +0000
Subject: [PATCH] [MLIR] [Python] Relaxed `list` to `Sequence` in most
parameter types
Using `Sequence` frees users from the need to cast to `list` in cases where the
underlying API does not really care about the type of the container.
Note that accepting an `nb::sequence` is marginally slower than accepting
`nb::list` directly, because `__getitem__`, `__len__` etc need to go through
an extra layer of indirection. However, I expect the performance difference
to be negligible.
---
.../mlir/Bindings/Python/IRAttributes.h | 10 ++---
mlir/include/mlir/Bindings/Python/IRCore.h | 11 ++---
mlir/include/mlir/Bindings/Python/IRTypes.h | 4 +-
mlir/lib/Bindings/Python/DialectTransform.cpp | 15 ++++---
mlir/lib/Bindings/Python/IRAffine.cpp | 26 +++++------
mlir/lib/Bindings/Python/IRAttributes.cpp | 15 ++++---
mlir/lib/Bindings/Python/IRCore.cpp | 44 +++++++++++--------
mlir/lib/Bindings/Python/IRInterfaces.cpp | 17 +++----
mlir/lib/Bindings/Python/IRTypes.cpp | 12 ++---
9 files changed, 83 insertions(+), 71 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
index 28a7134dc2d45..7f2f88b552453 100644
--- a/mlir/include/mlir/Bindings/Python/IRAttributes.h
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -193,7 +193,7 @@ class MLIR_PYTHON_API_EXPORTED PyDenseArrayAttribute
});
c.def("__iter__",
[](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
- c.def("__add__", [](DerivedT &arr, const nanobind::list &extras) {
+ c.def("__add__", [](DerivedT &arr, const nanobind::sequence &extras) {
std::vector<EltTy> values;
intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
values.reserve(numOldElements + nanobind::len(extras));
@@ -402,10 +402,10 @@ class MLIR_PYTHON_API_EXPORTED PyDenseElementsAttribute
static constexpr const char *pyClassName = "DenseElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
- static PyDenseElementsAttribute
- getFromList(const nanobind::list &attributes,
- std::optional<PyType> explicitType,
- DefaultingPyMlirContext contextWrapper);
+ static PyDenseElementsAttribute getFromList(
+ const nanobind::typed<nanobind::sequence, PyAttribute> &attributes,
+ std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper);
static PyDenseElementsAttribute
getFromBuffer(const nb_buffer &array, bool signless,
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index e834fca61522f..db8427cfc4f78 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -373,7 +373,7 @@ class MLIR_PYTHON_API_EXPORTED PyDiagnostic {
PyDiagnosticSeverity getSeverity();
PyLocation getLocation();
nanobind::str getMessage();
- nanobind::tuple getNotes();
+ nanobind::typed<nanobind::tuple, PyDiagnostic> getNotes();
/// Materialized diagnostic information. This is safe to access outside the
/// diagnostic callback.
@@ -755,8 +755,8 @@ class MLIR_PYTHON_API_EXPORTED PyOpView : public PyOperationBase {
buildGeneric(std::string_view name, std::tuple<int, bool> opRegionSpec,
nanobind::object operandSegmentSpecObj,
nanobind::object resultSegmentSpecObj,
- std::optional<nanobind::list> resultTypeList,
- nanobind::list operandList,
+ std::optional<nanobind::sequence> resultTypeList,
+ nanobind::sequence operandList,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, PyLocation &location,
@@ -1360,8 +1360,9 @@ inline MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
/// Create a block, using the current location context if no locations are
/// specified.
MlirBlock MLIR_PYTHON_API_EXPORTED
-createBlock(const nanobind::sequence &pyArgTypes,
- const std::optional<nanobind::sequence> &pyArgLocs);
+createBlock(const nanobind::typed<nanobind::sequence, PyType> &pyArgTypes,
+ const std::optional<nanobind::typed<nanobind::sequence, PyLocation>>
+ &pyArgLocs);
struct MLIR_PYTHON_API_EXPORTED PyAttrBuilderMap {
static bool dunderContains(const std::string &attributeKind);
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index 8e81ee9805200..c84ed456de301 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -338,12 +338,12 @@ class MLIR_PYTHON_API_EXPORTED PyVectorType
private:
static PyVectorType
getChecked(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nanobind::list> scalable,
+ std::optional<nanobind::sequence> scalable,
std::optional<std::vector<int64_t>> scalableDims,
DefaultingPyLocation loc);
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nanobind::list> scalable,
+ std::optional<nanobind::sequence> scalable,
std::optional<std::vector<int64_t>> scalableDims,
DefaultingPyMlirContext context);
};
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 3b1b9abffb0f9..bd72082cea7e8 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -46,18 +46,20 @@ class PyTransformResults {
MlirTransformResults get() const { return results; }
- void setOps(PyValue &result, const nb::list &ops) {
+ void setOps(PyValue &result,
+ const nb::typed<nb::sequence, PyOperationBase> &ops) {
std::vector<MlirOperation> opsVec;
- opsVec.reserve(ops.size());
+ opsVec.reserve(nb::len(ops));
for (auto op : ops) {
opsVec.push_back(nb::cast<MlirOperation>(op));
}
mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
}
- void setValues(PyValue &result, const nb::list &values) {
+ void setValues(PyValue &result,
+ const nb::typed<nb::sequence, PyValue> &values) {
std::vector<MlirValue> valuesVec;
- valuesVec.reserve(values.size());
+ valuesVec.reserve(nb::len(values));
for (auto item : values) {
valuesVec.push_back(nb::cast<MlirValue>(item));
}
@@ -65,9 +67,10 @@ class PyTransformResults {
valuesVec.data());
}
- void setParams(PyValue &result, const nb::list ¶ms) {
+ void setParams(PyValue &result,
+ const nb::typed<nb::sequence, PyAttribute> ¶ms) {
std::vector<MlirAttribute> paramsVec;
- paramsVec.reserve(params.size());
+ paramsVec.reserve(nb::len(params));
for (auto item : params) {
paramsVec.push_back(nb::cast<MlirAttribute>(item));
}
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 2ec13f10df380..07aa4759bae5f 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -37,7 +37,7 @@ static const char kDumpDocstring[] =
/// Throws errors in case of failure, using "action" to describe what the caller
/// was attempting to do.
template <typename PyType, typename CType>
-static void pyListToVector(const nb::list &list, std::vector<CType> &result,
+static void pyListToVector(const nb::sequence &list, std::vector<CType> &result,
std::string_view action) {
result.reserve(nb::len(list));
for (nb::handle item : list) {
@@ -733,12 +733,12 @@ void populateIRAffine(nb::module_ &m) {
})
.def_static(
"compress_unused_symbols",
- [](nb::typed<nb::list, PyAffineMap> affineMaps,
+ [](nb::typed<nb::sequence, PyAffineMap> affineMaps,
DefaultingPyMlirContext context) {
std::vector<MlirAffineMap> maps;
pyListToVector<PyAffineMap, MlirAffineMap>(
affineMaps, maps, "attempting to create an AffineMap");
- std::vector<MlirAffineMap> compressed(affineMaps.size());
+ std::vector<MlirAffineMap> compressed(nb::len(affineMaps));
auto populate = [](void *result, intptr_t idx, MlirAffineMap m) {
static_cast<MlirAffineMap *>(result)[idx] = (m);
};
@@ -762,7 +762,7 @@ void populateIRAffine(nb::module_ &m) {
.def_static(
"get",
[](intptr_t dimCount, intptr_t symbolCount,
- nb::typed<nb::list, PyAffineExpr> exprs,
+ nb::typed<nb::sequence, PyAffineExpr> exprs,
DefaultingPyMlirContext context) {
std::vector<MlirAffineExpr> affineExprs;
pyListToVector<PyAffineExpr, MlirAffineExpr>(
@@ -930,13 +930,13 @@ void populateIRAffine(nb::module_ &m) {
.def_static(
"get",
[](intptr_t numDims, intptr_t numSymbols,
- nb::typed<nb::list, PyAffineExpr> exprs, std::vector<bool> eqFlags,
- DefaultingPyMlirContext context) {
- if (exprs.size() != eqFlags.size())
+ nb::typed<nb::sequence, PyAffineExpr> exprs,
+ std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
+ if (nb::len(exprs) != eqFlags.size())
throw nb::value_error(
"Expected the number of constraints to match "
"that of equality flags");
- if (exprs.size() == 0)
+ if (nb::len(exprs) == 0)
throw nb::value_error("Expected non-empty list of constraints");
// std::vector<bool> does not expose a bool* data pointer.
@@ -945,7 +945,7 @@ void populateIRAffine(nb::module_ &m) {
pyListToVector<PyAffineExpr>(exprs, affineExprs,
"attempting to create an IntegerSet");
MlirIntegerSet set = mlirIntegerSetGet(
- context->get(), numDims, numSymbols, exprs.size(),
+ context->get(), numDims, numSymbols, nb::len(exprs),
affineExprs.data(), reinterpret_cast<bool *>(flags.data()));
return PyIntegerSet(context->getRef(), set);
},
@@ -963,15 +963,15 @@ void populateIRAffine(nb::module_ &m) {
nb::arg("context") = nb::none())
.def(
"get_replaced",
- [](PyIntegerSet &self, nb::typed<nb::list, PyAffineExpr> dimExprs,
- nb::typed<nb::list, PyAffineExpr> symbolExprs,
+ [](PyIntegerSet &self, nb::typed<nb::sequence, PyAffineExpr> dimExprs,
+ nb::typed<nb::sequence, PyAffineExpr> symbolExprs,
intptr_t numResultDims, intptr_t numResultSymbols) {
- if (static_cast<intptr_t>(dimExprs.size()) !=
+ if (static_cast<intptr_t>(nb::len(dimExprs)) !=
mlirIntegerSetGetNumDims(self))
throw nb::value_error(
"Expected the number of dimension replacement expressions "
"to match that of dimensions");
- if (static_cast<intptr_t>(symbolExprs.size()) !=
+ if (static_cast<intptr_t>(nb::len(symbolExprs)) !=
mlirIntegerSetGetNumSymbols(self))
throw nb::value_error(
"Expected the number of symbol replacement expressions "
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 5aebfabf5bc18..7fada5bbc8502 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -280,7 +280,7 @@ MlirAttribute PyArrayAttribute::getItem(intptr_t i) const {
void PyArrayAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](nb::typed<nb::list, PyAttribute> attributes,
+ [](nb::typed<nb::sequence, PyAttribute> attributes,
DefaultingPyMlirContext context) {
std::vector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(nb::len(attributes));
@@ -308,7 +308,7 @@ void PyArrayAttribute::bindDerived(ClassTy &c) {
return PyArrayAttributeIterator(arr);
});
c.def("__add__", [](PyArrayAttribute arr,
- nb::typed<nb::list, PyAttribute> extras) {
+ nb::typed<nb::sequence, PyAttribute> extras) {
std::vector<MlirAttribute> attributes;
intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
attributes.reserve(numOldElements + nb::len(extras));
@@ -579,10 +579,10 @@ void PyOpaqueAttribute::bindDerived(ClassTy &c) {
"Returns the data for the Opaqued attributes as `bytes`");
}
-PyDenseElementsAttribute
-PyDenseElementsAttribute::getFromList(const nb::list &attributes,
- std::optional<PyType> explicitType,
- DefaultingPyMlirContext contextWrapper) {
+PyDenseElementsAttribute PyDenseElementsAttribute::getFromList(
+ const nb::typed<nb::sequence, PyAttribute> &attributes,
+ std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper) {
const size_t numAttributes = nb::len(attributes);
if (numAttributes == 0)
throw nb::value_error("Attributes list must be non-empty.");
@@ -1175,7 +1175,8 @@ void PyDictAttribute::bindDerived(ClassTy &c) {
c.def("__len__", &PyDictAttribute::dunderLen);
c.def_static(
"get",
- [](const nb::dict &attributes, DefaultingPyMlirContext context) {
+ [](const nb::typed<nb::dict, nb::str, PyAttribute> &attributes,
+ DefaultingPyMlirContext context) {
std::vector<MlirNamedAttribute> mlirNamedAttributes;
mlirNamedAttributes.reserve(attributes.size());
for (std::pair<nb::handle, nb::handle> it : attributes) {
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7e7e62f510bcd..350eb52765e1a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -76,8 +76,9 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-MlirBlock createBlock(const nb::sequence &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
+MlirBlock createBlock(
+ const nb::typed<nb::sequence, PyType> &pyArgTypes,
+ const std::optional<nb::typed<nb::sequence, PyLocation>> &pyArgLocs) {
std::vector<MlirType> argTypes;
argTypes.reserve(nb::len(pyArgTypes));
for (nb::handle pyType : pyArgTypes)
@@ -767,7 +768,7 @@ nb::str PyDiagnostic::getMessage() {
return nb::cast<nb::str>(fileObject.attr("getvalue")());
}
-nb::tuple PyDiagnostic::getNotes() {
+nb::typed<nb::tuple, PyDiagnostic> PyDiagnostic::getNotes() {
checkValid();
if (materializedNotes)
return *materializedNotes;
@@ -1442,11 +1443,12 @@ PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
// PyOpView
//------------------------------------------------------------------------------
-static void populateResultTypes(std::string_view name, nb::list resultTypeList,
+static void populateResultTypes(std::string_view name,
+ nb::sequence resultTypeList,
const nb::object &resultSegmentSpecObj,
std::vector<int32_t> &resultSegmentLengths,
std::vector<PyType *> &resultTypes) {
- resultTypes.reserve(resultTypeList.size());
+ resultTypes.reserve(nb::len(resultTypeList));
if (resultSegmentSpecObj.is_none()) {
// Non-variadic result unpacking.
size_t index = 0;
@@ -1465,13 +1467,13 @@ static void populateResultTypes(std::string_view name, nb::list resultTypeList,
} else {
// Sized result unpacking.
auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
- if (resultSegmentSpec.size() != resultTypeList.size()) {
+ if (resultSegmentSpec.size() != nb::len(resultTypeList)) {
throw nb::value_error(
join("Operation \"", name, "\" requires ", resultSegmentSpec.size(),
- " result segments but was provided ", resultTypeList.size())
+ " result segments but was provided ", nb::len(resultTypeList))
.c_str());
}
- resultSegmentLengths.reserve(resultTypeList.size());
+ resultSegmentLengths.reserve(nb::len(resultTypeList));
for (size_t i = 0, e = resultSegmentSpec.size(); i < e; ++i) {
int segmentSpec = resultSegmentSpec[i];
if (segmentSpec == 1 || segmentSpec == 0) {
@@ -1565,7 +1567,7 @@ static MlirValue getOpResultOrValue(nb::handle operand) {
nb::typed<nb::object, PyOperation> PyOpView::buildGeneric(
std::string_view name, std::tuple<int, bool> opRegionSpec,
nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
- std::optional<nb::list> resultTypeList, nb::list operandList,
+ std::optional<nb::sequence> resultTypeList, nb::sequence operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, PyLocation &location,
@@ -1626,13 +1628,13 @@ nb::typed<nb::object, PyOperation> PyOpView::buildGeneric(
} else {
// Sized operand unpacking.
auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
- if (operandSegmentSpec.size() != operandList.size()) {
+ if (operandSegmentSpec.size() != nb::len(operandList)) {
throw nb::value_error(
join("Operation \"", name, "\" requires ", operandSegmentSpec.size(),
- "operand segments but was provided ", operandList.size())
+ "operand segments but was provided ", nb::len(operandList))
.c_str());
}
- operandSegmentLengths.reserve(operandList.size());
+ operandSegmentLengths.reserve(nb::len(operandList));
for (size_t i = 0, e = operandSegmentSpec.size(); i < e; ++i) {
int segmentSpec = operandSegmentSpec[i];
if (segmentSpec == 1 || segmentSpec == 0) {
@@ -2513,10 +2515,10 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
void PyOpAdaptor::bind(nb::module_ &m) {
nb::class_<PyOpAdaptor>(m, "OpAdaptor")
- .def(nb::init<nb::list, PyOpAttributeMap>(),
+ .def(nb::init<nb::typed<nb::list, PyValue>, PyOpAttributeMap>(),
"Creates an OpAdaptor with the given operands and attributes.",
"operands"_a, "attributes"_a)
- .def(nb::init<nb::list, PyOpView &>(),
+ .def(nb::init<nb::typed<nb::list, PyValue>, PyOpView &>(),
"Creates an OpAdaptor with the given operands and operation view.",
"operands"_a, "opview"_a)
.def_prop_ro(
@@ -3944,7 +3946,8 @@ void populateIRCore(nb::module_ &m) {
[](std::string_view name,
std::optional<std::vector<PyType *>> results,
std::optional<std::vector<PyValue *>> operands,
- std::optional<nb::dict> attributes,
+ std::optional<nb::typed<nb::dict, nb::str, PyAttribute>>
+ attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
const std::optional<PyLocation> &location,
const nb::object &maybeIp,
@@ -4046,8 +4049,10 @@ void populateIRCore(nb::module_ &m) {
std::tuple<int, bool> opRegionSpec,
nb::object operandSegmentSpecObj,
nb::object resultSegmentSpecObj,
- std::optional<nb::list> resultTypeList, nb::list operandList,
- std::optional<nb::dict> attributes,
+ std::optional<nb::sequence> resultTypeList,
+ nb::sequence operandList,
+ std::optional<nb::typed<nb::dict, nb::str, PyAttribute>>
+ attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions,
const std::optional<PyLocation> &location,
@@ -4093,8 +4098,9 @@ void populateIRCore(nb::module_ &m) {
// ods_operand_segments/ods_result_segments as arguments to the constructor,
// rather than to access them as attributes.
opViewClass.attr("build_generic") = classmethod(
- [](nb::handle cls, std::optional<nb::list> resultTypeList,
- nb::list operandList, std::optional<nb::dict> attributes,
+ [](nb::handle cls, std::optional<nb::sequence> resultTypeList,
+ nb::sequence operandList,
+ std::optional<nb::typed<nb::dict, nb::str, PyAttribute>> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, std::optional<PyLocation> location,
const nb::object &maybeIp) {
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index e7865cfda9d6f..561316d863f75 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -36,16 +36,16 @@ namespace {
/// Takes in an optional ist of operands and converts them into a std::vector
/// of MlirVlaues. Returns an empty std::vector if the list is empty.
-std::vector<MlirValue> wrapOperands(std::optional<nb::list> operandList) {
+std::vector<MlirValue> wrapOperands(std::optional<nb::sequence> operandList) {
std::vector<MlirValue> mlirOperands;
- if (!operandList || operandList->size() == 0) {
+ if (!operandList || nb::len(*operandList) == 0) {
return mlirOperands;
}
// Note: as the list may contain other lists this may not be final size.
- mlirOperands.reserve(operandList->size());
- for (size_t i = 0, e = operandList->size(); i < e; ++i) {
+ mlirOperands.reserve(nb::len(*operandList));
+ for (size_t i = 0, e = nb::len(*operandList); i < e; ++i) {
nb::handle operand = (*operandList)[i];
intptr_t index = static_cast<intptr_t>(i);
if (operand.is_none())
@@ -143,7 +143,7 @@ class PyInferTypeOpInterface
/// Given the arguments required to build an operation, attempts to infer its
/// return types. Throws value_error on failure.
std::vector<PyType>
- inferReturnTypes(std::optional<nb::list> operandList,
+ inferReturnTypes(std::optional<nb::sequence> operandList,
std::optional<PyAttribute> attributes, void *properties,
std::optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context,
@@ -213,14 +213,15 @@ class PyShapedTypeComponents {
"type.")
.def_static(
"get",
- [](nb::list shape, PyType &elementType) {
+ [](nb::typed<nb::list, nb::int_> shape, PyType &elementType) {
return PyShapedTypeComponents(std::move(shape), elementType);
},
nb::arg("shape"), nb::arg("element_type"),
"Create a ranked shaped type components object.")
.def_static(
"get",
- [](nb::list shape, PyType &elementType, PyAttribute &attribute) {
+ [](nb::typed<nb::list, nb::int_> shape, PyType &elementType,
+ PyAttribute &attribute) {
return PyShapedTypeComponents(std::move(shape), elementType,
attribute);
},
@@ -300,7 +301,7 @@ class PyInferShapedTypeOpInterface
/// Given the arguments required to build an operation, attempts to infer the
/// shaped type components. Throws value_error on failure.
std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
- std::optional<nb::list> operandList,
+ std::optional<nb::sequence> operandList,
std::optional<PyAttribute> attributes, void *properties,
std::optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context, DefaultingPyLocation location) {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 340deb019eb20..75fd55c90c2b5 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -465,7 +465,7 @@ void PyVectorType::bindDerived(ClassTy &c) {
PyVectorType
PyVectorType::getChecked(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
+ std::optional<nb::sequence> scalable,
std::optional<std::vector<int64_t>> scalableDims,
DefaultingPyLocation loc) {
if (scalable && scalableDims) {
@@ -476,11 +476,11 @@ PyVectorType::getChecked(std::vector<int64_t> shape, PyType &elementType,
PyMlirContext::ErrorCapture errors(loc->getContext());
MlirType type;
if (scalable) {
- if (scalable->size() != shape.size())
+ if (nb::len(*scalable) != shape.size())
throw nb::value_error("Expected len(scalable) == len(shape).");
std::vector<char> scalableDimFlags;
- scalableDimFlags.reserve(scalable->size());
+ scalableDimFlags.reserve(nb::len(*scalable));
for (const nb::handle &h : *scalable) {
scalableDimFlags.push_back(nb::cast<bool>(h) ? 1 : 0);
}
@@ -507,7 +507,7 @@ PyVectorType::getChecked(std::vector<int64_t> shape, PyType &elementType,
}
PyVectorType PyVectorType::get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
+ std::optional<nb::sequence> scalable,
std::optional<std::vector<int64_t>> scalableDims,
DefaultingPyMlirContext context) {
if (scalable && scalableDims) {
@@ -518,11 +518,11 @@ PyVectorType PyVectorType::get(std::vector<int64_t> shape, PyType &elementType,
PyMlirContext::ErrorCapture errors(context->getRef());
MlirType type;
if (scalable) {
- if (scalable->size() != shape.size())
+ if (nb::len(*scalable) != shape.size())
throw nb::value_error("Expected len(scalable) == len(shape).");
std::vector<char> scalableDimFlags;
- scalableDimFlags.reserve(scalable->size());
+ scalableDimFlags.reserve(nb::len(*scalable));
for (const nb::handle &h : *scalable) {
scalableDimFlags.push_back(nb::cast<bool>(h) ? 1 : 0);
}
More information about the Mlir-commits
mailing list