[Mlir-commits] [mlir] [MLIR][Python] fix generated value builder type hints (PR #158449)
Maksim Levental
llvmlistbot at llvm.org
Sat Sep 13 18:13:57 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/158449
>From 259a29d52e3cbdf2999093d02d64db6ed15fc176 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 13 Sep 2025 20:56:40 -0400
Subject: [PATCH] [MLIR][Python] fix generated value builder type hints
---
mlir/test/mlir-tblgen/op-python-bindings.td | 20 +++++-----
mlir/test/python/dialects/python_test.py | 6 +++
mlir/test/python/ir/auto_location.py | 6 +--
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 38 +++++++------------
4 files changed, 32 insertions(+), 38 deletions(-)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 3ec69c33b4bb9..8fa9c19615081 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -108,8 +108,8 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
Variadic<AnyType>:$variadic2);
}
-// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _ods_ir.OpResultList:
+// CHECK: return AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -249,8 +249,8 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
let results = (outs AnyType:$res, Variadic<AnyType>);
}
-// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -> _ods_ir.OpResultList:
+// CHECK: return DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class EmptyOp(_ods_ir.OpView):
@@ -433,8 +433,8 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
}
-// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
+// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -> _ods_ir.OpResultList:
+// CHECK: return OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class PythonKeywordOp(_ods_ir.OpView):
@@ -481,8 +481,8 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
let results = (outs Variadic<AnyType>:$res);
}
-// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -> _ods_ir.OpResultList:
+// CHECK: return SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -534,8 +534,8 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
Variadic<AnyType>:$variadic2);
}
-// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _ods_ir.OpResultList:
+// CHECK: return SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SimpleOp(_ods_ir.OpView):
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 68262822ca6b5..16d4ccaca1b0f 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -323,6 +323,7 @@ def resultTypesDefinedByTraits():
# CHECK: f32 index
print(no_infer.single.type, no_infer.doubled.type)
+
# CHECK-LABEL: TEST: testOptionalOperandOp
@run
def testOptionalOperandOp():
@@ -642,6 +643,11 @@ def types(lst):
# CHECK: [IntegerType(i3), IntegerType(i4)]
print(types(op.variadic2))
+ assert isinstance(
+ test.same_variadic_result_vfv([i[0], i[1]], i[2], [i[3], i[4]]),
+ OpResultList,
+ )
+
# Test Variadic-Variadic-Variadic
op = test.SameVariadicResultSizeOpVVV(
[i[0], i[1]], [i[2], i[3]], [i[4], i[5]]
diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py
index 01b5542119b4e..a063aa972cc48 100644
--- a/mlir/test/python/ir/auto_location.py
+++ b/mlir/test/python/ir/auto_location.py
@@ -51,7 +51,7 @@ def testInferLocations():
_cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__)
three = arith.constant(IndexType.get(), 3)
# fmt: off
- # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
# fmt: on
print(three.location)
@@ -60,14 +60,14 @@ def foo():
print(four.location)
# fmt: off
- # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
# fmt: on
foo()
_cext.globals.register_traceback_file_exclusion(__file__)
# fmt: off
- # CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235))
+ # CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235))
# fmt: on
foo()
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 6a7aa9e3432d5..3de8bd7da2d89 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -36,7 +36,6 @@ from ._ods_common import _cext as _ods_cext
from ._ods_common import (
equally_sized_accessor as _ods_equally_sized_accessor,
get_default_loc_context as _ods_get_default_loc_context,
- get_op_result_or_op_results as _get_op_result_or_op_results,
get_op_results_or_values as _get_op_results_or_values,
segmented_accessor as _ods_segmented_accessor,
)
@@ -275,11 +274,6 @@ def {0}({2}) -> {4}:
return {1}({3}){5}
)Py";
-constexpr const char *valueBuilderVariadicTemplate = R"Py(
-def {0}({2}) -> {4}:
- return _get_op_result_or_op_results({1}({3}))
-)Py";
-
static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");
@@ -1013,25 +1007,19 @@ static void emitValueBuilder(const Operator &op,
nameWithoutDialect += "_";
std::string params = llvm::join(valueBuilderParams, ", ");
std::string args = llvm::join(opBuilderArgs, ", ");
- const char *type =
- (op.getNumResults() > 1
- ? "_Sequence[_ods_ir.Value]"
- : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"));
- if (op.getNumVariableLengthResults() > 0) {
- os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
- op.getCppClassName(), params, args, type);
- } else {
- const char *results;
- if (op.getNumResults() == 0) {
- results = "";
- } else if (op.getNumResults() == 1) {
- results = ".result";
- } else {
- results = ".results";
- }
- os << formatv(valueBuilderTemplate, nameWithoutDialect,
- op.getCppClassName(), params, args, type, results);
- }
+ std::string type =
+ (op.getNumResults() > 1 || op.getNumVariableLengthResults())
+ ? "_ods_ir.OpResultList"
+ : op.getNumResults() == 1
+ ? "_ods_ir.OpResult"
+ : /*op.getNumResults() == 0*/ op.getCppClassName().str();
+ const char *results = "";
+ if (op.getNumResults() > 1 || op.getNumVariableLengthResults())
+ results = ".results";
+ else if (op.getNumResults() == 1)
+ results = ".result";
+ os << formatv(valueBuilderTemplate, nameWithoutDialect, op.getCppClassName(),
+ params, args, type, results);
}
/// Emits bindings for a specific Op to the given output stream.
More information about the Mlir-commits
mailing list