[Mlir-commits] [mlir] [mlir][python] fix value builders (PR #68764)

Maksim Levental llvmlistbot at llvm.org
Tue Oct 10 22:47:13 PDT 2023


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/68764

I goofed and didn't correct anticipate the interplay of generated builders and mixin classes (in `_ext.py`s). This fixes by generating a check for `__has_mixin__` and branching accordingly (in the generated value builder).

>From c53588976c4d8adb58d5a03884db94b9576a2689 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 11 Oct 2023 00:39:40 -0500
Subject: [PATCH] [mlir][python] fix value binders

---
 mlir/python/mlir/dialects/_ods_common.py      |   1 +
 mlir/test/mlir-tblgen/op-python-bindings.td   | 116 +++++++++++-------
 mlir/test/python/dialects/arith_dialect.py    |  13 ++
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp |  14 ++-
 4 files changed, 94 insertions(+), 50 deletions(-)

diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 895c3228139b392..0b1319c6f3c9366 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -71,6 +71,7 @@ class LocalOpView(mixin_cls, parent_opview_cls):
             ) from e
         LocalOpView.__name__ = parent_opview_cls.__name__
         LocalOpView.__qualname__ = parent_opview_cls.__qualname__
+        LocalOpView.__has_mixin__ = True
         return LocalOpView
 
     return class_decorator
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 8ca23fa9f45c4ab..1ab5846a69a2807 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -60,8 +60,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
                    Optional<AnyType>:$variadic2);
 }
 
-// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = AttrSizedOperandsOp.__base__ if getattr(AttrSizedOperandsOp, "__has_mixin__", False) else AttrSizedOperandsOp
+// CHECK:   return op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -107,9 +108,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
                  Variadic<AnyType>:$variadic2);
 }
 
-// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
-
+// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   op = AttrSizedResultsOp.__base__ if getattr(AttrSizedResultsOp, "__has_mixin__", False) else AttrSizedResultsOp
+// CHECK:   return _get_op_result_or_op_results(op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOp(_ods_ir.OpView):
@@ -157,8 +158,9 @@ def AttributedOp : TestOp<"attributed_op"> {
                    UnitAttr:$unitAttr, I32Attr:$in);
 }
 
-// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK:     return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = AttributedOp.__base__ if getattr(AttributedOp, "__has_mixin__", False) else AttributedOp
+// CHECK:   return op(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -193,8 +195,9 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
 }
 
-// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = AttributedOpWithOperands.__base__ if getattr(AttributedOpWithOperands, "__has_mixin__", False) else AttributedOpWithOperands
+// CHECK:   return op(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -217,8 +220,9 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
   let results = (outs);
 }
 
-// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = DefaultValuedAttrsOp.__base__ if getattr(DefaultValuedAttrsOp, "__has_mixin__", False) else DefaultValuedAttrsOp
+// CHECK:   return op(arr=arr, unsupported=unsupported, loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -235,8 +239,9 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
   let results = (outs AnyType:$res, AnyType);
 }
 
-// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   op = DeriveResultTypesOp.__base__ if getattr(DeriveResultTypesOp, "__has_mixin__", False) else DeriveResultTypesOp
+// CHECK:   return _get_op_result_or_op_results(op(type_=type_, loc=loc, ip=ip))
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -245,8 +250,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
   let results = (outs AnyType:$res, Variadic<AnyType>);
 }
 
-// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   op = DeriveResultTypesVariadicOp.__base__ if getattr(DeriveResultTypesVariadicOp, "__has_mixin__", False) else DeriveResultTypesVariadicOp
+// CHECK:   return _get_op_result_or_op_results(op(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class EmptyOp(_ods_ir.OpView):
@@ -262,8 +268,9 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:     attributes=attributes, results=results, operands=operands,
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
-// CHECK: def empty(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
+// CHECK: def empty(*, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = EmptyOp.__base__ if getattr(EmptyOp, "__has_mixin__", False) else EmptyOp
+// CHECK:   return op(loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
 def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -275,8 +282,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
   let results = (outs I32:$i32, F32:$f32);
 }
 
-// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   op = InferResultTypesImpliedOp.__base__ if getattr(InferResultTypesImpliedOp, "__has_mixin__", False) else InferResultTypesImpliedOp
+// CHECK:   return _get_op_result_or_op_results(op(loc=loc, ip=ip))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -288,8 +296,9 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
   let results = (outs AnyType, AnyType, AnyType);
 }
 
-// CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+// CHECK: def infer_result_types_op(*, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   op = InferResultTypesOp.__base__ if getattr(InferResultTypesOp, "__has_mixin__", False) else InferResultTypesOp
+// CHECK:   return _get_op_result_or_op_results(op(loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -326,8 +335,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
   let results = (outs I32:$i32, AnyFloat, I64:$i64);
 }
 
-// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   op = MissingNamesOp.__base__ if getattr(MissingNamesOp, "__has_mixin__", False) else MissingNamesOp
+// CHECK:   return _get_op_result_or_op_results(op(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -357,8 +367,9 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
   // CHECK:   return None if len(self.operation.operands) < 2 else self.operation.operands[1]
 }
 
-// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = OneOptionalOperandOp.__base__ if getattr(OneOptionalOperandOp, "__has_mixin__", False) else OneOptionalOperandOp
+// CHECK:   return op(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -389,8 +400,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
 }
 
-// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = OneVariadicOperandOp.__base__ if getattr(OneVariadicOperandOp, "__has_mixin__", False) else OneVariadicOperandOp
+// CHECK:   return op(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -422,8 +434,9 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
   let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
 }
 
-// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
+// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   op = OneVariadicResultOp.__base__ if getattr(OneVariadicResultOp, "__has_mixin__", False) else OneVariadicResultOp
+// CHECK:   return _get_op_result_or_op_results(op(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class PythonKeywordOp(_ods_ir.OpView):
@@ -446,8 +459,9 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
   let arguments = (ins AnyType:$in);
 }
 
-// CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+// CHECK: def python_keyword(in_, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = PythonKeywordOp.__base__ if getattr(PythonKeywordOp, "__has_mixin__", False) else PythonKeywordOp
+// CHECK:   return op(in_=in_, loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results"
 def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -460,8 +474,9 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
   let results = (outs AnyType:$res);
 }
 
-// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK: def same_results(in1, in2, *, loc=None, ip=None) -> _ods_ir.OpResult:
+// CHECK:   op = SameResultsOp.__base__ if getattr(SameResultsOp, "__has_mixin__", False) else SameResultsOp
+// CHECK:   return _get_op_result_or_op_results(op(in1=in1, in2=in2, loc=loc, ip=ip))
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
 def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -470,8 +485,9 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
   let results = (outs Variadic<AnyType>:$res);
 }
 
-// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -> _ods_ir.OpResult:
+// CHECK:   op = SameResultsVariadicOp.__base__ if getattr(SameResultsVariadicOp, "__has_mixin__", False) else SameResultsVariadicOp
+// CHECK:   return _get_op_result_or_op_results(op(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
 
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
@@ -497,8 +513,9 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
                    Variadic<AnyType>:$variadic2);
 }
 
-// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = SameVariadicOperandSizeOp.__base__ if getattr(SameVariadicOperandSizeOp, "__has_mixin__", False) else SameVariadicOperandSizeOp
+// CHECK:   return op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -523,8 +540,9 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
                  Variadic<AnyType>:$variadic2);
 }
 
-// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   op = SameVariadicResultSizeOp.__base__ if getattr(SameVariadicResultSizeOp, "__has_mixin__", False) else SameVariadicResultSizeOp
+// CHECK:   return _get_op_result_or_op_results(op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SimpleOp(_ods_ir.OpView):
@@ -563,8 +581,9 @@ def SimpleOp : TestOp<"simple"> {
   let results = (outs I64:$i64, AnyFloat:$f64);
 }
 
-// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   op = SimpleOp.__base__ if getattr(SimpleOp, "__has_mixin__", False) else SimpleOp
+// CHECK:   return _get_op_result_or_op_results(op(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
 
 // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -590,8 +609,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
   // CHECK:    return self.regions[2:]
 }
 
-// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = VariadicAndNormalRegionOp.__base__ if getattr(VariadicAndNormalRegionOp, "__has_mixin__", False) else VariadicAndNormalRegionOp
+// CHECK:   return op(num_variadic=num_variadic, loc=loc, ip=ip)
 
 // CHECK: class VariadicRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -613,8 +633,9 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
   // CHECK:    return self.regions[0:]
 }
 
-// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = VariadicRegionOp.__base__ if getattr(VariadicRegionOp, "__has_mixin__", False) else VariadicRegionOp
+// CHECK:   return op(num_variadic=num_variadic, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -628,5 +649,6 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
                               VariadicSuccessor<AnySuccessor>:$successors);
 }
 
-// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
+// CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   op = WithSuccessorsOp.__base__ if getattr(WithSuccessorsOp, "__has_mixin__", False) else WithSuccessorsOp
+// CHECK:   return op(successor=successor, successors=successors, loc=loc, ip=ip)
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index f4a793aee4aa14c..6d1c5eab7589847 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -33,3 +33,16 @@ def testFastMathFlags():
             )
             # CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
             print(r)
+
+
+# CHECK-LABEL: TEST: testArithValueBuilder
+ at run
+def testArithValueBuilder():
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32_t = F32Type.get()
+
+        with InsertionPoint(module.body):
+            a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
+            # CHECK: %cst = arith.constant 4.242000e+01 : f32
+            print(a)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 2c81538b7b40433..dda7d3ab466d75f 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -39,7 +39,7 @@ except ImportError:
   _ods_ext_module = None
 
 import builtins
-from typing import Sequence as _Sequence, Union as _Union
+from typing import Sequence as _Sequence
 
 )Py";
 
@@ -269,7 +269,14 @@ constexpr const char *regionAccessorTemplate = R"Py(
 
 constexpr const char *valueBuilderTemplate = R"Py(
 def {0}({2}) -> {4}:
-  return _get_op_result_or_op_results({1}({3}))
+  op = {1}.__base__ if getattr({1}, "__has_mixin__", False) else {1}
+  return _get_op_result_or_op_results(op({3}))
+)Py";
+
+constexpr const char *valueBuilderNoResultsTemplate = R"Py(
+def {0}({2}) -> {4}:
+  op = {1}.__base__ if getattr({1}, "__has_mixin__", False) else {1}
+  return op({3})
 )Py";
 
 static llvm::cl::OptionCategory
@@ -1009,7 +1016,8 @@ static void emitValueBuilder(const Operator &op,
         return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
       });
   os << llvm::formatv(
-      valueBuilderTemplate,
+      op.getNumResults() > 0 ? valueBuilderTemplate
+                             : valueBuilderNoResultsTemplate,
       // Drop dialect name and then sanitize again (to catch e.g. func.return).
       sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
       op.getCppClassName(), llvm::join(valueBuilderParams, ", "),



More information about the Mlir-commits mailing list