[Mlir-commits] [mlir] [MLIR] [Python] `ir.Value` is now generic in the type of the value it holds (PR #166148)
Sergei Lebedev
llvmlistbot at llvm.org
Tue Nov 11 02:50:36 PST 2025
https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/166148
>From 46e29cb8e8e6cf1f34ec61a99b3a0f8aabad5d81 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Mon, 3 Nov 2025 10:55:04 +0000
Subject: [PATCH] [MLIR] [Python] `ir.Value` is now generic in the type of the
value it holds
This makes it similar to `mlir::TypedValue` in the MLIR C++ API and allows
users to be more specific about the values they produce or accept.
---
mlir/lib/Bindings/Python/IRCore.cpp | 28 +++++++++++------
mlir/test/mlir-tblgen/op-python-bindings.td | 14 ++++-----
mlir/test/python/dialects/python_test.py | 9 +++++-
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 31 ++++++++++++++++++-
4 files changed, 64 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index cda4fe19c16f8..c21240926ddac 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -18,6 +18,7 @@
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
+#include "nanobind/typing.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -1596,7 +1597,11 @@ class PyConcreteValue : public PyValue {
/// Binds the Python module objects to functions of this class.
static void bind(nb::module_ &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName);
+ auto cls = ClassTy(
+ m, DerivedTy::pyClassName, nb::is_generic(),
+ nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
+ .str()
+ .c_str()));
cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
cls.def_static(
"isinstance",
@@ -4278,7 +4283,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of Value.
//----------------------------------------------------------------------------
- nb::class_<PyValue>(m, "Value")
+ m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
+
+ nb::class_<PyValue>(m, "Value", nb::is_generic(),
+ nb::sig("class Value(Generic[_T])"))
.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
.def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
@@ -4371,18 +4379,20 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return printAccum.join();
},
nb::arg("state"), kGetNameAsOperand)
- .def_prop_ro("type",
- [](PyValue &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getParentOperation()->getContext(),
- mlirValueGetType(self.get()))
- .maybeDownCast();
- })
+ .def_prop_ro(
+ "type",
+ [](PyValue &self) {
+ return PyType(self.getParentOperation()->getContext(),
+ mlirValueGetType(self.get()))
+ .maybeDownCast();
+ },
+ nb::sig("def type(self) -> _T"))
.def(
"set_type",
[](PyValue &self, const PyType &type) {
return mlirValueSetType(self.get(), type);
},
- nb::arg("type"))
+ nb::arg("type"), nb::sig("def set_type(self, type: _T)"))
.def(
"replace_all_uses_with",
[](PyValue &self, PyValue &with) {
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 42de7e4eda573..ff16ad8ca0cdd 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -350,16 +350,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def f32(self) -> _ods_ir.Value:
+ // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]:
// CHECK: return self.operation.operands[1]
let arguments = (ins I32, F32:$f32, I64);
// CHECK: @builtins.property
- // CHECK: def i32(self) -> _ods_ir.OpResult:
+ // CHECK: def i32(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
// CHECK: return self.operation.results[0]
//
// CHECK: @builtins.property
- // CHECK: def i64(self) -> _ods_ir.OpResult:
+ // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
// CHECK: return self.operation.results[2]
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}
@@ -590,20 +590,20 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def i32(self) -> _ods_ir.Value:
+ // CHECK: def i32(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
// CHECK: return self.operation.operands[0]
//
// CHECK: @builtins.property
- // CHECK: def f32(self) -> _ods_ir.Value:
+ // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]:
// CHECK: return self.operation.operands[1]
let arguments = (ins I32:$i32, F32:$f32);
// CHECK: @builtins.property
- // CHECK: def i64(self) -> _ods_ir.OpResult:
+ // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
// CHECK: return self.operation.results[0]
//
// CHECK: @builtins.property
- // CHECK: def f64(self) -> _ods_ir.OpResult:
+ // CHECK: def f64(self) -> _ods_ir.OpResult[_ods_ir.FloatType]:
// CHECK: return self.operation.results[1]
let results = (outs I64:$i64, AnyFloat:$f64);
}
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 1194e32c960c8..f0f74ebc12155 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -554,7 +554,7 @@ def testOptionalOperandOp():
)
assert (
typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"]
- is OpResult
+ == OpResult[IntegerType]
)
assert type(op1.result) is OpResult
@@ -662,6 +662,13 @@ def testCustomType():
raise
+ at run
+# CHECK-LABEL: TEST: testValue
+def testValue():
+ # Check that Value is a generic class at runtime.
+ assert hasattr(Value, "__class_getitem__")
+
+
@run
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0172b3fa38a6b..f01563fd49d17 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -341,6 +341,22 @@ static std::string attrSizedTraitForKind(const char *kind) {
StringRef(kind).drop_front());
}
+static StringRef getPythonType(StringRef cppType) {
+ return llvm::StringSwitch<StringRef>(cppType)
+ .Case("::mlir::MemRefType", "_ods_ir.MemRefType")
+ .Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType")
+ .Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType")
+ .Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType")
+ .Case("::mlir::VectorType", "_ods_ir.VectorType")
+ .Case("::mlir::IntegerType", "_ods_ir.IntegerType")
+ .Case("::mlir::FloatType", "_ods_ir.FloatType")
+ .Case("::mlir::IndexType", "_ods_ir.IndexType")
+ .Case("::mlir::ComplexType", "_ods_ir.ComplexType")
+ .Case("::mlir::TupleType", "_ods_ir.TupleType")
+ .Case("::mlir::NoneType", "_ods_ir.NoneType")
+ .Default(StringRef());
+}
+
/// Emits accessors to "elements" of an Op definition. Currently, the supported
/// elements are operands and results, indicated by `kind`, which must be either
/// `operand` or `result` and is used verbatim in the emitted code.
@@ -370,8 +386,11 @@ static void emitElementAccessors(
seenVariableLength = true;
if (element.name.empty())
continue;
- const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+ std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
+ if (StringRef pythonType = getPythonType(element.constraint.getCppType());
+ !pythonType.empty())
+ type = llvm::formatv("{0}[{1}]", type, pythonType);
if (element.isVariableLength()) {
if (element.isOptional()) {
os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
@@ -418,6 +437,11 @@ static void emitElementAccessors(
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
}
+ if (std::strcmp(kind, "operand") == 0) {
+ StringRef pythonType = getPythonType(element.constraint.getCppType());
+ if (!pythonType.empty())
+ type += "[" + pythonType.str() + "]";
+ }
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
kind, numSimpleLength, numVariadicGroups,
numPrecedingSimple, numPrecedingVariadic, type);
@@ -449,6 +473,11 @@ static void emitElementAccessors(
if (!element.isVariableLength() || element.isOptional()) {
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
+ if (std::strcmp(kind, "operand") == 0) {
+ StringRef pythonType = getPythonType(element.constraint.getCppType());
+ if (!pythonType.empty())
+ type += "[" + pythonType.str() + "]";
+ }
if (!element.isVariableLength()) {
trailing = "[0]";
} else if (element.isOptional()) {
More information about the Mlir-commits
mailing list