1import os 2import inspect 3import textwrap 4 5 6def is_cyfunction(obj): 7 return type(obj).__name__ == 'cython_function_or_method' 8 9 10def is_function(obj): 11 return inspect.isbuiltin(obj) or is_cyfunction(obj) or type(obj) is type(ord) 12 13 14def is_method(obj): 15 return ( 16 inspect.ismethoddescriptor(obj) 17 or inspect.ismethod(obj) 18 or is_cyfunction(obj) 19 or type(obj) 20 in ( 21 type(str.index), 22 type(str.__add__), 23 type(str.__new__), 24 ) 25 ) 26 27 28def is_classmethod(obj): 29 return inspect.isbuiltin(obj) or type(obj).__name__ in ( 30 'classmethod', 31 'classmethod_descriptor', 32 ) 33 34 35def is_staticmethod(obj): 36 return type(obj).__name__ in ('staticmethod',) 37 38 39def is_constant(obj): 40 return isinstance(obj, (int, float, str)) 41 42 43def is_datadescr(obj): 44 return inspect.isdatadescriptor(obj) and not hasattr(obj, 'fget') 45 46 47def is_property(obj): 48 return inspect.isdatadescriptor(obj) and hasattr(obj, 'fget') 49 50 51def is_class(obj): 52 return inspect.isclass(obj) or type(obj) is type(int) 53 54 55class Lines(list): 56 INDENT = ' ' * 4 57 level = 0 58 59 @property 60 def add(self): 61 return self 62 63 @add.setter 64 def add(self, lines): 65 if lines is None: 66 return 67 if isinstance(lines, str): 68 lines = textwrap.dedent(lines).strip().split('\n') 69 indent = self.INDENT * self.level 70 for line in lines: 71 self.append(indent + line) 72 73 74def signature(obj): 75 doc = obj.__doc__ 76 doc = doc or f'{obj.__name__}: Any' # FIXME remove line 77 sig = doc.split('\n', 1)[0].split('.', 1)[-1] 78 return sig or None 79 80 81def visit_constant(constant): 82 name, value = constant 83 return f'{name}: Final[{type(value).__name__}] = ...' 84 85 86def visit_function(function): 87 sig = signature(function) 88 return f'def {sig}: ...' 89 90 91incompatible_overrides = [ 92 'DMDA.create', 93 'DMStag.create', 94 'DMSwarm.getField', 95 'DMSwarm.setType', 96 'ViewerHDF5.create', 97 'SF.compose' 98] 99def visit_method(method, clas_name=None): 100 sig = signature(method) 101 stub = f'def {sig}: ...' 102 if f'{clas_name}.{method.__name__}' in incompatible_overrides: 103 stub += ' # type: ignore[override]' 104 105 return stub 106 107 108def visit_datadescr(datadescr): 109 sig = signature(datadescr) 110 return f'{sig}' 111 112 113def visit_property(prop, name=None): 114 sig = signature(prop.fget) 115 pname = name or prop.fget.__name__ 116 ptype = sig.rsplit('->', 1)[-1].strip() 117 return f'{pname}: {ptype}' 118 119 120def visit_constructor(cls, name='__init__', args=None): 121 init = name == '__init__' 122 argname = cls.__name__.lower() 123 argtype = cls.__name__ 124 initarg = args or f'{argname}: Optional[{argtype}] = None' 125 selfarg = 'self' if init else 'cls' 126 rettype = 'None' if init else argtype 127 arglist = f'{selfarg}, {initarg}' 128 sig = f'{name}({arglist}) -> {rettype}' 129 return f'def {sig}: ...' 130 131 132visited_classes = set() 133 134 135def visit_class(cls, outer=None, done=None): 136 skip = { 137 '__doc__', 138 '__dict__', 139 '__module__', 140 '__weakref__', 141 '__pyx_vtable__', 142 '__enum2str', # FIXME refactor implementation 143 '_traceback_', # FIXME maybe refactor? 144 '__lt__', 145 '__le__', 146 '__ge__', 147 '__gt__', 148 } 149 special = { 150 '__len__': '__len__(self) -> int', 151 '__bool__': '__bool__(self) -> bool', 152 '__hash__': '__hash__(self) -> int', 153 '__int__': '__int__(self) -> int', 154 '__index__': '__int__(self) -> int', 155 '__str__': '__str__(self) -> str', 156 '__repr__': '__repr__(self) -> str', 157 '__eq__': '__eq__(self, other: object) -> bool', 158 '__ne__': '__ne__(self, other: object) -> bool', 159 } 160 constructor = ( 161 '__new__', 162 '__init__', 163 ) 164 165 qualname = cls.__name__ 166 cls_name = cls.__name__ 167 if outer is not None and cls_name.startswith(outer): 168 cls_name = cls_name[len(outer) :] 169 qualname = f'{outer}.{cls_name}' 170 171 if qualname in visited_classes: 172 return '' 173 174 visited_classes.add(qualname) 175 176 override = OVERRIDE.get(qualname, {}) 177 done = set() if done is None else done 178 lines = Lines() 179 180 try: 181 182 class sub(cls): 183 pass 184 185 final = False 186 except TypeError: 187 final = True 188 if final: 189 lines.add = '@final' 190 base = cls.__base__ 191 if base is object: 192 lines.add = f'class {cls_name}:' 193 else: 194 lines.add = f'class {cls_name}({base.__name__}):' 195 lines.level += 1 196 start = len(lines) 197 198 for name in constructor: 199 if name in override: 200 continue 201 if name in cls.__dict__: 202 done.add(name) 203 204 if '__hash__' in cls.__dict__: 205 if cls.__hash__ is None: 206 done.add('__hash__') 207 208 dct = cls.__dict__ 209 keys = list(dct.keys()) 210 211 def dunder(name): 212 return name.startswith('__') and name.endswith('__') 213 214 def members(seq): 215 for name in seq: 216 if name in skip: 217 continue 218 if name in done: 219 continue 220 if dunder(name): 221 if name not in special and name not in override: 222 done.add(name) 223 continue 224 yield name 225 226 for name in members(keys): 227 attr = getattr(cls, name) 228 if is_class(attr): 229 done.add(name) 230 lines.add = visit_class(attr, outer=cls_name) 231 continue 232 233 for name in members(keys): 234 if name in override: 235 done.add(name) 236 lines.add = override[name] 237 continue 238 239 if name in special: 240 done.add(name) 241 sig = special[name] 242 lines.add = f'def {sig}: ...' 243 continue 244 245 attr = getattr(cls, name) 246 247 if is_method(attr): 248 done.add(name) 249 if name == attr.__name__: 250 obj = dct[name] 251 if is_classmethod(obj): 252 lines.add = '@classmethod' 253 elif is_staticmethod(obj): 254 lines.add = '@staticmethod' 255 lines.add = visit_method(attr, qualname) 256 elif True: 257 lines.add = f'{name} = {attr.__name__}' 258 continue 259 260 if is_datadescr(attr): 261 done.add(name) 262 lines.add = visit_datadescr(attr) 263 continue 264 265 if is_property(attr): 266 done.add(name) 267 lines.add = visit_property(attr, name) 268 continue 269 270 if is_constant(attr): 271 done.add(name) 272 lines.add = visit_constant((name, attr)) 273 continue 274 275 leftovers = [name for name in keys if name not in done and name not in skip] 276 if leftovers: 277 raise RuntimeError(f'leftovers: {leftovers}') 278 279 if len(lines) == start: 280 lines.add = '...' 281 lines.level -= 1 282 return lines 283 284 285def visit_module(module, done=None): 286 skip = { 287 '__doc__', 288 '__name__', 289 '__loader__', 290 '__spec__', 291 '__file__', 292 '__package__', 293 '__builtins__', 294 '__pyx_unpickle_Enum', # FIXME review 295 } 296 297 done = set() if done is None else done 298 lines = Lines() 299 300 keys = list(module.__dict__.keys()) 301 keys.sort(key=lambda name: name.startswith('_')) 302 303 constants = [ 304 (name, getattr(module, name)) 305 for name in keys 306 if all( 307 ( 308 name not in done and name not in skip, 309 isinstance(getattr(module, name), int), 310 ) 311 ) 312 ] 313 for name, value in constants: 314 done.add(name) 315 if name in OVERRIDE: 316 lines.add = OVERRIDE[name] 317 else: 318 lines.add = visit_constant((name, value)) 319 if constants: 320 lines.add = '' 321 322 for name in keys: 323 if name in done or name in skip: 324 continue 325 value = getattr(module, name) 326 327 if is_class(value): 328 done.add(name) 329 if value.__module__ != module.__name__: 330 continue 331 lines.add = visit_class(value) 332 lines.add = '' 333 instances = [ 334 (k, getattr(module, k)) 335 for k in keys 336 if all( 337 ( 338 k not in done and k not in skip, 339 type(getattr(module, k)) is value, 340 ) 341 ) 342 ] 343 for attrname, attrvalue in instances: 344 done.add(attrname) 345 lines.add = visit_constant((attrname, attrvalue)) 346 if instances: 347 lines.add = '' 348 continue 349 350 if is_function(value): 351 done.add(name) 352 if name == value.__name__: 353 lines.add = visit_function(value) 354 else: 355 lines.add = f'{name} = {value.__name__}' 356 continue 357 358 lines.add = '' 359 for name in keys: 360 if name in done or name in skip: 361 continue 362 value = getattr(module, name) 363 done.add(name) 364 if name in OVERRIDE: 365 lines.add = OVERRIDE[name] 366 else: 367 lines.add = visit_constant((name, value)) 368 369 leftovers = [name for name in keys if name not in done and name not in skip] 370 if leftovers: 371 raise RuntimeError(f'leftovers: {leftovers}') 372 return lines 373 374 375IMPORTS = """ 376from __future__ import annotations 377import sys 378from threading import Lock 379from typing import ( 380 Any, 381 Union, 382 Optional, 383 NoReturn, 384 overload, 385) 386if sys.version_info >= (3, 8): 387 from typing import ( 388 final, 389 Final, 390 Literal, 391 ) 392else: 393 from typing_extensions import ( 394 final, 395 Final, 396 Literal, 397 ) 398if sys.version_info >= (3, 9): 399 from collections.abc import ( 400 Callable, 401 Hashable, 402 Iterable, 403 Iterator, 404 Sequence, 405 Mapping, 406 ) 407else: 408 from typing import ( 409 Callable, 410 Hashable, 411 Iterable, 412 Iterator, 413 Sequence, 414 Mapping, 415 ) 416if sys.version_info >= (3, 11): 417 from typing import Self 418else: 419 from typing_extensions import Self 420from os import PathLike 421 422import numpy 423 424from numpy import ( 425 dtype, 426 ndarray, 427) 428 429from mpi4py.MPI import ( 430 Datatype, 431 Intracomm, 432 Op, 433) 434 435from petsc4py.typing import ( 436 Scalar, 437 ArrayBool, 438 ArrayComplex, 439 ArrayInt, 440 ArrayReal, 441 ArrayScalar, 442 CSRIndicesSpec, 443 CSRSpec, 444 DMCoarsenHookFunction, 445 DMRestrictHookFunction, 446 DimsSpec, 447 KSPConvergenceTestFunction, 448 KSPMonitorFunction, 449 KSPOperatorsFunction, 450 KSPPostSolveFunction, 451 KSPPreSolveFunction, 452 KSPRHSFunction, 453 LayoutSizeSpec, 454 MatAssemblySpec, 455 MatBlockSizeSpec, 456 MatNullFunction, 457 MatSizeSpec, 458 NNZSpec, 459 NormTypeSpec, 460 PetscOptionsHandlerFunction, 461 ScatterModeSpec, 462 SNESMonitorFunction, 463 SNESObjFunction, 464 SNESFunction, 465 SNESJacobianFunction, 466 SNESGuessFunction, 467 SNESUpdateFunction, 468 SNESLSPreFunction, 469 SNESNGSFunction, 470 SNESConvergedFunction, 471 TAOConstraintsFunction, 472 TAOConstraintsJacobianFunction, 473 TAOConvergedFunction, 474 TAOGradientFunction, 475 TAOHessianFunction, 476 TAOJacobianFunction, 477 TAOJacobianResidualFunction, 478 TAOMonitorFunction, 479 TAOObjectiveFunction, 480 TAOObjectiveGradientFunction, 481 TAOResidualFunction, 482 TAOUpdateFunction, 483 TAOVariableBoundsFunction, 484 TAOLSGradientFunction, 485 TAOLSObjectiveFunction, 486 TAOLSObjectiveGradientFunction, 487 TSI2Function, 488 TSI2Jacobian, 489 TSI2JacobianP, 490 TSIFunction, 491 TSIJacobian, 492 TSIJacobianP, 493 TSIndicatorFunction, 494 TSMonitorFunction, 495 TSPostEventFunction, 496 TSPostStepFunction, 497 TSPreStepFunction, 498 TSRHSFunction, 499 TSRHSJacobian, 500 TSRHSJacobianP, 501 AccessModeSpec, 502 InsertModeSpec, 503) 504 505IntType: numpy.dtype = ... 506RealType: numpy.dtype = ... 507ComplexType: numpy.dtype = ... 508ScalarType: numpy.dtype = ... 509""" 510 511OVERRIDE = { 512 'Error': { 513 '__init__': 'def __init__(self, ierr: int = 0) -> None: ...', 514 }, 515 'Options': { 516 '__init__': 'def __init__(self, prefix: str | None = None) -> None: ...', 517 }, 518 '__pyx_capi__': '__pyx_capi__: Final[dict[str, Any]] = ...', 519 '__type_registry__': '__type_registry__: Final[dict[int, type[Object]]] = ...', 520} 521 522TYPING = """ 523""" 524 525 526def visit_petsc4py_PETSc(done=None): 527 from petsc4py import PETSc as module 528 529 lines = Lines() 530 lines.add = IMPORTS 531 lines.add = '' 532 lines.add = visit_module(module) 533 lines.add = TYPING 534 return lines 535 536 537def generate(filename): 538 dirname = os.path.dirname(filename) 539 os.makedirs(dirname, exist_ok=True) 540 with open(filename, 'w') as f: 541 for line in visit_petsc4py_PETSc(): 542 print(line, file=f) 543 544 545OUTDIR = os.path.join('src', 'petsc4py') 546 547if __name__ == '__main__': 548 generate(os.path.join(OUTDIR, 'PETSc.pyi')) 549