[Mlir-commits] [mlir] [MLIR] [Python] The generated op definitions now use typed parameters (PR #188635)
Sergei Lebedev
llvmlistbot at llvm.org
Thu Mar 26 10:57:44 PDT 2026
https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/188635
>From 893131e2af51ab0a9880047d7a13b129a2e024ad Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Wed, 25 Mar 2026 19:55:50 +0000
Subject: [PATCH 1/2] [MLIR] [Python] The generated op definitions now use
typed parameters
As with operand/result types this only handles standard dialects, but I think
it is still useful as is.
We could consider extensibility if/when necessary.
---
mlir/test/mlir-tblgen/op-python-bindings.td | 95 +++++----
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 193 +++++++++++++-----
2 files changed, 191 insertions(+), 97 deletions(-)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 141cf430f36ef..23d4194344a66 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -21,7 +21,7 @@ class TestOp<string mnemonic, list<Trait> traits = []> :
// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,0,]
def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
[AttrSizedOperandSegments]> {
- // CHECK: def __init__(self, variadic1, non_variadic, *, variadic2=None, loc=None, ip=None):
+ // CHECK: def __init__(self, variadic1: _Sequence[_ods_ir.Value], non_variadic: _ods_ir.Value, *, variadic2: _Optional[_ods_ir.Value] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -60,7 +60,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
Optional<AnyType>:$variadic2);
}
-// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) -> AttrSizedOperandsOp:
+// CHECK: def attr_sized_operands(variadic1: _Sequence[_ods_ir.Value], non_variadic: _ods_ir.Value, *, variadic2: _Optional[_ods_ir.Value] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> AttrSizedOperandsOp:
// CHECK: return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -69,7 +69,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: _ODS_RESULT_SEGMENTS = [0,1,-1,]
def AttrSizedResultsOp : TestOp<"attr_sized_results",
[AttrSizedResultSegments]> {
- // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
+ // CHECK: def __init__(self, variadic1: _Optional[_ods_ir.Type], non_variadic: _ods_ir.Type, variadic2: _Sequence[_ods_ir.Type], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -108,7 +108,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
Variadic<AnyType>:$variadic2);
}
-// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, AttrSizedResultsOp]:
+// CHECK: def attr_sized_results(variadic1: _Optional[_ods_ir.Type], non_variadic: _ods_ir.Type, variadic2: _Sequence[_ods_ir.Type], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, AttrSizedResultsOp]:
// CHECK: op = AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@@ -118,7 +118,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def AttributedOp : TestOp<"attributed_op"> {
- // CHECK: def __init__(self, i32attr, in_, *, optionalF32Attr=None, unitAttr=None, loc=None, ip=None):
+ // CHECK: def __init__(self, i32attr: _Union[int, _ods_ir.IntegerAttr], in_: _Union[int, _ods_ir.IntegerAttr], *, optionalF32Attr: _Optional[_Union[float, _ods_ir.FloatAttr]] = None, unitAttr: _Optional[bool] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -176,7 +176,7 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: def in_(self) -> _ods_ir.IntegerAttr:
// CHECK: return self.attributes["in"]
-// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -> AttributedOp:
+// CHECK: def attributed_op(i32attr: _Union[int, _ods_ir.IntegerAttr], in_: _Union[int, _ods_ir.IntegerAttr], *, optional_f32_attr: _Optional[_Union[float, _ods_ir.FloatAttr]] = None, unit_attr: _Optional[bool] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> AttributedOp:
// CHECK: return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -185,7 +185,7 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
- // CHECK: def __init__(self, _gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None):
+ // CHECK: def __init__(self, _gen_arg_0: _ods_ir.Value[_ods_ir.IntegerType], _gen_arg_2: _ods_ir.Value[_ods_ir.FloatType], *, in_: _Optional[bool] = None, is_: _Optional[_Union[float, _ods_ir.FloatAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -213,14 +213,14 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
}
-// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) -> AttributedOpWithOperands
+// CHECK: def attributed_op_with_operands(_gen_arg_0: _ods_ir.Value[_ods_ir.IntegerType], _gen_arg_2: _ods_ir.Value[_ods_ir.FloatType], *, in_: _Optional[bool] = None, is_: _Optional[_Union[float, _ods_ir.FloatAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> AttributedOpWithOperands:
// CHECK: return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK-LABEL: class DefaultValuedAttrsOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.default_valued_attrs"
def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
- // CHECK: def __init__(self, *, arr=None, unsupported=None, loc=None, ip=None):
+ // CHECK: def __init__(self, *, arr: _Optional[_Union[_Sequence[int], _ods_ir.ArrayAttr]] = None, unsupported: _Optional[_Union[_Sequence[int], _ods_ir.ArrayAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -238,12 +238,12 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
let results = (outs);
}
-// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) -> DefaultValuedAttrsOp:
+// CHECK: def default_valued_attrs(*, arr: _Optional[_Union[_Sequence[int], _ods_ir.ArrayAttr]] = None, unsupported: _Optional[_Union[_Sequence[int], _ods_ir.ArrayAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> DefaultValuedAttrsOp:
// CHECK: return DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
- // CHECK: def __init__(self, type_, *, results=None, loc=None, ip=None):
+ // CHECK: def __init__(self, type_: _Union[_ods_ir.Type, _ods_ir.TypeAttr], *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: if results is None:
// CHECK: _ods_result_type_source_attr = attributes["type"]
@@ -256,17 +256,17 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
let results = (outs AnyType:$res, AnyType);
}
-// CHECK: def derive_result_types_op(type_, *, results=None, loc=None, ip=None) -> _ods_ir.OpResultList:
+// CHECK: def derive_result_types_op(type_: _Union[_ods_ir.Type, _ods_ir.TypeAttr], *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResultList:
// CHECK: return DeriveResultTypesOp(type_=type_, results=results, loc=loc, ip=ip).results
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
- // CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
+ // CHECK: def __init__(self, res: _ods_ir.Type, _gen_res_1: _Sequence[_ods_ir.Type], type_: _Union[_ods_ir.Type, _ods_ir.TypeAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
let arguments = (ins TypeAttr:$type);
let results = (outs AnyType:$res, Variadic<AnyType>);
}
-// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, DeriveResultTypesVariadicOp]:
+// CHECK: def derive_result_types_variadic_op(res: _ods_ir.Type, _gen_res_1: _Sequence[_ods_ir.Type], type_: _Union[_ods_ir.Type, _ods_ir.TypeAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, DeriveResultTypesVariadicOp]:
// CHECK: op = DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@@ -302,7 +302,7 @@ def DescriptionOp : TestOp<"description"> {
// CHECK-LABEL: class EmptyOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.empty"
def EmptyOp : TestOp<"empty">;
- // CHECK: def __init__(self, *, loc=None, ip=None):
+ // CHECK: def __init__(self, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -313,12 +313,12 @@ def EmptyOp : TestOp<"empty">;
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
-// CHECK: def empty(*, loc=None, ip=None) -> EmptyOp:
+// CHECK: def empty(*, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> EmptyOp:
// CHECK: return EmptyOp(loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
- // CHECK: def __init__(self, *, results=None, loc=None, ip=None):
+ // CHECK: def __init__(self, *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: _ods_context = _ods_get_default_loc_context(loc)
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -327,12 +327,12 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
let results = (outs I32:$i32, F32:$f32);
}
-// CHECK: def infer_result_types_implied_op(*, results=None, loc=None, ip=None) -> _ods_ir.OpResultList:
+// CHECK: def infer_result_types_implied_op(*, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResultList:
// CHECK: return InferResultTypesImpliedOp(results=results, loc=loc, ip=ip).results
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
- // CHECK: def __init__(self, *, results=None, loc=None, ip=None):
+ // CHECK: def __init__(self, *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -341,14 +341,14 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
let results = (outs AnyType, AnyType, AnyType);
}
-// CHECK: def infer_result_types_op(*, results=None, loc=None, ip=None) -> _ods_ir.OpResultList:
+// CHECK: def infer_result_types_op(*, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResultList:
// CHECK: return InferResultTypesOp(results=results, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK-LABEL: class MissingNamesOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.missing_names"
def MissingNamesOp : TestOp<"missing_names"> {
- // CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None):
+ // CHECK: def __init__(self, i32: _ods_ir.Type, _gen_res_1: _ods_ir.Type, i64: _ods_ir.Type, _gen_arg_0: _ods_ir.Value[_ods_ir.IntegerType], f32: _ods_ir.Value[_ods_ir.FloatType], _gen_arg_2: _ods_ir.Value[_ods_ir.IntegerType], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -380,7 +380,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}
-// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -> _ods_ir.OpResultList:
+// CHECK: def missing_names(i32: _ods_ir.Type, _gen_res_1: _ods_ir.Type, i64: _ods_ir.Type, _gen_arg_0: _ods_ir.Value[_ods_ir.IntegerType], f32: _ods_ir.Value[_ods_ir.FloatType], _gen_arg_2: _ods_ir.Value[_ods_ir.IntegerType], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResultList:
// CHECK: return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -390,7 +390,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
let arguments = (ins AnyType:$non_optional, Optional<AnyType>:$optional);
- // CHECK: def __init__(self, non_optional, *, optional=None, loc=None, ip=None):
+ // CHECK: def __init__(self, non_optional: _ods_ir.Value, *, optional: _Optional[_ods_ir.Value] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -421,7 +421,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
// CHECK: return None if len(self.operands) < 2 else self.operands[1]
-// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -> OneOptionalOperandOp:
+// CHECK: def one_optional_operand(non_optional: _ods_ir.Value, *, optional: _Optional[_ods_ir.Value] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> OneOptionalOperandOp:
// CHECK: return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -430,7 +430,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
- // CHECK: def __init__(self, non_variadic, variadic, *, loc=None, ip=None):
+ // CHECK: def __init__(self, non_variadic: _ods_ir.Value, variadic: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -464,7 +464,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: _ods_variadic_group_length = len(self.operands) - 2 + 1
// CHECK: return self.operands[1:1 + _ods_variadic_group_length]
-// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -> OneVariadicOperandOp:
+// CHECK: def one_variadic_operand(non_variadic: _ods_ir.Value, variadic: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> OneVariadicOperandOp:
// CHECK: return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -473,7 +473,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneVariadicResultOp : TestOp<"one_variadic_result"> {
- // CHECK: def __init__(self, variadic, non_variadic, *, loc=None, ip=None):
+ // CHECK: def __init__(self, variadic: _Sequence[_ods_ir.Type], non_variadic: _ods_ir.Type, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -498,7 +498,7 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
}
-// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, OneVariadicResultOp]:
+// CHECK: def one_variadic_result(variadic: _Sequence[_ods_ir.Type], non_variadic: _ods_ir.Type, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, OneVariadicResultOp]:
// CHECK: op = OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@@ -506,7 +506,7 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
// CHECK-LABEL: class PythonKeywordOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.python_keyword"
def PythonKeywordOp : TestOp<"python_keyword"> {
- // CHECK: def __init__(self, in_, *, loc=None, ip=None):
+ // CHECK: def __init__(self, in_: _ods_ir.Value, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -524,12 +524,12 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
let arguments = (ins AnyType:$in);
}
-// CHECK: def python_keyword(in_, *, loc=None, ip=None) -> PythonKeywordOp:
+// CHECK: def python_keyword(in_: _ods_ir.Value, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> PythonKeywordOp:
// CHECK: return PythonKeywordOp(in_=in_, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
- // CHECK: def __init__(self, in1, in2, *, results=None, loc=None, ip=None):
+ // CHECK: def __init__(self, in1: _ods_ir.Value, in2: _ods_ir.Value, *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: operands.append
// CHECK: if results is None: results = [operands[0].type] * 1
@@ -537,17 +537,17 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
let results = (outs AnyType:$res);
}
-// CHECK: def same_results(in1, in2, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
+// CHECK: def same_results(in1: _ods_ir.Value, in2: _ods_ir.Value, *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
// CHECK: return SameResultsOp(in1=in1, in2=in2, results=results, loc=loc, ip=ip).result
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
- // CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None):
+ // CHECK: def __init__(self, res: _Sequence[_ods_ir.Type], in1: _ods_ir.Value, in2: _ods_ir.Value, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
let arguments = (ins AnyType:$in1, AnyType:$in2);
let results = (outs Variadic<AnyType>:$res);
}
-// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, SameResultsVariadicOp]:
+// CHECK: def same_results_variadic(res: _Sequence[_ods_ir.Type], in1: _ods_ir.Value, in2: _ods_ir.Value, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, SameResultsVariadicOp]:
// CHECK: op = SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@@ -575,7 +575,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
Variadic<AnyType>:$variadic2);
}
-// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> SameVariadicOperandSizeOp:
+// CHECK: def same_variadic_operand(variadic1: _Sequence[_ods_ir.Value], non_variadic: _ods_ir.Value, variadic2: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> SameVariadicOperandSizeOp:
// CHECK: return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -601,7 +601,7 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
Variadic<AnyType>:$variadic2);
}
-// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, SameVariadicResultSizeOp]:
+// CHECK: def same_variadic_result(variadic1: _Sequence[_ods_ir.Type], non_variadic: _ods_ir.Type, variadic2: _Sequence[_ods_ir.Type], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, SameVariadicResultSizeOp]:
// CHECK: op = SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@@ -609,7 +609,7 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
// CHECK-LABEL: class SimpleOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.simple"
def SimpleOp : TestOp<"simple"> {
- // CHECK: def __init__(self, i64, f64, i32, f32, *, loc=None, ip=None):
+ // CHECK: def __init__(self, i64: _ods_ir.Type, f64: _ods_ir.Type, i32: _ods_ir.Value[_ods_ir.IntegerType], f32: _ods_ir.Value[_ods_ir.FloatType], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -643,7 +643,7 @@ def SimpleOp : TestOp<"simple"> {
let results = (outs I64:$i64, AnyFloat:$f64);
}
-// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _ods_ir.OpResultList:
+// CHECK: def simple(i64: _ods_ir.Type, f64: _ods_ir.Type, i32: _ods_ir.Value[_ods_ir.IntegerType], f32: _ods_ir.Value[_ods_ir.FloatType], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResultList:
// CHECK: return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -657,14 +657,13 @@ def SingleTypedResultOp : TestOp<"single_typed_result"> {
let results = (outs I64:$i64);
}
-// CHECK: def single_typed_result(in_, *, results=None, loc=None, ip=None) ->
-// _ods_ir.OpResult[_ods_ir.IntegerType]: CHECK: return
-// SingleTypedResultOp(in_=in_, results=results, loc=loc, ip=ip).result
+// CHECK: def single_typed_result(in_: _ods_ir.Value, *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
+// CHECK: return SingleTypedResultOp(in_=in_, results=results, loc=loc, ip=ip).result
// CHECK-LABEL: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.variadic_and_normal_region"
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
- // CHECK: def __init__(self, num_variadic, *, loc=None, ip=None):
+ // CHECK: def __init__(self, num_variadic: int, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -686,13 +685,13 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
// CHECK: return self.regions[2:]
}
-// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -> VariadicAndNormalRegionOp:
+// CHECK: def variadic_and_normal_region(num_variadic: int, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> VariadicAndNormalRegionOp:
// CHECK: return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
// CHECK-LABEL: class VariadicRegionOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.variadic_region"
def VariadicRegionOp : TestOp<"variadic_region"> {
- // CHECK: def __init__(self, num_variadic, *, loc=None, ip=None):
+ // CHECK: def __init__(self, num_variadic: int, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
// CHECK: operands = []
// CHECK: attributes = {}
// CHECK: regions = None
@@ -710,7 +709,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
// CHECK: return self.regions[0:]
}
-// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) -> VariadicRegionOp:
+// CHECK: def variadic_region(num_variadic: int, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> VariadicRegionOp:
// CHECK: return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -719,7 +718,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
}
-// CHECK: def _123with__special_characters(*, loc=None, ip=None) -> WithSpecialCharactersOp:
+// CHECK: def _123with__special_characters(*, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> WithSpecialCharactersOp:
// CHECK: return WithSpecialCharactersOp(loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -734,11 +733,11 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
VariadicSuccessor<AnySuccessor>:$successors);
}
-// CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -> WithSuccessorsOp:
+// CHECK: def with_successors(successor: _ods_ir.Block, successors: _Sequence[_ods_ir.Block], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> WithSuccessorsOp:
// CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
// CHECK-LABEL: class snake_case(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.snake_case"
def already_snake_case : TestOp<"snake_case"> {}
-// CHECK: def snake_case_(*, loc=None, ip=None) -> snake_case:
+// CHECK: def snake_case_(*, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> snake_case:
// CHECK: return snake_case(loc=loc, ip=ip)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 39e79e5631479..60074935c357d 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -46,7 +46,7 @@ _ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)
import builtins
-from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional
+from typing import Any as _Any, Sequence as _Sequence, Union as _Union, Optional as _Optional
)Py";
@@ -638,6 +638,35 @@ static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
return "Attribute";
}
+/// Returns the Python raw value type accepted by the AttrBuilder for the given
+/// attribute. Returns empty StringRef if no mapping is known.
+static StringRef getPythonAttrRawType(mlir::tblgen::Attribute attr) {
+ return llvm::StringSwitch<StringRef>(attr.getAttrDefName())
+ .Cases({"BoolAttr", "I1Attr"}, "bool")
+ .Cases({"I8Attr", "I16Attr", "I32Attr", "I64Attr"}, "int")
+ .Cases({"SI1Attr", "SI8Attr", "SI16Attr", "SI32Attr", "SI64Attr"}, "int")
+ .Cases({"UI1Attr", "UI8Attr", "UI16Attr", "UI32Attr", "UI64Attr"}, "int")
+ .Case("IndexAttr", "int")
+ .Cases({"F32Attr", "F64Attr"}, "float")
+ .Cases({"StrAttr", "SymbolNameAttr"}, "str")
+ .Cases({"FlatSymbolRefAttr", "SymbolRefAttr"}, "str")
+ .Case("TypeAttr", "_ods_ir.Type")
+ .Case("AffineMapAttr", "_ods_ir.AffineMap")
+ .Case("IntegerSetAttr", "_ods_ir.IntegerSet")
+ .Case("DictionaryAttr", "dict")
+ .Case("ArrayAttr", "_Sequence[_ods_ir.Attribute]")
+ .Cases({"I32ArrayAttr", "I64ArrayAttr", "I64SmallVectorArrayAttr"},
+ "_Sequence[int]")
+ .Cases({"F32ArrayAttr", "F64ArrayAttr"}, "_Sequence[float]")
+ .Cases({"BoolArrayAttr", "DenseBoolArrayAttr"}, "_Sequence[bool]")
+ .Cases({"StrArrayAttr", "FlatSymbolRefArrayAttr"}, "_Sequence[str]")
+ .Cases({"DenseI8ArrayAttr", "DenseI16ArrayAttr", "DenseI32ArrayAttr",
+ "DenseI64ArrayAttr"},
+ "_Sequence[int]")
+ .Cases({"DenseF32ArrayAttr", "DenseF64ArrayAttr"}, "_Sequence[float]")
+ .Default(StringRef());
+}
+
/// Emits accessors to Op attributes.
static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
for (const auto &namedAttr : op.getAttributes()) {
@@ -1099,6 +1128,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
populateBuilderArgs(op, builderArgs, operandArgNames);
size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
+ size_t numSuccessorArgs = successorArgNames.size();
populateBuilderLinesOperand(op, operandArgNames, builderLines);
populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs),
@@ -1108,54 +1138,110 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
populateBuilderRegions(op, builderArgs, builderLines);
- // Layout of builderArgs vector elements:
- // [ result_args operand_attr_args successor_args regions ]
+ // Compute type annotations for each builder arg.
+ SmallVector<std::string> argTypes(builderArgs.size());
+
+ // Result args: user passes Type objects.
+ for (size_t i = 0; i < numResultArgs; ++i) {
+ const NamedTypeConstraint &result = op.getResult(i);
+ if (result.isVariadic())
+ argTypes[i] = "_Sequence[_ods_ir.Type]";
+ else if (result.isOptional())
+ argTypes[i] = "_Optional[_ods_ir.Type]";
+ else
+ argTypes[i] = "_ods_ir.Type";
+ }
+
+ // Operand and attribute args.
+ for (size_t i = 0; i < numOperandAttrArgs; ++i) {
+ size_t idx = numResultArgs + i;
+ Argument arg = op.getArg(i);
+ if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(arg)) {
+ if (nattr->attr.getStorageType().trim() == "::mlir::UnitAttr") {
+ argTypes[idx] = "bool";
+ } else {
+ std::string attrType = "_ods_ir." + getPythonAttrName(nattr->attr);
+ StringRef rawType = getPythonAttrRawType(nattr->attr);
+ argTypes[idx] =
+ llvm::formatv("_Union[{0}, {1}]",
+ rawType.empty() ? "_Any" : rawType, attrType)
+ .str();
+ }
+ } else if (auto *ntype =
+ llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
+ if (ntype->isVariadic()) {
+ argTypes[idx] = "_Sequence[_ods_ir.Value]";
+ } else {
+ std::string type = "_ods_ir.Value";
+ if (StringRef pythonType =
+ getPythonType(ntype->constraint.getCppType());
+ !pythonType.empty())
+ type = llvm::formatv("{0}[{1}]", type, pythonType);
+ argTypes[idx] = type;
+ }
+ }
+ // NamedProperty args are skipped (no type hint).
+ }
- // Determine whether the argument corresponding to a given index into the
- // builderArgs vector is a python keyword argument or not.
- auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
- // All result, successor, and region arguments are positional arguments.
- if ((builderArgIndex < numResultArgs) ||
- (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
+ // Successor args.
+ for (size_t i = 0; i < numSuccessorArgs; ++i) {
+ size_t idx = numResultArgs + numOperandAttrArgs + i;
+ const NamedSuccessor &successor = op.getSuccessor(i);
+ argTypes[idx] =
+ successor.isVariadic() ? "_Sequence[_ods_ir.Block]" : "_ods_ir.Block";
+ }
+
+ // Region args (variadic region count).
+ for (size_t i = numResultArgs + numOperandAttrArgs + numSuccessorArgs;
+ i < builderArgs.size(); ++i) {
+ argTypes[i] = "int";
+ }
+
+ // Determine whether a builder arg is a keyword argument.
+ auto isKeywordArg = [&](size_t i) -> bool {
+ // Only operand/attr args can be keyword; results, successors, and regions
+ // are always positional.
+ if (i < numResultArgs || i >= numResultArgs + numOperandAttrArgs)
return false;
- // Keyword arguments:
- // - optional named attributes (including unit attributes)
- // - default-valued named attributes
- // - optional operands
- Argument a = op.getArg(builderArgIndex - numResultArgs);
+ Argument a = op.getArg(i - numResultArgs);
if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
- return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
+ return nattr->attr.isOptional() || nattr->attr.hasDefaultValue();
if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
return ntype->isOptional();
return false;
};
- // StringRefs in functionArgs refer to strings allocated by builderArgs.
- SmallVector<StringRef> functionArgs;
+ // Format a single function argument with optional type hint and default.
+ auto formatArg = [](StringRef name, StringRef typeHint,
+ bool isKeyword) -> std::string {
+ std::string result = name.str();
+ if (isKeyword && !typeHint.empty()) {
+ result += ": _Optional[" + typeHint.str() + "] = None";
+ } else if (isKeyword) {
+ result += "=None";
+ } else if (!typeHint.empty()) {
+ result += ": " + typeHint.str();
+ }
+ return result;
+ };
- // Add positional arguments.
- for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
- if (!isKeywordArgFn(i))
- functionArgs.push_back(builderArgs[i]);
- }
+ // Build the function argument list: positional args, *, keyword args.
+ SmallVector<std::string> functionArgs;
+ for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i)
+ if (!isKeywordArg(i))
+ functionArgs.push_back(formatArg(builderArgs[i], argTypes[i], false));
- // Add a bare '*' to indicate that all following arguments must be keyword
- // arguments.
functionArgs.push_back("*");
- // Add a default 'None' value to each keyword arg string, and then add to the
- // function args list.
- for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
- if (isKeywordArgFn(i)) {
- builderArgs[i].append("=None");
- functionArgs.push_back(builderArgs[i]);
- }
- }
- if (canInferType(op)) {
- functionArgs.push_back("results=None");
- }
- functionArgs.push_back("loc=None");
- functionArgs.push_back("ip=None");
+ for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i)
+ if (isKeywordArg(i))
+ functionArgs.push_back(formatArg(builderArgs[i], argTypes[i], true));
+
+ if (canInferType(op))
+ functionArgs.push_back(
+ "results: _Optional[_Sequence[_ods_ir.Type]] = None");
+ functionArgs.push_back("loc: _Optional[_ods_ir.Location] = None");
+ functionArgs.push_back("ip: _Optional[_ods_ir.InsertionPoint] = None");
SmallVector<std::string> initArgs;
initArgs.push_back("self.OPERATION_NAME");
@@ -1172,8 +1258,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
os << formatv(initTemplate, llvm::join(functionArgs, ", "),
llvm::join(builderLines, "\n "), llvm::join(initArgs, ", "));
- return llvm::map_to_vector<8>(functionArgs,
- [](StringRef s) { return s.str(); });
+ return functionArgs;
}
static void emitSegmentSpec(
@@ -1226,23 +1311,33 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
static void emitValueBuilder(const Operator &op,
SmallVector<std::string> functionArgs,
raw_ostream &os) {
+ // Parse a formatted function arg "name[: type][ = default]" into
+ // (name, type, defaultVal) with whitespace trimmed.
+ auto parseFunctionArg =
+ [](StringRef arg) -> std::tuple<StringRef, StringRef, StringRef> {
+ auto [nameAndType, defaultVal] = arg.split('=');
+ auto [name, type] = nameAndType.split(':');
+ return {name.trim(), type.trim(), defaultVal.trim()};
+ };
+
// Params with (possibly) default args.
auto valueBuilderParams =
- llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
- SmallVector<StringRef> argMaybeDefault =
- llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
- auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
- if (argMaybeDefault.size() == 2)
- return arg + "=" + argMaybeDefault[1].str();
- return arg;
+ llvm::map_range(functionArgs, [&](const std::string &arg) {
+ auto [name, type, defaultVal] = parseFunctionArg(arg);
+ std::string result = llvm::convertToSnakeFromCamelCase(name);
+ if (!type.empty())
+ result += ": " + type.str();
+ if (!defaultVal.empty())
+ result += " = " + defaultVal.str();
+ return result;
});
// Actual args passed to op builder (e.g., opParam=op_param).
auto opBuilderArgs = llvm::map_range(
llvm::make_filter_range(functionArgs,
[](const std::string &s) { return s != "*"; }),
- [](const std::string &arg) {
- auto lhs = *llvm::split(arg, "=").begin();
- return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
+ [&](const std::string &arg) {
+ auto [name, type, defaultVal] = parseFunctionArg(arg);
+ return (name + "=" + llvm::convertToSnakeFromCamelCase(name)).str();
});
std::string nameWithoutDialect = sanitizeName(
op.getOperationName().substr(op.getOperationName().find('.') + 1));
>From e0b4012372298aefbab70d095d5cbd279f5cf73c Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Thu, 26 Mar 2026 15:09:04 +0000
Subject: [PATCH 2/2] [MLIR] [Python] `OpOperandList` and `OpResultList` are
now generic
These are defined a single type parameter generic classes. In general, such
lists can be heterogeneously-typed, but not when they are used for variadic
operands/results.
I also updated the bindings generator to emit type arguments for both classes.
---
mlir/include/mlir/Bindings/Python/IRCore.h | 2 +
.../mlir/Bindings/Python/NanobindUtils.h | 37 ++++++++++++++++---
mlir/test/mlir-tblgen/op-python-bindings.td | 17 +++++++++
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 36 +++++++++++-------
4 files changed, 74 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index db8427cfc4f78..92c4072867cba 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1600,6 +1600,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpResultList
: public Sliceable<PyOpResultList, PyOpResult> {
public:
static constexpr const char *pyClassName = "OpResultList";
+ static constexpr std::array<const char *, 1> typeParams = {"_T"};
using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
@@ -1676,6 +1677,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperandList
: public Sliceable<PyOpOperandList, PyValue> {
public:
static constexpr const char *pyClassName = "OpOperandList";
+ static constexpr std::array<const char *, 1> typeParams = {"_T"};
using SliceableT = Sliceable<PyOpOperandList, PyValue>;
PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
diff --git a/mlir/include/mlir/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
index 8d8f9103f21dd..ea43356d2cf54 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindUtils.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
@@ -13,6 +13,7 @@
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include <array>
#include <atomic>
#include <fstream>
#include <memory>
@@ -318,11 +319,17 @@ struct PySinglePartStringAccumulator {
/// A derived class may additionally define:
/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
/// the python class.
+/// - a `static constexpr std::array<const char *, N> typeParams` to make the
+/// Python class generic, parameterizable with the given type parameters.
template <typename Derived, typename ElementTy>
class Sliceable {
protected:
using ClassTy = nanobind::class_<Derived>;
+ /// Type parameter names for generic classes. When non-empty, the Python
+ /// class will be made generic with `typing.Generic[...]`.
+ static constexpr std::array<const char *, 0> typeParams = {};
+
/// Transforms `index` into a legal value to access the underlying sequence.
/// Returns <0 on failure.
intptr_t wrapIndex(intptr_t index) {
@@ -475,11 +482,31 @@ class Sliceable {
nanobind::handle elemTyName = nanobind::detail::nb_type_name(elemTyInfo);
std::string sig = std::string("class ") + Derived::pyClassName +
"(collections.abc.Sequence[" +
- nanobind::cast<std::string>(elemTyName) + "])";
- auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName,
- nanobind::type_slots(sequenceSlots),
- nanobind::sig(sig.c_str()))
- .def("__add__", &Sliceable::dunderAdd);
+ nanobind::cast<std::string>(elemTyName) + "]";
+ if constexpr (!Derived::typeParams.empty()) {
+ sig += ", typing.Generic[";
+ for (size_t i = 0; i < Derived::typeParams.size(); ++i) {
+ if (i > 0)
+ sig += ", ";
+ const char *tp = Derived::typeParams[i];
+ sig += tp;
+ if (!nanobind::hasattr(m, tp))
+ m.attr(tp) = nanobind::type_var(tp);
+ }
+ sig += "]";
+ }
+ sig += ")";
+ ClassTy clazz;
+ if constexpr (!Derived::typeParams.empty()) {
+ clazz =
+ ClassTy(m, Derived::pyClassName, nanobind::type_slots(sequenceSlots),
+ nanobind::is_generic(), nanobind::sig(sig.c_str()));
+ } else {
+ clazz =
+ ClassTy(m, Derived::pyClassName, nanobind::type_slots(sequenceSlots),
+ nanobind::sig(sig.c_str()));
+ }
+ clazz.def("__add__", &Sliceable::dunderAdd);
Derived::bindDerived(clazz);
}
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 23d4194344a66..5e29f3f61e5c8 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -660,6 +660,23 @@ def SingleTypedResultOp : TestOp<"single_typed_result"> {
// CHECK: def single_typed_result(in_: _ods_ir.Value, *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
// CHECK: return SingleTypedResultOp(in_=in_, results=results, loc=loc, ip=ip).result
+// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK-LABEL: class TypedVariadicOperandOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.typed_variadic_operand"
+def TypedVariadicOperandOp : TestOp<"typed_variadic_operand"> {
+ // CHECK: def __init__(self, variadic: _Sequence[_ods_ir.Value[_ods_ir.IntegerType]], non_variadic: _ods_ir.Value[_ods_ir.FloatType], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
+
+ // CHECK: @builtins.property
+ // CHECK: def non_variadic(self) -> _ods_ir.Value[_ods_ir.FloatType]:
+ //
+ // CHECK: @builtins.property
+ // CHECK: def variadic(self) -> _ods_ir.OpOperandList[_ods_ir.IntegerType]:
+ let arguments = (ins Variadic<I32>:$variadic, F32:$non_variadic);
+}
+
+// CHECK: def typed_variadic_operand(variadic: _Sequence[_ods_ir.Value[_ods_ir.IntegerType]], non_variadic: _ods_ir.Value[_ods_ir.FloatType], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> TypedVariadicOperandOp:
+// CHECK: return TypedVariadicOperandOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)
+
// CHECK-LABEL: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.variadic_and_normal_region"
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 60074935c357d..0dbd05f8db10a 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -449,6 +449,10 @@ static void emitElementAccessors(
} else {
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
: "_ods_ir.OpResultList";
+ if (StringRef pythonType =
+ getPythonType(element.constraint.getCppType());
+ !pythonType.empty())
+ type = llvm::formatv("{0}[{1}]", type, pythonType);
os << formatv(opOneVariadicTemplate, sanitizeName(element.name),
pyAttrName, numElements, i, type);
}
@@ -488,12 +492,10 @@ static void emitElementAccessors(
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
}
- if (std::strcmp(type.c_str(), "_ods_ir.Value") == 0 ||
- std::strcmp(type.c_str(), "_ods_ir.OpResult") == 0) {
- StringRef pythonType = getPythonType(element.constraint.getCppType());
- if (!pythonType.empty())
- type += "[" + pythonType.str() + "]";
- }
+ if (StringRef pythonType =
+ getPythonType(element.constraint.getCppType());
+ !pythonType.empty())
+ type = llvm::formatv("{0}[{1}]", type, pythonType);
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
pyAttrName, numSimpleLength, numVariadicGroups,
numPrecedingSimple, numPrecedingVariadic, type);
@@ -525,12 +527,10 @@ static void emitElementAccessors(
if (!element.isVariableLength() || element.isOptional()) {
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
- if (std::strcmp(type.c_str(), "_ods_ir.Value") == 0 ||
- std::strcmp(type.c_str(), "_ods_ir.OpResult") == 0) {
- StringRef pythonType = getPythonType(element.constraint.getCppType());
- if (!pythonType.empty())
- type += "[" + pythonType.str() + "]";
- }
+ if (StringRef pythonType =
+ getPythonType(element.constraint.getCppType());
+ !pythonType.empty())
+ type = llvm::formatv("{0}[{1}]", type, pythonType);
if (!element.isVariableLength()) {
trailing = "[0]";
} else if (element.isOptional()) {
@@ -538,6 +538,11 @@ static void emitElementAccessors(
trailing = std::string(
formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
}
+ } else {
+ if (StringRef pythonType =
+ getPythonType(element.constraint.getCppType());
+ !pythonType.empty())
+ type = llvm::formatv("{0}[{1}]", type, pythonType);
}
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
@@ -1170,7 +1175,12 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
} else if (auto *ntype =
llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
if (ntype->isVariadic()) {
- argTypes[idx] = "_Sequence[_ods_ir.Value]";
+ std::string type = "_ods_ir.Value";
+ if (StringRef pythonType =
+ getPythonType(ntype->constraint.getCppType());
+ !pythonType.empty())
+ type = llvm::formatv("{0}[{1}]", type, pythonType);
+ argTypes[idx] = llvm::formatv("_Sequence[{0}]", type);
} else {
std::string type = "_ods_ir.Value";
if (StringRef pythonType =
More information about the Mlir-commits
mailing list