xref: /petsc/src/binding/petsc4py/test/test_tao_py.py (revision db2b9530cea7203cdddcdfcaf19eada576ef3459)
1import unittest
2from petsc4py import PETSc
3from sys import getrefcount
4import gc
5
6# --------------------------------------------------------------------
7class Objective:
8    def __call__(self, tao, x):
9        return (x[0] - 1.0)**2 + (x[1] - 2.0)**2
10
11class Gradient:
12    def __call__(self, tao, x, g):
13        g[0] = 2.0*(x[0] - 1.0)
14        g[1] = 2.0*(x[1] - 2.0)
15        g.assemble()
16
17class MyTao:
18    def __init__(self):
19        self.log = {}
20
21    def _log(self, method):
22        self.log.setdefault(method, 0)
23        self.log[method] += 1
24
25    def create(self, tao):
26        self._log('create')
27        self.testvec = PETSc.Vec()
28
29    def destroy(self, tao):
30        self._log('destroy')
31        self.testvec.destroy()
32
33    def setFromOptions(self, tao):
34        self._log('setFromOptions')
35
36    def setUp(self, tao):
37        self._log('setUp')
38        self.testvec = tao.getSolution().duplicate()
39
40    def solve(self, tao):
41        self._log('solve')
42
43    def step(self, tao, x, g, s):
44        self._log('step')
45        tao.computeGradient(x,g)
46        g.copy(s)
47        s.scale(-1.0)
48
49    def preStep(self, tao):
50        self._log('preStep')
51
52    def postStep(self, tao):
53        self._log('postStep')
54
55    def monitor(self, tao):
56        self._log('monitor')
57
58class TestTaoPython(unittest.TestCase):
59
60    def setUp(self):
61        self.tao = PETSc.TAO()
62        self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF)
63        ctx = self.tao.getPythonContext()
64        self.assertEqual(getrefcount(ctx),  3)
65        self.assertEqual(ctx.log['create'], 1)
66        self.nsolve = 0
67
68    def tearDown(self):
69        ctx = self.tao.getPythonContext()
70        self.assertEqual(getrefcount(ctx), 4)
71        self.assertTrue('destroy' not in ctx.log)
72        self.tao.destroy()
73        self.tao = None
74        self.assertEqual(ctx.log['destroy'], 1)
75        self.assertEqual(getrefcount(ctx),   2)
76
77    def testSolve(self):
78        tao = self.tao
79        ctx = tao.getPythonContext()
80        x = PETSc.Vec().create(tao.getComm())
81        x.setType('standard')
82        x.setSizes(2)
83        y1 = x.duplicate()
84        y2 = x.duplicate()
85        tao.setObjective(Objective())
86        tao.setGradient(Gradient(),None)
87        tao.setMonitor(ctx.monitor)
88        tao.setFromOptions()
89        tao.setMaximumIterations(3)
90        tao.setSolution(x)
91
92        # Call the solve method of MyTAO
93        x.set(0.5)
94        tao.solve()
95        n = tao.getIterationNumber()
96        self.assertTrue(n == 0)
97
98        # Call the default solve method and use step of MyTAO
99        ctx.solve = None
100        x.set(0.5)
101        tao.solve()
102        n = tao.getIterationNumber()
103        self.assertTrue(n == 3)
104        x.copy(y1)
105
106        # Call the default solve method with the default step method
107        ctx.step = None
108        x.set(0.5)
109        tao.solve()
110        n = tao.getIterationNumber()
111        self.assertTrue(n == 3)
112        x.copy(y2)
113
114        self.assertTrue(y1.equal(y2))
115        self.assertTrue(ctx.log['monitor'] == 2*(n+1))
116        self.assertTrue(ctx.log['preStep'] == 2*n)
117        self.assertTrue(ctx.log['postStep'] == 2*n)
118        self.assertTrue(ctx.log['solve'] == 1)
119        self.assertTrue(ctx.log['setUp'] == 1)
120        self.assertTrue(ctx.log['setFromOptions'] == 1)
121        self.assertTrue(ctx.log['step'] == n)
122
123# --------------------------------------------------------------------
124
125import numpy
126if numpy.iscomplexobj(PETSc.ScalarType()):
127    del TestTaoPython
128
129if __name__ == '__main__':
130    unittest.main()
131
132# --------------------------------------------------------------------
133