xref: /petsc/src/binding/petsc4py/test/test_tao_py.py (revision ebead697dbf761eb322f829370bbe90b3bd93fa3)
1import unittest
2from petsc4py import PETSc
3from sys import getrefcount
4import gc
5
6# --------------------------------------------------------------------
7class Objective:
8    def __call__(self, tao, x):
9        return (x[0] - 1.0)**2 + (x[1] - 2.0)**2
10
11class Gradient:
12    def __call__(self, tao, x, g):
13        g[0] = 2.0*(x[0] - 1.0)
14        g[1] = 2.0*(x[1] - 2.0)
15        g.assemble()
16
17class MyTao:
18    def __init__(self):
19        self.log = {}
20
21    def _log(self, method):
22        self.log.setdefault(method, 0)
23        self.log[method] += 1
24
25    def create(self, tao):
26        self._log('create')
27        self.testvec = PETSc.Vec()
28
29    def destroy(self, tao):
30        self._log('destroy')
31        self.testvec.destroy()
32
33    def setFromOptions(self, tao):
34        self._log('setFromOptions')
35
36    def setUp(self, tao):
37        self._log('setUp')
38        self.testvec = tao.getSolution().duplicate()
39
40    def solve(self, tao):
41        self._log('solve')
42
43    def step(self, tao, x, g, s):
44        self._log('step')
45        tao.computeGradient(x,g)
46        g.copy(s)
47        s.scale(-1.0)
48
49    def preStep(self, tao):
50        self._log('preStep')
51
52    def postStep(self, tao):
53        self._log('postStep')
54
55    def monitor(self, tao):
56        self._log('monitor')
57
58class TestTaoPython(unittest.TestCase):
59
60    def setUp(self):
61        self.tao = PETSc.TAO()
62        self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF)
63        ctx = self.tao.getPythonContext()
64        self.assertEqual(getrefcount(ctx),  3)
65        self.assertEqual(ctx.log['create'], 1)
66        self.nsolve = 0
67
68    def tearDown(self):
69        ctx = self.tao.getPythonContext()
70        self.assertEqual(getrefcount(ctx), 3)
71        self.assertTrue('destroy' not in ctx.log)
72        self.tao.destroy()
73        self.tao = None
74        self.assertEqual(ctx.log['destroy'], 1)
75        self.assertEqual(getrefcount(ctx),   2)
76
77    def testGetType(self):
78        ctx = self.tao.getPythonContext()
79        pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__)
80        self.assertTrue(self.tao.getPythonType() == pytype)
81
82    def testSolve(self):
83        tao = self.tao
84        ctx = tao.getPythonContext()
85        x = PETSc.Vec().create(tao.getComm())
86        x.setType('standard')
87        x.setSizes(2)
88        y1 = x.duplicate()
89        y2 = x.duplicate()
90        tao.setObjective(Objective())
91        tao.setGradient(Gradient(),None)
92        tao.setMonitor(ctx.monitor)
93        tao.setFromOptions()
94        tao.setMaximumIterations(3)
95        tao.setSolution(x)
96
97        # Call the solve method of MyTAO
98        x.set(0.5)
99        tao.solve()
100        n = tao.getIterationNumber()
101        self.assertTrue(n == 0)
102
103        # Call the default solve method and use step of MyTAO
104        ctx.solve = None
105        x.set(0.5)
106        tao.solve()
107        n = tao.getIterationNumber()
108        self.assertTrue(n == 3)
109        x.copy(y1)
110
111        # Call the default solve method with the default step method
112        ctx.step = None
113        x.set(0.5)
114        tao.solve()
115        n = tao.getIterationNumber()
116        self.assertTrue(n == 3)
117        x.copy(y2)
118
119        self.assertTrue(y1.equal(y2))
120        self.assertTrue(ctx.log['monitor'] == 2*(n+1))
121        self.assertTrue(ctx.log['preStep'] == 2*n)
122        self.assertTrue(ctx.log['postStep'] == 2*n)
123        self.assertTrue(ctx.log['solve'] == 1)
124        self.assertTrue(ctx.log['setUp'] == 1)
125        self.assertTrue(ctx.log['setFromOptions'] == 1)
126        self.assertTrue(ctx.log['step'] == n)
127        tao.cancelMonitor()
128
129# --------------------------------------------------------------------
130
131import numpy
132if numpy.iscomplexobj(PETSc.ScalarType()):
133    del TestTaoPython
134
135if __name__ == '__main__':
136    unittest.main()
137
138# --------------------------------------------------------------------
139