[clang] [libclang/python] Fix some type errors, add type annotations (PR #98745)

Jannick Kremer via cfe-commits cfe-commits at lists.llvm.org
Sat Jul 13 07:06:29 PDT 2024


https://github.com/DeinAlptraum updated https://github.com/llvm/llvm-project/pull/98745

>From 2c31f3fe5d232381b868e96158be6f2acf7da1c6 Mon Sep 17 00:00:00 2001
From: Jannick Kremer <jannick.kremer at mailbox.org>
Date: Sat, 13 Jul 2024 14:12:34 +0100
Subject: [PATCH] [libclang/python] Fix some type errors, add type annotations

---
 clang/bindings/python/clang/cindex.py         | 192 +++++++++++-------
 .../tests/cindex/test_code_completion.py      |  22 +-
 .../python/tests/cindex/test_comment.py       |   4 +-
 3 files changed, 127 insertions(+), 91 deletions(-)

diff --git a/clang/bindings/python/clang/cindex.py b/clang/bindings/python/clang/cindex.py
index 1d3ab89190407..9b50192068213 100644
--- a/clang/bindings/python/clang/cindex.py
+++ b/clang/bindings/python/clang/cindex.py
@@ -43,7 +43,7 @@
 Most object information is exposed using properties, when the underlying API
 call is efficient.
 """
-from __future__ import absolute_import, division, print_function
+from __future__ import annotations
 
 # TODO
 # ====
@@ -64,48 +64,78 @@
 
 from ctypes import *
 
-import collections.abc
 import os
+import sys
 from enum import Enum
 
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Optional,
+    Type as TType,
+    TypeVar,
+    TYPE_CHECKING,
+    Union as TUnion,
+)
+from typing_extensions import Protocol, TypeAlias
+
+if TYPE_CHECKING:
+    from ctypes import _Pointer
+
+    StrPath: TypeAlias = TUnion[str, os.PathLike[str]]
+    LibFunc: TypeAlias = TUnion[
+        "tuple[str, Optional[list[Any]]]",
+        "tuple[str, Optional[list[Any]], Any]",
+        "tuple[str, Optional[list[Any]], Any, Callable[..., Any]]",
+    ]
+    CObjP: TypeAlias = _Pointer[Any]
+
+    TSeq = TypeVar("TSeq", covariant=True)
+
+    class NoSliceSequence(Protocol[TSeq]):
+        def __len__(self) -> int: ...
+        def __getitem__(self, key: int) -> TSeq: ...
+
 
 # Python 3 strings are unicode, translate them to/from utf8 for C-interop.
 class c_interop_string(c_char_p):
-    def __init__(self, p=None):
+    def __init__(self, p: str | bytes | None = None):
         if p is None:
             p = ""
         if isinstance(p, str):
             p = p.encode("utf8")
         super(c_char_p, self).__init__(p)
 
-    def __str__(self):
-        return self.value
+    def __str__(self) -> str:
+        return self.value or ""
 
     @property
-    def value(self):
-        if super(c_char_p, self).value is None:
+    def value(self) -> str | None:  # type: ignore [override]
+        val = super(c_char_p, self).value
+        if val is None:
             return None
-        return super(c_char_p, self).value.decode("utf8")
+        return val.decode("utf8")
 
     @classmethod
-    def from_param(cls, param):
+    def from_param(cls, param: str | bytes | None) -> c_interop_string:
         if isinstance(param, str):
             return cls(param)
         if isinstance(param, bytes):
             return cls(param)
         if param is None:
             # Support passing null to C functions expecting char arrays
-            return None
+            return cls(param)
         raise TypeError(
             "Cannot convert '{}' to '{}'".format(type(param).__name__, cls.__name__)
         )
 
     @staticmethod
-    def to_python_string(x, *args):
+    def to_python_string(x: c_interop_string, *args: Any) -> str | None:
         return x.value
 
 
-def b(x):
+def b(x: str | bytes) -> bytes:
     if isinstance(x, bytes):
         return x
     return x.encode("utf8")
@@ -115,9 +145,7 @@ def b(x):
 # object. This is a problem, because it means that from_parameter will see an
 # integer and pass the wrong value on platforms where int != void*. Work around
 # this by marshalling object arguments as void**.
-c_object_p = POINTER(c_void_p)
-
-callbacks = {}
+c_object_p: TType[CObjP] = POINTER(c_void_p)
 
 ### Exception Classes ###
 
@@ -169,8 +197,11 @@ def __init__(self, enumeration, message):
 
 ### Structures and Utility Classes ###
 
+TInstance = TypeVar("TInstance")
+TResult = TypeVar("TResult")
+
 
-class CachedProperty:
+class CachedProperty(Generic[TInstance, TResult]):
     """Decorator that lazy-loads the value of a property.
 
     The first time the property is accessed, the original property function is
@@ -178,16 +209,20 @@ class CachedProperty:
     property, replacing the original method.
     """
 
-    def __init__(self, wrapped):
+    def __init__(self, wrapped: Callable[[TInstance], TResult]):
         self.wrapped = wrapped
         try:
             self.__doc__ = wrapped.__doc__
         except:
             pass
 
-    def __get__(self, instance, instance_type=None):
+    def __get__(self, instance: TInstance, instance_type: Any = None) -> TResult:
         if instance is None:
-            return self
+            property_name = self.wrapped.__name__
+            class_name = instance_type.__name__
+            raise TypeError(
+                f"'{property_name}' is not a static attribute of '{class_name}'"
+            )
 
         value = self.wrapped(instance)
         setattr(instance, self.wrapped.__name__, value)
@@ -200,13 +235,16 @@ class _CXString(Structure):
 
     _fields_ = [("spelling", c_char_p), ("free", c_int)]
 
-    def __del__(self):
+    def __del__(self) -> None:
         conf.lib.clang_disposeString(self)
 
     @staticmethod
-    def from_result(res, fn=None, args=None):
+    def from_result(res: _CXString, fn: Any = None, args: Any = None) -> str:
         assert isinstance(res, _CXString)
-        return conf.lib.clang_getCString(res)
+        pystr: str | None = conf.lib.clang_getCString(res)
+        if pystr is None:
+            return ""
+        return pystr
 
 
 class SourceLocation(Structure):
@@ -400,15 +438,15 @@ def spelling(self):
         return conf.lib.clang_getDiagnosticSpelling(self)
 
     @property
-    def ranges(self):
+    def ranges(self) -> NoSliceSequence[SourceRange]:
         class RangeIterator:
-            def __init__(self, diag):
+            def __init__(self, diag: Diagnostic):
                 self.diag = diag
 
-            def __len__(self):
+            def __len__(self) -> int:
                 return int(conf.lib.clang_getDiagnosticNumRanges(self.diag))
 
-            def __getitem__(self, key):
+            def __getitem__(self, key: int) -> SourceRange:
                 if key >= len(self):
                     raise IndexError
                 return conf.lib.clang_getDiagnosticRange(self.diag, key)
@@ -416,15 +454,15 @@ def __getitem__(self, key):
         return RangeIterator(self)
 
     @property
-    def fixits(self):
+    def fixits(self) -> NoSliceSequence[FixIt]:
         class FixItIterator:
-            def __init__(self, diag):
+            def __init__(self, diag: Diagnostic):
                 self.diag = diag
 
-            def __len__(self):
+            def __len__(self) -> int:
                 return int(conf.lib.clang_getDiagnosticNumFixIts(self.diag))
 
-            def __getitem__(self, key):
+            def __getitem__(self, key: int) -> FixIt:
                 range = SourceRange()
                 value = conf.lib.clang_getDiagnosticFixIt(self.diag, key, byref(range))
                 if len(value) == 0:
@@ -435,15 +473,15 @@ def __getitem__(self, key):
         return FixItIterator(self)
 
     @property
-    def children(self):
+    def children(self) -> NoSliceSequence[Diagnostic]:
         class ChildDiagnosticsIterator:
-            def __init__(self, diag):
+            def __init__(self, diag: Diagnostic):
                 self.diag_set = conf.lib.clang_getChildDiagnostics(diag)
 
-            def __len__(self):
+            def __len__(self) -> int:
                 return int(conf.lib.clang_getNumDiagnosticsInSet(self.diag_set))
 
-            def __getitem__(self, key):
+            def __getitem__(self, key: int) -> Diagnostic:
                 diag = conf.lib.clang_getDiagnosticInSet(self.diag_set, key)
                 if not diag:
                     raise IndexError
@@ -2030,8 +2068,8 @@ def visitor(child, parent, children):
             children.append(child)
             return 1  # continue
 
-        children = []
-        conf.lib.clang_visitChildren(self, callbacks["cursor_visit"](visitor), children)
+        children: list[Cursor] = []
+        conf.lib.clang_visitChildren(self, cursor_visit_callback(visitor), children)
         return iter(children)
 
     def walk_preorder(self):
@@ -2318,25 +2356,25 @@ def kind(self):
         """Return the kind of this type."""
         return TypeKind.from_id(self._kind_id)
 
-    def argument_types(self):
+    def argument_types(self) -> NoSliceSequence[Type]:
         """Retrieve a container for the non-variadic arguments for this type.
 
         The returned object is iterable and indexable. Each item in the
         container is a Type instance.
         """
 
-        class ArgumentsIterator(collections.abc.Sequence):
-            def __init__(self, parent):
+        class ArgumentsIterator:
+            def __init__(self, parent: Type):
                 self.parent = parent
-                self.length = None
+                self.length: int | None = None
 
-            def __len__(self):
+            def __len__(self) -> int:
                 if self.length is None:
                     self.length = conf.lib.clang_getNumArgTypes(self.parent)
 
                 return self.length
 
-            def __getitem__(self, key):
+            def __getitem__(self, key: int) -> Type:
                 # FIXME Support slice objects.
                 if not isinstance(key, int):
                     raise TypeError("Must supply a non-negative int.")
@@ -2350,7 +2388,7 @@ def __getitem__(self, key):
                         "%d > %d" % (key, len(self))
                     )
 
-                result = conf.lib.clang_getArgType(self.parent, key)
+                result: Type = conf.lib.clang_getArgType(self.parent, key)
                 if result.kind == TypeKind.INVALID:
                     raise IndexError("Argument could not be retrieved.")
 
@@ -2543,10 +2581,8 @@ def visitor(field, children):
             fields.append(field)
             return 1  # continue
 
-        fields = []
-        conf.lib.clang_Type_visitFields(
-            self, callbacks["fields_visit"](visitor), fields
-        )
+        fields: list[Cursor] = []
+        conf.lib.clang_Type_visitFields(self, fields_visit_callback(visitor), fields)
         return iter(fields)
 
     def get_exception_specification_kind(self):
@@ -2820,15 +2856,15 @@ def results(self):
         return self.ptr.contents
 
     @property
-    def diagnostics(self):
+    def diagnostics(self) -> NoSliceSequence[Diagnostic]:
         class DiagnosticsItr:
-            def __init__(self, ccr):
+            def __init__(self, ccr: CodeCompletionResults):
                 self.ccr = ccr
 
-            def __len__(self):
+            def __len__(self) -> int:
                 return int(conf.lib.clang_codeCompleteGetNumDiagnostics(self.ccr))
 
-            def __getitem__(self, key):
+            def __getitem__(self, key: int) -> Diagnostic:
                 return conf.lib.clang_codeCompleteGetDiagnostic(self.ccr, key)
 
         return DiagnosticsItr(self)
@@ -3058,7 +3094,7 @@ def visitor(fobj, lptr, depth, includes):
         # Automatically adapt CIndex/ctype pointers to python objects
         includes = []
         conf.lib.clang_getInclusions(
-            self, callbacks["translation_unit_includes"](visitor), includes
+            self, translation_unit_includes_callback(visitor), includes
         )
 
         return iter(includes)
@@ -3126,19 +3162,19 @@ def get_extent(self, filename, locations):
         return SourceRange.from_locations(start_location, end_location)
 
     @property
-    def diagnostics(self):
+    def diagnostics(self) -> NoSliceSequence[Diagnostic]:
         """
         Return an iterable (and indexable) object containing the diagnostics.
         """
 
         class DiagIterator:
-            def __init__(self, tu):
+            def __init__(self, tu: TranslationUnit):
                 self.tu = tu
 
-            def __len__(self):
+            def __len__(self) -> int:
                 return int(conf.lib.clang_getNumDiagnostics(self.tu))
 
-            def __getitem__(self, key):
+            def __getitem__(self, key: int) -> Diagnostic:
                 diag = conf.lib.clang_getDiagnostic(self.tu, key)
                 if not diag:
                     raise IndexError
@@ -3570,15 +3606,15 @@ def write_main_file_to_stdout(self):
 
 # Now comes the plumbing to hook up the C library.
 
-# Register callback types in common container.
-callbacks["translation_unit_includes"] = CFUNCTYPE(
+# Register callback types
+translation_unit_includes_callback = CFUNCTYPE(
     None, c_object_p, POINTER(SourceLocation), c_uint, py_object
 )
-callbacks["cursor_visit"] = CFUNCTYPE(c_int, Cursor, Cursor, py_object)
-callbacks["fields_visit"] = CFUNCTYPE(c_int, Cursor, py_object)
+cursor_visit_callback = CFUNCTYPE(c_int, Cursor, Cursor, py_object)
+fields_visit_callback = CFUNCTYPE(c_int, Cursor, py_object)
 
 # Functions strictly alphabetical order.
-functionList = [
+functionList: list[LibFunc] = [
     (
         "clang_annotateTokens",
         [TranslationUnit, POINTER(Token), c_uint, POINTER(Cursor)],
@@ -3748,7 +3784,7 @@ def write_main_file_to_stdout(self):
     ("clang_getIncludedFile", [Cursor], c_object_p, File.from_result),
     (
         "clang_getInclusions",
-        [TranslationUnit, callbacks["translation_unit_includes"], py_object],
+        [TranslationUnit, translation_unit_includes_callback, py_object],
     ),
     (
         "clang_getInstantiationLocation",
@@ -3833,7 +3869,7 @@ def write_main_file_to_stdout(self):
         "clang_tokenize",
         [TranslationUnit, SourceRange, POINTER(POINTER(Token)), POINTER(c_uint)],
     ),
-    ("clang_visitChildren", [Cursor, callbacks["cursor_visit"], py_object], c_uint),
+    ("clang_visitChildren", [Cursor, cursor_visit_callback, py_object], c_uint),
     ("clang_Cursor_getNumArguments", [Cursor], c_int),
     ("clang_Cursor_getArgument", [Cursor, c_uint], Cursor, Cursor.from_result),
     ("clang_Cursor_getNumTemplateArguments", [Cursor], c_int),
@@ -3859,19 +3895,19 @@ def write_main_file_to_stdout(self):
     ("clang_Type_getSizeOf", [Type], c_longlong),
     ("clang_Type_getCXXRefQualifier", [Type], c_uint),
     ("clang_Type_getNamedType", [Type], Type, Type.from_result),
-    ("clang_Type_visitFields", [Type, callbacks["fields_visit"], py_object], c_uint),
+    ("clang_Type_visitFields", [Type, fields_visit_callback, py_object], c_uint),
 ]
 
 
 class LibclangError(Exception):
-    def __init__(self, message):
+    def __init__(self, message: str):
         self.m = message
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self.m
 
 
-def register_function(lib, item, ignore_errors):
+def register_function(lib: CDLL, item: LibFunc, ignore_errors: bool) -> None:
     # A function may not exist, if these bindings are used with an older or
     # incompatible version of libclang.so.
     try:
@@ -3895,15 +3931,15 @@ def register_function(lib, item, ignore_errors):
         func.errcheck = item[3]
 
 
-def register_functions(lib, ignore_errors):
+def register_functions(lib: CDLL, ignore_errors: bool) -> None:
     """Register function prototypes with a libclang library instance.
 
     This must be called as part of library instantiation so Python knows how
     to call out to the shared library.
     """
 
-    def register(item):
-        return register_function(lib, item, ignore_errors)
+    def register(item: LibFunc) -> None:
+        register_function(lib, item, ignore_errors)
 
     for f in functionList:
         register(f)
@@ -3911,12 +3947,12 @@ def register(item):
 
 class Config:
     library_path = None
-    library_file = None
+    library_file: str | None = None
     compatibility_check = True
     loaded = False
 
     @staticmethod
-    def set_library_path(path):
+    def set_library_path(path: StrPath) -> None:
         """Set the path in which to search for libclang"""
         if Config.loaded:
             raise Exception(
@@ -3927,7 +3963,7 @@ def set_library_path(path):
         Config.library_path = os.fspath(path)
 
     @staticmethod
-    def set_library_file(filename):
+    def set_library_file(filename: StrPath) -> None:
         """Set the exact location of libclang"""
         if Config.loaded:
             raise Exception(
@@ -3938,7 +3974,7 @@ def set_library_file(filename):
         Config.library_file = os.fspath(filename)
 
     @staticmethod
-    def set_compatibility_check(check_status):
+    def set_compatibility_check(check_status: bool) -> None:
         """Perform compatibility check when loading libclang
 
         The python bindings are only tested and evaluated with the version of
@@ -3964,13 +4000,13 @@ def set_compatibility_check(check_status):
         Config.compatibility_check = check_status
 
     @CachedProperty
-    def lib(self):
+    def lib(self) -> CDLL:
         lib = self.get_cindex_library()
         register_functions(lib, not Config.compatibility_check)
         Config.loaded = True
         return lib
 
-    def get_filename(self):
+    def get_filename(self) -> str:
         if Config.library_file:
             return Config.library_file
 
@@ -3990,7 +4026,7 @@ def get_filename(self):
 
         return file
 
-    def get_cindex_library(self):
+    def get_cindex_library(self) -> CDLL:
         try:
             library = cdll.LoadLibrary(self.get_filename())
         except OSError as e:
@@ -4003,7 +4039,7 @@ def get_cindex_library(self):
 
         return library
 
-    def function_exists(self, name):
+    def function_exists(self, name: str) -> bool:
         try:
             getattr(self.lib, name)
         except AttributeError:
diff --git a/clang/bindings/python/tests/cindex/test_code_completion.py b/clang/bindings/python/tests/cindex/test_code_completion.py
index ca52fc6f73e1d..1d513dbca2536 100644
--- a/clang/bindings/python/tests/cindex/test_code_completion.py
+++ b/clang/bindings/python/tests/cindex/test_code_completion.py
@@ -53,7 +53,7 @@ def test_code_complete(self):
         expected = [
             "{'int', ResultType} | {'test1', TypedText} || Priority: 50 || Availability: Available || Brief comment: Aaa.",
             "{'void', ResultType} | {'test2', TypedText} | {'(', LeftParen} | {')', RightParen} || Priority: 50 || Availability: Available || Brief comment: Bbb.",
-            "{'return', TypedText} | {';', SemiColon} || Priority: 40 || Availability: Available || Brief comment: None",
+            "{'return', TypedText} | {';', SemiColon} || Priority: 40 || Availability: Available || Brief comment: ",
         ]
         self.check_completion_results(cr, expected)
 
@@ -94,7 +94,7 @@ def test_code_complete_pathlike(self):
         expected = [
             "{'int', ResultType} | {'test1', TypedText} || Priority: 50 || Availability: Available || Brief comment: Aaa.",
             "{'void', ResultType} | {'test2', TypedText} | {'(', LeftParen} | {')', RightParen} || Priority: 50 || Availability: Available || Brief comment: Bbb.",
-            "{'return', TypedText} | {';', SemiColon} || Priority: 40 || Availability: Available || Brief comment: None",
+            "{'return', TypedText} | {';', SemiColon} || Priority: 40 || Availability: Available || Brief comment: ",
         ]
         self.check_completion_results(cr, expected)
 
@@ -128,19 +128,19 @@ class Q : public P {
         cr = tu.codeComplete("fake.cpp", 12, 5, unsaved_files=files)
 
         expected = [
-            "{'const', TypedText} || Priority: 50 || Availability: Available || Brief comment: None",
-            "{'volatile', TypedText} || Priority: 50 || Availability: Available || Brief comment: None",
-            "{'operator', TypedText} || Priority: 40 || Availability: Available || Brief comment: None",
-            "{'P', TypedText} || Priority: 50 || Availability: Available || Brief comment: None",
-            "{'Q', TypedText} || Priority: 50 || Availability: Available || Brief comment: None",
+            "{'const', TypedText} || Priority: 50 || Availability: Available || Brief comment: ",
+            "{'volatile', TypedText} || Priority: 50 || Availability: Available || Brief comment: ",
+            "{'operator', TypedText} || Priority: 40 || Availability: Available || Brief comment: ",
+            "{'P', TypedText} || Priority: 50 || Availability: Available || Brief comment: ",
+            "{'Q', TypedText} || Priority: 50 || Availability: Available || Brief comment: ",
         ]
         self.check_completion_results(cr, expected)
 
         cr = tu.codeComplete("fake.cpp", 13, 5, unsaved_files=files)
         expected = [
-            "{'P', TypedText} | {'::', Text} || Priority: 75 || Availability: Available || Brief comment: None",
-            "{'P &', ResultType} | {'operator=', TypedText} | {'(', LeftParen} | {'const P &', Placeholder} | {')', RightParen} || Priority: 79 || Availability: Available || Brief comment: None",
-            "{'int', ResultType} | {'member', TypedText} || Priority: 35 || Availability: NotAccessible || Brief comment: None",
-            "{'void', ResultType} | {'~P', TypedText} | {'(', LeftParen} | {')', RightParen} || Priority: 79 || Availability: Available || Brief comment: None",
+            "{'P', TypedText} | {'::', Text} || Priority: 75 || Availability: Available || Brief comment: ",
+            "{'P &', ResultType} | {'operator=', TypedText} | {'(', LeftParen} | {'const P &', Placeholder} | {')', RightParen} || Priority: 79 || Availability: Available || Brief comment: ",
+            "{'int', ResultType} | {'member', TypedText} || Priority: 35 || Availability: NotAccessible || Brief comment: ",
+            "{'void', ResultType} | {'~P', TypedText} | {'(', LeftParen} | {')', RightParen} || Priority: 79 || Availability: Available || Brief comment: ",
         ]
         self.check_completion_results(cr, expected)
diff --git a/clang/bindings/python/tests/cindex/test_comment.py b/clang/bindings/python/tests/cindex/test_comment.py
index 0727c6fa35d95..265c6d3d73de0 100644
--- a/clang/bindings/python/tests/cindex/test_comment.py
+++ b/clang/bindings/python/tests/cindex/test_comment.py
@@ -53,5 +53,5 @@ def test_comment(self):
         f = get_cursor(tu, "f")
         raw = f.raw_comment
         brief = f.brief_comment
-        self.assertIsNone(raw)
-        self.assertIsNone(brief)
+        self.assertEqual(raw, "")
+        self.assertEqual(brief, "")



More information about the cfe-commits mailing list