xref: /petsc/src/binding/petsc4py/test/test_snes_py.py (revision 9371c9d470a9602b6d10a8bf50c9b2280a79e45a)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6
7# --------------------------------------------------------------------
8
9class MySNES(object):
10
11    def __init__(self):
12        self.trace = False
13        self.call_log = {}
14
15    def _log(self, method, *args):
16        self.call_log.setdefault(method, 0)
17        self.call_log[method] += 1
18        if not self.trace: return
19        clsname = self.__class__.__name__
20        pargs = []
21        for a in args:
22            pargs.append(a)
23            if isinstance(a, PETSc.Object):
24                pargs[-1] = type(a).__name__
25        pargs = tuple(pargs)
26        print ('%-20s' % ('%s.%s%s'% (clsname, method, pargs)))
27
28    def create(self,*args):
29        self._log('create', *args)
30
31    def destroy(self,*args):
32        self._log('destroy', *args)
33        if not self.trace: return
34        for k, v in self.call_log.items():
35            print ('%-20s %2d' % (k, v))
36
37    def view(self, snes, viewer):
38        self._log('view', snes, viewer)
39
40    def setFromOptions(self, snes):
41        OptDB = PETSc.Options(snes)
42        self.trace = OptDB.getBool('trace',self.trace)
43        self._log('setFromOptions',snes)
44
45    def setUp(self, snes):
46        self._log('setUp', snes)
47
48    def reset(self, snes):
49        self._log('reset', snes)
50
51    #def preSolve(self, snes):
52    #    self._log('preSolve', snes)
53    #
54    #def postSolve(self, snes):
55    #    self._log('postSolve', snes)
56
57    def preStep(self, snes):
58        self._log('preStep', snes)
59
60    def postStep(self, snes):
61        self._log('postStep', snes)
62
63    #def computeFunction(self, snes, x, F):
64    #    self._log('computeFunction', snes, x, F)
65    #    snes.computeFunction(x, F)
66    #
67    #def computeJacobian(self, snes, x, A, B):
68    #    self._log('computeJacobian', snes, x, A, B)
69    #    flag = snes.computeJacobian(x, A, B)
70    #    return flag
71    #
72    #def linearSolve(self, snes, b, x):
73    #    self._log('linearSolve', snes, b, x)
74    #    snes.ksp.solve(b,x)
75    #    ## return False # not succeed
76    #    if snes.ksp.getConvergedReason() < 0:
77    #        return False # not succeed
78    #    return True # succeed
79    #
80    #def lineSearch(self, snes, x, y, F):
81    #    self._log('lineSearch', snes, x, y, F)
82    #    x.axpy(-1,y)
83    #    snes.computeFunction(x, F)
84    #    ## return False # not succeed
85    #    return True # succeed
86
87
88from test_snes import BaseTestSNES
89
90class TestSNESPython(BaseTestSNES, unittest.TestCase):
91
92    SNES_TYPE = PETSc.SNES.Type.PYTHON
93
94    def setUp(self):
95        super(TestSNESPython, self).setUp()
96        self.snes.setPythonContext(MySNES())
97
98    def testGetType(self):
99        ctx = self.snes.getPythonContext()
100        pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__)
101        self.assertTrue(self.snes.getPythonType() == pytype)
102
103# --------------------------------------------------------------------
104
105if __name__ == '__main__':
106    unittest.main()
107
108# --------------------------------------------------------------------
109