xref: /petsc/src/binding/petsc4py/test/test_tao_py.py (revision b5f0bcd6e9e8ed97648738542f5163d94f7b1782)
1a82e8c82SStefano Zampiniimport unittest
2a82e8c82SStefano Zampinifrom petsc4py import PETSc
3a82e8c82SStefano Zampinifrom sys import getrefcount
46f336411SStefano Zampiniimport numpy
56f336411SStefano Zampini
6a82e8c82SStefano Zampini
7a82e8c82SStefano Zampini# --------------------------------------------------------------------
8a82e8c82SStefano Zampiniclass Objective:
9a82e8c82SStefano Zampini    def __call__(self, tao, x):
10a82e8c82SStefano Zampini        return (x[0] - 1.0) ** 2 + (x[1] - 2.0) ** 2
11a82e8c82SStefano Zampini
126f336411SStefano Zampini
13a82e8c82SStefano Zampiniclass Gradient:
14a82e8c82SStefano Zampini    def __call__(self, tao, x, g):
15a82e8c82SStefano Zampini        g[0] = 2.0 * (x[0] - 1.0)
16a82e8c82SStefano Zampini        g[1] = 2.0 * (x[1] - 2.0)
17a82e8c82SStefano Zampini        g.assemble()
18a82e8c82SStefano Zampini
196f336411SStefano Zampini
20a82e8c82SStefano Zampiniclass MyTao:
21a82e8c82SStefano Zampini    def __init__(self):
22a82e8c82SStefano Zampini        self.log = {}
23a82e8c82SStefano Zampini
24a82e8c82SStefano Zampini    def _log(self, method):
25a82e8c82SStefano Zampini        self.log.setdefault(method, 0)
26a82e8c82SStefano Zampini        self.log[method] += 1
27a82e8c82SStefano Zampini
28a82e8c82SStefano Zampini    def create(self, tao):
29a82e8c82SStefano Zampini        self._log('create')
30a82e8c82SStefano Zampini        self.testvec = PETSc.Vec()
31a82e8c82SStefano Zampini
32a82e8c82SStefano Zampini    def destroy(self, tao):
33a82e8c82SStefano Zampini        self._log('destroy')
34a82e8c82SStefano Zampini        self.testvec.destroy()
35a82e8c82SStefano Zampini
36a82e8c82SStefano Zampini    def setFromOptions(self, tao):
37a82e8c82SStefano Zampini        self._log('setFromOptions')
38a82e8c82SStefano Zampini
39a82e8c82SStefano Zampini    def setUp(self, tao):
40a82e8c82SStefano Zampini        self._log('setUp')
41a82e8c82SStefano Zampini        self.testvec = tao.getSolution().duplicate()
42a82e8c82SStefano Zampini
43a82e8c82SStefano Zampini    def solve(self, tao):
44a82e8c82SStefano Zampini        self._log('solve')
45a82e8c82SStefano Zampini
46a82e8c82SStefano Zampini    def step(self, tao, x, g, s):
47a82e8c82SStefano Zampini        self._log('step')
48a82e8c82SStefano Zampini        tao.computeGradient(x, g)
49a82e8c82SStefano Zampini        g.copy(s)
50a82e8c82SStefano Zampini        s.scale(-1.0)
51a82e8c82SStefano Zampini
52a82e8c82SStefano Zampini    def preStep(self, tao):
53a82e8c82SStefano Zampini        self._log('preStep')
54a82e8c82SStefano Zampini
55a82e8c82SStefano Zampini    def postStep(self, tao):
56a82e8c82SStefano Zampini        self._log('postStep')
57a82e8c82SStefano Zampini
58a82e8c82SStefano Zampini    def monitor(self, tao):
59a82e8c82SStefano Zampini        self._log('monitor')
60a82e8c82SStefano Zampini
61a82e8c82SStefano Zampini
626f336411SStefano Zampiniclass TestTaoPython(unittest.TestCase):
63a82e8c82SStefano Zampini    def setUp(self):
64a82e8c82SStefano Zampini        self.tao = PETSc.TAO()
65a82e8c82SStefano Zampini        self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF)
6639933f97SStefano Zampini        self.assertEqual(getrefcount(self._getCtx()), 2)
6739933f97SStefano Zampini        self.assertEqual(self._getCtx().log['create'], 1)
68a82e8c82SStefano Zampini        self.nsolve = 0
69a82e8c82SStefano Zampini
70a82e8c82SStefano Zampini    def tearDown(self):
7139933f97SStefano Zampini        self.assertEqual(getrefcount(self._getCtx()), 2)
7239933f97SStefano Zampini        self.assertTrue('destroy' not in self._getCtx().log)
7339933f97SStefano Zampini        ctx = self._getCtx()
74a82e8c82SStefano Zampini        self.tao.destroy()
75a82e8c82SStefano Zampini        self.tao = None
7662e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
77a82e8c82SStefano Zampini        self.assertEqual(ctx.log['destroy'], 1)
78a82e8c82SStefano Zampini
79ebead697SStefano Zampini    def testGetType(self):
80ebead697SStefano Zampini        ctx = self.tao.getPythonContext()
816f336411SStefano Zampini        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
82ebead697SStefano Zampini        self.assertTrue(self.tao.getPythonType() == pytype)
83ebead697SStefano Zampini
84a82e8c82SStefano Zampini    def testSolve(self):
85a82e8c82SStefano Zampini        tao = self.tao
86a82e8c82SStefano Zampini        ctx = tao.getPythonContext()
87a82e8c82SStefano Zampini        x = PETSc.Vec().create(tao.getComm())
88a82e8c82SStefano Zampini        x.setType('standard')
89a82e8c82SStefano Zampini        x.setSizes(2)
90a82e8c82SStefano Zampini        y1 = x.duplicate()
91a82e8c82SStefano Zampini        y2 = x.duplicate()
92a82e8c82SStefano Zampini        tao.setObjective(Objective())
93a82e8c82SStefano Zampini        tao.setGradient(Gradient(), None)
94a82e8c82SStefano Zampini        tao.setMonitor(ctx.monitor)
95a82e8c82SStefano Zampini        tao.setFromOptions()
96a82e8c82SStefano Zampini        tao.setMaximumIterations(3)
979ae6e6bbSStefano Zampini
989ae6e6bbSStefano Zampini        def _update(tao, it, cnt):
999ae6e6bbSStefano Zampini            cnt += 1
10039933f97SStefano Zampini
1019ae6e6bbSStefano Zampini        cnt_up = numpy.array(0)
1029ae6e6bbSStefano Zampini        tao.setUpdate(_update, (cnt_up,))
103a82e8c82SStefano Zampini        tao.setSolution(x)
104a82e8c82SStefano Zampini
105a82e8c82SStefano Zampini        # Call the solve method of MyTAO
106a82e8c82SStefano Zampini        x.set(0.5)
107a82e8c82SStefano Zampini        tao.solve()
108a82e8c82SStefano Zampini        n = tao.getIterationNumber()
109a82e8c82SStefano Zampini        self.assertTrue(n == 0)
110a82e8c82SStefano Zampini
111a82e8c82SStefano Zampini        # Call the default solve method and use step of MyTAO
112a82e8c82SStefano Zampini        ctx.solve = None
113a82e8c82SStefano Zampini        x.set(0.5)
114a82e8c82SStefano Zampini        tao.solve()
115a82e8c82SStefano Zampini        n = tao.getIterationNumber()
116e36f15a9Spaul.kuehner        self.assertGreater(tao.getConvergedReason(), 0)
117a19a13c7SStefano Zampini        self.assertTrue(n in [2, 3])
118e36f15a9Spaul.kuehner        self.assertAlmostEqual(x[0], 1.0)
119e36f15a9Spaul.kuehner        self.assertAlmostEqual(x[1], 2.0)
120a82e8c82SStefano Zampini
121a82e8c82SStefano Zampini        # Call the default solve method with the default step method
122a82e8c82SStefano Zampini        ctx.step = None
123a82e8c82SStefano Zampini        x.set(0.5)
124a82e8c82SStefano Zampini        tao.solve()
125a82e8c82SStefano Zampini        n = tao.getIterationNumber()
126e36f15a9Spaul.kuehner        self.assertGreater(tao.getConvergedReason(), 0)
127a19a13c7SStefano Zampini        self.assertTrue(n in [2, 3])
128e36f15a9Spaul.kuehner        self.assertAlmostEqual(x[0], 1.0)
129e36f15a9Spaul.kuehner        self.assertAlmostEqual(x[1], 2.0)
130a82e8c82SStefano Zampini
131a82e8c82SStefano Zampini        self.assertTrue(y1.equal(y2))
132a82e8c82SStefano Zampini        self.assertTrue(ctx.log['monitor'] == 2 * (n + 1))
133a82e8c82SStefano Zampini        self.assertTrue(ctx.log['preStep'] == 2 * n)
134a82e8c82SStefano Zampini        self.assertTrue(ctx.log['postStep'] == 2 * n)
135a82e8c82SStefano Zampini        self.assertTrue(ctx.log['solve'] == 1)
136a82e8c82SStefano Zampini        self.assertTrue(ctx.log['setUp'] == 1)
137a82e8c82SStefano Zampini        self.assertTrue(ctx.log['setFromOptions'] == 1)
138a82e8c82SStefano Zampini        self.assertTrue(ctx.log['step'] == n)
1399ae6e6bbSStefano Zampini        self.assertEqual(cnt_up, 2 * n)
140ebead697SStefano Zampini        tao.cancelMonitor()
141a82e8c82SStefano Zampini
14239933f97SStefano Zampini    def _getCtx(self):
14339933f97SStefano Zampini        return self.tao.getPythonContext()
14439933f97SStefano Zampini
1456f336411SStefano Zampini
146*23e8ad30Spaul.kuehnerclass MyGradientDescent:
147*23e8ad30Spaul.kuehner    def __init__(self):
148*23e8ad30Spaul.kuehner        self._ls = None
149*23e8ad30Spaul.kuehner
150*23e8ad30Spaul.kuehner    def create(self, tao):
151*23e8ad30Spaul.kuehner        self._ls = PETSc.TAOLineSearch().create(comm=PETSc.COMM_SELF)
152*23e8ad30Spaul.kuehner        self._ls.useTAORoutine(tao)
153*23e8ad30Spaul.kuehner        self._ls.setType(PETSc.TAOLineSearch.Type.UNIT)
154*23e8ad30Spaul.kuehner        self._ls.setInitialStepLength(0.2)
155*23e8ad30Spaul.kuehner
156*23e8ad30Spaul.kuehner    def destroy(self, tao):
157*23e8ad30Spaul.kuehner        self._ls.destroy()
158*23e8ad30Spaul.kuehner
159*23e8ad30Spaul.kuehner    def setUp(self, tao):
160*23e8ad30Spaul.kuehner        pass
161*23e8ad30Spaul.kuehner
162*23e8ad30Spaul.kuehner    def solve(self, tao):
163*23e8ad30Spaul.kuehner        x = tao.getSolution()
164*23e8ad30Spaul.kuehner        gradient = tao.getGradient()[0]
165*23e8ad30Spaul.kuehner        search_direction = gradient.copy()
166*23e8ad30Spaul.kuehner        for it in range(tao.getMaximumIterations()):
167*23e8ad30Spaul.kuehner            tao.setIterationNumber(it)
168*23e8ad30Spaul.kuehner
169*23e8ad30Spaul.kuehner            # search_direction = -gradient
170*23e8ad30Spaul.kuehner            tao.computeGradient(x, gradient)
171*23e8ad30Spaul.kuehner            gradient.copy(search_direction)
172*23e8ad30Spaul.kuehner            search_direction.scale(-1)
173*23e8ad30Spaul.kuehner
174*23e8ad30Spaul.kuehner            # x = x + .2 search_direction
175*23e8ad30Spaul.kuehner            f, s, reason = self._ls.apply(x, gradient, search_direction)
176*23e8ad30Spaul.kuehner
177*23e8ad30Spaul.kuehner            tao.monitor(f=f, res=gradient.norm())
178*23e8ad30Spaul.kuehner
179*23e8ad30Spaul.kuehner            if reason < 0:
180*23e8ad30Spaul.kuehner                raise RuntimeError('LS failed.')
181*23e8ad30Spaul.kuehner
182*23e8ad30Spaul.kuehner            if tao.checkConverged() > 0:
183*23e8ad30Spaul.kuehner                break
184*23e8ad30Spaul.kuehner
185*23e8ad30Spaul.kuehner    def step(self, tao, x, g, s):
186*23e8ad30Spaul.kuehner        raise RuntimeError('Should only be called by builtin solve.')
187*23e8ad30Spaul.kuehner
188*23e8ad30Spaul.kuehner    def preStep(self, tao):
189*23e8ad30Spaul.kuehner        raise RuntimeError('Should only be called by builtin solve.')
190*23e8ad30Spaul.kuehner
191*23e8ad30Spaul.kuehner    def postStep(self, tao):
192*23e8ad30Spaul.kuehner        raise RuntimeError('Should only be called by builtin solve.')
193*23e8ad30Spaul.kuehner
194*23e8ad30Spaul.kuehner
195*23e8ad30Spaul.kuehnerclass TestTaoPythonOptimiser(unittest.TestCase):
196*23e8ad30Spaul.kuehner    def setUp(self):
197*23e8ad30Spaul.kuehner        self.tao = PETSc.TAO()
198*23e8ad30Spaul.kuehner        self.tao.createPython(MyGradientDescent(), comm=PETSc.COMM_SELF)
199*23e8ad30Spaul.kuehner
200*23e8ad30Spaul.kuehner    def tearDown(self):
201*23e8ad30Spaul.kuehner        self.tao.destroy()
202*23e8ad30Spaul.kuehner        self.tao = None
203*23e8ad30Spaul.kuehner
204*23e8ad30Spaul.kuehner    def testSolve(self):
205*23e8ad30Spaul.kuehner        tao = self.tao
206*23e8ad30Spaul.kuehner
207*23e8ad30Spaul.kuehner        opts = PETSc.Options('test_tao_python_optimiser_')
208*23e8ad30Spaul.kuehner        opts['tao_max_it'] = 100
209*23e8ad30Spaul.kuehner        opts['tao_gatol'] = 1e-6
210*23e8ad30Spaul.kuehner
211*23e8ad30Spaul.kuehner        tao.setOptionsPrefix('test_tao_python_optimiser_')
212*23e8ad30Spaul.kuehner        tao.setFromOptions()
213*23e8ad30Spaul.kuehner
214*23e8ad30Spaul.kuehner        x = PETSc.Vec().createSeq(2, comm=tao.getComm())
215*23e8ad30Spaul.kuehner        x.set(0.5)
216*23e8ad30Spaul.kuehner
217*23e8ad30Spaul.kuehner        tao.setSolution(x)
218*23e8ad30Spaul.kuehner        tao.setObjective(Objective())
219*23e8ad30Spaul.kuehner        tao.setGradient(Gradient(), x.copy())
220*23e8ad30Spaul.kuehner
221*23e8ad30Spaul.kuehner        tao.solve()
222*23e8ad30Spaul.kuehner
223*23e8ad30Spaul.kuehner        self.assertEqual(tao.getMaximumIterations(), 100)
224*23e8ad30Spaul.kuehner        self.assertAlmostEqual(tao.getTolerances()[0], 1e-6)
225*23e8ad30Spaul.kuehner        self.assertGreater(tao.getIterationNumber(), 0)
226*23e8ad30Spaul.kuehner        self.assertGreater(tao.getConvergedReason(), 0)
227*23e8ad30Spaul.kuehner        self.assertAlmostEqual(x[0], 1.0, places=5)
228*23e8ad30Spaul.kuehner        self.assertAlmostEqual(x[1], 2.0, places=5)
229*23e8ad30Spaul.kuehner        self.assertGreater(tao.getObjectiveValue(), 0)
230*23e8ad30Spaul.kuehner        self.assertAlmostEqual(tao.getObjectiveValue(), 0, places=5)
231*23e8ad30Spaul.kuehner
232*23e8ad30Spaul.kuehner
233a82e8c82SStefano Zampini# --------------------------------------------------------------------
234a82e8c82SStefano Zampini
235a82e8c82SStefano Zampiniif numpy.iscomplexobj(PETSc.ScalarType()):
236a82e8c82SStefano Zampini    del TestTaoPython
237*23e8ad30Spaul.kuehner    del TestTaoPythonOptimiser
238a82e8c82SStefano Zampini
239a82e8c82SStefano Zampiniif __name__ == '__main__':
240a82e8c82SStefano Zampini    unittest.main()
241a82e8c82SStefano Zampini
242a82e8c82SStefano Zampini# --------------------------------------------------------------------
243