[Mlir-commits] [mlir] [mlir][python] value casting (PR #68763)
Maksim Levental
llvmlistbot at llvm.org
Tue Oct 10 22:29:52 PDT 2023
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/68763
None
>From 5f59ee9ddaccd50cd12ff68c9359e2b9bcbe5bdc Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 11 Oct 2023 00:28:04 -0500
Subject: [PATCH] [mlir][python] value casting
---
mlir/python/mlir/dialects/_ods_common.py | 71 ++++++++++-
mlir/python/mlir/ir.py | 1 +
mlir/test/mlir-tblgen/op-python-bindings.td | 116 +++++++++++-------
mlir/test/python/dialects/arith_dialect.py | 59 ++++++++-
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 30 ++++-
5 files changed, 222 insertions(+), 55 deletions(-)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 895c3228139b392..6ab516fb87ec609 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -1,11 +1,18 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from collections import defaultdict
# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
from .._mlir_libs import _mlir as _cext
-from typing import Sequence as _Sequence, Union as _Union
+from typing import (
+ Callable as _Callable,
+ Sequence as _Sequence,
+ Type as _Type,
+ TypeVar as _TypeVar,
+ Union as _Union,
+)
__all__ = [
"equally_sized_accessor",
@@ -182,3 +189,65 @@ def get_op_result_or_op_results(
if len(op.results) > 0
else op
)
+
+
+U = _TypeVar("U", bound=_cext.ir.Value)
+SubClassValueT = _Type[U]
+
+TypeCasterT = _Callable[
+ [_Union[_cext.ir.Value, _cext.ir.OpResult]], _Union[SubClassValueT, None]
+]
+
+__VALUE_CASTERS: defaultdict[
+ _cext.ir.TypeID,
+ _Sequence[TypeCasterT],
+] = defaultdict(list)
+
+
+def register_value_caster(typeid: _cext.ir.TypeID, priority: int = None):
+ def wrapper(caster: TypeCasterT):
+ if not isinstance(typeid, _cext.ir.TypeID):
+ raise ValueError(f"{typeid=} is not a TypeID")
+ if priority is None:
+ __VALUE_CASTERS[typeid].append(caster)
+ else:
+ __VALUE_CASTERS[typeid].insert(priority, caster)
+ return caster
+
+ return wrapper
+
+
+def has_value_caster(typeid: _cext.ir.TypeID):
+ if not isinstance(typeid, _cext.ir.TypeID):
+ raise ValueError(f"{typeid=} is not a TypeID")
+ if typeid in __VALUE_CASTERS:
+ return True
+ return False
+
+
+def get_value_caster(typeid: _cext.ir.TypeID):
+ if not has_value_caster(typeid):
+ raise ValueError(f"no registered caster for {typeid=}")
+ return __VALUE_CASTERS[typeid]
+
+
+def maybe_cast(
+ val: _Union[
+ _cext.ir.Value,
+ _cext.ir.OpResult,
+ _Sequence[_cext.ir.Value],
+ _Sequence[_cext.ir.OpResult],
+ _cext.ir.Operation,
+ ]
+) -> _Union[SubClassValueT, _Sequence[SubClassValueT], _cext.ir.Operation]:
+ if isinstance(val, (tuple, list)):
+ return tuple(map(maybe_cast, val))
+
+ if not isinstance(val, _cext.ir.Value) and not isinstance(val, _cext.ir.OpResult):
+ return val
+
+ if has_value_caster(val.type.typeid):
+ for caster in get_value_caster(val.type.typeid):
+ if casted := caster(val):
+ return casted
+ return val
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 43553f3118a51fc..f7f6b54919ec8df 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,6 +5,7 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster
+from .dialects._ods_common import register_value_caster as register_value_caster
# Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 8ca23fa9f45c4ab..8ea4865ad787048 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -60,8 +60,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
Optional<AnyType>:$variadic2);
}
-// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = AttrSizedOperandsOp.__base__ if len(AttrSizedOperandsOp.__bases__) > 1 else AttrSizedOperandsOp
+// CHECK: return op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -107,9 +108,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
Variadic<AnyType>:$variadic2);
}
-// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
-
+// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Sequence[_SubClassValueT]:
+// CHECK: op = AttrSizedResultsOp.__base__ if len(AttrSizedResultsOp.__bases__) > 1 else AttrSizedResultsOp
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOp(_ods_ir.OpView):
@@ -157,8 +158,9 @@ def AttributedOp : TestOp<"attributed_op"> {
UnitAttr:$unitAttr, I32Attr:$in);
}
-// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = AttributedOp.__base__ if len(AttributedOp.__bases__) > 1 else AttributedOp
+// CHECK: return op(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -193,8 +195,9 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
}
-// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = AttributedOpWithOperands.__base__ if len(AttributedOpWithOperands.__bases__) > 1 else AttributedOpWithOperands
+// CHECK: return op(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -217,8 +220,9 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
let results = (outs);
}
-// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = DefaultValuedAttrsOp.__base__ if len(DefaultValuedAttrsOp.__bases__) > 1 else DefaultValuedAttrsOp
+// CHECK: return op(arr=arr, unsupported=unsupported, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -235,8 +239,9 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
let results = (outs AnyType:$res, AnyType);
}
-// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None) -> _Sequence[_SubClassValueT]:
+// CHECK: op = DeriveResultTypesOp.__base__ if len(DeriveResultTypesOp.__bases__) > 1 else DeriveResultTypesOp
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(op(type_=type_, loc=loc, ip=ip)))
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -245,8 +250,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
let results = (outs AnyType:$res, Variadic<AnyType>);
}
-// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -> _Sequence[_SubClassValueT]:
+// CHECK: op = DeriveResultTypesVariadicOp.__base__ if len(DeriveResultTypesVariadicOp.__bases__) > 1 else DeriveResultTypesVariadicOp
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(op(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class EmptyOp(_ods_ir.OpView):
@@ -262,8 +268,9 @@ def EmptyOp : TestOp<"empty">;
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
-// CHECK: def empty(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
+// CHECK: def empty(*, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = EmptyOp.__base__ if len(EmptyOp.__bases__) > 1 else EmptyOp
+// CHECK: return op(loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -275,8 +282,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
let results = (outs I32:$i32, F32:$f32);
}
-// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None) -> _Sequence[_SubClassValueT]:
+// CHECK: op = InferResultTypesImpliedOp.__base__ if len(InferResultTypesImpliedOp.__bases__) > 1 else InferResultTypesImpliedOp
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(op(loc=loc, ip=ip)))
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -288,8 +296,9 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
let results = (outs AnyType, AnyType, AnyType);
}
-// CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+// CHECK: def infer_result_types_op(*, loc=None, ip=None) -> _Sequence[_SubClassValueT]:
+// CHECK: op = InferResultTypesOp.__base__ if len(InferResultTypesOp.__bases__) > 1 else InferResultTypesOp
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(op(loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -326,8 +335,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}
-// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -> _Sequence[_SubClassValueT]:
+// CHECK: op = MissingNamesOp.__base__ if len(MissingNamesOp.__bases__) > 1 else MissingNamesOp
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(op(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -357,8 +367,9 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
}
-// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = OneOptionalOperandOp.__base__ if len(OneOptionalOperandOp.__bases__) > 1 else OneOptionalOperandOp
+// CHECK: return op(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -389,8 +400,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
}
-// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = OneVariadicOperandOp.__base__ if len(OneVariadicOperandOp.__bases__) > 1 else OneVariadicOperandOp
+// CHECK: return op(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -422,8 +434,9 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
}
-// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
+// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -> _Sequence[_SubClassValueT]:
+// CHECK: op = OneVariadicResultOp.__base__ if len(OneVariadicResultOp.__bases__) > 1 else OneVariadicResultOp
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(op(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class PythonKeywordOp(_ods_ir.OpView):
@@ -446,8 +459,9 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
let arguments = (ins AnyType:$in);
}
-// CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+// CHECK: def python_keyword(in_, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = PythonKeywordOp.__base__ if len(PythonKeywordOp.__bases__) > 1 else PythonKeywordOp
+// CHECK: return op(in_=in_, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -460,8 +474,9 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
let results = (outs AnyType:$res);
}
-// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK: def same_results(in1, in2, *, loc=None, ip=None) -> _SubClassValueT:
+// CHECK: op = SameResultsOp.__base__ if len(SameResultsOp.__bases__) > 1 else SameResultsOp
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(op(in1=in1, in2=in2, loc=loc, ip=ip)))
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -470,8 +485,9 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
let results = (outs Variadic<AnyType>:$res);
}
-// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -> _SubClassValueT:
+// CHECK: op = SameResultsVariadicOp.__base__ if len(SameResultsVariadicOp.__bases__) > 1 else SameResultsVariadicOp
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(op(res=res, in1=in1, in2=in2, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
@@ -497,8 +513,9 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
Variadic<AnyType>:$variadic2);
}
-// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = SameVariadicOperandSizeOp.__base__ if len(SameVariadicOperandSizeOp.__bases__) > 1 else SameVariadicOperandSizeOp
+// CHECK: return op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -523,8 +540,9 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
Variadic<AnyType>:$variadic2);
}
-// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Sequence[_SubClassValueT]:
+// CHECK: op = SameVariadicResultSizeOp.__base__ if len(SameVariadicResultSizeOp.__bases__) > 1 else SameVariadicResultSizeOp
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(op(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SimpleOp(_ods_ir.OpView):
@@ -563,8 +581,9 @@ def SimpleOp : TestOp<"simple"> {
let results = (outs I64:$i64, AnyFloat:$f64);
}
-// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _Sequence[_SubClassValueT]:
+// CHECK: op = SimpleOp.__base__ if len(SimpleOp.__bases__) > 1 else SimpleOp
+// CHECK: return _maybe_cast(_get_op_result_or_op_results(op(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)))
// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -590,8 +609,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
// CHECK: return self.regions[2:]
}
-// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = VariadicAndNormalRegionOp.__base__ if len(VariadicAndNormalRegionOp.__bases__) > 1 else VariadicAndNormalRegionOp
+// CHECK: return op(num_variadic=num_variadic, loc=loc, ip=ip)
// CHECK: class VariadicRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -613,8 +633,9 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
// CHECK: return self.regions[0:]
}
-// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = VariadicRegionOp.__base__ if len(VariadicRegionOp.__bases__) > 1 else VariadicRegionOp
+// CHECK: return op(num_variadic=num_variadic, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -628,5 +649,6 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
VariadicSuccessor<AnySuccessor>:$successors);
}
-// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
+// CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: op = WithSuccessorsOp.__base__ if len(WithSuccessorsOp.__bases__) > 1 else WithSuccessorsOp
+// CHECK: return op(successor=successor, successors=successors, loc=loc, ip=ip)
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index f4a793aee4aa14c..2a43f27abe5d304 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -1,8 +1,9 @@
# RUN: %PYTHON %s | FileCheck %s
+from functools import partialmethod
from mlir.ir import *
-import mlir.dialects.func as func
import mlir.dialects.arith as arith
+import mlir.dialects._arith_ops_ext as arith_ext
def run(f):
@@ -33,3 +34,59 @@ def testFastMathFlags():
)
# CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
print(r)
+
+
+# CHECK-LABEL: TEST: testArithValue
+ at run
+def testArithValue():
+ def _binary_op(lhs, rhs, op: str):
+ op = op.capitalize()
+ if arith_ext._is_float_type(lhs.type):
+ op += "F"
+ elif arith_ext._is_integer_like_type(lhs.type):
+ op += "I"
+ else:
+ raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
+
+ op = getattr(arith, f"{op}Op")
+ return ArithValue(op(lhs, rhs).result)
+
+ @register_value_caster(F16Type.static_typeid)
+ @register_value_caster(F32Type.static_typeid)
+ @register_value_caster(F64Type.static_typeid)
+ @register_value_caster(IntegerType.static_typeid)
+ class ArithValue(Value):
+ __add__ = partialmethod(_binary_op, op="add")
+ __sub__ = partialmethod(_binary_op, op="sub")
+ __mul__ = partialmethod(_binary_op, op="mul")
+
+ def __str__(self):
+ return super().__str__().replace("Value", "ArithValue")
+
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f16_t = F16Type.get()
+ f32_t = F32Type.get()
+ f64_t = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+
+ with InsertionPoint(module.body):
+ a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
+ b = a + a
+ # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
+ print(b)
+
+ a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
+ b = a - a
+ # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
+ print(b)
+
+ a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
+ b = a * a
+ # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
+ print(b)
+
+ a = arith.constant(value=IntegerAttr.get(i32, 1))
+ b = a * a
+ # CHECK: ArithValue(%3 = arith.muli %c1_i32, %c1_i32 : i32)
+ print(b)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 2c81538b7b40433..ec2a4770b130a33 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,17 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
+from ._ods_common import (
+ SubClassValueT as _SubClassValueT,
+ equally_sized_accessor as _ods_equally_sized_accessor,
+ extend_opview_class as _ods_extend_opview_class,
+ get_default_loc_context as _ods_get_default_loc_context,
+ get_op_result_or_op_results as _get_op_result_or_op_results,
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ maybe_cast as _maybe_cast,
+ segmented_accessor as _ods_segmented_accessor,
+)
_ods_ir = _ods_cext.ir
try:
@@ -39,7 +49,7 @@ except ImportError:
_ods_ext_module = None
import builtins
-from typing import Sequence as _Sequence, Union as _Union
+from typing import Sequence as _Sequence
)Py";
@@ -269,7 +279,14 @@ constexpr const char *regionAccessorTemplate = R"Py(
constexpr const char *valueBuilderTemplate = R"Py(
def {0}({2}) -> {4}:
- return _get_op_result_or_op_results({1}({3}))
+ op = {1}.__base__ if len({1}.__bases__) > 1 else {1}
+ return _maybe_cast(_get_op_result_or_op_results(op({3})))
+)Py";
+
+constexpr const char *valueBuilderNoResultsTemplate = R"Py(
+def {0}({2}) -> {4}:
+ op = {1}.__base__ if len({1}.__bases__) > 1 else {1}
+ return op({3})
)Py";
static llvm::cl::OptionCategory
@@ -1009,14 +1026,15 @@ static void emitValueBuilder(const Operator &op,
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
});
os << llvm::formatv(
- valueBuilderTemplate,
+ op.getNumResults() > 0 ? valueBuilderTemplate
+ : valueBuilderNoResultsTemplate,
// Drop dialect name and then sanitize again (to catch e.g. func.return).
sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
llvm::join(opBuilderArgs, ", "),
(op.getNumResults() > 1
- ? "_Sequence[_ods_ir.OpResult]"
- : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+ ? "_Sequence[_SubClassValueT]"
+ : (op.getNumResults() > 0 ? "_SubClassValueT"
: "_ods_ir.Operation")));
}
More information about the Mlir-commits
mailing list