[Mlir-commits] [mlir] 65b2f24 - Fixed mypy type errors in MLIR Python type stubs

Alex Zinenko llvmlistbot at llvm.org
Thu Mar 31 02:56:32 PDT 2022


Author: Sergei Lebedev
Date: 2022-03-31T11:56:26+02:00
New Revision: 65b2f24c50f016828a96b9495cee491a4241f9b9

URL: https://github.com/llvm/llvm-project/commit/65b2f24c50f016828a96b9495cee491a4241f9b9
DIFF: https://github.com/llvm/llvm-project/commit/65b2f24c50f016828a96b9495cee491a4241f9b9.diff

LOG: Fixed mypy type errors in MLIR Python type stubs

This commit fixes or disables all errors reported by

    python3 -m mypy -p mlir --show-error-codes

Note that unhashable types cannot be currently expressed in a way compatible
with typeshed. See https://github.com/python/typeshed/issues/6243 for details.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D122790

Added: 
    

Modified: 
    mlir/python/mlir/_mlir_libs/__init__.py
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 4e2e5f453bc58..23bc502677353 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -2,13 +2,19 @@
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Sequence
+from typing import Any, Sequence
 
 import os
 
 _this_dir = os.path.dirname(__file__)
 
 
+# These submodules have no type stubs and are thus opaque to the type checker.
+_mlirConversions: Any
+_mlirTransforms: Any
+_mlirAllPassesRegistration: Any
+
+
 def get_lib_dirs() -> Sequence[str]:
   """Gets the lib directory for linking to shared libraries.
 

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 5bfb9202ecb8e..7b1667fa50214 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -7,7 +7,10 @@
 #   * Local edits to signatures and types that MyPy did not auto detect (or
 #     detected incorrectly).
 
-from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple
+from typing import (
+    Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple,
+    Type as _Type, TypeVar
+)
 
 from typing import overload
 
@@ -121,6 +124,8 @@ class _OperationBase:
     @property
     def results(self) -> OpResultList: ...
 
+_TOperation = TypeVar("_TOperation", bound=_OperationBase)
+
 # TODO: Auto-generated. Audit and fix.
 class AffineExpr:
     def __init__(self, *args, **kwargs) -> None: ...
@@ -379,7 +384,7 @@ class BF16Type(Type):
     def isinstance(arg: Any) -> bool: ...
 
 class Block:
-    __hash__: ClassVar[None] = ...
+    __hash__: ClassVar[None] = ...  # type: ignore
     def append(self, operation: _OperationBase) -> None: ...
     def create_after(self, *args: Type) -> Block: ...
     @staticmethod
@@ -406,7 +411,7 @@ class BlockArgument(Value):
     @property
     def arg_number(self) -> int: ...
     @property
-    def owner(self) -> Block: ...
+    def owner(self) -> Block: ...  # type: ignore[override]
 
 class BlockArgumentList:
     def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ...
@@ -463,7 +468,7 @@ class Context:
     def _get_live_operation_count(self) -> int: ...
     def attach_diagnostic_handler(self, callback: Callable[[Diagnostic], bool]) -> DiagnosticHandler: ...
     def enable_multithreading(self, enable: bool) -> None: ...
-    def get_dialect_descriptor(dialect_name: str) -> DialectDescriptor: ...
+    def get_dialect_descriptor(self, dialect_name: str) -> DialectDescriptor: ...
     def is_registered_operation(self, operation_name: str) -> bool: ...
     def __enter__(self) -> Context: ...
     def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
@@ -748,7 +753,7 @@ class IntegerType(Type):
 
 class Location:
     current: ClassVar[Location] = ...  # read-only
-    __hash__: ClassVar[None] = ...
+    __hash__: ClassVar[None] = ...  # type: ignore
     def _CAPICreate(self) -> Location: ...
     @staticmethod
     def callsite(callee: Location, frames: Sequence[Location], context: Optional[Context] = None) -> Location: ...
@@ -787,6 +792,7 @@ class MemRefType(ShapedType):
 
 class Module:
     def _CAPICreate(self) -> object: ...
+    @staticmethod
     def create(loc: Optional[Location] = None) -> Module: ...
     def dump(self) -> None: ...
     @staticmethod
@@ -858,17 +864,19 @@ class OpView(_OperationBase):
     _ODS_RESULT_SEGMENTS: ClassVar[None] = ...
     def __init__(self, operation: _OperationBase) -> None: ...
     @classmethod
-    def build_generic(cls, results: Optional[Sequence[Type]] = None,
+    def build_generic(
+        cls: _Type[_TOperation],
+        results: Optional[Sequence[Type]] = None,
         operands: Optional[Sequence[Value]] = None,
         attributes: Optional[Dict[str, Attribute]] = None,
         successors: Optional[Sequence[Block]] = None,
         regions: Optional[int] = None,
         loc: Optional[Location] = None,
-        ip: Optional[InsertionPoint] = None) -> _OperationBase: ...
+        ip: Optional[InsertionPoint] = None) -> _TOperation: ...
     @property
     def context(self) -> Context: ...
     @property
-    def operation(self) -> _OperationBase: ...
+    def operation(self) -> Operation: ...
 
 class Operation(_OperationBase):
     def _CAPICreate(self) -> object: ...
@@ -912,7 +920,7 @@ class RankedTensorType(ShapedType):
     def encoding(self) -> Optional[Attribute]: ...
 
 class Region:
-    __hash__: ClassVar[None] = ...
+    __hash__: ClassVar[None] = ...  # type: ignore
     @overload
     def __eq__(self, arg0: Region) -> bool: ...
     @overload


        


More information about the Mlir-commits mailing list