[Mlir-commits] [mlir] [MLIR] [Python] `ir.Value` is now generic in the type of the value it holds (PR #166148)

Sergei Lebedev llvmlistbot at llvm.org
Tue Nov 11 02:50:36 PST 2025


https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/166148

>From 46e29cb8e8e6cf1f34ec61a99b3a0f8aabad5d81 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Mon, 3 Nov 2025 10:55:04 +0000
Subject: [PATCH] [MLIR] [Python] `ir.Value` is now generic in the type of the
 value it holds

This makes it similar to `mlir::TypedValue` in the MLIR C++ API and allows
users to be more specific about the values they produce or accept.
---
 mlir/lib/Bindings/Python/IRCore.cpp           | 28 +++++++++++------
 mlir/test/mlir-tblgen/op-python-bindings.td   | 14 ++++-----
 mlir/test/python/dialects/python_test.py      |  9 +++++-
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 31 ++++++++++++++++++-
 4 files changed, 64 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index cda4fe19c16f8..c21240926ddac 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "nanobind/nanobind.h"
+#include "nanobind/typing.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 
@@ -1596,7 +1597,11 @@ class PyConcreteValue : public PyValue {
 
   /// Binds the Python module objects to functions of this class.
   static void bind(nb::module_ &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    auto cls = ClassTy(
+        m, DerivedTy::pyClassName, nb::is_generic(),
+        nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
+                    .str()
+                    .c_str()));
     cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
     cls.def_static(
         "isinstance",
@@ -4278,7 +4283,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   //----------------------------------------------------------------------------
   // Mapping of Value.
   //----------------------------------------------------------------------------
-  nb::class_<PyValue>(m, "Value")
+  m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
+
+  nb::class_<PyValue>(m, "Value", nb::is_generic(),
+                      nb::sig("class Value(Generic[_T])"))
       .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
       .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
@@ -4371,18 +4379,20 @@ void mlir::python::populateIRCore(nb::module_ &m) {
             return printAccum.join();
           },
           nb::arg("state"), kGetNameAsOperand)
-      .def_prop_ro("type",
-                   [](PyValue &self) -> nb::typed<nb::object, PyType> {
-                     return PyType(self.getParentOperation()->getContext(),
-                                   mlirValueGetType(self.get()))
-                         .maybeDownCast();
-                   })
+      .def_prop_ro(
+          "type",
+          [](PyValue &self) {
+            return PyType(self.getParentOperation()->getContext(),
+                          mlirValueGetType(self.get()))
+                .maybeDownCast();
+          },
+          nb::sig("def type(self) -> _T"))
       .def(
           "set_type",
           [](PyValue &self, const PyType &type) {
             return mlirValueSetType(self.get(), type);
           },
-          nb::arg("type"))
+          nb::arg("type"), nb::sig("def set_type(self, type: _T)"))
       .def(
           "replace_all_uses_with",
           [](PyValue &self, PyValue &with) {
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 42de7e4eda573..ff16ad8ca0cdd 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -350,16 +350,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def f32(self) -> _ods_ir.Value:
+  // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]:
   // CHECK:   return self.operation.operands[1]
   let arguments = (ins I32, F32:$f32, I64);
 
   // CHECK: @builtins.property
-  // CHECK: def i32(self) -> _ods_ir.OpResult:
+  // CHECK: def i32(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
   // CHECK:   return self.operation.results[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def i64(self) -> _ods_ir.OpResult:
+  // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
   // CHECK:   return self.operation.results[2]
   let results = (outs I32:$i32, AnyFloat, I64:$i64);
 }
@@ -590,20 +590,20 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def i32(self) -> _ods_ir.Value:
+  // CHECK: def i32(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
   // CHECK:   return self.operation.operands[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def f32(self) -> _ods_ir.Value:
+  // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]:
   // CHECK:   return self.operation.operands[1]
   let arguments = (ins I32:$i32, F32:$f32);
 
   // CHECK: @builtins.property
-  // CHECK: def i64(self) -> _ods_ir.OpResult:
+  // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
   // CHECK:   return self.operation.results[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def f64(self) -> _ods_ir.OpResult:
+  // CHECK: def f64(self) -> _ods_ir.OpResult[_ods_ir.FloatType]:
   // CHECK:   return self.operation.results[1]
   let results = (outs I64:$i64, AnyFloat:$f64);
 }
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 1194e32c960c8..f0f74ebc12155 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -554,7 +554,7 @@ def testOptionalOperandOp():
             )
             assert (
                 typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"]
-                is OpResult
+                == OpResult[IntegerType]
             )
             assert type(op1.result) is OpResult
 
@@ -662,6 +662,13 @@ def testCustomType():
             raise
 
 
+ at run
+# CHECK-LABEL: TEST: testValue
+def testValue():
+    # Check that Value is a generic class at runtime.
+    assert hasattr(Value, "__class_getitem__")
+
+
 @run
 # CHECK-LABEL: TEST: testTensorValue
 def testTensorValue():
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0172b3fa38a6b..f01563fd49d17 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -341,6 +341,22 @@ static std::string attrSizedTraitForKind(const char *kind) {
                  StringRef(kind).drop_front());
 }
 
+static StringRef getPythonType(StringRef cppType) {
+  return llvm::StringSwitch<StringRef>(cppType)
+      .Case("::mlir::MemRefType", "_ods_ir.MemRefType")
+      .Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType")
+      .Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType")
+      .Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType")
+      .Case("::mlir::VectorType", "_ods_ir.VectorType")
+      .Case("::mlir::IntegerType", "_ods_ir.IntegerType")
+      .Case("::mlir::FloatType", "_ods_ir.FloatType")
+      .Case("::mlir::IndexType", "_ods_ir.IndexType")
+      .Case("::mlir::ComplexType", "_ods_ir.ComplexType")
+      .Case("::mlir::TupleType", "_ods_ir.TupleType")
+      .Case("::mlir::NoneType", "_ods_ir.NoneType")
+      .Default(StringRef());
+}
+
 /// Emits accessors to "elements" of an Op definition. Currently, the supported
 /// elements are operands and results, indicated by `kind`, which must be either
 /// `operand` or `result` and is used verbatim in the emitted code.
@@ -370,8 +386,11 @@ static void emitElementAccessors(
         seenVariableLength = true;
       if (element.name.empty())
         continue;
-      const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+      std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                            : "_ods_ir.OpResult";
+      if (StringRef pythonType = getPythonType(element.constraint.getCppType());
+          !pythonType.empty())
+        type = llvm::formatv("{0}[{1}]", type, pythonType);
       if (element.isVariableLength()) {
         if (element.isOptional()) {
           os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
@@ -418,6 +437,11 @@ static void emitElementAccessors(
           type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                    : "_ods_ir.OpResult";
         }
+        if (std::strcmp(kind, "operand") == 0) {
+          StringRef pythonType = getPythonType(element.constraint.getCppType());
+          if (!pythonType.empty())
+            type += "[" + pythonType.str() + "]";
+        }
         os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
                       kind, numSimpleLength, numVariadicGroups,
                       numPrecedingSimple, numPrecedingVariadic, type);
@@ -449,6 +473,11 @@ static void emitElementAccessors(
       if (!element.isVariableLength() || element.isOptional()) {
         type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                  : "_ods_ir.OpResult";
+        if (std::strcmp(kind, "operand") == 0) {
+          StringRef pythonType = getPythonType(element.constraint.getCppType());
+          if (!pythonType.empty())
+            type += "[" + pythonType.str() + "]";
+        }
         if (!element.isVariableLength()) {
           trailing = "[0]";
         } else if (element.isOptional()) {



More information about the Mlir-commits mailing list