[Mlir-commits] [mlir] [mlir][python] update type stubs (PR #75099)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 11 13:23:27 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
So I definitely didn't mean to get sucked into this (I just wanted to add `StridedLayoutAttr`) but here we are: I regenerated `ir.pyi` using [`pybind11-stubgen`](https://github.com/sizmailov/pybind11-stubgen) instead of mypy's `stubgen`. It did a pretty good job! Of course I had to spot check and patch things up...
---
Patch is 108.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75099.diff
1 Files Affected:
- (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+2044-481)
``````````diff
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 2609117dd220be..577222ce79a9ea 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -1,15 +1,63 @@
# Originally imported via:
-# stubgen {...} -m mlir._mlir_libs._mlir.ir
+# pybind11-stubgen --print-invalid-expressions-as-is mlir._mlir_libs._mlir.ir
+# but with the following diff (in order to remove pipes from types,
+# which we won't support until bumping minimum python to 3.10)
+#
+# --------------------- diff begins ------------------------------------
+#
+# diff --git a/pybind11_stubgen/printer.py b/pybind11_stubgen/printer.py
+# index 1f755aa..4924927 100644
+# --- a/pybind11_stubgen/printer.py
+# +++ b/pybind11_stubgen/printer.py
+# @@ -283,14 +283,6 @@ class Printer:
+# return split[0] + "..."
+#
+# def print_type(self, type_: ResolvedType) -> str:
+# - if (
+# - str(type_.name) == "typing.Optional"
+# - and type_.parameters is not None
+# - and len(type_.parameters) == 1
+# - ):
+# - return f"{self.print_annotation(type_.parameters[0])} | None"
+# - if str(type_.name) == "typing.Union" and type_.parameters is not None:
+# - return " | ".join(self.print_annotation(p) for p in type_.parameters)
+# if type_.parameters:
+# param_str = (
+# "["
+#
+# --------------------- diff ends ------------------------------------
+#
# Local modifications:
-# * Rewrite references to 'mlir.ir.' to local types
-# * Add __all__ with the following incantation:
-# egrep '^class ' ir.pyi | awk -F ' |:|\\(' '{print " \"" $2 "\","}'
+# * Rewrite references to '' to local types.
+# * Drop `typing.` everywhere (top-level import instead).
+# * List -> List, dict -> Dict, Tuple -> Tuple.
+# * copy-paste Buffer type from
+# * Shuffle _OperationBase, AffineExpr, Attribute, Type, Value to the top.
+# * Patch raw C++ types (like "PyAsmState") with a regex like `Py(.*)`.
+# * _BaseContext -> Context, MlirType -> Type, MlirTypeID -> TypeID, MlirAttribute -> Attribute.
# * Local edits to signatures and types that MyPy did not auto detect (or
# detected incorrectly).
+# * Add MLIRError, _GlobalDebug, _OperationBase to __all__ by hand.
+# * Fill in `Any`s where possible.
+# * black formatting.
+from __future__ import annotations
+
+import abc
+import collections
+import io
from typing import (
- Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple,
- Type as _Type, TypeVar
+ Any,
+ Callable,
+ ClassVar,
+ Dict,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type as _Type,
+ TypeVar,
+ Union,
)
from typing import overload
@@ -30,6 +78,8 @@ __all__ = [
"AffineSymbolExpr",
"ArrayAttr",
"ArrayAttributeIterator",
+ "AsmState",
+ "AttrBuilder",
"Attribute",
"BF16Type",
"Block",
@@ -40,29 +90,44 @@ __all__ = [
"BoolAttr",
"ComplexType",
"Context",
+ "DenseBoolArrayAttr",
+ "DenseBoolArrayIterator",
"DenseElementsAttr",
+ "DenseF32ArrayAttr",
+ "DenseF32ArrayIterator",
+ "DenseF64ArrayAttr",
+ "DenseF64ArrayIterator",
"DenseFPElementsAttr",
+ "DenseI16ArrayAttr",
+ "DenseI16ArrayIterator",
+ "DenseI32ArrayAttr",
+ "DenseI32ArrayIterator",
+ "DenseI64ArrayAttr",
+ "DenseI64ArrayIterator",
+ "DenseI8ArrayAttr",
+ "DenseI8ArrayIterator",
"DenseIntElementsAttr",
"DenseResourceElementsAttr",
- "Dialect",
- "DialectDescriptor",
- "Dialects",
"Diagnostic",
"DiagnosticHandler",
"DiagnosticInfo",
"DiagnosticSeverity",
+ "Dialect",
+ "DialectDescriptor",
+ "DialectRegistry",
+ "Dialects",
"DictAttr",
- "Float8E4M3FNType",
- "Float8E5M2Type",
- "Float8E4M3FNUZType",
- "Float8E4M3B11FNUZType",
- "Float8E5M2FNUZType",
"F16Type",
- "FloatTF32Type",
"F32Type",
"F64Type",
"FlatSymbolRefAttr",
+ "Float8E4M3B11FNUZType",
+ "Float8E4M3FNType",
+ "Float8E4M3FNUZType",
+ "Float8E5M2FNUZType",
+ "Float8E5M2Type",
"FloatAttr",
+ "FloatTF32Type",
"FunctionType",
"IndexType",
"InferShapedTypeOpInterface",
@@ -76,15 +141,18 @@ __all__ = [
"Location",
"MemRefType",
"Module",
- "MLIRError",
"NamedAttribute",
"NoneType",
- "OpaqueType",
"OpAttributeMap",
+ "OpOperand",
+ "OpOperandIterator",
"OpOperandList",
"OpResult",
"OpResultList",
+ "OpSuccessors",
"OpView",
+ "OpaqueAttr",
+ "OpaqueType",
"Operation",
"OperationIterator",
"OperationList",
@@ -94,11 +162,14 @@ __all__ = [
"RegionSequence",
"ShapedType",
"ShapedTypeComponents",
+ "StridedLayoutAttr",
"StringAttr",
+ "SymbolRefAttr",
"SymbolTable",
"TupleType",
"Type",
"TypeAttr",
+ "TypeID",
"UnitAttr",
"UnrankedMemRefType",
"UnrankedTensorType",
@@ -108,222 +179,561 @@ __all__ = [
"_OperationBase",
]
-# Base classes: declared first to simplify declarations below.
+if hasattr(collections.abc, "Buffer"):
+ Buffer = collections.abc.Buffer
+else:
+ class Buffer(abc.ABC):
+ pass
+
class _OperationBase:
- def detach_from_parent(self) -> OpView: ...
- def get_asm(self, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> object: ...
- def move_after(self, other: _OperationBase) -> None: ...
- def move_before(self, other: _OperationBase) -> None: ...
- def print(self, file: Optional[Any] = None, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> None: ...
- def verify(self) -> bool: ...
@overload
def __eq__(self, arg0: _OperationBase) -> bool: ...
@overload
- def __eq__(self, arg0: object) -> bool: ...
+ def __eq__(self, arg0: _OperationBase) -> bool: ...
def __hash__(self) -> int: ...
+ def __str__(self) -> str:
+ """
+ Returns the assembly form of the operation.
+ """
+ def clone(self, ip: InsertionPoint = None) -> OpView: ...
+ def detach_from_parent(self) -> OpView:
+ """
+ Detaches the operation from its parent block.
+ """
+ def erase(self) -> None: ...
+ def get_asm(
+ self,
+ binary: bool = False,
+ large_elements_limit: Optional[int] = None,
+ enable_debug_info: bool = False,
+ pretty_debug_info: bool = False,
+ print_generic_op_form: bool = False,
+ use_local_scope: bool = False,
+ assume_verified: bool = False,
+ ) -> Union[io.BytesIO, io.StringIO]:
+ """
+ Gets the assembly form of the operation with all options available.
+
+ Args:
+ binary: Whether to return a bytes (True) or str (False) object. Defaults to
+ False.
+ ... others ...: See the print() method for common keyword arguments for
+ configuring the printout.
+ Returns:
+ Either a bytes or str object, depending on the setting of the 'binary'
+ argument.
+ """
+ def move_after(self, other: _OperationBase) -> None:
+ """
+ Puts self immediately after the other operation in its parent block.
+ """
+ def move_before(self, other: _OperationBase) -> None:
+ """
+ Puts self immediately before the other operation in its parent block.
+ """
+ @overload
+ def print(
+ self,
+ state: AsmState,
+ file: Optional[Any] = None,
+ binary: bool = False,
+ ) -> None:
+ """
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ state: AsmState capturing the operation numbering and flags.
+ """
+ @overload
+ def print(
+ self,
+ large_elements_limit: Optional[int] = None,
+ enable_debug_info: bool = False,
+ pretty_debug_info: bool = False,
+ print_generic_op_form: bool = False,
+ use_local_scope: bool = False,
+ assume_verified: bool = False,
+ file: Optional[Any] = None,
+ binary: bool = False,
+ ) -> None:
+ """
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ large_elements_limit: Whether to elide elements attributes above this
+ number of elements. Defaults to None (no limit).
+ enable_debug_info: Whether to print debug/location information. Defaults
+ to False.
+ pretty_debug_info: Whether to format debug information for easier reading
+ by a human (warning: the result is unparseable).
+ print_generic_op_form: Whether to print the generic assembly forms of all
+ ops. Defaults to False.
+ use_local_Scope: Whether to print in a way that is more optimized for
+ multi-threaded access but may not be consistent with how the overall
+ module prints.
+ assume_verified: By default, if not printing generic form, the verifier
+ will be run and if it fails, generic form will be printed with a comment
+ about failed verification. While a reasonable default for interactive use,
+ for systematic use, it is often better for the caller to verify explicitly
+ and report failures in a more robust fashion. Set this to True if doing this
+ in order to avoid running a redundant verification. If the IR is actually
+ invalid, behavior is undefined.
+ """
+ def verify(self) -> bool:
+ """
+ Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
+ """
+ def write_bytecode(self, file: Any, desired_version: Optional[int] = None) -> None:
+ """
+ Write the bytecode form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to.
+ desired_version: The version of bytecode to emit.
+ Returns:
+ The bytecode writer status.
+ """
@property
def _CAPIPtr(self) -> object: ...
@property
def attributes(self) -> OpAttributeMap: ...
@property
- def context(self) -> Context: ...
+ def context(self) -> Context:
+ """
+ Context that owns the Operation
+ """
@property
- def location(self) -> Location: ...
+ def location(self) -> Location:
+ """
+ Returns the source location the operation was defined or derived from.
+ """
@property
def name(self) -> str: ...
@property
def operands(self) -> OpOperandList: ...
@property
- @property
def parent(self) -> Optional[_OperationBase]: ...
+ @property
def regions(self) -> RegionSequence: ...
@property
- def result(self) -> OpResult: ...
+ def result(self) -> OpResult:
+ """
+ Shortcut to get an op result if it has only one (throws an error otherwise).
+ """
@property
- def results(self) -> OpResultList: ...
+ def results(self) -> OpResultList:
+ """
+ Returns the List of Operation results.
+ """
_TOperation = TypeVar("_TOperation", bound=_OperationBase)
-# TODO: Auto-generated. Audit and fix.
class AffineExpr:
- def __init__(self, *args, **kwargs) -> None: ...
+ @staticmethod
+ @overload
+ def get_add(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr:
+ """
+ Gets an affine expression containing a sum of two expressions.
+ """
+ @staticmethod
+ @overload
+ def get_add(arg0: int, arg1: AffineExpr) -> AffineAddExpr:
+ """
+ Gets an affine expression containing a sum of a constant and another expression.
+ """
+ @staticmethod
+ @overload
+ def get_add(arg0: AffineExpr, arg1: int) -> AffineAddExpr:
+ """
+ Gets an affine expression containing a sum of an expression and a constant.
+ """
+ @staticmethod
+ @overload
+ def get_ceil_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr:
+ """
+ Gets an affine expression containing the rounded-up result of dividing one expression by another.
+ """
+ @staticmethod
+ @overload
+ def get_ceil_div(arg0: int, arg1: AffineExpr) -> AffineCeilDivExpr:
+ """
+ Gets a semi-affine expression containing the rounded-up result of dividing a constant by an expression.
+ """
+ @staticmethod
+ @overload
+ def get_ceil_div(arg0: AffineExpr, arg1: int) -> AffineCeilDivExpr:
+ """
+ Gets an affine expression containing the rounded-up result of dividing an expression by a constant.
+ """
+ @staticmethod
+ def get_constant(
+ value: int, context: Optional[Context] = None
+ ) -> AffineConstantExpr:
+ """
+ Gets a constant affine expression with the given value.
+ """
+ @staticmethod
+ def get_dim(position: int, context: Optional[Context] = None) -> AffineDimExpr:
+ """
+ Gets an affine expression of a dimension at the given position.
+ """
+ @staticmethod
+ @overload
+ def get_floor_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr:
+ """
+ Gets an affine expression containing the rounded-down result of dividing one expression by another.
+ """
+ @staticmethod
+ @overload
+ def get_floor_div(arg0: int, arg1: AffineExpr) -> AffineFloorDivExpr:
+ """
+ Gets a semi-affine expression containing the rounded-down result of dividing a constant by an expression.
+ """
+ @staticmethod
+ @overload
+ def get_floor_div(arg0: AffineExpr, arg1: int) -> AffineFloorDivExpr:
+ """
+ Gets an affine expression containing the rounded-down result of dividing an expression by a constant.
+ """
+ @staticmethod
+ @overload
+ def get_mod(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr:
+ """
+ Gets an affine expression containing the modulo of dividing one expression by another.
+ """
+ @staticmethod
+ @overload
+ def get_mod(arg0: int, arg1: AffineExpr) -> AffineModExpr:
+ """
+ Gets a semi-affine expression containing the modulo of dividing a constant by an expression.
+ """
+ @staticmethod
+ @overload
+ def get_mod(arg0: AffineExpr, arg1: int) -> AffineModExpr:
+ """
+ Gets an affine expression containing the module of dividingan expression by a constant.
+ """
+ @staticmethod
+ @overload
+ def get_mul(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr:
+ """
+ Gets an affine expression containing a product of two expressions.
+ """
+ @staticmethod
+ @overload
+ def get_mul(arg0: int, arg1: AffineExpr) -> AffineMulExpr:
+ """
+ Gets an affine expression containing a product of a constant and another expression.
+ """
+ @staticmethod
+ @overload
+ def get_mul(arg0: AffineExpr, arg1: int) -> AffineMulExpr:
+ """
+ Gets an affine expression containing a product of an expression and a constant.
+ """
+ @staticmethod
+ def get_symbol(
+ position: int, context: Optional[Context] = None
+ ) -> AffineSymbolExpr:
+ """
+ Gets an affine expression of a symbol at the given position.
+ """
def _CAPICreate(self) -> AffineExpr: ...
- def compose(self, arg0) -> AffineExpr: ...
- def dump(self) -> None: ...
- def get_add(self, *args, **kwargs) -> Any: ...
- def get_ceil_div(self, *args, **kwargs) -> Any: ...
- def get_constant(self, *args, **kwargs) -> Any: ...
- def get_dim(self, *args, **kwargs) -> Any: ...
- def get_floor_div(self, *args, **kwargs) -> Any: ...
- def get_mod(self, *args, **kwargs) -> Any: ...
- def get_mul(self, *args, **kwargs) -> Any: ...
- def get_symbol(self, *args, **kwargs) -> Any: ...
- def __add__(self, other) -> Any: ...
+ @overload
+ def __add__(self, arg0: AffineExpr) -> AffineAddExpr: ...
+ @overload
+ def __add__(self, arg0: int) -> AffineAddExpr: ...
@overload
def __eq__(self, arg0: AffineExpr) -> bool: ...
@overload
- def __eq__(self, arg0: object) -> bool: ...
+ def __eq__(self, arg0: Any) -> bool: ...
def __hash__(self) -> int: ...
- def __mod__(self, other) -> Any: ...
- def __mul__(self, other) -> Any: ...
- def __radd__(self, other) -> Any: ...
- def __rmod__(self, other) -> Any: ...
- def __rmul__(self, other) -> Any: ...
- def __rsub__(self, other) -> Any: ...
- def __sub__(self, other) -> Any: ...
+ @overload
+ def __mod__(self, arg0: AffineExpr) -> AffineModExpr: ...
+ @overload
+ def __mod__(self, arg0: int) -> AffineModExpr: ...
+ @overload
+ def __mul__(self, arg0: AffineExpr) -> AffineMulExpr: ...
+ @overload
+ def __mul__(self, arg0: int) -> AffineMulExpr: ...
+ def __radd__(self, arg0: int) -> AffineAddExpr: ...
+ def __repr__(self) -> str: ...
+ def __rmod__(self, arg0: int) -> AffineModExpr: ...
+ def __rmul__(self, arg0: int) -> AffineMulExpr: ...
+ def __rsub__(self, arg0: int) -> AffineAddExpr: ...
+ def __str__(self) -> str: ...
+ @overload
+ def __sub__(self, arg0: AffineExpr) -> AffineAddExpr: ...
+ @overload
+ def __sub__(self, arg0: int) -> AffineAddExpr: ...
+ def compose(self, arg0: AffineMap) -> AffineExpr: ...
+ def dump(self) -> None:
+ """
+ Dumps a debug representation of the object to stderr.
+ """
@property
def _CAPIPtr(self) -> object: ...
@property
def context(self) -> Context: ...
class Attribute:
- def __init__(self, cast_from_type: Attribute) -> None: ...
- def _CAPICreate(self) -> Attribute: ...
- def dump(self) -> None: ...
- def get_named(self, *args, **kwargs) -> Any: ...
@staticmethod
- def parse(asm: str, context: Optional[Context] = None) -> Any: ...
+ def parse(asm: str, context: Optional[Context] = None) -> Attribute:
+ """
+ Parses an attribute from an assembly form. Raises an MLIRError on failure.
+ """
+ def _CAPICreate(self) -> Attribute: ...
@overload
def __eq__(self, arg0: Attribute) -> bool: ...
@overload
def __eq__(self, arg0: object) -> bool: ...
def __hash__(self) -> int: ...
+ def __init__(self, cast_from_type: Attribute) -> None:
+ """
+ Casts the passed attribute to the generic Attribute
+ """
+ def __repr__(self) -> str: ...
+ def __str__(self) -> str:
+ """
+ Returns the assembly form of the Attribute.
+ """
+ def dump(self) -> None:
+ """
+ Dumps a debug representation of the object to stderr.
+ """
+ def get_named(self, arg0: str) -> NamedAttribute:
+ """
+ Binds a name to the attribute
+ """
+ def maybe_downcast(self) -> Any: ...
@property
def _CAPIPtr(self) -> object: ...
@property
- def context(self) -> Context: ...
+ def context(self) -> Context:
+ """
+ Context that owns the Attribute
+ """
@property
def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
class Type:
- def __init__(self, cast_from_type: Type) -> None: ...
- def _CAPICreate(self) -> Type: ...
- def dump(self) -> None: ...
@staticmethod
- def parse(asm: str, context: Optional[Context] = None) -> Type: ...
+ def parse(asm: str, context: Optional[Context] = None) -> Type:
+ """
+ Parses the assembly form of a type.
+
+ Returns a Type object or raises an MLI...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/75099
More information about the Mlir-commits
mailing list