1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from test_snes import BaseTestSNES 6 7# -------------------------------------------------------------------- 8 9 10class MySNES: 11 def __init__(self): 12 self.trace = False 13 self.call_log = {} 14 15 def _log(self, method, *args): 16 self.call_log.setdefault(method, 0) 17 self.call_log[method] += 1 18 if not self.trace: 19 return 20 clsname = self.__class__.__name__ 21 pargs = [] 22 for a in args: 23 pargs.append(a) 24 if isinstance(a, PETSc.Object): 25 pargs[-1] = type(a).__name__ 26 pargs = tuple(pargs) 27 print(f'{clsname}.{method}{pargs}') 28 29 def create(self, *args): 30 self._log('create', *args) 31 32 def destroy(self, *args): 33 self._log('destroy', *args) 34 if not self.trace: 35 return 36 for k, v in self.call_log.items(): 37 print(f'{k} {v}') 38 39 def view(self, snes, viewer): 40 self._log('view', snes, viewer) 41 42 def setFromOptions(self, snes): 43 OptDB = PETSc.Options(snes) 44 self.trace = OptDB.getBool('trace', self.trace) 45 self._log('setFromOptions', snes) 46 47 def setUp(self, snes): 48 self._log('setUp', snes) 49 50 def reset(self, snes): 51 self._log('reset', snes) 52 53 # def preSolve(self, snes): 54 # self._log('preSolve', snes) 55 # 56 # def postSolve(self, snes): 57 # self._log('postSolve', snes) 58 59 def preStep(self, snes): 60 self._log('preStep', snes) 61 62 def postStep(self, snes): 63 self._log('postStep', snes) 64 65 # def computeFunction(self, snes, x, F): 66 # self._log('computeFunction', snes, x, F) 67 # snes.computeFunction(x, F) 68 # 69 # def computeJacobian(self, snes, x, A, B): 70 # self._log('computeJacobian', snes, x, A, B) 71 # flag = snes.computeJacobian(x, A, B) 72 # return flag 73 # 74 # def linearSolve(self, snes, b, x): 75 # self._log('linearSolve', snes, b, x) 76 # snes.ksp.solve(b,x) 77 # ## return False # not succeed 78 # if snes.ksp.getConvergedReason() < 0: 79 # return False # not succeed 80 # return True # succeed 81 # 82 # def lineSearch(self, snes, x, y, F): 83 # self._log('lineSearch', snes, x, y, F) 84 # x.axpy(-1,y) 85 # snes.computeFunction(x, F) 86 # ## return False # not succeed 87 # return True # succeed 88 89 90class TestSNESPython(BaseTestSNES, unittest.TestCase): 91 SNES_TYPE = PETSc.SNES.Type.PYTHON 92 93 def setUp(self): 94 super().setUp() 95 self.snes.setPythonContext(MySNES()) 96 97 def testGetType(self): 98 ctx = self.snes.getPythonContext() 99 pytype = f'{ctx.__module__}.{type(ctx).__name__}' 100 self.assertTrue(self.snes.getPythonType() == pytype) 101 102 103# -------------------------------------------------------------------- 104 105if __name__ == '__main__': 106 unittest.main() 107 108# -------------------------------------------------------------------- 109