xref: /petsc/src/binding/petsc4py/conf/stubgen.py (revision fbf9dbe564678ed6eff1806adbc4c4f01b9743f4)
1import os
2import inspect
3import textwrap
4
5
6def is_cyfunction(obj):
7    return type(obj).__name__ == 'cython_function_or_method'
8
9
10def is_function(obj):
11    return (
12        inspect.isbuiltin(obj)
13        or is_cyfunction(obj)
14        or type(obj) is type(ord)
15    )
16
17
18def is_method(obj):
19    return (
20        inspect.ismethoddescriptor(obj)
21        or inspect.ismethod(obj)
22        or is_cyfunction(obj)
23        or type(obj) in (
24            type(str.index),
25            type(str.__add__),
26            type(str.__new__),
27        )
28    )
29
30
31def is_classmethod(obj):
32    return (
33        inspect.isbuiltin(obj)
34        or type(obj).__name__ in (
35            'classmethod',
36            'classmethod_descriptor',
37        )
38    )
39
40
41def is_staticmethod(obj):
42    return (
43        type(obj).__name__ in (
44            'staticmethod',
45        )
46    )
47
48
49def is_constant(obj):
50    return isinstance(obj, (int, float, str))
51
52
53def is_datadescr(obj):
54    return inspect.isdatadescriptor(obj) and not hasattr(obj, 'fget')
55
56
57def is_property(obj):
58    return inspect.isdatadescriptor(obj) and hasattr(obj, 'fget')
59
60
61def is_class(obj):
62    return inspect.isclass(obj) or type(obj) is type(int)
63
64
65class Lines(list):
66
67    INDENT = " " * 4
68    level = 0
69
70    @property
71    def add(self):
72        return self
73
74    @add.setter
75    def add(self, lines):
76        if lines is None:
77            return
78        if isinstance(lines, str):
79            lines = textwrap.dedent(lines).strip().split("\n")
80        indent = self.INDENT * self.level
81        for line in lines:
82            self.append(indent + line)
83
84
85def signature(obj):
86    doc = obj.__doc__
87    doc = doc or f"{obj.__name__}: Any"  # FIXME remove line
88    sig = doc.split('\n', 1)[0].split('.', 1)[-1]
89    return sig or None
90
91def visit_constant(constant):
92    name, value = constant
93    return f"{name}: Final[{type(value).__name__}] = ..."
94
95
96def visit_function(function):
97    sig = signature(function)
98    return f"def {sig}: ..."
99
100
101def visit_method(method):
102    sig = signature(method)
103    return f"def {sig}: ..."
104
105
106def visit_datadescr(datadescr):
107    sig = signature(datadescr)
108    return f"{sig}"
109
110
111def visit_property(prop, name=None):
112    sig = signature(prop.fget)
113    pname = name or prop.fget.__name__
114    ptype = sig.rsplit('->', 1)[-1].strip()
115    return f"{pname}: {ptype}"
116
117
118def visit_constructor(cls, name='__init__', args=None):
119    init = (name == '__init__')
120    argname = cls.__name__.lower()
121    argtype = cls.__name__
122    initarg = args or f"{argname}: Optional[{argtype}] = None"
123    selfarg = 'self' if init else 'cls'
124    rettype = 'None' if init else argtype
125    arglist = f"{selfarg}, {initarg}"
126    sig = f"{name}({arglist}) -> {rettype}"
127    return f"def {sig}: ..."
128
129def visit_class(cls, outer=None, done=None):
130    skip = {
131        '__doc__',
132        '__dict__',
133        '__module__',
134        '__weakref__',
135        '__pyx_vtable__',
136        '__enum2str',  # FIXME refactor implemetation
137        '_traceback_', # FIXME maybe refactor?
138        '__lt__',
139        '__le__',
140        '__ge__',
141        '__gt__',
142    }
143    special = {
144        '__len__':   "__len__(self) -> int",
145        '__bool__':  "__bool__(self) -> bool",
146        '__hash__':  "__hash__(self) -> int",
147        '__int__':   "__int__(self) -> int",
148        '__index__': "__int__(self) -> int",
149        '__str__':   "__str__(self) -> str",
150        '__repr__':  "__repr__(self) -> str",
151        '__eq__':    "__eq__(self, other: object) -> bool",
152        '__ne__':    "__ne__(self, other: object) -> bool",
153    }
154    constructor = (
155        '__new__',
156        '__init__',
157    )
158
159    qualname = cls.__name__
160    cls_name = cls.__name__
161    if outer is not None and cls_name.startswith(outer):
162        cls_name = cls_name[len(outer):]
163        qualname = f"{outer}.{cls_name}"
164
165    override = OVERRIDE.get(qualname, {})
166    done = set() if done is None else done
167    lines = Lines()
168
169    try:
170        class sub(cls):
171            pass
172        final = False
173    except TypeError:
174        final = True
175    if final:
176        lines.add = "@final"
177    base = cls.__base__
178    if base is object:
179        lines.add = f"class {cls_name}:"
180    else:
181        lines.add = f"class {cls_name}({base.__name__}):"
182    lines.level += 1
183    start = len(lines)
184
185    for name in constructor:
186        if name in cls.__dict__:
187            done.add(name)
188
189    if '__hash__' in cls.__dict__:
190        if cls.__hash__ is None:
191            done.add('__hash__')
192
193    dct = cls.__dict__
194    keys = list(dct.keys())
195
196    def dunder(name):
197        return name.startswith('__') and name.endswith('__')
198
199    def members(seq):
200        for name in seq:
201            if name in skip:
202                continue
203            if name in done:
204                continue
205            if dunder(name):
206                if name not in special and name not in override:
207                    done.add(name)
208                    continue
209            yield name
210
211    for name in members(keys):
212        attr = getattr(cls, name)
213        if is_class(attr):
214            done.add(name)
215            lines.add = visit_class(attr, outer=cls_name)
216            continue
217
218    for name in members(keys):
219
220        if name in override:
221            done.add(name)
222            lines.add = override[name]
223            continue
224
225        if name in special:
226            done.add(name)
227            sig = special[name]
228            lines.add = f"def {sig}: ..."
229            continue
230
231        attr = getattr(cls, name)
232
233        if is_method(attr):
234            done.add(name)
235            if name == attr.__name__:
236                obj = dct[name]
237                if is_classmethod(obj):
238                    lines.add = "@classmethod"
239                elif is_staticmethod(obj):
240                    lines.add = "@staticmethod"
241                lines.add = visit_method(attr)
242            elif True:
243                lines.add = f"{name} = {attr.__name__}"
244            continue
245
246        if is_datadescr(attr):
247            done.add(name)
248            lines.add = visit_datadescr(attr)
249            continue
250
251        if is_property(attr):
252            done.add(name)
253            lines.add = visit_property(attr, name)
254            continue
255
256        if is_constant(attr):
257            done.add(name)
258            lines.add = visit_constant((name, attr))
259            continue
260
261    leftovers = [name for name in keys if
262                 name not in done and name not in skip]
263    if leftovers:
264        raise RuntimeError(f"leftovers: {leftovers}")
265
266    if len(lines) == start:
267        lines.add = "pass"
268    lines.level -= 1
269    return lines
270
271
272def visit_module(module, done=None):
273    skip = {
274        '__doc__',
275        '__name__',
276        '__loader__',
277        '__spec__',
278        '__file__',
279        '__package__',
280        '__builtins__',
281        '__pyx_unpickle_Enum',  # FIXME review
282    }
283
284    done = set() if done is None else done
285    lines = Lines()
286
287    keys = list(module.__dict__.keys())
288    keys.sort(key=lambda name: name.startswith("_"))
289
290    constants = [
291        (name, getattr(module, name)) for name in keys
292        if all((
293            name not in done and name not in skip,
294            isinstance(getattr(module, name), int),
295        ))
296    ]
297    for name, value in constants:
298        done.add(name)
299        if name in OVERRIDE:
300            lines.add = OVERRIDE[name]
301        else:
302            lines.add = visit_constant((name, value))
303    if constants:
304        lines.add = ""
305
306    for name in keys:
307        if name in done or name in skip:
308            continue
309        value = getattr(module, name)
310
311        if is_class(value):
312            done.add(name)
313            if value.__module__ != module.__name__:
314                continue
315            lines.add = visit_class(value)
316            lines.add = ""
317            instances = [
318                (k, getattr(module, k)) for k in keys
319                if all((
320                    k not in done and k not in skip,
321                    type(getattr(module, k)) is value,
322                ))
323            ]
324            for attrname, attrvalue in instances:
325                done.add(attrname)
326                lines.add = visit_constant((attrname, attrvalue))
327            if instances:
328                lines.add = ""
329            continue
330
331        if is_function(value):
332            done.add(name)
333            if name == value.__name__:
334                lines.add = visit_function(value)
335            else:
336                lines.add = f"{name} = {value.__name__}"
337            continue
338
339    lines.add = ""
340    for name in keys:
341        if name in done or name in skip:
342            continue
343        value = getattr(module, name)
344        done.add(name)
345        if name in OVERRIDE:
346            lines.add = OVERRIDE[name]
347        else:
348            lines.add = visit_constant((name, value))
349
350    leftovers = [name for name in keys if
351                 name not in done and name not in skip]
352    if leftovers:
353        raise RuntimeError(f"leftovers: {leftovers}")
354    return lines
355
356
357IMPORTS = """
358from __future__ import annotations
359import sys
360from threading import Lock
361from typing import (
362    Any,
363    Union,
364    Optional,
365    NoReturn,
366    overload,
367)
368if sys.version_info >= (3, 8):
369    from typing import (
370        final,
371        Final,
372        Literal,
373    )
374else:
375    from typing_extensions import (
376        final,
377        Final,
378        Literal,
379    )
380if sys.version_info >= (3, 9):
381    from collections.abc import (
382        Callable,
383        Hashable,
384        Iterable,
385        Iterator,
386        Sequence,
387        Mapping,
388    )
389else:
390    from typing import (
391        Callable,
392        Hashable,
393        Iterable,
394        Iterator,
395        Sequence,
396        Mapping,
397    )
398if sys.version_info >= (3, 11):
399    from typing import Self
400else:
401    from typing_extensions import Self
402from os import PathLike
403
404import numpy
405
406IntType: numpy.dtype = ...
407RealType: numpy.dtype = ...
408ComplexType: numpy.dtype = ...
409ScalarType: numpy.dtype = ...
410"""
411
412OVERRIDE = {
413    'Error': {
414    },
415    '__pyx_capi__': "__pyx_capi__: Final[Dict[str, Any]] = ...",
416    '__type_registry__': "__type_registry__: Final[Dict[int, type[Object]]] = ...",
417}
418
419TYPING = """
420"""
421
422
423def visit_petsc4py_PETSc(done=None):
424    from petsc4py import PETSc as module
425    lines = Lines()
426    lines.add = IMPORTS
427    lines.add = ""
428    lines.add = visit_module(module)
429    lines.add = TYPING
430    return lines
431
432
433def generate(filename):
434    dirname = os.path.dirname(filename)
435    os.makedirs(dirname, exist_ok=True)
436    with open(filename, 'w') as f:
437        for line in visit_petsc4py_PETSc():
438            print(line, file=f)
439
440
441OUTDIR = os.path.join('src', 'petsc4py')
442
443if __name__ == '__main__':
444    generate(os.path.join(OUTDIR, 'PETSc.pyi'))
445