1# -------------------------------------------------------------------- 2 3from petsc4py import PETSc 4import unittest 5from sys import getrefcount 6 7# -------------------------------------------------------------------- 8 9class MySNES(object): 10 11 def __init__(self): 12 self.trace = False 13 self.call_log = {} 14 15 def _log(self, method, *args): 16 self.call_log.setdefault(method, 0) 17 self.call_log[method] += 1 18 if not self.trace: return 19 clsname = self.__class__.__name__ 20 pargs = [] 21 for a in args: 22 pargs.append(a) 23 if isinstance(a, PETSc.Object): 24 pargs[-1] = type(a).__name__ 25 pargs = tuple(pargs) 26 print ('%-20s' % ('%s.%s%s'% (clsname, method, pargs))) 27 28 def create(self,*args): 29 self._log('create', *args) 30 31 def destroy(self,*args): 32 self._log('destroy', *args) 33 if not self.trace: return 34 for k, v in self.call_log.items(): 35 print ('%-20s %2d' % (k, v)) 36 37 def view(self, snes, viewer): 38 self._log('view', snes, viewer) 39 40 def setFromOptions(self, snes): 41 OptDB = PETSc.Options(snes) 42 self.trace = OptDB.getBool('trace',self.trace) 43 self._log('setFromOptions',snes) 44 45 def setUp(self, snes): 46 self._log('setUp', snes) 47 48 def reset(self, snes): 49 self._log('reset', snes) 50 51 #def preSolve(self, snes): 52 # self._log('preSolve', snes) 53 # 54 #def postSolve(self, snes): 55 # self._log('postSolve', snes) 56 57 def preStep(self, snes): 58 self._log('preStep', snes) 59 60 def postStep(self, snes): 61 self._log('postStep', snes) 62 63 #def computeFunction(self, snes, x, F): 64 # self._log('computeFunction', snes, x, F) 65 # snes.computeFunction(x, F) 66 # 67 #def computeJacobian(self, snes, x, A, B): 68 # self._log('computeJacobian', snes, x, A, B) 69 # flag = snes.computeJacobian(x, A, B) 70 # return flag 71 # 72 #def linearSolve(self, snes, b, x): 73 # self._log('linearSolve', snes, b, x) 74 # snes.ksp.solve(b,x) 75 # ## return False # not succeed 76 # if snes.ksp.getConvergedReason() < 0: 77 # return False # not succeed 78 # return True # succeed 79 # 80 #def lineSearch(self, snes, x, y, F): 81 # self._log('lineSearch', snes, x, y, F) 82 # x.axpy(-1,y) 83 # snes.computeFunction(x, F) 84 # ## return False # not succeed 85 # return True # succeed 86 87 88from test_snes import BaseTestSNES 89 90class TestSNESPython(BaseTestSNES, unittest.TestCase): 91 92 SNES_TYPE = PETSc.SNES.Type.PYTHON 93 94 def setUp(self): 95 super(TestSNESPython, self).setUp() 96 self.snes.setPythonContext(MySNES()) 97 98 def testGetType(self): 99 ctx = self.snes.getPythonContext() 100 pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__) 101 self.assertTrue(self.snes.getPythonType() == pytype) 102 103# -------------------------------------------------------------------- 104 105if __name__ == '__main__': 106 unittest.main() 107 108# -------------------------------------------------------------------- 109