[Mlir-commits] [mlir] [MLIR][Python] add type hints for accessors (PR #158455)
Maksim Levental
llvmlistbot at llvm.org
Wed Sep 17 14:06:48 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/158455
>From dca51ee38f4602974e33101849d3cc9383b28252 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sun, 14 Sep 2025 01:01:30 -0400
Subject: [PATCH 1/3] [MLIR][Python] add type hints for accessors
---
mlir/test/mlir-tblgen/op-python-bindings.td | 70 ++++++++--------
mlir/test/python/dialects/python_test.py | 27 +++++-
mlir/test/python/python_test_ops.td | 5 ++
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 84 +++++++++++++------
4 files changed, 122 insertions(+), 64 deletions(-)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 90feec9ed8d6b..432cfaef4d7d9 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -36,7 +36,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def variadic1(self):
+ // CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
// CHECK: operand_range = _ods_segmented_accessor(
// CHECK: self.operation.operands,
// CHECK: self.operation.attributes["operandSegmentSizes"], 0)
@@ -44,14 +44,14 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK-NOT: if len(operand_range)
//
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.Value:
// CHECK: operand_range = _ods_segmented_accessor(
// CHECK: self.operation.operands,
// CHECK: self.operation.attributes["operandSegmentSizes"], 1)
// CHECK: return operand_range[0]
//
// CHECK: @builtins.property
- // CHECK: def variadic2(self):
+ // CHECK: def variadic2(self) -> _Optional[_ods_ir.Value]:
// CHECK: operand_range = _ods_segmented_accessor(
// CHECK: self.operation.operands,
// CHECK: self.operation.attributes["operandSegmentSizes"], 2)
@@ -84,21 +84,21 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def variadic1(self):
+ // CHECK: def variadic1(self) -> _Optional[_ods_ir.OpResult]:
// CHECK: result_range = _ods_segmented_accessor(
// CHECK: self.operation.results,
// CHECK: self.operation.attributes["resultSegmentSizes"], 0)
// CHECK: return result_range[0] if len(result_range) > 0 else None
//
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
// CHECK: result_range = _ods_segmented_accessor(
// CHECK: self.operation.results,
// CHECK: self.operation.attributes["resultSegmentSizes"], 1)
// CHECK: return result_range[0]
//
// CHECK: @builtins.property
- // CHECK: def variadic2(self):
+ // CHECK: def variadic2(self) -> _ods_ir.OpResultList:
// CHECK: result_range = _ods_segmented_accessor(
// CHECK: self.operation.results,
// CHECK: self.operation.attributes["resultSegmentSizes"], 2)
@@ -139,21 +139,21 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def i32attr(self):
+ // CHECK: def i32attr(self) -> _ods_ir.Attribute:
// CHECK: return self.operation.attributes["i32attr"]
// CHECK: @builtins.property
- // CHECK: def optionalF32Attr(self):
+ // CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.Attribute]:
// CHECK: if "optionalF32Attr" not in self.operation.attributes:
// CHECK: return None
// CHECK: return self.operation.attributes["optionalF32Attr"]
// CHECK: @builtins.property
- // CHECK: def unitAttr(self):
+ // CHECK: def unitAttr(self) -> bool:
// CHECK: return "unitAttr" in self.operation.attributes
// CHECK: @builtins.property
- // CHECK: def in_(self):
+ // CHECK: def in_(self) -> _ods_ir.Attribute:
// CHECK: return self.operation.attributes["in"]
let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
@@ -186,11 +186,11 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def in_(self):
+ // CHECK: def in_(self) -> bool:
// CHECK: return "in" in self.operation.attributes
// CHECK: @builtins.property
- // CHECK: def is_(self):
+ // CHECK: def is_(self) -> _Optional[_ods_ir.Attribute]:
// CHECK: if "is" not in self.operation.attributes:
// CHECK: return None
// CHECK: return self.operation.attributes["is"]
@@ -322,16 +322,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def f32(self):
+ // CHECK: def f32(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[1]
let arguments = (ins I32, F32:$f32, I64);
// CHECK: @builtins.property
- // CHECK: def i32(self):
+ // CHECK: def i32(self) -> _ods_ir.OpResult:
// CHECK: return self.operation.results[0]
//
// CHECK: @builtins.property
- // CHECK: def i64(self):
+ // CHECK: def i64(self) -> _ods_ir.OpResult:
// CHECK: return self.operation.results[2]
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}
@@ -360,11 +360,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def non_optional(self):
+ // CHECK: def non_optional(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[0]
// CHECK: @builtins.property
- // CHECK: def optional(self):
+ // CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
}
@@ -391,11 +391,11 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[0]
//
// CHECK: @builtins.property
- // CHECK: def variadic(self):
+ // CHECK: def variadic(self) -> _ods_ir.OpOperandList:
// CHECK: _ods_variadic_group_length = len(self.operation.operands) - 2 + 1
// CHECK: return self.operation.operands[1:1 + _ods_variadic_group_length]
let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
@@ -424,12 +424,12 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def variadic(self):
+ // CHECK: def variadic(self) -> _ods_ir.OpResultList:
// CHECK: _ods_variadic_group_length = len(self.operation.results) - 2 + 1
// CHECK: return self.operation.results[0:0 + _ods_variadic_group_length]
//
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
// CHECK: _ods_variadic_group_length = len(self.operation.results) - 2 + 1
// CHECK: return self.operation.results[1 + _ods_variadic_group_length - 1]
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
@@ -456,7 +456,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def in_(self):
+ // CHECK: def in_(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[0]
let arguments = (ins AnyType:$in);
}
@@ -495,17 +495,17 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
[SameVariadicOperandSize]> {
// CHECK: @builtins.property
- // CHECK: def variadic1(self):
+ // CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 0)
// CHECK: return self.operation.operands[start:start + elements_per_group]
//
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.Value:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 1)
// CHECK: return self.operation.operands[start]
//
// CHECK: @builtins.property
- // CHECK: def variadic2(self):
+ // CHECK: def variadic2(self) -> _ods_ir.OpOperandList:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 1, 1)
// CHECK: return self.operation.operands[start:start + elements_per_group]
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -521,17 +521,17 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
[SameVariadicResultSize]> {
// CHECK: @builtins.property
- // CHECK: def variadic1(self):
+ // CHECK: def variadic1(self) -> _ods_ir.OpResultList:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 0)
// CHECK: return self.operation.results[start:start + elements_per_group]
//
// CHECK: @builtins.property
- // CHECK: def non_variadic(self):
+ // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 1)
// CHECK: return self.operation.results[start]
//
// CHECK: @builtins.property
- // CHECK: def variadic2(self):
+ // CHECK: def variadic2(self) -> _ods_ir.OpResultList:
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 1, 1)
// CHECK: return self.operation.results[start:start + elements_per_group]
let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -562,20 +562,20 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def i32(self):
+ // CHECK: def i32(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[0]
//
// CHECK: @builtins.property
- // CHECK: def f32(self):
+ // CHECK: def f32(self) -> _ods_ir.Value:
// CHECK: return self.operation.operands[1]
let arguments = (ins I32:$i32, F32:$f32);
// CHECK: @builtins.property
- // CHECK: def i64(self):
+ // CHECK: def i64(self) -> _ods_ir.OpResult:
// CHECK: return self.operation.results[0]
//
// CHECK: @builtins.property
- // CHECK: def f64(self):
+ // CHECK: def f64(self) -> _ods_ir.OpResult:
// CHECK: return self.operation.results[1]
let results = (outs I64:$i64, AnyFloat:$f64);
}
@@ -600,11 +600,11 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
let regions = (region AnyRegion:$region, AnyRegion, VariadicRegion<AnyRegion>:$variadic);
// CHECK: @builtins.property
- // CHECK: def region(self):
+ // CHECK: def region(self) -> _ods_ir.Region:
// CHECK: return self.regions[0]
// CHECK: @builtins.property
- // CHECK: def variadic(self):
+ // CHECK: def variadic(self) -> _ods_ir.RegionSequence:
// CHECK: return self.regions[2:]
}
@@ -628,7 +628,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
let regions = (region VariadicRegion<AnyRegion>:$Variadic);
// CHECK: @builtins.property
- // CHECK: def Variadic(self):
+ // CHECK: def Variadic(self) -> _ods_ir.RegionSequence:
// CHECK: return self.regions[0:]
}
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 17aaef7e1b9f4..db822c641bf47 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -629,11 +629,17 @@ def values(lst):
[zero, one], two, [three, four]
)
# CHECK: Value(%{{.*}} = arith.constant 2 : i32)
- print(variadic_operands.non_variadic)
+ non_variadic = variadic_operands.non_variadic
+ print(non_variadic)
+ assert isinstance(non_variadic, Value)
# CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
- print(values(variadic_operands.variadic1))
+ variadic1 = variadic_operands.variadic1
+ print(values(variadic1))
+ assert isinstance(variadic1, OpOperandList)
# CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
- print(values(variadic_operands.variadic2))
+ variadic2 = variadic_operands.variadic2
+ print(values(variadic2))
+ assert isinstance(variadic2, OpOperandList)
assert (
inspect.signature(test.same_variadic_operand).return_annotation
@@ -692,7 +698,9 @@ def types(lst):
# CHECK: i1
print(op.non_variadic2.type)
# CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
- print(types(op.variadic))
+ variadic = op.variadic
+ print(types(variadic))
+ assert isinstance(variadic, OpResultList)
# Test Variadic-Variadic-Fixed
op = test.SameVariadicResultSizeOpVVF(
@@ -754,3 +762,14 @@ def types(lst):
test.results_variadic([i[0]]),
OpResult,
)
+
+
+# CHECK-LABEL: TEST: testVariadicAndNormalRegion
+ at run
+def testVariadicAndNormalRegionOp():
+ with Context() as ctx, Location.unknown(ctx):
+ module = Module.create()
+ with InsertionPoint(module.body):
+ region_op = test.VariadicAndNormalRegionOp(2)
+ assert isinstance(region_op.region, Region)
+ assert isinstance(region_op.variadic, RegionSequence)
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 1e94b94dc714b..cfc1d72bb479d 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -269,4 +269,9 @@ def ResultsVariadicOp : TestOp<"results_variadic"> {
let results = (outs Variadic<AnyType>:$res);
}
+def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
+ let regions = (region AnyRegion:$region, VariadicRegion<AnyRegion>:$variadic);
+ let assemblyFormat = "$region $variadic attr-dict";
+}
+
#endif // PYTHON_TEST_OPS
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 21f712e85e6c0..2e33581fbbff8 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -43,7 +43,7 @@ _ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)
import builtins
-from typing import Sequence as _Sequence, Union as _Union
+from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional
)Py";
@@ -92,9 +92,10 @@ constexpr const char *opClassRegionSpecTemplate = R"Py(
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the position in the element list.
+/// {3} is the type hint.
constexpr const char *opSingleTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {3}:
return self.operation.{1}s[{2}]
)Py";
@@ -103,11 +104,12 @@ constexpr const char *opSingleTemplate = R"Py(
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
+/// {4} is the type hint.
/// This works for both a single variadic group (non-negative length) and an
/// single optional element (zero length if the element is absent).
constexpr const char *opSingleAfterVariableTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {4}:
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
)Py";
@@ -117,12 +119,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
+/// {4} is the type hint.
/// This works if we have only one variable-length group (and it's the optional
/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
/// smaller than the total number of groups.
constexpr const char *opOneOptionalTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> _Optional[{4}]:
return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
)Py";
@@ -131,9 +134,10 @@ constexpr const char *opOneOptionalTemplate = R"Py(
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
+/// {4} is the type hint.
constexpr const char *opOneVariadicTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {4}:
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
)Py";
@@ -145,9 +149,10 @@ constexpr const char *opOneVariadicTemplate = R"Py(
/// {3} is the total number of variadic groups;
/// {4} is the number of non-variadic groups preceding the current group;
/// {5} is the number of variadic groups preceding the current group.
+/// {6} is the type hint.
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {6}:
start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
/// Second part of the template for equally-sized case, accessing a single
@@ -170,9 +175,10 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
/// {2} is the position of the group in the group list;
/// {3} is a return suffix (expected [0] for single-element, empty for
/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
+/// {4} is the type hint.
constexpr const char *opVariadicSegmentTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {4}:
{1}_range = _ods_segmented_accessor(
self.operation.{1}s,
self.operation.attributes["{1}SegmentSizes"], {2})
@@ -190,7 +196,7 @@ constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
/// {1} is the original name of the attribute.
constexpr const char *attributeGetterTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> _ods_ir.Attribute:
return self.operation.attributes["{1}"]
)Py";
@@ -199,7 +205,7 @@ constexpr const char *attributeGetterTemplate = R"Py(
/// {1} is the original name of the attribute.
constexpr const char *optionalAttributeGetterTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> _Optional[_ods_ir.Attribute]:
if "{1}" not in self.operation.attributes:
return None
return self.operation.attributes["{1}"]
@@ -212,7 +218,7 @@ constexpr const char *optionalAttributeGetterTemplate = R"Py(
/// {1} is the original name of the attribute.
constexpr const char *unitAttributeGetterTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> bool:
return "{1}" in self.operation.attributes
)Py";
@@ -265,7 +271,7 @@ constexpr const char *attributeDeleterTemplate = R"Py(
constexpr const char *regionAccessorTemplate = R"Py(
@builtins.property
- def {0}(self):
+ def {0}(self) -> {2}:
return self.regions[{1}]
)Py";
@@ -357,15 +363,24 @@ static void emitElementAccessors(
seenVariableLength = true;
if (element.name.empty())
continue;
+ const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+ : "_ods_ir.OpResult";
if (element.isVariableLength()) {
- os << formatv(element.isOptional() ? opOneOptionalTemplate
- : opOneVariadicTemplate,
- sanitizeName(element.name), kind, numElements, i);
+ if (element.isOptional()) {
+ os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
+ numElements, i, type);
+ } else {
+ type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
+ : "_ods_ir.OpResultList";
+ os << formatv(opOneVariadicTemplate, sanitizeName(element.name), kind,
+ numElements, i, type);
+ }
} else if (seenVariableLength) {
os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
- kind, numElements, i);
+ kind, numElements, i, type);
} else {
- os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
+ os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i,
+ type);
}
}
return;
@@ -388,9 +403,17 @@ static void emitElementAccessors(
for (unsigned i = 0; i < numElements; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (!element.name.empty()) {
+ std::string type;
+ if (element.isVariableLength()) {
+ type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
+ : "_ods_ir.OpResultList";
+ } else {
+ type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+ : "_ods_ir.OpResult";
+ }
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
kind, numSimpleLength, numVariadicGroups,
- numPrecedingSimple, numPrecedingVariadic);
+ numPrecedingSimple, numPrecedingVariadic, type);
os << formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
: opVariadicEqualSimpleTemplate,
@@ -413,13 +436,23 @@ static void emitElementAccessors(
if (element.name.empty())
continue;
std::string trailing;
- if (!element.isVariableLength())
- trailing = "[0]";
- else if (element.isOptional())
- trailing = std::string(
- formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
+ std::string type = std::strcmp(kind, "operand") == 0
+ ? "_ods_ir.OpOperandList"
+ : "_ods_ir.OpResultList";
+ if (!element.isVariableLength() || element.isOptional()) {
+ type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+ : "_ods_ir.OpResult";
+ if (!element.isVariableLength()) {
+ trailing = "[0]";
+ } else if (element.isOptional()) {
+ type = "_Optional[" + type + "]";
+ trailing = std::string(
+ formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
+ }
+ }
+
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
- i, trailing);
+ i, trailing, type);
}
return;
}
@@ -980,8 +1013,9 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
"expected only the last region to be variadic");
os << formatv(regionAccessorTemplate, sanitizeName(region.name),
- std::to_string(en.index()) +
- (region.isVariadic() ? ":" : ""));
+ std::to_string(en.index()) + (region.isVariadic() ? ":" : ""),
+ region.isVariadic() ? "_ods_ir.RegionSequence"
+ : "_ods_ir.Region");
}
}
>From c2bcd9c1d7401e23197f49c278f00fefb4848a39 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 15 Sep 2025 12:23:14 -0700
Subject: [PATCH 2/3] add concerete type hints
---
mlir/lib/Bindings/Python/IRAttributes.cpp | 6 +-
mlir/test/mlir-tblgen/op-python-bindings.td | 8 +-
mlir/test/python/dialects/python_test.py | 156 ++++++++++++++++++
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 89 +++++++++-
4 files changed, 244 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b3c768846c74f..4e0ade41fb708 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1742,9 +1742,9 @@ nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
return nb::cast(PyBoolAttribute(pyAttribute));
if (PyIntegerAttribute::isaFunction(pyAttribute))
return nb::cast(PyIntegerAttribute(pyAttribute));
- std::string msg =
- std::string("Can't cast unknown element type DenseArrayAttr (") +
- nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
+ std::string msg = std::string("Can't cast unknown attribute type Attr (") +
+ nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
+ ")";
throw nb::type_error(msg.c_str());
}
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 432cfaef4d7d9..eee09f9d5cedf 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -139,11 +139,11 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
// CHECK: @builtins.property
- // CHECK: def i32attr(self) -> _ods_ir.Attribute:
+ // CHECK: def i32attr(self) -> _ods_ir.IntegerAttr:
// CHECK: return self.operation.attributes["i32attr"]
// CHECK: @builtins.property
- // CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.Attribute]:
+ // CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.FloatAttr]:
// CHECK: if "optionalF32Attr" not in self.operation.attributes:
// CHECK: return None
// CHECK: return self.operation.attributes["optionalF32Attr"]
@@ -153,7 +153,7 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: return "unitAttr" in self.operation.attributes
// CHECK: @builtins.property
- // CHECK: def in_(self) -> _ods_ir.Attribute:
+ // CHECK: def in_(self) -> _ods_ir.IntegerAttr:
// CHECK: return self.operation.attributes["in"]
let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
@@ -190,7 +190,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: return "in" in self.operation.attributes
// CHECK: @builtins.property
- // CHECK: def is_(self) -> _Optional[_ods_ir.Attribute]:
+ // CHECK: def is_(self) -> _Optional[_ods_ir.FloatAttr]:
// CHECK: if "is" not in self.operation.attributes:
// CHECK: return None
// CHECK: return self.operation.attributes["is"]
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index db822c641bf47..9686cc61bb11c 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -2,6 +2,7 @@
# RUN: %PYTHON %s nanobind | FileCheck %s
import inspect
import sys
+import typing
from typing import Union
from mlir.ir import *
@@ -233,6 +234,161 @@ def attrBuilder():
op.verify()
op.print(use_local_scope=True)
+ # fmt: off
+ assert typing.get_type_hints(test.AttributesOp.x_affinemaparr.fset)["value"] is ArrayAttr
+ assert type(op.x_affinemaparr) is typing.get_type_hints(test.AttributesOp.x_affinemaparr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_affinemap.fset)["value"] is AffineMapAttr
+ assert type(op.x_affinemap) is typing.get_type_hints(test.AttributesOp.x_affinemap.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_arr.fset)["value"] is ArrayAttr
+ assert type(op.x_arr) is typing.get_type_hints(test.AttributesOp.x_arr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_boolarr.fset)["value"] is ArrayAttr
+ assert type(op.x_boolarr) is typing.get_type_hints(test.AttributesOp.x_boolarr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_bool.fset)["value"] is BoolAttr
+ assert type(op.x_bool) is typing.get_type_hints(test.AttributesOp.x_bool.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_dboolarr.fset)["value"] is DenseBoolArrayAttr
+ assert type(op.x_dboolarr) is typing.get_type_hints(test.AttributesOp.x_dboolarr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_df32arr.fset)["value"] is DenseF32ArrayAttr
+ assert type(op.x_df32arr) is typing.get_type_hints(test.AttributesOp.x_df32arr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_df64arr.fset)["value"] is DenseF64ArrayAttr
+ assert type(op.x_df64arr) is typing.get_type_hints(test.AttributesOp.x_df64arr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_df16arr.fset)["value"] is DenseI16ArrayAttr
+ assert type(op.x_df16arr) is typing.get_type_hints(test.AttributesOp.x_df16arr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_di32arr.fset)["value"] is DenseI32ArrayAttr
+ assert type(op.x_di32arr) is typing.get_type_hints(test.AttributesOp.x_di32arr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_di64arr.fset)["value"] is DenseI64ArrayAttr
+ assert type(op.x_di64arr) is typing.get_type_hints(test.AttributesOp.x_di64arr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_di8arr.fset)["value"] is DenseI8ArrayAttr
+ assert type(op.x_di8arr) is typing.get_type_hints(test.AttributesOp.x_di8arr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_dictarr.fset)["value"] is ArrayAttr
+ assert type(op.x_dictarr) is typing.get_type_hints(test.AttributesOp.x_dictarr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_dict.fset)["value"] is DictAttr
+ assert type(op.x_dict) is typing.get_type_hints(test.AttributesOp.x_dict.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_f32arr.fset)["value"] is ArrayAttr
+ assert type(op.x_f32arr) is typing.get_type_hints(test.AttributesOp.x_f32arr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_f32.fset)["value"] is FloatAttr
+ assert type(op.x_f32) is typing.get_type_hints(test.AttributesOp.x_f32.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_f64arr.fset)["value"] is ArrayAttr
+ assert type(op.x_f64arr) is typing.get_type_hints(test.AttributesOp.x_f64arr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_f64.fset)["value"] is FloatAttr
+ assert type(op.x_f64) is typing.get_type_hints(test.AttributesOp.x_f64.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_f64elems.fset)["value"] is DenseFPElementsAttr
+ assert type(op.x_f64elems) is typing.get_type_hints(test.AttributesOp.x_f64elems.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fset)["value"] is ArrayAttr
+ assert type(op.x_flatsymrefarr) is typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_flatsymref.fset)["value"] is FlatSymbolRefAttr
+ assert type(op.x_flatsymref) is typing.get_type_hints(test.AttributesOp.x_flatsymref.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_i16.fset)["value"] is IntegerAttr
+ assert type(op.x_i16) is typing.get_type_hints(test.AttributesOp.x_i16.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_i1.fset)["value"] is BoolAttr
+ assert type(op.x_i1) is typing.get_type_hints(test.AttributesOp.x_i1.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_i32arr.fset)["value"] is ArrayAttr
+ assert type(op.x_i32arr) is typing.get_type_hints(test.AttributesOp.x_i32arr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_i32.fset)["value"] is IntegerAttr
+ assert type(op.x_i32) is typing.get_type_hints(test.AttributesOp.x_i32.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_i32elems.fset)["value"] is DenseIntElementsAttr
+ assert type(op.x_i32elems) is typing.get_type_hints(test.AttributesOp.x_i32elems.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_i64arr.fset)["value"] is ArrayAttr
+ assert type(op.x_i64arr) is typing.get_type_hints(test.AttributesOp.x_i64arr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_i64.fset)["value"] is IntegerAttr
+ assert type(op.x_i64) is typing.get_type_hints(test.AttributesOp.x_i64.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_i64elems.fset)["value"] is DenseIntElementsAttr
+ assert type(op.x_i64elems) is typing.get_type_hints(test.AttributesOp.x_i64elems.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_i64svecarr.fset)["value"] is ArrayAttr
+ assert type(op.x_i64svecarr) is typing.get_type_hints(test.AttributesOp.x_i64svecarr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_i8.fset)["value"] is IntegerAttr
+ assert type(op.x_i8) is typing.get_type_hints(test.AttributesOp.x_i8.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_idx.fset)["value"] is IntegerAttr
+ assert type(op.x_idx) is typing.get_type_hints(test.AttributesOp.x_idx.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_idxelems.fset)["value"] is DenseIntElementsAttr
+ assert type(op.x_idxelems) is typing.get_type_hints(test.AttributesOp.x_idxelems.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_idxlistarr.fset)["value"] is ArrayAttr
+ assert type(op.x_idxlistarr) is typing.get_type_hints(test.AttributesOp.x_idxlistarr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_si16.fset)["value"] is IntegerAttr
+ assert type(op.x_si16) is typing.get_type_hints(test.AttributesOp.x_si16.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_si1.fset)["value"] is IntegerAttr
+ assert type(op.x_si1) is typing.get_type_hints(test.AttributesOp.x_si1.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_si32.fset)["value"] is IntegerAttr
+ assert type(op.x_si32) is typing.get_type_hints(test.AttributesOp.x_si32.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_si64.fset)["value"] is IntegerAttr
+ assert type(op.x_si64) is typing.get_type_hints(test.AttributesOp.x_si64.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_si8.fset)["value"] is IntegerAttr
+ assert type(op.x_si8) is typing.get_type_hints(test.AttributesOp.x_si8.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_strarr.fset)["value"] is ArrayAttr
+ assert type(op.x_strarr) is typing.get_type_hints(test.AttributesOp.x_strarr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_str.fset)["value"] is StringAttr
+ assert type(op.x_str) is typing.get_type_hints(test.AttributesOp.x_str.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_sym.fset)["value"] is StringAttr
+ assert type(op.x_sym) is typing.get_type_hints(test.AttributesOp.x_sym.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_symrefarr.fset)["value"] is ArrayAttr
+ assert type(op.x_symrefarr) is typing.get_type_hints(test.AttributesOp.x_symrefarr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_symref.fset)["value"] is SymbolRefAttr
+ assert type(op.x_symref) is typing.get_type_hints(test.AttributesOp.x_symref.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_typearr.fset)["value"] is ArrayAttr
+ assert type(op.x_typearr) is typing.get_type_hints(test.AttributesOp.x_typearr.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_type.fset)["value"] is TypeAttr
+ assert type(op.x_type) is typing.get_type_hints(test.AttributesOp.x_type.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_ui16.fset)["value"] is IntegerAttr
+ assert type(op.x_ui16) is typing.get_type_hints(test.AttributesOp.x_ui16.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_ui1.fset)["value"] is IntegerAttr
+ assert type(op.x_ui1) is typing.get_type_hints(test.AttributesOp.x_ui1.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_ui32.fset)["value"] is IntegerAttr
+ assert type(op.x_ui32) is typing.get_type_hints(test.AttributesOp.x_ui32.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_ui64.fset)["value"] is IntegerAttr
+ assert type(op.x_ui64) is typing.get_type_hints(test.AttributesOp.x_ui64.fget)["return"]
+
+ assert typing.get_type_hints(test.AttributesOp.x_ui8.fset)["value"] is IntegerAttr
+ assert type(op.x_ui8) is typing.get_type_hints(test.AttributesOp.x_ui8.fget)["return"]
+ # fmt: on
+
# CHECK-LABEL: TEST: inferReturnTypes
@run
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 2e33581fbbff8..f73324afd200c 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -194,18 +194,20 @@ constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
/// Template for an operation attribute getter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
+/// {2} is the type hint.
constexpr const char *attributeGetterTemplate = R"Py(
@builtins.property
- def {0}(self) -> _ods_ir.Attribute:
+ def {0}(self) -> {2}:
return self.operation.attributes["{1}"]
)Py";
/// Template for an optional operation attribute getter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
+/// {2} is the type hint.
constexpr const char *optionalAttributeGetterTemplate = R"Py(
@builtins.property
- def {0}(self) -> _Optional[_ods_ir.Attribute]:
+ def {0}(self) -> _Optional[{2}]:
if "{1}" not in self.operation.attributes:
return None
return self.operation.attributes["{1}"]
@@ -225,9 +227,10 @@ constexpr const char *unitAttributeGetterTemplate = R"Py(
/// Template for an operation attribute setter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
+/// {2} is the type hint.
constexpr const char *attributeSetterTemplate = R"Py(
@{0}.setter
- def {0}(self, value):
+ def {0}(self, value: {2}):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["{1}"] = value
@@ -237,9 +240,10 @@ constexpr const char *attributeSetterTemplate = R"Py(
/// removes the attribute:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
+/// {2} is the type hint.
constexpr const char *optionalAttributeSetterTemplate = R"Py(
@{0}.setter
- def {0}(self, value):
+ def {0}(self, value: _Optional[{2}]):
if value is not None:
self.operation.attributes["{1}"] = value
elif "{1}" in self.operation.attributes:
@@ -482,6 +486,72 @@ static void emitResultAccessors(const Operator &op, raw_ostream &os) {
getNumResults(op), getResult);
}
+static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
+ auto storageTypeStr = attr.getStorageType();
+ if (storageTypeStr == "::mlir::AffineMapAttr")
+ return "AffineMapAttr";
+ if (storageTypeStr == "::mlir::ArrayAttr")
+ return "ArrayAttr";
+ if (storageTypeStr == "::mlir::BoolAttr")
+ return "BoolAttr";
+ if (storageTypeStr == "::mlir::DenseBoolArrayAttr")
+ return "DenseBoolArrayAttr";
+ if (storageTypeStr == "::mlir::DenseElementsAttr") {
+ llvm::StringSet<> superClasses;
+ for (const Record *sc : attr.getDef().getSuperClasses())
+ superClasses.insert(sc->getNameInitAsString());
+ if (superClasses.contains("FloatElementsAttr") ||
+ superClasses.contains("RankedFloatElementsAttr")) {
+ return "DenseFPElementsAttr";
+ }
+ return "DenseElementsAttr";
+ }
+ if (storageTypeStr == "::mlir::DenseF32ArrayAttr")
+ return "DenseF32ArrayAttr";
+ if (storageTypeStr == "::mlir::DenseF64ArrayAttr")
+ return "DenseF64ArrayAttr";
+ if (storageTypeStr == "::mlir::DenseFPElementsAttr")
+ return "DenseFPElementsAttr";
+ if (storageTypeStr == "::mlir::DenseI16ArrayAttr")
+ return "DenseI16ArrayAttr";
+ if (storageTypeStr == "::mlir::DenseI32ArrayAttr")
+ return "DenseI32ArrayAttr";
+ if (storageTypeStr == "::mlir::DenseI64ArrayAttr")
+ return "DenseI64ArrayAttr";
+ if (storageTypeStr == "::mlir::DenseI8ArrayAttr")
+ return "DenseI8ArrayAttr";
+ if (storageTypeStr == "::mlir::DenseIntElementsAttr")
+ return "DenseIntElementsAttr";
+ if (storageTypeStr == "::mlir::DenseResourceElementsAttr")
+ return "DenseResourceElementsAttr";
+ if (storageTypeStr == "::mlir::DictionaryAttr")
+ return "DictAttr";
+ if (storageTypeStr == "::mlir::FlatSymbolRefAttr")
+ return "FlatSymbolRefAttr";
+ if (storageTypeStr == "::mlir::FloatAttr")
+ return "FloatAttr";
+ if (storageTypeStr == "::mlir::IntegerAttr") {
+ if (attr.getAttrDefName().str() == "I1Attr")
+ return "BoolAttr";
+ return "IntegerAttr";
+ }
+ if (storageTypeStr == "::mlir::IntegerSetAttr")
+ return "IntegerSetAttr";
+ if (storageTypeStr == "::mlir::OpaqueAttr")
+ return "OpaqueAttr";
+ if (storageTypeStr == "::mlir::StridedLayoutAttr")
+ return "StridedLayoutAttr";
+ if (storageTypeStr == "::mlir::StringAttr")
+ return "StringAttr";
+ if (storageTypeStr == "::mlir::SymbolRefAttr")
+ return "SymbolRefAttr";
+ if (storageTypeStr == "::mlir::TypeAttr")
+ return "TypeAttr";
+ if (storageTypeStr == "::mlir::UnitAttr")
+ return "UnitAttr";
+ return "Attribute";
+}
+
/// Emits accessors to Op attributes.
static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
for (const auto &namedAttr : op.getAttributes()) {
@@ -503,15 +573,18 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
continue;
}
+ std::string type = "_ods_ir." + getPythonAttrName(namedAttr.attr);
if (namedAttr.attr.isOptional()) {
os << formatv(optionalAttributeGetterTemplate, sanitizedName,
- namedAttr.name);
+ namedAttr.name, type);
os << formatv(optionalAttributeSetterTemplate, sanitizedName,
- namedAttr.name);
+ namedAttr.name, type);
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
} else {
- os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name);
- os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name);
+ os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name,
+ type);
+ os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name,
+ type);
// Non-optional attributes cannot be deleted.
}
}
>From 69abc3f34eeb2b9368f9eedfe201b30ba2adbc76 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 16 Sep 2025 09:35:04 -0700
Subject: [PATCH 3/3] is type is instead of isinstance
---
mlir/test/python/dialects/python_test.py | 94 +++++++++++++++---------
1 file changed, 60 insertions(+), 34 deletions(-)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 9686cc61bb11c..5c7d3234c6df2 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -1,9 +1,8 @@
# RUN: %PYTHON %s pybind11 | FileCheck %s
# RUN: %PYTHON %s nanobind | FileCheck %s
-import inspect
import sys
import typing
-from typing import Union
+from typing import Union, Optional
from mlir.ir import *
import mlir.dialects.func as func
@@ -449,6 +448,13 @@ def resultTypesDefinedByTraits():
# CHECK-COUNT-2: i32
print(same.one.type)
print(same.two.type)
+ assert (
+ typing.get_type_hints(test.SameOperandAndResultTypeOp.one.fget)[
+ "return"
+ ]
+ is OpResult
+ )
+ assert type(same.one) is OpResult
first_type_attr = test.FirstAttrDeriveTypeAttrOp(
inferred.results[1], TypeAttr.get(IndexType.get())
@@ -491,6 +497,15 @@ def testOptionalOperandOp():
op1 = test.OptionalOperandOp()
# CHECK: op1.input is None: True
print(f"op1.input is None: {op1.input is None}")
+ assert (
+ typing.get_type_hints(test.OptionalOperandOp.input.fget)["return"]
+ is Optional[Value]
+ )
+ assert (
+ typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"]
+ is OpResult
+ )
+ assert type(op1.result) is OpResult
op2 = test.OptionalOperandOp(input=op1)
# CHECK: op2.input is None: False
@@ -754,14 +769,12 @@ def testInferTypeOpInterface():
print(two_operands.result.type)
assert (
- inspect.signature(
- test.infer_results_variadic_inputs_op
- ).return_annotation
+ typing.get_type_hints(test.infer_results_variadic_inputs_op)["return"]
is OpResult
)
- assert isinstance(
- test.infer_results_variadic_inputs_op(single=zero, doubled=zero),
- OpResult,
+ assert (
+ type(test.infer_results_variadic_inputs_op(single=zero, doubled=zero))
+ is OpResult
)
@@ -785,25 +798,36 @@ def values(lst):
[zero, one], two, [three, four]
)
# CHECK: Value(%{{.*}} = arith.constant 2 : i32)
- non_variadic = variadic_operands.non_variadic
- print(non_variadic)
- assert isinstance(non_variadic, Value)
+ print(variadic_operands.non_variadic)
+ assert (
+ typing.get_type_hints(test.SameVariadicOperandSizeOp.non_variadic.fget)[
+ "return"
+ ]
+ is Value
+ )
+ assert type(variadic_operands.non_variadic) is Value
+
# CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
- variadic1 = variadic_operands.variadic1
- print(values(variadic1))
- assert isinstance(variadic1, OpOperandList)
+ print(values(variadic_operands.variadic1))
+ assert (
+ typing.get_type_hints(test.SameVariadicOperandSizeOp.variadic1.fget)[
+ "return"
+ ]
+ is OpOperandList
+ )
+ assert type(variadic_operands.variadic1) is OpOperandList
+
# CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
- variadic2 = variadic_operands.variadic2
- print(values(variadic2))
- assert isinstance(variadic2, OpOperandList)
+ print(values(variadic_operands.variadic2))
+ assert type(variadic_operands.variadic2) is OpOperandList
assert (
- inspect.signature(test.same_variadic_operand).return_annotation
+ typing.get_type_hints(test.same_variadic_operand)["return"]
is test.SameVariadicOperandSizeOp
)
- assert isinstance(
- test.same_variadic_operand([zero, one], two, [three, four]),
- test.SameVariadicOperandSizeOp,
+ assert (
+ type(test.same_variadic_operand([zero, one], two, [three, four]))
+ is test.SameVariadicOperandSizeOp
)
@@ -828,12 +852,12 @@ def types(lst):
print(types(op.variadic2))
assert (
- inspect.signature(test.same_variadic_result_vfv).return_annotation
+ typing.get_type_hints(test.same_variadic_result_vfv)["return"]
is Union[OpResult, OpResultList, test.SameVariadicResultSizeOpVFV]
)
- assert isinstance(
- test.same_variadic_result_vfv([i[0], i[1]], i[2], [i[3], i[4]]),
- OpResultList,
+ assert (
+ type(test.same_variadic_result_vfv([i[0], i[1]], i[2], [i[3], i[4]]))
+ is OpResultList
)
# Test Variadic-Variadic-Variadic
@@ -854,9 +878,8 @@ def types(lst):
# CHECK: i1
print(op.non_variadic2.type)
# CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
- variadic = op.variadic
- print(types(variadic))
- assert isinstance(variadic, OpResultList)
+ print(types(op.variadic))
+ assert type(op.variadic) is OpResultList
# Test Variadic-Variadic-Fixed
op = test.SameVariadicResultSizeOpVVF(
@@ -911,13 +934,16 @@ def types(lst):
print(op.non_variadic3.type)
assert (
- inspect.signature(test.results_variadic).return_annotation
+ typing.get_type_hints(test.results_variadic)["return"]
is Union[OpResult, OpResultList, test.ResultsVariadicOp]
)
- assert isinstance(
- test.results_variadic([i[0]]),
- OpResult,
+ assert type(test.results_variadic([i[0]])) is OpResult
+ op_res_variadic = test.ResultsVariadicOp([i[0]])
+ assert (
+ typing.get_type_hints(test.ResultsVariadicOp.res.fget)["return"]
+ is OpResultList
)
+ assert type(op_res_variadic.res) is OpResultList
# CHECK-LABEL: TEST: testVariadicAndNormalRegion
@@ -927,5 +953,5 @@ def testVariadicAndNormalRegionOp():
module = Module.create()
with InsertionPoint(module.body):
region_op = test.VariadicAndNormalRegionOp(2)
- assert isinstance(region_op.region, Region)
- assert isinstance(region_op.variadic, RegionSequence)
+ assert type(region_op.region) is Region
+ assert type(region_op.variadic) is RegionSequence
More information about the Mlir-commits
mailing list