1import os 2import inspect 3import textwrap 4 5 6def is_cyfunction(obj): 7 return type(obj).__name__ == 'cython_function_or_method' 8 9 10def is_function(obj): 11 return inspect.isbuiltin(obj) or is_cyfunction(obj) or type(obj) is type(ord) 12 13 14def is_method(obj): 15 return ( 16 inspect.ismethoddescriptor(obj) 17 or inspect.ismethod(obj) 18 or is_cyfunction(obj) 19 or type(obj) 20 in ( 21 type(str.index), 22 type(str.__add__), 23 type(str.__new__), 24 ) 25 ) 26 27 28def is_classmethod(obj): 29 return inspect.isbuiltin(obj) or type(obj).__name__ in ( 30 'classmethod', 31 'classmethod_descriptor', 32 ) 33 34 35def is_staticmethod(obj): 36 return type(obj).__name__ in ('staticmethod',) 37 38 39def is_constant(obj): 40 return isinstance(obj, (int, float, str)) 41 42 43def is_datadescr(obj): 44 return inspect.isdatadescriptor(obj) and not hasattr(obj, 'fget') 45 46 47def is_property(obj): 48 return inspect.isdatadescriptor(obj) and hasattr(obj, 'fget') 49 50 51def is_class(obj): 52 return inspect.isclass(obj) or type(obj) is type(int) 53 54 55class Lines(list): 56 INDENT = ' ' * 4 57 level = 0 58 59 @property 60 def add(self): 61 return self 62 63 @add.setter 64 def add(self, lines): 65 if lines is None: 66 return 67 if isinstance(lines, str): 68 lines = textwrap.dedent(lines).strip().split('\n') 69 indent = self.INDENT * self.level 70 for line in lines: 71 self.append(indent + line) 72 73 74def signature(obj): 75 doc = obj.__doc__ 76 doc = doc or f'{obj.__name__}: Any' # FIXME remove line 77 sig = doc.split('\n', 1)[0].split('.', 1)[-1] 78 return sig or None 79 80 81def visit_constant(constant): 82 name, value = constant 83 return f'{name}: Final[{type(value).__name__}] = ...' 84 85 86def visit_function(function): 87 sig = signature(function) 88 return f'def {sig}: ...' 89 90 91def visit_method(method): 92 sig = signature(method) 93 return f'def {sig}: ...' 94 95 96def visit_datadescr(datadescr): 97 sig = signature(datadescr) 98 return f'{sig}' 99 100 101def visit_property(prop, name=None): 102 sig = signature(prop.fget) 103 pname = name or prop.fget.__name__ 104 ptype = sig.rsplit('->', 1)[-1].strip() 105 return f'{pname}: {ptype}' 106 107 108def visit_constructor(cls, name='__init__', args=None): 109 init = name == '__init__' 110 argname = cls.__name__.lower() 111 argtype = cls.__name__ 112 initarg = args or f'{argname}: Optional[{argtype}] = None' 113 selfarg = 'self' if init else 'cls' 114 rettype = 'None' if init else argtype 115 arglist = f'{selfarg}, {initarg}' 116 sig = f'{name}({arglist}) -> {rettype}' 117 return f'def {sig}: ...' 118 119 120def visit_class(cls, outer=None, done=None): 121 skip = { 122 '__doc__', 123 '__dict__', 124 '__module__', 125 '__weakref__', 126 '__pyx_vtable__', 127 '__enum2str', # FIXME refactor implementation 128 '_traceback_', # FIXME maybe refactor? 129 '__lt__', 130 '__le__', 131 '__ge__', 132 '__gt__', 133 } 134 special = { 135 '__len__': '__len__(self) -> int', 136 '__bool__': '__bool__(self) -> bool', 137 '__hash__': '__hash__(self) -> int', 138 '__int__': '__int__(self) -> int', 139 '__index__': '__int__(self) -> int', 140 '__str__': '__str__(self) -> str', 141 '__repr__': '__repr__(self) -> str', 142 '__eq__': '__eq__(self, other: object) -> bool', 143 '__ne__': '__ne__(self, other: object) -> bool', 144 } 145 constructor = ( 146 '__new__', 147 '__init__', 148 ) 149 150 qualname = cls.__name__ 151 cls_name = cls.__name__ 152 if outer is not None and cls_name.startswith(outer): 153 cls_name = cls_name[len(outer) :] 154 qualname = f'{outer}.{cls_name}' 155 156 override = OVERRIDE.get(qualname, {}) 157 done = set() if done is None else done 158 lines = Lines() 159 160 try: 161 162 class sub(cls): 163 pass 164 165 final = False 166 except TypeError: 167 final = True 168 if final: 169 lines.add = '@final' 170 base = cls.__base__ 171 if base is object: 172 lines.add = f'class {cls_name}:' 173 else: 174 lines.add = f'class {cls_name}({base.__name__}):' 175 lines.level += 1 176 start = len(lines) 177 178 for name in constructor: 179 if name in cls.__dict__: 180 done.add(name) 181 182 if '__hash__' in cls.__dict__: 183 if cls.__hash__ is None: 184 done.add('__hash__') 185 186 dct = cls.__dict__ 187 keys = list(dct.keys()) 188 189 def dunder(name): 190 return name.startswith('__') and name.endswith('__') 191 192 def members(seq): 193 for name in seq: 194 if name in skip: 195 continue 196 if name in done: 197 continue 198 if dunder(name): 199 if name not in special and name not in override: 200 done.add(name) 201 continue 202 yield name 203 204 for name in members(keys): 205 attr = getattr(cls, name) 206 if is_class(attr): 207 done.add(name) 208 lines.add = visit_class(attr, outer=cls_name) 209 continue 210 211 for name in members(keys): 212 if name in override: 213 done.add(name) 214 lines.add = override[name] 215 continue 216 217 if name in special: 218 done.add(name) 219 sig = special[name] 220 lines.add = f'def {sig}: ...' 221 continue 222 223 attr = getattr(cls, name) 224 225 if is_method(attr): 226 done.add(name) 227 if name == attr.__name__: 228 obj = dct[name] 229 if is_classmethod(obj): 230 lines.add = '@classmethod' 231 elif is_staticmethod(obj): 232 lines.add = '@staticmethod' 233 lines.add = visit_method(attr) 234 elif True: 235 lines.add = f'{name} = {attr.__name__}' 236 continue 237 238 if is_datadescr(attr): 239 done.add(name) 240 lines.add = visit_datadescr(attr) 241 continue 242 243 if is_property(attr): 244 done.add(name) 245 lines.add = visit_property(attr, name) 246 continue 247 248 if is_constant(attr): 249 done.add(name) 250 lines.add = visit_constant((name, attr)) 251 continue 252 253 leftovers = [name for name in keys if name not in done and name not in skip] 254 if leftovers: 255 raise RuntimeError(f'leftovers: {leftovers}') 256 257 if len(lines) == start: 258 lines.add = 'pass' 259 lines.level -= 1 260 return lines 261 262 263def visit_module(module, done=None): 264 skip = { 265 '__doc__', 266 '__name__', 267 '__loader__', 268 '__spec__', 269 '__file__', 270 '__package__', 271 '__builtins__', 272 '__pyx_unpickle_Enum', # FIXME review 273 } 274 275 done = set() if done is None else done 276 lines = Lines() 277 278 keys = list(module.__dict__.keys()) 279 keys.sort(key=lambda name: name.startswith('_')) 280 281 constants = [ 282 (name, getattr(module, name)) 283 for name in keys 284 if all( 285 ( 286 name not in done and name not in skip, 287 isinstance(getattr(module, name), int), 288 ) 289 ) 290 ] 291 for name, value in constants: 292 done.add(name) 293 if name in OVERRIDE: 294 lines.add = OVERRIDE[name] 295 else: 296 lines.add = visit_constant((name, value)) 297 if constants: 298 lines.add = '' 299 300 for name in keys: 301 if name in done or name in skip: 302 continue 303 value = getattr(module, name) 304 305 if is_class(value): 306 done.add(name) 307 if value.__module__ != module.__name__: 308 continue 309 lines.add = visit_class(value) 310 lines.add = '' 311 instances = [ 312 (k, getattr(module, k)) 313 for k in keys 314 if all( 315 ( 316 k not in done and k not in skip, 317 type(getattr(module, k)) is value, 318 ) 319 ) 320 ] 321 for attrname, attrvalue in instances: 322 done.add(attrname) 323 lines.add = visit_constant((attrname, attrvalue)) 324 if instances: 325 lines.add = '' 326 continue 327 328 if is_function(value): 329 done.add(name) 330 if name == value.__name__: 331 lines.add = visit_function(value) 332 else: 333 lines.add = f'{name} = {value.__name__}' 334 continue 335 336 lines.add = '' 337 for name in keys: 338 if name in done or name in skip: 339 continue 340 value = getattr(module, name) 341 done.add(name) 342 if name in OVERRIDE: 343 lines.add = OVERRIDE[name] 344 else: 345 lines.add = visit_constant((name, value)) 346 347 leftovers = [name for name in keys if name not in done and name not in skip] 348 if leftovers: 349 raise RuntimeError(f'leftovers: {leftovers}') 350 return lines 351 352 353IMPORTS = """ 354from __future__ import annotations 355import sys 356from threading import Lock 357from typing import ( 358 Any, 359 Union, 360 Optional, 361 NoReturn, 362 overload, 363) 364if sys.version_info >= (3, 8): 365 from typing import ( 366 final, 367 Final, 368 Literal, 369 ) 370else: 371 from typing_extensions import ( 372 final, 373 Final, 374 Literal, 375 ) 376if sys.version_info >= (3, 9): 377 from collections.abc import ( 378 Callable, 379 Hashable, 380 Iterable, 381 Iterator, 382 Sequence, 383 Mapping, 384 ) 385else: 386 from typing import ( 387 Callable, 388 Hashable, 389 Iterable, 390 Iterator, 391 Sequence, 392 Mapping, 393 ) 394if sys.version_info >= (3, 11): 395 from typing import Self 396else: 397 from typing_extensions import Self 398from os import PathLike 399 400import numpy 401 402IntType: numpy.dtype = ... 403RealType: numpy.dtype = ... 404ComplexType: numpy.dtype = ... 405ScalarType: numpy.dtype = ... 406""" 407 408OVERRIDE = { 409 'Error': {}, 410 '__pyx_capi__': '__pyx_capi__: Final[Dict[str, Any]] = ...', 411 '__type_registry__': '__type_registry__: Final[Dict[int, type[Object]]] = ...', 412} 413 414TYPING = """ 415""" 416 417 418def visit_petsc4py_PETSc(done=None): 419 from petsc4py import PETSc as module 420 421 lines = Lines() 422 lines.add = IMPORTS 423 lines.add = '' 424 lines.add = visit_module(module) 425 lines.add = TYPING 426 return lines 427 428 429def generate(filename): 430 dirname = os.path.dirname(filename) 431 os.makedirs(dirname, exist_ok=True) 432 with open(filename, 'w') as f: 433 for line in visit_petsc4py_PETSc(): 434 print(line, file=f) 435 436 437OUTDIR = os.path.join('src', 'petsc4py') 438 439if __name__ == '__main__': 440 generate(os.path.join(OUTDIR, 'PETSc.pyi')) 441