[Mlir-commits] [mlir] [mlir][Python] downcast ir.Value to BlockArgument or OpResult (PR #175264)
Maksim Levental
llvmlistbot at llvm.org
Fri Jan 9 17:05:01 PST 2026
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/175264
>From a895c2e4ce49a0c1732329283161cd1dc2f54b8b Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 15:02:44 -0800
Subject: [PATCH] [mlir][Python] downcast ir.Value to BlockArgument or OpResult
---
mlir/include/mlir/Bindings/Python/IRCore.h | 13 +++--
.../mlir/Bindings/Python/NanobindUtils.h | 2 +-
mlir/lib/Bindings/Python/IRAffine.cpp | 2 +-
mlir/lib/Bindings/Python/IRCore.cpp | 52 ++++++++++---------
4 files changed, 37 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 330318683c15e..bfc35ca5b9d50 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -36,20 +36,21 @@
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-
+class DefaultingPyLocation;
+class DefaultingPyMlirContext;
class PyBlock;
+class PyBlockArgument;
class PyDiagnostic;
class PyDiagnosticHandler;
class PyInsertionPoint;
class PyLocation;
-class DefaultingPyLocation;
class PyMlirContext;
-class DefaultingPyMlirContext;
class PyModule;
+class PyOpResult;
class PyOperation;
class PyOperationBase;
-class PyType;
class PySymbolTable;
+class PyType;
class PyValue;
/// Wrapper for the global LLVM debugging flag.
@@ -1188,7 +1189,9 @@ class MLIR_PYTHON_API_EXPORTED PyValue {
/// Gets a capsule wrapping the void* within the MlirValue.
nanobind::object getCapsule();
- nanobind::typed<nanobind::object, PyValue> maybeDownCast();
+ nanobind::typed<nanobind::object,
+ std::variant<PyBlockArgument, PyOpResult, PyValue>>
+ maybeDownCast();
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
diff --git a/mlir/include/mlir/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
index aea195fecae82..215daf245b902 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindUtils.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
@@ -277,7 +277,7 @@ class Sliceable {
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
- nanobind::object getItem(intptr_t index) {
+ nanobind::typed<nanobind::object, ElementTy> getItem(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 3c2da03181e6a..2e760e6e6f830 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -614,7 +614,7 @@ void populateIRAffine(nb::module_ &m) {
return PyAffineExpr(self.getContext(),
mlirAffineExprCompose(self, other));
})
- .def("maybe_downcast", &PyAffineExpr::maybeDownCast)
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAffineExpr::maybeDownCast)
.def(
"shift_dims",
[](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 19db41fae4fe2..38cbb33821fe7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -20,6 +20,7 @@
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
+#include "nanobind/stl/variant.h"
#include "nanobind/typing.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -2003,7 +2004,22 @@ nb::object PyValue::getCapsule() {
return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
}
-nb::typed<nb::object, PyValue> PyValue::maybeDownCast() {
+static PyOperationRef getValueOwnerRef(MlirValue value) {
+ MlirOperation owner;
+ if (mlirValueIsAOpResult(value))
+ owner = mlirOpResultGetOwner(value);
+ else if (mlirValueIsABlockArgument(value))
+ owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
+ else
+ assert(false && "Value must be an block arg or op result.");
+ if (mlirOperationIsNull(owner))
+ throw nb::python_error();
+ MlirContext ctx = mlirOperationGetContext(owner);
+ return PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
+}
+
+nb::typed<nb::object, std::variant<PyBlockArgument, PyOpResult, PyValue>>
+PyValue::maybeDownCast() {
MlirType type = mlirValueGetType(get());
MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
@@ -2013,8 +2029,13 @@ nb::typed<nb::object, PyValue> PyValue::maybeDownCast() {
// nb::rv_policy::move means use std::move to move the return value
// contents into a new instance that will be owned by Python.
nb::object thisObj = nb::cast(this, nb::rv_policy::move);
- if (!valueCaster)
- return thisObj;
+ if (!valueCaster) {
+ PyOperationRef ownerRef = getValueOwnerRef(value);
+ if (mlirValueIsAOpResult(value))
+ return nb::cast(PyOpResult(ownerRef, value));
+ if (mlirValueIsABlockArgument(value))
+ return nb::cast(PyBlockArgument(ownerRef, value));
+ }
return valueCaster.value()(thisObj);
}
@@ -2022,16 +2043,7 @@ PyValue PyValue::createFromCapsule(nb::object capsule) {
MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
if (mlirValueIsNull(value))
throw nb::python_error();
- MlirOperation owner;
- if (mlirValueIsAOpResult(value))
- owner = mlirOpResultGetOwner(value);
- if (mlirValueIsABlockArgument(value))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
- if (mlirOperationIsNull(owner))
- throw nb::python_error();
- MlirContext ctx = mlirOperationGetContext(owner);
- PyOperationRef ownerRef =
- PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
+ PyOperationRef ownerRef = getValueOwnerRef(value);
return PyValue(ownerRef, value);
}
@@ -2279,15 +2291,7 @@ intptr_t PyOpOperandList::getRawNumElements() {
PyValue PyOpOperandList::getRawElement(intptr_t pos) {
MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
- MlirOperation owner;
- if (mlirValueIsAOpResult(operand))
- owner = mlirOpResultGetOwner(operand);
- else if (mlirValueIsABlockArgument(operand))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
- else
- assert(false && "Value must be an block arg or op result.");
- PyOperationRef pyOwner =
- PyOperation::forOperation(operation->getContext(), owner);
+ PyOperationRef pyOwner = getValueOwnerRef(operand);
return PyValue(pyOwner, operand);
}
@@ -4679,9 +4683,7 @@ void populateIRCore(nb::module_ &m) {
"with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
.def(
MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyValue &self) -> nb::typed<nb::object, PyValue> {
- return self.maybeDownCast();
- },
+ [](PyValue &self) { return self.maybeDownCast(); },
"Downcasts the `Value` to a more specific kind if possible.")
.def_prop_ro(
"location",
More information about the Mlir-commits
mailing list