xref: /petsc/src/binding/petsc4py/test/test_snes_py.py (revision 98d129c30f3ee9fdddc40fdbc5a989b7be64f888)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from test_snes import BaseTestSNES
6
7# --------------------------------------------------------------------
8
9
10class MySNES:
11    def __init__(self):
12        self.trace = False
13        self.call_log = {}
14
15    def _log(self, method, *args):
16        self.call_log.setdefault(method, 0)
17        self.call_log[method] += 1
18        if not self.trace:
19            return
20        clsname = self.__class__.__name__
21        pargs = []
22        for a in args:
23            pargs.append(a)
24            if isinstance(a, PETSc.Object):
25                pargs[-1] = type(a).__name__
26        pargs = tuple(pargs)
27        print(f'{clsname}.{method}{pargs}')
28
29    def create(self, *args):
30        self._log('create', *args)
31
32    def destroy(self, *args):
33        self._log('destroy', *args)
34        if not self.trace:
35            return
36        for k, v in self.call_log.items():
37            print(f'{k} {v}')
38
39    def view(self, snes, viewer):
40        self._log('view', snes, viewer)
41
42    def setFromOptions(self, snes):
43        OptDB = PETSc.Options(snes)
44        self.trace = OptDB.getBool('trace', self.trace)
45        self._log('setFromOptions', snes)
46
47    def setUp(self, snes):
48        self._log('setUp', snes)
49
50    def reset(self, snes):
51        self._log('reset', snes)
52
53    # def preSolve(self, snes):
54    #    self._log('preSolve', snes)
55    #
56    # def postSolve(self, snes):
57    #    self._log('postSolve', snes)
58
59    def preStep(self, snes):
60        self._log('preStep', snes)
61
62    def postStep(self, snes):
63        self._log('postStep', snes)
64
65    # def computeFunction(self, snes, x, F):
66    #    self._log('computeFunction', snes, x, F)
67    #    snes.computeFunction(x, F)
68    #
69    # def computeJacobian(self, snes, x, A, B):
70    #    self._log('computeJacobian', snes, x, A, B)
71    #    flag = snes.computeJacobian(x, A, B)
72    #    return flag
73    #
74    # def linearSolve(self, snes, b, x):
75    #    self._log('linearSolve', snes, b, x)
76    #    snes.ksp.solve(b,x)
77    #    ## return False # not succeed
78    #    if snes.ksp.getConvergedReason() < 0:
79    #        return False # not succeed
80    #    return True # succeed
81    #
82    # def lineSearch(self, snes, x, y, F):
83    #    self._log('lineSearch', snes, x, y, F)
84    #    x.axpy(-1,y)
85    #    snes.computeFunction(x, F)
86    #    ## return False # not succeed
87    #    return True # succeed
88
89
90class TestSNESPython(BaseTestSNES, unittest.TestCase):
91    SNES_TYPE = PETSc.SNES.Type.PYTHON
92
93    def setUp(self):
94        super().setUp()
95        self.snes.setPythonContext(MySNES())
96
97    def testGetType(self):
98        ctx = self.snes.getPythonContext()
99        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
100        self.assertTrue(self.snes.getPythonType() == pytype)
101
102
103# --------------------------------------------------------------------
104
105if __name__ == '__main__':
106    unittest.main()
107
108# --------------------------------------------------------------------
109