[Mlir-commits] [mlir] [MLIR][Python] add type hints for accessors (PR #158455)

Maksim Levental llvmlistbot at llvm.org
Sun Sep 14 17:50:30 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/158455

>From c986aac469b8d93e65952b310c89a188ecb27e95 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sun, 14 Sep 2025 01:01:30 -0400
Subject: [PATCH] [MLIR][Python] add type hints for accessors

---
 mlir/test/mlir-tblgen/op-python-bindings.td   | 70 ++++++++--------
 mlir/test/python/dialects/python_test.py      | 27 +++++-
 mlir/test/python/python_test_ops.td           |  5 ++
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 84 +++++++++++++------
 4 files changed, 122 insertions(+), 64 deletions(-)

diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 3ec69c33b4bb9..d943d0f590c33 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -36,7 +36,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
   // CHECK:   operand_range = _ods_segmented_accessor(
   // CHECK:       self.operation.operands,
   // CHECK:       self.operation.attributes["operandSegmentSizes"], 0)
@@ -44,14 +44,14 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK-NOT: if len(operand_range)
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.Value:
   // CHECK:   operand_range = _ods_segmented_accessor(
   // CHECK:       self.operation.operands,
   // CHECK:       self.operation.attributes["operandSegmentSizes"], 1)
   // CHECK:   return operand_range[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _Optional[_ods_ir.Value]:
   // CHECK:   operand_range = _ods_segmented_accessor(
   // CHECK:       self.operation.operands,
   // CHECK:       self.operation.attributes["operandSegmentSizes"], 2)
@@ -84,21 +84,21 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _Optional[_ods_ir.OpResult]:
   // CHECK:   result_range = _ods_segmented_accessor(
   // CHECK:       self.operation.results,
   // CHECK:       self.operation.attributes["resultSegmentSizes"], 0)
   // CHECK:   return result_range[0] if len(result_range) > 0 else None
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
   // CHECK:   result_range = _ods_segmented_accessor(
   // CHECK:       self.operation.results,
   // CHECK:       self.operation.attributes["resultSegmentSizes"], 1)
   // CHECK:   return result_range[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _ods_ir.OpResultList:
   // CHECK:   result_range = _ods_segmented_accessor(
   // CHECK:       self.operation.results,
   // CHECK:       self.operation.attributes["resultSegmentSizes"], 2)
@@ -138,21 +138,21 @@ def AttributedOp : TestOp<"attributed_op"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def i32attr(self):
+  // CHECK: def i32attr(self) -> _ods_ir.Attribute:
   // CHECK:   return self.operation.attributes["i32attr"]
 
   // CHECK: @builtins.property
-  // CHECK: def optionalF32Attr(self):
+  // CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.Attribute]:
   // CHECK:   if "optionalF32Attr" not in self.operation.attributes:
   // CHECK:     return None
   // CHECK:   return self.operation.attributes["optionalF32Attr"]
 
   // CHECK: @builtins.property
-  // CHECK: def unitAttr(self):
+  // CHECK: def unitAttr(self) -> bool:
   // CHECK:   return "unitAttr" in self.operation.attributes
 
   // CHECK: @builtins.property
-  // CHECK: def in_(self):
+  // CHECK: def in_(self) -> _ods_ir.Attribute:
   // CHECK:   return self.operation.attributes["in"]
 
   let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
@@ -185,11 +185,11 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def in_(self):
+  // CHECK: def in_(self) -> bool:
   // CHECK:   return "in" in self.operation.attributes
 
   // CHECK: @builtins.property
-  // CHECK: def is_(self):
+  // CHECK: def is_(self) -> _Optional[_ods_ir.Attribute]:
   // CHECK:   if "is" not in self.operation.attributes:
   // CHECK:     return None
   // CHECK:   return self.operation.attributes["is"]
@@ -320,16 +320,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def f32(self):
+  // CHECK: def f32(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[1]
   let arguments = (ins I32, F32:$f32, I64);
 
   // CHECK: @builtins.property
-  // CHECK: def i32(self):
+  // CHECK: def i32(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def i64(self):
+  // CHECK: def i64(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[2]
   let results = (outs I32:$i32, AnyFloat, I64:$i64);
 }
@@ -358,11 +358,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def non_optional(self):
+  // CHECK: def non_optional(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
 
   // CHECK: @builtins.property
-  // CHECK: def optional(self):
+  // CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
   // CHECK:   return None if len(self.operation.operands) < 2 else self.operation.operands[1]
 }
 
@@ -389,11 +389,11 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic(self):
+  // CHECK: def variadic(self) -> _ods_ir.OpOperandList:
   // CHECK:   _ods_variadic_group_length = len(self.operation.operands) - 2 + 1
   // CHECK:   return self.operation.operands[1:1 + _ods_variadic_group_length]
   let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
@@ -422,12 +422,12 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def variadic(self):
+  // CHECK: def variadic(self) -> _ods_ir.OpResultList:
   // CHECK:   _ods_variadic_group_length = len(self.operation.results) - 2 + 1
   // CHECK:   return self.operation.results[0:0 + _ods_variadic_group_length]
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
   // CHECK:   _ods_variadic_group_length = len(self.operation.results) - 2 + 1
   // CHECK:   return self.operation.results[1 + _ods_variadic_group_length - 1]
   let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
@@ -453,7 +453,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def in_(self):
+  // CHECK: def in_(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
   let arguments = (ins AnyType:$in);
 }
@@ -491,17 +491,17 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
 def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
                                        [SameVariadicOperandSize]> {
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 0)
   // CHECK:   return self.operation.operands[start:start + elements_per_group]
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.Value:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 1)
   // CHECK:   return self.operation.operands[start]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _ods_ir.OpOperandList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 1, 1)
   // CHECK:   return self.operation.operands[start:start + elements_per_group]
   let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -517,17 +517,17 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
 def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
                                       [SameVariadicResultSize]> {
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _ods_ir.OpResultList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 0)
   // CHECK:   return self.operation.results[start:start + elements_per_group]
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 1)
   // CHECK:   return self.operation.results[start]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _ods_ir.OpResultList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 1, 1)
   // CHECK:   return self.operation.results[start:start + elements_per_group]
   let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -557,20 +557,20 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def i32(self):
+  // CHECK: def i32(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def f32(self):
+  // CHECK: def f32(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[1]
   let arguments = (ins I32:$i32, F32:$f32);
 
   // CHECK: @builtins.property
-  // CHECK: def i64(self):
+  // CHECK: def i64(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def f64(self):
+  // CHECK: def f64(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[1]
   let results = (outs I64:$i64, AnyFloat:$f64);
 }
@@ -595,11 +595,11 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
   let regions = (region AnyRegion:$region, AnyRegion, VariadicRegion<AnyRegion>:$variadic);
 
   // CHECK:  @builtins.property
-  // CHECK:  def region(self):
+  // CHECK:  def region(self) -> _ods_ir.Region:
   // CHECK:    return self.regions[0]
 
   // CHECK:  @builtins.property
-  // CHECK:  def variadic(self):
+  // CHECK:  def variadic(self) -> _ods_ir.RegionSequence:
   // CHECK:    return self.regions[2:]
 }
 
@@ -623,7 +623,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
   let regions = (region VariadicRegion<AnyRegion>:$Variadic);
 
   // CHECK:  @builtins.property
-  // CHECK:  def Variadic(self):
+  // CHECK:  def Variadic(self) -> _ods_ir.RegionSequence:
   // CHECK:    return self.regions[0:]
 }
 
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 68262822ca6b5..8efa45c40f262 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -323,6 +323,7 @@ def resultTypesDefinedByTraits():
             # CHECK: f32 index
             print(no_infer.single.type, no_infer.doubled.type)
 
+
 # CHECK-LABEL: TEST: testOptionalOperandOp
 @run
 def testOptionalOperandOp():
@@ -615,11 +616,17 @@ def values(lst):
                 [zero, one], two, [three, four]
             )
             # CHECK: Value(%{{.*}} = arith.constant 2 : i32)
-            print(variadic_operands.non_variadic)
+            non_variadic = variadic_operands.non_variadic
+            print(non_variadic)
+            assert isinstance(non_variadic, Value)
             # CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
-            print(values(variadic_operands.variadic1))
+            variadic1 = variadic_operands.variadic1
+            print(values(variadic1))
+            assert isinstance(variadic1, OpOperandList)
             # CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
-            print(values(variadic_operands.variadic2))
+            variadic2 = variadic_operands.variadic2
+            print(values(variadic2))
+            assert isinstance(variadic2, OpOperandList)
 
 
 # CHECK-LABEL: TEST: testVariadicResultAccess
@@ -660,7 +667,9 @@ def types(lst):
             # CHECK: i1
             print(op.non_variadic2.type)
             # CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
-            print(types(op.variadic))
+            variadic = op.variadic
+            print(types(variadic))
+            assert isinstance(variadic, OpResultList)
 
             #  Test Variadic-Variadic-Fixed
             op = test.SameVariadicResultSizeOpVVF(
@@ -713,3 +722,13 @@ def types(lst):
             print(types(op.variadic2))
             # CHECK: i4
             print(op.non_variadic3.type)
+
+# CHECK-LABEL: TEST: testVariadicAndNormalRegion
+ at run
+def testVariadicAndNormalRegionOp():
+    with Context() as ctx, Location.unknown(ctx):
+        module = Module.create()
+        with InsertionPoint(module.body):
+            region_op = test.VariadicAndNormalRegionOp(2)
+            assert isinstance(region_op.region, Region)
+            assert isinstance(region_op.variadic, RegionSequence)
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 026e64a3cfc19..5f45aafcb2fda 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -265,4 +265,9 @@ def SameVariadicResultSizeOpFVFVF : TestOp<"same_variadic_result_fvfvf",
                  AnyType:$non_variadic3);
 }
 
+def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
+  let regions = (region AnyRegion:$region, VariadicRegion<AnyRegion>:$variadic);
+  let assemblyFormat = "$region $variadic attr-dict";
+}
+
 #endif // PYTHON_TEST_OPS
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 6a7aa9e3432d5..b975120703c17 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -44,7 +44,7 @@ _ods_ir = _ods_cext.ir
 _ods_cext.globals.register_traceback_file_exclusion(__file__)
 
 import builtins
-from typing import Sequence as _Sequence, Union as _Union
+from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional
 
 )Py";
 
@@ -93,9 +93,10 @@ constexpr const char *opClassRegionSpecTemplate = R"Py(
 ///   {0} is the name of the accessor;
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the position in the element list.
+///   {3} is the type hint.
 constexpr const char *opSingleTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {3}:
     return self.operation.{1}s[{2}]
 )Py";
 
@@ -104,11 +105,12 @@ constexpr const char *opSingleTemplate = R"Py(
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the total number of element groups;
 ///   {3} is the position of the current group in the group list.
+///   {4} is the type hint.
 /// This works for both a single variadic group (non-negative length) and an
 /// single optional element (zero length if the element is absent).
 constexpr const char *opSingleAfterVariableTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {4}:
     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
     return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
 )Py";
@@ -118,12 +120,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the total number of element groups;
 ///   {3} is the position of the current group in the group list.
+///   {4} is the type hint.
 /// This works if we have only one variable-length group (and it's the optional
 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
 /// smaller than the total number of groups.
 constexpr const char *opOneOptionalTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> _Optional[{4}]:
     return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
 )Py";
 
@@ -132,9 +135,10 @@ constexpr const char *opOneOptionalTemplate = R"Py(
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the total number of element groups;
 ///   {3} is the position of the current group in the group list.
+///   {4} is the type hint.
 constexpr const char *opOneVariadicTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {4}:
     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
     return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
 )Py";
@@ -146,9 +150,10 @@ constexpr const char *opOneVariadicTemplate = R"Py(
 ///   {3} is the total number of variadic groups;
 ///   {4} is the number of non-variadic groups preceding the current group;
 ///   {5} is the number of variadic groups preceding the current group.
+///   {6} is the type hint.
 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {6}:
     start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
 
 /// Second part of the template for equally-sized case, accessing a single
@@ -171,9 +176,10 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
 ///   {2} is the position of the group in the group list;
 ///   {3} is a return suffix (expected [0] for single-element, empty for
 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
+///   {4} is the type hint.
 constexpr const char *opVariadicSegmentTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {4}:
     {1}_range = _ods_segmented_accessor(
          self.operation.{1}s,
          self.operation.attributes["{1}SegmentSizes"], {2})
@@ -191,7 +197,7 @@ constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
 ///   {1} is the original name of the attribute.
 constexpr const char *attributeGetterTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> _ods_ir.Attribute:
     return self.operation.attributes["{1}"]
 )Py";
 
@@ -200,7 +206,7 @@ constexpr const char *attributeGetterTemplate = R"Py(
 ///   {1} is the original name of the attribute.
 constexpr const char *optionalAttributeGetterTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> _Optional[_ods_ir.Attribute]:
     if "{1}" not in self.operation.attributes:
       return None
     return self.operation.attributes["{1}"]
@@ -213,7 +219,7 @@ constexpr const char *optionalAttributeGetterTemplate = R"Py(
 ///    {1} is the original name of the attribute.
 constexpr const char *unitAttributeGetterTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> bool:
     return "{1}" in self.operation.attributes
 )Py";
 
@@ -266,7 +272,7 @@ constexpr const char *attributeDeleterTemplate = R"Py(
 
 constexpr const char *regionAccessorTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {2}:
     return self.regions[{1}]
 )Py";
 
@@ -357,15 +363,24 @@ static void emitElementAccessors(
         seenVariableLength = true;
       if (element.name.empty())
         continue;
+      const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+                                                           : "_ods_ir.OpResult";
       if (element.isVariableLength()) {
-        os << formatv(element.isOptional() ? opOneOptionalTemplate
-                                           : opOneVariadicTemplate,
-                      sanitizeName(element.name), kind, numElements, i);
+        if (element.isOptional()) {
+          os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
+                        numElements, i, type);
+        } else {
+          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
+                                                   : "_ods_ir.OpResultList";
+          os << formatv(opOneVariadicTemplate, sanitizeName(element.name), kind,
+                        numElements, i, type);
+        }
       } else if (seenVariableLength) {
         os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
-                      kind, numElements, i);
+                      kind, numElements, i, type);
       } else {
-        os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
+        os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i,
+                      type);
       }
     }
     return;
@@ -388,9 +403,17 @@ static void emitElementAccessors(
     for (unsigned i = 0; i < numElements; ++i) {
       const NamedTypeConstraint &element = getElement(op, i);
       if (!element.name.empty()) {
+        std::string type;
+        if (element.isVariableLength()) {
+          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
+                                                   : "_ods_ir.OpResultList";
+        } else {
+          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+                                                   : "_ods_ir.OpResult";
+        }
         os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
                       kind, numSimpleLength, numVariadicGroups,
-                      numPrecedingSimple, numPrecedingVariadic);
+                      numPrecedingSimple, numPrecedingVariadic, type);
         os << formatv(element.isVariableLength()
                           ? opVariadicEqualVariadicTemplate
                           : opVariadicEqualSimpleTemplate,
@@ -413,13 +436,23 @@ static void emitElementAccessors(
       if (element.name.empty())
         continue;
       std::string trailing;
-      if (!element.isVariableLength())
-        trailing = "[0]";
-      else if (element.isOptional())
-        trailing = std::string(
-            formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
+      std::string type = std::strcmp(kind, "operand") == 0
+                             ? "_ods_ir.OpOperandList"
+                             : "_ods_ir.OpResultList";
+      if (!element.isVariableLength() || element.isOptional()) {
+        type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+                                                 : "_ods_ir.OpResult";
+        if (!element.isVariableLength()) {
+          trailing = "[0]";
+        } else if (element.isOptional()) {
+          type = "_Optional[" + type + "]";
+          trailing = std::string(
+              formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
+        }
+      }
+
       os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
-                    i, trailing);
+                    i, trailing, type);
     }
     return;
   }
@@ -980,8 +1013,9 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
     assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
            "expected only the last region to be variadic");
     os << formatv(regionAccessorTemplate, sanitizeName(region.name),
-                  std::to_string(en.index()) +
-                      (region.isVariadic() ? ":" : ""));
+                  std::to_string(en.index()) + (region.isVariadic() ? ":" : ""),
+                  region.isVariadic() ? "_ods_ir.RegionSequence"
+                                      : "_ods_ir.Region");
   }
 }
 



More information about the Mlir-commits mailing list