[Mlir-commits] [mlir] [MLIR][Python] add type hints for accessors (PR #158455)

Maksim Levental llvmlistbot at llvm.org
Wed Sep 17 14:06:48 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/158455

>From dca51ee38f4602974e33101849d3cc9383b28252 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sun, 14 Sep 2025 01:01:30 -0400
Subject: [PATCH 1/3] [MLIR][Python] add type hints for accessors

---
 mlir/test/mlir-tblgen/op-python-bindings.td   | 70 ++++++++--------
 mlir/test/python/dialects/python_test.py      | 27 +++++-
 mlir/test/python/python_test_ops.td           |  5 ++
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 84 +++++++++++++------
 4 files changed, 122 insertions(+), 64 deletions(-)

diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 90feec9ed8d6b..432cfaef4d7d9 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -36,7 +36,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
   // CHECK:   operand_range = _ods_segmented_accessor(
   // CHECK:       self.operation.operands,
   // CHECK:       self.operation.attributes["operandSegmentSizes"], 0)
@@ -44,14 +44,14 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK-NOT: if len(operand_range)
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.Value:
   // CHECK:   operand_range = _ods_segmented_accessor(
   // CHECK:       self.operation.operands,
   // CHECK:       self.operation.attributes["operandSegmentSizes"], 1)
   // CHECK:   return operand_range[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _Optional[_ods_ir.Value]:
   // CHECK:   operand_range = _ods_segmented_accessor(
   // CHECK:       self.operation.operands,
   // CHECK:       self.operation.attributes["operandSegmentSizes"], 2)
@@ -84,21 +84,21 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _Optional[_ods_ir.OpResult]:
   // CHECK:   result_range = _ods_segmented_accessor(
   // CHECK:       self.operation.results,
   // CHECK:       self.operation.attributes["resultSegmentSizes"], 0)
   // CHECK:   return result_range[0] if len(result_range) > 0 else None
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
   // CHECK:   result_range = _ods_segmented_accessor(
   // CHECK:       self.operation.results,
   // CHECK:       self.operation.attributes["resultSegmentSizes"], 1)
   // CHECK:   return result_range[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _ods_ir.OpResultList:
   // CHECK:   result_range = _ods_segmented_accessor(
   // CHECK:       self.operation.results,
   // CHECK:       self.operation.attributes["resultSegmentSizes"], 2)
@@ -139,21 +139,21 @@ def AttributedOp : TestOp<"attributed_op"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def i32attr(self):
+  // CHECK: def i32attr(self) -> _ods_ir.Attribute:
   // CHECK:   return self.operation.attributes["i32attr"]
 
   // CHECK: @builtins.property
-  // CHECK: def optionalF32Attr(self):
+  // CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.Attribute]:
   // CHECK:   if "optionalF32Attr" not in self.operation.attributes:
   // CHECK:     return None
   // CHECK:   return self.operation.attributes["optionalF32Attr"]
 
   // CHECK: @builtins.property
-  // CHECK: def unitAttr(self):
+  // CHECK: def unitAttr(self) -> bool:
   // CHECK:   return "unitAttr" in self.operation.attributes
 
   // CHECK: @builtins.property
-  // CHECK: def in_(self):
+  // CHECK: def in_(self) -> _ods_ir.Attribute:
   // CHECK:   return self.operation.attributes["in"]
 
   let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
@@ -186,11 +186,11 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def in_(self):
+  // CHECK: def in_(self) -> bool:
   // CHECK:   return "in" in self.operation.attributes
 
   // CHECK: @builtins.property
-  // CHECK: def is_(self):
+  // CHECK: def is_(self) -> _Optional[_ods_ir.Attribute]:
   // CHECK:   if "is" not in self.operation.attributes:
   // CHECK:     return None
   // CHECK:   return self.operation.attributes["is"]
@@ -322,16 +322,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def f32(self):
+  // CHECK: def f32(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[1]
   let arguments = (ins I32, F32:$f32, I64);
 
   // CHECK: @builtins.property
-  // CHECK: def i32(self):
+  // CHECK: def i32(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def i64(self):
+  // CHECK: def i64(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[2]
   let results = (outs I32:$i32, AnyFloat, I64:$i64);
 }
@@ -360,11 +360,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def non_optional(self):
+  // CHECK: def non_optional(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
 
   // CHECK: @builtins.property
-  // CHECK: def optional(self):
+  // CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
   // CHECK:   return None if len(self.operation.operands) < 2 else self.operation.operands[1]
 }
 
@@ -391,11 +391,11 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic(self):
+  // CHECK: def variadic(self) -> _ods_ir.OpOperandList:
   // CHECK:   _ods_variadic_group_length = len(self.operation.operands) - 2 + 1
   // CHECK:   return self.operation.operands[1:1 + _ods_variadic_group_length]
   let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
@@ -424,12 +424,12 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def variadic(self):
+  // CHECK: def variadic(self) -> _ods_ir.OpResultList:
   // CHECK:   _ods_variadic_group_length = len(self.operation.results) - 2 + 1
   // CHECK:   return self.operation.results[0:0 + _ods_variadic_group_length]
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
   // CHECK:   _ods_variadic_group_length = len(self.operation.results) - 2 + 1
   // CHECK:   return self.operation.results[1 + _ods_variadic_group_length - 1]
   let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
@@ -456,7 +456,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def in_(self):
+  // CHECK: def in_(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
   let arguments = (ins AnyType:$in);
 }
@@ -495,17 +495,17 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
 def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
                                        [SameVariadicOperandSize]> {
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 0)
   // CHECK:   return self.operation.operands[start:start + elements_per_group]
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.Value:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 1)
   // CHECK:   return self.operation.operands[start]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _ods_ir.OpOperandList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 1, 1)
   // CHECK:   return self.operation.operands[start:start + elements_per_group]
   let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -521,17 +521,17 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
 def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
                                       [SameVariadicResultSize]> {
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _ods_ir.OpResultList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 0)
   // CHECK:   return self.operation.results[start:start + elements_per_group]
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 1)
   // CHECK:   return self.operation.results[start]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _ods_ir.OpResultList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 1, 1)
   // CHECK:   return self.operation.results[start:start + elements_per_group]
   let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -562,20 +562,20 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def i32(self):
+  // CHECK: def i32(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def f32(self):
+  // CHECK: def f32(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[1]
   let arguments = (ins I32:$i32, F32:$f32);
 
   // CHECK: @builtins.property
-  // CHECK: def i64(self):
+  // CHECK: def i64(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def f64(self):
+  // CHECK: def f64(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[1]
   let results = (outs I64:$i64, AnyFloat:$f64);
 }
@@ -600,11 +600,11 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
   let regions = (region AnyRegion:$region, AnyRegion, VariadicRegion<AnyRegion>:$variadic);
 
   // CHECK:  @builtins.property
-  // CHECK:  def region(self):
+  // CHECK:  def region(self) -> _ods_ir.Region:
   // CHECK:    return self.regions[0]
 
   // CHECK:  @builtins.property
-  // CHECK:  def variadic(self):
+  // CHECK:  def variadic(self) -> _ods_ir.RegionSequence:
   // CHECK:    return self.regions[2:]
 }
 
@@ -628,7 +628,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
   let regions = (region VariadicRegion<AnyRegion>:$Variadic);
 
   // CHECK:  @builtins.property
-  // CHECK:  def Variadic(self):
+  // CHECK:  def Variadic(self) -> _ods_ir.RegionSequence:
   // CHECK:    return self.regions[0:]
 }
 
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 17aaef7e1b9f4..db822c641bf47 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -629,11 +629,17 @@ def values(lst):
                 [zero, one], two, [three, four]
             )
             # CHECK: Value(%{{.*}} = arith.constant 2 : i32)
-            print(variadic_operands.non_variadic)
+            non_variadic = variadic_operands.non_variadic
+            print(non_variadic)
+            assert isinstance(non_variadic, Value)
             # CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
-            print(values(variadic_operands.variadic1))
+            variadic1 = variadic_operands.variadic1
+            print(values(variadic1))
+            assert isinstance(variadic1, OpOperandList)
             # CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
-            print(values(variadic_operands.variadic2))
+            variadic2 = variadic_operands.variadic2
+            print(values(variadic2))
+            assert isinstance(variadic2, OpOperandList)
 
             assert (
                 inspect.signature(test.same_variadic_operand).return_annotation
@@ -692,7 +698,9 @@ def types(lst):
             # CHECK: i1
             print(op.non_variadic2.type)
             # CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
-            print(types(op.variadic))
+            variadic = op.variadic
+            print(types(variadic))
+            assert isinstance(variadic, OpResultList)
 
             #  Test Variadic-Variadic-Fixed
             op = test.SameVariadicResultSizeOpVVF(
@@ -754,3 +762,14 @@ def types(lst):
                 test.results_variadic([i[0]]),
                 OpResult,
             )
+
+
+# CHECK-LABEL: TEST: testVariadicAndNormalRegion
+ at run
+def testVariadicAndNormalRegionOp():
+    with Context() as ctx, Location.unknown(ctx):
+        module = Module.create()
+        with InsertionPoint(module.body):
+            region_op = test.VariadicAndNormalRegionOp(2)
+            assert isinstance(region_op.region, Region)
+            assert isinstance(region_op.variadic, RegionSequence)
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 1e94b94dc714b..cfc1d72bb479d 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -269,4 +269,9 @@ def ResultsVariadicOp : TestOp<"results_variadic"> {
   let results = (outs Variadic<AnyType>:$res);
 }
 
+def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
+  let regions = (region AnyRegion:$region, VariadicRegion<AnyRegion>:$variadic);
+  let assemblyFormat = "$region $variadic attr-dict";
+}
+
 #endif // PYTHON_TEST_OPS
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 21f712e85e6c0..2e33581fbbff8 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -43,7 +43,7 @@ _ods_ir = _ods_cext.ir
 _ods_cext.globals.register_traceback_file_exclusion(__file__)
 
 import builtins
-from typing import Sequence as _Sequence, Union as _Union
+from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional
 
 )Py";
 
@@ -92,9 +92,10 @@ constexpr const char *opClassRegionSpecTemplate = R"Py(
 ///   {0} is the name of the accessor;
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the position in the element list.
+///   {3} is the type hint.
 constexpr const char *opSingleTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {3}:
     return self.operation.{1}s[{2}]
 )Py";
 
@@ -103,11 +104,12 @@ constexpr const char *opSingleTemplate = R"Py(
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the total number of element groups;
 ///   {3} is the position of the current group in the group list.
+///   {4} is the type hint.
 /// This works for both a single variadic group (non-negative length) and an
 /// single optional element (zero length if the element is absent).
 constexpr const char *opSingleAfterVariableTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {4}:
     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
     return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
 )Py";
@@ -117,12 +119,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the total number of element groups;
 ///   {3} is the position of the current group in the group list.
+///   {4} is the type hint.
 /// This works if we have only one variable-length group (and it's the optional
 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
 /// smaller than the total number of groups.
 constexpr const char *opOneOptionalTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> _Optional[{4}]:
     return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
 )Py";
 
@@ -131,9 +134,10 @@ constexpr const char *opOneOptionalTemplate = R"Py(
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the total number of element groups;
 ///   {3} is the position of the current group in the group list.
+///   {4} is the type hint.
 constexpr const char *opOneVariadicTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {4}:
     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
     return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
 )Py";
@@ -145,9 +149,10 @@ constexpr const char *opOneVariadicTemplate = R"Py(
 ///   {3} is the total number of variadic groups;
 ///   {4} is the number of non-variadic groups preceding the current group;
 ///   {5} is the number of variadic groups preceding the current group.
+///   {6} is the type hint.
 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {6}:
     start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
 
 /// Second part of the template for equally-sized case, accessing a single
@@ -170,9 +175,10 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
 ///   {2} is the position of the group in the group list;
 ///   {3} is a return suffix (expected [0] for single-element, empty for
 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
+///   {4} is the type hint.
 constexpr const char *opVariadicSegmentTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {4}:
     {1}_range = _ods_segmented_accessor(
          self.operation.{1}s,
          self.operation.attributes["{1}SegmentSizes"], {2})
@@ -190,7 +196,7 @@ constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
 ///   {1} is the original name of the attribute.
 constexpr const char *attributeGetterTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> _ods_ir.Attribute:
     return self.operation.attributes["{1}"]
 )Py";
 
@@ -199,7 +205,7 @@ constexpr const char *attributeGetterTemplate = R"Py(
 ///   {1} is the original name of the attribute.
 constexpr const char *optionalAttributeGetterTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> _Optional[_ods_ir.Attribute]:
     if "{1}" not in self.operation.attributes:
       return None
     return self.operation.attributes["{1}"]
@@ -212,7 +218,7 @@ constexpr const char *optionalAttributeGetterTemplate = R"Py(
 ///    {1} is the original name of the attribute.
 constexpr const char *unitAttributeGetterTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> bool:
     return "{1}" in self.operation.attributes
 )Py";
 
@@ -265,7 +271,7 @@ constexpr const char *attributeDeleterTemplate = R"Py(
 
 constexpr const char *regionAccessorTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {2}:
     return self.regions[{1}]
 )Py";
 
@@ -357,15 +363,24 @@ static void emitElementAccessors(
         seenVariableLength = true;
       if (element.name.empty())
         continue;
+      const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+                                                           : "_ods_ir.OpResult";
       if (element.isVariableLength()) {
-        os << formatv(element.isOptional() ? opOneOptionalTemplate
-                                           : opOneVariadicTemplate,
-                      sanitizeName(element.name), kind, numElements, i);
+        if (element.isOptional()) {
+          os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
+                        numElements, i, type);
+        } else {
+          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
+                                                   : "_ods_ir.OpResultList";
+          os << formatv(opOneVariadicTemplate, sanitizeName(element.name), kind,
+                        numElements, i, type);
+        }
       } else if (seenVariableLength) {
         os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
-                      kind, numElements, i);
+                      kind, numElements, i, type);
       } else {
-        os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
+        os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i,
+                      type);
       }
     }
     return;
@@ -388,9 +403,17 @@ static void emitElementAccessors(
     for (unsigned i = 0; i < numElements; ++i) {
       const NamedTypeConstraint &element = getElement(op, i);
       if (!element.name.empty()) {
+        std::string type;
+        if (element.isVariableLength()) {
+          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
+                                                   : "_ods_ir.OpResultList";
+        } else {
+          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+                                                   : "_ods_ir.OpResult";
+        }
         os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
                       kind, numSimpleLength, numVariadicGroups,
-                      numPrecedingSimple, numPrecedingVariadic);
+                      numPrecedingSimple, numPrecedingVariadic, type);
         os << formatv(element.isVariableLength()
                           ? opVariadicEqualVariadicTemplate
                           : opVariadicEqualSimpleTemplate,
@@ -413,13 +436,23 @@ static void emitElementAccessors(
       if (element.name.empty())
         continue;
       std::string trailing;
-      if (!element.isVariableLength())
-        trailing = "[0]";
-      else if (element.isOptional())
-        trailing = std::string(
-            formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
+      std::string type = std::strcmp(kind, "operand") == 0
+                             ? "_ods_ir.OpOperandList"
+                             : "_ods_ir.OpResultList";
+      if (!element.isVariableLength() || element.isOptional()) {
+        type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+                                                 : "_ods_ir.OpResult";
+        if (!element.isVariableLength()) {
+          trailing = "[0]";
+        } else if (element.isOptional()) {
+          type = "_Optional[" + type + "]";
+          trailing = std::string(
+              formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
+        }
+      }
+
       os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
-                    i, trailing);
+                    i, trailing, type);
     }
     return;
   }
@@ -980,8 +1013,9 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
     assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
            "expected only the last region to be variadic");
     os << formatv(regionAccessorTemplate, sanitizeName(region.name),
-                  std::to_string(en.index()) +
-                      (region.isVariadic() ? ":" : ""));
+                  std::to_string(en.index()) + (region.isVariadic() ? ":" : ""),
+                  region.isVariadic() ? "_ods_ir.RegionSequence"
+                                      : "_ods_ir.Region");
   }
 }
 

>From c2bcd9c1d7401e23197f49c278f00fefb4848a39 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 15 Sep 2025 12:23:14 -0700
Subject: [PATCH 2/3] add concerete type hints

---
 mlir/lib/Bindings/Python/IRAttributes.cpp     |   6 +-
 mlir/test/mlir-tblgen/op-python-bindings.td   |   8 +-
 mlir/test/python/dialects/python_test.py      | 156 ++++++++++++++++++
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp |  89 +++++++++-
 4 files changed, 244 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b3c768846c74f..4e0ade41fb708 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1742,9 +1742,9 @@ nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
     return nb::cast(PyBoolAttribute(pyAttribute));
   if (PyIntegerAttribute::isaFunction(pyAttribute))
     return nb::cast(PyIntegerAttribute(pyAttribute));
-  std::string msg =
-      std::string("Can't cast unknown element type DenseArrayAttr (") +
-      nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
+  std::string msg = std::string("Can't cast unknown attribute type Attr (") +
+                    nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
+                    ")";
   throw nb::type_error(msg.c_str());
 }
 
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 432cfaef4d7d9..eee09f9d5cedf 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -139,11 +139,11 @@ def AttributedOp : TestOp<"attributed_op"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def i32attr(self) -> _ods_ir.Attribute:
+  // CHECK: def i32attr(self) -> _ods_ir.IntegerAttr:
   // CHECK:   return self.operation.attributes["i32attr"]
 
   // CHECK: @builtins.property
-  // CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.Attribute]:
+  // CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.FloatAttr]:
   // CHECK:   if "optionalF32Attr" not in self.operation.attributes:
   // CHECK:     return None
   // CHECK:   return self.operation.attributes["optionalF32Attr"]
@@ -153,7 +153,7 @@ def AttributedOp : TestOp<"attributed_op"> {
   // CHECK:   return "unitAttr" in self.operation.attributes
 
   // CHECK: @builtins.property
-  // CHECK: def in_(self) -> _ods_ir.Attribute:
+  // CHECK: def in_(self) -> _ods_ir.IntegerAttr:
   // CHECK:   return self.operation.attributes["in"]
 
   let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
@@ -190,7 +190,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK:   return "in" in self.operation.attributes
 
   // CHECK: @builtins.property
-  // CHECK: def is_(self) -> _Optional[_ods_ir.Attribute]:
+  // CHECK: def is_(self) -> _Optional[_ods_ir.FloatAttr]:
   // CHECK:   if "is" not in self.operation.attributes:
   // CHECK:     return None
   // CHECK:   return self.operation.attributes["is"]
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index db822c641bf47..9686cc61bb11c 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -2,6 +2,7 @@
 # RUN: %PYTHON %s nanobind | FileCheck %s
 import inspect
 import sys
+import typing
 from typing import Union
 
 from mlir.ir import *
@@ -233,6 +234,161 @@ def attrBuilder():
         op.verify()
         op.print(use_local_scope=True)
 
+    # fmt: off
+    assert typing.get_type_hints(test.AttributesOp.x_affinemaparr.fset)["value"] is ArrayAttr
+    assert type(op.x_affinemaparr) is typing.get_type_hints(test.AttributesOp.x_affinemaparr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_affinemap.fset)["value"] is AffineMapAttr
+    assert type(op.x_affinemap) is typing.get_type_hints(test.AttributesOp.x_affinemap.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_arr.fset)["value"] is ArrayAttr
+    assert type(op.x_arr) is typing.get_type_hints(test.AttributesOp.x_arr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_boolarr.fset)["value"] is ArrayAttr
+    assert type(op.x_boolarr) is typing.get_type_hints(test.AttributesOp.x_boolarr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_bool.fset)["value"] is BoolAttr
+    assert type(op.x_bool) is typing.get_type_hints(test.AttributesOp.x_bool.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_dboolarr.fset)["value"] is DenseBoolArrayAttr
+    assert type(op.x_dboolarr) is typing.get_type_hints(test.AttributesOp.x_dboolarr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_df32arr.fset)["value"] is DenseF32ArrayAttr
+    assert type(op.x_df32arr) is typing.get_type_hints(test.AttributesOp.x_df32arr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_df64arr.fset)["value"] is DenseF64ArrayAttr
+    assert type(op.x_df64arr) is typing.get_type_hints(test.AttributesOp.x_df64arr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_df16arr.fset)["value"] is DenseI16ArrayAttr
+    assert type(op.x_df16arr) is typing.get_type_hints(test.AttributesOp.x_df16arr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_di32arr.fset)["value"] is DenseI32ArrayAttr
+    assert type(op.x_di32arr) is typing.get_type_hints(test.AttributesOp.x_di32arr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_di64arr.fset)["value"] is DenseI64ArrayAttr
+    assert type(op.x_di64arr) is typing.get_type_hints(test.AttributesOp.x_di64arr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_di8arr.fset)["value"] is DenseI8ArrayAttr
+    assert type(op.x_di8arr) is typing.get_type_hints(test.AttributesOp.x_di8arr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_dictarr.fset)["value"] is ArrayAttr
+    assert type(op.x_dictarr) is typing.get_type_hints(test.AttributesOp.x_dictarr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_dict.fset)["value"] is DictAttr
+    assert type(op.x_dict) is typing.get_type_hints(test.AttributesOp.x_dict.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_f32arr.fset)["value"] is ArrayAttr
+    assert type(op.x_f32arr) is typing.get_type_hints(test.AttributesOp.x_f32arr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_f32.fset)["value"] is FloatAttr
+    assert type(op.x_f32) is typing.get_type_hints(test.AttributesOp.x_f32.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_f64arr.fset)["value"] is ArrayAttr
+    assert type(op.x_f64arr) is typing.get_type_hints(test.AttributesOp.x_f64arr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_f64.fset)["value"] is FloatAttr
+    assert type(op.x_f64) is typing.get_type_hints(test.AttributesOp.x_f64.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_f64elems.fset)["value"] is DenseFPElementsAttr
+    assert type(op.x_f64elems) is typing.get_type_hints(test.AttributesOp.x_f64elems.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fset)["value"] is ArrayAttr
+    assert type(op.x_flatsymrefarr) is typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_flatsymref.fset)["value"] is FlatSymbolRefAttr
+    assert type(op.x_flatsymref) is typing.get_type_hints(test.AttributesOp.x_flatsymref.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_i16.fset)["value"] is IntegerAttr
+    assert type(op.x_i16) is typing.get_type_hints(test.AttributesOp.x_i16.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_i1.fset)["value"] is BoolAttr
+    assert type(op.x_i1) is typing.get_type_hints(test.AttributesOp.x_i1.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_i32arr.fset)["value"] is ArrayAttr
+    assert type(op.x_i32arr) is typing.get_type_hints(test.AttributesOp.x_i32arr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_i32.fset)["value"] is IntegerAttr
+    assert type(op.x_i32) is typing.get_type_hints(test.AttributesOp.x_i32.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_i32elems.fset)["value"] is DenseIntElementsAttr
+    assert type(op.x_i32elems) is typing.get_type_hints(test.AttributesOp.x_i32elems.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_i64arr.fset)["value"] is ArrayAttr
+    assert type(op.x_i64arr) is typing.get_type_hints(test.AttributesOp.x_i64arr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_i64.fset)["value"] is IntegerAttr
+    assert type(op.x_i64) is typing.get_type_hints(test.AttributesOp.x_i64.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_i64elems.fset)["value"] is DenseIntElementsAttr
+    assert type(op.x_i64elems) is typing.get_type_hints(test.AttributesOp.x_i64elems.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_i64svecarr.fset)["value"] is ArrayAttr
+    assert type(op.x_i64svecarr) is typing.get_type_hints(test.AttributesOp.x_i64svecarr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_i8.fset)["value"] is IntegerAttr
+    assert type(op.x_i8) is typing.get_type_hints(test.AttributesOp.x_i8.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_idx.fset)["value"] is IntegerAttr
+    assert type(op.x_idx) is typing.get_type_hints(test.AttributesOp.x_idx.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_idxelems.fset)["value"] is DenseIntElementsAttr
+    assert type(op.x_idxelems) is typing.get_type_hints(test.AttributesOp.x_idxelems.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_idxlistarr.fset)["value"] is ArrayAttr
+    assert type(op.x_idxlistarr) is typing.get_type_hints(test.AttributesOp.x_idxlistarr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_si16.fset)["value"] is IntegerAttr
+    assert type(op.x_si16) is typing.get_type_hints(test.AttributesOp.x_si16.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_si1.fset)["value"] is IntegerAttr
+    assert type(op.x_si1) is typing.get_type_hints(test.AttributesOp.x_si1.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_si32.fset)["value"] is IntegerAttr
+    assert type(op.x_si32) is typing.get_type_hints(test.AttributesOp.x_si32.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_si64.fset)["value"] is IntegerAttr
+    assert type(op.x_si64) is typing.get_type_hints(test.AttributesOp.x_si64.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_si8.fset)["value"] is IntegerAttr
+    assert type(op.x_si8) is typing.get_type_hints(test.AttributesOp.x_si8.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_strarr.fset)["value"] is ArrayAttr
+    assert type(op.x_strarr) is typing.get_type_hints(test.AttributesOp.x_strarr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_str.fset)["value"] is StringAttr
+    assert type(op.x_str) is typing.get_type_hints(test.AttributesOp.x_str.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_sym.fset)["value"] is StringAttr
+    assert type(op.x_sym) is typing.get_type_hints(test.AttributesOp.x_sym.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_symrefarr.fset)["value"] is ArrayAttr
+    assert type(op.x_symrefarr) is typing.get_type_hints(test.AttributesOp.x_symrefarr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_symref.fset)["value"] is SymbolRefAttr
+    assert type(op.x_symref) is typing.get_type_hints(test.AttributesOp.x_symref.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_typearr.fset)["value"] is ArrayAttr
+    assert type(op.x_typearr) is typing.get_type_hints(test.AttributesOp.x_typearr.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_type.fset)["value"] is TypeAttr
+    assert type(op.x_type) is typing.get_type_hints(test.AttributesOp.x_type.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_ui16.fset)["value"] is IntegerAttr
+    assert type(op.x_ui16) is typing.get_type_hints(test.AttributesOp.x_ui16.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_ui1.fset)["value"] is IntegerAttr
+    assert type(op.x_ui1) is typing.get_type_hints(test.AttributesOp.x_ui1.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_ui32.fset)["value"] is IntegerAttr
+    assert type(op.x_ui32) is typing.get_type_hints(test.AttributesOp.x_ui32.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_ui64.fset)["value"] is IntegerAttr
+    assert type(op.x_ui64) is typing.get_type_hints(test.AttributesOp.x_ui64.fget)["return"]
+
+    assert typing.get_type_hints(test.AttributesOp.x_ui8.fset)["value"] is IntegerAttr
+    assert type(op.x_ui8) is typing.get_type_hints(test.AttributesOp.x_ui8.fget)["return"]
+    # fmt: on
+
 
 # CHECK-LABEL: TEST: inferReturnTypes
 @run
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 2e33581fbbff8..f73324afd200c 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -194,18 +194,20 @@ constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
 /// Template for an operation attribute getter:
 ///   {0} is the name of the attribute sanitized for Python;
 ///   {1} is the original name of the attribute.
+///   {2} is the type hint.
 constexpr const char *attributeGetterTemplate = R"Py(
   @builtins.property
-  def {0}(self) -> _ods_ir.Attribute:
+  def {0}(self) -> {2}:
     return self.operation.attributes["{1}"]
 )Py";
 
 /// Template for an optional operation attribute getter:
 ///   {0} is the name of the attribute sanitized for Python;
 ///   {1} is the original name of the attribute.
+///   {2} is the type hint.
 constexpr const char *optionalAttributeGetterTemplate = R"Py(
   @builtins.property
-  def {0}(self) -> _Optional[_ods_ir.Attribute]:
+  def {0}(self) -> _Optional[{2}]:
     if "{1}" not in self.operation.attributes:
       return None
     return self.operation.attributes["{1}"]
@@ -225,9 +227,10 @@ constexpr const char *unitAttributeGetterTemplate = R"Py(
 /// Template for an operation attribute setter:
 ///    {0} is the name of the attribute sanitized for Python;
 ///    {1} is the original name of the attribute.
+///    {2} is the type hint.
 constexpr const char *attributeSetterTemplate = R"Py(
   @{0}.setter
-  def {0}(self, value):
+  def {0}(self, value: {2}):
     if value is None:
       raise ValueError("'None' not allowed as value for mandatory attributes")
     self.operation.attributes["{1}"] = value
@@ -237,9 +240,10 @@ constexpr const char *attributeSetterTemplate = R"Py(
 /// removes the attribute:
 ///    {0} is the name of the attribute sanitized for Python;
 ///    {1} is the original name of the attribute.
+///    {2} is the type hint.
 constexpr const char *optionalAttributeSetterTemplate = R"Py(
   @{0}.setter
-  def {0}(self, value):
+  def {0}(self, value: _Optional[{2}]):
     if value is not None:
       self.operation.attributes["{1}"] = value
     elif "{1}" in self.operation.attributes:
@@ -482,6 +486,72 @@ static void emitResultAccessors(const Operator &op, raw_ostream &os) {
                        getNumResults(op), getResult);
 }
 
+static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
+  auto storageTypeStr = attr.getStorageType();
+  if (storageTypeStr == "::mlir::AffineMapAttr")
+    return "AffineMapAttr";
+  if (storageTypeStr == "::mlir::ArrayAttr")
+    return "ArrayAttr";
+  if (storageTypeStr == "::mlir::BoolAttr")
+    return "BoolAttr";
+  if (storageTypeStr == "::mlir::DenseBoolArrayAttr")
+    return "DenseBoolArrayAttr";
+  if (storageTypeStr == "::mlir::DenseElementsAttr") {
+    llvm::StringSet<> superClasses;
+    for (const Record *sc : attr.getDef().getSuperClasses())
+      superClasses.insert(sc->getNameInitAsString());
+    if (superClasses.contains("FloatElementsAttr") ||
+        superClasses.contains("RankedFloatElementsAttr")) {
+      return "DenseFPElementsAttr";
+    }
+    return "DenseElementsAttr";
+  }
+  if (storageTypeStr == "::mlir::DenseF32ArrayAttr")
+    return "DenseF32ArrayAttr";
+  if (storageTypeStr == "::mlir::DenseF64ArrayAttr")
+    return "DenseF64ArrayAttr";
+  if (storageTypeStr == "::mlir::DenseFPElementsAttr")
+    return "DenseFPElementsAttr";
+  if (storageTypeStr == "::mlir::DenseI16ArrayAttr")
+    return "DenseI16ArrayAttr";
+  if (storageTypeStr == "::mlir::DenseI32ArrayAttr")
+    return "DenseI32ArrayAttr";
+  if (storageTypeStr == "::mlir::DenseI64ArrayAttr")
+    return "DenseI64ArrayAttr";
+  if (storageTypeStr == "::mlir::DenseI8ArrayAttr")
+    return "DenseI8ArrayAttr";
+  if (storageTypeStr == "::mlir::DenseIntElementsAttr")
+    return "DenseIntElementsAttr";
+  if (storageTypeStr == "::mlir::DenseResourceElementsAttr")
+    return "DenseResourceElementsAttr";
+  if (storageTypeStr == "::mlir::DictionaryAttr")
+    return "DictAttr";
+  if (storageTypeStr == "::mlir::FlatSymbolRefAttr")
+    return "FlatSymbolRefAttr";
+  if (storageTypeStr == "::mlir::FloatAttr")
+    return "FloatAttr";
+  if (storageTypeStr == "::mlir::IntegerAttr") {
+    if (attr.getAttrDefName().str() == "I1Attr")
+      return "BoolAttr";
+    return "IntegerAttr";
+  }
+  if (storageTypeStr == "::mlir::IntegerSetAttr")
+    return "IntegerSetAttr";
+  if (storageTypeStr == "::mlir::OpaqueAttr")
+    return "OpaqueAttr";
+  if (storageTypeStr == "::mlir::StridedLayoutAttr")
+    return "StridedLayoutAttr";
+  if (storageTypeStr == "::mlir::StringAttr")
+    return "StringAttr";
+  if (storageTypeStr == "::mlir::SymbolRefAttr")
+    return "SymbolRefAttr";
+  if (storageTypeStr == "::mlir::TypeAttr")
+    return "TypeAttr";
+  if (storageTypeStr == "::mlir::UnitAttr")
+    return "UnitAttr";
+  return "Attribute";
+}
+
 /// Emits accessors to Op attributes.
 static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
   for (const auto &namedAttr : op.getAttributes()) {
@@ -503,15 +573,18 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
       continue;
     }
 
+    std::string type = "_ods_ir." + getPythonAttrName(namedAttr.attr);
     if (namedAttr.attr.isOptional()) {
       os << formatv(optionalAttributeGetterTemplate, sanitizedName,
-                    namedAttr.name);
+                    namedAttr.name, type);
       os << formatv(optionalAttributeSetterTemplate, sanitizedName,
-                    namedAttr.name);
+                    namedAttr.name, type);
       os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
     } else {
-      os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name);
-      os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name);
+      os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name,
+                    type);
+      os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name,
+                    type);
       // Non-optional attributes cannot be deleted.
     }
   }

>From 69abc3f34eeb2b9368f9eedfe201b30ba2adbc76 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 16 Sep 2025 09:35:04 -0700
Subject: [PATCH 3/3] is type is instead of isinstance

---
 mlir/test/python/dialects/python_test.py | 94 +++++++++++++++---------
 1 file changed, 60 insertions(+), 34 deletions(-)

diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 9686cc61bb11c..5c7d3234c6df2 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -1,9 +1,8 @@
 # RUN: %PYTHON %s pybind11 | FileCheck %s
 # RUN: %PYTHON %s nanobind | FileCheck %s
-import inspect
 import sys
 import typing
-from typing import Union
+from typing import Union, Optional
 
 from mlir.ir import *
 import mlir.dialects.func as func
@@ -449,6 +448,13 @@ def resultTypesDefinedByTraits():
             # CHECK-COUNT-2: i32
             print(same.one.type)
             print(same.two.type)
+            assert (
+                typing.get_type_hints(test.SameOperandAndResultTypeOp.one.fget)[
+                    "return"
+                ]
+                is OpResult
+            )
+            assert type(same.one) is OpResult
 
             first_type_attr = test.FirstAttrDeriveTypeAttrOp(
                 inferred.results[1], TypeAttr.get(IndexType.get())
@@ -491,6 +497,15 @@ def testOptionalOperandOp():
             op1 = test.OptionalOperandOp()
             # CHECK: op1.input is None: True
             print(f"op1.input is None: {op1.input is None}")
+            assert (
+                typing.get_type_hints(test.OptionalOperandOp.input.fget)["return"]
+                is Optional[Value]
+            )
+            assert (
+                typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"]
+                is OpResult
+            )
+            assert type(op1.result) is OpResult
 
             op2 = test.OptionalOperandOp(input=op1)
             # CHECK: op2.input is None: False
@@ -754,14 +769,12 @@ def testInferTypeOpInterface():
             print(two_operands.result.type)
 
             assert (
-                inspect.signature(
-                    test.infer_results_variadic_inputs_op
-                ).return_annotation
+                typing.get_type_hints(test.infer_results_variadic_inputs_op)["return"]
                 is OpResult
             )
-            assert isinstance(
-                test.infer_results_variadic_inputs_op(single=zero, doubled=zero),
-                OpResult,
+            assert (
+                type(test.infer_results_variadic_inputs_op(single=zero, doubled=zero))
+                is OpResult
             )
 
 
@@ -785,25 +798,36 @@ def values(lst):
                 [zero, one], two, [three, four]
             )
             # CHECK: Value(%{{.*}} = arith.constant 2 : i32)
-            non_variadic = variadic_operands.non_variadic
-            print(non_variadic)
-            assert isinstance(non_variadic, Value)
+            print(variadic_operands.non_variadic)
+            assert (
+                typing.get_type_hints(test.SameVariadicOperandSizeOp.non_variadic.fget)[
+                    "return"
+                ]
+                is Value
+            )
+            assert type(variadic_operands.non_variadic) is Value
+
             # CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
-            variadic1 = variadic_operands.variadic1
-            print(values(variadic1))
-            assert isinstance(variadic1, OpOperandList)
+            print(values(variadic_operands.variadic1))
+            assert (
+                typing.get_type_hints(test.SameVariadicOperandSizeOp.variadic1.fget)[
+                    "return"
+                ]
+                is OpOperandList
+            )
+            assert type(variadic_operands.variadic1) is OpOperandList
+
             # CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
-            variadic2 = variadic_operands.variadic2
-            print(values(variadic2))
-            assert isinstance(variadic2, OpOperandList)
+            print(values(variadic_operands.variadic2))
+            assert type(variadic_operands.variadic2) is OpOperandList
 
             assert (
-                inspect.signature(test.same_variadic_operand).return_annotation
+                typing.get_type_hints(test.same_variadic_operand)["return"]
                 is test.SameVariadicOperandSizeOp
             )
-            assert isinstance(
-                test.same_variadic_operand([zero, one], two, [three, four]),
-                test.SameVariadicOperandSizeOp,
+            assert (
+                type(test.same_variadic_operand([zero, one], two, [three, four]))
+                is test.SameVariadicOperandSizeOp
             )
 
 
@@ -828,12 +852,12 @@ def types(lst):
             print(types(op.variadic2))
 
             assert (
-                inspect.signature(test.same_variadic_result_vfv).return_annotation
+                typing.get_type_hints(test.same_variadic_result_vfv)["return"]
                 is Union[OpResult, OpResultList, test.SameVariadicResultSizeOpVFV]
             )
-            assert isinstance(
-                test.same_variadic_result_vfv([i[0], i[1]], i[2], [i[3], i[4]]),
-                OpResultList,
+            assert (
+                type(test.same_variadic_result_vfv([i[0], i[1]], i[2], [i[3], i[4]]))
+                is OpResultList
             )
 
             #  Test Variadic-Variadic-Variadic
@@ -854,9 +878,8 @@ def types(lst):
             # CHECK: i1
             print(op.non_variadic2.type)
             # CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
-            variadic = op.variadic
-            print(types(variadic))
-            assert isinstance(variadic, OpResultList)
+            print(types(op.variadic))
+            assert type(op.variadic) is OpResultList
 
             #  Test Variadic-Variadic-Fixed
             op = test.SameVariadicResultSizeOpVVF(
@@ -911,13 +934,16 @@ def types(lst):
             print(op.non_variadic3.type)
 
             assert (
-                inspect.signature(test.results_variadic).return_annotation
+                typing.get_type_hints(test.results_variadic)["return"]
                 is Union[OpResult, OpResultList, test.ResultsVariadicOp]
             )
-            assert isinstance(
-                test.results_variadic([i[0]]),
-                OpResult,
+            assert type(test.results_variadic([i[0]])) is OpResult
+            op_res_variadic = test.ResultsVariadicOp([i[0]])
+            assert (
+                typing.get_type_hints(test.ResultsVariadicOp.res.fget)["return"]
+                is OpResultList
             )
+            assert type(op_res_variadic.res) is OpResultList
 
 
 # CHECK-LABEL: TEST: testVariadicAndNormalRegion
@@ -927,5 +953,5 @@ def testVariadicAndNormalRegionOp():
         module = Module.create()
         with InsertionPoint(module.body):
             region_op = test.VariadicAndNormalRegionOp(2)
-            assert isinstance(region_op.region, Region)
-            assert isinstance(region_op.variadic, RegionSequence)
+            assert type(region_op.region) is Region
+            assert type(region_op.variadic) is RegionSequence



More information about the Mlir-commits mailing list