xref: /petsc/src/binding/petsc4py/conf/stubgen.py (revision 98d129c30f3ee9fdddc40fdbc5a989b7be64f888)
1import os
2import inspect
3import textwrap
4
5
6def is_cyfunction(obj):
7    return type(obj).__name__ == 'cython_function_or_method'
8
9
10def is_function(obj):
11    return inspect.isbuiltin(obj) or is_cyfunction(obj) or type(obj) is type(ord)
12
13
14def is_method(obj):
15    return (
16        inspect.ismethoddescriptor(obj)
17        or inspect.ismethod(obj)
18        or is_cyfunction(obj)
19        or type(obj)
20        in (
21            type(str.index),
22            type(str.__add__),
23            type(str.__new__),
24        )
25    )
26
27
28def is_classmethod(obj):
29    return inspect.isbuiltin(obj) or type(obj).__name__ in (
30        'classmethod',
31        'classmethod_descriptor',
32    )
33
34
35def is_staticmethod(obj):
36    return type(obj).__name__ in ('staticmethod',)
37
38
39def is_constant(obj):
40    return isinstance(obj, (int, float, str))
41
42
43def is_datadescr(obj):
44    return inspect.isdatadescriptor(obj) and not hasattr(obj, 'fget')
45
46
47def is_property(obj):
48    return inspect.isdatadescriptor(obj) and hasattr(obj, 'fget')
49
50
51def is_class(obj):
52    return inspect.isclass(obj) or type(obj) is type(int)
53
54
55class Lines(list):
56    INDENT = ' ' * 4
57    level = 0
58
59    @property
60    def add(self):
61        return self
62
63    @add.setter
64    def add(self, lines):
65        if lines is None:
66            return
67        if isinstance(lines, str):
68            lines = textwrap.dedent(lines).strip().split('\n')
69        indent = self.INDENT * self.level
70        for line in lines:
71            self.append(indent + line)
72
73
74def signature(obj):
75    doc = obj.__doc__
76    doc = doc or f'{obj.__name__}: Any'  # FIXME remove line
77    sig = doc.split('\n', 1)[0].split('.', 1)[-1]
78    return sig or None
79
80
81def visit_constant(constant):
82    name, value = constant
83    return f'{name}: Final[{type(value).__name__}] = ...'
84
85
86def visit_function(function):
87    sig = signature(function)
88    return f'def {sig}: ...'
89
90
91def visit_method(method):
92    sig = signature(method)
93    return f'def {sig}: ...'
94
95
96def visit_datadescr(datadescr):
97    sig = signature(datadescr)
98    return f'{sig}'
99
100
101def visit_property(prop, name=None):
102    sig = signature(prop.fget)
103    pname = name or prop.fget.__name__
104    ptype = sig.rsplit('->', 1)[-1].strip()
105    return f'{pname}: {ptype}'
106
107
108def visit_constructor(cls, name='__init__', args=None):
109    init = name == '__init__'
110    argname = cls.__name__.lower()
111    argtype = cls.__name__
112    initarg = args or f'{argname}: Optional[{argtype}] = None'
113    selfarg = 'self' if init else 'cls'
114    rettype = 'None' if init else argtype
115    arglist = f'{selfarg}, {initarg}'
116    sig = f'{name}({arglist}) -> {rettype}'
117    return f'def {sig}: ...'
118
119
120def visit_class(cls, outer=None, done=None):
121    skip = {
122        '__doc__',
123        '__dict__',
124        '__module__',
125        '__weakref__',
126        '__pyx_vtable__',
127        '__enum2str',  # FIXME refactor implementation
128        '_traceback_',  # FIXME maybe refactor?
129        '__lt__',
130        '__le__',
131        '__ge__',
132        '__gt__',
133    }
134    special = {
135        '__len__': '__len__(self) -> int',
136        '__bool__': '__bool__(self) -> bool',
137        '__hash__': '__hash__(self) -> int',
138        '__int__': '__int__(self) -> int',
139        '__index__': '__int__(self) -> int',
140        '__str__': '__str__(self) -> str',
141        '__repr__': '__repr__(self) -> str',
142        '__eq__': '__eq__(self, other: object) -> bool',
143        '__ne__': '__ne__(self, other: object) -> bool',
144    }
145    constructor = (
146        '__new__',
147        '__init__',
148    )
149
150    qualname = cls.__name__
151    cls_name = cls.__name__
152    if outer is not None and cls_name.startswith(outer):
153        cls_name = cls_name[len(outer) :]
154        qualname = f'{outer}.{cls_name}'
155
156    override = OVERRIDE.get(qualname, {})
157    done = set() if done is None else done
158    lines = Lines()
159
160    try:
161
162        class sub(cls):
163            pass
164
165        final = False
166    except TypeError:
167        final = True
168    if final:
169        lines.add = '@final'
170    base = cls.__base__
171    if base is object:
172        lines.add = f'class {cls_name}:'
173    else:
174        lines.add = f'class {cls_name}({base.__name__}):'
175    lines.level += 1
176    start = len(lines)
177
178    for name in constructor:
179        if name in cls.__dict__:
180            done.add(name)
181
182    if '__hash__' in cls.__dict__:
183        if cls.__hash__ is None:
184            done.add('__hash__')
185
186    dct = cls.__dict__
187    keys = list(dct.keys())
188
189    def dunder(name):
190        return name.startswith('__') and name.endswith('__')
191
192    def members(seq):
193        for name in seq:
194            if name in skip:
195                continue
196            if name in done:
197                continue
198            if dunder(name):
199                if name not in special and name not in override:
200                    done.add(name)
201                    continue
202            yield name
203
204    for name in members(keys):
205        attr = getattr(cls, name)
206        if is_class(attr):
207            done.add(name)
208            lines.add = visit_class(attr, outer=cls_name)
209            continue
210
211    for name in members(keys):
212        if name in override:
213            done.add(name)
214            lines.add = override[name]
215            continue
216
217        if name in special:
218            done.add(name)
219            sig = special[name]
220            lines.add = f'def {sig}: ...'
221            continue
222
223        attr = getattr(cls, name)
224
225        if is_method(attr):
226            done.add(name)
227            if name == attr.__name__:
228                obj = dct[name]
229                if is_classmethod(obj):
230                    lines.add = '@classmethod'
231                elif is_staticmethod(obj):
232                    lines.add = '@staticmethod'
233                lines.add = visit_method(attr)
234            elif True:
235                lines.add = f'{name} = {attr.__name__}'
236            continue
237
238        if is_datadescr(attr):
239            done.add(name)
240            lines.add = visit_datadescr(attr)
241            continue
242
243        if is_property(attr):
244            done.add(name)
245            lines.add = visit_property(attr, name)
246            continue
247
248        if is_constant(attr):
249            done.add(name)
250            lines.add = visit_constant((name, attr))
251            continue
252
253    leftovers = [name for name in keys if name not in done and name not in skip]
254    if leftovers:
255        raise RuntimeError(f'leftovers: {leftovers}')
256
257    if len(lines) == start:
258        lines.add = 'pass'
259    lines.level -= 1
260    return lines
261
262
263def visit_module(module, done=None):
264    skip = {
265        '__doc__',
266        '__name__',
267        '__loader__',
268        '__spec__',
269        '__file__',
270        '__package__',
271        '__builtins__',
272        '__pyx_unpickle_Enum',  # FIXME review
273    }
274
275    done = set() if done is None else done
276    lines = Lines()
277
278    keys = list(module.__dict__.keys())
279    keys.sort(key=lambda name: name.startswith('_'))
280
281    constants = [
282        (name, getattr(module, name))
283        for name in keys
284        if all(
285            (
286                name not in done and name not in skip,
287                isinstance(getattr(module, name), int),
288            )
289        )
290    ]
291    for name, value in constants:
292        done.add(name)
293        if name in OVERRIDE:
294            lines.add = OVERRIDE[name]
295        else:
296            lines.add = visit_constant((name, value))
297    if constants:
298        lines.add = ''
299
300    for name in keys:
301        if name in done or name in skip:
302            continue
303        value = getattr(module, name)
304
305        if is_class(value):
306            done.add(name)
307            if value.__module__ != module.__name__:
308                continue
309            lines.add = visit_class(value)
310            lines.add = ''
311            instances = [
312                (k, getattr(module, k))
313                for k in keys
314                if all(
315                    (
316                        k not in done and k not in skip,
317                        type(getattr(module, k)) is value,
318                    )
319                )
320            ]
321            for attrname, attrvalue in instances:
322                done.add(attrname)
323                lines.add = visit_constant((attrname, attrvalue))
324            if instances:
325                lines.add = ''
326            continue
327
328        if is_function(value):
329            done.add(name)
330            if name == value.__name__:
331                lines.add = visit_function(value)
332            else:
333                lines.add = f'{name} = {value.__name__}'
334            continue
335
336    lines.add = ''
337    for name in keys:
338        if name in done or name in skip:
339            continue
340        value = getattr(module, name)
341        done.add(name)
342        if name in OVERRIDE:
343            lines.add = OVERRIDE[name]
344        else:
345            lines.add = visit_constant((name, value))
346
347    leftovers = [name for name in keys if name not in done and name not in skip]
348    if leftovers:
349        raise RuntimeError(f'leftovers: {leftovers}')
350    return lines
351
352
353IMPORTS = """
354from __future__ import annotations
355import sys
356from threading import Lock
357from typing import (
358    Any,
359    Union,
360    Optional,
361    NoReturn,
362    overload,
363)
364if sys.version_info >= (3, 8):
365    from typing import (
366        final,
367        Final,
368        Literal,
369    )
370else:
371    from typing_extensions import (
372        final,
373        Final,
374        Literal,
375    )
376if sys.version_info >= (3, 9):
377    from collections.abc import (
378        Callable,
379        Hashable,
380        Iterable,
381        Iterator,
382        Sequence,
383        Mapping,
384    )
385else:
386    from typing import (
387        Callable,
388        Hashable,
389        Iterable,
390        Iterator,
391        Sequence,
392        Mapping,
393    )
394if sys.version_info >= (3, 11):
395    from typing import Self
396else:
397    from typing_extensions import Self
398from os import PathLike
399
400import numpy
401
402IntType: numpy.dtype = ...
403RealType: numpy.dtype = ...
404ComplexType: numpy.dtype = ...
405ScalarType: numpy.dtype = ...
406"""
407
408OVERRIDE = {
409    'Error': {},
410    '__pyx_capi__': '__pyx_capi__: Final[Dict[str, Any]] = ...',
411    '__type_registry__': '__type_registry__: Final[Dict[int, type[Object]]] = ...',
412}
413
414TYPING = """
415"""
416
417
418def visit_petsc4py_PETSc(done=None):
419    from petsc4py import PETSc as module
420
421    lines = Lines()
422    lines.add = IMPORTS
423    lines.add = ''
424    lines.add = visit_module(module)
425    lines.add = TYPING
426    return lines
427
428
429def generate(filename):
430    dirname = os.path.dirname(filename)
431    os.makedirs(dirname, exist_ok=True)
432    with open(filename, 'w') as f:
433        for line in visit_petsc4py_PETSc():
434            print(line, file=f)
435
436
437OUTDIR = os.path.join('src', 'petsc4py')
438
439if __name__ == '__main__':
440    generate(os.path.join(OUTDIR, 'PETSc.pyi'))
441