xref: /petsc/src/binding/petsc4py/test/test_tao_py.py (revision df4cd43f92eaa320656440c40edb1046daee8f75)
1import unittest
2from petsc4py import PETSc
3from sys import getrefcount
4
5# --------------------------------------------------------------------
6class Objective:
7    def __call__(self, tao, x):
8        return (x[0] - 1.0)**2 + (x[1] - 2.0)**2
9
10class Gradient:
11    def __call__(self, tao, x, g):
12        g[0] = 2.0*(x[0] - 1.0)
13        g[1] = 2.0*(x[1] - 2.0)
14        g.assemble()
15
16class MyTao:
17    def __init__(self):
18        self.log = {}
19
20    def _log(self, method):
21        self.log.setdefault(method, 0)
22        self.log[method] += 1
23
24    def create(self, tao):
25        self._log('create')
26        self.testvec = PETSc.Vec()
27
28    def destroy(self, tao):
29        self._log('destroy')
30        self.testvec.destroy()
31
32    def setFromOptions(self, tao):
33        self._log('setFromOptions')
34
35    def setUp(self, tao):
36        self._log('setUp')
37        self.testvec = tao.getSolution().duplicate()
38
39    def solve(self, tao):
40        self._log('solve')
41
42    def step(self, tao, x, g, s):
43        self._log('step')
44        tao.computeGradient(x,g)
45        g.copy(s)
46        s.scale(-1.0)
47
48    def preStep(self, tao):
49        self._log('preStep')
50
51    def postStep(self, tao):
52        self._log('postStep')
53
54    def monitor(self, tao):
55        self._log('monitor')
56
57class TestTaoPython(unittest.TestCase):
58
59    def setUp(self):
60        self.tao = PETSc.TAO()
61        self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF)
62        ctx = self.tao.getPythonContext()
63        self.assertEqual(getrefcount(ctx),  3)
64        self.assertEqual(ctx.log['create'], 1)
65        self.nsolve = 0
66
67    def tearDown(self):
68        ctx = self.tao.getPythonContext()
69        self.assertEqual(getrefcount(ctx), 3)
70        self.assertTrue('destroy' not in ctx.log)
71        self.tao.destroy()
72        self.tao = None
73        PETSc.garbage_cleanup()
74        self.assertEqual(ctx.log['destroy'], 1)
75        self.assertEqual(getrefcount(ctx),   2)
76
77    def testGetType(self):
78        ctx = self.tao.getPythonContext()
79        pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__)
80        self.assertTrue(self.tao.getPythonType() == pytype)
81
82    def testSolve(self):
83        tao = self.tao
84        ctx = tao.getPythonContext()
85        x = PETSc.Vec().create(tao.getComm())
86        x.setType('standard')
87        x.setSizes(2)
88        y1 = x.duplicate()
89        y2 = x.duplicate()
90        tao.setObjective(Objective())
91        tao.setGradient(Gradient(),None)
92        tao.setMonitor(ctx.monitor)
93        tao.setFromOptions()
94        tao.setMaximumIterations(3)
95        tao.setSolution(x)
96
97        # Call the solve method of MyTAO
98        x.set(0.5)
99        tao.solve()
100        n = tao.getIterationNumber()
101        self.assertTrue(n == 0)
102
103        # Call the default solve method and use step of MyTAO
104        ctx.solve = None
105        x.set(0.5)
106        tao.solve()
107        n = tao.getIterationNumber()
108        self.assertTrue(n == 3)
109        x.copy(y1)
110
111        # Call the default solve method with the default step method
112        ctx.step = None
113        x.set(0.5)
114        tao.solve()
115        n = tao.getIterationNumber()
116        self.assertTrue(n == 3)
117        x.copy(y2)
118
119        self.assertTrue(y1.equal(y2))
120        self.assertTrue(ctx.log['monitor'] == 2*(n+1))
121        self.assertTrue(ctx.log['preStep'] == 2*n)
122        self.assertTrue(ctx.log['postStep'] == 2*n)
123        self.assertTrue(ctx.log['solve'] == 1)
124        self.assertTrue(ctx.log['setUp'] == 1)
125        self.assertTrue(ctx.log['setFromOptions'] == 1)
126        self.assertTrue(ctx.log['step'] == n)
127        tao.cancelMonitor()
128
129# --------------------------------------------------------------------
130
131import numpy
132if numpy.iscomplexobj(PETSc.ScalarType()):
133    del TestTaoPython
134
135if __name__ == '__main__':
136    unittest.main()
137
138# --------------------------------------------------------------------
139