[Mlir-commits] [mlir] [MLIR][Python] add type hints for accessors (PR #158455)
Maksim Levental
llvmlistbot at llvm.org
Sun Sep 14 17:50:30 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/158455
>From c986aac469b8d93e65952b310c89a188ecb27e95 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sun, 14 Sep 2025 01:01:30 -0400
Subject: [PATCH] [MLIR][Python] add type hints for accessors
---
mlir/test/mlir-tblgen/op-python-bindings.td | 70 ++++++++--------
mlir/test/python/dialects/python_test.py | 27 +++++-
mlir/test/python/python_test_ops.td | 5 ++
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 84 +++++++++++++------
4 files changed, 122 insertions(+), 64 deletions(-)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 3ec69c33b4bb9..d943d0f590c33 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -36,7 +36,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def variadic1(self):
+ // CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
// CHECK: operand_range = _ods_segmented_accessor(
// CHECK: self.operation.operands,
// CHECK: self.operation.attributes["operandSegmentSizes"], 0)
@@ -44,14 +44,14 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK-NOT: if len(operand_range)
//
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.Value:
// CHECK: operand_range = _ods_segmented_accessor(
// CHECK: self.operation.operands,
// CHECK: self.operation.attributes["operandSegmentSizes"], 1)
// CHECK: return operand_range[0]
//
// CHECK: @builtins.property
- // CHECK: def variadic2(self):
+ // CHECK: def variadic2(self) -> _Optional[_ods_ir.Value]:
// CHECK: operand_range = _ods_segmented_accessor(
// CHECK: self.operation.operands,
// CHECK: self.operation.attributes["operandSegmentSizes"], 2)
@@ -84,21 +84,21 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def variadic1(self):
+ // CHECK: def variadic1(self) -> _Optional[_ods_ir.OpResult]:
// CHECK: result_range = _ods_segmented_accessor(
// CHECK: self.operation.results,
// CHECK: self.operation.attributes["resultSegmentSizes"], 0)
// CHECK: return result_range[0] if len(result_range) > 0 else None
//
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
// CHECK: result_range = _ods_segmented_accessor(
// CHECK: self.operation.results,
// CHECK: self.operation.attributes["resultSegmentSizes"], 1)
// CHECK: return result_range[0]
//
// CHECK: @builtins.property
- // CHECK: def variadic2(self):
+ // CHECK: def variadic2(self) -> _ods_ir.OpResultList:
// CHECK: result_range = _ods_segmented_accessor(
// CHECK: self.operation.results,
// CHECK: self.operation.attributes["resultSegmentSizes"], 2)
@@ -138,21 +138,21 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def i32attr(self):
+ // CHECK: def i32attr(self) -> _ods_ir.Attribute:
// CHECK: return self.operation.attributes["i32attr"]
// CHECK: @builtins.property
- // CHECK: def optionalF32Attr(self):
+ // CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.Attribute]:
// CHECK: if "optionalF32Attr" not in self.operation.attributes:
// CHECK: return None
// CHECK: return self.operation.attributes["optionalF32Attr"]
// CHECK: @builtins.property
- // CHECK: def unitAttr(self):
+ // CHECK: def unitAttr(self) -> bool:
// CHECK: return "unitAttr" in self.operation.attributes
// CHECK: @builtins.property
- // CHECK: def in_(self):
+ // CHECK: def in_(self) -> _ods_ir.Attribute:
// CHECK: return self.operation.attributes["in"]
let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
@@ -185,11 +185,11 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def in_(self):
+ // CHECK: def in_(self) -> bool:
// CHECK: return "in" in self.operation.attributes
// CHECK: @builtins.property
- // CHECK: def is_(self):
+ // CHECK: def is_(self) -> _Optional[_ods_ir.Attribute]:
// CHECK: if "is" not in self.operation.attributes:
// CHECK: return None
// CHECK: return self.operation.attributes["is"]
@@ -320,16 +320,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def f32(self):
+ // CHECK: def f32(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[1]
let arguments = (ins I32, F32:$f32, I64);
// CHECK: @builtins.property
- // CHECK: def i32(self):
+ // CHECK: def i32(self) -> _ods_ir.OpResult:
// CHECK: return self.operation.results[0]
//
// CHECK: @builtins.property
- // CHECK: def i64(self):
+ // CHECK: def i64(self) -> _ods_ir.OpResult:
// CHECK: return self.operation.results[2]
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}
@@ -358,11 +358,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def non_optional(self):
+ // CHECK: def non_optional(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[0]
// CHECK: @builtins.property
- // CHECK: def optional(self):
+ // CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
}
@@ -389,11 +389,11 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[0]
//
// CHECK: @builtins.property
- // CHECK: def variadic(self):
+ // CHECK: def variadic(self) -> _ods_ir.OpOperandList:
// CHECK: _ods_variadic_group_length = len(self.operation.operands) - 2 + 1
// CHECK: return self.operation.operands[1:1 + _ods_variadic_group_length]
let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
@@ -422,12 +422,12 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def variadic(self):
+ // CHECK: def variadic(self) -> _ods_ir.OpResultList:
// CHECK: _ods_variadic_group_length = len(self.operation.results) - 2 + 1
// CHECK: return self.operation.results[0:0 + _ods_variadic_group_length]
//
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
// CHECK: _ods_variadic_group_length = len(self.operation.results) - 2 + 1
// CHECK: return self.operation.results[1 + _ods_variadic_group_length - 1]
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
@@ -453,7 +453,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def in_(self):
+ // CHECK: def in_(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[0]
let arguments = (ins AnyType:$in);
}
@@ -491,17 +491,17 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
[SameVariadicOperandSize]> {
// CHECK: @builtins.property
- // CHECK: def variadic1(self):
+ // CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 0)
// CHECK: return self.operation.operands[start:start + elements_per_group]
//
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.Value:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 1)
// CHECK: return self.operation.operands[start]
//
// CHECK: @builtins.property
- // CHECK: def variadic2(self):
+ // CHECK: def variadic2(self) -> _ods_ir.OpOperandList:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 1, 1)
// CHECK: return self.operation.operands[start:start + elements_per_group]
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -517,17 +517,17 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
[SameVariadicResultSize]> {
// CHECK: @builtins.property
- // CHECK: def variadic1(self):
+ // CHECK: def variadic1(self) -> _ods_ir.OpResultList:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 0)
// CHECK: return self.operation.results[start:start + elements_per_group]
//
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 1)
// CHECK: return self.operation.results[start]
//
// CHECK: @builtins.property
- // CHECK: def variadic2(self):
+ // CHECK: def variadic2(self) -> _ods_ir.OpResultList:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 1, 1)
// CHECK: return self.operation.results[start:start + elements_per_group]
let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -557,20 +557,20 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def i32(self):
+ // CHECK: def i32(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[0]
//
// CHECK: @builtins.property
- // CHECK: def f32(self):
+ // CHECK: def f32(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[1]
let arguments = (ins I32:$i32, F32:$f32);
// CHECK: @builtins.property
- // CHECK: def i64(self):
+ // CHECK: def i64(self) -> _ods_ir.OpResult:
// CHECK: return self.operation.results[0]
//
// CHECK: @builtins.property
- // CHECK: def f64(self):
+ // CHECK: def f64(self) -> _ods_ir.OpResult:
// CHECK: return self.operation.results[1]
let results = (outs I64:$i64, AnyFloat:$f64);
}
@@ -595,11 +595,11 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
let regions = (region AnyRegion:$region, AnyRegion, VariadicRegion<AnyRegion>:$variadic);
// CHECK: @builtins.property
- // CHECK: def region(self):
+ // CHECK: def region(self) -> _ods_ir.Region:
// CHECK: return self.regions[0]
// CHECK: @builtins.property
- // CHECK: def variadic(self):
+ // CHECK: def variadic(self) -> _ods_ir.RegionSequence:
// CHECK: return self.regions[2:]
}
@@ -623,7 +623,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
let regions = (region VariadicRegion<AnyRegion>:$Variadic);
// CHECK: @builtins.property
- // CHECK: def Variadic(self):
+ // CHECK: def Variadic(self) -> _ods_ir.RegionSequence:
// CHECK: return self.regions[0:]
}
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 68262822ca6b5..8efa45c40f262 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -323,6 +323,7 @@ def resultTypesDefinedByTraits():
# CHECK: f32 index
print(no_infer.single.type, no_infer.doubled.type)
+
# CHECK-LABEL: TEST: testOptionalOperandOp
@run
def testOptionalOperandOp():
@@ -615,11 +616,17 @@ def values(lst):
[zero, one], two, [three, four]
)
# CHECK: Value(%{{.*}} = arith.constant 2 : i32)
- print(variadic_operands.non_variadic)
+ non_variadic = variadic_operands.non_variadic
+ print(non_variadic)
+ assert isinstance(non_variadic, Value)
# CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
- print(values(variadic_operands.variadic1))
+ variadic1 = variadic_operands.variadic1
+ print(values(variadic1))
+ assert isinstance(variadic1, OpOperandList)
# CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
- print(values(variadic_operands.variadic2))
+ variadic2 = variadic_operands.variadic2
+ print(values(variadic2))
+ assert isinstance(variadic2, OpOperandList)
# CHECK-LABEL: TEST: testVariadicResultAccess
@@ -660,7 +667,9 @@ def types(lst):
# CHECK: i1
print(op.non_variadic2.type)
# CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
- print(types(op.variadic))
+ variadic = op.variadic
+ print(types(variadic))
+ assert isinstance(variadic, OpResultList)
# Test Variadic-Variadic-Fixed
op = test.SameVariadicResultSizeOpVVF(
@@ -713,3 +722,13 @@ def types(lst):
print(types(op.variadic2))
# CHECK: i4
print(op.non_variadic3.type)
+
+# CHECK-LABEL: TEST: testVariadicAndNormalRegion
+ at run
+def testVariadicAndNormalRegionOp():
+ with Context() as ctx, Location.unknown(ctx):
+ module = Module.create()
+ with InsertionPoint(module.body):
+ region_op = test.VariadicAndNormalRegionOp(2)
+ assert isinstance(region_op.region, Region)
+ assert isinstance(region_op.variadic, RegionSequence)
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 026e64a3cfc19..5f45aafcb2fda 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -265,4 +265,9 @@ def SameVariadicResultSizeOpFVFVF : TestOp<"same_variadic_result_fvfvf",
AnyType:$non_variadic3);
}
+def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
+ let regions = (region AnyRegion:$region, VariadicRegion<AnyRegion>:$variadic);
+ let assemblyFormat = "$region $variadic attr-dict";
+}
+
#endif // PYTHON_TEST_OPS
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 6a7aa9e3432d5..b975120703c17 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -44,7 +44,7 @@ _ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)
import builtins
-from typing import Sequence as _Sequence, Union as _Union
+from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional
)Py";
@@ -93,9 +93,10 @@ constexpr const char *opClassRegionSpecTemplate = R"Py(
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the position in the element list.
+/// {3} is the type hint.
constexpr const char *opSingleTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {3}:
return self.operation.{1}s[{2}]
)Py";
@@ -104,11 +105,12 @@ constexpr const char *opSingleTemplate = R"Py(
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
+/// {4} is the type hint.
/// This works for both a single variadic group (non-negative length) and an
/// single optional element (zero length if the element is absent).
constexpr const char *opSingleAfterVariableTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {4}:
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
)Py";
@@ -118,12 +120,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
+/// {4} is the type hint.
/// This works if we have only one variable-length group (and it's the optional
/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
/// smaller than the total number of groups.
constexpr const char *opOneOptionalTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> _Optional[{4}]:
return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
)Py";
@@ -132,9 +135,10 @@ constexpr const char *opOneOptionalTemplate = R"Py(
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
+/// {4} is the type hint.
constexpr const char *opOneVariadicTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {4}:
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
)Py";
@@ -146,9 +150,10 @@ constexpr const char *opOneVariadicTemplate = R"Py(
/// {3} is the total number of variadic groups;
/// {4} is the number of non-variadic groups preceding the current group;
/// {5} is the number of variadic groups preceding the current group.
+/// {6} is the type hint.
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {6}:
start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
/// Second part of the template for equally-sized case, accessing a single
@@ -171,9 +176,10 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
/// {2} is the position of the group in the group list;
/// {3} is a return suffix (expected [0] for single-element, empty for
/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
+/// {4} is the type hint.
constexpr const char *opVariadicSegmentTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {4}:
{1}_range = _ods_segmented_accessor(
self.operation.{1}s,
self.operation.attributes["{1}SegmentSizes"], {2})
@@ -191,7 +197,7 @@ constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
/// {1} is the original name of the attribute.
constexpr const char *attributeGetterTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> _ods_ir.Attribute:
return self.operation.attributes["{1}"]
)Py";
@@ -200,7 +206,7 @@ constexpr const char *attributeGetterTemplate = R"Py(
/// {1} is the original name of the attribute.
constexpr const char *optionalAttributeGetterTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> _Optional[_ods_ir.Attribute]:
if "{1}" not in self.operation.attributes:
return None
return self.operation.attributes["{1}"]
@@ -213,7 +219,7 @@ constexpr const char *optionalAttributeGetterTemplate = R"Py(
/// {1} is the original name of the attribute.
constexpr const char *unitAttributeGetterTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> bool:
return "{1}" in self.operation.attributes
)Py";
@@ -266,7 +272,7 @@ constexpr const char *attributeDeleterTemplate = R"Py(
constexpr const char *regionAccessorTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {2}:
return self.regions[{1}]
)Py";
@@ -357,15 +363,24 @@ static void emitElementAccessors(
seenVariableLength = true;
if (element.name.empty())
continue;
+ const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+ : "_ods_ir.OpResult";
if (element.isVariableLength()) {
- os << formatv(element.isOptional() ? opOneOptionalTemplate
- : opOneVariadicTemplate,
- sanitizeName(element.name), kind, numElements, i);
+ if (element.isOptional()) {
+ os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
+ numElements, i, type);
+ } else {
+ type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
+ : "_ods_ir.OpResultList";
+ os << formatv(opOneVariadicTemplate, sanitizeName(element.name), kind,
+ numElements, i, type);
+ }
} else if (seenVariableLength) {
os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
- kind, numElements, i);
+ kind, numElements, i, type);
} else {
- os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
+ os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i,
+ type);
}
}
return;
@@ -388,9 +403,17 @@ static void emitElementAccessors(
for (unsigned i = 0; i < numElements; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (!element.name.empty()) {
+ std::string type;
+ if (element.isVariableLength()) {
+ type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
+ : "_ods_ir.OpResultList";
+ } else {
+ type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+ : "_ods_ir.OpResult";
+ }
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
kind, numSimpleLength, numVariadicGroups,
- numPrecedingSimple, numPrecedingVariadic);
+ numPrecedingSimple, numPrecedingVariadic, type);
os << formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
: opVariadicEqualSimpleTemplate,
@@ -413,13 +436,23 @@ static void emitElementAccessors(
if (element.name.empty())
continue;
std::string trailing;
- if (!element.isVariableLength())
- trailing = "[0]";
- else if (element.isOptional())
- trailing = std::string(
- formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
+ std::string type = std::strcmp(kind, "operand") == 0
+ ? "_ods_ir.OpOperandList"
+ : "_ods_ir.OpResultList";
+ if (!element.isVariableLength() || element.isOptional()) {
+ type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+ : "_ods_ir.OpResult";
+ if (!element.isVariableLength()) {
+ trailing = "[0]";
+ } else if (element.isOptional()) {
+ type = "_Optional[" + type + "]";
+ trailing = std::string(
+ formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
+ }
+ }
+
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
- i, trailing);
+ i, trailing, type);
}
return;
}
@@ -980,8 +1013,9 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
"expected only the last region to be variadic");
os << formatv(regionAccessorTemplate, sanitizeName(region.name),
- std::to_string(en.index()) +
- (region.isVariadic() ? ":" : ""));
+ std::to_string(en.index()) + (region.isVariadic() ? ":" : ""),
+ region.isVariadic() ? "_ods_ir.RegionSequence"
+ : "_ods_ir.Region");
}
}
More information about the Mlir-commits
mailing list