xref: /petsc/src/binding/petsc4py/test/test_snes.py (revision 8aa39e1bf17a5ea28fa0458095c26b0a3b4f2478)
1# --------------------------------------------------------------------
2
3from petsc4py import PETSc
4import unittest
5from sys import getrefcount
6import numpy as np
7
8# --------------------------------------------------------------------
9
10
11class Function:
12    def __call__(self, snes, x, f):
13        f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0).item()
14        f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0).item()
15        f.assemble()
16
17
18class Jacobian:
19    def __call__(self, snes, x, J, P):
20        P[0, 0] = (2.0 * x[0] + x[1]).item()
21        P[0, 1] = (x[0]).item()
22        P[1, 0] = (x[1]).item()
23        P[1, 1] = (x[0] + 2.0 * x[1]).item()
24        P.assemble()
25        if J != P:
26            J.assemble()
27
28
29class FunctionAL:
30    def __call__(self, snes, x, f):
31        load = snes.getNewtonALLoadParameter()
32        f[0] = (x[0] * x[0] + x[0] * x[1] - 3.0 * load).item()
33        f[1] = (x[0] * x[1] + x[1] * x[1] - 6.0 * load).item()
34        f.assemble()
35
36
37class FunctionALLoad:
38    def __call__(self, snes, x, f):
39        f[0] = 3.0
40        f[1] = 6.0
41        f.assemble()
42
43
44# --------------------------------------------------------------------
45
46
47class BaseTestSNES:
48    SNES_TYPE = None
49
50    def setUp(self):
51        snes = PETSc.SNES()
52        snes.create(PETSc.COMM_SELF)
53        if self.SNES_TYPE:
54            snes.setType(self.SNES_TYPE)
55        self.snes = snes
56
57    def tearDown(self):
58        self.snes = None
59        PETSc.garbage_cleanup()
60
61    def testGetSetType(self):
62        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
63        self.snes.setType(self.SNES_TYPE)
64        self.assertEqual(self.snes.getType(), self.SNES_TYPE)
65
66    def testTols(self):
67        tols = self.snes.getTolerances()
68        self.snes.setTolerances(*tols)
69        tnames = ('rtol', 'atol', 'stol', 'max_it')
70        tolvals = [getattr(self.snes, t) for t in tnames]
71        self.assertEqual(tuple(tols), tuple(tolvals))
72        dtol = self.snes.getDivergenceTolerance()
73        self.assertTrue(dtol > 0)
74        self.snes.setDivergenceTolerance(PETSc.UNLIMITED)
75        dtol = self.snes.getDivergenceTolerance()
76        self.assertEqual(dtol, PETSc.UNLIMITED)
77        self.snes.setDivergenceTolerance(PETSc.CURRENT)
78        self.assertEqual(dtol, PETSc.UNLIMITED)
79
80    def testProperties(self):
81        snes = self.snes
82        #
83        snes.appctx = (1, 2, 3)
84        self.assertEqual(snes.appctx, (1, 2, 3))
85        snes.appctx = None
86        self.assertEqual(snes.appctx, None)
87        #
88        snes.its = 1
89        self.assertEqual(snes.its, 1)
90        snes.its = 0
91        self.assertEqual(snes.its, 0)
92        #
93        snes.norm = 1
94        self.assertEqual(snes.norm, 1)
95        snes.norm = 0
96        self.assertEqual(snes.norm, 0)
97        #
98        rh, ih = snes.history
99        self.assertTrue(len(rh) == 0)
100        self.assertTrue(len(ih) == 0)
101        #
102        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITS
103        snes.reason = reason
104        self.assertEqual(snes.reason, reason)
105        self.assertTrue(snes.is_converged)
106        self.assertFalse(snes.is_diverged)
107        self.assertFalse(snes.is_iterating)
108        reason = PETSc.SNES.ConvergedReason.DIVERGED_MAX_IT
109        snes.reason = reason
110        self.assertEqual(snes.reason, reason)
111        self.assertFalse(snes.is_converged)
112        self.assertTrue(snes.is_diverged)
113        self.assertFalse(snes.is_iterating)
114        reason = PETSc.SNES.ConvergedReason.CONVERGED_ITERATING
115        snes.reason = reason
116        self.assertEqual(snes.reason, reason)
117        self.assertFalse(snes.is_converged)
118        self.assertFalse(snes.is_diverged)
119        self.assertTrue(snes.is_iterating)
120        #
121        self.assertFalse(snes.use_ew)
122        self.assertFalse(snes.use_mf)
123        self.assertFalse(snes.use_fd)
124        ouse = snes.use_ksp
125        self.assertEqual(ouse, snes.getUseKSP())
126        snes.use_ksp = not ouse
127        self.assertEqual(not ouse, snes.getUseKSP())
128        snes.setUseKSP(ouse)
129        self.assertEqual(ouse, snes.use_ksp)
130
131    def testGetSetFunc(self):
132        r, func = self.snes.getFunction()
133        self.assertFalse(r)
134        self.assertTrue(func is None)
135        r = PETSc.Vec().createSeq(2)
136        func = Function()
137        refcnt = getrefcount(func)
138        self.snes.setFunction(func, r)
139        self.snes.setFunction(func, r)
140        self.assertEqual(getrefcount(func), refcnt + 1)
141        r2, func2 = self.snes.getFunction()
142        self.assertEqual(r, r2)
143        self.assertEqual(func, func2[0])
144        self.assertEqual(getrefcount(func), refcnt + 1)
145        r3, func3 = self.snes.getFunction()
146        self.assertEqual(r, r3)
147        self.assertEqual(func, func3[0])
148        self.assertEqual(getrefcount(func), refcnt + 1)
149
150    def testCompFunc(self):
151        r = PETSc.Vec().createSeq(2)
152        func = Function()
153        self.snes.setFunction(func, r)
154        x, y = r.duplicate(), r.duplicate()
155        x[0], x[1] = [1, 2]
156        self.snes.computeFunction(x, y)
157        self.assertAlmostEqual(abs(y[0]), 0.0, places=5)
158        self.assertAlmostEqual(abs(y[1]), 0.0, places=5)
159
160    def testGetSetJac(self):
161        A, P, jac = self.snes.getJacobian()
162        self.assertFalse(A)
163        self.assertFalse(P)
164        self.assertTrue(jac is None)
165        J = PETSc.Mat().create(PETSc.COMM_SELF)
166        J.setSizes([2, 2])
167        J.setType(PETSc.Mat.Type.SEQAIJ)
168        J.setUp()
169        jac = Jacobian()
170        refcnt = getrefcount(jac)
171        self.snes.setJacobian(jac, J)
172        self.snes.setJacobian(jac, J)
173        self.assertEqual(getrefcount(jac), refcnt + 1)
174        J2, P2, jac2 = self.snes.getJacobian()
175        self.assertEqual(J, J2)
176        self.assertEqual(J2, P2)
177        self.assertEqual(jac, jac2[0])
178        self.assertEqual(getrefcount(jac), refcnt + 1)
179        J3, P3, jac3 = self.snes.getJacobian()
180        self.assertEqual(J, J3)
181        self.assertEqual(J3, P3)
182        self.assertEqual(jac, jac3[0])
183        self.assertEqual(getrefcount(jac), refcnt + 1)
184
185    def testCompJac(self):
186        J = PETSc.Mat().create(PETSc.COMM_SELF)
187        J.setSizes([2, 2])
188        J.setType(PETSc.Mat.Type.SEQAIJ)
189        J.setUp()
190        jac = Jacobian()
191        self.snes.setJacobian(jac, J)
192        x = PETSc.Vec().createSeq(2)
193        x[0], x[1] = [1, 2]
194        self.snes.getKSP().getPC()
195        self.snes.computeJacobian(x, J)
196
197    def testGetSetUpd(self):
198        self.assertTrue(self.snes.getUpdate() is None)
199        upd = lambda snes, it: None
200        refcnt = getrefcount(upd)
201        self.snes.setUpdate(upd)
202        self.assertEqual(getrefcount(upd), refcnt + 1)
203        self.snes.setUpdate(upd)
204        self.assertEqual(getrefcount(upd), refcnt + 1)
205        self.snes.setUpdate(None)
206        self.assertTrue(self.snes.getUpdate() is None)
207        self.assertEqual(getrefcount(upd), refcnt)
208        self.snes.setUpdate(upd)
209        self.assertEqual(getrefcount(upd), refcnt + 1)
210        upd2 = lambda snes, it: None
211        refcnt2 = getrefcount(upd2)
212        self.snes.setUpdate(upd2)
213        self.assertEqual(getrefcount(upd), refcnt)
214        self.assertEqual(getrefcount(upd2), refcnt2 + 1)
215        tmp = self.snes.getUpdate()[0]
216        self.assertTrue(tmp is upd2)
217        self.assertEqual(getrefcount(upd2), refcnt2 + 2)
218        del tmp
219        self.snes.setUpdate(None)
220        self.assertTrue(self.snes.getUpdate() is None)
221        self.assertEqual(getrefcount(upd2), refcnt2)
222
223    def testGetKSP(self):
224        ksp = self.snes.getKSP()
225        self.assertEqual(ksp.getRefCount(), 2)
226
227    def testSolve(self):
228        J = PETSc.Mat().create(PETSc.COMM_SELF)
229        J.setSizes([2, 2])
230        J.setType(PETSc.Mat.Type.SEQAIJ)
231        J.setUp()
232        r = PETSc.Vec().createSeq(2)
233        x = PETSc.Vec().createSeq(2)
234        b = PETSc.Vec().createSeq(2)
235        if self.snes.getType() == PETSc.SNES.Type.NEWTONAL:
236            self.snes.setFunction(FunctionAL(), r)
237            self.snes.setNewtonALCorrectionType(PETSc.SNES.NewtonALCorrectionType.EXACT)
238            self.snes.setNewtonALFunction(FunctionALLoad())
239        else:
240            self.snes.setFunction(Function(), r)
241        self.snes.setJacobian(Jacobian(), J)
242
243        def _update(snes, it, cnt):
244             cnt += 1
245        cnt_up = np.array(0)
246        self.snes.setUpdate(_update, (cnt_up,) )
247
248        x.setArray([2, 3])
249        b.set(0)
250        self.snes.setConvergenceHistory()
251        self.snes.setFromOptions()
252        self.snes.solve(b, x)
253        self.snes.setUpdate(None)
254        rh, ih = self.snes.getConvergenceHistory()
255        self.snes.setConvergenceHistory(0, reset=True)
256        rh, ih = self.snes.getConvergenceHistory()
257        self.assertEqual(len(rh), 0)
258        self.assertEqual(len(ih), 0)
259        if self.snes.getType() != PETSc.SNES.Type.NEWTONAL:
260            self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
261            self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
262        self.assertEqual(self.snes.getIterationNumber(), cnt_up)
263        # XXX this test should not be here !
264        reason = self.snes.callConvergenceTest(1, 0, 0, 0)
265        self.assertTrue(reason > 0)
266
267        # test interface
268        x = self.snes.getSolution()
269        x.setArray([2, 3])
270        self.snes.solve()
271        if self.snes.getType() != PETSc.SNES.Type.NEWTONAL:
272            self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
273            self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
274
275    def testResetAndSolve(self):
276        self.snes.reset()
277        self.testSolve()
278        self.snes.reset()
279        self.testSolve()
280        self.snes.reset()
281
282    def testSetMonitor(self):
283        reshist = {}
284
285        def monitor(snes, its, fgnorm):
286            reshist[its] = fgnorm
287
288        refcnt = getrefcount(monitor)
289        self.snes.setMonitor(monitor)
290        self.assertEqual(getrefcount(monitor), refcnt + 1)
291        self.testSolve()
292        self.assertTrue(len(reshist) > 0)
293        reshist = {}
294        self.snes.monitorCancel()
295        self.assertEqual(getrefcount(monitor), refcnt)
296        self.testSolve()
297        self.assertTrue(len(reshist) == 0)
298        self.snes.setMonitor(monitor)
299        self.snes.monitor(1, 7)
300        self.assertTrue(reshist[1] == 7)
301        ## Monitor = PETSc.SNES.Monitor
302        ## self.snes.setMonitor(Monitor())
303        ## self.snes.setMonitor(Monitor.DEFAULT)
304        ## self.snes.setMonitor(Monitor.SOLUTION)
305        ## self.snes.setMonitor(Monitor.RESIDUAL)
306        ## self.snes.setMonitor(Monitor.SOLUTION_UPDATE)
307
308    def testSetGetStepFails(self):
309        its = self.snes.getIterationNumber()
310        self.assertEqual(its, 0)
311        fails = self.snes.getNonlinearStepFailures()
312        self.assertEqual(fails, 0)
313        fails = self.snes.getMaxNonlinearStepFailures()
314        self.assertEqual(fails, 1)
315        self.snes.setMaxNonlinearStepFailures(5)
316        fails = self.snes.getMaxNonlinearStepFailures()
317        self.assertEqual(fails, 5)
318        self.snes.setMaxNonlinearStepFailures(1)
319        fails = self.snes.getMaxNonlinearStepFailures()
320        self.assertEqual(fails, 1)
321
322    def testSetGetLinFails(self):
323        its = self.snes.getLinearSolveIterations()
324        self.assertEqual(its, 0)
325        fails = self.snes.getLinearSolveFailures()
326        self.assertEqual(fails, 0)
327        fails = self.snes.getMaxLinearSolveFailures()
328        self.assertEqual(fails, 1)
329        self.snes.setMaxLinearSolveFailures(5)
330        fails = self.snes.getMaxLinearSolveFailures()
331        self.assertEqual(fails, 5)
332        self.snes.setMaxLinearSolveFailures(1)
333        fails = self.snes.getMaxLinearSolveFailures()
334        self.assertEqual(fails, 1)
335
336    def testEW(self):
337        self.snes.setUseEW(False)
338        self.assertFalse(self.snes.getUseEW())
339        self.snes.setUseEW(True)
340        self.assertTrue(self.snes.getUseEW())
341        params = self.snes.getParamsEW()
342        params['version'] = 1
343        self.snes.setParamsEW(**params)
344        params = self.snes.getParamsEW()
345        self.assertEqual(params['version'], 1)
346        params['version'] = PETSc.CURRENT
347        self.snes.setParamsEW(**params)
348        params = self.snes.getParamsEW()
349        self.assertEqual(params['version'], 1)
350
351    def testMF(self):
352        # self.snes.setOptionsPrefix('MF-')
353        # opts = PETSc.Options(self.snes)
354        # opts['mat_mffd_type'] = 'ds'
355        # opts['snes_monitor']  = 'stdout'
356        # opts['ksp_monitor']   = 'stdout'
357        # opts['snes_view']     = 'stdout'
358        J = PETSc.Mat().create(PETSc.COMM_SELF)
359        J.setSizes([2, 2])
360        J.setType(PETSc.Mat.Type.SEQAIJ)
361        J.setUp()
362        r = PETSc.Vec().createSeq(2)
363        x = PETSc.Vec().createSeq(2)
364        b = PETSc.Vec().createSeq(2)
365        fun = Function()
366        jac = Jacobian()
367        self.snes.setFunction(fun, r)
368        self.snes.setJacobian(jac, J)
369        self.assertFalse(self.snes.getUseMF())
370        self.snes.setUseMF(False)
371        self.assertFalse(self.snes.getUseMF())
372        self.snes.setUseMF(True)
373        self.assertTrue(self.snes.getUseMF())
374        self.snes.setFromOptions()
375        if self.snes.getType() == PETSc.SNES.Type.NEWTONLS:
376            x.setArray([2, 3])
377            b.set(0)
378            self.snes.solve(b, x)
379            self.assertAlmostEqual(abs(x[0]), 1.0, places=5)
380            self.assertAlmostEqual(abs(x[1]), 2.0, places=5)
381
382    def testFDColor(self):
383        J = PETSc.Mat().create(PETSc.COMM_SELF)
384        J.setSizes([2, 2])
385        J.setType(PETSc.Mat.Type.SEQAIJ)
386        J.setUp()
387        r = PETSc.Vec().createSeq(2)
388        x = PETSc.Vec().createSeq(2)
389        b = PETSc.Vec().createSeq(2)
390        fun = Function()
391        jac = Jacobian()
392        self.snes.setFunction(fun, r)
393        self.snes.setJacobian(jac, J)
394        self.assertFalse(self.snes.getUseFD())
395        jac(self.snes, x, J, J)
396        self.snes.setUseFD(False)
397        self.assertFalse(self.snes.getUseFD())
398        self.snes.setUseFD(True)
399        self.assertTrue(self.snes.getUseFD())
400        self.snes.setFromOptions()
401        x.setArray([2, 3])
402        b.set(0)
403        self.snes.solve(b, x)
404        if self.snes.getType() != PETSc.SNES.Type.NEWTONAL:
405            self.assertAlmostEqual(abs(x[0]), 1.0, places=4)
406            self.assertAlmostEqual(abs(x[1]), 2.0, places=4)
407
408    def testNPC(self):
409        self.snes.appctx = (1, 2, 3)
410        npc = self.snes.getNPC()
411        self.assertEqual(npc.appctx, (1, 2, 3))
412
413    def testTRAPI(self):
414        newreg = (1,2,3)
415        newup = (1,2,3,4,5)
416        if self.snes.getType() == PETSc.SNES.Type.NEWTONTR:
417            defreg = self.snes.getTRTolerances()
418            defup = self.snes.getTRUpdateParameters()
419        self.snes.setTRTolerances(*newreg)
420        self.snes.setTRUpdateParameters(*newup)
421        if self.snes.getType() == PETSc.SNES.Type.NEWTONTR:
422            self.assertEqual(newreg, self.snes.getTRTolerances())
423            self.assertEqual(newup, self.snes.getTRUpdateParameters())
424        self.snes.setTRTolerances()
425        self.snes.setTRUpdateParameters()
426        if self.snes.getType() == PETSc.SNES.Type.NEWTONTR:
427            self.assertEqual(newreg, self.snes.getTRTolerances())
428            self.assertEqual(newup, self.snes.getTRUpdateParameters())
429        self.snes.setTRTolerances(*(PETSc.DETERMINE,)*3)
430        self.snes.setTRUpdateParameters(*(PETSc.DETERMINE,)*5)
431        if self.snes.getType() == PETSc.SNES.Type.NEWTONTR:
432            self.assertEqual(defreg, self.snes.getTRTolerances())
433            self.assertEqual(defup, self.snes.getTRUpdateParameters())
434
435# --------------------------------------------------------------------
436
437
438class TestSNESLS(BaseTestSNES, unittest.TestCase):
439    SNES_TYPE = PETSc.SNES.Type.NEWTONLS
440
441
442class TestSNESTR(BaseTestSNES, unittest.TestCase):
443    SNES_TYPE = PETSc.SNES.Type.NEWTONTR
444
445
446class TestSNESAL(BaseTestSNES, unittest.TestCase):
447    SNES_TYPE = PETSc.SNES.Type.NEWTONAL
448
449
450# --------------------------------------------------------------------
451
452
453class TestSNESLineSearchAPI(unittest.TestCase):
454    def test_create_destroy(self):
455        ls = PETSc.SNESLineSearch()
456        ls.create()
457        ls.destroy()
458
459    def test_type_set_get(self):
460        ls = PETSc.SNESLineSearch().create()
461        ls.setType(PETSc.SNESLineSearch.Type.BASIC)
462        typ = ls.getType()
463        self.assertEqual(typ, 'basic')
464        ls.destroy()
465
466    def test_tolerances_set_get(self):
467        ls = PETSc.SNESLineSearch().create()
468        ls.setTolerances(rtol=0.125, atol=3, minstep=4, ltol=5, maxstep=6, max_its=7)
469        minstep, maxstep, rtol, atol, ltol, max_its = ls.getTolerances()
470        self.assertEqual(rtol, 0.125)
471        self.assertEqual(atol, 3)
472        self.assertEqual(minstep, 4)
473        self.assertEqual(ltol, 5)
474        self.assertEqual(maxstep, 6)
475        self.assertEqual(max_its, 7)
476        ls.destroy()
477
478    def test_order_set_get(self):
479        ls = PETSc.SNESLineSearch().create()
480        ls.setOrder(2)
481        order = ls.getOrder()
482        self.assertEqual(order, 2)
483        ls.destroy()
484
485    def test_set_from_options(self):
486        ls = PETSc.SNESLineSearch().create()
487        ls.setFromOptions()
488        # ls.view()
489        ls.destroy()
490
491    def test_snes_linesearch_property(self):
492        snes = PETSc.SNES().create()
493        ls = snes.getLineSearch()
494        self.assertTrue(isinstance(ls, PETSc.SNESLineSearch))
495        # Set/get via property
496        self.assertEqual(snes.linesearch, ls)
497        snes.linesearch = ls
498        self.assertEqual(snes.linesearch, ls)
499        snes.destroy()
500
501
502# --------------------------------------------------------------------
503
504if __name__ == '__main__':
505    unittest.main()
506