[Mlir-commits] [mlir] [mlir][Python] downcast Value to BlockArgument or OpResult (PR #175264)
Maksim Levental
llvmlistbot at llvm.org
Mon Jan 12 09:04:59 PST 2026
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/175264
>From 7c2c7f0ac6e721e42445d45f3ec3fff1e1368aa3 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 15:02:44 -0800
Subject: [PATCH 1/6] [mlir][Python] downcast ir.Value to BlockArgument or
OpResult
---
mlir/include/mlir/Bindings/Python/IRCore.h | 13 ++--
mlir/include/mlir/Bindings/Python/Nanobind.h | 1 +
.../mlir/Bindings/Python/NanobindUtils.h | 2 +-
mlir/lib/Bindings/Python/IRAffine.cpp | 2 +-
mlir/lib/Bindings/Python/IRCore.cpp | 61 +++++++++----------
mlir/test/python/dialects/python_test.py | 2 +-
mlir/test/python/ir/value.py | 22 +++----
7 files changed, 50 insertions(+), 53 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 330318683c15e..bfc35ca5b9d50 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -36,20 +36,21 @@
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-
+class DefaultingPyLocation;
+class DefaultingPyMlirContext;
class PyBlock;
+class PyBlockArgument;
class PyDiagnostic;
class PyDiagnosticHandler;
class PyInsertionPoint;
class PyLocation;
-class DefaultingPyLocation;
class PyMlirContext;
-class DefaultingPyMlirContext;
class PyModule;
+class PyOpResult;
class PyOperation;
class PyOperationBase;
-class PyType;
class PySymbolTable;
+class PyType;
class PyValue;
/// Wrapper for the global LLVM debugging flag.
@@ -1188,7 +1189,9 @@ class MLIR_PYTHON_API_EXPORTED PyValue {
/// Gets a capsule wrapping the void* within the MlirValue.
nanobind::object getCapsule();
- nanobind::typed<nanobind::object, PyValue> maybeDownCast();
+ nanobind::typed<nanobind::object,
+ std::variant<PyBlockArgument, PyOpResult, PyValue>>
+ maybeDownCast();
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
index 8dc8a0d063d70..f7cf8fd38981d 100644
--- a/mlir/include/mlir/Bindings/Python/Nanobind.h
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -29,6 +29,7 @@
#include <nanobind/stl/string.h>
#include <nanobind/stl/string_view.h>
#include <nanobind/stl/tuple.h>
+#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <nanobind/typing.h>
#if defined(__clang__) || defined(__GNUC__)
diff --git a/mlir/include/mlir/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
index aea195fecae82..215daf245b902 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindUtils.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
@@ -277,7 +277,7 @@ class Sliceable {
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
- nanobind::object getItem(intptr_t index) {
+ nanobind::typed<nanobind::object, ElementTy> getItem(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 3c2da03181e6a..2e760e6e6f830 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -614,7 +614,7 @@ void populateIRAffine(nb::module_ &m) {
return PyAffineExpr(self.getContext(),
mlirAffineExprCompose(self, other));
})
- .def("maybe_downcast", &PyAffineExpr::maybeDownCast)
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAffineExpr::maybeDownCast)
.def(
"shift_dims",
[](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 19db41fae4fe2..646490bba366b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -10,6 +10,7 @@
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
// clang-format on
#include "mlir-c/BuiltinAttributes.h"
@@ -17,10 +18,6 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
-#include "nanobind/nanobind.h"
-#include "nanobind/typing.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -2003,7 +2000,22 @@ nb::object PyValue::getCapsule() {
return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
}
-nb::typed<nb::object, PyValue> PyValue::maybeDownCast() {
+static PyOperationRef getValueOwnerRef(MlirValue value) {
+ MlirOperation owner;
+ if (mlirValueIsAOpResult(value))
+ owner = mlirOpResultGetOwner(value);
+ else if (mlirValueIsABlockArgument(value))
+ owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
+ else
+ assert(false && "Value must be an block arg or op result.");
+ if (mlirOperationIsNull(owner))
+ throw nb::python_error();
+ MlirContext ctx = mlirOperationGetContext(owner);
+ return PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
+}
+
+nb::typed<nb::object, std::variant<PyBlockArgument, PyOpResult, PyValue>>
+PyValue::maybeDownCast() {
MlirType type = mlirValueGetType(get());
MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
@@ -2012,26 +2024,23 @@ nb::typed<nb::object, PyValue> PyValue::maybeDownCast() {
PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
// nb::rv_policy::move means use std::move to move the return value
// contents into a new instance that will be owned by Python.
- nb::object thisObj = nb::cast(this, nb::rv_policy::move);
- if (!valueCaster)
- return thisObj;
- return valueCaster.value()(thisObj);
+ nb::object thisObj;
+ if (mlirValueIsAOpResult(value))
+ thisObj = nb::cast<PyOpResult>(*this, nb::rv_policy::move);
+ else if (mlirValueIsABlockArgument(value))
+ thisObj = nb::cast<PyBlockArgument>(*this, nb::rv_policy::move);
+ else
+ assert(false && "Value must be an block arg or op result.");
+ if (valueCaster)
+ return valueCaster.value()(thisObj);
+ return thisObj;
}
PyValue PyValue::createFromCapsule(nb::object capsule) {
MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
if (mlirValueIsNull(value))
throw nb::python_error();
- MlirOperation owner;
- if (mlirValueIsAOpResult(value))
- owner = mlirOpResultGetOwner(value);
- if (mlirValueIsABlockArgument(value))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
- if (mlirOperationIsNull(owner))
- throw nb::python_error();
- MlirContext ctx = mlirOperationGetContext(owner);
- PyOperationRef ownerRef =
- PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
+ PyOperationRef ownerRef = getValueOwnerRef(value);
return PyValue(ownerRef, value);
}
@@ -2279,15 +2288,7 @@ intptr_t PyOpOperandList::getRawNumElements() {
PyValue PyOpOperandList::getRawElement(intptr_t pos) {
MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
- MlirOperation owner;
- if (mlirValueIsAOpResult(operand))
- owner = mlirOpResultGetOwner(operand);
- else if (mlirValueIsABlockArgument(operand))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
- else
- assert(false && "Value must be an block arg or op result.");
- PyOperationRef pyOwner =
- PyOperation::forOperation(operation->getContext(), owner);
+ PyOperationRef pyOwner = getValueOwnerRef(operand);
return PyValue(pyOwner, operand);
}
@@ -4679,9 +4680,7 @@ void populateIRCore(nb::module_ &m) {
"with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
.def(
MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyValue &self) -> nb::typed<nb::object, PyValue> {
- return self.maybeDownCast();
- },
+ [](PyValue &self) { return self.maybeDownCast(); },
"Downcasts the `Value` to a more specific kind if possible.")
.def_prop_ro(
"location",
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 9c0966d2d8798..8814f4d66fdf9 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -862,7 +862,7 @@ def values(lst):
]
is Value
)
- assert type(variadic_operands.non_variadic) is Value
+ assert type(variadic_operands.non_variadic) is OpResult
# CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
print(values(variadic_operands.variadic1))
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 45efb880bab44..4cd1c2fa96c27 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -347,13 +347,6 @@ def __init__(self, v):
def __str__(self):
return super().__str__().replace(Value.__name__, NOPResult.__name__)
- class NOPValue(Value):
- def __init__(self, v):
- super().__init__(v)
-
- def __str__(self):
- return super().__str__().replace(Value.__name__, NOPValue.__name__)
-
class NOPBlockArg(BlockArgument):
def __init__(self, v):
super().__init__(v)
@@ -362,14 +355,12 @@ def __str__(self):
return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
@register_value_caster(IntegerType.static_typeid)
- def cast_int(v) -> Value:
+ def cast_int(v) -> NOPResult | NOPBlockArg:
print("in caster", v.__class__.__name__)
if isinstance(v, OpResult):
return NOPResult(v)
if isinstance(v, BlockArgument):
return NOPBlockArg(v)
- elif isinstance(v, Value):
- return NOPValue(v)
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -400,12 +391,15 @@ def cast_int(v) -> Value:
# CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
print(op1)
- # CHECK: in caster Value
- # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+ # CHECK: in caster OpResult
+ # CHECK: operand 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
print("operand 0", op1.operands[0])
- # CHECK: in caster Value
- # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+ assert isinstance(op1.operands[0], Value)
+ assert isinstance(op1.operands[0], OpResult)
+ # CHECK: in caster OpResult
+ # CHECK: operand 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
print("operand 1", op1.operands[1])
+ assert isinstance(op1.operands[1], OpResult)
# CHECK: in caster BlockArgument
# CHECK: in caster BlockArgument
>From 8424f9c96dd5d88cae92f4a8901ad0932bac03eb Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 20:28:47 -0800
Subject: [PATCH 2/6] fix __Str__
---
mlir/include/mlir/Bindings/Python/IRCore.h | 8 ++++++
mlir/test/python/dialects/python_test.py | 6 ++---
mlir/test/python/dialects/rocdl.py | 2 +-
mlir/test/python/dialects/ub.py | 2 +-
mlir/test/python/ir/value.py | 31 +++++++++++-----------
5 files changed, 29 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index bfc35ca5b9d50..b0a097f4beda8 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1570,6 +1570,14 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
[](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
return self.maybeDownCast();
});
+ cls.def("__str__", [](PyValue &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append(std::string(DerivedTy::pyClassName) + "(");
+ mlirValuePrint(self.get(), printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ });
if (DerivedTy::getTypeIdFunction) {
PyGlobals::get().registerValueCaster(
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 8814f4d66fdf9..d4b394b8c7e06 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -854,7 +854,7 @@ def values(lst):
variadic_operands = test.SameVariadicOperandSizeOp(
[zero, one], two, [three, four]
)
- # CHECK: Value(%{{.*}} = arith.constant 2 : i32)
+ # CHECK: OpResult(%{{.*}} = arith.constant 2 : i32)
print(variadic_operands.non_variadic)
assert (
typing.get_type_hints(test.SameVariadicOperandSizeOp.non_variadic.fget)[
@@ -864,7 +864,7 @@ def values(lst):
)
assert type(variadic_operands.non_variadic) is OpResult
- # CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
+ # CHECK: ['OpResult(%{{.*}} = arith.constant 0 : i32)', 'OpResult(%{{.*}} = arith.constant 1 : i32)']
print(values(variadic_operands.variadic1))
assert (
typing.get_type_hints(test.SameVariadicOperandSizeOp.variadic1.fget)[
@@ -874,7 +874,7 @@ def values(lst):
)
assert type(variadic_operands.variadic1) is OpOperandList
- # CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
+ # CHECK: ['OpResult(%{{.*}} = arith.constant 3 : i32)', 'OpResult(%{{.*}} = arith.constant 4 : i32)']
print(values(variadic_operands.variadic2))
assert type(variadic_operands.variadic2) is OpOperandList
diff --git a/mlir/test/python/dialects/rocdl.py b/mlir/test/python/dialects/rocdl.py
index c73a536e03820..ed83504f3ac5c 100644
--- a/mlir/test/python/dialects/rocdl.py
+++ b/mlir/test/python/dialects/rocdl.py
@@ -34,7 +34,7 @@ def testSmoke():
# CHECK: %{{.*}} = "rocdl.wmma.f16.16x16x16.f16"
print(c_frag)
assert isinstance(c_frag, OpView)
- # CHECK: Value(%{{.*}} = "rocdl.wmma.f16.16x16x16.f16"
+ # CHECK: OpResult(%{{.*}} = "rocdl.wmma.f16.16x16x16.f16"
c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, a_frag, b_frag, c_frag, opsel=False)
print(c_frag)
assert isinstance(c_frag, Value)
diff --git a/mlir/test/python/dialects/ub.py b/mlir/test/python/dialects/ub.py
index 0d88da82c5e7b..cbbd49d2ec6cc 100644
--- a/mlir/test/python/dialects/ub.py
+++ b/mlir/test/python/dialects/ub.py
@@ -20,7 +20,7 @@ def constructAndPrintInModule(f):
# CHECK-LABEL: testSmoke
@constructAndPrintInModule
def testSmoke():
- # CHECK: Value(%{{.*}} = ub.poison : f32
+ # CHECK: OpResult(%{{.*}} = ub.poison : f32
f32 = F32Type.get()
poison = ub.poison(f32)
print(poison)
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 4cd1c2fa96c27..18fd1ba0ac775 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -229,11 +229,11 @@ def testValuePrintAsOperand():
module = Module.create()
with InsertionPoint(module.body):
value = Operation.create("custom.op1", results=[i32]).results[0]
- # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
+ # CHECK: OpResult(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
print(value)
value2 = Operation.create("custom.op2", results=[i32]).results[0]
- # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
+ # CHECK: OpResult(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
print(value2)
topFn = func.FuncOp("test", ([i32, i32], []))
@@ -241,10 +241,10 @@ def testValuePrintAsOperand():
with InsertionPoint(entry_block):
value3 = Operation.create("custom.op3", results=[i32]).results[0]
- # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
+ # CHECK: OpResult(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
print(value3)
value4 = Operation.create("custom.op4", results=[i32]).results[0]
- # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
+ # CHECK: OpResult(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
print(value4)
func.ReturnOp([])
@@ -306,7 +306,7 @@ def testValuePrintAsOperandNamedLocPrefix():
named_value = Operation.create(
"custom.op5", results=[i32], loc=Location.name("apple")
).results[0]
- # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
+ # CHECK: OpResult(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
print(named_value)
# CHECK: With use_name_loc_as_prefix
@@ -326,11 +326,11 @@ def testValueSetType():
module = Module.create()
with InsertionPoint(module.body):
value = Operation.create("custom.op1", results=[i32]).results[0]
- # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
+ # CHECK: OpResult(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
print(value)
value.set_type(i64)
- # CHECK: Value(%[[VAL1]] = "custom.op1"() : () -> i64)
+ # CHECK: OpResult(%[[VAL1]] = "custom.op1"() : () -> i64)
print(value)
# CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
@@ -345,14 +345,16 @@ def __init__(self, v):
super().__init__(v)
def __str__(self):
- return super().__str__().replace(Value.__name__, NOPResult.__name__)
+ return super().__str__().replace(OpResult.__name__, NOPResult.__name__)
class NOPBlockArg(BlockArgument):
def __init__(self, v):
super().__init__(v)
def __str__(self):
- return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
+ return (
+ super().__str__().replace(BlockArgument.__name__, NOPBlockArg.__name__)
+ )
@register_value_caster(IntegerType.static_typeid)
def cast_int(v) -> NOPResult | NOPBlockArg:
@@ -420,8 +422,7 @@ def reduction(arg0, arg1):
try:
@register_value_caster(IntegerType.static_typeid)
- def dont_cast_int_shouldnt_register(v):
- ...
+ def dont_cast_int_shouldnt_register(v): ...
except RuntimeError as e:
# CHECK: Value caster is already registered: {{.*}}cast_int
@@ -437,12 +438,12 @@ def dont_cast_int(v) -> OpResult:
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
- # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32)
+ # CHECK: don't cast 0 OpResult(%0 = "custom.op1"() : () -> i32)
new_value = Operation.create("custom.op1", results=[i32]).result
- # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32)
+ # CHECK: result 0 OpResult(%0 = "custom.op1"() : () -> i32)
print("result", new_value.result_number, new_value)
- # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32)
+ # CHECK: don't cast 0 OpResult(%1 = "custom.op2"() : () -> i32)
new_value = Operation.create("custom.op2", results=[i32]).results[0]
- # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32)
+ # CHECK: result 0 OpResult(%1 = "custom.op2"() : () -> i32)
print("result", new_value.result_number, new_value)
>From a72da3c3f6e6d0fa1ac272e3a5567ecda6a3909c Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 20:30:55 -0800
Subject: [PATCH 3/6] format
---
mlir/test/python/ir/value.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 18fd1ba0ac775..380aee724ccbc 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -422,7 +422,8 @@ def reduction(arg0, arg1):
try:
@register_value_caster(IntegerType.static_typeid)
- def dont_cast_int_shouldnt_register(v): ...
+ def dont_cast_int_shouldnt_register(v):
+ ...
except RuntimeError as e:
# CHECK: Value caster is already registered: {{.*}}cast_int
>From f882d34d31b1ecbd644d02818c3ee05123c6dd9d Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 20:36:52 -0800
Subject: [PATCH 4/6] add test
---
mlir/test/python/ir/value.py | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 380aee724ccbc..6df32fedfb954 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -340,6 +340,24 @@ def testValueSetType():
# CHECK-LABEL: TEST: testValueCasters
@run
def testValueCasters():
+ # check before registering casters
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ value = Operation.create("custom.op0", results=[i32]).result
+ # CHECK: value OpResult(%0 = "custom.op0"() : () -> i32)
+ print("value", value)
+
+ @func.FuncOp.from_py_func(i32, i32)
+ def reduction(arg0, arg1):
+ # CHECK: arg0 BlockArgument(<block argument> of type 'i32' at index: 0)
+ print("arg0", arg0)
+ # CHECK: arg1 BlockArgument(<block argument> of type 'i32' at index: 1)
+ print("arg1", arg1)
+
class NOPResult(OpResult):
def __init__(self, v):
super().__init__(v)
@@ -407,6 +425,7 @@ def cast_int(v) -> NOPResult | NOPBlockArg:
# CHECK: in caster BlockArgument
@func.FuncOp.from_py_func(i32, i32)
def reduction(arg0, arg1):
+ print("arg0", arg0)
# CHECK: as func arg 0 NOPBlockArg
print("as func arg", arg0.arg_number, arg0.__class__.__name__)
# CHECK: as func arg 1 NOPBlockArg
>From 2da195c0187035e314007be223391a41b65394d8 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 20:42:37 -0800
Subject: [PATCH 5/6] restore predecls
---
mlir/include/mlir/Bindings/Python/IRCore.h | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index b0a097f4beda8..533a5bf4e7d0b 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -36,21 +36,20 @@
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-class DefaultingPyLocation;
-class DefaultingPyMlirContext;
+
class PyBlock;
-class PyBlockArgument;
class PyDiagnostic;
class PyDiagnosticHandler;
class PyInsertionPoint;
class PyLocation;
+class DefaultingPyLocation;
class PyMlirContext;
+class DefaultingPyMlirContext;
class PyModule;
-class PyOpResult;
class PyOperation;
class PyOperationBase;
-class PySymbolTable;
class PyType;
+class PySymbolTable;
class PyValue;
/// Wrapper for the global LLVM debugging flag.
@@ -1171,6 +1170,8 @@ class MLIR_PYTHON_API_EXPORTED PyStringAttribute
/// value. For block argument values, this is the operation that contains the
/// block to which the value is an argument (blocks cannot be detached in Python
/// bindings so such operation always exists).
+class PyBlockArgument;
+class PyOpResult;
class MLIR_PYTHON_API_EXPORTED PyValue {
public:
// The virtual here is "load bearing" in that it enables RTTI
>From a70559847da7dc5c3deec6795ae6854c149a5520 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 12 Jan 2026 09:04:42 -0800
Subject: [PATCH 6/6] address comment
---
mlir/test/python/ir/value.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 6df32fedfb954..7761c2eb4a424 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -381,6 +381,7 @@ def cast_int(v) -> NOPResult | NOPBlockArg:
return NOPResult(v)
if isinstance(v, BlockArgument):
return NOPBlockArg(v)
+ raise ValueError(f"expected OpResult or BlockArgument; got {v=}")
ctx = Context()
ctx.allow_unregistered_dialects = True
More information about the Mlir-commits
mailing list