[Mlir-commits] [mlir] f223fcf - [MLIR][python bindings] Add some AttrBuilder and port _exts to use them.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 26 15:50:42 PDT 2023


Author: max
Date: 2023-04-26T17:50:10-05:00
New Revision: f223fcf67f00c58f49a759b5d43e82d277a346d7

URL: https://github.com/llvm/llvm-project/commit/f223fcf67f00c58f49a759b5d43e82d277a346d7
DIFF: https://github.com/llvm/llvm-project/commit/f223fcf67f00c58f49a759b5d43e82d277a346d7.diff

LOG: [MLIR][python bindings] Add some AttrBuilder and port _exts to use them.

Differential Revision: https://reviews.llvm.org/D149287

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_loop_transform_ops_ext.py
    mlir/python/mlir/dialects/_pdl_ops_ext.py
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/python/mlir/dialects/_transform_ops_ext.py
    mlir/python/mlir/ir.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
index 0dc8fc07431a2..a275ea615378e 100644
--- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
@@ -11,65 +11,63 @@
 from typing import Optional, Union
 
 
-def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]],
-                    default_value: int = None):
-  if isinstance(arg, IntegerAttr):
-    return arg
-
-  if arg is None:
-    assert default_value is not None, "must provide default value"
-    arg = default_value
-
-  return IntegerAttr.get(IntegerType.get_signless(64), arg)
-
-
 class GetParentForOp:
   """Extension for GetParentForOp."""
 
-  def __init__(self,
-               result_type: Type,
-               target: Union[Operation, Value],
-               *,
-               num_loops: int = 1,
-               ip=None,
-               loc=None):
+  def __init__(
+      self,
+      result_type: Type,
+      target: Union[Operation, Value],
+      *,
+      num_loops: Optional[int] = None,
+      ip=None,
+      loc=None,
+  ):
+    if num_loops is None:
+      num_loops = 1
     super().__init__(
         result_type,
         _get_op_result_or_value(target),
-        num_loops=_get_int64_attr(num_loops, default_value=1),
+        num_loops=num_loops,
         ip=ip,
-        loc=loc)
+        loc=loc,
+    )
 
 
 class LoopOutlineOp:
   """Extension for LoopOutlineOp."""
 
-  def __init__(self,
-               result_type: Type,
-               target: Union[Operation, Value],
-               *,
-               func_name: Union[str, StringAttr],
-               ip=None,
-               loc=None):
+  def __init__(
+      self,
+      result_type: Type,
+      target: Union[Operation, Value],
+      *,
+      func_name: Union[str, StringAttr],
+      ip=None,
+      loc=None,
+  ):
     super().__init__(
         result_type,
         _get_op_result_or_value(target),
         func_name=(func_name if isinstance(func_name, StringAttr) else
                    StringAttr.get(func_name)),
         ip=ip,
-        loc=loc)
+        loc=loc,
+    )
 
 
 class LoopPeelOp:
   """Extension for LoopPeelOp."""
 
-  def __init__(self,
-               result_type: Type,
-               target: Union[Operation, Value],
-               *,
-               fail_if_already_divisible: Union[bool, BoolAttr] = False,
-               ip=None,
-               loc=None):
+  def __init__(
+      self,
+      result_type: Type,
+      target: Union[Operation, Value],
+      *,
+      fail_if_already_divisible: Union[bool, BoolAttr] = False,
+      ip=None,
+      loc=None,
+  ):
     super().__init__(
         result_type,
         _get_op_result_or_value(target),
@@ -77,40 +75,51 @@ def __init__(self,
             fail_if_already_divisible, BoolAttr) else
                                    BoolAttr.get(fail_if_already_divisible)),
         ip=ip,
-        loc=loc)
+        loc=loc,
+    )
 
 
 class LoopPipelineOp:
   """Extension for LoopPipelineOp."""
 
-  def __init__(self,
-               result_type: Type,
-               target: Union[Operation, Value],
-               *,
-               iteration_interval: Optional[Union[int, IntegerAttr]] = None,
-               read_latency: Optional[Union[int, IntegerAttr]] = None,
-               ip=None,
-               loc=None):
+  def __init__(
+      self,
+      result_type: Type,
+      target: Union[Operation, Value],
+      *,
+      iteration_interval: Optional[Union[int, IntegerAttr]] = None,
+      read_latency: Optional[Union[int, IntegerAttr]] = None,
+      ip=None,
+      loc=None,
+  ):
+    if iteration_interval is None:
+      iteration_interval = 1
+    if read_latency is None:
+      read_latency = 10
     super().__init__(
         result_type,
         _get_op_result_or_value(target),
-        iteration_interval=_get_int64_attr(iteration_interval, default_value=1),
-        read_latency=_get_int64_attr(read_latency, default_value=10),
+        iteration_interval=iteration_interval,
+        read_latency=read_latency,
         ip=ip,
-        loc=loc)
+        loc=loc,
+    )
 
 
 class LoopUnrollOp:
   """Extension for LoopUnrollOp."""
 
-  def __init__(self,
-               target: Union[Operation, Value],
-               *,
-               factor: Union[int, IntegerAttr],
-               ip=None,
-               loc=None):
+  def __init__(
+      self,
+      target: Union[Operation, Value],
+      *,
+      factor: Union[int, IntegerAttr],
+      ip=None,
+      loc=None,
+  ):
     super().__init__(
         _get_op_result_or_value(target),
-        factor=_get_int64_attr(factor),
+        factor=factor,
         ip=ip,
-        loc=loc)
+        loc=loc,
+    )

diff  --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py
index 428301b18f208..40ccbef6351dc 100644
--- a/mlir/python/mlir/dialects/_pdl_ops_ext.py
+++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py
@@ -8,61 +8,26 @@
 except ImportError as e:
   raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Union, Optional, Sequence, List, Mapping
-from ._ods_common import get_op_result_or_value as _get_value, get_op_results_or_values as _get_values
-
-
-def _get_int_attr(bits: int, value: Union[IntegerAttr, int]) -> IntegerAttr:
-  """Converts the given value to signless integer attribute of given bit width."""
-  if isinstance(value, int):
-    ty = IntegerType.get_signless(bits)
-    return IntegerAttr.get(ty, value)
-  else:
-    return value
-
-
-def _get_array_attr(attrs: Union[ArrayAttr, Sequence[Attribute]]) -> ArrayAttr:
-  """Converts the given value to array attribute."""
-  if isinstance(attrs, ArrayAttr):
-    return attrs
-  else:
-    return ArrayAttr.get(list(attrs))
-
-
-def _get_str_array_attr(attrs: Union[ArrayAttr, Sequence[str]]) -> ArrayAttr:
-  """Converts the given value to string array attribute."""
-  if isinstance(attrs, ArrayAttr):
-    return attrs
-  else:
-    return ArrayAttr.get([StringAttr.get(s) for s in attrs])
-
-
-def _get_str_attr(name: Union[StringAttr, str]) -> Optional[StringAttr]:
-  """Converts the given value to string attribute."""
-  if isinstance(name, str):
-    return StringAttr.get(name)
-  else:
-    return name
-
-
-def _get_type_attr(type: Union[TypeAttr, Type]) -> TypeAttr:
-  """Converts the given value to type attribute."""
-  if isinstance(type, Type):
-    return TypeAttr.get(type)
-  else:
-    return type
+from typing import Union, Optional, Sequence, Mapping
+from ._ods_common import (
+    get_op_result_or_value as _get_value,
+    get_op_results_or_values as _get_values,
+)
 
 
 class ApplyNativeConstraintOp:
   """Specialization for PDL apply native constraint op class."""
 
-  def __init__(self,
-               name: Union[str, StringAttr],
-               args: Sequence[Union[OpView, Operation, Value]] = [],
-               *,
-               loc=None,
-               ip=None):
-    name = _get_str_attr(name)
+  def __init__(
+      self,
+      name: Union[str, StringAttr],
+      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
+    if args is None:
+      args = []
     args = _get_values(args)
     super().__init__(name, args, loc=loc, ip=ip)
 
@@ -70,14 +35,17 @@ def __init__(self,
 class ApplyNativeRewriteOp:
   """Specialization for PDL apply native rewrite op class."""
 
-  def __init__(self,
-               results: Sequence[Type],
-               name: Union[str, StringAttr],
-               args: Sequence[Union[OpView, Operation, Value]] = [],
-               *,
-               loc=None,
-               ip=None):
-    name = _get_str_attr(name)
+  def __init__(
+      self,
+      results: Sequence[Type],
+      name: Union[str, StringAttr],
+      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
+    if args is None:
+      args = []
     args = _get_values(args)
     super().__init__(results, name, args, loc=loc, ip=ip)
 
@@ -85,12 +53,14 @@ def __init__(self,
 class AttributeOp:
   """Specialization for PDL attribute op class."""
 
-  def __init__(self,
-               valueType: Optional[Union[OpView, Operation, Value]] = None,
-               value: Optional[Attribute] = None,
-               *,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      valueType: Optional[Union[OpView, Operation, Value]] = None,
+      value: Optional[Attribute] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
     valueType = valueType if valueType is None else _get_value(valueType)
     result = pdl.AttributeType.get()
     super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
@@ -99,11 +69,13 @@ def __init__(self,
 class EraseOp:
   """Specialization for PDL erase op class."""
 
-  def __init__(self,
-               operation: Optional[Union[OpView, Operation, Value]] = None,
-               *,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      operation: Optional[Union[OpView, Operation, Value]] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
     operation = _get_value(operation)
     super().__init__(operation, loc=loc, ip=ip)
 
@@ -111,11 +83,13 @@ def __init__(self,
 class OperandOp:
   """Specialization for PDL operand op class."""
 
-  def __init__(self,
-               type: Optional[Union[OpView, Operation, Value]] = None,
-               *,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      type: Optional[Union[OpView, Operation, Value]] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
     type = type if type is None else _get_value(type)
     result = pdl.ValueType.get()
     super().__init__(result, valueType=type, loc=loc, ip=ip)
@@ -124,11 +98,13 @@ def __init__(self,
 class OperandsOp:
   """Specialization for PDL operands op class."""
 
-  def __init__(self,
-               types: Optional[Union[OpView, Operation, Value]] = None,
-               *,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      types: Optional[Union[OpView, Operation, Value]] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
     types = types if types is None else _get_value(types)
     result = pdl.RangeType.get(pdl.ValueType.get())
     super().__init__(result, valueType=types, loc=loc, ip=ip)
@@ -137,15 +113,23 @@ def __init__(self,
 class OperationOp:
   """Specialization for PDL operand op class."""
 
-  def __init__(self,
-               name: Optional[Union[str, StringAttr]] = None,
-               args: Sequence[Union[OpView, Operation, Value]] = [],
-               attributes: Mapping[str, Union[OpView, Operation, Value]] = {},
-               types: Sequence[Union[OpView, Operation, Value]] = [],
-               *,
-               loc=None,
-               ip=None):
-    name = name if name is None else _get_str_attr(name)
+  def __init__(
+      self,
+      name: Optional[Union[str, StringAttr]] = None,
+      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+      attributes: Optional[Mapping[str, Union[OpView, Operation,
+                                              Value]]] = None,
+      types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
+    if types is None:
+      types = []
+    if attributes is None:
+      attributes = {}
+    if args is None:
+      args = []
     args = _get_values(args)
     attrNames = []
     attrValues = []
@@ -155,22 +139,29 @@ def __init__(self,
     attrNames = ArrayAttr.get(attrNames)
     types = _get_values(types)
     result = pdl.OperationType.get()
-    super().__init__(result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip)
+    super().__init__(result,
+                     args,
+                     attrValues,
+                     attrNames,
+                     types,
+                     opName=name,
+                     loc=loc,
+                     ip=ip)
 
 
 class PatternOp:
   """Specialization for PDL pattern op class."""
 
-  def __init__(self,
-               benefit: Union[IntegerAttr, int],
-               name: Optional[Union[StringAttr, str]] = None,
-               *,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      benefit: Union[IntegerAttr, int],
+      name: Optional[Union[StringAttr, str]] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
     """Creates an PDL `pattern` operation."""
-    name_attr = None if name is None else _get_str_attr(name)
-    benefit_attr = _get_int_attr(16, benefit)
-    super().__init__(benefit_attr, sym_name=name_attr, loc=loc, ip=ip)
+    super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
     self.regions[0].blocks.append()
 
   @property
@@ -182,13 +173,17 @@ def body(self):
 class ReplaceOp:
   """Specialization for PDL replace op class."""
 
-  def __init__(self,
-               op: Union[OpView, Operation, Value],
-               *,
-               with_op: Optional[Union[OpView, Operation, Value]] = None,
-               with_values: Sequence[Union[OpView, Operation, Value]] = [],
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      op: Union[OpView, Operation, Value],
+      *,
+      with_op: Optional[Union[OpView, Operation, Value]] = None,
+      with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+      loc=None,
+      ip=None,
+  ):
+    if with_values is None:
+      with_values = []
     op = _get_value(op)
     with_op = with_op if with_op is None else _get_value(with_op)
     with_values = _get_values(with_values)
@@ -198,13 +193,14 @@ def __init__(self,
 class ResultOp:
   """Specialization for PDL result op class."""
 
-  def __init__(self,
-               parent: Union[OpView, Operation, Value],
-               index: Union[IntegerAttr, int],
-               *,
-               loc=None,
-               ip=None):
-    index = _get_int_attr(32, index)
+  def __init__(
+      self,
+      parent: Union[OpView, Operation, Value],
+      index: Union[IntegerAttr, int],
+      *,
+      loc=None,
+      ip=None,
+  ):
     parent = _get_value(parent)
     result = pdl.ValueType.get()
     super().__init__(result, parent, index, loc=loc, ip=ip)
@@ -213,32 +209,36 @@ def __init__(self,
 class ResultsOp:
   """Specialization for PDL results op class."""
 
-  def __init__(self,
-               result: Type,
-               parent: Union[OpView, Operation, Value],
-               index: Optional[Union[IntegerAttr, int]] = None,
-               *,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      result: Type,
+      parent: Union[OpView, Operation, Value],
+      index: Optional[Union[IntegerAttr, int]] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
     parent = _get_value(parent)
-    index = index if index is None else _get_int_attr(32, index)
     super().__init__(result, parent, index=index, loc=loc, ip=ip)
 
 
 class RewriteOp:
   """Specialization for PDL rewrite op class."""
 
-  def __init__(self,
-               root: Optional[Union[OpView, Operation, Value]] = None,
-               name: Optional[Union[StringAttr, str]] = None,
-               args: Sequence[Union[OpView, Operation, Value]] = [],
-               *,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      root: Optional[Union[OpView, Operation, Value]] = None,
+      name: Optional[Union[StringAttr, str]] = None,
+      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
+    if args is None:
+      args = []
     root = root if root is None else _get_value(root)
-    name = name if name is None else _get_str_attr(name)
     args = _get_values(args)
-    super().__init__(args, root=root,name=name, loc=loc, ip=ip)
+    super().__init__(args, root=root, name=name, loc=loc, ip=ip)
 
   def add_body(self):
     """Add body (block) to the rewrite."""
@@ -259,8 +259,6 @@ def __init__(self,
                *,
                loc=None,
                ip=None):
-    constantType = constantType if constantType is None else _get_type_attr(
-        constantType)
     result = pdl.TypeType.get()
     super().__init__(result, constantType=constantType, loc=loc, ip=ip)
 
@@ -268,13 +266,14 @@ def __init__(self,
 class TypesOp:
   """Specialization for PDL types op class."""
 
-  def __init__(self,
-               constantTypes: Sequence[Union[TypeAttr, Type]] = [],
-               *,
-               loc=None,
-               ip=None):
-    constantTypes = _get_array_attr(
-        [_get_type_attr(ty) for ty in constantTypes])
-    constantTypes = None if not constantTypes else constantTypes
+  def __init__(
+      self,
+      constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
+    if constantTypes is None:
+      constantTypes = []
     result = pdl.RangeType.get(pdl.TypeType.get())
     super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)

diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index e2c262ca50201..9c051cd3d146d 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -15,180 +15,159 @@
 OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
 
 
-def _get_int64_attr(value: Union[int, Attribute]) -> IntegerAttr:
-  if isinstance(value, int):
-    return IntegerAttr.get(IntegerType.get_signless(64), value)
-  return value
-
-
-def _get_array_attr(
-    values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr:
-  """Creates an array attribute from its operand."""
-  if values is None:
-    return ArrayAttr.get([])
-  if isinstance(values, ArrayAttr):
-    return values
-
-  return ArrayAttr.get(values)
-
-
-def _get_int_array_attr(
-    values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]]
-) -> ArrayAttr:
-  """Creates an integer array attribute from its operand.
-
-  If the operand is already an array attribute, forwards it. Otherwise treats
-  the operand as a list of attributes or integers, possibly intersperced, to
-  create a new array attribute containing integer attributes. Expects the
-  thread-local MLIR context to have been set by the context manager.
-  """
-  if values is None:
-    return ArrayAttr.get([])
-  if isinstance(values, ArrayAttr):
-    return values
-
-  return ArrayAttr.get([_get_int64_attr(v) for v in values])
-
-def _get_dense_int64_array_attr(
-        values: Sequence[int]) -> DenseI64ArrayAttr:
-  """Creates a dense integer array from a sequence of integers.
-    Expects the thread-local MLIR context to have been set by the context 
-    manager.
-    """
-  if values is None:
-    return DenseI64ArrayAttr.get([])
-  return DenseI64ArrayAttr.get(values)
-
 def _get_int_int_array_attr(
     values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
                                                      IntOrAttrList]]]]
 ) -> ArrayAttr:
   """Creates an array attribute containing array attributes of integers.
 
-  If the operand is already an array attribute, forwards it. Otherwise treats
-  the operand as a list of attributes or integers, potentially interpserced, to
-  create a new array-of-array attribute. Expects the thread-local MLIR context
-  to have been set by the context manager.
-  """
+    If the operand is already an array attribute, forwards it. Otherwise treats
+    the operand as a list of attributes or integers, potentially interpserced, to
+    create a new array-of-array attribute. Expects the thread-local MLIR context
+    to have been set by the context manager.
+    """
   if values is None:
     return ArrayAttr.get([])
   if isinstance(values, ArrayAttr):
     return values
+  if isinstance(values, list):
+    values = [
+        ArrayAttr.get(
+            [IntegerAttr.get(IntegerType.get_signless(64), v)
+             for v in value])
+        for value in values
+    ]
 
-  return ArrayAttr.get([_get_int_array_attr(value) for value in values])
+  return ArrayAttr.get(values)
 
 
 class DecomposeOp:
   """Specialization for DecomposeOp class."""
 
   def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-    super().__init__(
-        pdl.OperationType.get(),
-        _get_op_result_or_value(target),
-        loc=loc,
-        ip=ip)
+    super().__init__(pdl.OperationType.get(),
+                     _get_op_result_or_value(target),
+                     loc=loc,
+                     ip=ip)
 
 
 class GeneralizeOp:
   """Specialization for GeneralizeOp class."""
 
   def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-    super().__init__(
-        pdl.OperationType.get(),
-        _get_op_result_or_value(target),
-        loc=loc,
-        ip=ip)
+    super().__init__(pdl.OperationType.get(),
+                     _get_op_result_or_value(target),
+                     loc=loc,
+                     ip=ip)
 
 
 class InterchangeOp:
   """Specialization for InterchangeOp class."""
 
-  def __init__(self,
-               target: Union[Operation, Value],
-               *,
-               iterator_interchange: OptionalIntList = None,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      target: Union[Operation, Value],
+      *,
+      iterator_interchange: OptionalIntList = None,
+      loc=None,
+      ip=None,
+  ):
     pdl_operation_type = pdl.OperationType.get()
-    interchange_attr = _get_dense_int64_array_attr(iterator_interchange)
     super().__init__(
         pdl_operation_type,
         _get_op_result_or_value(target),
-        iterator_interchange=interchange_attr,
+        iterator_interchange=iterator_interchange,
         loc=loc,
-        ip=ip)
+        ip=ip,
+    )
 
 
 class MatchOp:
   """Specialization for MatchOp class."""
 
   @classmethod
-  def match_op_names(MatchOp,
-                     target: Union[Operation, Value],
-                     names: Sequence[str],
-                     loc=None,
-                     ip=None):
+  def match_op_names(
+      MatchOp,
+      target: Union[Operation, Value],
+      names: Sequence[str],
+      loc=None,
+      ip=None,
+  ):
     pdl_operation_type = pdl.OperationType.get()
     return MatchOp(
         pdl_operation_type,
         _get_op_result_or_value(target),
         ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
         loc=loc,
-        ip=ip)
+        ip=ip,
+    )
 
 
 class MultiTileSizesOp:
   """Specialization for MultitileSizesOp class."""
 
-  def __init__(self,
-               result_type: Type,
-               target: Union[Operation, Value],
-               *,
-               dimension: Union[int, IntegerAttr],
-               target_size: Union[int, IntegerAttr],
-               divisor: Optional[Union[int, IntegerAttr]] = None,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      result_type: Type,
+      target: Union[Operation, Value],
+      *,
+      dimension: Union[int, IntegerAttr],
+      target_size: Union[int, IntegerAttr],
+      divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
+      loc=None,
+      ip=None,
+  ):
+    if divisor is None:
+      divisor = 1
     super().__init__(
         result_type,
         result_type,
         result_type,
         _get_op_result_or_value(target),
-        dimension=_get_int64_attr(dimension),
-        target_size=_get_int64_attr(target_size),
-        divisor=_get_int64_attr(divisor if divisor else 1),
+        dimension=dimension,
+        target_size=target_size,
+        divisor=divisor,
         loc=loc,
-        ip=ip)
+        ip=ip,
+    )
 
 
 class PadOp:
   """Specialization for PadOp class."""
 
-  def __init__(self,
-               target: Union[Operation, Value],
-               *,
-               padding_values: Optional[Union[ArrayAttr,
-                                              Sequence[Attribute]]] = None,
-               padding_dimensions: OptionalIntList = None,
-               pack_paddings: OptionalIntList = None,
-               transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
-                   ArrayAttr, IntOrAttrList]]]] = None,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      target: Union[Operation, Value],
+      *,
+      padding_values: Optional[Optional[Union[ArrayAttr,
+                                              Sequence[Attribute]]]] = None,
+      padding_dimensions: OptionalIntList = None,
+      pack_paddings: OptionalIntList = None,
+      transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
+          ArrayAttr, IntOrAttrList]]]] = None,
+      loc=None,
+      ip=None,
+  ):
+    if transpose_paddings is None:
+      transpose_paddings = []
+    if pack_paddings is None:
+      pack_paddings = []
+    if padding_dimensions is None:
+      padding_dimensions = []
+    if padding_values is None:
+      padding_values = []
     pdl_operation_type = pdl.OperationType.get()
-    padding_values_attr = _get_array_attr(padding_values)
-    padding_dimensions_attr = _get_int_array_attr(padding_dimensions)
-    pack_paddings_attr = _get_int_array_attr(pack_paddings)
     transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
     super().__init__(
         pdl_operation_type,
         _get_op_result_or_value(target),
-        padding_values=padding_values_attr,
-        padding_dimensions=padding_dimensions_attr,
-        pack_paddings=pack_paddings_attr,
+        padding_values=padding_values,
+        padding_dimensions=padding_dimensions,
+        pack_paddings=pack_paddings,
         transpose_paddings=transpose_paddings_attr,
         loc=loc,
-        ip=ip)
+        ip=ip,
+    )
 
 
 class ScalarizeOp:
@@ -196,29 +175,29 @@ class ScalarizeOp:
 
   def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
     pdl_operation_type = pdl.OperationType.get()
-    super().__init__(
-        pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip)
+    super().__init__(pdl_operation_type,
+                     _get_op_result_or_value(target),
+                     loc=loc,
+                     ip=ip)
 
 
 class SplitOp:
   """Specialization for SplitOp class."""
 
-  def __init__(self,
-               target: Union[Operation, Value],
-               dimension: Union[int, Attribute],
-               split_point: Union[int, Operation, Value, Attribute],
-               *,
-               loc=None,
-               ip=None):
-    dimension = _get_int64_attr(dimension)
+  def __init__(
+      self,
+      target: Union[Operation, Value],
+      dimension: Union[int, Attribute],
+      split_point: Union[int, Operation, Value, Attribute],
+      *,
+      loc=None,
+      ip=None,
+  ):
     if isinstance(split_point, int):
-      split_point = _get_int64_attr(split_point)
-
-    if isinstance(split_point, Attribute):
       static_split_point = split_point
       dynamic_split_point = None
     else:
-      static_split_point = _get_int64_attr(ShapedType.get_dynamic_size())
+      static_split_point = ShapedType.get_dynamic_size()
       dynamic_split_point = _get_op_result_or_value(split_point)
 
     target = _get_op_result_or_value(target)
@@ -231,44 +210,53 @@ def __init__(self,
         static_split_point=static_split_point,
         dynamic_split_point=dynamic_split_point,
         loc=loc,
-        ip=ip)
+        ip=ip,
+    )
 
 
 class TileOp:
   """Specialization for TileOp class."""
 
   @overload
-  def __init__(self,
-               loop_types: Union[Type, List[Type]],
-               target: Union[Operation, Value],
-               *,
-               sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
-                                                    Value]], ArrayAttr]] = None,
-               interchange: OptionalIntList = None,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      loop_types: Union[Type, List[Type]],
+      target: Union[Operation, Value],
+      *,
+      sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
+                            ArrayAttr]] = None,
+      interchange: OptionalIntList = None,
+      loc=None,
+      ip=None,
+  ):
     ...
 
   @overload
-  def __init__(self,
-               target: Union[Operation, Value, OpView],
-               *,
-               sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
-                                                    Value]], ArrayAttr]] = None,
-               interchange: OptionalIntList = None,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      target: Union[Operation, Value, OpView],
+      *,
+      sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
+                            ArrayAttr]] = None,
+      interchange: OptionalIntList = None,
+      loc=None,
+      ip=None,
+  ):
     ...
 
-  def __init__(self,
-               loop_types_or_target: Union[Type, List[Type], Operation, Value],
-               target_or_none: Optional[Union[Operation, Value, OpView]] = None,
-               *,
-               sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
-                                                    Value]], ArrayAttr]] = None,
-               interchange: OptionalIntList = None,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      loop_types_or_target: Union[Type, List[Type], Operation, Value],
+      target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+      *,
+      sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
+                            ArrayAttr]] = None,
+      interchange: OptionalIntList = None,
+      loc=None,
+      ip=None,
+  ):
+    if interchange is None:
+      interchange = []
     if sizes is None:
       sizes = []
 
@@ -293,8 +281,8 @@ def __init__(self,
       target = loop_types_or_target
       assert target_or_none is None, "Cannot construct TileOp with two targets."
     else:
-      loop_types = ([loop_types_or_target] * num_loops) if isinstance(
-          loop_types_or_target, Type) else loop_types_or_target
+      loop_types = (([loop_types_or_target] * num_loops) if isinstance(
+          loop_types_or_target, Type) else loop_types_or_target)
       target = target_or_none
 
     target = _get_op_result_or_value(target)
@@ -305,10 +293,10 @@ def __init__(self,
         target,
         dynamic_sizes=dynamic_sizes,
         static_sizes=sizes_attr,
-        interchange=_get_dense_int64_array_attr(interchange)
-        if interchange else None,
+        interchange=interchange,
         loc=loc,
-        ip=ip)
+        ip=ip,
+    )
 
   def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
     if not attr:
@@ -319,12 +307,14 @@ def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
 class VectorizeOp:
   """Specialization for VectorizeOp class."""
 
-  def __init__(self,
-               target: Union[Operation, Value],
-               *,
-               vectorize_padding: Union[bool, BoolAttr] = False,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      target: Union[Operation, Value],
+      *,
+      vectorize_padding: Union[bool, BoolAttr] = False,
+      loc=None,
+      ip=None,
+  ):
     pdl_operation_type = pdl.OperationType.get()
     if isinstance(vectorize_padding, bool):
       vectorize_padding = UnitAttr.get()
@@ -333,4 +323,5 @@ def __init__(self,
         _get_op_result_or_value(target),
         vectorize_padding=vectorize_padding,
         loc=loc,
-        ip=ip)
+        ip=ip,
+    )

diff  --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
index 593b8855c935f..8651c76ea7dfc 100644
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_transform_ops_ext.py
@@ -4,102 +4,119 @@
 
 try:
   from ..ir import *
-  from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+  from ._ods_common import (
+      get_op_result_or_value as _get_op_result_or_value,
+      get_op_results_or_values as _get_op_results_or_values,
+  )
 except ImportError as e:
   raise RuntimeError("Error loading imports from extension module") from e
 
-from argparse import SUPPRESS
-from typing import Optional, overload, Sequence, Union
-
-
-def _get_symbol_ref_attr(value: Union[Attribute, str]):
-  if isinstance(value, Attribute):
-    return value
-  return FlatSymbolRefAttr.get(value)
+from typing import Optional, Sequence, Union
 
 
 class CastOp:
 
-  def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
-    super().__init__(
-      result_type,
-      _get_op_result_or_value(target),
-      loc=loc,
-      ip=ip)
+  def __init__(self,
+               result_type: Type,
+               target: Union[Operation, Value],
+               *,
+               loc=None,
+               ip=None):
+    super().__init__(result_type,
+                     _get_op_result_or_value(target),
+                     loc=loc,
+                     ip=ip)
 
 
 class GetClosestIsolatedParentOp:
 
-  def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
-    super().__init__(
-        result_type,
-        _get_op_result_or_value(target),
-        loc=loc,
-        ip=ip)
-
-
-class MergeHandlesOp:
-
   def __init__(self,
-               handles: Sequence[Union[Operation, Value]],
+               result_type: Type,
+               target: Union[Operation, Value],
                *,
-               deduplicate: bool = False,
                loc=None,
                ip=None):
+    super().__init__(result_type,
+                     _get_op_result_or_value(target),
+                     loc=loc,
+                     ip=ip)
+
+
+class MergeHandlesOp:
+
+  def __init__(
+      self,
+      handles: Sequence[Union[Operation, Value]],
+      *,
+      deduplicate: bool = False,
+      loc=None,
+      ip=None,
+  ):
     super().__init__(
         [_get_op_result_or_value(h) for h in handles],
         deduplicate=deduplicate,
         loc=loc,
-        ip=ip)
+        ip=ip,
+    )
 
 
 class PDLMatchOp:
 
-  def __init__(self,
-               result_type: Type,
-               target: Union[Operation, Value],
-               pattern_name: Union[Attribute, str],
-               *,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      result_type: Type,
+      target: Union[Operation, Value],
+      pattern_name: Union[Attribute, str],
+      *,
+      loc=None,
+      ip=None,
+  ):
     super().__init__(
         result_type,
         _get_op_result_or_value(target),
-        _get_symbol_ref_attr(pattern_name),
+        pattern_name,
         loc=loc,
-        ip=ip)
+        ip=ip,
+    )
 
 
 class ReplicateOp:
 
-  def __init__(self,
-               pattern: Union[Operation, Value],
-               handles: Sequence[Union[Operation, Value]],
-               *,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      pattern: Union[Operation, Value],
+      handles: Sequence[Union[Operation, Value]],
+      *,
+      loc=None,
+      ip=None,
+  ):
     super().__init__(
         [_get_op_result_or_value(h).type for h in handles],
         _get_op_result_or_value(pattern),
         [_get_op_result_or_value(h) for h in handles],
         loc=loc,
-        ip=ip)
+        ip=ip,
+    )
 
 
 class SequenceOp:
 
-  def __init__(self, failure_propagation_mode, results: Sequence[Type],
-               target: Union[Operation, Value, Type],
-               extra_bindings: Optional[Union[Sequence[Value], Sequence[Type],
-                                              Operation, OpView]] = None):
-    root = _get_op_result_or_value(target) if isinstance(
-        target, (Operation, Value)) else None
+  def __init__(
+      self,
+      failure_propagation_mode,
+      results: Sequence[Type],
+      target: Union[Operation, Value, Type],
+      extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], Operation,
+                                     OpView]] = None,
+  ):
+    root = (_get_op_result_or_value(target) if isinstance(
+        target, (Operation, Value)) else None)
     root_type = root.type if not isinstance(target, Type) else target
     if not isinstance(failure_propagation_mode, Attribute):
       failure_propagation_mode_attr = IntegerAttr.get(
           IntegerType.get_signless(32), failure_propagation_mode._as_int())
     else:
-      failure_propagation_mode = failure_propagation_mode
+      failure_propagation_mode_attr = failure_propagation_mode
 
     if extra_bindings is None:
       extra_bindings = []
@@ -114,10 +131,12 @@ def __init__(self, failure_propagation_mode, results: Sequence[Type],
       else:
         extra_binding_types = [v.type for v in extra_bindings]
 
-    super().__init__(results_=results,
-                     failure_propagation_mode=failure_propagation_mode_attr,
-                     root=root,
-                     extra_bindings=extra_bindings)
+    super().__init__(
+        results_=results,
+        failure_propagation_mode=failure_propagation_mode_attr,
+        root=root,
+        extra_bindings=extra_bindings,
+    )
     self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
 
   @property
@@ -143,10 +162,7 @@ def __init__(self,
     root = _get_op_result_or_value(target) if not isinstance(target,
                                                              Type) else None
     root_type = target if isinstance(target, Type) else root.type
-    super().__init__(
-        root=root,
-        loc=loc,
-        ip=ip)
+    super().__init__(root=root, loc=loc, ip=ip)
     self.regions[0].blocks.append(root_type)
 
   @property
@@ -160,9 +176,13 @@ def bodyTarget(self) -> Value:
 
 class YieldOp:
 
-  def __init__(self,
-               operands: Union[Operation, Sequence[Value]] = [],
-               *,
-               loc=None,
-               ip=None):
+  def __init__(
+      self,
+      operands: Optional[Union[Operation, Sequence[Value]]] = None,
+      *,
+      loc=None,
+      ip=None,
+  ):
+    if operands is None:
+      operands = []
     super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 1e24fcbf99e40..714253426b025 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -8,9 +8,11 @@
 
 # Convenience decorator for registering user-friendly Attribute builders.
 def register_attribute_builder(kind):
+
   def decorator_builder(func):
     AttrBuilder.insert(kind, func)
     return func
+
   return decorator_builder
 
 
@@ -18,34 +20,77 @@ def decorator_builder(func):
 def _boolAttr(x, context):
   return BoolAttr.get(x, context=context)
 
+
 @register_attribute_builder("IndexAttr")
 def _indexAttr(x, context):
   return IntegerAttr.get(IndexType.get(context=context), x)
 
+
+ at register_attribute_builder("I16Attr")
+def _i32Attr(x, context):
+  return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
+
+
 @register_attribute_builder("I32Attr")
 def _i32Attr(x, context):
-  return IntegerAttr.get(
-      IntegerType.get_signless(32, context=context), x)
+  return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
+
 
 @register_attribute_builder("I64Attr")
 def _i64Attr(x, context):
-  return IntegerAttr.get(
-      IntegerType.get_signless(64, context=context), x)
+  return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
+
 
 @register_attribute_builder("StrAttr")
 def _stringAttr(x, context):
   return StringAttr.get(x, context=context)
 
+
 @register_attribute_builder("SymbolNameAttr")
 def _symbolNameAttr(x, context):
   return StringAttr.get(x, context=context)
 
+
+ at register_attribute_builder("SymbolRefAttr")
+def _symbolRefAttr(x, context):
+  return FlatSymbolRefAttr.get(x, context=context)
+
+
+ at register_attribute_builder("ArrayAttr")
+def _arrayAttr(x, context):
+  return ArrayAttr.get(x, context=context)
+
+
+ at register_attribute_builder("I64ArrayAttr")
+def _i64ArrayAttr(x, context):
+  return ArrayAttr.get([_i64Attr(v, context) for v in x])
+
+
+ at register_attribute_builder("DenseI64ArrayAttr")
+def _denseI64ArrayAttr(x, context):
+  return DenseI64ArrayAttr.get(x, context=context)
+
+
+ at register_attribute_builder("TypeAttr")
+def _typeAttr(x, context):
+  return TypeAttr.get(x, context=context)
+
+
+ at register_attribute_builder("TypeArrayAttr")
+def _typeArrayAttr(x, context):
+  return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
+
+
 try:
   import numpy as np
+
   @register_attribute_builder("IndexElementsAttr")
   def _indexElementsAttr(x, context):
     return DenseElementsAttr.get(
-        np.array(x, dtype=np.int64), type=IndexType.get(context=context),
-        context=context)
+        np.array(x, dtype=np.int64),
+        type=IndexType.get(context=context),
+        context=context,
+    )
+
 except ImportError:
   pass


        


More information about the Mlir-commits mailing list