xref: /petsc/src/binding/petsc4py/conf/stubgen.py (revision 124b60a56262a80503e09b9eaaec281e19388b1e)
1import os
2import inspect
3import textwrap
4
5
6def is_cyfunction(obj):
7    return type(obj).__name__ == 'cython_function_or_method'
8
9
10def is_function(obj):
11    return inspect.isbuiltin(obj) or is_cyfunction(obj) or type(obj) is type(ord)
12
13
14def is_method(obj):
15    return (
16        inspect.ismethoddescriptor(obj)
17        or inspect.ismethod(obj)
18        or is_cyfunction(obj)
19        or type(obj)
20        in (
21            type(str.index),
22            type(str.__add__),
23            type(str.__new__),
24        )
25    )
26
27
28def is_classmethod(obj):
29    return inspect.isbuiltin(obj) or type(obj).__name__ in (
30        'classmethod',
31        'classmethod_descriptor',
32    )
33
34
35def is_staticmethod(obj):
36    return type(obj).__name__ in ('staticmethod',)
37
38
39def is_constant(obj):
40    return isinstance(obj, (int, float, str))
41
42
43def is_datadescr(obj):
44    return inspect.isdatadescriptor(obj) and not hasattr(obj, 'fget')
45
46
47def is_property(obj):
48    return inspect.isdatadescriptor(obj) and hasattr(obj, 'fget')
49
50
51def is_class(obj):
52    return inspect.isclass(obj) or type(obj) is type(int)
53
54
55class Lines(list):
56    INDENT = ' ' * 4
57    level = 0
58
59    @property
60    def add(self):
61        return self
62
63    @add.setter
64    def add(self, lines):
65        if lines is None:
66            return
67        if isinstance(lines, str):
68            lines = textwrap.dedent(lines).strip().split('\n')
69        indent = self.INDENT * self.level
70        for line in lines:
71            self.append(indent + line)
72
73
74def signature(obj):
75    doc = obj.__doc__
76    doc = doc or f'{obj.__name__}: Any'  # FIXME remove line
77    sig = doc.split('\n', 1)[0].split('.', 1)[-1]
78    return sig or None
79
80
81def visit_constant(constant):
82    name, value = constant
83    return f'{name}: Final[{type(value).__name__}] = ...'
84
85
86def visit_function(function):
87    sig = signature(function)
88    return f'def {sig}: ...'
89
90
91incompatible_overrides = [
92    'DMDA.create',
93    'DMStag.create',
94    'DMSwarm.getField',
95    'DMSwarm.setType',
96    'ViewerHDF5.create',
97    'SF.compose'
98]
99def visit_method(method, clas_name=None):
100    sig = signature(method)
101    stub = f'def {sig}: ...'
102    if f'{clas_name}.{method.__name__}' in incompatible_overrides:
103        stub += ' # type: ignore[override]'
104
105    return stub
106
107
108def visit_datadescr(datadescr):
109    sig = signature(datadescr)
110    return f'{sig}'
111
112
113def visit_property(prop, name=None):
114    sig = signature(prop.fget)
115    pname = name or prop.fget.__name__
116    ptype = sig.rsplit('->', 1)[-1].strip()
117    return f'{pname}: {ptype}'
118
119
120def visit_constructor(cls, name='__init__', args=None):
121    init = name == '__init__'
122    argname = cls.__name__.lower()
123    argtype = cls.__name__
124    initarg = args or f'{argname}: Optional[{argtype}] = None'
125    selfarg = 'self' if init else 'cls'
126    rettype = 'None' if init else argtype
127    arglist = f'{selfarg}, {initarg}'
128    sig = f'{name}({arglist}) -> {rettype}'
129    return f'def {sig}: ...'
130
131
132visited_classes = set()
133
134
135def visit_class(cls, outer=None, done=None):
136    skip = {
137        '__doc__',
138        '__dict__',
139        '__module__',
140        '__weakref__',
141        '__pyx_vtable__',
142        '__enum2str',  # FIXME refactor implementation
143        '_traceback_',  # FIXME maybe refactor?
144        '__lt__',
145        '__le__',
146        '__ge__',
147        '__gt__',
148    }
149    special = {
150        '__len__': '__len__(self) -> int',
151        '__bool__': '__bool__(self) -> bool',
152        '__hash__': '__hash__(self) -> int',
153        '__int__': '__int__(self) -> int',
154        '__index__': '__int__(self) -> int',
155        '__str__': '__str__(self) -> str',
156        '__repr__': '__repr__(self) -> str',
157        '__eq__': '__eq__(self, other: object) -> bool',
158        '__ne__': '__ne__(self, other: object) -> bool',
159    }
160    constructor = (
161        '__new__',
162        '__init__',
163    )
164
165    qualname = cls.__name__
166    cls_name = cls.__name__
167    if outer is not None and cls_name.startswith(outer):
168        cls_name = cls_name[len(outer) :]
169        qualname = f'{outer}.{cls_name}'
170
171    if qualname in visited_classes:
172        return ''
173
174    visited_classes.add(qualname)
175
176    override = OVERRIDE.get(qualname, {})
177    done = set() if done is None else done
178    lines = Lines()
179
180    try:
181
182        class sub(cls):
183            pass
184
185        final = False
186    except TypeError:
187        final = True
188    if final:
189        lines.add = '@final'
190    base = cls.__base__
191    if base is object:
192        lines.add = f'class {cls_name}:'
193    else:
194        lines.add = f'class {cls_name}({base.__name__}):'
195    lines.level += 1
196    start = len(lines)
197
198    for name in constructor:
199        if name in override:
200            continue
201        if name in cls.__dict__:
202            done.add(name)
203
204    if '__hash__' in cls.__dict__:
205        if cls.__hash__ is None:
206            done.add('__hash__')
207
208    dct = cls.__dict__
209    keys = list(dct.keys())
210
211    def dunder(name):
212        return name.startswith('__') and name.endswith('__')
213
214    def members(seq):
215        for name in seq:
216            if name in skip:
217                continue
218            if name in done:
219                continue
220            if dunder(name):
221                if name not in special and name not in override:
222                    done.add(name)
223                    continue
224            yield name
225
226    for name in members(keys):
227        attr = getattr(cls, name)
228        if is_class(attr):
229            done.add(name)
230            lines.add = visit_class(attr, outer=cls_name)
231            continue
232
233    for name in members(keys):
234        if name in override:
235            done.add(name)
236            lines.add = override[name]
237            continue
238
239        if name in special:
240            done.add(name)
241            sig = special[name]
242            lines.add = f'def {sig}: ...'
243            continue
244
245        attr = getattr(cls, name)
246
247        if is_method(attr):
248            done.add(name)
249            if name == attr.__name__:
250                obj = dct[name]
251                if is_classmethod(obj):
252                    lines.add = '@classmethod'
253                elif is_staticmethod(obj):
254                    lines.add = '@staticmethod'
255                lines.add = visit_method(attr, qualname)
256            elif True:
257                lines.add = f'{name} = {attr.__name__}'
258            continue
259
260        if is_datadescr(attr):
261            done.add(name)
262            lines.add = visit_datadescr(attr)
263            continue
264
265        if is_property(attr):
266            done.add(name)
267            lines.add = visit_property(attr, name)
268            continue
269
270        if is_constant(attr):
271            done.add(name)
272            lines.add = visit_constant((name, attr))
273            continue
274
275    leftovers = [name for name in keys if name not in done and name not in skip]
276    if leftovers:
277        raise RuntimeError(f'leftovers: {leftovers}')
278
279    if len(lines) == start:
280        lines.add = '...'
281    lines.level -= 1
282    return lines
283
284
285def visit_module(module, done=None):
286    skip = {
287        '__doc__',
288        '__name__',
289        '__loader__',
290        '__spec__',
291        '__file__',
292        '__package__',
293        '__builtins__',
294        '__pyx_unpickle_Enum',  # FIXME review
295    }
296
297    done = set() if done is None else done
298    lines = Lines()
299
300    keys = list(module.__dict__.keys())
301    keys.sort(key=lambda name: name.startswith('_'))
302
303    constants = [
304        (name, getattr(module, name))
305        for name in keys
306        if all(
307            (
308                name not in done and name not in skip,
309                isinstance(getattr(module, name), int),
310            )
311        )
312    ]
313    for name, value in constants:
314        done.add(name)
315        if name in OVERRIDE:
316            lines.add = OVERRIDE[name]
317        else:
318            lines.add = visit_constant((name, value))
319    if constants:
320        lines.add = ''
321
322    for name in keys:
323        if name in done or name in skip:
324            continue
325        value = getattr(module, name)
326
327        if is_class(value):
328            done.add(name)
329            if value.__module__ != module.__name__:
330                continue
331            lines.add = visit_class(value)
332            lines.add = ''
333            instances = [
334                (k, getattr(module, k))
335                for k in keys
336                if all(
337                    (
338                        k not in done and k not in skip,
339                        type(getattr(module, k)) is value,
340                    )
341                )
342            ]
343            for attrname, attrvalue in instances:
344                done.add(attrname)
345                lines.add = visit_constant((attrname, attrvalue))
346            if instances:
347                lines.add = ''
348            continue
349
350        if is_function(value):
351            done.add(name)
352            if name == value.__name__:
353                lines.add = visit_function(value)
354            else:
355                lines.add = f'{name} = {value.__name__}'
356            continue
357
358    lines.add = ''
359    for name in keys:
360        if name in done or name in skip:
361            continue
362        value = getattr(module, name)
363        done.add(name)
364        if name in OVERRIDE:
365            lines.add = OVERRIDE[name]
366        else:
367            lines.add = visit_constant((name, value))
368
369    leftovers = [name for name in keys if name not in done and name not in skip]
370    if leftovers:
371        raise RuntimeError(f'leftovers: {leftovers}')
372    return lines
373
374
375IMPORTS = """
376from __future__ import annotations
377import sys
378from threading import Lock
379from typing import (
380    Any,
381    Union,
382    Optional,
383    NoReturn,
384    overload,
385)
386if sys.version_info >= (3, 8):
387    from typing import (
388        final,
389        Final,
390        Literal,
391    )
392else:
393    from typing_extensions import (
394        final,
395        Final,
396        Literal,
397    )
398if sys.version_info >= (3, 9):
399    from collections.abc import (
400        Callable,
401        Hashable,
402        Iterable,
403        Iterator,
404        Sequence,
405        Mapping,
406    )
407else:
408    from typing import (
409        Callable,
410        Hashable,
411        Iterable,
412        Iterator,
413        Sequence,
414        Mapping,
415    )
416if sys.version_info >= (3, 11):
417    from typing import Self
418else:
419    from typing_extensions import Self
420from os import PathLike
421
422import numpy
423
424from numpy import (
425    dtype,
426    ndarray,
427)
428
429from mpi4py.MPI import (
430    Datatype,
431    Intracomm,
432    Op,
433)
434
435from petsc4py.typing import (
436    Scalar,
437    ArrayBool,
438    ArrayComplex,
439    ArrayInt,
440    ArrayReal,
441    ArrayScalar,
442    CSRIndicesSpec,
443    CSRSpec,
444    DMCoarsenHookFunction,
445    DMRestrictHookFunction,
446    DimsSpec,
447    KSPConvergenceTestFunction,
448    KSPMonitorFunction,
449    KSPOperatorsFunction,
450    KSPPostSolveFunction,
451    KSPPreSolveFunction,
452    KSPRHSFunction,
453    LayoutSizeSpec,
454    MatAssemblySpec,
455    MatBlockSizeSpec,
456    MatNullFunction,
457    MatSizeSpec,
458    NNZSpec,
459    NormTypeSpec,
460    PetscOptionsHandlerFunction,
461    ScatterModeSpec,
462    SNESMonitorFunction,
463    SNESObjFunction,
464    SNESFunction,
465    SNESJacobianFunction,
466    SNESGuessFunction,
467    SNESUpdateFunction,
468    SNESLSPreFunction,
469    SNESNGSFunction,
470    SNESConvergedFunction,
471    TAOConstraintsFunction,
472    TAOConstraintsJacobianFunction,
473    TAOConvergedFunction,
474    TAOGradientFunction,
475    TAOHessianFunction,
476    TAOJacobianFunction,
477    TAOJacobianResidualFunction,
478    TAOMonitorFunction,
479    TAOObjectiveFunction,
480    TAOObjectiveGradientFunction,
481    TAOResidualFunction,
482    TAOUpdateFunction,
483    TAOVariableBoundsFunction,
484    TAOLSGradientFunction,
485    TAOLSObjectiveFunction,
486    TAOLSObjectiveGradientFunction,
487    TSI2Function,
488    TSI2Jacobian,
489    TSI2JacobianP,
490    TSIFunction,
491    TSIJacobian,
492    TSIJacobianP,
493    TSIndicatorFunction,
494    TSMonitorFunction,
495    TSPostEventFunction,
496    TSPostStepFunction,
497    TSPreStepFunction,
498    TSRHSFunction,
499    TSRHSJacobian,
500    TSRHSJacobianP,
501    AccessModeSpec,
502    InsertModeSpec,
503)
504
505IntType: numpy.dtype = ...
506RealType: numpy.dtype = ...
507ComplexType: numpy.dtype = ...
508ScalarType: numpy.dtype = ...
509"""
510
511OVERRIDE = {
512    'Error': {
513        '__init__': 'def __init__(self, ierr: int = 0) -> None: ...',
514    },
515    'Options': {
516        '__init__': 'def __init__(self, prefix: str | None = None) -> None: ...',
517    },
518    '__pyx_capi__': '__pyx_capi__: Final[dict[str, Any]] = ...',
519    '__type_registry__': '__type_registry__: Final[dict[int, type[Object]]] = ...',
520}
521
522TYPING = """
523"""
524
525
526def visit_petsc4py_PETSc(done=None):
527    from petsc4py import PETSc as module
528
529    lines = Lines()
530    lines.add = IMPORTS
531    lines.add = ''
532    lines.add = visit_module(module)
533    lines.add = TYPING
534    return lines
535
536
537def generate(filename):
538    dirname = os.path.dirname(filename)
539    os.makedirs(dirname, exist_ok=True)
540    with open(filename, 'w') as f:
541        for line in visit_petsc4py_PETSc():
542            print(line, file=f)
543
544
545OUTDIR = os.path.join('src', 'petsc4py')
546
547if __name__ == '__main__':
548    generate(os.path.join(OUTDIR, 'PETSc.pyi'))
549