import os import inspect import textwrap def is_cyfunction(obj): return type(obj).__name__ == 'cython_function_or_method' def is_function(obj): return inspect.isbuiltin(obj) or is_cyfunction(obj) or type(obj) is type(ord) def is_method(obj): return ( inspect.ismethoddescriptor(obj) or inspect.ismethod(obj) or is_cyfunction(obj) or type(obj) in ( type(str.index), type(str.__add__), type(str.__new__), ) ) def is_classmethod(obj): return inspect.isbuiltin(obj) or type(obj).__name__ in ( 'classmethod', 'classmethod_descriptor', ) def is_staticmethod(obj): return type(obj).__name__ in ('staticmethod',) def is_constant(obj): return isinstance(obj, (int, float, str)) def is_datadescr(obj): return inspect.isdatadescriptor(obj) and not hasattr(obj, 'fget') def is_property(obj): return inspect.isdatadescriptor(obj) and hasattr(obj, 'fget') def is_class(obj): return inspect.isclass(obj) or type(obj) is type(int) class Lines(list): INDENT = ' ' * 4 level = 0 @property def add(self): return self @add.setter def add(self, lines): if lines is None: return if isinstance(lines, str): lines = textwrap.dedent(lines).strip().split('\n') indent = self.INDENT * self.level for line in lines: self.append(indent + line) def signature(obj): doc = obj.__doc__ doc = doc or f'{obj.__name__}: Any' # FIXME remove line sig = doc.split('\n', 1)[0].split('.', 1)[-1] return sig or None def visit_constant(constant): name, value = constant return f'{name}: Final[{type(value).__name__}] = ...' def visit_function(function): sig = signature(function) return f'def {sig}: ...' incompatible_overrides = [ 'DMDA.create', 'DMStag.create', 'DMSwarm.getField', 'DMSwarm.setType', 'ViewerHDF5.create', 'SF.compose' ] def visit_method(method, clas_name=None): sig = signature(method) stub = f'def {sig}: ...' if f'{clas_name}.{method.__name__}' in incompatible_overrides: stub += ' # type: ignore[override]' return stub def visit_datadescr(datadescr): sig = signature(datadescr) return f'{sig}' def visit_property(prop, name=None): sig = signature(prop.fget) pname = name or prop.fget.__name__ ptype = sig.rsplit('->', 1)[-1].strip() return f'{pname}: {ptype}' def visit_constructor(cls, name='__init__', args=None): init = name == '__init__' argname = cls.__name__.lower() argtype = cls.__name__ initarg = args or f'{argname}: Optional[{argtype}] = None' selfarg = 'self' if init else 'cls' rettype = 'None' if init else argtype arglist = f'{selfarg}, {initarg}' sig = f'{name}({arglist}) -> {rettype}' return f'def {sig}: ...' visited_classes = set() def visit_class(cls, outer=None, done=None): skip = { '__doc__', '__dict__', '__module__', '__weakref__', '__pyx_vtable__', '__enum2str', # FIXME refactor implementation '_traceback_', # FIXME maybe refactor? '__lt__', '__le__', '__ge__', '__gt__', } special = { '__len__': '__len__(self) -> int', '__bool__': '__bool__(self) -> bool', '__hash__': '__hash__(self) -> int', '__int__': '__int__(self) -> int', '__index__': '__int__(self) -> int', '__str__': '__str__(self) -> str', '__repr__': '__repr__(self) -> str', '__eq__': '__eq__(self, other: object) -> bool', '__ne__': '__ne__(self, other: object) -> bool', } constructor = ( '__new__', '__init__', ) qualname = cls.__name__ cls_name = cls.__name__ if outer is not None and cls_name.startswith(outer): cls_name = cls_name[len(outer) :] qualname = f'{outer}.{cls_name}' if qualname in visited_classes: return '' visited_classes.add(qualname) override = OVERRIDE.get(qualname, {}) done = set() if done is None else done lines = Lines() try: class sub(cls): pass final = False except TypeError: final = True if final: lines.add = '@final' base = cls.__base__ if base is object: lines.add = f'class {cls_name}:' else: lines.add = f'class {cls_name}({base.__name__}):' lines.level += 1 start = len(lines) for name in constructor: if name in override: continue if name in cls.__dict__: done.add(name) if '__hash__' in cls.__dict__: if cls.__hash__ is None: done.add('__hash__') dct = cls.__dict__ keys = list(dct.keys()) def dunder(name): return name.startswith('__') and name.endswith('__') def members(seq): for name in seq: if name in skip: continue if name in done: continue if dunder(name): if name not in special and name not in override: done.add(name) continue yield name for name in members(keys): attr = getattr(cls, name) if is_class(attr): done.add(name) lines.add = visit_class(attr, outer=cls_name) continue for name in members(keys): if name in override: done.add(name) lines.add = override[name] continue if name in special: done.add(name) sig = special[name] lines.add = f'def {sig}: ...' continue attr = getattr(cls, name) if is_method(attr): done.add(name) if name == attr.__name__: obj = dct[name] if is_classmethod(obj): lines.add = '@classmethod' elif is_staticmethod(obj): lines.add = '@staticmethod' lines.add = visit_method(attr, qualname) elif True: lines.add = f'{name} = {attr.__name__}' continue if is_datadescr(attr): done.add(name) lines.add = visit_datadescr(attr) continue if is_property(attr): done.add(name) lines.add = visit_property(attr, name) continue if is_constant(attr): done.add(name) lines.add = visit_constant((name, attr)) continue leftovers = [name for name in keys if name not in done and name not in skip] if leftovers: raise RuntimeError(f'leftovers: {leftovers}') if len(lines) == start: lines.add = '...' lines.level -= 1 return lines def visit_module(module, done=None): skip = { '__doc__', '__name__', '__loader__', '__spec__', '__file__', '__package__', '__builtins__', '__pyx_unpickle_Enum', # FIXME review } done = set() if done is None else done lines = Lines() keys = list(module.__dict__.keys()) keys.sort(key=lambda name: name.startswith('_')) constants = [ (name, getattr(module, name)) for name in keys if all( ( name not in done and name not in skip, isinstance(getattr(module, name), int), ) ) ] for name, value in constants: done.add(name) if name in OVERRIDE: lines.add = OVERRIDE[name] else: lines.add = visit_constant((name, value)) if constants: lines.add = '' for name in keys: if name in done or name in skip: continue value = getattr(module, name) if is_class(value): done.add(name) if value.__module__ != module.__name__: continue lines.add = visit_class(value) lines.add = '' instances = [ (k, getattr(module, k)) for k in keys if all( ( k not in done and k not in skip, type(getattr(module, k)) is value, ) ) ] for attrname, attrvalue in instances: done.add(attrname) lines.add = visit_constant((attrname, attrvalue)) if instances: lines.add = '' continue if is_function(value): done.add(name) if name == value.__name__: lines.add = visit_function(value) else: lines.add = f'{name} = {value.__name__}' continue lines.add = '' for name in keys: if name in done or name in skip: continue value = getattr(module, name) done.add(name) if name in OVERRIDE: lines.add = OVERRIDE[name] else: lines.add = visit_constant((name, value)) leftovers = [name for name in keys if name not in done and name not in skip] if leftovers: raise RuntimeError(f'leftovers: {leftovers}') return lines IMPORTS = """ from __future__ import annotations import sys from threading import Lock from typing import ( Any, Union, Optional, NoReturn, overload, ) if sys.version_info >= (3, 8): from typing import ( final, Final, Literal, ) else: from typing_extensions import ( final, Final, Literal, ) if sys.version_info >= (3, 9): from collections.abc import ( Callable, Hashable, Iterable, Iterator, Sequence, Mapping, ) else: from typing import ( Callable, Hashable, Iterable, Iterator, Sequence, Mapping, ) if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self from os import PathLike import numpy from numpy import ( dtype, ndarray, ) from mpi4py.MPI import ( Datatype, Intracomm, Op, ) from petsc4py.typing import ( Scalar, ArrayBool, ArrayComplex, ArrayInt, ArrayReal, ArrayScalar, CSRIndicesSpec, CSRSpec, DMCoarsenHookFunction, DMRestrictHookFunction, DimsSpec, KSPConvergenceTestFunction, KSPMonitorFunction, KSPOperatorsFunction, KSPPostSolveFunction, KSPPreSolveFunction, KSPRHSFunction, LayoutSizeSpec, MatAssemblySpec, MatBlockSizeSpec, MatNullFunction, MatSizeSpec, NNZSpec, NormTypeSpec, PetscOptionsHandlerFunction, ScatterModeSpec, SNESMonitorFunction, SNESObjFunction, SNESFunction, SNESJacobianFunction, SNESGuessFunction, SNESUpdateFunction, SNESLSPreFunction, SNESNGSFunction, SNESConvergedFunction, TAOConstraintsFunction, TAOConstraintsJacobianFunction, TAOConvergedFunction, TAOGradientFunction, TAOHessianFunction, TAOJacobianFunction, TAOJacobianResidualFunction, TAOMonitorFunction, TAOObjectiveFunction, TAOObjectiveGradientFunction, TAOResidualFunction, TAOUpdateFunction, TAOVariableBoundsFunction, TAOLSGradientFunction, TAOLSObjectiveFunction, TAOLSObjectiveGradientFunction, TSI2Function, TSI2Jacobian, TSI2JacobianP, TSIFunction, TSIJacobian, TSIJacobianP, TSIndicatorFunction, TSMonitorFunction, TSPostEventFunction, TSPostStepFunction, TSPreStepFunction, TSRHSFunction, TSRHSJacobian, TSRHSJacobianP, AccessModeSpec, InsertModeSpec, ) IntType: numpy.dtype = ... RealType: numpy.dtype = ... ComplexType: numpy.dtype = ... ScalarType: numpy.dtype = ... """ OVERRIDE = { 'Error': { '__init__': 'def __init__(self, ierr: int = 0) -> None: ...', }, 'Options': { '__init__': 'def __init__(self, prefix: str | None = None) -> None: ...', }, '__pyx_capi__': '__pyx_capi__: Final[dict[str, Any]] = ...', '__type_registry__': '__type_registry__: Final[dict[int, type[Object]]] = ...', } TYPING = """ """ def visit_petsc4py_PETSc(done=None): from petsc4py import PETSc as module lines = Lines() lines.add = IMPORTS lines.add = '' lines.add = visit_module(module) lines.add = TYPING return lines def generate(filename): dirname = os.path.dirname(filename) os.makedirs(dirname, exist_ok=True) with open(filename, 'w') as f: for line in visit_petsc4py_PETSc(): print(line, file=f) OUTDIR = os.path.join('src', 'petsc4py') if __name__ == '__main__': generate(os.path.join(OUTDIR, 'PETSc.pyi'))