xref: /petsc/src/binding/petsc4py/test/test_snes.py (revision 6dd63270497ad23dcf16ae500a87ff2b2a0b7474) !
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6import numpy as np
7
8# --------------------------------------------------------------------
9
10
11class Function:
12    def __call__(self, snes, x, f):
13        f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item()
14        f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item()
15        f.assemble()
16
17
18class Jacobian:
19    def __call__(self, snes, x, J, P):
20        P[0, 0] = (2.0 * x[0] + x[1]).item()
21        P[0, 1] = (x[0]).item()
22        P[1, 0] = (x[1]).item()
23        P[1, 1] = (x[0] + 2.0 * x[1]).item()
24        P.assemble()
25        if J != P:
26            J.assemble()
27
28
29
30# --------------------------------------------------------------------
31
32
33class BaseTestSNES:
34    SNES_TYPE = None
35
36    def setUp(self):
37        snes = PETSc.SNES()
38        snes.create(PETSc.COMM_SELF)
39        if self.SNES_TYPE:
40            snes.setType(self.SNES_TYPE)
41        self.snes = snes
42
43    def tearDown(self):
44        self.snes = None
45        PETSc.garbage_cleanup()
46
47    def testGetSetType(self):
48        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
49        self.snes.setType(self.SNES_TYPE)
50        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
51
52    def testTols(self):
53        tols = self.snes.getTolerances()
54        self.snes.setTolerances(*tols)
55        tnames = ('rtol', 'atol', 'stol', 'max_it')
56        tolvals = [getattr(self.snes, t) for t in tnames]
57        self.assertEqual(tuple(tols), tuple(tolvals))
58        dtol = self.snes.getDivergenceTolerance()
59        self.assertTrue(dtol > 0)
60        self.snes.setDivergenceTolerance(PETSc.UNLIMITED)
61        dtol = self.snes.getDivergenceTolerance()
62        self.assertEqual(dtol, PETSc.UNLIMITED)
63        self.snes.setDivergenceTolerance(PETSc.CURRENT)
64        self.assertEqual(dtol, PETSc.UNLIMITED)
65
66    def testProperties(self):
67        snes = self.snes
68        #
69        snes.appctx = (1, 2, 3)
70        self.assertEqual(snes.appctx, (1, 2, 3))
71        snes.appctx = None
72        self.assertEqual(snes.appctx, None)
73        #
74        snes.its = 1
75        self.assertEqual(snes.its, 1)
76        snes.its = 0
77        self.assertEqual(snes.its, 0)
78        #
79        snes.norm = 1
80        self.assertEqual(snes.norm, 1)
81        snes.norm = 0
82        self.assertEqual(snes.norm, 0)
83        #
84        rh, ih = snes.history
85        self.assertTrue(len(rh) == 0)
86        self.assertTrue(len(ih) == 0)
87        #
88        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS
89        snes.reason = reason
90        self.assertEqual(snes.reason, reason)
91        self.assertTrue(snes.is_converged)
92        self.assertFalse(snes.is_diverged)
93        self.assertFalse(snes.is_iterating)
94        reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT
95        snes.reason = reason
96        self.assertEqual(snes.reason, reason)
97        self.assertFalse(snes.is_converged)
98        self.assertTrue(snes.is_diverged)
99        self.assertFalse(snes.is_iterating)
100        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING
101        snes.reason = reason
102        self.assertEqual(snes.reason, reason)
103        self.assertFalse(snes.is_converged)
104        self.assertFalse(snes.is_diverged)
105        self.assertTrue(snes.is_iterating)
106        #
107        self.assertFalse(snes.use_ew)
108        self.assertFalse(snes.use_mf)
109        self.assertFalse(snes.use_fd)
110        ouse = snes.use_ksp
111        self.assertEqual(ouse, snes.getUseKSP())
112        snes.use_ksp = not ouse
113        self.assertEqual(not ouse, snes.getUseKSP())
114        snes.setUseKSP(ouse)
115        self.assertEqual(ouse, snes.use_ksp)
116
117    def testGetSetFunc(self):
118        r, func = self.snes.getFunction()
119        self.assertFalse(r)
120        self.assertTrue(func is None)
121        r = PETSc.Vec().createSeq(2)
122        func = Function()
123        refcnt = getrefcount(func)
124        self.snes.setFunction(func, r)
125        self.snes.setFunction(func, r)
126        self.assertEqual(getrefcount(func), refcnt + 1)
127        r2, func2 = self.snes.getFunction()
128        self.assertEqual(r, r2)
129        self.assertEqual(func, func2[0])
130        self.assertEqual(getrefcount(func), refcnt + 1)
131        r3, func3 = self.snes.getFunction()
132        self.assertEqual(r, r3)
133        self.assertEqual(func, func3[0])
134        self.assertEqual(getrefcount(func), refcnt + 1)
135
136    def testCompFunc(self):
137        r = PETSc.Vec().createSeq(2)
138        func = Function()
139        self.snes.setFunction(func, r)
140        x, y = r.duplicate(), r.duplicate()
141        x[0], x[1] = [1, 2]
142        self.snes.computeFunction(x, y)
143        self.assertAlmostEqual(abs(y[0]), 0.0, places=5)
144        self.assertAlmostEqual(abs(y[1]), 0.0, places=5)
145
146    def testGetSetJac(self):
147        A, P, jac = self.snes.getJacobian()
148        self.assertFalse(A)
149        self.assertFalse(P)
150        self.assertTrue(jac is None)
151        J = PETSc.Mat().create(PETSc.COMM_SELF)
152        J.setSizes([2, 2])
153        J.setType(PETSc.Mat.Type.SEQAIJ)
154        J.setUp()
155        jac = Jacobian()
156        refcnt = getrefcount(jac)
157        self.snes.setJacobian(jac, J)
158        self.snes.setJacobian(jac, J)
159        self.assertEqual(getrefcount(jac), refcnt + 1)
160        J2, P2, jac2 = self.snes.getJacobian()
161        self.assertEqual(J, J2)
162        self.assertEqual(J2, P2)
163        self.assertEqual(jac, jac2[0])
164        self.assertEqual(getrefcount(jac), refcnt + 1)
165        J3, P3, jac3 = self.snes.getJacobian()
166        self.assertEqual(J, J3)
167        self.assertEqual(J3, P3)
168        self.assertEqual(jac, jac3[0])
169        self.assertEqual(getrefcount(jac), refcnt + 1)
170
171    def testCompJac(self):
172        J = PETSc.Mat().create(PETSc.COMM_SELF)
173        J.setSizes([2, 2])
174        J.setType(PETSc.Mat.Type.SEQAIJ)
175        J.setUp()
176        jac = Jacobian()
177        self.snes.setJacobian(jac, J)
178        x = PETSc.Vec().createSeq(2)
179        x[0], x[1] = [1, 2]
180        self.snes.getKSP().getPC()
181        self.snes.computeJacobian(x, J)
182
183    def testGetSetUpd(self):
184        self.assertTrue(self.snes.getUpdate() is None)
185        upd = lambda snes, it: None
186        refcnt = getrefcount(upd)
187        self.snes.setUpdate(upd)
188        self.assertEqual(getrefcount(upd), refcnt + 1)
189        self.snes.setUpdate(upd)
190        self.assertEqual(getrefcount(upd), refcnt + 1)
191        self.snes.setUpdate(None)
192        self.assertTrue(self.snes.getUpdate() is None)
193        self.assertEqual(getrefcount(upd), refcnt)
194        self.snes.setUpdate(upd)
195        self.assertEqual(getrefcount(upd), refcnt + 1)
196        upd2 = lambda snes, it: None
197        refcnt2 = getrefcount(upd2)
198        self.snes.setUpdate(upd2)
199        self.assertEqual(getrefcount(upd), refcnt)
200        self.assertEqual(getrefcount(upd2), refcnt2 + 1)
201        tmp = self.snes.getUpdate()[0]
202        self.assertTrue(tmp is upd2)
203        self.assertEqual(getrefcount(upd2), refcnt2 + 2)
204        del tmp
205        self.snes.setUpdate(None)
206        self.assertTrue(self.snes.getUpdate() is None)
207        self.assertEqual(getrefcount(upd2), refcnt2)
208
209    def testGetKSP(self):
210        ksp = self.snes.getKSP()
211        self.assertEqual(ksp.getRefCount(), 2)
212
213    def testSolve(self):
214        J = PETSc.Mat().create(PETSc.COMM_SELF)
215        J.setSizes([2, 2])
216        J.setType(PETSc.Mat.Type.SEQAIJ)
217        J.setUp()
218        r = PETSc.Vec().createSeq(2)
219        x = PETSc.Vec().createSeq(2)
220        b = PETSc.Vec().createSeq(2)
221        self.snes.setFunction(Function(), r)
222        self.snes.setJacobian(Jacobian(), J)
223
224        def _update(snes, it, cnt):
225             cnt += 1
226        cnt_up = np.array(0)
227        self.snes.setUpdate(_update, (cnt_up,) )
228
229        x.setArray([2, 3])
230        b.set(0)
231        self.snes.setConvergenceHistory()
232        self.snes.setFromOptions()
233        self.snes.solve(b, x)
234        self.snes.setUpdate(None)
235        rh, ih = self.snes.getConvergenceHistory()
236        self.snes.setConvergenceHistory(0, reset=True)
237        rh, ih = self.snes.getConvergenceHistory()
238        self.assertEqual(len(rh), 0)
239        self.assertEqual(len(ih), 0)
240        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
241        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
242        self.assertEqual(self.snes.getIterationNumber(), cnt_up)
243        # XXX this test should not be here !
244        reason = self.snes.callConvergenceTest(1, 0, 0, 0)
245        self.assertTrue(reason > 0)
246
247        # test interface
248        x = self.snes.getSolution()
249        x.setArray([2, 3])
250        self.snes.solve()
251        self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
252        self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
253
254    def testResetAndSolve(self):
255        self.snes.reset()
256        self.testSolve()
257        self.snes.reset()
258        self.testSolve()
259        self.snes.reset()
260
261    def testSetMonitor(self):
262        reshist = {}
263
264        def monitor(snes, its, fgnorm):
265            reshist[its] = fgnorm
266
267        refcnt = getrefcount(monitor)
268        self.snes.setMonitor(monitor)
269        self.assertEqual(getrefcount(monitor), refcnt + 1)
270        self.testSolve()
271        self.assertTrue(len(reshist) > 0)
272        reshist = {}
273        self.snes.monitorCancel()
274        self.assertEqual(getrefcount(monitor), refcnt)
275        self.testSolve()
276        self.assertTrue(len(reshist) == 0)
277        self.snes.setMonitor(monitor)
278        self.snes.monitor(1, 7)
279        self.assertTrue(reshist[1] == 7)
280        ## Monitor = PETSc.SNES.Monitor
281        ## self.snes.setMonitor(Monitor())
282        ## self.snes.setMonitor(Monitor.DEFAULT)
283        ## self.snes.setMonitor(Monitor.SOLUTION)
284        ## self.snes.setMonitor(Monitor.RESIDUAL)
285        ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE)
286
287    def testSetGetStepFails(self):
288        its = self.snes.getIterationNumber()
289        self.assertEqual(its, 0)
290        fails = self.snes.getNonlinearStepFailures()
291        self.assertEqual(fails, 0)
292        fails = self.snes.getMaxNonlinearStepFailures()
293        self.assertEqual(fails, 1)
294        self.snes.setMaxNonlinearStepFailures(5)
295        fails = self.snes.getMaxNonlinearStepFailures()
296        self.assertEqual(fails, 5)
297        self.snes.setMaxNonlinearStepFailures(1)
298        fails = self.snes.getMaxNonlinearStepFailures()
299        self.assertEqual(fails, 1)
300
301    def testSetGetLinFails(self):
302        its = self.snes.getLinearSolveIterations()
303        self.assertEqual(its, 0)
304        fails = self.snes.getLinearSolveFailures()
305        self.assertEqual(fails, 0)
306        fails = self.snes.getMaxLinearSolveFailures()
307        self.assertEqual(fails, 1)
308        self.snes.setMaxLinearSolveFailures(5)
309        fails = self.snes.getMaxLinearSolveFailures()
310        self.assertEqual(fails, 5)
311        self.snes.setMaxLinearSolveFailures(1)
312        fails = self.snes.getMaxLinearSolveFailures()
313        self.assertEqual(fails, 1)
314
315    def testEW(self):
316        self.snes.setUseEW(False)
317        self.assertFalse(self.snes.getUseEW())
318        self.snes.setUseEW(True)
319        self.assertTrue(self.snes.getUseEW())
320        params = self.snes.getParamsEW()
321        params['version'] = 1
322        self.snes.setParamsEW(**params)
323        params = self.snes.getParamsEW()
324        self.assertEqual(params['version'], 1)
325        params['version'] = PETSc.CURRENT
326        self.snes.setParamsEW(**params)
327        params = self.snes.getParamsEW()
328        self.assertEqual(params['version'], 1)
329
330    def testMF(self):
331        # self.snes.setOptionsPrefix('MF-')
332        # opts = PETSc.Options(self.snes)
333        # opts['mat_mffd_type'] = 'ds'
334        # opts['snes_monitor']  = 'stdout'
335        # opts['ksp_monitor']   = 'stdout'
336        # opts['snes_view']     = 'stdout'
337        J = PETSc.Mat().create(PETSc.COMM_SELF)
338        J.setSizes([2, 2])
339        J.setType(PETSc.Mat.Type.SEQAIJ)
340        J.setUp()
341        r = PETSc.Vec().createSeq(2)
342        x = PETSc.Vec().createSeq(2)
343        b = PETSc.Vec().createSeq(2)
344        fun = Function()
345        jac = Jacobian()
346        self.snes.setFunction(fun, r)
347        self.snes.setJacobian(jac, J)
348        self.assertFalse(self.snes.getUseMF())
349        self.snes.setUseMF(False)
350        self.assertFalse(self.snes.getUseMF())
351        self.snes.setUseMF(True)
352        self.assertTrue(self.snes.getUseMF())
353        self.snes.setFromOptions()
354        if self.snes.getType() != PETSc.SNES.Type.NEWTONTR:
355            x.setArray([2, 3])
356            b.set(0)
357            self.snes.solve(b, x)
358            self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
359            self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
360
361    def testFDColor(self):
362        J = PETSc.Mat().create(PETSc.COMM_SELF)
363        J.setSizes([2, 2])
364        J.setType(PETSc.Mat.Type.SEQAIJ)
365        J.setUp()
366        r = PETSc.Vec().createSeq(2)
367        x = PETSc.Vec().createSeq(2)
368        b = PETSc.Vec().createSeq(2)
369        fun = Function()
370        jac = Jacobian()
371        self.snes.setFunction(fun, r)
372        self.snes.setJacobian(jac, J)
373        self.assertFalse(self.snes.getUseFD())
374        jac(self.snes, x, J, J)
375        self.snes.setUseFD(False)
376        self.assertFalse(self.snes.getUseFD())
377        self.snes.setUseFD(True)
378        self.assertTrue(self.snes.getUseFD())
379        self.snes.setFromOptions()
380        x.setArray([2, 3])
381        b.set(0)
382        self.snes.solve(b, x)
383        self.assertAlmostEqual(abs(x[0]), 1.0, places=4)
384        self.assertAlmostEqual(abs(x[1]), 2.0, places=4)
385
386    def testNPC(self):
387        self.snes.appctx = (1, 2, 3)
388        npc = self.snes.getNPC()
389        self.assertEqual(npc.appctx, (1, 2, 3))
390
391    def testTRAPI(self):
392        newreg = (1,2,3)
393        newup = (1,2,3,4,5)
394        if self.snes.getType() == PETSc.SNES.Type.NEWTONTR:
395            defreg = self.snes.getTRTolerances()
396            defup = self.snes.getTRUpdateParameters()
397        self.snes.setTRTolerances(*newreg)
398        self.snes.setTRUpdateParameters(*newup)
399        if self.snes.getType() == PETSc.SNES.Type.NEWTONTR:
400            self.assertEqual(newreg, self.snes.getTRTolerances())
401            self.assertEqual(newup, self.snes.getTRUpdateParameters())
402        self.snes.setTRTolerances()
403        self.snes.setTRUpdateParameters()
404        if self.snes.getType() == PETSc.SNES.Type.NEWTONTR:
405            self.assertEqual(newreg, self.snes.getTRTolerances())
406            self.assertEqual(newup, self.snes.getTRUpdateParameters())
407        self.snes.setTRTolerances(*(PETSc.DETERMINE,)*3)
408        self.snes.setTRUpdateParameters(*(PETSc.DETERMINE,)*5)
409        if self.snes.getType() == PETSc.SNES.Type.NEWTONTR:
410            self.assertEqual(defreg, self.snes.getTRTolerances())
411            self.assertEqual(defup, self.snes.getTRUpdateParameters())
412
413# --------------------------------------------------------------------
414
415
416class TestSNESLS(BaseTestSNES, unittest.TestCase):
417    SNES_TYPE = PETSc.SNES.Type.NEWTONLS
418
419
420class TestSNESTR(BaseTestSNES, unittest.TestCase):
421    SNES_TYPE = PETSc.SNES.Type.NEWTONTR
422
423
424# --------------------------------------------------------------------
425
426if __name__ == '__main__':
427    unittest.main()
428