[clang] Fix all mypy --strict errors in clang python binding (PR #101784)

via cfe-commits cfe-commits at lists.llvm.org
Fri Aug 2 20:26:27 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: None (TsXor)

<details>
<summary>Changes</summary>

Related: https://github.com/llvm/llvm-project/issues/76664

I used metadata reflection so that we can import C library functions just by declaring annotated python functions. This makes C function types visible to type checker, then it's easy to fix most typing errors.

---

Patch is 153.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101784.diff


5 Files Affected:

- (modified) clang/bindings/python/clang/cindex.py (+1145-751) 
- (added) clang/bindings/python/clang/ctyped.py (+334) 
- (modified) clang/bindings/python/tests/cindex/test_type.py (+1-1) 
- (added) clang/bindings/python/tests/ctyped/__init__.py () 
- (added) clang/bindings/python/tests/ctyped/test_stub_conversion.py (+359) 


``````````diff
diff --git a/clang/bindings/python/clang/cindex.py b/clang/bindings/python/clang/cindex.py
index 2038ef6045c7d..521dc2829ae41 100644
--- a/clang/bindings/python/clang/cindex.py
+++ b/clang/bindings/python/clang/cindex.py
@@ -62,36 +62,50 @@
 #
 # o implement additional SourceLocation, SourceRange, and File methods.
 
-from ctypes import *
+from ctypes import (c_byte, c_ubyte, c_short, c_ushort, c_int, c_uint, c_long,  # pyright: ignore[reportUnusedImport]
+                    c_ulong, c_longlong,c_ulonglong, c_size_t, c_ssize_t,  # pyright: ignore[reportUnusedImport]
+                    c_bool, c_char, c_wchar, c_float, c_double, c_longdouble,  # pyright: ignore[reportUnusedImport]
+                    c_char_p, c_wchar_p, c_void_p)  # pyright: ignore[reportUnusedImport]
+from ctypes import py_object, Structure, POINTER, byref, cast, cdll
+from .ctyped import *
+from .ctyped import ANNO_CONVERTIBLE, generate_metadata
 
 import os
 import sys
 from enum import Enum
 
 from typing import (
+    cast as tcast,
     Any,
     Callable,
+    Dict,
+    Generator,
     Generic,
+    Iterator,
+    List,
     Optional,
+    Tuple,
     Type as TType,
     TypeVar,
     TYPE_CHECKING,
     Union as TUnion,
 )
 
+from typing_extensions import Annotated
+
 if TYPE_CHECKING:
-    from ctypes import _Pointer
-    from typing_extensions import Protocol, TypeAlias
+    from typing_extensions import Protocol, Self, TypeAlias
+    from ctypes import CDLL
 
     StrPath: TypeAlias = TUnion[str, os.PathLike[str]]
-    LibFunc: TypeAlias = TUnion[
-        "tuple[str, Optional[list[Any]]]",
-        "tuple[str, Optional[list[Any]], Any]",
-        "tuple[str, Optional[list[Any]], Any, Callable[..., Any]]",
-    ]
-
+    StrOrBytes: TypeAlias = TUnion[str, bytes]
+    FsPath: TypeAlias = TUnion[StrOrBytes, os.PathLike[str]]
     TSeq = TypeVar("TSeq", covariant=True)
 
+    class SupportsReadStringData(Protocol):
+        def read(self) -> str | bytes:
+            ...
+
     class NoSliceSequence(Protocol[TSeq]):
         def __len__(self) -> int:
             ...
@@ -102,7 +116,7 @@ def __getitem__(self, key: int) -> TSeq:
 
 # Python 3 strings are unicode, translate them to/from utf8 for C-interop.
 class c_interop_string(c_char_p):
-    def __init__(self, p: str | bytes | None = None):
+    def __init__(self, p: 'CInteropString' = None):
         if p is None:
             p = ""
         if isinstance(p, str):
@@ -120,7 +134,7 @@ def value(self) -> str | None:  # type: ignore [override]
         return val.decode("utf8")
 
     @classmethod
-    def from_param(cls, param: str | bytes | None) -> c_interop_string:
+    def from_param(cls, param: 'CInteropString') -> c_interop_string:
         if isinstance(param, str):
             return cls(param)
         if isinstance(param, bytes):
@@ -136,6 +150,8 @@ def from_param(cls, param: str | bytes | None) -> c_interop_string:
     def to_python_string(x: c_interop_string, *args: Any) -> str | None:
         return x.value
 
+CInteropString = Annotated[TUnion[str, bytes, None], ANNO_CONVERTIBLE, c_interop_string]
+
 
 def b(x: str | bytes) -> bytes:
     if isinstance(x, bytes):
@@ -147,7 +163,8 @@ def b(x: str | bytes) -> bytes:
 # object. This is a problem, because it means that from_parameter will see an
 # integer and pass the wrong value on platforms where int != void*. Work around
 # this by marshalling object arguments as void**.
-c_object_p: TType[_Pointer[Any]] = POINTER(c_void_p)
+CObjectP = CPointer[c_void_p]
+c_object_p: TType[CObjectP] = convert_annotation(CObjectP)
 
 ### Exception Classes ###
 
@@ -183,7 +200,7 @@ class TranslationUnitSaveError(Exception):
     # Indicates that the translation unit was somehow invalid.
     ERROR_INVALID_TU = 3
 
-    def __init__(self, enumeration, message):
+    def __init__(self, enumeration: int, message: str):
         assert isinstance(enumeration, int)
 
         if enumeration < 1 or enumeration > 3:
@@ -241,7 +258,7 @@ def __del__(self) -> None:
         conf.lib.clang_disposeString(self)
 
     @staticmethod
-    def from_result(res: _CXString, fn: Any = None, args: Any = None) -> str:
+    def from_result(res: _CXString, fn: Optional[Callable[..., _CXString]] = None, args: Optional[Tuple[Any, ...]] = None) -> str:
         assert isinstance(res, _CXString)
         pystr: str | None = conf.lib.clang_getCString(res)
         if pystr is None:
@@ -255,71 +272,73 @@ class SourceLocation(Structure):
     """
 
     _fields_ = [("ptr_data", c_void_p * 2), ("int_data", c_uint)]
-    _data = None
+    _data: Optional[Tuple[Optional[File], int, int, int]] = None
 
-    def _get_instantiation(self):
+    def _get_instantiation(self) -> Tuple[Optional[File], int, int, int]:
         if self._data is None:
-            f, l, c, o = c_object_p(), c_uint(), c_uint(), c_uint()
+            fp, l, c, o = c_object_p(), c_uint(), c_uint(), c_uint()
             conf.lib.clang_getInstantiationLocation(
-                self, byref(f), byref(l), byref(c), byref(o)
+                self, byref(fp), byref(l), byref(c), byref(o)
             )
-            if f:
-                f = File(f)
+            if fp:
+                f = File(fp)
             else:
                 f = None
             self._data = (f, int(l.value), int(c.value), int(o.value))
         return self._data
 
     @staticmethod
-    def from_position(tu, file, line, column):
+    def from_position(tu: TranslationUnit, file: File, line: int, column: int) -> SourceLocation:
         """
         Retrieve the source location associated with a given file/line/column in
         a particular translation unit.
         """
-        return conf.lib.clang_getLocation(tu, file, line, column)  # type: ignore [no-any-return]
+        return conf.lib.clang_getLocation(tu, file, line, column)
 
     @staticmethod
-    def from_offset(tu, file, offset):
+    def from_offset(tu: TranslationUnit, file: File, offset: int) -> SourceLocation:
         """Retrieve a SourceLocation from a given character offset.
 
         tu -- TranslationUnit file belongs to
         file -- File instance to obtain offset from
         offset -- Integer character offset within file
         """
-        return conf.lib.clang_getLocationForOffset(tu, file, offset)  # type: ignore [no-any-return]
+        return conf.lib.clang_getLocationForOffset(tu, file, offset)
 
     @property
-    def file(self):
+    def file(self) -> Optional[File]:
         """Get the file represented by this source location."""
         return self._get_instantiation()[0]
 
     @property
-    def line(self):
+    def line(self) -> int:
         """Get the line represented by this source location."""
         return self._get_instantiation()[1]
 
     @property
-    def column(self):
+    def column(self) -> int:
         """Get the column represented by this source location."""
         return self._get_instantiation()[2]
 
     @property
-    def offset(self):
+    def offset(self) -> int:
         """Get the file offset represented by this source location."""
         return self._get_instantiation()[3]
 
     @property
-    def is_in_system_header(self):
+    def is_in_system_header(self) -> bool:
         """Returns true if the given source location is in a system header."""
-        return conf.lib.clang_Location_isInSystemHeader(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_Location_isInSystemHeader(self)
 
-    def __eq__(self, other):
-        return conf.lib.clang_equalLocations(self, other)  # type: ignore [no-any-return]
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, SourceLocation):
+            return NotImplemented
+        return conf.lib.clang_equalLocations(self, other)
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return not self.__eq__(other)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         if self.file:
             filename = self.file.name
         else:
@@ -346,40 +365,43 @@ class SourceRange(Structure):
     # FIXME: Eliminate this and make normal constructor? Requires hiding ctypes
     # object.
     @staticmethod
-    def from_locations(start, end):
-        return conf.lib.clang_getRange(start, end)  # type: ignore [no-any-return]
+    def from_locations(start: SourceLocation, end: SourceLocation) -> SourceRange:
+        return conf.lib.clang_getRange(start, end)
 
     @property
-    def start(self):
+    def start(self) -> SourceLocation:
         """
         Return a SourceLocation representing the first character within a
         source range.
         """
-        return conf.lib.clang_getRangeStart(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_getRangeStart(self)
 
     @property
-    def end(self):
+    def end(self) -> SourceLocation:
         """
         Return a SourceLocation representing the last character within a
         source range.
         """
-        return conf.lib.clang_getRangeEnd(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_getRangeEnd(self)
 
-    def __eq__(self, other):
-        return conf.lib.clang_equalRanges(self, other)  # type: ignore [no-any-return]
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, SourceRange):
+            return NotImplemented
+        return conf.lib.clang_equalRanges(self, other)
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return not self.__eq__(other)
 
-    def __contains__(self, other):
+    def __contains__(self, other: object) -> bool:
         """Useful to detect the Token/Lexer bug"""
         if not isinstance(other, SourceLocation):
             return False
-        if other.file is None and self.start.file is None:
-            pass
-        elif (
-            self.start.file.name != other.file.name
-            or other.file.name != self.end.file.name
+        if (
+            other.file is not None
+            and self.start.file is not None
+            and self.end.file is not None
+            and (other.file.name != self.start.file.name
+                 or other.file.name != self.end.file.name)
         ):
             # same file name
             return False
@@ -396,7 +418,7 @@ def __contains__(self, other):
                 return True
         return False
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<SourceRange start %r, end %r>" % (self.start, self.end)
 
 
@@ -421,23 +443,25 @@ class Diagnostic:
     DisplayCategoryName = 0x20
     _FormatOptionsMask = 0x3F
 
-    def __init__(self, ptr):
+    ptr: CObjectP
+
+    def __init__(self, ptr: CObjectP):
         self.ptr = ptr
 
-    def __del__(self):
+    def __del__(self) -> None:
         conf.lib.clang_disposeDiagnostic(self)
 
     @property
-    def severity(self):
-        return conf.lib.clang_getDiagnosticSeverity(self)  # type: ignore [no-any-return]
+    def severity(self) -> int:
+        return conf.lib.clang_getDiagnosticSeverity(self)
 
     @property
-    def location(self):
-        return conf.lib.clang_getDiagnosticLocation(self)  # type: ignore [no-any-return]
+    def location(self) -> SourceLocation:
+        return conf.lib.clang_getDiagnosticLocation(self)
 
     @property
-    def spelling(self):
-        return conf.lib.clang_getDiagnosticSpelling(self)  # type: ignore [no-any-return]
+    def spelling(self) -> str:
+        return conf.lib.clang_getDiagnosticSpelling(self)
 
     @property
     def ranges(self) -> NoSliceSequence[SourceRange]:
@@ -451,7 +475,7 @@ def __len__(self) -> int:
             def __getitem__(self, key: int) -> SourceRange:
                 if key >= len(self):
                     raise IndexError
-                return conf.lib.clang_getDiagnosticRange(self.diag, key)  # type: ignore [no-any-return]
+                return conf.lib.clang_getDiagnosticRange(self.diag, key)
 
         return RangeIterator(self)
 
@@ -492,28 +516,28 @@ def __getitem__(self, key: int) -> Diagnostic:
         return ChildDiagnosticsIterator(self)
 
     @property
-    def category_number(self):
+    def category_number(self) -> int:
         """The category number for this diagnostic or 0 if unavailable."""
-        return conf.lib.clang_getDiagnosticCategory(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_getDiagnosticCategory(self)
 
     @property
-    def category_name(self):
+    def category_name(self) -> str:
         """The string name of the category for this diagnostic."""
-        return conf.lib.clang_getDiagnosticCategoryText(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_getDiagnosticCategoryText(self)
 
     @property
-    def option(self):
+    def option(self) -> str:
         """The command-line option that enables this diagnostic."""
-        return conf.lib.clang_getDiagnosticOption(self, None)  # type: ignore [no-any-return]
+        return conf.lib.clang_getDiagnosticOption(self, None)
 
     @property
-    def disable_option(self):
+    def disable_option(self) -> str:
         """The command-line option that disables this diagnostic."""
         disable = _CXString()
         conf.lib.clang_getDiagnosticOption(self, byref(disable))
         return _CXString.from_result(disable)
 
-    def format(self, options=None):
+    def format(self, options: Optional[int] = None) -> str:
         """
         Format this diagnostic for display. The options argument takes
         Diagnostic.Display* flags, which can be combined using bitwise OR. If
@@ -524,19 +548,19 @@ def format(self, options=None):
             options = conf.lib.clang_defaultDiagnosticDisplayOptions()
         if options & ~Diagnostic._FormatOptionsMask:
             raise ValueError("Invalid format options")
-        return conf.lib.clang_formatDiagnostic(self, options)  # type: ignore [no-any-return]
+        return conf.lib.clang_formatDiagnostic(self, options)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<Diagnostic severity %r, location %r, spelling %r>" % (
             self.severity,
             self.location,
             self.spelling,
         )
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self.format()
 
-    def from_param(self):
+    def from_param(self) -> CObjectP:
         return self.ptr
 
 
@@ -547,11 +571,14 @@ class FixIt:
     with the given value.
     """
 
-    def __init__(self, range, value):
+    range: SourceRange
+    value: str
+
+    def __init__(self, range: SourceRange, value: str):
         self.range = range
         self.value = value
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<FixIt range %r, value %r>" % (self.range, self.value)
 
 
@@ -570,16 +597,20 @@ class TokenGroup:
     You should not instantiate this class outside of this module.
     """
 
-    def __init__(self, tu, memory, count):
+    _tu: TranslationUnit
+    _memory: CPointer[Token]
+    _count: c_uint
+
+    def __init__(self, tu: TranslationUnit, memory: CPointer[Token], count: c_uint):
         self._tu = tu
         self._memory = memory
         self._count = count
 
-    def __del__(self):
+    def __del__(self) -> None:
         conf.lib.clang_disposeTokens(self._tu, self._memory, self._count)
 
     @staticmethod
-    def get_tokens(tu, extent):
+    def get_tokens(tu: TranslationUnit, extent: SourceRange) -> Generator[Token, None, None]:
         """Helper method to return all tokens in an extent.
 
         This functionality is needed multiple places in this module. We define
@@ -616,16 +647,16 @@ class BaseEnumeration(Enum):
     """
     Common base class for named enumerations held in sync with Index.h values.
     """
+    value: int # pyright: ignore[reportIncompatibleMethodOverride]
 
-
-    def from_param(self):
+    def from_param(self) -> int:
         return self.value
 
     @classmethod
-    def from_id(cls, id):
+    def from_id(cls, id: int) -> Self:
         return cls(id)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "%s.%s" % (
             self.__class__.__name__,
             self.name,
@@ -636,7 +667,7 @@ class TokenKind(BaseEnumeration):
     """Describes a specific type of a Token."""
 
     @classmethod
-    def from_value(cls, value):
+    def from_value(cls, value: int) -> Self:
         """Obtain a registered TokenKind instance from its value."""
         return cls.from_id(value)
 
@@ -653,45 +684,44 @@ class CursorKind(BaseEnumeration):
     """
 
     @staticmethod
-    def get_all_kinds():
+    def get_all_kinds() -> List[CursorKind]:
         """Return all CursorKind enumeration instances."""
         return list(CursorKind)
 
-    def is_declaration(self):
+    def is_declaration(self) -> bool:
         """Test if this is a declaration kind."""
-        return conf.lib.clang_isDeclaration(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_isDeclaration(self)
 
-    def is_reference(self):
+    def is_reference(self) -> bool:
         """Test if this is a reference kind."""
-        return conf.lib.clang_isReference(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_isReference(self)
 
-    def is_expression(self):
+    def is_expression(self) -> bool:
         """Test if this is an expression kind."""
-        return conf.lib.clang_isExpression(self)  # type: ignore [no-any-return]
-
-    def is_statement(self):
+        return conf.lib.clang_isExpression(self)
+    def is_statement(self) -> bool:
         """Test if this is a statement kind."""
-        return conf.lib.clang_isStatement(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_isStatement(self)
 
-    def is_attribute(self):
+    def is_attribute(self) -> bool:
         """Test if this is an attribute kind."""
-        return conf.lib.clang_isAttribute(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_isAttribute(self)
 
-    def is_invalid(self):
+    def is_invalid(self) -> bool:
         """Test if this is an invalid kind."""
-        return conf.lib.clang_isInvalid(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_isInvalid(self)
 
-    def is_translation_unit(self):
+    def is_translation_unit(self) -> bool:
         """Test if this is a translation unit kind."""
-        return conf.lib.clang_isTranslationUnit(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_isTranslationUnit(self)
 
-    def is_preprocessing(self):
+    def is_preprocessing(self) -> bool:
         """Test if this is a preprocessing kind."""
-        return conf.lib.clang_isPreprocessing(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_isPreprocessing(self)
 
-    def is_unexposed(self):
+    def is_unexposed(self) -> bool:
         """Test if this is an unexposed kind."""
-        return conf.lib.clang_isUnexposed(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_isUnexposed(self)
 
 
     ###
@@ -1555,7 +1585,7 @@ class Cursor(Structure):
     _fields_ = [("_kind_id", c_int), ("xdata", c_int), ("data", c_void_p * 3)]
 
     @staticmethod
-    def from_location(tu, location):
+    def from_location(tu: TranslationUnit, location: SourceLocation) -> Cursor:
         # We store a reference to the TU in the instance so the TU won't get
         # collected before the cursor.
         cursor = conf.lib.clang_getCursor(tu, location)
@@ -1563,54 +1593,56 @@ def from_location(tu, location):
 
         return cursor
 
-    def __eq__(self, other):
-        return conf.lib.clang_equalCursors(self, other)  # type: ignore [no-any-return]
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, Cursor):
+            return NotImplemented
+        return conf.lib.clang_equalCursors(self, other)
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return not self.__eq__(other)
 
-    def is_definition(self):
+    def is_definition(self) -> bool:
         """
         Returns true if the declaration pointed at by the cursor is also a
         definition of that entity.
         """
-        return conf.lib.clang_isCursorDefinition(self)  # type: ignore [no-any-return]
+        return conf.lib.clang_isCursorDefinition(self)
 
-    def is_const_method(self):
+    def is_const_method(self) -> bool:
         """Returns True if the cursor ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/101784


More information about the cfe-commits mailing list