[Mlir-commits] [mlir] [MLIR][Python] Add optional `results` parameter for building op with inferable result types (PR #156818)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 4 09:43:15 PDT 2025
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/156818
>From 689a0926b89d5a0a5561bc46032785c2d519abfe Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 4 Sep 2025 15:10:28 +0800
Subject: [PATCH 1/3] [MLIR][Python] Add optional results parameter for
building op with inferable result types
---
mlir/test/mlir-tblgen/op-python-bindings.td | 67 +++++++++----------
mlir/test/python/ir/auto_location.py | 6 +-
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 46 +++++++------
3 files changed, 63 insertions(+), 56 deletions(-)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index c2bd86819666b..4d5d7ee26b775 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -23,12 +23,12 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
[AttrSizedOperandSegments]> {
// CHECK: def __init__(self, variadic1, non_variadic, *, variadic2=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(_get_op_results_or_values(variadic1))
// CHECK: operands.append(non_variadic)
// CHECK: operands.append(variadic2)
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -71,9 +71,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
[AttrSizedResultSegments]> {
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
+ // CHECK: results = []
// CHECK: if variadic1 is not None: results.append(variadic1)
// CHECK: results.append(non_variadic)
// CHECK: results.append(variadic2)
@@ -120,7 +120,6 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
def AttributedOp : TestOp<"attributed_op"> {
// CHECK: def __init__(self, i32attr, in_, *, optionalF32Attr=None, unitAttr=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: attributes["i32attr"] = (i32attr if (
@@ -131,6 +130,7 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: attributes["in"] = (in_
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -170,7 +170,6 @@ def AttributedOp : TestOp<"attributed_op"> {
def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: def __init__(self, _gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(_gen_arg_0)
@@ -178,6 +177,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: if is_ is not None: attributes["is"] = (is_
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -205,11 +205,11 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
// CHECK: def __init__(self, *, arr=None, unsupported=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: if arr is not None: attributes["arr"] = (arr
// CHECK: if unsupported is not None: attributes["unsupported"] = (unsupported
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -226,21 +226,21 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
- // CHECK: def __init__(self, type_, *, loc=None, ip=None):
+ // CHECK: def __init__(self, type_, *, results=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
- // CHECK: _ods_result_type_source_attr = attributes["type"]
- // CHECK: _ods_derived_result_type = (
+ // CHECK: if results is None:
+ // CHECK: _ods_result_type_source_attr = attributes["type"]
+ // CHECK: _ods_derived_result_type = (
// CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value
// CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
// CHECK: _ods_result_type_source_attr.type)
- // CHECK: results.extend([_ods_derived_result_type] * 2)
+ // CHECK: results = [_ods_derived_result_type] * 2
let arguments = (ins TypeAttr:$type);
let results = (outs AnyType:$res, AnyType);
}
-// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK: return DeriveResultTypesOp(type_=type_, loc=loc, ip=ip).results
+// CHECK: def derive_result_types_op(type_, *, results=None, loc=None, ip=None)
+// CHECK: return DeriveResultTypesOp(type_=type_, results=results, loc=loc, ip=ip).results
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -258,9 +258,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
def EmptyOp : TestOp<"empty">;
// CHECK: def __init__(self, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -272,31 +272,31 @@ def EmptyOp : TestOp<"empty">;
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
- // CHECK: def __init__(self, *, loc=None, ip=None):
+ // CHECK: def __init__(self, *, results=None, loc=None, ip=None):
// CHECK: _ods_context = _ods_get_default_loc_context(loc)
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
- // CHECK: attributes=attributes, operands=operands,
+ // CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
let results = (outs I32:$i32, F32:$f32);
}
-// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK: return InferResultTypesImpliedOp(loc=loc, ip=ip).results
+// CHECK: def infer_result_types_implied_op(*, results=None, loc=None, ip=None)
+// CHECK: return InferResultTypesImpliedOp(results=results, loc=loc, ip=ip).results
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
- // CHECK: def __init__(self, *, loc=None, ip=None):
+ // CHECK: def __init__(self, *, results=None, loc=None, ip=None):
// CHECK: operands = []
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
- // CHECK: attributes=attributes, operands=operands,
+ // CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
let results = (outs AnyType, AnyType, AnyType);
}
-// CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK: return InferResultTypesOp(loc=loc, ip=ip).results
+// CHECK: def infer_result_types_op(*, results=None, loc=None, ip=None)
+// CHECK: return InferResultTypesOp(results=results, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -304,12 +304,12 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(_gen_arg_0)
// CHECK: operands.append(f32)
// CHECK: operands.append(_gen_arg_2)
+ // CHECK: results = []
// CHECK: results.append(i32)
// CHECK: results.append(_gen_res_1)
// CHECK: results.append(i64)
@@ -346,11 +346,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
let arguments = (ins AnyType:$non_optional, Optional<AnyType>:$optional);
// CHECK: def __init__(self, non_optional, *, optional=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(non_optional)
// CHECK: if optional is not None: operands.append(optional)
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -377,11 +377,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: def __init__(self, non_variadic, variadic, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(non_variadic)
// CHECK: operands.extend(_get_op_results_or_values(variadic))
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -410,9 +410,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
def OneVariadicResultOp : TestOp<"one_variadic_result"> {
// CHECK: def __init__(self, variadic, non_variadic, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
+ // CHECK: results = []
// CHECK: results.extend(variadic)
// CHECK: results.append(non_variadic)
// CHECK: _ods_successors = None
@@ -442,10 +442,10 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: def __init__(self, in_, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(in_)
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -463,17 +463,16 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
- // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
+ // CHECK: def __init__(self, in1, in2, *, results=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: operands.append
- // CHECK: results.extend([operands[0].type] * 1)
+ // CHECK: if results is None: results = [operands[0].type] * 1
let arguments = (ins AnyType:$in1, AnyType:$in2);
let results = (outs AnyType:$res);
}
-// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK: return SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)
+// CHECK: def same_results(in1, in2, *, results=None, loc=None, ip=None)
+// CHECK: return SameResultsOp(in1=in1, in2=in2, results=results, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -544,11 +543,11 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
def SimpleOp : TestOp<"simple"> {
// CHECK: def __init__(self, i64, f64, i32, f32, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(i32)
// CHECK: operands.append(f32)
+ // CHECK: results = []
// CHECK: results.append(i64)
// CHECK: results.append(f64)
// CHECK: _ods_successors = None
@@ -584,9 +583,9 @@ def SimpleOp : TestOp<"simple"> {
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
// CHECK: def __init__(self, num_variadic, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: regions = 2 + num_variadic
// CHECK: super().__init__(
@@ -612,9 +611,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
def VariadicRegionOp : TestOp<"variadic_region"> {
// CHECK: def __init__(self, num_variadic, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: regions = 0 + num_variadic
// CHECK: super().__init__(
diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py
index a45ca48b5c484..01b5542119b4e 100644
--- a/mlir/test/python/ir/auto_location.py
+++ b/mlir/test/python/ir/auto_location.py
@@ -51,7 +51,7 @@ def testInferLocations():
_cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__)
three = arith.constant(IndexType.get(), 3)
# fmt: off
- # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
# fmt: on
print(three.location)
@@ -60,14 +60,14 @@ def foo():
print(four.location)
# fmt: off
- # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
# fmt: on
foo()
_cext.globals.register_traceback_file_exclusion(__file__)
# fmt: off
- # CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218))
+ # CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235))
# fmt: on
foo()
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 038f56d5a2150..6a7aa9e3432d5 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -492,7 +492,6 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
constexpr const char *initTemplate = R"Py(
def __init__(self, {0}):
operands = []
- results = []
attributes = {{}
regions = None
{1}
@@ -738,18 +737,24 @@ populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
}
}
-/// Python code template for deriving the operation result types from its
-/// attribute:
+/// Python code template of generating result types for
+/// FirstAttrDerivedResultType trait
/// - {0} is the name of the attribute from which to derive the types.
-constexpr const char *deriveTypeFromAttrTemplate =
- R"Py(_ods_result_type_source_attr = attributes["{0}"]
-_ods_derived_result_type = (
+/// - {1} is the number of results.
+constexpr const char *firstAttrDerivedResultTypeTemplate =
+ R"Py(if results is None:
+ _ods_result_type_source_attr = attributes["{0}"]
+ _ods_derived_result_type = (
_ods_ir.TypeAttr(_ods_result_type_source_attr).value
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
- _ods_result_type_source_attr.type))Py";
+ _ods_result_type_source_attr.type)
+ results = [_ods_derived_result_type] * {1})Py";
-/// Python code template appending {0} type {1} times to the results list.
-constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
+/// Python code template of generating result types for
+/// SameOperandsAndResultType trait
+/// - {0} is the number of results.
+constexpr const char *sameOperandsAndResultTypeTemplate =
+ R"Py(if results is None: results = [operands[0].type] * {0})Py";
/// Appends the given multiline string as individual strings into
/// `builderLines`.
@@ -768,11 +773,10 @@ static void appendLineByLine(StringRef string,
static void
populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
SmallVectorImpl<std::string> &builderLines) {
- bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
-
if (hasSameArgumentAndResultTypes(op)) {
- builderLines.push_back(formatv(appendSameResultsTemplate,
- "operands[0].type", op.getNumResults()));
+ appendLineByLine(
+ formatv(sameOperandsAndResultTypeTemplate, op.getNumResults()).str(),
+ builderLines);
return;
}
@@ -780,17 +784,19 @@ populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
const NamedAttribute &firstAttr = op.getAttribute(0);
assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
"from which the type is derived");
- appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
+ appendLineByLine(formatv(firstAttrDerivedResultTypeTemplate, firstAttr.name,
+ op.getNumResults())
+ .str(),
builderLines);
- builderLines.push_back(formatv(appendSameResultsTemplate,
- "_ods_derived_result_type",
- op.getNumResults()));
return;
}
if (hasInferTypeInterface(op))
return;
+ bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
+ builderLines.push_back("results = []");
+
// For each element, find or generate a name.
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
const NamedTypeConstraint &element = op.getResult(i);
@@ -909,6 +915,9 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
functionArgs.push_back(builderArgs[i]);
}
}
+ if (canInferType(op)) {
+ functionArgs.push_back("results=None");
+ }
functionArgs.push_back("loc=None");
functionArgs.push_back("ip=None");
@@ -918,8 +927,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
initArgs.push_back("self._ODS_OPERAND_SEGMENTS");
initArgs.push_back("self._ODS_RESULT_SEGMENTS");
initArgs.push_back("attributes=attributes");
- if (!hasInferTypeInterface(op))
- initArgs.push_back("results=results");
+ initArgs.push_back("results=results");
initArgs.push_back("operands=operands");
initArgs.push_back("successors=_ods_successors");
initArgs.push_back("regions=regions");
>From 7d434c90e0f88b684bb5b052f667e1dbc2d6faab Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 4 Sep 2025 23:20:58 +0800
Subject: [PATCH 2/3] add test case for results parameter
---
mlir/test/python/dialects/python_test.py | 13 +++++++++++++
1 file changed, 13 insertions(+)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 694616696a9e2..68262822ca6b5 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -283,6 +283,10 @@ def resultTypesDefinedByTraits():
module = Module.create()
with InsertionPoint(module.body):
inferred = test.InferResultsOp()
+
+ # CHECK: i32 i64
+ print(inferred.single.type, inferred.doubled.type)
+
same = test.SameOperandAndResultTypeOp([inferred.results[0]])
# CHECK-COUNT-2: i32
print(same.one.type)
@@ -309,6 +313,15 @@ def resultTypesDefinedByTraits():
# CHECK: index
print(implied.index.type)
+ # provide the result types to avoid inferring them
+ f64 = F64Type.get()
+ no_imply = test.InferResultsImpliedOp(results=[f64, f64, f64])
+ # CHECK-COUNT-3: f64
+ print(no_imply.integer.type, no_imply.flt.type, no_imply.index.type)
+
+ no_infer = test.InferResultsOp(results=[F32Type.get(), IndexType.get()])
+ # CHECK: f32 index
+ print(no_infer.single.type, no_infer.doubled.type)
# CHECK-LABEL: TEST: testOptionalOperandOp
@run
>From 4e72e8f2dd62041284d7d5ff3b3fda20d997596a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 5 Sep 2025 00:41:12 +0800
Subject: [PATCH 3/3] fix test
---
mlir/test/mlir-tblgen/op-python-bindings.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 4d5d7ee26b775..3ec69c33b4bb9 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -234,7 +234,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
// CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value
// CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
// CHECK: _ods_result_type_source_attr.type)
- // CHECK: results = [_ods_derived_result_type] * 2
+ // CHECK: results = [_ods_derived_result_type] * 2
let arguments = (ins TypeAttr:$type);
let results = (outs AnyType:$res, AnyType);
}
More information about the Mlir-commits
mailing list