[Mlir-commits] [mlir] [mlir][python] update type stubs (PR #75099)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 11 13:23:27 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

So I definitely didn't mean to get sucked into this (I just wanted to add `StridedLayoutAttr`) but here we are: I regenerated `ir.pyi` using [`pybind11-stubgen`](https://github.com/sizmailov/pybind11-stubgen) instead of mypy's `stubgen`. It did a pretty good job! Of course I had to spot check and patch things up...

---

Patch is 108.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75099.diff


1 Files Affected:

- (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+2044-481) 


``````````diff
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 2609117dd220be..577222ce79a9ea 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -1,15 +1,63 @@
 # Originally imported via:
-#   stubgen {...} -m mlir._mlir_libs._mlir.ir
+#   pybind11-stubgen --print-invalid-expressions-as-is mlir._mlir_libs._mlir.ir
+# but with the following diff (in order to remove pipes from types,
+# which we won't support until bumping minimum python to 3.10)
+#
+# --------------------- diff begins ------------------------------------
+#
+# diff --git a/pybind11_stubgen/printer.py b/pybind11_stubgen/printer.py
+# index 1f755aa..4924927 100644
+# --- a/pybind11_stubgen/printer.py
+# +++ b/pybind11_stubgen/printer.py
+# @@ -283,14 +283,6 @@ class Printer:
+#              return split[0] + "..."
+#
+#      def print_type(self, type_: ResolvedType) -> str:
+# -        if (
+# -            str(type_.name) == "typing.Optional"
+# -            and type_.parameters is not None
+# -            and len(type_.parameters) == 1
+# -        ):
+# -            return f"{self.print_annotation(type_.parameters[0])} | None"
+# -        if str(type_.name) == "typing.Union" and type_.parameters is not None:
+# -            return " | ".join(self.print_annotation(p) for p in type_.parameters)
+#          if type_.parameters:
+#              param_str = (
+#                  "["
+#
+# --------------------- diff ends ------------------------------------
+#
 # Local modifications:
-#   * Rewrite references to 'mlir.ir.' to local types
-#   * Add __all__ with the following incantation:
-#       egrep '^class ' ir.pyi | awk -F ' |:|\\(' '{print "    \"" $2 "\","}'
+#   * Rewrite references to '' to local types.
+#   * Drop `typing.` everywhere (top-level import instead).
+#   * List -> List, dict -> Dict, Tuple -> Tuple.
+#   * copy-paste Buffer type from
+#   * Shuffle _OperationBase, AffineExpr, Attribute, Type, Value to the top.
+#   * Patch raw C++ types (like "PyAsmState") with a regex like `Py(.*)`.
+#   * _BaseContext -> Context, MlirType -> Type, MlirTypeID -> TypeID, MlirAttribute -> Attribute.
 #   * Local edits to signatures and types that MyPy did not auto detect (or
 #     detected incorrectly).
+#   * Add MLIRError, _GlobalDebug, _OperationBase to __all__ by hand.
+#   * Fill in `Any`s where possible.
+#   * black formatting.
 
+from __future__ import annotations
+
+import abc
+import collections
+import io
 from typing import (
-    Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple,
-    Type as _Type, TypeVar
+    Any,
+    Callable,
+    ClassVar,
+    Dict,
+    List,
+    Optional,
+    Sequence,
+    Tuple,
+    Type as _Type,
+    TypeVar,
+    Union,
 )
 
 from typing import overload
@@ -30,6 +78,8 @@ __all__ = [
     "AffineSymbolExpr",
     "ArrayAttr",
     "ArrayAttributeIterator",
+    "AsmState",
+    "AttrBuilder",
     "Attribute",
     "BF16Type",
     "Block",
@@ -40,29 +90,44 @@ __all__ = [
     "BoolAttr",
     "ComplexType",
     "Context",
+    "DenseBoolArrayAttr",
+    "DenseBoolArrayIterator",
     "DenseElementsAttr",
+    "DenseF32ArrayAttr",
+    "DenseF32ArrayIterator",
+    "DenseF64ArrayAttr",
+    "DenseF64ArrayIterator",
     "DenseFPElementsAttr",
+    "DenseI16ArrayAttr",
+    "DenseI16ArrayIterator",
+    "DenseI32ArrayAttr",
+    "DenseI32ArrayIterator",
+    "DenseI64ArrayAttr",
+    "DenseI64ArrayIterator",
+    "DenseI8ArrayAttr",
+    "DenseI8ArrayIterator",
     "DenseIntElementsAttr",
     "DenseResourceElementsAttr",
-    "Dialect",
-    "DialectDescriptor",
-    "Dialects",
     "Diagnostic",
     "DiagnosticHandler",
     "DiagnosticInfo",
     "DiagnosticSeverity",
+    "Dialect",
+    "DialectDescriptor",
+    "DialectRegistry",
+    "Dialects",
     "DictAttr",
-    "Float8E4M3FNType",
-    "Float8E5M2Type",
-    "Float8E4M3FNUZType",
-    "Float8E4M3B11FNUZType",
-    "Float8E5M2FNUZType",
     "F16Type",
-    "FloatTF32Type",
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
+    "Float8E4M3B11FNUZType",
+    "Float8E4M3FNType",
+    "Float8E4M3FNUZType",
+    "Float8E5M2FNUZType",
+    "Float8E5M2Type",
     "FloatAttr",
+    "FloatTF32Type",
     "FunctionType",
     "IndexType",
     "InferShapedTypeOpInterface",
@@ -76,15 +141,18 @@ __all__ = [
     "Location",
     "MemRefType",
     "Module",
-    "MLIRError",
     "NamedAttribute",
     "NoneType",
-    "OpaqueType",
     "OpAttributeMap",
+    "OpOperand",
+    "OpOperandIterator",
     "OpOperandList",
     "OpResult",
     "OpResultList",
+    "OpSuccessors",
     "OpView",
+    "OpaqueAttr",
+    "OpaqueType",
     "Operation",
     "OperationIterator",
     "OperationList",
@@ -94,11 +162,14 @@ __all__ = [
     "RegionSequence",
     "ShapedType",
     "ShapedTypeComponents",
+    "StridedLayoutAttr",
     "StringAttr",
+    "SymbolRefAttr",
     "SymbolTable",
     "TupleType",
     "Type",
     "TypeAttr",
+    "TypeID",
     "UnitAttr",
     "UnrankedMemRefType",
     "UnrankedTensorType",
@@ -108,222 +179,561 @@ __all__ = [
     "_OperationBase",
 ]
 
-# Base classes: declared first to simplify declarations below.
+if hasattr(collections.abc, "Buffer"):
+    Buffer = collections.abc.Buffer
+else:
+    class Buffer(abc.ABC):
+        pass
+
 class _OperationBase:
-    def detach_from_parent(self) -> OpView: ...
-    def get_asm(self, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> object: ...
-    def move_after(self, other: _OperationBase) -> None: ...
-    def move_before(self, other: _OperationBase) -> None: ...
-    def print(self, file: Optional[Any] = None, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> None: ...
-    def verify(self) -> bool: ...
     @overload
     def __eq__(self, arg0: _OperationBase) -> bool: ...
     @overload
-    def __eq__(self, arg0: object) -> bool: ...
+    def __eq__(self, arg0: _OperationBase) -> bool: ...
     def __hash__(self) -> int: ...
+    def __str__(self) -> str:
+        """
+        Returns the assembly form of the operation.
+        """
+    def clone(self, ip: InsertionPoint = None) -> OpView: ...
+    def detach_from_parent(self) -> OpView:
+        """
+        Detaches the operation from its parent block.
+        """
+    def erase(self) -> None: ...
+    def get_asm(
+        self,
+        binary: bool = False,
+        large_elements_limit: Optional[int] = None,
+        enable_debug_info: bool = False,
+        pretty_debug_info: bool = False,
+        print_generic_op_form: bool = False,
+        use_local_scope: bool = False,
+        assume_verified: bool = False,
+    ) -> Union[io.BytesIO, io.StringIO]:
+        """
+        Gets the assembly form of the operation with all options available.
+
+        Args:
+          binary: Whether to return a bytes (True) or str (False) object. Defaults to
+            False.
+          ... others ...: See the print() method for common keyword arguments for
+            configuring the printout.
+        Returns:
+          Either a bytes or str object, depending on the setting of the 'binary'
+          argument.
+        """
+    def move_after(self, other: _OperationBase) -> None:
+        """
+        Puts self immediately after the other operation in its parent block.
+        """
+    def move_before(self, other: _OperationBase) -> None:
+        """
+        Puts self immediately before the other operation in its parent block.
+        """
+    @overload
+    def print(
+        self,
+        state: AsmState,
+        file: Optional[Any] = None,
+        binary: bool = False,
+    ) -> None:
+        """
+        Prints the assembly form of the operation to a file like object.
+
+        Args:
+          file: The file like object to write to. Defaults to sys.stdout.
+          binary: Whether to write bytes (True) or str (False). Defaults to False.
+          state: AsmState capturing the operation numbering and flags.
+        """
+    @overload
+    def print(
+        self,
+        large_elements_limit: Optional[int] = None,
+        enable_debug_info: bool = False,
+        pretty_debug_info: bool = False,
+        print_generic_op_form: bool = False,
+        use_local_scope: bool = False,
+        assume_verified: bool = False,
+        file: Optional[Any] = None,
+        binary: bool = False,
+    ) -> None:
+        """
+        Prints the assembly form of the operation to a file like object.
+
+        Args:
+          file: The file like object to write to. Defaults to sys.stdout.
+          binary: Whether to write bytes (True) or str (False). Defaults to False.
+          large_elements_limit: Whether to elide elements attributes above this
+            number of elements. Defaults to None (no limit).
+          enable_debug_info: Whether to print debug/location information. Defaults
+            to False.
+          pretty_debug_info: Whether to format debug information for easier reading
+            by a human (warning: the result is unparseable).
+          print_generic_op_form: Whether to print the generic assembly forms of all
+            ops. Defaults to False.
+          use_local_Scope: Whether to print in a way that is more optimized for
+            multi-threaded access but may not be consistent with how the overall
+            module prints.
+          assume_verified: By default, if not printing generic form, the verifier
+            will be run and if it fails, generic form will be printed with a comment
+            about failed verification. While a reasonable default for interactive use,
+            for systematic use, it is often better for the caller to verify explicitly
+            and report failures in a more robust fashion. Set this to True if doing this
+            in order to avoid running a redundant verification. If the IR is actually
+            invalid, behavior is undefined.
+        """
+    def verify(self) -> bool:
+        """
+        Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
+        """
+    def write_bytecode(self, file: Any, desired_version: Optional[int] = None) -> None:
+        """
+        Write the bytecode form of the operation to a file like object.
+
+        Args:
+          file: The file like object to write to.
+          desired_version: The version of bytecode to emit.
+        Returns:
+          The bytecode writer status.
+        """
     @property
     def _CAPIPtr(self) -> object: ...
     @property
     def attributes(self) -> OpAttributeMap: ...
     @property
-    def context(self) -> Context: ...
+    def context(self) -> Context:
+        """
+        Context that owns the Operation
+        """
     @property
-    def location(self) -> Location: ...
+    def location(self) -> Location:
+        """
+        Returns the source location the operation was defined or derived from.
+        """
     @property
     def name(self) -> str: ...
     @property
     def operands(self) -> OpOperandList: ...
     @property
-    @property
     def parent(self) -> Optional[_OperationBase]: ...
+    @property
     def regions(self) -> RegionSequence: ...
     @property
-    def result(self) -> OpResult: ...
+    def result(self) -> OpResult:
+        """
+        Shortcut to get an op result if it has only one (throws an error otherwise).
+        """
     @property
-    def results(self) -> OpResultList: ...
+    def results(self) -> OpResultList:
+        """
+        Returns the List of Operation results.
+        """
 
 _TOperation = TypeVar("_TOperation", bound=_OperationBase)
 
-# TODO: Auto-generated. Audit and fix.
 class AffineExpr:
-    def __init__(self, *args, **kwargs) -> None: ...
+    @staticmethod
+    @overload
+    def get_add(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr:
+        """
+        Gets an affine expression containing a sum of two expressions.
+        """
+    @staticmethod
+    @overload
+    def get_add(arg0: int, arg1: AffineExpr) -> AffineAddExpr:
+        """
+        Gets an affine expression containing a sum of a constant and another expression.
+        """
+    @staticmethod
+    @overload
+    def get_add(arg0: AffineExpr, arg1: int) -> AffineAddExpr:
+        """
+        Gets an affine expression containing a sum of an expression and a constant.
+        """
+    @staticmethod
+    @overload
+    def get_ceil_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr:
+        """
+        Gets an affine expression containing the rounded-up result of dividing one expression by another.
+        """
+    @staticmethod
+    @overload
+    def get_ceil_div(arg0: int, arg1: AffineExpr) -> AffineCeilDivExpr:
+        """
+        Gets a semi-affine expression containing the rounded-up result of dividing a constant by an expression.
+        """
+    @staticmethod
+    @overload
+    def get_ceil_div(arg0: AffineExpr, arg1: int) -> AffineCeilDivExpr:
+        """
+        Gets an affine expression containing the rounded-up result of dividing an expression by a constant.
+        """
+    @staticmethod
+    def get_constant(
+        value: int, context: Optional[Context] = None
+    ) -> AffineConstantExpr:
+        """
+        Gets a constant affine expression with the given value.
+        """
+    @staticmethod
+    def get_dim(position: int, context: Optional[Context] = None) -> AffineDimExpr:
+        """
+        Gets an affine expression of a dimension at the given position.
+        """
+    @staticmethod
+    @overload
+    def get_floor_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr:
+        """
+        Gets an affine expression containing the rounded-down result of dividing one expression by another.
+        """
+    @staticmethod
+    @overload
+    def get_floor_div(arg0: int, arg1: AffineExpr) -> AffineFloorDivExpr:
+        """
+        Gets a semi-affine expression containing the rounded-down result of dividing a constant by an expression.
+        """
+    @staticmethod
+    @overload
+    def get_floor_div(arg0: AffineExpr, arg1: int) -> AffineFloorDivExpr:
+        """
+        Gets an affine expression containing the rounded-down result of dividing an expression by a constant.
+        """
+    @staticmethod
+    @overload
+    def get_mod(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr:
+        """
+        Gets an affine expression containing the modulo of dividing one expression by another.
+        """
+    @staticmethod
+    @overload
+    def get_mod(arg0: int, arg1: AffineExpr) -> AffineModExpr:
+        """
+        Gets a semi-affine expression containing the modulo of dividing a constant by an expression.
+        """
+    @staticmethod
+    @overload
+    def get_mod(arg0: AffineExpr, arg1: int) -> AffineModExpr:
+        """
+        Gets an affine expression containing the module of dividingan expression by a constant.
+        """
+    @staticmethod
+    @overload
+    def get_mul(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr:
+        """
+        Gets an affine expression containing a product of two expressions.
+        """
+    @staticmethod
+    @overload
+    def get_mul(arg0: int, arg1: AffineExpr) -> AffineMulExpr:
+        """
+        Gets an affine expression containing a product of a constant and another expression.
+        """
+    @staticmethod
+    @overload
+    def get_mul(arg0: AffineExpr, arg1: int) -> AffineMulExpr:
+        """
+        Gets an affine expression containing a product of an expression and a constant.
+        """
+    @staticmethod
+    def get_symbol(
+        position: int, context: Optional[Context] = None
+    ) -> AffineSymbolExpr:
+        """
+        Gets an affine expression of a symbol at the given position.
+        """
     def _CAPICreate(self) -> AffineExpr: ...
-    def compose(self, arg0) -> AffineExpr: ...
-    def dump(self) -> None: ...
-    def get_add(self, *args, **kwargs) -> Any: ...
-    def get_ceil_div(self, *args, **kwargs) -> Any: ...
-    def get_constant(self, *args, **kwargs) -> Any: ...
-    def get_dim(self, *args, **kwargs) -> Any: ...
-    def get_floor_div(self, *args, **kwargs) -> Any: ...
-    def get_mod(self, *args, **kwargs) -> Any: ...
-    def get_mul(self, *args, **kwargs) -> Any: ...
-    def get_symbol(self, *args, **kwargs) -> Any: ...
-    def __add__(self, other) -> Any: ...
+    @overload
+    def __add__(self, arg0: AffineExpr) -> AffineAddExpr: ...
+    @overload
+    def __add__(self, arg0: int) -> AffineAddExpr: ...
     @overload
     def __eq__(self, arg0: AffineExpr) -> bool: ...
     @overload
-    def __eq__(self, arg0: object) -> bool: ...
+    def __eq__(self, arg0: Any) -> bool: ...
     def __hash__(self) -> int: ...
-    def __mod__(self, other) -> Any: ...
-    def __mul__(self, other) -> Any: ...
-    def __radd__(self, other) -> Any: ...
-    def __rmod__(self, other) -> Any: ...
-    def __rmul__(self, other) -> Any: ...
-    def __rsub__(self, other) -> Any: ...
-    def __sub__(self, other) -> Any: ...
+    @overload
+    def __mod__(self, arg0: AffineExpr) -> AffineModExpr: ...
+    @overload
+    def __mod__(self, arg0: int) -> AffineModExpr: ...
+    @overload
+    def __mul__(self, arg0: AffineExpr) -> AffineMulExpr: ...
+    @overload
+    def __mul__(self, arg0: int) -> AffineMulExpr: ...
+    def __radd__(self, arg0: int) -> AffineAddExpr: ...
+    def __repr__(self) -> str: ...
+    def __rmod__(self, arg0: int) -> AffineModExpr: ...
+    def __rmul__(self, arg0: int) -> AffineMulExpr: ...
+    def __rsub__(self, arg0: int) -> AffineAddExpr: ...
+    def __str__(self) -> str: ...
+    @overload
+    def __sub__(self, arg0: AffineExpr) -> AffineAddExpr: ...
+    @overload
+    def __sub__(self, arg0: int) -> AffineAddExpr: ...
+    def compose(self, arg0: AffineMap) -> AffineExpr: ...
+    def dump(self) -> None:
+        """
+        Dumps a debug representation of the object to stderr.
+        """
     @property
     def _CAPIPtr(self) -> object: ...
     @property
     def context(self) -> Context: ...
 
 class Attribute:
-    def __init__(self, cast_from_type: Attribute) -> None: ...
-    def _CAPICreate(self) -> Attribute: ...
-    def dump(self) -> None: ...
-    def get_named(self, *args, **kwargs) -> Any: ...
     @staticmethod
-    def parse(asm: str, context: Optional[Context] = None) -> Any: ...
+    def parse(asm: str, context: Optional[Context] = None) -> Attribute:
+        """
+        Parses an attribute from an assembly form. Raises an MLIRError on failure.
+        """
+    def _CAPICreate(self) -> Attribute: ...
     @overload
     def __eq__(self, arg0: Attribute) -> bool: ...
     @overload
     def __eq__(self, arg0: object) -> bool: ...
     def __hash__(self) -> int: ...
+    def __init__(self, cast_from_type: Attribute) -> None:
+        """
+        Casts the passed attribute to the generic Attribute
+        """
+    def __repr__(self) -> str: ...
+    def __str__(self) -> str:
+        """
+        Returns the assembly form of the Attribute.
+        """
+    def dump(self) -> None:
+        """
+        Dumps a debug representation of the object to stderr.
+        """
+    def get_named(self, arg0: str) -> NamedAttribute:
+        """
+        Binds a name to the attribute
+        """
+    def maybe_downcast(self) -> Any: ...
     @property
     def _CAPIPtr(self) -> object: ...
     @property
-    def context(self) -> Context: ...
+    def context(self) -> Context:
+        """
+        Context that owns the Attribute
+        """
     @property
     def type(self) -> Type: ...
+    @property
+    def typeid(self) -> TypeID: ...
 
 class Type:
-    def __init__(self, cast_from_type: Type) -> None: ...
-    def _CAPICreate(self) -> Type: ...
-    def dump(self) -> None: ...
     @staticmethod
-    def parse(asm: str, context: Optional[Context] = None) -> Type: ...
+    def parse(asm: str, context: Optional[Context] = None) -> Type:
+        """
+        Parses the assembly form of a type.
+
+        Returns a Type object or raises an MLI...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/75099


More information about the Mlir-commits mailing list