xref: /petsc/src/binding/petsc4py/test/test_tao_py.py (revision fc47f7259de0629fb6058a3a6076517fd44721f2)
1import unittest
2from petsc4py import PETSc
3from sys import getrefcount
4import numpy
5
6
7# --------------------------------------------------------------------
8class Objective:
9    def __call__(self, tao, x):
10        return (x[0] - 1.0) ** 2 + (x[1] - 2.0) ** 2
11
12
13class Gradient:
14    def __call__(self, tao, x, g):
15        g[0] = 2.0 * (x[0] - 1.0)
16        g[1] = 2.0 * (x[1] - 2.0)
17        g.assemble()
18
19
20class MyTao:
21    def __init__(self):
22        self.log = {}
23
24    def _log(self, method):
25        self.log.setdefault(method, 0)
26        self.log[method] += 1
27
28    def create(self, tao):
29        self._log('create')
30        self.testvec = PETSc.Vec()
31
32    def destroy(self, tao):
33        self._log('destroy')
34        self.testvec.destroy()
35
36    def setFromOptions(self, tao):
37        self._log('setFromOptions')
38
39    def setUp(self, tao):
40        self._log('setUp')
41        self.testvec = tao.getSolution().duplicate()
42
43    def solve(self, tao):
44        self._log('solve')
45
46    def step(self, tao, x, g, s):
47        self._log('step')
48        tao.computeGradient(x, g)
49        g.copy(s)
50        s.scale(-1.0)
51
52    def preStep(self, tao):
53        self._log('preStep')
54
55    def postStep(self, tao):
56        self._log('postStep')
57
58    def monitor(self, tao):
59        self._log('monitor')
60
61
62class TestTaoPython(unittest.TestCase):
63    def setUp(self):
64        self.tao = PETSc.TAO()
65        self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF)
66        ctx = self.tao.getPythonContext()
67        self.assertEqual(getrefcount(ctx), 3)
68        self.assertEqual(ctx.log['create'], 1)
69        self.nsolve = 0
70
71    def tearDown(self):
72        ctx = self.tao.getPythonContext()
73        self.assertEqual(getrefcount(ctx), 3)
74        self.assertTrue('destroy' not in ctx.log)
75        self.tao.destroy()
76        self.tao = None
77        PETSc.garbage_cleanup()
78        self.assertEqual(ctx.log['destroy'], 1)
79        self.assertEqual(getrefcount(ctx), 2)
80
81    def testGetType(self):
82        ctx = self.tao.getPythonContext()
83        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
84        self.assertTrue(self.tao.getPythonType() == pytype)
85
86    def testSolve(self):
87        tao = self.tao
88        ctx = tao.getPythonContext()
89        x = PETSc.Vec().create(tao.getComm())
90        x.setType('standard')
91        x.setSizes(2)
92        y1 = x.duplicate()
93        y2 = x.duplicate()
94        tao.setObjective(Objective())
95        tao.setGradient(Gradient(), None)
96        tao.setMonitor(ctx.monitor)
97        tao.setFromOptions()
98        tao.setMaximumIterations(3)
99
100        def _update(tao, it, cnt):
101             cnt += 1
102        cnt_up = numpy.array(0)
103        tao.setUpdate(_update, (cnt_up,) )
104        tao.setSolution(x)
105
106        # Call the solve method of MyTAO
107        x.set(0.5)
108        tao.solve()
109        n = tao.getIterationNumber()
110        self.assertTrue(n == 0)
111
112        # Call the default solve method and use step of MyTAO
113        ctx.solve = None
114        x.set(0.5)
115        tao.solve()
116        n = tao.getIterationNumber()
117        self.assertTrue(n in [2, 3])
118        x.copy(y1)
119
120        # Call the default solve method with the default step method
121        ctx.step = None
122        x.set(0.5)
123        tao.solve()
124        n = tao.getIterationNumber()
125        self.assertTrue(n in [2, 3])
126        x.copy(y2)
127
128        self.assertTrue(y1.equal(y2))
129        self.assertTrue(ctx.log['monitor'] == 2 * (n + 1))
130        self.assertTrue(ctx.log['preStep'] == 2 * n)
131        self.assertTrue(ctx.log['postStep'] == 2 * n)
132        self.assertTrue(ctx.log['solve'] == 1)
133        self.assertTrue(ctx.log['setUp'] == 1)
134        self.assertTrue(ctx.log['setFromOptions'] == 1)
135        self.assertTrue(ctx.log['step'] == n)
136        self.assertEqual(cnt_up, 2 * n)
137        tao.cancelMonitor()
138
139
140# --------------------------------------------------------------------
141
142if numpy.iscomplexobj(PETSc.ScalarType()):
143    del TestTaoPython
144
145if __name__ == '__main__':
146    unittest.main()
147
148# --------------------------------------------------------------------
149