[Mlir-commits] [mlir] f223fcf - [MLIR][python bindings] Add some AttrBuilder and port _exts to use them.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 26 15:50:42 PDT 2023
Author: max
Date: 2023-04-26T17:50:10-05:00
New Revision: f223fcf67f00c58f49a759b5d43e82d277a346d7
URL: https://github.com/llvm/llvm-project/commit/f223fcf67f00c58f49a759b5d43e82d277a346d7
DIFF: https://github.com/llvm/llvm-project/commit/f223fcf67f00c58f49a759b5d43e82d277a346d7.diff
LOG: [MLIR][python bindings] Add some AttrBuilder and port _exts to use them.
Differential Revision: https://reviews.llvm.org/D149287
Added:
Modified:
mlir/python/mlir/dialects/_loop_transform_ops_ext.py
mlir/python/mlir/dialects/_pdl_ops_ext.py
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/python/mlir/dialects/_transform_ops_ext.py
mlir/python/mlir/ir.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
index 0dc8fc07431a2..a275ea615378e 100644
--- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
@@ -11,65 +11,63 @@
from typing import Optional, Union
-def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]],
- default_value: int = None):
- if isinstance(arg, IntegerAttr):
- return arg
-
- if arg is None:
- assert default_value is not None, "must provide default value"
- arg = default_value
-
- return IntegerAttr.get(IntegerType.get_signless(64), arg)
-
-
class GetParentForOp:
"""Extension for GetParentForOp."""
- def __init__(self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- num_loops: int = 1,
- ip=None,
- loc=None):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ num_loops: Optional[int] = None,
+ ip=None,
+ loc=None,
+ ):
+ if num_loops is None:
+ num_loops = 1
super().__init__(
result_type,
_get_op_result_or_value(target),
- num_loops=_get_int64_attr(num_loops, default_value=1),
+ num_loops=num_loops,
ip=ip,
- loc=loc)
+ loc=loc,
+ )
class LoopOutlineOp:
"""Extension for LoopOutlineOp."""
- def __init__(self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- func_name: Union[str, StringAttr],
- ip=None,
- loc=None):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ func_name: Union[str, StringAttr],
+ ip=None,
+ loc=None,
+ ):
super().__init__(
result_type,
_get_op_result_or_value(target),
func_name=(func_name if isinstance(func_name, StringAttr) else
StringAttr.get(func_name)),
ip=ip,
- loc=loc)
+ loc=loc,
+ )
class LoopPeelOp:
"""Extension for LoopPeelOp."""
- def __init__(self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- fail_if_already_divisible: Union[bool, BoolAttr] = False,
- ip=None,
- loc=None):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ fail_if_already_divisible: Union[bool, BoolAttr] = False,
+ ip=None,
+ loc=None,
+ ):
super().__init__(
result_type,
_get_op_result_or_value(target),
@@ -77,40 +75,51 @@ def __init__(self,
fail_if_already_divisible, BoolAttr) else
BoolAttr.get(fail_if_already_divisible)),
ip=ip,
- loc=loc)
+ loc=loc,
+ )
class LoopPipelineOp:
"""Extension for LoopPipelineOp."""
- def __init__(self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- iteration_interval: Optional[Union[int, IntegerAttr]] = None,
- read_latency: Optional[Union[int, IntegerAttr]] = None,
- ip=None,
- loc=None):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ iteration_interval: Optional[Union[int, IntegerAttr]] = None,
+ read_latency: Optional[Union[int, IntegerAttr]] = None,
+ ip=None,
+ loc=None,
+ ):
+ if iteration_interval is None:
+ iteration_interval = 1
+ if read_latency is None:
+ read_latency = 10
super().__init__(
result_type,
_get_op_result_or_value(target),
- iteration_interval=_get_int64_attr(iteration_interval, default_value=1),
- read_latency=_get_int64_attr(read_latency, default_value=10),
+ iteration_interval=iteration_interval,
+ read_latency=read_latency,
ip=ip,
- loc=loc)
+ loc=loc,
+ )
class LoopUnrollOp:
"""Extension for LoopUnrollOp."""
- def __init__(self,
- target: Union[Operation, Value],
- *,
- factor: Union[int, IntegerAttr],
- ip=None,
- loc=None):
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ factor: Union[int, IntegerAttr],
+ ip=None,
+ loc=None,
+ ):
super().__init__(
_get_op_result_or_value(target),
- factor=_get_int64_attr(factor),
+ factor=factor,
ip=ip,
- loc=loc)
+ loc=loc,
+ )
diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py
index 428301b18f208..40ccbef6351dc 100644
--- a/mlir/python/mlir/dialects/_pdl_ops_ext.py
+++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py
@@ -8,61 +8,26 @@
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from typing import Union, Optional, Sequence, List, Mapping
-from ._ods_common import get_op_result_or_value as _get_value, get_op_results_or_values as _get_values
-
-
-def _get_int_attr(bits: int, value: Union[IntegerAttr, int]) -> IntegerAttr:
- """Converts the given value to signless integer attribute of given bit width."""
- if isinstance(value, int):
- ty = IntegerType.get_signless(bits)
- return IntegerAttr.get(ty, value)
- else:
- return value
-
-
-def _get_array_attr(attrs: Union[ArrayAttr, Sequence[Attribute]]) -> ArrayAttr:
- """Converts the given value to array attribute."""
- if isinstance(attrs, ArrayAttr):
- return attrs
- else:
- return ArrayAttr.get(list(attrs))
-
-
-def _get_str_array_attr(attrs: Union[ArrayAttr, Sequence[str]]) -> ArrayAttr:
- """Converts the given value to string array attribute."""
- if isinstance(attrs, ArrayAttr):
- return attrs
- else:
- return ArrayAttr.get([StringAttr.get(s) for s in attrs])
-
-
-def _get_str_attr(name: Union[StringAttr, str]) -> Optional[StringAttr]:
- """Converts the given value to string attribute."""
- if isinstance(name, str):
- return StringAttr.get(name)
- else:
- return name
-
-
-def _get_type_attr(type: Union[TypeAttr, Type]) -> TypeAttr:
- """Converts the given value to type attribute."""
- if isinstance(type, Type):
- return TypeAttr.get(type)
- else:
- return type
+from typing import Union, Optional, Sequence, Mapping
+from ._ods_common import (
+ get_op_result_or_value as _get_value,
+ get_op_results_or_values as _get_values,
+)
class ApplyNativeConstraintOp:
"""Specialization for PDL apply native constraint op class."""
- def __init__(self,
- name: Union[str, StringAttr],
- args: Sequence[Union[OpView, Operation, Value]] = [],
- *,
- loc=None,
- ip=None):
- name = _get_str_attr(name)
+ def __init__(
+ self,
+ name: Union[str, StringAttr],
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
args = _get_values(args)
super().__init__(name, args, loc=loc, ip=ip)
@@ -70,14 +35,17 @@ def __init__(self,
class ApplyNativeRewriteOp:
"""Specialization for PDL apply native rewrite op class."""
- def __init__(self,
- results: Sequence[Type],
- name: Union[str, StringAttr],
- args: Sequence[Union[OpView, Operation, Value]] = [],
- *,
- loc=None,
- ip=None):
- name = _get_str_attr(name)
+ def __init__(
+ self,
+ results: Sequence[Type],
+ name: Union[str, StringAttr],
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
args = _get_values(args)
super().__init__(results, name, args, loc=loc, ip=ip)
@@ -85,12 +53,14 @@ def __init__(self,
class AttributeOp:
"""Specialization for PDL attribute op class."""
- def __init__(self,
- valueType: Optional[Union[OpView, Operation, Value]] = None,
- value: Optional[Attribute] = None,
- *,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ valueType: Optional[Union[OpView, Operation, Value]] = None,
+ value: Optional[Attribute] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
valueType = valueType if valueType is None else _get_value(valueType)
result = pdl.AttributeType.get()
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
@@ -99,11 +69,13 @@ def __init__(self,
class EraseOp:
"""Specialization for PDL erase op class."""
- def __init__(self,
- operation: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ operation: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
operation = _get_value(operation)
super().__init__(operation, loc=loc, ip=ip)
@@ -111,11 +83,13 @@ def __init__(self,
class OperandOp:
"""Specialization for PDL operand op class."""
- def __init__(self,
- type: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ type: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
type = type if type is None else _get_value(type)
result = pdl.ValueType.get()
super().__init__(result, valueType=type, loc=loc, ip=ip)
@@ -124,11 +98,13 @@ def __init__(self,
class OperandsOp:
"""Specialization for PDL operands op class."""
- def __init__(self,
- types: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ types: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
types = types if types is None else _get_value(types)
result = pdl.RangeType.get(pdl.ValueType.get())
super().__init__(result, valueType=types, loc=loc, ip=ip)
@@ -137,15 +113,23 @@ def __init__(self,
class OperationOp:
"""Specialization for PDL operand op class."""
- def __init__(self,
- name: Optional[Union[str, StringAttr]] = None,
- args: Sequence[Union[OpView, Operation, Value]] = [],
- attributes: Mapping[str, Union[OpView, Operation, Value]] = {},
- types: Sequence[Union[OpView, Operation, Value]] = [],
- *,
- loc=None,
- ip=None):
- name = name if name is None else _get_str_attr(name)
+ def __init__(
+ self,
+ name: Optional[Union[str, StringAttr]] = None,
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ attributes: Optional[Mapping[str, Union[OpView, Operation,
+ Value]]] = None,
+ types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if types is None:
+ types = []
+ if attributes is None:
+ attributes = {}
+ if args is None:
+ args = []
args = _get_values(args)
attrNames = []
attrValues = []
@@ -155,22 +139,29 @@ def __init__(self,
attrNames = ArrayAttr.get(attrNames)
types = _get_values(types)
result = pdl.OperationType.get()
- super().__init__(result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip)
+ super().__init__(result,
+ args,
+ attrValues,
+ attrNames,
+ types,
+ opName=name,
+ loc=loc,
+ ip=ip)
class PatternOp:
"""Specialization for PDL pattern op class."""
- def __init__(self,
- benefit: Union[IntegerAttr, int],
- name: Optional[Union[StringAttr, str]] = None,
- *,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ benefit: Union[IntegerAttr, int],
+ name: Optional[Union[StringAttr, str]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
"""Creates an PDL `pattern` operation."""
- name_attr = None if name is None else _get_str_attr(name)
- benefit_attr = _get_int_attr(16, benefit)
- super().__init__(benefit_attr, sym_name=name_attr, loc=loc, ip=ip)
+ super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
self.regions[0].blocks.append()
@property
@@ -182,13 +173,17 @@ def body(self):
class ReplaceOp:
"""Specialization for PDL replace op class."""
- def __init__(self,
- op: Union[OpView, Operation, Value],
- *,
- with_op: Optional[Union[OpView, Operation, Value]] = None,
- with_values: Sequence[Union[OpView, Operation, Value]] = [],
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ op: Union[OpView, Operation, Value],
+ *,
+ with_op: Optional[Union[OpView, Operation, Value]] = None,
+ with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if with_values is None:
+ with_values = []
op = _get_value(op)
with_op = with_op if with_op is None else _get_value(with_op)
with_values = _get_values(with_values)
@@ -198,13 +193,14 @@ def __init__(self,
class ResultOp:
"""Specialization for PDL result op class."""
- def __init__(self,
- parent: Union[OpView, Operation, Value],
- index: Union[IntegerAttr, int],
- *,
- loc=None,
- ip=None):
- index = _get_int_attr(32, index)
+ def __init__(
+ self,
+ parent: Union[OpView, Operation, Value],
+ index: Union[IntegerAttr, int],
+ *,
+ loc=None,
+ ip=None,
+ ):
parent = _get_value(parent)
result = pdl.ValueType.get()
super().__init__(result, parent, index, loc=loc, ip=ip)
@@ -213,32 +209,36 @@ def __init__(self,
class ResultsOp:
"""Specialization for PDL results op class."""
- def __init__(self,
- result: Type,
- parent: Union[OpView, Operation, Value],
- index: Optional[Union[IntegerAttr, int]] = None,
- *,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ result: Type,
+ parent: Union[OpView, Operation, Value],
+ index: Optional[Union[IntegerAttr, int]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
parent = _get_value(parent)
- index = index if index is None else _get_int_attr(32, index)
super().__init__(result, parent, index=index, loc=loc, ip=ip)
class RewriteOp:
"""Specialization for PDL rewrite op class."""
- def __init__(self,
- root: Optional[Union[OpView, Operation, Value]] = None,
- name: Optional[Union[StringAttr, str]] = None,
- args: Sequence[Union[OpView, Operation, Value]] = [],
- *,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ root: Optional[Union[OpView, Operation, Value]] = None,
+ name: Optional[Union[StringAttr, str]] = None,
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
root = root if root is None else _get_value(root)
- name = name if name is None else _get_str_attr(name)
args = _get_values(args)
- super().__init__(args, root=root,name=name, loc=loc, ip=ip)
+ super().__init__(args, root=root, name=name, loc=loc, ip=ip)
def add_body(self):
"""Add body (block) to the rewrite."""
@@ -259,8 +259,6 @@ def __init__(self,
*,
loc=None,
ip=None):
- constantType = constantType if constantType is None else _get_type_attr(
- constantType)
result = pdl.TypeType.get()
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
@@ -268,13 +266,14 @@ def __init__(self,
class TypesOp:
"""Specialization for PDL types op class."""
- def __init__(self,
- constantTypes: Sequence[Union[TypeAttr, Type]] = [],
- *,
- loc=None,
- ip=None):
- constantTypes = _get_array_attr(
- [_get_type_attr(ty) for ty in constantTypes])
- constantTypes = None if not constantTypes else constantTypes
+ def __init__(
+ self,
+ constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if constantTypes is None:
+ constantTypes = []
result = pdl.RangeType.get(pdl.TypeType.get())
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index e2c262ca50201..9c051cd3d146d 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -15,180 +15,159 @@
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
-def _get_int64_attr(value: Union[int, Attribute]) -> IntegerAttr:
- if isinstance(value, int):
- return IntegerAttr.get(IntegerType.get_signless(64), value)
- return value
-
-
-def _get_array_attr(
- values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr:
- """Creates an array attribute from its operand."""
- if values is None:
- return ArrayAttr.get([])
- if isinstance(values, ArrayAttr):
- return values
-
- return ArrayAttr.get(values)
-
-
-def _get_int_array_attr(
- values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]]
-) -> ArrayAttr:
- """Creates an integer array attribute from its operand.
-
- If the operand is already an array attribute, forwards it. Otherwise treats
- the operand as a list of attributes or integers, possibly intersperced, to
- create a new array attribute containing integer attributes. Expects the
- thread-local MLIR context to have been set by the context manager.
- """
- if values is None:
- return ArrayAttr.get([])
- if isinstance(values, ArrayAttr):
- return values
-
- return ArrayAttr.get([_get_int64_attr(v) for v in values])
-
-def _get_dense_int64_array_attr(
- values: Sequence[int]) -> DenseI64ArrayAttr:
- """Creates a dense integer array from a sequence of integers.
- Expects the thread-local MLIR context to have been set by the context
- manager.
- """
- if values is None:
- return DenseI64ArrayAttr.get([])
- return DenseI64ArrayAttr.get(values)
-
def _get_int_int_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an array attribute containing array attributes of integers.
- If the operand is already an array attribute, forwards it. Otherwise treats
- the operand as a list of attributes or integers, potentially interpserced, to
- create a new array-of-array attribute. Expects the thread-local MLIR context
- to have been set by the context manager.
- """
+ If the operand is already an array attribute, forwards it. Otherwise treats
+ the operand as a list of attributes or integers, potentially interpserced, to
+ create a new array-of-array attribute. Expects the thread-local MLIR context
+ to have been set by the context manager.
+ """
if values is None:
return ArrayAttr.get([])
if isinstance(values, ArrayAttr):
return values
+ if isinstance(values, list):
+ values = [
+ ArrayAttr.get(
+ [IntegerAttr.get(IntegerType.get_signless(64), v)
+ for v in value])
+ for value in values
+ ]
- return ArrayAttr.get([_get_int_array_attr(value) for value in values])
+ return ArrayAttr.get(values)
class DecomposeOp:
"""Specialization for DecomposeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- super().__init__(
- pdl.OperationType.get(),
- _get_op_result_or_value(target),
- loc=loc,
- ip=ip)
+ super().__init__(pdl.OperationType.get(),
+ _get_op_result_or_value(target),
+ loc=loc,
+ ip=ip)
class GeneralizeOp:
"""Specialization for GeneralizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- super().__init__(
- pdl.OperationType.get(),
- _get_op_result_or_value(target),
- loc=loc,
- ip=ip)
+ super().__init__(pdl.OperationType.get(),
+ _get_op_result_or_value(target),
+ loc=loc,
+ ip=ip)
class InterchangeOp:
"""Specialization for InterchangeOp class."""
- def __init__(self,
- target: Union[Operation, Value],
- *,
- iterator_interchange: OptionalIntList = None,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ iterator_interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
pdl_operation_type = pdl.OperationType.get()
- interchange_attr = _get_dense_int64_array_attr(iterator_interchange)
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
- iterator_interchange=interchange_attr,
+ iterator_interchange=iterator_interchange,
loc=loc,
- ip=ip)
+ ip=ip,
+ )
class MatchOp:
"""Specialization for MatchOp class."""
@classmethod
- def match_op_names(MatchOp,
- target: Union[Operation, Value],
- names: Sequence[str],
- loc=None,
- ip=None):
+ def match_op_names(
+ MatchOp,
+ target: Union[Operation, Value],
+ names: Sequence[str],
+ loc=None,
+ ip=None,
+ ):
pdl_operation_type = pdl.OperationType.get()
return MatchOp(
pdl_operation_type,
_get_op_result_or_value(target),
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
loc=loc,
- ip=ip)
+ ip=ip,
+ )
class MultiTileSizesOp:
"""Specialization for MultitileSizesOp class."""
- def __init__(self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- dimension: Union[int, IntegerAttr],
- target_size: Union[int, IntegerAttr],
- divisor: Optional[Union[int, IntegerAttr]] = None,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ dimension: Union[int, IntegerAttr],
+ target_size: Union[int, IntegerAttr],
+ divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if divisor is None:
+ divisor = 1
super().__init__(
result_type,
result_type,
result_type,
_get_op_result_or_value(target),
- dimension=_get_int64_attr(dimension),
- target_size=_get_int64_attr(target_size),
- divisor=_get_int64_attr(divisor if divisor else 1),
+ dimension=dimension,
+ target_size=target_size,
+ divisor=divisor,
loc=loc,
- ip=ip)
+ ip=ip,
+ )
class PadOp:
"""Specialization for PadOp class."""
- def __init__(self,
- target: Union[Operation, Value],
- *,
- padding_values: Optional[Union[ArrayAttr,
- Sequence[Attribute]]] = None,
- padding_dimensions: OptionalIntList = None,
- pack_paddings: OptionalIntList = None,
- transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
- ArrayAttr, IntOrAttrList]]]] = None,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ padding_values: Optional[Optional[Union[ArrayAttr,
+ Sequence[Attribute]]]] = None,
+ padding_dimensions: OptionalIntList = None,
+ pack_paddings: OptionalIntList = None,
+ transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
+ ArrayAttr, IntOrAttrList]]]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if transpose_paddings is None:
+ transpose_paddings = []
+ if pack_paddings is None:
+ pack_paddings = []
+ if padding_dimensions is None:
+ padding_dimensions = []
+ if padding_values is None:
+ padding_values = []
pdl_operation_type = pdl.OperationType.get()
- padding_values_attr = _get_array_attr(padding_values)
- padding_dimensions_attr = _get_int_array_attr(padding_dimensions)
- pack_paddings_attr = _get_int_array_attr(pack_paddings)
transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
- padding_values=padding_values_attr,
- padding_dimensions=padding_dimensions_attr,
- pack_paddings=pack_paddings_attr,
+ padding_values=padding_values,
+ padding_dimensions=padding_dimensions,
+ pack_paddings=pack_paddings,
transpose_paddings=transpose_paddings_attr,
loc=loc,
- ip=ip)
+ ip=ip,
+ )
class ScalarizeOp:
@@ -196,29 +175,29 @@ class ScalarizeOp:
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
pdl_operation_type = pdl.OperationType.get()
- super().__init__(
- pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip)
+ super().__init__(pdl_operation_type,
+ _get_op_result_or_value(target),
+ loc=loc,
+ ip=ip)
class SplitOp:
"""Specialization for SplitOp class."""
- def __init__(self,
- target: Union[Operation, Value],
- dimension: Union[int, Attribute],
- split_point: Union[int, Operation, Value, Attribute],
- *,
- loc=None,
- ip=None):
- dimension = _get_int64_attr(dimension)
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ dimension: Union[int, Attribute],
+ split_point: Union[int, Operation, Value, Attribute],
+ *,
+ loc=None,
+ ip=None,
+ ):
if isinstance(split_point, int):
- split_point = _get_int64_attr(split_point)
-
- if isinstance(split_point, Attribute):
static_split_point = split_point
dynamic_split_point = None
else:
- static_split_point = _get_int64_attr(ShapedType.get_dynamic_size())
+ static_split_point = ShapedType.get_dynamic_size()
dynamic_split_point = _get_op_result_or_value(split_point)
target = _get_op_result_or_value(target)
@@ -231,44 +210,53 @@ def __init__(self,
static_split_point=static_split_point,
dynamic_split_point=dynamic_split_point,
loc=loc,
- ip=ip)
+ ip=ip,
+ )
class TileOp:
"""Specialization for TileOp class."""
@overload
- def __init__(self,
- loop_types: Union[Type, List[Type]],
- target: Union[Operation, Value],
- *,
- sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
- Value]], ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ loop_types: Union[Type, List[Type]],
+ target: Union[Operation, Value],
+ *,
+ sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
+ ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
...
@overload
- def __init__(self,
- target: Union[Operation, Value, OpView],
- *,
- sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
- Value]], ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
+ ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
...
- def __init__(self,
- loop_types_or_target: Union[Type, List[Type], Operation, Value],
- target_or_none: Optional[Union[Operation, Value, OpView]] = None,
- *,
- sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
- Value]], ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ loop_types_or_target: Union[Type, List[Type], Operation, Value],
+ target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+ *,
+ sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
+ ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ if interchange is None:
+ interchange = []
if sizes is None:
sizes = []
@@ -293,8 +281,8 @@ def __init__(self,
target = loop_types_or_target
assert target_or_none is None, "Cannot construct TileOp with two targets."
else:
- loop_types = ([loop_types_or_target] * num_loops) if isinstance(
- loop_types_or_target, Type) else loop_types_or_target
+ loop_types = (([loop_types_or_target] * num_loops) if isinstance(
+ loop_types_or_target, Type) else loop_types_or_target)
target = target_or_none
target = _get_op_result_or_value(target)
@@ -305,10 +293,10 @@ def __init__(self,
target,
dynamic_sizes=dynamic_sizes,
static_sizes=sizes_attr,
- interchange=_get_dense_int64_array_attr(interchange)
- if interchange else None,
+ interchange=interchange,
loc=loc,
- ip=ip)
+ ip=ip,
+ )
def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
if not attr:
@@ -319,12 +307,14 @@ def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
class VectorizeOp:
"""Specialization for VectorizeOp class."""
- def __init__(self,
- target: Union[Operation, Value],
- *,
- vectorize_padding: Union[bool, BoolAttr] = False,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ vectorize_padding: Union[bool, BoolAttr] = False,
+ loc=None,
+ ip=None,
+ ):
pdl_operation_type = pdl.OperationType.get()
if isinstance(vectorize_padding, bool):
vectorize_padding = UnitAttr.get()
@@ -333,4 +323,5 @@ def __init__(self,
_get_op_result_or_value(target),
vectorize_padding=vectorize_padding,
loc=loc,
- ip=ip)
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
index 593b8855c935f..8651c76ea7dfc 100644
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_transform_ops_ext.py
@@ -4,102 +4,119 @@
try:
from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+ from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ )
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from argparse import SUPPRESS
-from typing import Optional, overload, Sequence, Union
-
-
-def _get_symbol_ref_attr(value: Union[Attribute, str]):
- if isinstance(value, Attribute):
- return value
- return FlatSymbolRefAttr.get(value)
+from typing import Optional, Sequence, Union
class CastOp:
- def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- loc=loc,
- ip=ip)
+ def __init__(self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ loc=None,
+ ip=None):
+ super().__init__(result_type,
+ _get_op_result_or_value(target),
+ loc=loc,
+ ip=ip)
class GetClosestIsolatedParentOp:
- def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- loc=loc,
- ip=ip)
-
-
-class MergeHandlesOp:
-
def __init__(self,
- handles: Sequence[Union[Operation, Value]],
+ result_type: Type,
+ target: Union[Operation, Value],
*,
- deduplicate: bool = False,
loc=None,
ip=None):
+ super().__init__(result_type,
+ _get_op_result_or_value(target),
+ loc=loc,
+ ip=ip)
+
+
+class MergeHandlesOp:
+
+ def __init__(
+ self,
+ handles: Sequence[Union[Operation, Value]],
+ *,
+ deduplicate: bool = False,
+ loc=None,
+ ip=None,
+ ):
super().__init__(
[_get_op_result_or_value(h) for h in handles],
deduplicate=deduplicate,
loc=loc,
- ip=ip)
+ ip=ip,
+ )
class PDLMatchOp:
- def __init__(self,
- result_type: Type,
- target: Union[Operation, Value],
- pattern_name: Union[Attribute, str],
- *,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ pattern_name: Union[Attribute, str],
+ *,
+ loc=None,
+ ip=None,
+ ):
super().__init__(
result_type,
_get_op_result_or_value(target),
- _get_symbol_ref_attr(pattern_name),
+ pattern_name,
loc=loc,
- ip=ip)
+ ip=ip,
+ )
class ReplicateOp:
- def __init__(self,
- pattern: Union[Operation, Value],
- handles: Sequence[Union[Operation, Value]],
- *,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ pattern: Union[Operation, Value],
+ handles: Sequence[Union[Operation, Value]],
+ *,
+ loc=None,
+ ip=None,
+ ):
super().__init__(
[_get_op_result_or_value(h).type for h in handles],
_get_op_result_or_value(pattern),
[_get_op_result_or_value(h) for h in handles],
loc=loc,
- ip=ip)
+ ip=ip,
+ )
class SequenceOp:
- def __init__(self, failure_propagation_mode, results: Sequence[Type],
- target: Union[Operation, Value, Type],
- extra_bindings: Optional[Union[Sequence[Value], Sequence[Type],
- Operation, OpView]] = None):
- root = _get_op_result_or_value(target) if isinstance(
- target, (Operation, Value)) else None
+ def __init__(
+ self,
+ failure_propagation_mode,
+ results: Sequence[Type],
+ target: Union[Operation, Value, Type],
+ extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], Operation,
+ OpView]] = None,
+ ):
+ root = (_get_op_result_or_value(target) if isinstance(
+ target, (Operation, Value)) else None)
root_type = root.type if not isinstance(target, Type) else target
if not isinstance(failure_propagation_mode, Attribute):
failure_propagation_mode_attr = IntegerAttr.get(
IntegerType.get_signless(32), failure_propagation_mode._as_int())
else:
- failure_propagation_mode = failure_propagation_mode
+ failure_propagation_mode_attr = failure_propagation_mode
if extra_bindings is None:
extra_bindings = []
@@ -114,10 +131,12 @@ def __init__(self, failure_propagation_mode, results: Sequence[Type],
else:
extra_binding_types = [v.type for v in extra_bindings]
- super().__init__(results_=results,
- failure_propagation_mode=failure_propagation_mode_attr,
- root=root,
- extra_bindings=extra_bindings)
+ super().__init__(
+ results_=results,
+ failure_propagation_mode=failure_propagation_mode_attr,
+ root=root,
+ extra_bindings=extra_bindings,
+ )
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
@property
@@ -143,10 +162,7 @@ def __init__(self,
root = _get_op_result_or_value(target) if not isinstance(target,
Type) else None
root_type = target if isinstance(target, Type) else root.type
- super().__init__(
- root=root,
- loc=loc,
- ip=ip)
+ super().__init__(root=root, loc=loc, ip=ip)
self.regions[0].blocks.append(root_type)
@property
@@ -160,9 +176,13 @@ def bodyTarget(self) -> Value:
class YieldOp:
- def __init__(self,
- operands: Union[Operation, Sequence[Value]] = [],
- *,
- loc=None,
- ip=None):
+ def __init__(
+ self,
+ operands: Optional[Union[Operation, Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if operands is None:
+ operands = []
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 1e24fcbf99e40..714253426b025 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -8,9 +8,11 @@
# Convenience decorator for registering user-friendly Attribute builders.
def register_attribute_builder(kind):
+
def decorator_builder(func):
AttrBuilder.insert(kind, func)
return func
+
return decorator_builder
@@ -18,34 +20,77 @@ def decorator_builder(func):
def _boolAttr(x, context):
return BoolAttr.get(x, context=context)
+
@register_attribute_builder("IndexAttr")
def _indexAttr(x, context):
return IntegerAttr.get(IndexType.get(context=context), x)
+
+ at register_attribute_builder("I16Attr")
+def _i32Attr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
+
+
@register_attribute_builder("I32Attr")
def _i32Attr(x, context):
- return IntegerAttr.get(
- IntegerType.get_signless(32, context=context), x)
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
+
@register_attribute_builder("I64Attr")
def _i64Attr(x, context):
- return IntegerAttr.get(
- IntegerType.get_signless(64, context=context), x)
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
+
@register_attribute_builder("StrAttr")
def _stringAttr(x, context):
return StringAttr.get(x, context=context)
+
@register_attribute_builder("SymbolNameAttr")
def _symbolNameAttr(x, context):
return StringAttr.get(x, context=context)
+
+ at register_attribute_builder("SymbolRefAttr")
+def _symbolRefAttr(x, context):
+ return FlatSymbolRefAttr.get(x, context=context)
+
+
+ at register_attribute_builder("ArrayAttr")
+def _arrayAttr(x, context):
+ return ArrayAttr.get(x, context=context)
+
+
+ at register_attribute_builder("I64ArrayAttr")
+def _i64ArrayAttr(x, context):
+ return ArrayAttr.get([_i64Attr(v, context) for v in x])
+
+
+ at register_attribute_builder("DenseI64ArrayAttr")
+def _denseI64ArrayAttr(x, context):
+ return DenseI64ArrayAttr.get(x, context=context)
+
+
+ at register_attribute_builder("TypeAttr")
+def _typeAttr(x, context):
+ return TypeAttr.get(x, context=context)
+
+
+ at register_attribute_builder("TypeArrayAttr")
+def _typeArrayAttr(x, context):
+ return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
+
+
try:
import numpy as np
+
@register_attribute_builder("IndexElementsAttr")
def _indexElementsAttr(x, context):
return DenseElementsAttr.get(
- np.array(x, dtype=np.int64), type=IndexType.get(context=context),
- context=context)
+ np.array(x, dtype=np.int64),
+ type=IndexType.get(context=context),
+ context=context,
+ )
+
except ImportError:
pass
More information about the Mlir-commits
mailing list